├── LICENSE ├── README.md ├── backbone ├── ResNet.py └── __init__.py ├── config.py ├── data ├── OBdataset.py ├── all_transforms.py └── data │ ├── test_data_pair.csv │ ├── test_pair_new.json │ ├── train_data_pair.csv │ └── train_pair_new.json ├── network ├── BaseBlocks.py ├── DynamicModules.py ├── ObPlaNet_simple.py ├── __init__.py └── tensor_ops.py ├── prepare_multi_fg_scales.py ├── requirements.txt ├── test.py ├── test_multi_fg_scales.py ├── train.py └── utils └── misc.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BCMI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **FOPA: Fast Object Placement Assessment** 2 | ===== 3 | This is the PyTorch implementation of **FOPA** for the following research paper. **FOPA is the first discriminative approach for object placement task.** 4 | > **Fast Object Placement Assessment** [[arXiv]](https://arxiv.org/pdf/2205.14280.pdf)
5 | > 6 | > Li Niu, Qingyang Liu, Zhenchen Liu, Jiangtong Li 7 | 8 | **Our FOPA has been integrated into our image composition toolbox libcom https://github.com/bcmi/libcom. Welcome to visit and try \(^▽^)/** 9 | 10 | If you want to change the backbone to transformer, you can refer to [TopNet](https://github.com/bcmi/TopNet-Object-Placement). 11 | 12 | ## Setup 13 | All the code have been tested on PyTorch 1.7.0. Follow the instructions to run the project. 14 | 15 | First, clone the repository: 16 | ``` 17 | git clone git@github.com:bcmi/FOPA-Fast-Object-Placement-Assessment.git 18 | ``` 19 | Then, install Anaconda and create a virtual environment: 20 | ``` 21 | conda create -n fopa 22 | conda activate fopa 23 | ``` 24 | Install PyTorch 1.7.0 (higher version should be fine): 25 | ``` 26 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.2 -c pytorch 27 | ``` 28 | Install necessary packages: 29 | ``` 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | 34 | ## Data Preparation 35 | Download and extract data from [Baidu Cloud](https://pan.baidu.com/s/10JBpXBMZybEl5FTqBlq-hQ) (access code: 4zf9) or [Dropbox](https://www.dropbox.com/scl/fi/c05wk038piy224sba6jpi/data.rar?rlkey=tghrxjjgo2g93le64tb1xymvq&st=u9nf6hbf&dl=0). 36 | Download the SOPA encoder from [Baidu Cloud](https://pan.baidu.com/s/1hQGm3ryRONRZpNpU66SJZA) (access code: 1x3n) or [Dropbox](https://www.dropbox.com/scl/fi/tlkbmqebokjloe0i1yfpy/SOPA.pth.tar?rlkey=8mzzc53wy6rjqz69o5lkzusau&st=32t23vwm&dl=0). 37 | Put them in "data/data". It should contain the following directories and files: 38 | ``` 39 | 40 | bg/ # background images 41 | fg/ # foreground images 42 | mask/ # foreground masks 43 | train(test)_pair_new.json # json annotations 44 | train(test)_pair_new.csv # csv files 45 | SOPA.pth.tar # SOPA encoder 46 | ``` 47 | 48 | Download our pretrained model from [Baidu Cloud](https://pan.baidu.com/s/15-OBaYE0CF-nDoJrNcCRaw) (access code: uqvb) or [Dropbox](https://www.dropbox.com/scl/fi/q3i6fryoumzr15piuq9pr/best_weight.pth?rlkey=wahho3h18k3ntsaw9pvdyfvea&st=vp2dhpa5&dl=0), and put it in './best_weight.pth'. 49 | 50 | ## Training 51 | Before training, modify "config.py" according to your need. After that, run: 52 | ``` 53 | python train.py 54 | ``` 55 | 56 | ## Test 57 | To get the F1 score and balanced accuracy of a specified model, run: 58 | ``` 59 | python test.py --mode evaluate 60 | ``` 61 | 62 | The results obtained with our released model should be F1: 0.778302, bAcc: 0.838696. 63 | 64 | 65 | To get the heatmaps predicted by FOPA, run: 66 | ``` 67 | python test.py --mode heatmap 68 | ``` 69 | 70 | To get the optimal composite images based on the predicted heatmaps, run: 71 | ``` 72 | python test.py --mode composite 73 | ``` 74 | 75 | 76 | ## Multiple Foreground Scales 77 | For testing multi-scale foregrounds for each foreground-background pair, first run the following command to generate 'test_data_16scales.json' in './data/data' and 'test_16scales' in './data/data/fg', './data/data/mask'. 78 | ``` 79 | python prepare_multi_fg_scales.py 80 | ``` 81 | 82 | Then, to get the heatmaps of multi-scale foregrounds for each foreground-background pair, run: 83 | ``` 84 | python test_multi_fg_scales.py --mode heatmap 85 | ``` 86 | 87 | Finally, to get the composite images with top scores for each foreground-background pair, run: 88 | ``` 89 | python test_multi_fg_scales.py --mode composite 90 | ``` 91 | 92 | ## Evalution on Discriminative Task 93 | 94 | We show the results reported in the paper. FOPA can achieve comparable results with SOPA. 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 |
MethodF1bAcc
SOPA 0.7800.842
FOPA0.7760.840
117 | 118 | ## Evalution on Generation Task 119 | 120 | Given each background-foreground pair in the test set, we predict 16 rationality score maps for 16 foreground scales and generate composite images with top 50 rationality scores. Then, we randomly sample one from 50 generated composite images per background-foreground pair for Acc and FID evaluation, using the test scripts provided by [GracoNet](https://github.com/bcmi/GracoNet-Object-Placement). The generated composite images for evaluation can be downloaded from [Baidu Cloud](https://pan.baidu.com/s/1qqDiXF4tEhizEoI_2BwkrA) (access code: ppft) or [Google Drive](https://drive.google.com/file/d/1yvuoVum_-FMK7lOvrvpx35IdvrV58bTm/view?usp=share_link). The test results of baselines and our method are shown below: 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 |
MethodAccFID
TERSE 0.67946.94
PlaceNet0.68336.69
GracoNet0.84727.75
IOPRE0.89521.59
FOPA 0.932 19.76
158 | 159 | ## Other Resources 160 | 161 | + [Awesome-Object-Placement](https://github.com/bcmi/Awesome-Object-Placement) 162 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion) 163 | 164 | 165 | ## Bibtex 166 | 167 | If you find this work useful for your research, please cite our paper using the following BibTeX [[arxiv](https://arxiv.org/pdf/2107.01889.pdf)]: 168 | 169 | ``` 170 | @article{niu2022fast, 171 | title={Fast Object Placement Assessment}, 172 | author={Niu, Li and Liu, Qingyang and Liu, Zhenchen and Li, Jiangtong}, 173 | journal={arXiv preprint arXiv:2205.14280}, 174 | year={2022} 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /backbone/ResNet.py: -------------------------------------------------------------------------------- 1 | # import torchvision.models as models 2 | # import torch.nn as nn 3 | # # https://pytorch.org/docs/stable/torchvision/models.html#id3 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | 10 | model_urls = { 11 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 12 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 13 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 14 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 15 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 16 | } 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = conv1x1(inplanes, planes) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = conv3x3(planes, planes, stride) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = conv1x1(planes, planes * self.expansion) 70 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | identity = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | identity = self.downsample(x) 91 | 92 | out += identity 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | def __init__(self, block, layers, zero_init_residual=False): 100 | super(ResNet, self).__init__() 101 | self.inplanes = 64 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 6 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 3 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 114 | elif isinstance(m, nn.BatchNorm2d): 115 | nn.init.constant_(m.weight, 1) 116 | nn.init.constant_(m.bias, 0) 117 | 118 | # Zero-initialize the last BN in each residual branch, 119 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 120 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 121 | if zero_init_residual: 122 | for m in self.modules(): 123 | if isinstance(m, Bottleneck): 124 | nn.init.constant_(m.bn3.weight, 0) 125 | elif isinstance(m, BasicBlock): 126 | nn.init.constant_(m.bn2.weight, 0) 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | return x 155 | 156 | def resnet18(pretrained=False, **kwargs): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 163 | if pretrained: 164 | pretrained_dict = model_zoo.load_url(model_urls["resnet18"]) 165 | 166 | model_dict = model.state_dict() 167 | # 1. filter out unnecessary keys 168 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 169 | # 2. overwrite entries in the existing state dict 170 | model_dict.update(pretrained_dict) 171 | # 3. load the new state dict 172 | model.load_state_dict(model_dict) 173 | return model 174 | 175 | 176 | def Backbone_ResNet18_in3(pretrained=True): 177 | if pretrained: 178 | print("The backbone model loads the pretrained parameters...") 179 | net = pretrained_resnet18_4ch(pretrained=True) 180 | div_2 = nn.Sequential(*list(net.children())[:3]) 181 | div_4 = nn.Sequential(*list(net.children())[3:5]) 182 | div_8 = net.layer2 183 | div_16 = net.layer3 184 | div_32 = net.layer4 185 | 186 | return div_2, div_4, div_8, div_16, div_32 187 | 188 | 189 | def Backbone_ResNet18_in3_1(pretrained=True): 190 | if pretrained: 191 | print("The backbone model loads the pretrained parameters...") 192 | net = resnet18(pretrained=pretrained) 193 | 194 | model_dict = net.state_dict() 195 | conv1 = model_dict['conv1.weight'] 196 | new = torch.zeros(64, 1, 7, 7) 197 | for i, output_channel in enumerate(conv1): 198 | new[i] = 0.299 * output_channel[0] + 0.587 * output_channel[1] + 0.114 * output_channel[2] 199 | net.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=3, bias=False) 200 | model_dict['conv1.weight'] = torch.cat((conv1, new), dim=1) 201 | net.load_state_dict(model_dict) 202 | 203 | div_1 = nn.Sequential(*list(net.children())[:1]) 204 | div_2 = nn.Sequential(*list(net.children())[1:3]) 205 | div_4 = nn.Sequential(*list(net.children())[3:5]) 206 | div_8 = net.layer2 207 | div_16 = net.layer3 208 | # div_32 = make_layer_4(BasicBlock, 448, 2, stride=2) 209 | div_32 = net.layer4 210 | return div_1, div_2, div_4, div_8, div_16, div_32 211 | 212 | 213 | 214 | def pretrained_resnet18_4ch(pretrained=True, **kwargs): 215 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 216 | model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=3, bias=False) 217 | 218 | if pretrained: 219 | # load the pretrained binary classification model for slow object placement assessment (SOPA) 220 | checkpoint = torch.load('./data/data/SOPA.pth.tar') 221 | model.load_state_dict(checkpoint['state_dict'], strict=False) 222 | 223 | return model 224 | 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/FOPA-Fast-Object-Placement-Assessment/7f990e06b6b234bfd1e107a30067b610c217c915/backbone/__init__.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __all__ = ["proj_root", "arg_config"] 4 | 5 | proj_root = os.path.dirname(__file__) 6 | datasets_root = "./data/data" 7 | 8 | tr_data_path = os.path.join(datasets_root, "train_pair_new.json") 9 | ts_data_path = os.path.join(datasets_root, "test_pair_new.json") 10 | 11 | coco_dir = './data/data/train2017' 12 | bg_dir = os.path.join(datasets_root, "bg") 13 | fg_dir = os.path.join(datasets_root, "fg") 14 | mask_dir = os.path.join(datasets_root, "mask") 15 | 16 | arg_config = { 17 | 18 | "model": "ObPlaNet_resnet18", # model name 19 | "epoch_num": 25, 20 | "lr": 0.0005, 21 | "train_data_path": tr_data_path, 22 | "test_data_path": ts_data_path, 23 | "bg_dir": bg_dir, 24 | "fg_dir": fg_dir, 25 | "mask_dir": mask_dir, 26 | 27 | "print_freq": 10, # >0, frequency of log print 28 | "prefix": (".jpg", ".png"), 29 | "reduction": "mean", # “mean” or “sum” 30 | "optim": "Adam_trick", # optimizer 31 | "weight_decay": 5e-4, # set as 0.0001 when finetuning 32 | "momentum": 0.9, 33 | "nesterov": False, 34 | "lr_type": "all_decay", # learning rate schedule 35 | "lr_decay": 0.9, # poly 36 | "batch_size": 8, 37 | "num_workers": 6, 38 | "input_size": 256, # input size 39 | "gpu_id": 0, 40 | "ex_name":"demo", # experiment name 41 | } 42 | -------------------------------------------------------------------------------- /data/OBdataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision import transforms 9 | from data.all_transforms import Compose, JointResize 10 | 11 | 12 | class CPDataset(Dataset): 13 | def __init__(self, file, bg_dir, fg_dir, mask_dir, in_size, datatype='train'): 14 | """ 15 | initialize dataset 16 | 17 | Args: 18 | file(str): file with training/test data information 19 | bg_dir(str): folder with background images 20 | fg_dir(str): folder with foreground images 21 | mask_dir(str): folder with mask images 22 | in_size(int): input size of network 23 | datatype(str): "train" or "test" 24 | """ 25 | 26 | self.datatype = datatype 27 | self.data = _collect_info(file, bg_dir, fg_dir, mask_dir, datatype) 28 | self.insize = in_size 29 | 30 | self.train_triple_transform = Compose([JointResize(in_size)]) 31 | self.train_img_transform = transforms.Compose( 32 | [ 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # 处理的是Tensor 35 | ] 36 | ) 37 | self.train_mask_transform = transforms.ToTensor() 38 | 39 | self.transforms_flip = transforms.Compose([ 40 | transforms.RandomHorizontalFlip(p=1) 41 | ]) 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | 46 | def __getitem__(self, index): 47 | """ 48 | load each item 49 | return: 50 | i: the image index, 51 | bg_t:(1 * 3 * in_size * in_size) background image, 52 | mask_t:(1 * 1 * in_size * in_size) scaled foreground mask 53 | fg_t:(1 * 3 * in_size * in_size) scaled foreground image 54 | target_t: (1 * in_size * in_size) pixel-wise binary labels 55 | labels_num: (int) the number of annotated pixels 56 | """ 57 | i, _, bg_path, fg_path, mask_path, scale, pos_label, neg_label, fg_path_2, mask_path_2, w, h = self.data[index] 58 | 59 | fg_name = fg_path.split('/')[-1][:-4] 60 | ## save_name: fg_bg_w_h_scale.jpg 61 | save_name = fg_name + '_' + str(scale) + '.jpg' 62 | 63 | bg_img = Image.open(bg_path) 64 | fg_img = Image.open(fg_path) 65 | mask = Image.open(mask_path) 66 | if len(bg_img.split()) != 3: 67 | bg_img = bg_img.convert("RGB") 68 | if len(fg_img.split()) == 3: 69 | fg_img = fg_img.convert("RGB") 70 | if len(mask.split()) == 3: 71 | mask = mask.convert("L") 72 | 73 | is_flip = False 74 | if self.datatype == 'train' and np.random.uniform() < 0.5: 75 | is_flip = True 76 | 77 | # make composite images which are used in feature mimicking 78 | fg_tocp = Image.open(fg_path_2).convert("RGB") 79 | mask_tocp = Image.open(mask_path_2).convert("L") 80 | composite_list = [] 81 | for pos in pos_label: 82 | x_, y_ = pos 83 | x = int(x_ - w / 2) 84 | y = int(y_ - h / 2) 85 | composite_list.append(make_composite(fg_tocp, mask_tocp, bg_img, [x, y, w, h], is_flip)) 86 | 87 | for pos in neg_label: 88 | x_, y_ = pos 89 | x = int(x_ - w / 2) 90 | y = int(y_ - h / 2) 91 | composite_list.append(make_composite(fg_tocp, mask_tocp, bg_img, [x, y, w, h], is_flip)) 92 | 93 | composite_list_ = torch.stack(composite_list, dim=0) 94 | composite_cat = torch.zeros(50 - len(composite_list), 4, 256, 256) 95 | composite_list = torch.cat((composite_list_, composite_cat), dim=0) 96 | 97 | # positive pixels are 1, negative pixels are 0, other pixels are 255 98 | # feature_pos: record the positions of annotated pixels 99 | target, feature_pos = _obtain_target(bg_img.size[0], bg_img.size[1], self.insize, pos_label, neg_label, is_flip) 100 | for i in range(50 - len(feature_pos)): 101 | feature_pos.append((0, 0)) # pad the length to 50 102 | feature_pos = torch.Tensor(feature_pos) 103 | 104 | # resize the foreground/background to 256, convert them to tensors 105 | bg_t, fg_t, mask_t = self.train_triple_transform(bg_img, fg_img, mask) 106 | mask_t = self.train_mask_transform(mask_t) 107 | fg_t = self.train_img_transform(fg_t) 108 | bg_t = self.train_img_transform(bg_t) 109 | 110 | if is_flip == True: 111 | fg_t = self.transforms_flip(fg_t) 112 | bg_t = self.transforms_flip(bg_t) 113 | mask_t = self.transforms_flip(mask_t) 114 | 115 | # tensor is normalized to [0,1],map back to [0, 255] for ease of computation 116 | target_t = self.train_mask_transform(target) * 255 117 | labels_num = (target_t != 255).sum() 118 | 119 | return i, bg_t, mask_t, fg_t, target_t.squeeze(), labels_num, composite_list, feature_pos, w, h, save_name 120 | 121 | 122 | def _obtain_target(original_width, original_height, in_size, pos_label, neg_label, isflip=False): 123 | """ 124 | put 0, 1 labels on a 256x256 score map 125 | Args: 126 | original_width(int): width of original background 127 | original_height(int): height of original background 128 | in_size(int): input size of network 129 | pos_label(list): positive pixels in original background 130 | neg_label(list): negative pixels in original background 131 | return: 132 | target_r: score map with ground-truth labels 133 | """ 134 | target = np.uint8(np.ones((in_size, in_size)) * 255) 135 | feature_pos = [] 136 | for pos in pos_label: 137 | x, y = pos 138 | x_new = int(x * in_size / original_width) 139 | y_new = int(y * in_size / original_height) 140 | target[y_new, x_new] = 1. 141 | if isflip: 142 | x_new = 256 - x_new 143 | feature_pos.append((x_new, y_new)) 144 | for pos in neg_label: 145 | x, y = pos 146 | x_new = int(x * in_size / original_width) 147 | y_new = int(y * in_size / original_height) 148 | target[y_new, x_new] = 0. 149 | if isflip: 150 | x_new = 256 - x_new 151 | feature_pos.append((x_new, y_new)) 152 | target_r = Image.fromarray(target) 153 | if isflip: 154 | target_r = transforms.RandomHorizontalFlip(p=1)(target_r) 155 | return target_r, feature_pos 156 | 157 | 158 | def _collect_info(json_file, bg_dir, fg_dir, mask_dir, datatype='train'): 159 | """ 160 | load json file and return required information 161 | Args: 162 | json_file(str): json file with train/test information 163 | bg_dir(str): folder with background images 164 | fg_dir(str): folder with foreground images 165 | mask_dir(str): folder with foreground masks 166 | datatype(str): "train" or "test" 167 | return: 168 | index(int): the sample index 169 | background image path, foreground image path, foreground mask image 170 | foreground scale, the locations of positive/negative pixels 171 | """ 172 | f_json = json.load(open(json_file, 'r')) 173 | return [ 174 | ( 175 | index, 176 | row['scID'].rjust(12,'0'), 177 | os.path.join(bg_dir, "%012d.jpg" % int(row['scID'])), # background image path 178 | os.path.join(fg_dir, "{}/{}_{}_{}_{}.jpg".format(datatype, int(row['annID']), int(row['scID']), # scaled foreground image path 179 | int(row['newWidth']), int(row['newHeight']))), 180 | 181 | os.path.join(mask_dir, "{}/{}_{}_{}_{}.jpg".format(datatype, int(row['annID']), int(row['scID']), # scaled foreground mask path 182 | int(row['newWidth']), int(row['newHeight']))), 183 | row['scale'], 184 | row['pos_label'], row['neg_label'], 185 | os.path.join(fg_dir, "foreground/{}.jpg".format(int(row['annID']))), # original foreground image path 186 | os.path.join(fg_dir, "foreground/mask_{}.jpg".format(int(row['annID']))), # original foreground mask path 187 | int(row['newWidth']), int(row['newHeight']) # scaled foreground width and height 188 | ) 189 | for index, row in enumerate(f_json) 190 | ] 191 | 192 | 193 | def _to_center(bbox): 194 | """conver bbox to center pixel""" 195 | x, y, width, height = bbox 196 | return x + width // 2, y + height // 2 197 | 198 | 199 | def create_loader(table_path, bg_dir, fg_dir, mask_dir, in_size, datatype, batch_size, num_workers, shuffle): 200 | dset = CPDataset(table_path, bg_dir, fg_dir, mask_dir, in_size, datatype) 201 | data_loader = DataLoader(dset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle) 202 | 203 | return data_loader 204 | 205 | 206 | def make_composite(fg_img, mask_img, bg_img, pos, isflip=False): 207 | x, y, w, h = pos 208 | bg_h = bg_img.height 209 | bg_w = bg_img.width 210 | # resize foreground to expected size [h, w] 211 | fg_transform = transforms.Compose([ 212 | transforms.Resize((h, w)), 213 | transforms.ToTensor(), 214 | ]) 215 | top = max(y, 0) 216 | bottom = min(y + h, bg_h) 217 | left = max(x, 0) 218 | right = min(x + w, bg_w) 219 | fg_img_ = fg_transform(fg_img) 220 | mask_img_ = fg_transform(mask_img) 221 | fg_img = torch.zeros(3, bg_h, bg_w) 222 | mask_img = torch.zeros(3, bg_h, bg_w) 223 | fg_img[:, top:bottom, left:right] = fg_img_[:, top - y:bottom - y, left - x:right - x] 224 | mask_img[:, top:bottom, left:right] = mask_img_[:, top - y:bottom - y, left - x:right - x] 225 | bg_img = transforms.ToTensor()(bg_img) 226 | blended = fg_img * mask_img + bg_img * (1 - mask_img) 227 | com_pic = transforms.ToPILImage()(blended).convert('RGB') 228 | if isflip == False: 229 | com_pic = transforms.Compose( 230 | [ 231 | transforms.Resize((256, 256)), 232 | transforms.ToTensor() 233 | ] 234 | )(com_pic) 235 | mask_img = transforms.ToPILImage()(mask_img).convert('L') 236 | mask_img = transforms.Compose( 237 | [ 238 | transforms.Resize((256, 256)), 239 | transforms.ToTensor() 240 | ] 241 | )(mask_img) 242 | com_pic = torch.cat((com_pic, mask_img), dim=0) 243 | else: 244 | com_pic = transforms.Compose( 245 | [ 246 | transforms.Resize((256, 256)), 247 | transforms.RandomHorizontalFlip(p=1), 248 | transforms.ToTensor() 249 | ] 250 | )(com_pic) 251 | mask_img = transforms.ToPILImage()(mask_img).convert('L') 252 | mask_img = transforms.Compose( 253 | [ 254 | transforms.Resize((256, 256)), 255 | transforms.RandomHorizontalFlip(p=1), 256 | transforms.ToTensor() 257 | ] 258 | )(mask_img) 259 | com_pic = torch.cat((com_pic, mask_img), dim=0) 260 | return com_pic 261 | 262 | def make_composite_PIL(fg_img, mask_img, bg_img, pos, return_mask=False): 263 | x, y, w, h = pos 264 | bg_h = bg_img.height 265 | bg_w = bg_img.width 266 | 267 | top = max(y, 0) 268 | bottom = min(y + h, bg_h) 269 | left = max(x, 0) 270 | right = min(x + w, bg_w) 271 | fg_img_ = fg_img.resize((w,h)) 272 | mask_img_ = mask_img.resize((w,h)) 273 | 274 | fg_img_ = np.array(fg_img_) 275 | mask_img_ = np.array(mask_img_, dtype=np.float_)/255 276 | bg_img = np.array(bg_img) 277 | 278 | fg_img = np.zeros((bg_h, bg_w, 3), dtype=np.uint8) 279 | mask_img = np.zeros((bg_h, bg_w, 3), dtype=np.float_) 280 | 281 | fg_img[top:bottom, left:right, :] = fg_img_[top - y:bottom - y, left - x:right - x, :] 282 | mask_img[top:bottom, left:right, :] = mask_img_[top - y:bottom - y, left - x:right - x, :] 283 | composite_img = fg_img * mask_img + bg_img * (1 - mask_img) 284 | 285 | 286 | composite_img = Image.fromarray(composite_img.astype(np.uint8)) 287 | if return_mask==False: 288 | return composite_img 289 | else: 290 | composite_msk = Image.fromarray((mask_img*255).astype(np.uint8)) 291 | return composite_img, composite_msk 292 | 293 | -------------------------------------------------------------------------------- /data/all_transforms.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | class JointResize(object): 4 | def __init__(self, size): 5 | if isinstance(size, int): 6 | self.size = (size, size) 7 | elif isinstance(size, tuple): 8 | self.size = size 9 | else: 10 | raise RuntimeError("size should be int or tuple") 11 | 12 | def __call__(self, bg, fg, mask): 13 | bg = bg.resize(self.size, Image.BILINEAR) 14 | fg = fg.resize(self.size, Image.BILINEAR) 15 | mask = mask.resize(self.size, Image.NEAREST) 16 | return bg, fg, mask 17 | 18 | class Compose(object): 19 | def __init__(self, transforms): 20 | self.transforms = transforms 21 | 22 | def __call__(self, bg, fg, mask): 23 | for t in self.transforms: 24 | bg, fg, mask = t(bg, fg, mask) 25 | return bg, fg, mask 26 | -------------------------------------------------------------------------------- /network/BaseBlocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BasicConv2d(nn.Module): 5 | def __init__( 6 | self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, 7 | ): 8 | super(BasicConv2d, self).__init__() 9 | 10 | self.basicconv = nn.Sequential( 11 | nn.Conv2d( 12 | in_planes, 13 | out_planes, 14 | kernel_size=kernel_size, 15 | stride=stride, 16 | padding=padding, 17 | dilation=dilation, 18 | groups=groups, 19 | bias=bias, 20 | ), 21 | nn.BatchNorm2d(out_planes), 22 | nn.ReLU(inplace=True), 23 | ) 24 | 25 | def forward(self, x): 26 | return self.basicconv(x) 27 | -------------------------------------------------------------------------------- /network/DynamicModules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class simpleDFN(nn.Module): 6 | def __init__(self, in_xC, in_yC, out_C, kernel_size=3, down_factor=4): 7 | """use nn.Unfold to realize dynamic convolution 8 | 9 | Args: 10 | in_xC (int): channel number of first input 11 | in_yC (int): channel number of second input 12 | out_C (int): channel number of output 13 | kernel_size (int): the size of generated conv kernel 14 | down_factor (int): reduce the model parameters when generating conv kernel 15 | """ 16 | super(simpleDFN, self).__init__() 17 | self.kernel_size = kernel_size 18 | self.fuse = nn.Conv2d(in_xC, out_C, 3, 1, 1) 19 | self.out_C = out_C 20 | self.gernerate_kernel = nn.Sequential( 21 | # nn.Conv2d(in_yC, in_yC, 3, 1, 1), 22 | # DenseLayer(in_yC, in_yC, k=down_factor), 23 | nn.Conv2d(in_yC, in_xC, 1), 24 | ) 25 | self.unfold = nn.Unfold(kernel_size=3, dilation=1, padding=1, stride=1) 26 | self.pool = nn.AdaptiveAvgPool2d(self.kernel_size) 27 | self.in_planes = in_yC 28 | 29 | def forward(self, x, y): # x:bg y:fg 30 | kernel = self.gernerate_kernel(self.pool(y)) 31 | batch_size, in_planes, height, width = x.size() 32 | x = x.view(1, -1, height, width) 33 | kernel = kernel.view(-1, 1, self.kernel_size, self.kernel_size) 34 | if self.kernel_size == 3: 35 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=1, groups=self.in_planes * batch_size) 36 | elif self.kernel_size == 1: 37 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=0, groups=self.in_planes * batch_size) 38 | elif self.kernel_size == 5: 39 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=2, groups=self.in_planes * batch_size) 40 | else: 41 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=3, groups=self.in_planes * batch_size) 42 | output = output.view(batch_size, -1, height, width) 43 | return self.fuse(output) 44 | -------------------------------------------------------------------------------- /network/ObPlaNet_simple.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from torchvision import transforms 5 | 6 | sys.path.append("..") 7 | from backbone.ResNet import Backbone_ResNet18_in3, Backbone_ResNet18_in3_1 8 | from network.BaseBlocks import BasicConv2d 9 | from network.DynamicModules import simpleDFN 10 | from network.tensor_ops import cus_sample, upsample_add 11 | 12 | class ObPlaNet_resnet18(nn.Module): 13 | def __init__(self, pretrained=True, ks=3, scale=3): 14 | super(ObPlaNet_resnet18, self).__init__() 15 | self.Eiters = 0 16 | self.upsample_add = upsample_add 17 | self.upsample = cus_sample 18 | self.to_pil = transforms.ToPILImage() 19 | self.scale = scale 20 | 21 | self.add_mask = True 22 | 23 | ( 24 | self.bg_encoder1, 25 | self.bg_encoder2, 26 | self.bg_encoder4, 27 | self.bg_encoder8, 28 | self.bg_encoder16, 29 | ) = Backbone_ResNet18_in3(pretrained=pretrained) 30 | 31 | # freeze background encoder 32 | for p in self.parameters(): 33 | p.requires_grad = False 34 | 35 | ( 36 | self.fg_encoder1, 37 | self.fg_encoder2, 38 | self.fg_encoder4, 39 | self.fg_encoder8, 40 | self.fg_encoder16, 41 | self.fg_encoder32, 42 | ) = Backbone_ResNet18_in3_1(pretrained=pretrained) 43 | 44 | if self.add_mask: 45 | self.mask_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 46 | 47 | # dynamic conv 48 | self.fg_trans16 = nn.Conv2d(512, 64, 1) 49 | self.fg_trans8 = nn.Conv2d(256, 64, 1) 50 | self.selfdc_16 = simpleDFN(64, 64, 512, ks, 4) 51 | self.selfdc_8 = simpleDFN(64, 64, 512, ks, 4) 52 | 53 | self.upconv16 = BasicConv2d(512, 256, kernel_size=3, stride=1, padding=1) 54 | self.upconv8 = BasicConv2d(256, 128, kernel_size=3, stride=1, padding=1) 55 | self.upconv4 = BasicConv2d(128, 64, kernel_size=3, stride=1, padding=1) 56 | self.upconv2 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 57 | self.upconv1 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 58 | 59 | self.classifier = nn.Conv2d(512, 2, 1) 60 | 61 | def forward(self, bg_in_data, fg_in_data, mask_in_data=None, mode='test'): 62 | """ 63 | Args: 64 | bg_in_data: (batch_size * 3 * H * W) background image 65 | fg_in_data: (batch_size * 3 * H * W) scaled foreground image 66 | mask_in_data: (batch_size * 1 * H * W) scaled foreground mask 67 | mode: "train" or "test" 68 | """ 69 | if ('train' == mode): 70 | self.Eiters += 1 71 | 72 | # extract background and foreground features 73 | black_mask = torch.zeros(mask_in_data.size()).to(mask_in_data.device) 74 | bg_in_data_ = torch.cat([bg_in_data, black_mask], dim=1) 75 | bg_in_data_1 = self.bg_encoder1(bg_in_data_) # torch.Size([2, 64, 128, 128]) 76 | fg_cat_mask = torch.cat([fg_in_data, mask_in_data], dim=1) 77 | fg_in_data_1 = self.fg_encoder1(fg_cat_mask) # torch.Size([2, 64, 128, 128]) 78 | 79 | 80 | bg_in_data_2 = self.bg_encoder2(bg_in_data_1) # torch.Size([2, 64, 64, 64]) 81 | fg_in_data_2 = self.fg_encoder2(fg_in_data_1) # torch.Size([2, 64, 128, 128]) 82 | bg_in_data_4 = self.bg_encoder4(bg_in_data_2) # torch.Size([2, 128, 32, 32]) 83 | fg_in_data_4 = self.fg_encoder4(fg_in_data_2) # torch.Size([2, 64, 64, 64]) 84 | del fg_in_data_1, fg_in_data_2 85 | 86 | bg_in_data_8 = self.bg_encoder8(bg_in_data_4) # torch.Size([2, 256, 16, 16]) 87 | fg_in_data_8 = self.fg_encoder8(fg_in_data_4) # torch.Size([2, 128, 32, 32]) 88 | bg_in_data_16 = self.bg_encoder16(bg_in_data_8) # torch.Size([2, 512, 8, 8]) 89 | fg_in_data_16 = self.fg_encoder16(fg_in_data_8) # torch.Size([2, 256, 16, 16]) 90 | fg_in_data_32 = self.fg_encoder32(fg_in_data_16) # torch.Size([2, 512, 8, 8]) 91 | 92 | in_data_8_aux = self.fg_trans8(fg_in_data_16) # torch.Size([2, 64, 16, 16]) 93 | in_data_16_aux = self.fg_trans16(fg_in_data_32) # torch.Size([2, 64, 8, 8]) 94 | 95 | # Unet decoder 96 | bg_out_data_16 = bg_in_data_16 # torch.Size([2, 512, 8, 8]) 97 | 98 | bg_out_data_8 = self.upsample_add(self.upconv16(bg_out_data_16), bg_in_data_8) # torch.Size([2, 256, 16, 16]) 99 | bg_out_data_4 = self.upsample_add(self.upconv8(bg_out_data_8), bg_in_data_4) # torch.Size([2, 128, 32, 32]) 100 | bg_out_data_2 = self.upsample_add(self.upconv4(bg_out_data_4), bg_in_data_2) # torch.Size([2, 64, 64, 64]) 101 | bg_out_data_1 = self.upsample_add(self.upconv2(bg_out_data_2), bg_in_data_1) # torch.Size([2, 64, 128, 128]) 102 | del bg_out_data_2, bg_out_data_4, bg_out_data_8, bg_out_data_16 103 | 104 | bg_out_data = self.upconv1(self.upsample(bg_out_data_1, scale_factor=2)) # torch.Size([2, 64, 256, 256]) 105 | 106 | # fuse foreground and background features using dynamic conv 107 | fuse_out = self.upsample_add(self.selfdc_16(bg_out_data_1, in_data_8_aux), \ 108 | self.selfdc_8(bg_out_data, in_data_16_aux)) # torch.Size([2, 64, 256, 256]) 109 | 110 | out_data = self.classifier(fuse_out) # torch.Size([2, 2, 256, 256]) 111 | 112 | return out_data, fuse_out 113 | 114 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from network.ObPlaNet_simple import * 2 | -------------------------------------------------------------------------------- /network/tensor_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def cus_sample(feat, **kwargs): 5 | assert len(kwargs.keys()) == 1 and list(kwargs.keys())[0] in ["size", "scale_factor"] 6 | return F.interpolate(feat, **kwargs, mode="bilinear", align_corners=True) 7 | 8 | 9 | def upsample_add(*xs): 10 | y = xs[-1] 11 | for x in xs[:-1]: 12 | y = y + F.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False) 13 | return y 14 | 15 | 16 | def upsample_cat(*xs): 17 | y = xs[-1] 18 | out = [] 19 | for x in xs[:-1]: 20 | out.append(F.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False)) 21 | return torch.cat([*out, y], dim=1) 22 | 23 | 24 | def upsample_reduce(b, a): 25 | _, C, _, _ = b.size() 26 | N, _, H, W = a.size() 27 | 28 | b = F.interpolate(b, size=(H, W), mode="bilinear", align_corners=False) 29 | a = a.reshape(N, -1, C, H, W).mean(1) 30 | 31 | return b + a 32 | 33 | 34 | def shuffle_channels(x, groups): 35 | N, C, H, W = x.size() 36 | x = x.reshape(N, groups, C // groups, H, W).permute(0, 2, 1, 3, 4) 37 | return x.reshape(N, C, H, W) 38 | -------------------------------------------------------------------------------- /prepare_multi_fg_scales.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import json 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from config import arg_config 9 | 10 | fg_scale_num = 16 11 | save_img_flag = True 12 | 13 | 14 | def collect_info(json_file, bg_dir, fg_dir): 15 | 16 | f_json = json.load(open(json_file, 'r')) 17 | return [ 18 | ( 19 | row['imgID'], row['annID'], row['scID'], 20 | os.path.join(bg_dir, "%012d.jpg" % int(row['scID'])), 21 | os.path.join(fg_dir, "foreground/{}.jpg".format(int(row['annID']))), 22 | os.path.join(fg_dir, "foreground/mask_{}.jpg".format(int(row['annID']))) 23 | ) 24 | for _, row in enumerate(f_json) 25 | ] 26 | 27 | 28 | fg_scales = list(range(1, fg_scale_num+1)) 29 | fg_scales = [i/(1+fg_scale_num+1) for i in fg_scales] 30 | 31 | fg_bg_dict = dict() 32 | args = arg_config 33 | data = collect_info(args["test_data_path"], args["bg_dir"], args["fg_dir"]) 34 | 35 | csv_dir = './data/data' 36 | scaled_fg_dir = f'./data/data/fg/test_{fg_scale_num}scales/' 37 | scaled_mask_dir = f'./data/data/mask/test_{fg_scale_num}scales/' 38 | 39 | os.makedirs(scaled_fg_dir, exist_ok=True) 40 | os.makedirs(scaled_mask_dir, exist_ok=True) 41 | 42 | csv_file = os.path.join(csv_dir, f'test_data_{fg_scale_num}scales.csv') 43 | json_file = csv_file.replace('.csv', '.json') 44 | 45 | file = open(csv_file, mode='w', newline='') 46 | writer = csv.writer(file) 47 | 48 | 49 | csv_head = ['imgID', 'annID', 'scID', 'scale', 'newWidth', 'newHeight', 'pos_label', 'neg_label'] 50 | writer.writerow(csv_head) 51 | 52 | 53 | 54 | for _,index in enumerate(tqdm(range(len(data)))): 55 | imgID, fg_id, bg_id, bg_path, fg_path, mask_path = data[index] 56 | if (fg_id, bg_id) in fg_bg_dict.keys(): 57 | continue 58 | fg_bg_dict[(fg_id, bg_id)] = 1 59 | 60 | 61 | bg_img = Image.open(bg_path) 62 | if len(bg_img.split()) != 3: 63 | bg_img = bg_img.convert("RGB") 64 | bg_img_aspect = bg_img.height/bg_img.width 65 | fg_tocp = Image.open(fg_path).convert("RGB") 66 | mask_tocp = Image.open(mask_path).convert("RGB") 67 | fg_tocp_aspect = fg_tocp.height/fg_tocp.width 68 | 69 | for fg_scale in fg_scales: 70 | if fg_tocp_aspect>bg_img_aspect: 71 | new_height = bg_img.height*fg_scale 72 | new_width = new_height/fg_tocp.height*fg_tocp.width 73 | else: 74 | new_width = bg_img.width*fg_scale 75 | new_height = new_width/fg_tocp.width*fg_tocp.height 76 | 77 | new_height = int(new_height) 78 | new_width = int(new_width) 79 | 80 | if save_img_flag: 81 | top = int((bg_img.height-new_height)/2) 82 | bottom = top+new_height 83 | left = int((bg_img.width-new_width)/2) 84 | right = left+new_width 85 | 86 | fg_img_ = fg_tocp.resize((new_width, new_height)) 87 | mask_ = mask_tocp.resize((new_width, new_height)) 88 | 89 | fg_img_ = np.array(fg_img_) 90 | mask_ = np.array(mask_) 91 | 92 | fg_img = np.zeros((bg_img.height, bg_img.width, 3), dtype=np.uint8) 93 | mask = np.zeros((bg_img.height, bg_img.width, 3), dtype=np.uint8) 94 | 95 | fg_img[top:bottom, left:right, :] = fg_img_ 96 | mask[top:bottom, left:right, :] = mask_ 97 | 98 | fg_img = Image.fromarray(fg_img.astype(np.uint8)) 99 | mask = Image.fromarray(mask.astype(np.uint8)) 100 | 101 | basename = f'{fg_id}_{bg_id}_{new_width}_{new_height}.jpg' 102 | fg_img_path = os.path.join(scaled_fg_dir, basename) 103 | mask_path = os.path.join(scaled_mask_dir, basename) 104 | fg_img.save(fg_img_path) 105 | mask.save(mask_path) 106 | 107 | writer.writerow([imgID, fg_id, bg_id, fg_scale, new_width, new_height, None, None]) 108 | 109 | 110 | file.close() 111 | 112 | # convert csv file to json file 113 | csv_data = [] 114 | with open(csv_file, mode='r') as file: 115 | reader = csv.DictReader(file) 116 | for row in reader: 117 | if row['pos_label']=="": 118 | row['pos_label'] = [[0,0]] 119 | if row['neg_label']=="": 120 | row['neg_label'] = [[0,0]] 121 | csv_data.append(row) 122 | 123 | with open(json_file, mode='w') as file: 124 | json.dump(csv_data, file, indent=4) 125 | 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | tensorboard_logger 3 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from pprint import pprint 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | import network 11 | from config import arg_config 12 | from data.OBdataset import create_loader, _collect_info 13 | from data.OBdataset import make_composite_PIL 14 | 15 | 16 | class Evaluator: 17 | def __init__(self, args, checkpoint_path): 18 | super(Evaluator, self).__init__() 19 | self.args = args 20 | self.dev = torch.device("cuda:0") 21 | self.to_pil = transforms.ToPILImage() 22 | self.checkpoint_path = checkpoint_path 23 | pprint(self.args) 24 | 25 | print('load pretrained weights from ', checkpoint_path) 26 | self.net = getattr(network, self.args["model"])( 27 | pretrained=False).to(self.dev) 28 | self.net.load_state_dict(torch.load(checkpoint_path, map_location=self.dev), strict=False) 29 | self.net = self.net.to(self.dev).eval() 30 | self.softmax = torch.nn.Softmax(dim=1) 31 | 32 | def evalutate_model(self, datatype): 33 | ''' 34 | calculate F1 and bAcc metrics 35 | ''' 36 | 37 | correct = 0 38 | total = 0 39 | TP = 0 40 | TN = 0 41 | FP = 0 42 | FN = 0 43 | 44 | assert datatype=='train' or datatype=='test' 45 | 46 | self.ts_loader = create_loader( 47 | self.args[f"{datatype}_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 48 | self.args["input_size"], datatype, self.args["batch_size"], self.args["num_workers"], False, 49 | ) 50 | 51 | with torch.no_grad(): 52 | 53 | for _, test_data in enumerate(tqdm(self.ts_loader)): 54 | _, test_bgs, test_masks, test_fgs, test_targets, nums, composite_list, feature_pos, _, _, _ = test_data 55 | test_bgs = test_bgs.to(self.dev, non_blocking=True) 56 | test_masks = test_masks.to(self.dev, non_blocking=True) 57 | test_fgs = test_fgs.to(self.dev, non_blocking=True) 58 | nums = nums.to(self.dev, non_blocking=True) 59 | composite_list = composite_list.to(self.dev, non_blocking=True) 60 | feature_pos = feature_pos.to(self.dev, non_blocking=True) 61 | 62 | test_outs, _ = self.net(test_bgs, test_fgs, test_masks, 'val') 63 | test_preds = np.argmax(test_outs.cpu().numpy(), axis=1) 64 | test_targets = test_targets.cpu().numpy() 65 | 66 | TP += ((test_preds == 1) & (test_targets == 1)).sum() 67 | TN += ((test_preds == 0) & (test_targets == 0)).sum() 68 | FP += ((test_preds == 1) & (test_targets == 0)).sum() 69 | FN += ((test_preds == 0) & (test_targets == 1)).sum() 70 | 71 | correct += (test_preds == test_targets).sum() 72 | total += nums.sum() 73 | 74 | precision = TP / (TP + FP) 75 | recall = TP / (TP + FN) 76 | fscore = (2 * precision * recall) / (precision + recall) 77 | weighted_acc = (TP / (TP + FN) + TN / (TN + FP)) * 0.5 78 | 79 | print('F-1 Measure: %f, ' % fscore) 80 | print('Weighted acc measure: %f, ' % weighted_acc) 81 | 82 | def get_heatmap(self, datatype): 83 | ''' 84 | generate heatmap for each pair of scaled foreground and background 85 | ''' 86 | 87 | save_dir, base_name = os.path.split(self.checkpoint_path) 88 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_heatmap')) 89 | 90 | if not os.path.exists(heatmap_dir): 91 | print(f"Create directory {heatmap_dir}") 92 | os.makedirs(heatmap_dir) 93 | 94 | 95 | 96 | self.ts_loader = create_loader( 97 | self.args[f"{datatype}_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 98 | self.args["input_size"], datatype, 1, self.args["num_workers"], False, 99 | ) 100 | 101 | with torch.no_grad(): 102 | for _, test_data in enumerate(tqdm(self.ts_loader)): 103 | _, test_bgs, test_masks, test_fgs, _, nums, composite_list, feature_pos, _, _, save_name = test_data 104 | test_bgs = test_bgs.to(self.dev, non_blocking=True) 105 | test_masks = test_masks.to(self.dev, non_blocking=True) 106 | test_fgs = test_fgs.to(self.dev, non_blocking=True) 107 | nums = nums.to(self.dev, non_blocking=True) 108 | composite_list = composite_list.to(self.dev, non_blocking=True) 109 | feature_pos = feature_pos.to(self.dev, non_blocking=True) 110 | 111 | test_outs, _ = self.net(test_bgs, test_fgs, test_masks, 'test') 112 | test_outs = self.softmax(test_outs) 113 | 114 | test_outs = test_outs[:,1,:,:] 115 | test_outs = transforms.ToPILImage()(test_outs) 116 | test_outs.save(os.path.join(heatmap_dir, save_name[0])) 117 | 118 | def generate_composite(self, datatype, composite_num): 119 | ''' 120 | generate composite images for each pair of scaled foreground and background 121 | ''' 122 | 123 | save_dir, base_name = os.path.split(self.checkpoint_path) 124 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_heatmap')) 125 | if not os.path.exists(heatmap_dir): 126 | print(f"{heatmap_dir} does not exist! Please first use 'heatmap' mode to generate heatmaps") 127 | 128 | data = _collect_info(self.args[f"{datatype}_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 'test') 129 | for index in range(len(data)): 130 | _, _, bg_path, fg_path, _, scale, _, _, fg_path_2, mask_path_2, w, h = data[index] 131 | 132 | fg_name = fg_path.split('/')[-1][:-4] 133 | save_name = fg_name + '_' + str(scale) 134 | 135 | bg_img = Image.open(bg_path) 136 | if len(bg_img.split()) != 3: 137 | bg_img = bg_img.convert("RGB") 138 | fg_tocp = Image.open(fg_path_2).convert("RGB") 139 | mask_tocp = Image.open(mask_path_2).convert("RGB") 140 | 141 | composite_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_composite'), save_name) 142 | if not os.path.exists(composite_dir): 143 | print(f"Create directory {composite_dir}") 144 | os.makedirs(composite_dir) 145 | 146 | heatmap = Image.open(os.path.join(heatmap_dir, save_name+'.jpg')) 147 | heatmap = np.array(heatmap) 148 | 149 | # exclude boundary 150 | heatmap_center = np.zeros_like(heatmap, dtype=np.float_) 151 | hb= int(h/bg_img.height*heatmap.shape[0]/2) 152 | wb = int(w/bg_img.width*heatmap.shape[1]/2) 153 | heatmap_center[hb:-hb, wb:-wb] = heatmap[hb:-hb, wb:-wb] 154 | 155 | # sort pixels in a descending order based on the heatmap 156 | sorted_indices = np.argsort(-heatmap_center, axis=None) 157 | sorted_indices = np.unravel_index(sorted_indices, heatmap_center.shape) 158 | for i in range(composite_num): 159 | y_, x_ = sorted_indices[0][i], sorted_indices[1][i] 160 | x_ = x_/heatmap.shape[1]*bg_img.width 161 | y_ = y_/heatmap.shape[0]*bg_img.height 162 | x = int(x_ - w / 2) 163 | y = int(y_ - h / 2) 164 | # make composite image with foreground, background, and placement 165 | composite_img = make_composite_PIL(fg_tocp, mask_tocp, bg_img, [x, y, w, h]) 166 | save_img_path = os.path.join(composite_dir, f'{save_name}_{int(x_)}_{int(y_)}.jpg') 167 | composite_img.save(save_img_path) 168 | print(save_img_path) 169 | 170 | 171 | if __name__ == "__main__": 172 | 173 | parser = argparse.ArgumentParser() 174 | # "evaluate": calculate F1 and bAcc 175 | # "heatmap": generate FOPA heatmap 176 | # "composite": generate composite images based on the heatmap 177 | parser.add_argument('--mode', type=str, default= "composite") 178 | # datatype: "train" or "test" 179 | parser.add_argument('--datatype', type=str, default= "test") 180 | parser.add_argument('--path', type=str, default= "demo2023-05-19-22:36:47.952468") 181 | parser.add_argument('--epoch', type=int, default= 23) 182 | args = parser.parse_args() 183 | 184 | #full_path = os.path.join('output', args.path, 'pth', f'{args.epoch}_state_final.pth') 185 | full_path = 'best_weight.pth' 186 | 187 | if not os.path.exists(full_path): 188 | print(f'{full_path} does not exist!') 189 | else: 190 | evaluator = Evaluator(arg_config, checkpoint_path=full_path) 191 | if args.mode== "evaluate": 192 | evaluator.evalutate_model(args.datatype) 193 | elif args.mode== "heatmap": 194 | evaluator.get_heatmap(args.datatype) 195 | elif args.mode== "composite": 196 | evaluator.generate_composite(args.datatype, 50) 197 | else: 198 | print(f'There is no {args.mode} mode.') 199 | 200 | -------------------------------------------------------------------------------- /test_multi_fg_scales.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from pprint import pprint 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | import network 11 | from config import arg_config 12 | from data.OBdataset import create_loader, _collect_info 13 | from data.OBdataset import make_composite_PIL 14 | 15 | 16 | class Evaluator: 17 | def __init__(self, args, checkpoint_path): 18 | super(Evaluator, self).__init__() 19 | self.args = args 20 | self.dev = torch.device("cuda:0") 21 | self.to_pil = transforms.ToPILImage() 22 | self.checkpoint_path = checkpoint_path 23 | pprint(self.args) 24 | 25 | print('load pretrained weights from ', checkpoint_path) 26 | self.net = getattr(network, self.args["model"])( 27 | pretrained=False).to(self.dev) 28 | self.net.load_state_dict(torch.load(checkpoint_path, map_location=self.dev), strict=False) 29 | self.net = self.net.to(self.dev).eval() 30 | self.softmax = torch.nn.Softmax(dim=1) 31 | 32 | def get_heatmap_multi_scales(self, fg_scale_num): 33 | ''' 34 | generate heatmap for each pair of scaled foreground and background 35 | ''' 36 | 37 | datatype= f"test_{fg_scale_num}scales" 38 | 39 | save_dir, base_name = os.path.split(self.checkpoint_path) 40 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_heatmap')) 41 | 42 | if not os.path.exists(heatmap_dir): 43 | print(f"Create directory {heatmap_dir}") 44 | os.makedirs(heatmap_dir) 45 | 46 | 47 | json_path = os.path.join('./data/data', f"test_data_{fg_scale_num}scales.json") 48 | 49 | self.ts_loader = create_loader( 50 | json_path, self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 51 | self.args["input_size"], datatype, 1, self.args["num_workers"], False, 52 | ) 53 | 54 | with torch.no_grad(): 55 | for _, test_data in enumerate(tqdm(self.ts_loader)): 56 | _, test_bgs, test_masks, test_fgs, _, nums, composite_list, feature_pos, _, _, save_name = test_data 57 | test_bgs = test_bgs.to(self.dev, non_blocking=True) 58 | test_masks = test_masks.to(self.dev, non_blocking=True) 59 | test_fgs = test_fgs.to(self.dev, non_blocking=True) 60 | nums = nums.to(self.dev, non_blocking=True) 61 | composite_list = composite_list.to(self.dev, non_blocking=True) 62 | feature_pos = feature_pos.to(self.dev, non_blocking=True) 63 | 64 | test_outs, _ = self.net(test_bgs, test_fgs, test_masks, 'test') 65 | test_outs = self.softmax(test_outs) 66 | 67 | test_outs = test_outs[:,1,:,:] 68 | test_outs = transforms.ToPILImage()(test_outs) 69 | test_outs.save(os.path.join(heatmap_dir, save_name[0])) 70 | 71 | def generate_composite_multi_scales(self, fg_scale_num, composite_num): 72 | ''' 73 | generate composite images for each pair of scaled foreground and background 74 | ''' 75 | 76 | fg_scales = list(range(1, fg_scale_num+1)) 77 | fg_scales = [i/(1+fg_scale_num+1) for i in fg_scales] 78 | 79 | icount = 0 80 | 81 | save_dir, base_name = os.path.split(self.checkpoint_path) 82 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_test_{fg_scale_num}scales_heatmap')) 83 | if not os.path.exists(heatmap_dir): 84 | print(f"{heatmap_dir} does not exist! Please first use 'heatmap' mode to generate heatmaps") 85 | 86 | json_path = os.path.join('./data/data', f"test_data_{fg_scale_num}scales.json") 87 | 88 | data = _collect_info(json_path, self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 'test') 89 | for index in range(len(data)): 90 | _, _, bg_path, fg_path, _, scale, _, _, fg_path_2, mask_path_2, w, h = data[index] 91 | 92 | fg_name = fg_path.split('/')[-1][:-4] 93 | save_name = fg_name + '_' + str(scale) 94 | segs = fg_name.split('_') 95 | fg_id, bg_id = segs[0], segs[1] 96 | if icount==0: 97 | 98 | bg_img = Image.open(bg_path) 99 | if len(bg_img.split()) != 3: 100 | bg_img = bg_img.convert("RGB") 101 | fg_tocp = Image.open(fg_path_2).convert("RGB") 102 | mask_tocp = Image.open(mask_path_2).convert("RGB") 103 | 104 | composite_dir = os.path.join(save_dir, base_name.replace('.pth', f'_test_{fg_scale_num}scales_composite'), f'{fg_id}_{bg_id}') 105 | if not os.path.exists(composite_dir): 106 | print(f"Create directory {composite_dir}") 107 | os.makedirs(composite_dir) 108 | 109 | heatmap_center_list = [] 110 | fg_size_list = [] 111 | 112 | icount += 1 113 | heatmap = Image.open(os.path.join(heatmap_dir, save_name+'.jpg')) 114 | heatmap = np.array(heatmap) 115 | # exclude boundary 116 | heatmap_center = np.zeros_like(heatmap, dtype=np.float_) 117 | hb= int(h/bg_img.height*heatmap.shape[0]/2) 118 | wb = int(w/bg_img.width*heatmap.shape[1]/2) 119 | heatmap_center[hb:-hb, wb:-wb] = heatmap[hb:-hb, wb:-wb] 120 | heatmap_center_list.append(heatmap_center) 121 | fg_size_list.append((h,w)) 122 | 123 | if icount==fg_scale_num: 124 | icount = 0 125 | heatmap_center_stack = np.stack(heatmap_center_list) 126 | # sort pixels in a descending order based on the heatmap 127 | sorted_indices = np.argsort(-heatmap_center_stack, axis=None) 128 | sorted_indices = np.unravel_index(sorted_indices, heatmap_center_stack.shape) 129 | for i in range(composite_num): 130 | iscale, y_, x_ = sorted_indices[0][i], sorted_indices[1][i], sorted_indices[2][i] 131 | h, w = fg_size_list[iscale] 132 | x_ = x_/heatmap.shape[1]*bg_img.width 133 | y_ = y_/heatmap.shape[0]*bg_img.height 134 | x = int(x_ - w / 2) 135 | y = int(y_ - h / 2) 136 | # make composite image with foreground, background, and placement 137 | composite_img, composite_msk = make_composite_PIL(fg_tocp, mask_tocp, bg_img, [x, y, w, h], return_mask=True) 138 | save_img_path = os.path.join(composite_dir, f'{fg_id}_{bg_id}_{x}_{y}_{w}_{h}.jpg') 139 | save_msk_path = os.path.join(composite_dir, f'{fg_id}_{bg_id}_{x}_{y}_{w}_{h}.png') 140 | composite_img.save(save_img_path) 141 | composite_msk.save(save_msk_path) 142 | print(save_img_path) 143 | 144 | 145 | 146 | if __name__ == "__main__": 147 | print("cuda: ", torch.cuda.is_available()) 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--mode', type=str, default= "composite") 150 | parser.add_argument('--path', type=str, default= "demo2023-05-19-22:36:47.952468") 151 | parser.add_argument('--epoch', type=int, default= 20) 152 | args = parser.parse_args() 153 | 154 | fg_scale_num = 16 155 | composite_num = 50 156 | 157 | full_path = os.path.join('output', args.path, 'pth', f'{args.epoch}_state_final.pth') 158 | 159 | if not os.path.exists(full_path): 160 | print(f'{full_path} does not exist!') 161 | else: 162 | evaluator = Evaluator(arg_config, checkpoint_path=full_path) 163 | if args.mode== "heatmap": 164 | evaluator.get_heatmap_multi_scales(fg_scale_num) 165 | elif args.mode== "composite": 166 | evaluator.generate_composite_multi_scales(fg_scale_num, composite_num) 167 | 168 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from datetime import datetime 4 | from pprint import pprint 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as torchcudnn 9 | from torch.nn import CrossEntropyLoss 10 | from torch.optim import SGD, Adam 11 | from torchvision import transforms 12 | 13 | import argparse 14 | import random 15 | import network 16 | import tensorboard_logger as tb_logger 17 | import torch.nn as nn 18 | 19 | from backbone.ResNet import pretrained_resnet18_4ch 20 | from config import arg_config, proj_root 21 | from data.OBdataset import create_loader 22 | from utils.misc import AvgMeter, construct_path_dict, make_log, pre_mkdir 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--ex_name', type=str, default=arg_config["ex_name"]) 26 | parser.add_argument('--alpha', type=float, default=16.) 27 | parser.add_argument('--resume', type=bool, help='resume from checkpoint') 28 | 29 | user_args = parser.parse_args() 30 | datetime_str = str(datetime.now()) 31 | datetime_str = '-'.join(datetime_str.split()) 32 | user_args.ex_name += datetime_str 33 | 34 | def setup_seed(seed): 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | torch.backends.cudnn.deterministic = True 40 | 41 | # set random seed 42 | setup_seed(0) 43 | torchcudnn.benchmark = True 44 | torchcudnn.enabled = True 45 | torchcudnn.deterministic = True 46 | 47 | 48 | class Trainer: 49 | def __init__(self, args): 50 | super(Trainer, self).__init__() 51 | self.args = args 52 | self.to_pil = transforms.ToPILImage() 53 | pprint(self.args) 54 | 55 | self.path = construct_path_dict(proj_root=proj_root, exp_name=user_args.ex_name) # self.args["Experiment_name"]) 56 | pre_mkdir(path_config=self.path) 57 | 58 | # backup used file 59 | shutil.copy(f"{proj_root}/config.py", self.path["cfg_log"]) 60 | shutil.copy(f"{proj_root}/train.py", self.path["trainer_log"]) 61 | shutil.copy(f"{proj_root}/data/OBdataset.py", self.path["dataset_log"]) 62 | shutil.copy(f"{proj_root}/network/ObPlaNet_simple.py", self.path["network_log"]) 63 | 64 | # training data loader 65 | self.tr_loader = create_loader( 66 | self.args["train_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 67 | self.args["input_size"], 'train', self.args["batch_size"], self.args["num_workers"], True, 68 | ) 69 | 70 | # load model 71 | self.dev = torch.device(f'cuda:{arg_config["gpu_id"]}') 72 | self.net = getattr(network, self.args["model"])(pretrained=True).to(self.dev) 73 | 74 | # loss functions 75 | self.loss = CrossEntropyLoss(ignore_index=255, reduction=self.args["reduction"]).to(self.dev) 76 | 77 | # optimizer 78 | self.opti = self.make_optim() 79 | 80 | # record loss 81 | tb_logger.configure(self.path['pth_log'], flush_secs=5) 82 | 83 | self.end_epoch = self.args["epoch_num"] 84 | if user_args.resume: 85 | try: 86 | self.resume_checkpoint(load_path=self.path["final_full_net"], mode="all") 87 | except: 88 | print(f"{self.path['final_full_net']} does not exist and we will load {self.path['final_state_net']}") 89 | self.resume_checkpoint(load_path=self.path["final_state_net"], mode="onlynet") 90 | self.start_epoch = self.end_epoch 91 | else: 92 | self.start_epoch = 0 93 | self.iter_num = self.end_epoch * len(self.tr_loader) 94 | 95 | 96 | def train(self): 97 | 98 | for curr_epoch in range(self.start_epoch, self.end_epoch): 99 | self.net.train() 100 | train_loss_record = AvgMeter() 101 | mimicking_loss_record = AvgMeter() 102 | 103 | # change learning rate 104 | if self.args["lr_type"] == "poly": 105 | self.change_lr(curr_epoch) 106 | elif self.args["lr_type"] == "decay": 107 | self.change_lr(curr_epoch) 108 | elif self.args["lr_type"] == "all_decay": 109 | self.change_lr(curr_epoch) 110 | else: 111 | raise NotImplementedError 112 | 113 | for train_batch_id, train_data in enumerate(self.tr_loader): 114 | curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id 115 | 116 | self.opti.zero_grad() 117 | 118 | _, train_bgs, train_masks, train_fgs, train_targets, num, composite_list, feature_pos, _, _, _ = train_data 119 | 120 | train_bgs = train_bgs.to(self.dev, non_blocking=True) 121 | train_masks = train_masks.to(self.dev, non_blocking=True) 122 | train_fgs = train_fgs.to(self.dev, non_blocking=True) 123 | train_targets = train_targets.to(self.dev, non_blocking=True) 124 | num = num.to(self.dev, non_blocking=True) 125 | composite_list = composite_list.to(self.dev, non_blocking=True) 126 | feature_pos = feature_pos.to(self.dev, non_blocking=True) 127 | 128 | # model training 129 | train_outs, feature_map = self.net(train_bgs, train_fgs, train_masks, 'train') 130 | 131 | mimicking_loss = feature_mimicking(composite_list, feature_pos, feature_map, num, self.dev) 132 | out_loss = self.loss(train_outs, train_targets.long()) 133 | train_loss = out_loss + user_args.alpha*mimicking_loss 134 | train_loss.backward() 135 | self.opti.step() 136 | 137 | train_iter_loss = out_loss.item() 138 | mimicking_iter_loss = mimicking_loss.item() 139 | train_batch_size = train_bgs.size(0) 140 | train_loss_record.update(train_iter_loss, train_batch_size) 141 | mimicking_loss_record.update(mimicking_iter_loss, train_batch_size) 142 | 143 | tb_logger.log_value('loss', train_loss.item(), step=self.net.Eiters) 144 | 145 | if self.args["print_freq"] > 0 and (curr_iter + 1) % self.args["print_freq"] == 0: 146 | log = ( 147 | f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>" 148 | f"(L2)[Avg:{train_loss_record.avg:.3f}|Cur:{train_iter_loss:.3f}]" 149 | f"(Lm)[Avg:{mimicking_loss_record.avg:.3f}][Cur:{mimicking_iter_loss:.3f}]" 150 | ) 151 | print(log) 152 | make_log(self.path["tr_log"], log) 153 | 154 | save_dir, save_name = os.path.split(self.path["final_full_net"]) 155 | epoch_full_net_path = os.path.join(save_dir, str(curr_epoch + 1)+'_'+save_name) 156 | save_dir, save_name = os.path.split(self.path["final_state_net"]) 157 | epoch_state_net_path = os.path.join(save_dir, str(curr_epoch + 1)+'_'+save_name) 158 | 159 | self.save_checkpoint(curr_epoch + 1, full_net_path=epoch_full_net_path, state_net_path=epoch_state_net_path) 160 | 161 | 162 | def change_lr(self, curr): 163 | total_num = self.end_epoch 164 | if self.args["lr_type"] == "poly": 165 | ratio = pow((1 - float(curr) / total_num), self.args["lr_decay"]) 166 | self.opti.param_groups[0]["lr"] = self.opti.param_groups[0]["lr"] * ratio 167 | self.opti.param_groups[1]["lr"] = self.opti.param_groups[0]["lr"] 168 | elif self.args["lr_type"] == "decay": 169 | ratio = 0.1 170 | if (curr % 9 == 0): 171 | self.opti.param_groups[0]["lr"] = self.opti.param_groups[0]["lr"] * ratio 172 | self.opti.param_groups[1]["lr"] = self.opti.param_groups[0]["lr"] 173 | elif self.args["lr_type"] == "all_decay": 174 | lr = self.args["lr"] * (0.5 ** (curr // 2)) 175 | for param_group in self.opti.param_groups: 176 | param_group['lr'] = lr 177 | else: 178 | raise NotImplementedError 179 | 180 | def make_optim(self): 181 | if self.args["optim"] == "sgd_trick": 182 | params = [ 183 | { 184 | "params": [p for name, p in self.net.named_parameters() if ("bias" in name or "bn" in name)], 185 | "weight_decay": 0, 186 | }, 187 | { 188 | "params": [ 189 | p for name, p in self.net.named_parameters() if ("bias" not in name and "bn" not in name) 190 | ] 191 | }, 192 | ] 193 | optimizer = SGD( 194 | params, 195 | lr=self.args["lr"], 196 | momentum=self.args["momentum"], 197 | weight_decay=self.args["weight_decay"], 198 | nesterov=self.args["nesterov"], 199 | ) 200 | elif self.args["optim"] == "f3_trick": 201 | backbone, head = [], [] 202 | for name, params_tensor in self.net.named_parameters(): 203 | if "encoder" in name: 204 | backbone.append(params_tensor) 205 | else: 206 | head.append(params_tensor) 207 | params = [ 208 | {"params": backbone, "lr": 0.1 * self.args["lr"]}, 209 | {"params": head, "lr": self.args["lr"]}, 210 | ] 211 | optimizer = SGD( 212 | params=params, 213 | momentum=self.args["momentum"], 214 | weight_decay=self.args["weight_decay"], 215 | nesterov=self.args["nesterov"], 216 | ) 217 | elif self.args["optim"] == "Adam_trick": 218 | optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.args["lr"]) 219 | else: 220 | raise NotImplementedError 221 | print("optimizer = ", optimizer) 222 | return optimizer 223 | 224 | def save_checkpoint(self, current_epoch, full_net_path, state_net_path): 225 | state_dict = { 226 | "epoch": current_epoch, 227 | "net_state": self.net.state_dict(), 228 | "opti_state": self.opti.state_dict(), 229 | } 230 | torch.save(state_dict, full_net_path) 231 | torch.save(self.net.state_dict(), state_net_path) 232 | 233 | def resume_checkpoint(self, load_path, mode="all"): 234 | """ 235 | Args: 236 | load_path (str): path of pretrained model 237 | mode (str): 'all':resume all information;'onlynet':only resume model parameters 238 | """ 239 | if os.path.exists(load_path) and os.path.isfile(load_path): 240 | print(f" =>> loading checkpoint '{load_path}' <<== ") 241 | checkpoint = torch.load(load_path, map_location=self.dev) 242 | if mode == "all": 243 | self.start_epoch = checkpoint["epoch"] 244 | self.net.load_state_dict(checkpoint["net_state"]) 245 | self.opti.load_state_dict(checkpoint["opti_state"]) 246 | print(f" ==> loaded checkpoint '{load_path}' (epoch {checkpoint['epoch']})") 247 | elif mode == "onlynet": 248 | self.net.load_state_dict(checkpoint) 249 | print(f" ==> loaded checkpoint '{load_path}' " f"(only has the net's weight params) <<== ") 250 | else: 251 | raise NotImplementedError 252 | else: 253 | raise Exception(f"{load_path} is not correct.") 254 | 255 | 256 | def feature_mimicking(composites, feature_pos, feature_map, num, device): 257 | 258 | net_ = pretrained_resnet18_4ch(pretrained=True).to(device) 259 | 260 | composite_cat_list = [] 261 | pos_feature = torch.zeros(int(num.sum()), 512, 1, 1).to(device) 262 | count = 0 263 | for i in range(num.shape[0]): 264 | composite_cat_list.append(composites[i, :num[i], :, :, :]) 265 | for j in range(num[i]): 266 | pos_feature[count, :, 0, 0] = feature_map[i, :, int(feature_pos[i, j, 1]), int(feature_pos[i, j, 0])] 267 | count += 1 268 | composites_ = torch.cat(composite_cat_list, dim=0) 269 | composite_feature = net_(composites_) 270 | composite_feature = nn.AdaptiveAvgPool2d(1)(composite_feature) 271 | pos_feature.view(-1, 512) 272 | composite_feature.view(-1, 512) 273 | 274 | mimicking_loss_criter = nn.MSELoss() 275 | mimicking_loss = mimicking_loss_criter(pos_feature, composite_feature) 276 | 277 | return mimicking_loss 278 | 279 | 280 | if __name__ == "__main__": 281 | trainer = Trainer(arg_config) 282 | print(f" ===========>> {datetime.now()}: begin training <<=========== ") 283 | trainer.train() 284 | print(f" ===========>> {datetime.now()}: end training <<=========== ") 285 | 286 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | class AvgMeter(object): 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | 20 | 21 | def pre_mkdir(path_config): 22 | check_mkdir(path_config["pth_log"]) 23 | check_mkdir(path_config["pth"]) 24 | make_log(path_config["te_log"], f"=== te_log {datetime.now()} ===") 25 | make_log(path_config["tr_log"], f"=== tr_log {datetime.now()} ===") 26 | 27 | 28 | def check_mkdir(dir_name): 29 | if not os.path.exists(dir_name): 30 | os.makedirs(dir_name) 31 | 32 | 33 | def make_log(path, context): 34 | with open(path, "a") as log: 35 | log.write(f"{context}\n") 36 | 37 | 38 | def check_dir_path_valid(path: list): 39 | for p in path: 40 | if p: 41 | assert os.path.exists(p) 42 | assert os.path.isdir(p) 43 | 44 | 45 | def construct_path_dict(proj_root, exp_name): 46 | ckpt_path = os.path.join(proj_root, "output") 47 | 48 | pth_log_path = os.path.join(ckpt_path, exp_name) 49 | 50 | tb_path = os.path.join(pth_log_path, "tb") 51 | save_path = os.path.join(pth_log_path, "pre") 52 | pth_path = os.path.join(pth_log_path, "pth") 53 | 54 | final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth.tar") 55 | final_state_path = os.path.join(pth_path, "state_final.pth") 56 | 57 | tr_log_path = os.path.join(pth_log_path, f"tr_{str(datetime.now())[:10]}.txt") 58 | te_log_path = os.path.join(pth_log_path, f"te_{str(datetime.now())[:10]}.txt") 59 | cfg_log_path = os.path.join(pth_log_path, f"cfg_{str(datetime.now())[:10]}.txt") 60 | trainer_log_path = os.path.join(pth_log_path, f"trainer_{str(datetime.now())[:10]}.txt") 61 | dataset_log_path = os.path.join(pth_log_path, f"dataset_{str(datetime.now())[:10]}.txt") 62 | network_log_path = os.path.join(pth_log_path, f"network_{str(datetime.now())[:10]}.txt") 63 | 64 | path_config = { 65 | "ckpt_path": ckpt_path, 66 | "pth_log": pth_log_path, 67 | "tb": tb_path, 68 | "save": save_path, 69 | "pth": pth_path, 70 | "final_full_net": final_full_model_path, 71 | "final_state_net": final_state_path, 72 | "tr_log": tr_log_path, 73 | "te_log": te_log_path, 74 | "cfg_log": cfg_log_path, 75 | "trainer_log": trainer_log_path, 76 | "dataset_log": dataset_log_path, 77 | "network_log": network_log_path, 78 | } 79 | 80 | return path_config 81 | --------------------------------------------------------------------------------