├── LICENSE
├── README.md
├── backbone
├── ResNet.py
└── __init__.py
├── config.py
├── data
├── OBdataset.py
├── all_transforms.py
└── data
│ ├── test_data_pair.csv
│ ├── test_pair_new.json
│ ├── train_data_pair.csv
│ └── train_pair_new.json
├── network
├── BaseBlocks.py
├── DynamicModules.py
├── ObPlaNet_simple.py
├── __init__.py
└── tensor_ops.py
├── prepare_multi_fg_scales.py
├── requirements.txt
├── test.py
├── test_multi_fg_scales.py
├── train.py
└── utils
└── misc.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 BCMI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **FOPA: Fast Object Placement Assessment**
2 | =====
3 | This is the PyTorch implementation of **FOPA** for the following research paper. **FOPA is the first discriminative approach for object placement task.**
4 | > **Fast Object Placement Assessment** [[arXiv]](https://arxiv.org/pdf/2205.14280.pdf)
5 | >
6 | > Li Niu, Qingyang Liu, Zhenchen Liu, Jiangtong Li
7 |
8 | **Our FOPA has been integrated into our image composition toolbox libcom https://github.com/bcmi/libcom. Welcome to visit and try \(^▽^)/**
9 |
10 | If you want to change the backbone to transformer, you can refer to [TopNet](https://github.com/bcmi/TopNet-Object-Placement).
11 |
12 | ## Setup
13 | All the code have been tested on PyTorch 1.7.0. Follow the instructions to run the project.
14 |
15 | First, clone the repository:
16 | ```
17 | git clone git@github.com:bcmi/FOPA-Fast-Object-Placement-Assessment.git
18 | ```
19 | Then, install Anaconda and create a virtual environment:
20 | ```
21 | conda create -n fopa
22 | conda activate fopa
23 | ```
24 | Install PyTorch 1.7.0 (higher version should be fine):
25 | ```
26 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.2 -c pytorch
27 | ```
28 | Install necessary packages:
29 | ```
30 | pip install -r requirements.txt
31 | ```
32 |
33 |
34 | ## Data Preparation
35 | Download and extract data from [Baidu Cloud](https://pan.baidu.com/s/10JBpXBMZybEl5FTqBlq-hQ) (access code: 4zf9) or [Dropbox](https://www.dropbox.com/scl/fi/c05wk038piy224sba6jpi/data.rar?rlkey=tghrxjjgo2g93le64tb1xymvq&st=u9nf6hbf&dl=0).
36 | Download the SOPA encoder from [Baidu Cloud](https://pan.baidu.com/s/1hQGm3ryRONRZpNpU66SJZA) (access code: 1x3n) or [Dropbox](https://www.dropbox.com/scl/fi/tlkbmqebokjloe0i1yfpy/SOPA.pth.tar?rlkey=8mzzc53wy6rjqz69o5lkzusau&st=32t23vwm&dl=0).
37 | Put them in "data/data". It should contain the following directories and files:
38 | ```
39 |
40 | bg/ # background images
41 | fg/ # foreground images
42 | mask/ # foreground masks
43 | train(test)_pair_new.json # json annotations
44 | train(test)_pair_new.csv # csv files
45 | SOPA.pth.tar # SOPA encoder
46 | ```
47 |
48 | Download our pretrained model from [Baidu Cloud](https://pan.baidu.com/s/15-OBaYE0CF-nDoJrNcCRaw) (access code: uqvb) or [Dropbox](https://www.dropbox.com/scl/fi/q3i6fryoumzr15piuq9pr/best_weight.pth?rlkey=wahho3h18k3ntsaw9pvdyfvea&st=vp2dhpa5&dl=0), and put it in './best_weight.pth'.
49 |
50 | ## Training
51 | Before training, modify "config.py" according to your need. After that, run:
52 | ```
53 | python train.py
54 | ```
55 |
56 | ## Test
57 | To get the F1 score and balanced accuracy of a specified model, run:
58 | ```
59 | python test.py --mode evaluate
60 | ```
61 |
62 | The results obtained with our released model should be F1: 0.778302, bAcc: 0.838696.
63 |
64 |
65 | To get the heatmaps predicted by FOPA, run:
66 | ```
67 | python test.py --mode heatmap
68 | ```
69 |
70 | To get the optimal composite images based on the predicted heatmaps, run:
71 | ```
72 | python test.py --mode composite
73 | ```
74 |
75 |
76 | ## Multiple Foreground Scales
77 | For testing multi-scale foregrounds for each foreground-background pair, first run the following command to generate 'test_data_16scales.json' in './data/data' and 'test_16scales' in './data/data/fg', './data/data/mask'.
78 | ```
79 | python prepare_multi_fg_scales.py
80 | ```
81 |
82 | Then, to get the heatmaps of multi-scale foregrounds for each foreground-background pair, run:
83 | ```
84 | python test_multi_fg_scales.py --mode heatmap
85 | ```
86 |
87 | Finally, to get the composite images with top scores for each foreground-background pair, run:
88 | ```
89 | python test_multi_fg_scales.py --mode composite
90 | ```
91 |
92 | ## Evalution on Discriminative Task
93 |
94 | We show the results reported in the paper. FOPA can achieve comparable results with SOPA.
95 |
96 |
97 |
98 | Method |
99 | F1 |
100 | bAcc |
101 |
102 |
103 |
104 |
105 | SOPA |
106 | 0.780 |
107 | 0.842 |
108 |
109 |
110 | FOPA |
111 | 0.776 |
112 | 0.840 |
113 |
114 |
115 |
116 |
117 |
118 | ## Evalution on Generation Task
119 |
120 | Given each background-foreground pair in the test set, we predict 16 rationality score maps for 16 foreground scales and generate composite images with top 50 rationality scores. Then, we randomly sample one from 50 generated composite images per background-foreground pair for Acc and FID evaluation, using the test scripts provided by [GracoNet](https://github.com/bcmi/GracoNet-Object-Placement). The generated composite images for evaluation can be downloaded from [Baidu Cloud](https://pan.baidu.com/s/1qqDiXF4tEhizEoI_2BwkrA) (access code: ppft) or [Google Drive](https://drive.google.com/file/d/1yvuoVum_-FMK7lOvrvpx35IdvrV58bTm/view?usp=share_link). The test results of baselines and our method are shown below:
121 |
122 |
123 |
124 |
125 | Method |
126 | Acc |
127 | FID |
128 |
129 |
130 |
131 |
132 | TERSE |
133 | 0.679 |
134 | 46.94 |
135 |
136 |
137 | PlaceNet |
138 | 0.683 |
139 | 36.69 |
140 |
141 |
142 | GracoNet |
143 | 0.847 |
144 | 27.75 |
145 |
146 |
147 | IOPRE |
148 | 0.895 |
149 | 21.59 |
150 |
151 |
152 | FOPA |
153 | 0.932 |
154 | 19.76 |
155 |
156 |
157 |
158 |
159 | ## Other Resources
160 |
161 | + [Awesome-Object-Placement](https://github.com/bcmi/Awesome-Object-Placement)
162 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion)
163 |
164 |
165 | ## Bibtex
166 |
167 | If you find this work useful for your research, please cite our paper using the following BibTeX [[arxiv](https://arxiv.org/pdf/2107.01889.pdf)]:
168 |
169 | ```
170 | @article{niu2022fast,
171 | title={Fast Object Placement Assessment},
172 | author={Niu, Li and Liu, Qingyang and Liu, Zhenchen and Li, Jiangtong},
173 | journal={arXiv preprint arXiv:2205.14280},
174 | year={2022}
175 | }
176 | ```
177 |
--------------------------------------------------------------------------------
/backbone/ResNet.py:
--------------------------------------------------------------------------------
1 | # import torchvision.models as models
2 | # import torch.nn as nn
3 | # # https://pytorch.org/docs/stable/torchvision/models.html#id3
4 | #
5 | import torch
6 | import torch.nn as nn
7 | import torch.utils.model_zoo as model_zoo
8 |
9 |
10 | model_urls = {
11 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
12 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
13 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
14 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
15 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
16 | }
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 |
22 |
23 | def conv1x1(in_planes, out_planes, stride=1):
24 | """1x1 convolution"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None):
32 | super(BasicBlock, self).__init__()
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = nn.BatchNorm2d(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | identity = x
43 |
44 | out = self.conv1(x)
45 | out = self.bn1(out)
46 | out = self.relu(out)
47 |
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | identity = self.downsample(x)
53 |
54 | out += identity
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None):
64 | super(Bottleneck, self).__init__()
65 | self.conv1 = conv1x1(inplanes, planes)
66 | self.bn1 = nn.BatchNorm2d(planes)
67 | self.conv2 = conv3x3(planes, planes, stride)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = conv1x1(planes, planes * self.expansion)
70 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | identity = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | identity = self.downsample(x)
91 |
92 | out += identity
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 | def __init__(self, block, layers, zero_init_residual=False):
100 | super(ResNet, self).__init__()
101 | self.inplanes = 64
102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 6
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 3
110 |
111 | for m in self.modules():
112 | if isinstance(m, nn.Conv2d):
113 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
114 | elif isinstance(m, nn.BatchNorm2d):
115 | nn.init.constant_(m.weight, 1)
116 | nn.init.constant_(m.bias, 0)
117 |
118 | # Zero-initialize the last BN in each residual branch,
119 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
120 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
121 | if zero_init_residual:
122 | for m in self.modules():
123 | if isinstance(m, Bottleneck):
124 | nn.init.constant_(m.bn3.weight, 0)
125 | elif isinstance(m, BasicBlock):
126 | nn.init.constant_(m.bn2.weight, 0)
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample))
137 | self.inplanes = planes * block.expansion
138 | for _ in range(1, blocks):
139 | layers.append(block(self.inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.relu(x)
147 | x = self.maxpool(x)
148 |
149 | x = self.layer1(x)
150 | x = self.layer2(x)
151 | x = self.layer3(x)
152 | x = self.layer4(x)
153 |
154 | return x
155 |
156 | def resnet18(pretrained=False, **kwargs):
157 | """Constructs a ResNet-18 model.
158 |
159 | Args:
160 | pretrained (bool): If True, returns a model pre-trained on ImageNet
161 | """
162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
163 | if pretrained:
164 | pretrained_dict = model_zoo.load_url(model_urls["resnet18"])
165 |
166 | model_dict = model.state_dict()
167 | # 1. filter out unnecessary keys
168 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
169 | # 2. overwrite entries in the existing state dict
170 | model_dict.update(pretrained_dict)
171 | # 3. load the new state dict
172 | model.load_state_dict(model_dict)
173 | return model
174 |
175 |
176 | def Backbone_ResNet18_in3(pretrained=True):
177 | if pretrained:
178 | print("The backbone model loads the pretrained parameters...")
179 | net = pretrained_resnet18_4ch(pretrained=True)
180 | div_2 = nn.Sequential(*list(net.children())[:3])
181 | div_4 = nn.Sequential(*list(net.children())[3:5])
182 | div_8 = net.layer2
183 | div_16 = net.layer3
184 | div_32 = net.layer4
185 |
186 | return div_2, div_4, div_8, div_16, div_32
187 |
188 |
189 | def Backbone_ResNet18_in3_1(pretrained=True):
190 | if pretrained:
191 | print("The backbone model loads the pretrained parameters...")
192 | net = resnet18(pretrained=pretrained)
193 |
194 | model_dict = net.state_dict()
195 | conv1 = model_dict['conv1.weight']
196 | new = torch.zeros(64, 1, 7, 7)
197 | for i, output_channel in enumerate(conv1):
198 | new[i] = 0.299 * output_channel[0] + 0.587 * output_channel[1] + 0.114 * output_channel[2]
199 | net.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=3, bias=False)
200 | model_dict['conv1.weight'] = torch.cat((conv1, new), dim=1)
201 | net.load_state_dict(model_dict)
202 |
203 | div_1 = nn.Sequential(*list(net.children())[:1])
204 | div_2 = nn.Sequential(*list(net.children())[1:3])
205 | div_4 = nn.Sequential(*list(net.children())[3:5])
206 | div_8 = net.layer2
207 | div_16 = net.layer3
208 | # div_32 = make_layer_4(BasicBlock, 448, 2, stride=2)
209 | div_32 = net.layer4
210 | return div_1, div_2, div_4, div_8, div_16, div_32
211 |
212 |
213 |
214 | def pretrained_resnet18_4ch(pretrained=True, **kwargs):
215 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
216 | model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=3, bias=False)
217 |
218 | if pretrained:
219 | # load the pretrained binary classification model for slow object placement assessment (SOPA)
220 | checkpoint = torch.load('./data/data/SOPA.pth.tar')
221 | model.load_state_dict(checkpoint['state_dict'], strict=False)
222 |
223 | return model
224 |
225 |
226 |
227 |
228 |
--------------------------------------------------------------------------------
/backbone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/FOPA-Fast-Object-Placement-Assessment/7f990e06b6b234bfd1e107a30067b610c217c915/backbone/__init__.py
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | __all__ = ["proj_root", "arg_config"]
4 |
5 | proj_root = os.path.dirname(__file__)
6 | datasets_root = "./data/data"
7 |
8 | tr_data_path = os.path.join(datasets_root, "train_pair_new.json")
9 | ts_data_path = os.path.join(datasets_root, "test_pair_new.json")
10 |
11 | coco_dir = './data/data/train2017'
12 | bg_dir = os.path.join(datasets_root, "bg")
13 | fg_dir = os.path.join(datasets_root, "fg")
14 | mask_dir = os.path.join(datasets_root, "mask")
15 |
16 | arg_config = {
17 |
18 | "model": "ObPlaNet_resnet18", # model name
19 | "epoch_num": 25,
20 | "lr": 0.0005,
21 | "train_data_path": tr_data_path,
22 | "test_data_path": ts_data_path,
23 | "bg_dir": bg_dir,
24 | "fg_dir": fg_dir,
25 | "mask_dir": mask_dir,
26 |
27 | "print_freq": 10, # >0, frequency of log print
28 | "prefix": (".jpg", ".png"),
29 | "reduction": "mean", # “mean” or “sum”
30 | "optim": "Adam_trick", # optimizer
31 | "weight_decay": 5e-4, # set as 0.0001 when finetuning
32 | "momentum": 0.9,
33 | "nesterov": False,
34 | "lr_type": "all_decay", # learning rate schedule
35 | "lr_decay": 0.9, # poly
36 | "batch_size": 8,
37 | "num_workers": 6,
38 | "input_size": 256, # input size
39 | "gpu_id": 0,
40 | "ex_name":"demo", # experiment name
41 | }
42 |
--------------------------------------------------------------------------------
/data/OBdataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from torch.utils.data import DataLoader, Dataset
8 | from torchvision import transforms
9 | from data.all_transforms import Compose, JointResize
10 |
11 |
12 | class CPDataset(Dataset):
13 | def __init__(self, file, bg_dir, fg_dir, mask_dir, in_size, datatype='train'):
14 | """
15 | initialize dataset
16 |
17 | Args:
18 | file(str): file with training/test data information
19 | bg_dir(str): folder with background images
20 | fg_dir(str): folder with foreground images
21 | mask_dir(str): folder with mask images
22 | in_size(int): input size of network
23 | datatype(str): "train" or "test"
24 | """
25 |
26 | self.datatype = datatype
27 | self.data = _collect_info(file, bg_dir, fg_dir, mask_dir, datatype)
28 | self.insize = in_size
29 |
30 | self.train_triple_transform = Compose([JointResize(in_size)])
31 | self.train_img_transform = transforms.Compose(
32 | [
33 | transforms.ToTensor(),
34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # 处理的是Tensor
35 | ]
36 | )
37 | self.train_mask_transform = transforms.ToTensor()
38 |
39 | self.transforms_flip = transforms.Compose([
40 | transforms.RandomHorizontalFlip(p=1)
41 | ])
42 |
43 | def __len__(self):
44 | return len(self.data)
45 |
46 | def __getitem__(self, index):
47 | """
48 | load each item
49 | return:
50 | i: the image index,
51 | bg_t:(1 * 3 * in_size * in_size) background image,
52 | mask_t:(1 * 1 * in_size * in_size) scaled foreground mask
53 | fg_t:(1 * 3 * in_size * in_size) scaled foreground image
54 | target_t: (1 * in_size * in_size) pixel-wise binary labels
55 | labels_num: (int) the number of annotated pixels
56 | """
57 | i, _, bg_path, fg_path, mask_path, scale, pos_label, neg_label, fg_path_2, mask_path_2, w, h = self.data[index]
58 |
59 | fg_name = fg_path.split('/')[-1][:-4]
60 | ## save_name: fg_bg_w_h_scale.jpg
61 | save_name = fg_name + '_' + str(scale) + '.jpg'
62 |
63 | bg_img = Image.open(bg_path)
64 | fg_img = Image.open(fg_path)
65 | mask = Image.open(mask_path)
66 | if len(bg_img.split()) != 3:
67 | bg_img = bg_img.convert("RGB")
68 | if len(fg_img.split()) == 3:
69 | fg_img = fg_img.convert("RGB")
70 | if len(mask.split()) == 3:
71 | mask = mask.convert("L")
72 |
73 | is_flip = False
74 | if self.datatype == 'train' and np.random.uniform() < 0.5:
75 | is_flip = True
76 |
77 | # make composite images which are used in feature mimicking
78 | fg_tocp = Image.open(fg_path_2).convert("RGB")
79 | mask_tocp = Image.open(mask_path_2).convert("L")
80 | composite_list = []
81 | for pos in pos_label:
82 | x_, y_ = pos
83 | x = int(x_ - w / 2)
84 | y = int(y_ - h / 2)
85 | composite_list.append(make_composite(fg_tocp, mask_tocp, bg_img, [x, y, w, h], is_flip))
86 |
87 | for pos in neg_label:
88 | x_, y_ = pos
89 | x = int(x_ - w / 2)
90 | y = int(y_ - h / 2)
91 | composite_list.append(make_composite(fg_tocp, mask_tocp, bg_img, [x, y, w, h], is_flip))
92 |
93 | composite_list_ = torch.stack(composite_list, dim=0)
94 | composite_cat = torch.zeros(50 - len(composite_list), 4, 256, 256)
95 | composite_list = torch.cat((composite_list_, composite_cat), dim=0)
96 |
97 | # positive pixels are 1, negative pixels are 0, other pixels are 255
98 | # feature_pos: record the positions of annotated pixels
99 | target, feature_pos = _obtain_target(bg_img.size[0], bg_img.size[1], self.insize, pos_label, neg_label, is_flip)
100 | for i in range(50 - len(feature_pos)):
101 | feature_pos.append((0, 0)) # pad the length to 50
102 | feature_pos = torch.Tensor(feature_pos)
103 |
104 | # resize the foreground/background to 256, convert them to tensors
105 | bg_t, fg_t, mask_t = self.train_triple_transform(bg_img, fg_img, mask)
106 | mask_t = self.train_mask_transform(mask_t)
107 | fg_t = self.train_img_transform(fg_t)
108 | bg_t = self.train_img_transform(bg_t)
109 |
110 | if is_flip == True:
111 | fg_t = self.transforms_flip(fg_t)
112 | bg_t = self.transforms_flip(bg_t)
113 | mask_t = self.transforms_flip(mask_t)
114 |
115 | # tensor is normalized to [0,1],map back to [0, 255] for ease of computation
116 | target_t = self.train_mask_transform(target) * 255
117 | labels_num = (target_t != 255).sum()
118 |
119 | return i, bg_t, mask_t, fg_t, target_t.squeeze(), labels_num, composite_list, feature_pos, w, h, save_name
120 |
121 |
122 | def _obtain_target(original_width, original_height, in_size, pos_label, neg_label, isflip=False):
123 | """
124 | put 0, 1 labels on a 256x256 score map
125 | Args:
126 | original_width(int): width of original background
127 | original_height(int): height of original background
128 | in_size(int): input size of network
129 | pos_label(list): positive pixels in original background
130 | neg_label(list): negative pixels in original background
131 | return:
132 | target_r: score map with ground-truth labels
133 | """
134 | target = np.uint8(np.ones((in_size, in_size)) * 255)
135 | feature_pos = []
136 | for pos in pos_label:
137 | x, y = pos
138 | x_new = int(x * in_size / original_width)
139 | y_new = int(y * in_size / original_height)
140 | target[y_new, x_new] = 1.
141 | if isflip:
142 | x_new = 256 - x_new
143 | feature_pos.append((x_new, y_new))
144 | for pos in neg_label:
145 | x, y = pos
146 | x_new = int(x * in_size / original_width)
147 | y_new = int(y * in_size / original_height)
148 | target[y_new, x_new] = 0.
149 | if isflip:
150 | x_new = 256 - x_new
151 | feature_pos.append((x_new, y_new))
152 | target_r = Image.fromarray(target)
153 | if isflip:
154 | target_r = transforms.RandomHorizontalFlip(p=1)(target_r)
155 | return target_r, feature_pos
156 |
157 |
158 | def _collect_info(json_file, bg_dir, fg_dir, mask_dir, datatype='train'):
159 | """
160 | load json file and return required information
161 | Args:
162 | json_file(str): json file with train/test information
163 | bg_dir(str): folder with background images
164 | fg_dir(str): folder with foreground images
165 | mask_dir(str): folder with foreground masks
166 | datatype(str): "train" or "test"
167 | return:
168 | index(int): the sample index
169 | background image path, foreground image path, foreground mask image
170 | foreground scale, the locations of positive/negative pixels
171 | """
172 | f_json = json.load(open(json_file, 'r'))
173 | return [
174 | (
175 | index,
176 | row['scID'].rjust(12,'0'),
177 | os.path.join(bg_dir, "%012d.jpg" % int(row['scID'])), # background image path
178 | os.path.join(fg_dir, "{}/{}_{}_{}_{}.jpg".format(datatype, int(row['annID']), int(row['scID']), # scaled foreground image path
179 | int(row['newWidth']), int(row['newHeight']))),
180 |
181 | os.path.join(mask_dir, "{}/{}_{}_{}_{}.jpg".format(datatype, int(row['annID']), int(row['scID']), # scaled foreground mask path
182 | int(row['newWidth']), int(row['newHeight']))),
183 | row['scale'],
184 | row['pos_label'], row['neg_label'],
185 | os.path.join(fg_dir, "foreground/{}.jpg".format(int(row['annID']))), # original foreground image path
186 | os.path.join(fg_dir, "foreground/mask_{}.jpg".format(int(row['annID']))), # original foreground mask path
187 | int(row['newWidth']), int(row['newHeight']) # scaled foreground width and height
188 | )
189 | for index, row in enumerate(f_json)
190 | ]
191 |
192 |
193 | def _to_center(bbox):
194 | """conver bbox to center pixel"""
195 | x, y, width, height = bbox
196 | return x + width // 2, y + height // 2
197 |
198 |
199 | def create_loader(table_path, bg_dir, fg_dir, mask_dir, in_size, datatype, batch_size, num_workers, shuffle):
200 | dset = CPDataset(table_path, bg_dir, fg_dir, mask_dir, in_size, datatype)
201 | data_loader = DataLoader(dset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
202 |
203 | return data_loader
204 |
205 |
206 | def make_composite(fg_img, mask_img, bg_img, pos, isflip=False):
207 | x, y, w, h = pos
208 | bg_h = bg_img.height
209 | bg_w = bg_img.width
210 | # resize foreground to expected size [h, w]
211 | fg_transform = transforms.Compose([
212 | transforms.Resize((h, w)),
213 | transforms.ToTensor(),
214 | ])
215 | top = max(y, 0)
216 | bottom = min(y + h, bg_h)
217 | left = max(x, 0)
218 | right = min(x + w, bg_w)
219 | fg_img_ = fg_transform(fg_img)
220 | mask_img_ = fg_transform(mask_img)
221 | fg_img = torch.zeros(3, bg_h, bg_w)
222 | mask_img = torch.zeros(3, bg_h, bg_w)
223 | fg_img[:, top:bottom, left:right] = fg_img_[:, top - y:bottom - y, left - x:right - x]
224 | mask_img[:, top:bottom, left:right] = mask_img_[:, top - y:bottom - y, left - x:right - x]
225 | bg_img = transforms.ToTensor()(bg_img)
226 | blended = fg_img * mask_img + bg_img * (1 - mask_img)
227 | com_pic = transforms.ToPILImage()(blended).convert('RGB')
228 | if isflip == False:
229 | com_pic = transforms.Compose(
230 | [
231 | transforms.Resize((256, 256)),
232 | transforms.ToTensor()
233 | ]
234 | )(com_pic)
235 | mask_img = transforms.ToPILImage()(mask_img).convert('L')
236 | mask_img = transforms.Compose(
237 | [
238 | transforms.Resize((256, 256)),
239 | transforms.ToTensor()
240 | ]
241 | )(mask_img)
242 | com_pic = torch.cat((com_pic, mask_img), dim=0)
243 | else:
244 | com_pic = transforms.Compose(
245 | [
246 | transforms.Resize((256, 256)),
247 | transforms.RandomHorizontalFlip(p=1),
248 | transforms.ToTensor()
249 | ]
250 | )(com_pic)
251 | mask_img = transforms.ToPILImage()(mask_img).convert('L')
252 | mask_img = transforms.Compose(
253 | [
254 | transforms.Resize((256, 256)),
255 | transforms.RandomHorizontalFlip(p=1),
256 | transforms.ToTensor()
257 | ]
258 | )(mask_img)
259 | com_pic = torch.cat((com_pic, mask_img), dim=0)
260 | return com_pic
261 |
262 | def make_composite_PIL(fg_img, mask_img, bg_img, pos, return_mask=False):
263 | x, y, w, h = pos
264 | bg_h = bg_img.height
265 | bg_w = bg_img.width
266 |
267 | top = max(y, 0)
268 | bottom = min(y + h, bg_h)
269 | left = max(x, 0)
270 | right = min(x + w, bg_w)
271 | fg_img_ = fg_img.resize((w,h))
272 | mask_img_ = mask_img.resize((w,h))
273 |
274 | fg_img_ = np.array(fg_img_)
275 | mask_img_ = np.array(mask_img_, dtype=np.float_)/255
276 | bg_img = np.array(bg_img)
277 |
278 | fg_img = np.zeros((bg_h, bg_w, 3), dtype=np.uint8)
279 | mask_img = np.zeros((bg_h, bg_w, 3), dtype=np.float_)
280 |
281 | fg_img[top:bottom, left:right, :] = fg_img_[top - y:bottom - y, left - x:right - x, :]
282 | mask_img[top:bottom, left:right, :] = mask_img_[top - y:bottom - y, left - x:right - x, :]
283 | composite_img = fg_img * mask_img + bg_img * (1 - mask_img)
284 |
285 |
286 | composite_img = Image.fromarray(composite_img.astype(np.uint8))
287 | if return_mask==False:
288 | return composite_img
289 | else:
290 | composite_msk = Image.fromarray((mask_img*255).astype(np.uint8))
291 | return composite_img, composite_msk
292 |
293 |
--------------------------------------------------------------------------------
/data/all_transforms.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | class JointResize(object):
4 | def __init__(self, size):
5 | if isinstance(size, int):
6 | self.size = (size, size)
7 | elif isinstance(size, tuple):
8 | self.size = size
9 | else:
10 | raise RuntimeError("size should be int or tuple")
11 |
12 | def __call__(self, bg, fg, mask):
13 | bg = bg.resize(self.size, Image.BILINEAR)
14 | fg = fg.resize(self.size, Image.BILINEAR)
15 | mask = mask.resize(self.size, Image.NEAREST)
16 | return bg, fg, mask
17 |
18 | class Compose(object):
19 | def __init__(self, transforms):
20 | self.transforms = transforms
21 |
22 | def __call__(self, bg, fg, mask):
23 | for t in self.transforms:
24 | bg, fg, mask = t(bg, fg, mask)
25 | return bg, fg, mask
26 |
--------------------------------------------------------------------------------
/network/BaseBlocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BasicConv2d(nn.Module):
5 | def __init__(
6 | self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False,
7 | ):
8 | super(BasicConv2d, self).__init__()
9 |
10 | self.basicconv = nn.Sequential(
11 | nn.Conv2d(
12 | in_planes,
13 | out_planes,
14 | kernel_size=kernel_size,
15 | stride=stride,
16 | padding=padding,
17 | dilation=dilation,
18 | groups=groups,
19 | bias=bias,
20 | ),
21 | nn.BatchNorm2d(out_planes),
22 | nn.ReLU(inplace=True),
23 | )
24 |
25 | def forward(self, x):
26 | return self.basicconv(x)
27 |
--------------------------------------------------------------------------------
/network/DynamicModules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class simpleDFN(nn.Module):
6 | def __init__(self, in_xC, in_yC, out_C, kernel_size=3, down_factor=4):
7 | """use nn.Unfold to realize dynamic convolution
8 |
9 | Args:
10 | in_xC (int): channel number of first input
11 | in_yC (int): channel number of second input
12 | out_C (int): channel number of output
13 | kernel_size (int): the size of generated conv kernel
14 | down_factor (int): reduce the model parameters when generating conv kernel
15 | """
16 | super(simpleDFN, self).__init__()
17 | self.kernel_size = kernel_size
18 | self.fuse = nn.Conv2d(in_xC, out_C, 3, 1, 1)
19 | self.out_C = out_C
20 | self.gernerate_kernel = nn.Sequential(
21 | # nn.Conv2d(in_yC, in_yC, 3, 1, 1),
22 | # DenseLayer(in_yC, in_yC, k=down_factor),
23 | nn.Conv2d(in_yC, in_xC, 1),
24 | )
25 | self.unfold = nn.Unfold(kernel_size=3, dilation=1, padding=1, stride=1)
26 | self.pool = nn.AdaptiveAvgPool2d(self.kernel_size)
27 | self.in_planes = in_yC
28 |
29 | def forward(self, x, y): # x:bg y:fg
30 | kernel = self.gernerate_kernel(self.pool(y))
31 | batch_size, in_planes, height, width = x.size()
32 | x = x.view(1, -1, height, width)
33 | kernel = kernel.view(-1, 1, self.kernel_size, self.kernel_size)
34 | if self.kernel_size == 3:
35 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=1, groups=self.in_planes * batch_size)
36 | elif self.kernel_size == 1:
37 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=0, groups=self.in_planes * batch_size)
38 | elif self.kernel_size == 5:
39 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=2, groups=self.in_planes * batch_size)
40 | else:
41 | output = F.conv2d(x, kernel, bias=None, stride=1, padding=3, groups=self.in_planes * batch_size)
42 | output = output.view(batch_size, -1, height, width)
43 | return self.fuse(output)
44 |
--------------------------------------------------------------------------------
/network/ObPlaNet_simple.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | from torchvision import transforms
5 |
6 | sys.path.append("..")
7 | from backbone.ResNet import Backbone_ResNet18_in3, Backbone_ResNet18_in3_1
8 | from network.BaseBlocks import BasicConv2d
9 | from network.DynamicModules import simpleDFN
10 | from network.tensor_ops import cus_sample, upsample_add
11 |
12 | class ObPlaNet_resnet18(nn.Module):
13 | def __init__(self, pretrained=True, ks=3, scale=3):
14 | super(ObPlaNet_resnet18, self).__init__()
15 | self.Eiters = 0
16 | self.upsample_add = upsample_add
17 | self.upsample = cus_sample
18 | self.to_pil = transforms.ToPILImage()
19 | self.scale = scale
20 |
21 | self.add_mask = True
22 |
23 | (
24 | self.bg_encoder1,
25 | self.bg_encoder2,
26 | self.bg_encoder4,
27 | self.bg_encoder8,
28 | self.bg_encoder16,
29 | ) = Backbone_ResNet18_in3(pretrained=pretrained)
30 |
31 | # freeze background encoder
32 | for p in self.parameters():
33 | p.requires_grad = False
34 |
35 | (
36 | self.fg_encoder1,
37 | self.fg_encoder2,
38 | self.fg_encoder4,
39 | self.fg_encoder8,
40 | self.fg_encoder16,
41 | self.fg_encoder32,
42 | ) = Backbone_ResNet18_in3_1(pretrained=pretrained)
43 |
44 | if self.add_mask:
45 | self.mask_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
46 |
47 | # dynamic conv
48 | self.fg_trans16 = nn.Conv2d(512, 64, 1)
49 | self.fg_trans8 = nn.Conv2d(256, 64, 1)
50 | self.selfdc_16 = simpleDFN(64, 64, 512, ks, 4)
51 | self.selfdc_8 = simpleDFN(64, 64, 512, ks, 4)
52 |
53 | self.upconv16 = BasicConv2d(512, 256, kernel_size=3, stride=1, padding=1)
54 | self.upconv8 = BasicConv2d(256, 128, kernel_size=3, stride=1, padding=1)
55 | self.upconv4 = BasicConv2d(128, 64, kernel_size=3, stride=1, padding=1)
56 | self.upconv2 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
57 | self.upconv1 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
58 |
59 | self.classifier = nn.Conv2d(512, 2, 1)
60 |
61 | def forward(self, bg_in_data, fg_in_data, mask_in_data=None, mode='test'):
62 | """
63 | Args:
64 | bg_in_data: (batch_size * 3 * H * W) background image
65 | fg_in_data: (batch_size * 3 * H * W) scaled foreground image
66 | mask_in_data: (batch_size * 1 * H * W) scaled foreground mask
67 | mode: "train" or "test"
68 | """
69 | if ('train' == mode):
70 | self.Eiters += 1
71 |
72 | # extract background and foreground features
73 | black_mask = torch.zeros(mask_in_data.size()).to(mask_in_data.device)
74 | bg_in_data_ = torch.cat([bg_in_data, black_mask], dim=1)
75 | bg_in_data_1 = self.bg_encoder1(bg_in_data_) # torch.Size([2, 64, 128, 128])
76 | fg_cat_mask = torch.cat([fg_in_data, mask_in_data], dim=1)
77 | fg_in_data_1 = self.fg_encoder1(fg_cat_mask) # torch.Size([2, 64, 128, 128])
78 |
79 |
80 | bg_in_data_2 = self.bg_encoder2(bg_in_data_1) # torch.Size([2, 64, 64, 64])
81 | fg_in_data_2 = self.fg_encoder2(fg_in_data_1) # torch.Size([2, 64, 128, 128])
82 | bg_in_data_4 = self.bg_encoder4(bg_in_data_2) # torch.Size([2, 128, 32, 32])
83 | fg_in_data_4 = self.fg_encoder4(fg_in_data_2) # torch.Size([2, 64, 64, 64])
84 | del fg_in_data_1, fg_in_data_2
85 |
86 | bg_in_data_8 = self.bg_encoder8(bg_in_data_4) # torch.Size([2, 256, 16, 16])
87 | fg_in_data_8 = self.fg_encoder8(fg_in_data_4) # torch.Size([2, 128, 32, 32])
88 | bg_in_data_16 = self.bg_encoder16(bg_in_data_8) # torch.Size([2, 512, 8, 8])
89 | fg_in_data_16 = self.fg_encoder16(fg_in_data_8) # torch.Size([2, 256, 16, 16])
90 | fg_in_data_32 = self.fg_encoder32(fg_in_data_16) # torch.Size([2, 512, 8, 8])
91 |
92 | in_data_8_aux = self.fg_trans8(fg_in_data_16) # torch.Size([2, 64, 16, 16])
93 | in_data_16_aux = self.fg_trans16(fg_in_data_32) # torch.Size([2, 64, 8, 8])
94 |
95 | # Unet decoder
96 | bg_out_data_16 = bg_in_data_16 # torch.Size([2, 512, 8, 8])
97 |
98 | bg_out_data_8 = self.upsample_add(self.upconv16(bg_out_data_16), bg_in_data_8) # torch.Size([2, 256, 16, 16])
99 | bg_out_data_4 = self.upsample_add(self.upconv8(bg_out_data_8), bg_in_data_4) # torch.Size([2, 128, 32, 32])
100 | bg_out_data_2 = self.upsample_add(self.upconv4(bg_out_data_4), bg_in_data_2) # torch.Size([2, 64, 64, 64])
101 | bg_out_data_1 = self.upsample_add(self.upconv2(bg_out_data_2), bg_in_data_1) # torch.Size([2, 64, 128, 128])
102 | del bg_out_data_2, bg_out_data_4, bg_out_data_8, bg_out_data_16
103 |
104 | bg_out_data = self.upconv1(self.upsample(bg_out_data_1, scale_factor=2)) # torch.Size([2, 64, 256, 256])
105 |
106 | # fuse foreground and background features using dynamic conv
107 | fuse_out = self.upsample_add(self.selfdc_16(bg_out_data_1, in_data_8_aux), \
108 | self.selfdc_8(bg_out_data, in_data_16_aux)) # torch.Size([2, 64, 256, 256])
109 |
110 | out_data = self.classifier(fuse_out) # torch.Size([2, 2, 256, 256])
111 |
112 | return out_data, fuse_out
113 |
114 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | from network.ObPlaNet_simple import *
2 |
--------------------------------------------------------------------------------
/network/tensor_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def cus_sample(feat, **kwargs):
5 | assert len(kwargs.keys()) == 1 and list(kwargs.keys())[0] in ["size", "scale_factor"]
6 | return F.interpolate(feat, **kwargs, mode="bilinear", align_corners=True)
7 |
8 |
9 | def upsample_add(*xs):
10 | y = xs[-1]
11 | for x in xs[:-1]:
12 | y = y + F.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False)
13 | return y
14 |
15 |
16 | def upsample_cat(*xs):
17 | y = xs[-1]
18 | out = []
19 | for x in xs[:-1]:
20 | out.append(F.interpolate(x, size=y.size()[2:], mode="bilinear", align_corners=False))
21 | return torch.cat([*out, y], dim=1)
22 |
23 |
24 | def upsample_reduce(b, a):
25 | _, C, _, _ = b.size()
26 | N, _, H, W = a.size()
27 |
28 | b = F.interpolate(b, size=(H, W), mode="bilinear", align_corners=False)
29 | a = a.reshape(N, -1, C, H, W).mean(1)
30 |
31 | return b + a
32 |
33 |
34 | def shuffle_channels(x, groups):
35 | N, C, H, W = x.size()
36 | x = x.reshape(N, groups, C // groups, H, W).permute(0, 2, 1, 3, 4)
37 | return x.reshape(N, C, H, W)
38 |
--------------------------------------------------------------------------------
/prepare_multi_fg_scales.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import json
4 |
5 | import numpy as np
6 | from PIL import Image
7 | from tqdm import tqdm
8 | from config import arg_config
9 |
10 | fg_scale_num = 16
11 | save_img_flag = True
12 |
13 |
14 | def collect_info(json_file, bg_dir, fg_dir):
15 |
16 | f_json = json.load(open(json_file, 'r'))
17 | return [
18 | (
19 | row['imgID'], row['annID'], row['scID'],
20 | os.path.join(bg_dir, "%012d.jpg" % int(row['scID'])),
21 | os.path.join(fg_dir, "foreground/{}.jpg".format(int(row['annID']))),
22 | os.path.join(fg_dir, "foreground/mask_{}.jpg".format(int(row['annID'])))
23 | )
24 | for _, row in enumerate(f_json)
25 | ]
26 |
27 |
28 | fg_scales = list(range(1, fg_scale_num+1))
29 | fg_scales = [i/(1+fg_scale_num+1) for i in fg_scales]
30 |
31 | fg_bg_dict = dict()
32 | args = arg_config
33 | data = collect_info(args["test_data_path"], args["bg_dir"], args["fg_dir"])
34 |
35 | csv_dir = './data/data'
36 | scaled_fg_dir = f'./data/data/fg/test_{fg_scale_num}scales/'
37 | scaled_mask_dir = f'./data/data/mask/test_{fg_scale_num}scales/'
38 |
39 | os.makedirs(scaled_fg_dir, exist_ok=True)
40 | os.makedirs(scaled_mask_dir, exist_ok=True)
41 |
42 | csv_file = os.path.join(csv_dir, f'test_data_{fg_scale_num}scales.csv')
43 | json_file = csv_file.replace('.csv', '.json')
44 |
45 | file = open(csv_file, mode='w', newline='')
46 | writer = csv.writer(file)
47 |
48 |
49 | csv_head = ['imgID', 'annID', 'scID', 'scale', 'newWidth', 'newHeight', 'pos_label', 'neg_label']
50 | writer.writerow(csv_head)
51 |
52 |
53 |
54 | for _,index in enumerate(tqdm(range(len(data)))):
55 | imgID, fg_id, bg_id, bg_path, fg_path, mask_path = data[index]
56 | if (fg_id, bg_id) in fg_bg_dict.keys():
57 | continue
58 | fg_bg_dict[(fg_id, bg_id)] = 1
59 |
60 |
61 | bg_img = Image.open(bg_path)
62 | if len(bg_img.split()) != 3:
63 | bg_img = bg_img.convert("RGB")
64 | bg_img_aspect = bg_img.height/bg_img.width
65 | fg_tocp = Image.open(fg_path).convert("RGB")
66 | mask_tocp = Image.open(mask_path).convert("RGB")
67 | fg_tocp_aspect = fg_tocp.height/fg_tocp.width
68 |
69 | for fg_scale in fg_scales:
70 | if fg_tocp_aspect>bg_img_aspect:
71 | new_height = bg_img.height*fg_scale
72 | new_width = new_height/fg_tocp.height*fg_tocp.width
73 | else:
74 | new_width = bg_img.width*fg_scale
75 | new_height = new_width/fg_tocp.width*fg_tocp.height
76 |
77 | new_height = int(new_height)
78 | new_width = int(new_width)
79 |
80 | if save_img_flag:
81 | top = int((bg_img.height-new_height)/2)
82 | bottom = top+new_height
83 | left = int((bg_img.width-new_width)/2)
84 | right = left+new_width
85 |
86 | fg_img_ = fg_tocp.resize((new_width, new_height))
87 | mask_ = mask_tocp.resize((new_width, new_height))
88 |
89 | fg_img_ = np.array(fg_img_)
90 | mask_ = np.array(mask_)
91 |
92 | fg_img = np.zeros((bg_img.height, bg_img.width, 3), dtype=np.uint8)
93 | mask = np.zeros((bg_img.height, bg_img.width, 3), dtype=np.uint8)
94 |
95 | fg_img[top:bottom, left:right, :] = fg_img_
96 | mask[top:bottom, left:right, :] = mask_
97 |
98 | fg_img = Image.fromarray(fg_img.astype(np.uint8))
99 | mask = Image.fromarray(mask.astype(np.uint8))
100 |
101 | basename = f'{fg_id}_{bg_id}_{new_width}_{new_height}.jpg'
102 | fg_img_path = os.path.join(scaled_fg_dir, basename)
103 | mask_path = os.path.join(scaled_mask_dir, basename)
104 | fg_img.save(fg_img_path)
105 | mask.save(mask_path)
106 |
107 | writer.writerow([imgID, fg_id, bg_id, fg_scale, new_width, new_height, None, None])
108 |
109 |
110 | file.close()
111 |
112 | # convert csv file to json file
113 | csv_data = []
114 | with open(csv_file, mode='r') as file:
115 | reader = csv.DictReader(file)
116 | for row in reader:
117 | if row['pos_label']=="":
118 | row['pos_label'] = [[0,0]]
119 | if row['neg_label']=="":
120 | row['neg_label'] = [[0,0]]
121 | csv_data.append(row)
122 |
123 | with open(json_file, mode='w') as file:
124 | json.dump(csv_data, file, indent=4)
125 |
126 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | tensorboard_logger
3 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | from pprint import pprint
7 | from torchvision import transforms
8 | from tqdm import tqdm
9 |
10 | import network
11 | from config import arg_config
12 | from data.OBdataset import create_loader, _collect_info
13 | from data.OBdataset import make_composite_PIL
14 |
15 |
16 | class Evaluator:
17 | def __init__(self, args, checkpoint_path):
18 | super(Evaluator, self).__init__()
19 | self.args = args
20 | self.dev = torch.device("cuda:0")
21 | self.to_pil = transforms.ToPILImage()
22 | self.checkpoint_path = checkpoint_path
23 | pprint(self.args)
24 |
25 | print('load pretrained weights from ', checkpoint_path)
26 | self.net = getattr(network, self.args["model"])(
27 | pretrained=False).to(self.dev)
28 | self.net.load_state_dict(torch.load(checkpoint_path, map_location=self.dev), strict=False)
29 | self.net = self.net.to(self.dev).eval()
30 | self.softmax = torch.nn.Softmax(dim=1)
31 |
32 | def evalutate_model(self, datatype):
33 | '''
34 | calculate F1 and bAcc metrics
35 | '''
36 |
37 | correct = 0
38 | total = 0
39 | TP = 0
40 | TN = 0
41 | FP = 0
42 | FN = 0
43 |
44 | assert datatype=='train' or datatype=='test'
45 |
46 | self.ts_loader = create_loader(
47 | self.args[f"{datatype}_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"],
48 | self.args["input_size"], datatype, self.args["batch_size"], self.args["num_workers"], False,
49 | )
50 |
51 | with torch.no_grad():
52 |
53 | for _, test_data in enumerate(tqdm(self.ts_loader)):
54 | _, test_bgs, test_masks, test_fgs, test_targets, nums, composite_list, feature_pos, _, _, _ = test_data
55 | test_bgs = test_bgs.to(self.dev, non_blocking=True)
56 | test_masks = test_masks.to(self.dev, non_blocking=True)
57 | test_fgs = test_fgs.to(self.dev, non_blocking=True)
58 | nums = nums.to(self.dev, non_blocking=True)
59 | composite_list = composite_list.to(self.dev, non_blocking=True)
60 | feature_pos = feature_pos.to(self.dev, non_blocking=True)
61 |
62 | test_outs, _ = self.net(test_bgs, test_fgs, test_masks, 'val')
63 | test_preds = np.argmax(test_outs.cpu().numpy(), axis=1)
64 | test_targets = test_targets.cpu().numpy()
65 |
66 | TP += ((test_preds == 1) & (test_targets == 1)).sum()
67 | TN += ((test_preds == 0) & (test_targets == 0)).sum()
68 | FP += ((test_preds == 1) & (test_targets == 0)).sum()
69 | FN += ((test_preds == 0) & (test_targets == 1)).sum()
70 |
71 | correct += (test_preds == test_targets).sum()
72 | total += nums.sum()
73 |
74 | precision = TP / (TP + FP)
75 | recall = TP / (TP + FN)
76 | fscore = (2 * precision * recall) / (precision + recall)
77 | weighted_acc = (TP / (TP + FN) + TN / (TN + FP)) * 0.5
78 |
79 | print('F-1 Measure: %f, ' % fscore)
80 | print('Weighted acc measure: %f, ' % weighted_acc)
81 |
82 | def get_heatmap(self, datatype):
83 | '''
84 | generate heatmap for each pair of scaled foreground and background
85 | '''
86 |
87 | save_dir, base_name = os.path.split(self.checkpoint_path)
88 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_heatmap'))
89 |
90 | if not os.path.exists(heatmap_dir):
91 | print(f"Create directory {heatmap_dir}")
92 | os.makedirs(heatmap_dir)
93 |
94 |
95 |
96 | self.ts_loader = create_loader(
97 | self.args[f"{datatype}_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"],
98 | self.args["input_size"], datatype, 1, self.args["num_workers"], False,
99 | )
100 |
101 | with torch.no_grad():
102 | for _, test_data in enumerate(tqdm(self.ts_loader)):
103 | _, test_bgs, test_masks, test_fgs, _, nums, composite_list, feature_pos, _, _, save_name = test_data
104 | test_bgs = test_bgs.to(self.dev, non_blocking=True)
105 | test_masks = test_masks.to(self.dev, non_blocking=True)
106 | test_fgs = test_fgs.to(self.dev, non_blocking=True)
107 | nums = nums.to(self.dev, non_blocking=True)
108 | composite_list = composite_list.to(self.dev, non_blocking=True)
109 | feature_pos = feature_pos.to(self.dev, non_blocking=True)
110 |
111 | test_outs, _ = self.net(test_bgs, test_fgs, test_masks, 'test')
112 | test_outs = self.softmax(test_outs)
113 |
114 | test_outs = test_outs[:,1,:,:]
115 | test_outs = transforms.ToPILImage()(test_outs)
116 | test_outs.save(os.path.join(heatmap_dir, save_name[0]))
117 |
118 | def generate_composite(self, datatype, composite_num):
119 | '''
120 | generate composite images for each pair of scaled foreground and background
121 | '''
122 |
123 | save_dir, base_name = os.path.split(self.checkpoint_path)
124 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_heatmap'))
125 | if not os.path.exists(heatmap_dir):
126 | print(f"{heatmap_dir} does not exist! Please first use 'heatmap' mode to generate heatmaps")
127 |
128 | data = _collect_info(self.args[f"{datatype}_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 'test')
129 | for index in range(len(data)):
130 | _, _, bg_path, fg_path, _, scale, _, _, fg_path_2, mask_path_2, w, h = data[index]
131 |
132 | fg_name = fg_path.split('/')[-1][:-4]
133 | save_name = fg_name + '_' + str(scale)
134 |
135 | bg_img = Image.open(bg_path)
136 | if len(bg_img.split()) != 3:
137 | bg_img = bg_img.convert("RGB")
138 | fg_tocp = Image.open(fg_path_2).convert("RGB")
139 | mask_tocp = Image.open(mask_path_2).convert("RGB")
140 |
141 | composite_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_composite'), save_name)
142 | if not os.path.exists(composite_dir):
143 | print(f"Create directory {composite_dir}")
144 | os.makedirs(composite_dir)
145 |
146 | heatmap = Image.open(os.path.join(heatmap_dir, save_name+'.jpg'))
147 | heatmap = np.array(heatmap)
148 |
149 | # exclude boundary
150 | heatmap_center = np.zeros_like(heatmap, dtype=np.float_)
151 | hb= int(h/bg_img.height*heatmap.shape[0]/2)
152 | wb = int(w/bg_img.width*heatmap.shape[1]/2)
153 | heatmap_center[hb:-hb, wb:-wb] = heatmap[hb:-hb, wb:-wb]
154 |
155 | # sort pixels in a descending order based on the heatmap
156 | sorted_indices = np.argsort(-heatmap_center, axis=None)
157 | sorted_indices = np.unravel_index(sorted_indices, heatmap_center.shape)
158 | for i in range(composite_num):
159 | y_, x_ = sorted_indices[0][i], sorted_indices[1][i]
160 | x_ = x_/heatmap.shape[1]*bg_img.width
161 | y_ = y_/heatmap.shape[0]*bg_img.height
162 | x = int(x_ - w / 2)
163 | y = int(y_ - h / 2)
164 | # make composite image with foreground, background, and placement
165 | composite_img = make_composite_PIL(fg_tocp, mask_tocp, bg_img, [x, y, w, h])
166 | save_img_path = os.path.join(composite_dir, f'{save_name}_{int(x_)}_{int(y_)}.jpg')
167 | composite_img.save(save_img_path)
168 | print(save_img_path)
169 |
170 |
171 | if __name__ == "__main__":
172 |
173 | parser = argparse.ArgumentParser()
174 | # "evaluate": calculate F1 and bAcc
175 | # "heatmap": generate FOPA heatmap
176 | # "composite": generate composite images based on the heatmap
177 | parser.add_argument('--mode', type=str, default= "composite")
178 | # datatype: "train" or "test"
179 | parser.add_argument('--datatype', type=str, default= "test")
180 | parser.add_argument('--path', type=str, default= "demo2023-05-19-22:36:47.952468")
181 | parser.add_argument('--epoch', type=int, default= 23)
182 | args = parser.parse_args()
183 |
184 | #full_path = os.path.join('output', args.path, 'pth', f'{args.epoch}_state_final.pth')
185 | full_path = 'best_weight.pth'
186 |
187 | if not os.path.exists(full_path):
188 | print(f'{full_path} does not exist!')
189 | else:
190 | evaluator = Evaluator(arg_config, checkpoint_path=full_path)
191 | if args.mode== "evaluate":
192 | evaluator.evalutate_model(args.datatype)
193 | elif args.mode== "heatmap":
194 | evaluator.get_heatmap(args.datatype)
195 | elif args.mode== "composite":
196 | evaluator.generate_composite(args.datatype, 50)
197 | else:
198 | print(f'There is no {args.mode} mode.')
199 |
200 |
--------------------------------------------------------------------------------
/test_multi_fg_scales.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | from pprint import pprint
7 | from torchvision import transforms
8 | from tqdm import tqdm
9 |
10 | import network
11 | from config import arg_config
12 | from data.OBdataset import create_loader, _collect_info
13 | from data.OBdataset import make_composite_PIL
14 |
15 |
16 | class Evaluator:
17 | def __init__(self, args, checkpoint_path):
18 | super(Evaluator, self).__init__()
19 | self.args = args
20 | self.dev = torch.device("cuda:0")
21 | self.to_pil = transforms.ToPILImage()
22 | self.checkpoint_path = checkpoint_path
23 | pprint(self.args)
24 |
25 | print('load pretrained weights from ', checkpoint_path)
26 | self.net = getattr(network, self.args["model"])(
27 | pretrained=False).to(self.dev)
28 | self.net.load_state_dict(torch.load(checkpoint_path, map_location=self.dev), strict=False)
29 | self.net = self.net.to(self.dev).eval()
30 | self.softmax = torch.nn.Softmax(dim=1)
31 |
32 | def get_heatmap_multi_scales(self, fg_scale_num):
33 | '''
34 | generate heatmap for each pair of scaled foreground and background
35 | '''
36 |
37 | datatype= f"test_{fg_scale_num}scales"
38 |
39 | save_dir, base_name = os.path.split(self.checkpoint_path)
40 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_{datatype}_heatmap'))
41 |
42 | if not os.path.exists(heatmap_dir):
43 | print(f"Create directory {heatmap_dir}")
44 | os.makedirs(heatmap_dir)
45 |
46 |
47 | json_path = os.path.join('./data/data', f"test_data_{fg_scale_num}scales.json")
48 |
49 | self.ts_loader = create_loader(
50 | json_path, self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"],
51 | self.args["input_size"], datatype, 1, self.args["num_workers"], False,
52 | )
53 |
54 | with torch.no_grad():
55 | for _, test_data in enumerate(tqdm(self.ts_loader)):
56 | _, test_bgs, test_masks, test_fgs, _, nums, composite_list, feature_pos, _, _, save_name = test_data
57 | test_bgs = test_bgs.to(self.dev, non_blocking=True)
58 | test_masks = test_masks.to(self.dev, non_blocking=True)
59 | test_fgs = test_fgs.to(self.dev, non_blocking=True)
60 | nums = nums.to(self.dev, non_blocking=True)
61 | composite_list = composite_list.to(self.dev, non_blocking=True)
62 | feature_pos = feature_pos.to(self.dev, non_blocking=True)
63 |
64 | test_outs, _ = self.net(test_bgs, test_fgs, test_masks, 'test')
65 | test_outs = self.softmax(test_outs)
66 |
67 | test_outs = test_outs[:,1,:,:]
68 | test_outs = transforms.ToPILImage()(test_outs)
69 | test_outs.save(os.path.join(heatmap_dir, save_name[0]))
70 |
71 | def generate_composite_multi_scales(self, fg_scale_num, composite_num):
72 | '''
73 | generate composite images for each pair of scaled foreground and background
74 | '''
75 |
76 | fg_scales = list(range(1, fg_scale_num+1))
77 | fg_scales = [i/(1+fg_scale_num+1) for i in fg_scales]
78 |
79 | icount = 0
80 |
81 | save_dir, base_name = os.path.split(self.checkpoint_path)
82 | heatmap_dir = os.path.join(save_dir, base_name.replace('.pth', f'_test_{fg_scale_num}scales_heatmap'))
83 | if not os.path.exists(heatmap_dir):
84 | print(f"{heatmap_dir} does not exist! Please first use 'heatmap' mode to generate heatmaps")
85 |
86 | json_path = os.path.join('./data/data', f"test_data_{fg_scale_num}scales.json")
87 |
88 | data = _collect_info(json_path, self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"], 'test')
89 | for index in range(len(data)):
90 | _, _, bg_path, fg_path, _, scale, _, _, fg_path_2, mask_path_2, w, h = data[index]
91 |
92 | fg_name = fg_path.split('/')[-1][:-4]
93 | save_name = fg_name + '_' + str(scale)
94 | segs = fg_name.split('_')
95 | fg_id, bg_id = segs[0], segs[1]
96 | if icount==0:
97 |
98 | bg_img = Image.open(bg_path)
99 | if len(bg_img.split()) != 3:
100 | bg_img = bg_img.convert("RGB")
101 | fg_tocp = Image.open(fg_path_2).convert("RGB")
102 | mask_tocp = Image.open(mask_path_2).convert("RGB")
103 |
104 | composite_dir = os.path.join(save_dir, base_name.replace('.pth', f'_test_{fg_scale_num}scales_composite'), f'{fg_id}_{bg_id}')
105 | if not os.path.exists(composite_dir):
106 | print(f"Create directory {composite_dir}")
107 | os.makedirs(composite_dir)
108 |
109 | heatmap_center_list = []
110 | fg_size_list = []
111 |
112 | icount += 1
113 | heatmap = Image.open(os.path.join(heatmap_dir, save_name+'.jpg'))
114 | heatmap = np.array(heatmap)
115 | # exclude boundary
116 | heatmap_center = np.zeros_like(heatmap, dtype=np.float_)
117 | hb= int(h/bg_img.height*heatmap.shape[0]/2)
118 | wb = int(w/bg_img.width*heatmap.shape[1]/2)
119 | heatmap_center[hb:-hb, wb:-wb] = heatmap[hb:-hb, wb:-wb]
120 | heatmap_center_list.append(heatmap_center)
121 | fg_size_list.append((h,w))
122 |
123 | if icount==fg_scale_num:
124 | icount = 0
125 | heatmap_center_stack = np.stack(heatmap_center_list)
126 | # sort pixels in a descending order based on the heatmap
127 | sorted_indices = np.argsort(-heatmap_center_stack, axis=None)
128 | sorted_indices = np.unravel_index(sorted_indices, heatmap_center_stack.shape)
129 | for i in range(composite_num):
130 | iscale, y_, x_ = sorted_indices[0][i], sorted_indices[1][i], sorted_indices[2][i]
131 | h, w = fg_size_list[iscale]
132 | x_ = x_/heatmap.shape[1]*bg_img.width
133 | y_ = y_/heatmap.shape[0]*bg_img.height
134 | x = int(x_ - w / 2)
135 | y = int(y_ - h / 2)
136 | # make composite image with foreground, background, and placement
137 | composite_img, composite_msk = make_composite_PIL(fg_tocp, mask_tocp, bg_img, [x, y, w, h], return_mask=True)
138 | save_img_path = os.path.join(composite_dir, f'{fg_id}_{bg_id}_{x}_{y}_{w}_{h}.jpg')
139 | save_msk_path = os.path.join(composite_dir, f'{fg_id}_{bg_id}_{x}_{y}_{w}_{h}.png')
140 | composite_img.save(save_img_path)
141 | composite_msk.save(save_msk_path)
142 | print(save_img_path)
143 |
144 |
145 |
146 | if __name__ == "__main__":
147 | print("cuda: ", torch.cuda.is_available())
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--mode', type=str, default= "composite")
150 | parser.add_argument('--path', type=str, default= "demo2023-05-19-22:36:47.952468")
151 | parser.add_argument('--epoch', type=int, default= 20)
152 | args = parser.parse_args()
153 |
154 | fg_scale_num = 16
155 | composite_num = 50
156 |
157 | full_path = os.path.join('output', args.path, 'pth', f'{args.epoch}_state_final.pth')
158 |
159 | if not os.path.exists(full_path):
160 | print(f'{full_path} does not exist!')
161 | else:
162 | evaluator = Evaluator(arg_config, checkpoint_path=full_path)
163 | if args.mode== "heatmap":
164 | evaluator.get_heatmap_multi_scales(fg_scale_num)
165 | elif args.mode== "composite":
166 | evaluator.generate_composite_multi_scales(fg_scale_num, composite_num)
167 |
168 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from datetime import datetime
4 | from pprint import pprint
5 |
6 | import numpy as np
7 | import torch
8 | import torch.backends.cudnn as torchcudnn
9 | from torch.nn import CrossEntropyLoss
10 | from torch.optim import SGD, Adam
11 | from torchvision import transforms
12 |
13 | import argparse
14 | import random
15 | import network
16 | import tensorboard_logger as tb_logger
17 | import torch.nn as nn
18 |
19 | from backbone.ResNet import pretrained_resnet18_4ch
20 | from config import arg_config, proj_root
21 | from data.OBdataset import create_loader
22 | from utils.misc import AvgMeter, construct_path_dict, make_log, pre_mkdir
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--ex_name', type=str, default=arg_config["ex_name"])
26 | parser.add_argument('--alpha', type=float, default=16.)
27 | parser.add_argument('--resume', type=bool, help='resume from checkpoint')
28 |
29 | user_args = parser.parse_args()
30 | datetime_str = str(datetime.now())
31 | datetime_str = '-'.join(datetime_str.split())
32 | user_args.ex_name += datetime_str
33 |
34 | def setup_seed(seed):
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed_all(seed)
37 | np.random.seed(seed)
38 | random.seed(seed)
39 | torch.backends.cudnn.deterministic = True
40 |
41 | # set random seed
42 | setup_seed(0)
43 | torchcudnn.benchmark = True
44 | torchcudnn.enabled = True
45 | torchcudnn.deterministic = True
46 |
47 |
48 | class Trainer:
49 | def __init__(self, args):
50 | super(Trainer, self).__init__()
51 | self.args = args
52 | self.to_pil = transforms.ToPILImage()
53 | pprint(self.args)
54 |
55 | self.path = construct_path_dict(proj_root=proj_root, exp_name=user_args.ex_name) # self.args["Experiment_name"])
56 | pre_mkdir(path_config=self.path)
57 |
58 | # backup used file
59 | shutil.copy(f"{proj_root}/config.py", self.path["cfg_log"])
60 | shutil.copy(f"{proj_root}/train.py", self.path["trainer_log"])
61 | shutil.copy(f"{proj_root}/data/OBdataset.py", self.path["dataset_log"])
62 | shutil.copy(f"{proj_root}/network/ObPlaNet_simple.py", self.path["network_log"])
63 |
64 | # training data loader
65 | self.tr_loader = create_loader(
66 | self.args["train_data_path"], self.args["bg_dir"], self.args["fg_dir"], self.args["mask_dir"],
67 | self.args["input_size"], 'train', self.args["batch_size"], self.args["num_workers"], True,
68 | )
69 |
70 | # load model
71 | self.dev = torch.device(f'cuda:{arg_config["gpu_id"]}')
72 | self.net = getattr(network, self.args["model"])(pretrained=True).to(self.dev)
73 |
74 | # loss functions
75 | self.loss = CrossEntropyLoss(ignore_index=255, reduction=self.args["reduction"]).to(self.dev)
76 |
77 | # optimizer
78 | self.opti = self.make_optim()
79 |
80 | # record loss
81 | tb_logger.configure(self.path['pth_log'], flush_secs=5)
82 |
83 | self.end_epoch = self.args["epoch_num"]
84 | if user_args.resume:
85 | try:
86 | self.resume_checkpoint(load_path=self.path["final_full_net"], mode="all")
87 | except:
88 | print(f"{self.path['final_full_net']} does not exist and we will load {self.path['final_state_net']}")
89 | self.resume_checkpoint(load_path=self.path["final_state_net"], mode="onlynet")
90 | self.start_epoch = self.end_epoch
91 | else:
92 | self.start_epoch = 0
93 | self.iter_num = self.end_epoch * len(self.tr_loader)
94 |
95 |
96 | def train(self):
97 |
98 | for curr_epoch in range(self.start_epoch, self.end_epoch):
99 | self.net.train()
100 | train_loss_record = AvgMeter()
101 | mimicking_loss_record = AvgMeter()
102 |
103 | # change learning rate
104 | if self.args["lr_type"] == "poly":
105 | self.change_lr(curr_epoch)
106 | elif self.args["lr_type"] == "decay":
107 | self.change_lr(curr_epoch)
108 | elif self.args["lr_type"] == "all_decay":
109 | self.change_lr(curr_epoch)
110 | else:
111 | raise NotImplementedError
112 |
113 | for train_batch_id, train_data in enumerate(self.tr_loader):
114 | curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id
115 |
116 | self.opti.zero_grad()
117 |
118 | _, train_bgs, train_masks, train_fgs, train_targets, num, composite_list, feature_pos, _, _, _ = train_data
119 |
120 | train_bgs = train_bgs.to(self.dev, non_blocking=True)
121 | train_masks = train_masks.to(self.dev, non_blocking=True)
122 | train_fgs = train_fgs.to(self.dev, non_blocking=True)
123 | train_targets = train_targets.to(self.dev, non_blocking=True)
124 | num = num.to(self.dev, non_blocking=True)
125 | composite_list = composite_list.to(self.dev, non_blocking=True)
126 | feature_pos = feature_pos.to(self.dev, non_blocking=True)
127 |
128 | # model training
129 | train_outs, feature_map = self.net(train_bgs, train_fgs, train_masks, 'train')
130 |
131 | mimicking_loss = feature_mimicking(composite_list, feature_pos, feature_map, num, self.dev)
132 | out_loss = self.loss(train_outs, train_targets.long())
133 | train_loss = out_loss + user_args.alpha*mimicking_loss
134 | train_loss.backward()
135 | self.opti.step()
136 |
137 | train_iter_loss = out_loss.item()
138 | mimicking_iter_loss = mimicking_loss.item()
139 | train_batch_size = train_bgs.size(0)
140 | train_loss_record.update(train_iter_loss, train_batch_size)
141 | mimicking_loss_record.update(mimicking_iter_loss, train_batch_size)
142 |
143 | tb_logger.log_value('loss', train_loss.item(), step=self.net.Eiters)
144 |
145 | if self.args["print_freq"] > 0 and (curr_iter + 1) % self.args["print_freq"] == 0:
146 | log = (
147 | f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
148 | f"(L2)[Avg:{train_loss_record.avg:.3f}|Cur:{train_iter_loss:.3f}]"
149 | f"(Lm)[Avg:{mimicking_loss_record.avg:.3f}][Cur:{mimicking_iter_loss:.3f}]"
150 | )
151 | print(log)
152 | make_log(self.path["tr_log"], log)
153 |
154 | save_dir, save_name = os.path.split(self.path["final_full_net"])
155 | epoch_full_net_path = os.path.join(save_dir, str(curr_epoch + 1)+'_'+save_name)
156 | save_dir, save_name = os.path.split(self.path["final_state_net"])
157 | epoch_state_net_path = os.path.join(save_dir, str(curr_epoch + 1)+'_'+save_name)
158 |
159 | self.save_checkpoint(curr_epoch + 1, full_net_path=epoch_full_net_path, state_net_path=epoch_state_net_path)
160 |
161 |
162 | def change_lr(self, curr):
163 | total_num = self.end_epoch
164 | if self.args["lr_type"] == "poly":
165 | ratio = pow((1 - float(curr) / total_num), self.args["lr_decay"])
166 | self.opti.param_groups[0]["lr"] = self.opti.param_groups[0]["lr"] * ratio
167 | self.opti.param_groups[1]["lr"] = self.opti.param_groups[0]["lr"]
168 | elif self.args["lr_type"] == "decay":
169 | ratio = 0.1
170 | if (curr % 9 == 0):
171 | self.opti.param_groups[0]["lr"] = self.opti.param_groups[0]["lr"] * ratio
172 | self.opti.param_groups[1]["lr"] = self.opti.param_groups[0]["lr"]
173 | elif self.args["lr_type"] == "all_decay":
174 | lr = self.args["lr"] * (0.5 ** (curr // 2))
175 | for param_group in self.opti.param_groups:
176 | param_group['lr'] = lr
177 | else:
178 | raise NotImplementedError
179 |
180 | def make_optim(self):
181 | if self.args["optim"] == "sgd_trick":
182 | params = [
183 | {
184 | "params": [p for name, p in self.net.named_parameters() if ("bias" in name or "bn" in name)],
185 | "weight_decay": 0,
186 | },
187 | {
188 | "params": [
189 | p for name, p in self.net.named_parameters() if ("bias" not in name and "bn" not in name)
190 | ]
191 | },
192 | ]
193 | optimizer = SGD(
194 | params,
195 | lr=self.args["lr"],
196 | momentum=self.args["momentum"],
197 | weight_decay=self.args["weight_decay"],
198 | nesterov=self.args["nesterov"],
199 | )
200 | elif self.args["optim"] == "f3_trick":
201 | backbone, head = [], []
202 | for name, params_tensor in self.net.named_parameters():
203 | if "encoder" in name:
204 | backbone.append(params_tensor)
205 | else:
206 | head.append(params_tensor)
207 | params = [
208 | {"params": backbone, "lr": 0.1 * self.args["lr"]},
209 | {"params": head, "lr": self.args["lr"]},
210 | ]
211 | optimizer = SGD(
212 | params=params,
213 | momentum=self.args["momentum"],
214 | weight_decay=self.args["weight_decay"],
215 | nesterov=self.args["nesterov"],
216 | )
217 | elif self.args["optim"] == "Adam_trick":
218 | optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.args["lr"])
219 | else:
220 | raise NotImplementedError
221 | print("optimizer = ", optimizer)
222 | return optimizer
223 |
224 | def save_checkpoint(self, current_epoch, full_net_path, state_net_path):
225 | state_dict = {
226 | "epoch": current_epoch,
227 | "net_state": self.net.state_dict(),
228 | "opti_state": self.opti.state_dict(),
229 | }
230 | torch.save(state_dict, full_net_path)
231 | torch.save(self.net.state_dict(), state_net_path)
232 |
233 | def resume_checkpoint(self, load_path, mode="all"):
234 | """
235 | Args:
236 | load_path (str): path of pretrained model
237 | mode (str): 'all':resume all information;'onlynet':only resume model parameters
238 | """
239 | if os.path.exists(load_path) and os.path.isfile(load_path):
240 | print(f" =>> loading checkpoint '{load_path}' <<== ")
241 | checkpoint = torch.load(load_path, map_location=self.dev)
242 | if mode == "all":
243 | self.start_epoch = checkpoint["epoch"]
244 | self.net.load_state_dict(checkpoint["net_state"])
245 | self.opti.load_state_dict(checkpoint["opti_state"])
246 | print(f" ==> loaded checkpoint '{load_path}' (epoch {checkpoint['epoch']})")
247 | elif mode == "onlynet":
248 | self.net.load_state_dict(checkpoint)
249 | print(f" ==> loaded checkpoint '{load_path}' " f"(only has the net's weight params) <<== ")
250 | else:
251 | raise NotImplementedError
252 | else:
253 | raise Exception(f"{load_path} is not correct.")
254 |
255 |
256 | def feature_mimicking(composites, feature_pos, feature_map, num, device):
257 |
258 | net_ = pretrained_resnet18_4ch(pretrained=True).to(device)
259 |
260 | composite_cat_list = []
261 | pos_feature = torch.zeros(int(num.sum()), 512, 1, 1).to(device)
262 | count = 0
263 | for i in range(num.shape[0]):
264 | composite_cat_list.append(composites[i, :num[i], :, :, :])
265 | for j in range(num[i]):
266 | pos_feature[count, :, 0, 0] = feature_map[i, :, int(feature_pos[i, j, 1]), int(feature_pos[i, j, 0])]
267 | count += 1
268 | composites_ = torch.cat(composite_cat_list, dim=0)
269 | composite_feature = net_(composites_)
270 | composite_feature = nn.AdaptiveAvgPool2d(1)(composite_feature)
271 | pos_feature.view(-1, 512)
272 | composite_feature.view(-1, 512)
273 |
274 | mimicking_loss_criter = nn.MSELoss()
275 | mimicking_loss = mimicking_loss_criter(pos_feature, composite_feature)
276 |
277 | return mimicking_loss
278 |
279 |
280 | if __name__ == "__main__":
281 | trainer = Trainer(arg_config)
282 | print(f" ===========>> {datetime.now()}: begin training <<=========== ")
283 | trainer.train()
284 | print(f" ===========>> {datetime.now()}: end training <<=========== ")
285 |
286 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | class AvgMeter(object):
5 | def __init__(self):
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = 0
10 | self.avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.avg = self.sum / self.count
19 |
20 |
21 | def pre_mkdir(path_config):
22 | check_mkdir(path_config["pth_log"])
23 | check_mkdir(path_config["pth"])
24 | make_log(path_config["te_log"], f"=== te_log {datetime.now()} ===")
25 | make_log(path_config["tr_log"], f"=== tr_log {datetime.now()} ===")
26 |
27 |
28 | def check_mkdir(dir_name):
29 | if not os.path.exists(dir_name):
30 | os.makedirs(dir_name)
31 |
32 |
33 | def make_log(path, context):
34 | with open(path, "a") as log:
35 | log.write(f"{context}\n")
36 |
37 |
38 | def check_dir_path_valid(path: list):
39 | for p in path:
40 | if p:
41 | assert os.path.exists(p)
42 | assert os.path.isdir(p)
43 |
44 |
45 | def construct_path_dict(proj_root, exp_name):
46 | ckpt_path = os.path.join(proj_root, "output")
47 |
48 | pth_log_path = os.path.join(ckpt_path, exp_name)
49 |
50 | tb_path = os.path.join(pth_log_path, "tb")
51 | save_path = os.path.join(pth_log_path, "pre")
52 | pth_path = os.path.join(pth_log_path, "pth")
53 |
54 | final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth.tar")
55 | final_state_path = os.path.join(pth_path, "state_final.pth")
56 |
57 | tr_log_path = os.path.join(pth_log_path, f"tr_{str(datetime.now())[:10]}.txt")
58 | te_log_path = os.path.join(pth_log_path, f"te_{str(datetime.now())[:10]}.txt")
59 | cfg_log_path = os.path.join(pth_log_path, f"cfg_{str(datetime.now())[:10]}.txt")
60 | trainer_log_path = os.path.join(pth_log_path, f"trainer_{str(datetime.now())[:10]}.txt")
61 | dataset_log_path = os.path.join(pth_log_path, f"dataset_{str(datetime.now())[:10]}.txt")
62 | network_log_path = os.path.join(pth_log_path, f"network_{str(datetime.now())[:10]}.txt")
63 |
64 | path_config = {
65 | "ckpt_path": ckpt_path,
66 | "pth_log": pth_log_path,
67 | "tb": tb_path,
68 | "save": save_path,
69 | "pth": pth_path,
70 | "final_full_net": final_full_model_path,
71 | "final_state_net": final_state_path,
72 | "tr_log": tr_log_path,
73 | "te_log": te_log_path,
74 | "cfg_log": cfg_log_path,
75 | "trainer_log": trainer_log_path,
76 | "dataset_log": dataset_log_path,
77 | "network_log": network_log_path,
78 | }
79 |
80 | return path_config
81 |
--------------------------------------------------------------------------------