├── .gitignore
├── Demo_hr
├── Real_09_bg.jpg
└── Real_09_fg.png
├── LICENSE
├── PIH_train.py
├── README.md
├── __init__.py
├── dataset.py
├── demo.py
├── demo_data
└── train
│ ├── bg
│ ├── 00057ce6-67c5-411e-8087-3799b638518a_after_mask.png
│ └── 00057ce6-67c5-411e-8087-3799b638518a_before_mask.png
│ ├── masks
│ ├── 00057ce6-67c5-411e-8087-3799b638518a_after_mask.png
│ └── 00057ce6-67c5-411e-8087-3799b638518a_before_mask.png
│ └── real_images
│ ├── 00057ce6-67c5-411e-8087-3799b638518a_after.png
│ └── 00057ce6-67c5-411e-8087-3799b638518a_before.png
├── demo_light.py
├── environment.yml
├── github_images
├── Figure5.png
├── Figure_3.png
├── Figure_5_final.png
├── Figure_8_final.png
├── Figure_teaser.png
├── S1.png
└── demo.gif
├── inference.py
├── inference_scripts
├── Inference.sh
├── Inference_Composite.sh
├── Inference_Composite_masking.sh
├── Inference_Composite_masking_3.sh
├── Inference_Composite_masking_3_depth.sh
├── Inference_Composite_masking_3_highres.sh
├── Inference_Composite_masking_3_noweb.sh
├── Inference_Composite_masking_highres.sh
├── Inference_Composite_masking_pixel.sh
├── Inference_Composite_unet.sh
└── Inference_iHarmony.sh
├── model.py
├── pretrained
└── pretrained.placeholder
├── results
└── results.placeholder
├── scripts
├── installation.sh
└── train_example.sh
└── utils
├── efficientnet_v2.py
├── mobilenet_v3.py
├── modules.py
├── networks.py
├── resnet.py
├── resnet_ibn.py
├── unet
├── __init__.py
├── unet_model.py
└── unet_parts.py
└── unet_dis.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | PIH_ResNet/results/
3 | .DS_Store
4 | __pycache__
5 | .python-version
6 |
--------------------------------------------------------------------------------
/Demo_hr/Real_09_bg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/Demo_hr/Real_09_bg.jpg
--------------------------------------------------------------------------------
/Demo_hr/Real_09_fg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/Demo_hr/Real_09_fg.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Adobe Research
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Parametric Image Harmonization (PIH)
2 | [Project Page](http://people.eecs.berkeley.edu/~kewang/sprih/) | [Paper](https://arxiv.org/abs/2303.00157) | [Bibtex](https://people.eecs.berkeley.edu/~kewang/sprih/sprih.txt)
3 |
4 | Semi-supervised Parametric Real-world Image Harmonization.\
5 | _CVPR 2023_ \
6 | [Ke Wang](people.eecs.berkeley.edu/~kewang), [Michaël Gharbi](http://mgharbi.com/), [He Zhang](https://sites.google.com/site/hezhangsprinter/), [Zhihao Xia](https://likesum.github.io/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/)
7 |
8 |
9 |
10 |
11 | **A novel semi-supervised training strategy and the first harmonization method that learns complex local appearance harmonization from unpaired real composites.**
12 |
13 | The code was developed by [Ke Wang](people.eecs.berkeley.edu/~kewang) when Ke was a research scientist intern at Adobe research.
14 |
15 | Please contact Ke (kewang@berkeley.edu) or Michaël (mgharbi@adobe.com) if you have any question.
16 |
17 | **Results**
18 |
19 |
20 | Our results show better visual agreements with the ground truth compared to SOTA methods in terms of color harmonization
21 | (rows 1,2 and 4) and shading correction (row 3).
22 |
23 |
24 |
25 | RGB curves harmonize the global color/tone (center), while our shading map corrects the local shading in the harmonization output (right).
26 |
27 |
28 | ---
29 |
30 | ## Prerequisites
31 |
32 | - Linux
33 | - Python 3
34 | - NVIDIA GPU + CUDA CuDNN
35 | - [Conda](https://docs.conda.io/en/latest/) installed
36 |
37 |
38 | ---
39 |
40 | **Table of Contents:**
41 | 1. [Setup](#setup) - set up the enviroment
42 | 2. [Pretrained Models](#setup) - download pretrained models and resources
43 | 3. [Interactive Demo](#demo) - off-line interactive demo
44 | 4. [Inference](#inference) - inference on high-resolution images with pretrained model
45 | 5. [Dataset](#dataset) - prepare your own dataset for the training
46 | 6. [Training](#training) - pipeline for training PIH
47 | 7. [Citation](#citation) - bibtex citation
48 |
49 |
50 |
51 | ---
52 |
53 | ## Setup
54 |
55 | - Clone this repo:
56 |
57 | ```bash
58 | git clone git@github.com:adobe/PIH.git
59 | ```
60 |
61 | - Install dependencies
62 |
63 | We create a `environment.yml` to install the dependencies, you need to have [Conda](https://docs.conda.io/en/latest/) installed. Run
64 |
65 | ```
66 | conda env create -f environment.yml
67 | ```
68 | (essentially install [PyTorch](https://pytorch.org/))
69 |
70 | ---
71 |
72 | ## Pretrained models
73 |
74 | We provide our pre-trained model (93M parameters) on *Artist Retouched Dataset* from this [link](https://drive.google.com/file/d/1seW8qSnaBOQ4_S9bQ4ThVOdeJGYJ-f74/view?usp=sharing) and put it in the folder.
75 |
76 | ```
77 | ./pretrained/
78 | ```
79 |
80 | ---
81 |
82 | ## Demo
83 |
84 | We provide an interactive demo host offline built with [PyGame](https://www.pygame.org/news)
85 |
86 | First, we install the dependencies:
87 |
88 | ```
89 | python -m pip install -U pygame --user
90 | pip install pygame_gui
91 | pip install timm
92 | ```
93 |
94 | Then, simpy run the following command to start the demo:
95 | ```
96 | python demo.py
97 | ```
98 |
99 | Here we provide a tutorial video for the demo.
100 |
101 |
102 |
103 | ---
104 |
105 | ## Inference
106 |
107 | We provide the inference code for evaluations:
108 |
109 | ```
110 | python inference.py --bg --fg --checkpoints [--gpu]
111 | ```
112 |
113 | notes:
114 | - arguments `--gpu` enable inference on GPU using cuda, default is by using CPU.
115 | - arguments `--checkpoints` specifies the dir for the checkpoint.
116 |
117 |
118 | Example:
119 | ```
120 | python inference.py --bg Demo_hr/Real_09_bg.jpg --fg Demo_hr/Real_09_fg.png --checkpoints pretrained/ckpt_g39.pth --gpu
121 | ```
122 |
123 | Check the `results/` folder for output images.
124 |
125 | ---
126 |
127 | ## Dataset
128 |
129 |
130 |
131 | We prepare a guidline of preparing *Artist Retouched Dataset*.
132 |
133 | For image with name ``, we organize the `data` directory like this:
134 |
135 | ```
136 | data
137 | |--train
138 | |--bg
139 | |-- _before.png
140 | |-- _after.png
141 | |--masks
142 | |-- _before.png
143 | |-- _after.png
144 | |--real_images
145 | |-- _before.png
146 | |-- _after.png
147 |
148 | |--test
149 | |--bg
150 | |-- _before.png
151 | |-- _after.png
152 | |--masks
153 | |-- _before.png
154 | |-- _after.png
155 | |--real_images
156 | |-- _before.png
157 | |-- _after.png
158 | ```
159 |
160 | notes:
161 | - bg (background): Inpainted background using foreground masks. Here we use [LAMA](https://github.com/advimman/lama) to perform inpainting.
162 | - masks: Foreground masks, should be consistent between `Before`, and `After`.
163 | - real_images: Ground truth real images.
164 |
165 | ---
166 |
167 | ## Training
168 |
169 | Our approach uses a dual-stream semi-supervised training to bridge the domain gap, alleviating the generalization issues that plague many state-of-the-art harmonization models
170 |
171 |
172 |
173 |
174 | We provide the script `train_example.sh` to perform training.
175 |
176 | Training notes:
177 | - modify `--dir_data` to the path of your custom dataset.
178 | - arguments `recon_weight` correspons to the weighting parameter to balance stream 1 and stream 2.
179 |
180 | Simply run:
181 | ```
182 | bash scripts/train_example.sh
183 | ```
184 | to start the training.
185 |
186 | ---
187 |
188 | ## Citation
189 | If you use this code for your research, please cite our paper.
190 |
191 | ```
192 | @article{wang2023semi,
193 | title={Semi-supervised Parametric Real-world Image Harmonization},
194 | author={Wang, Ke and Gharbi, Micha{\"e}l and Zhang, He and Xia, Zhihao and Shechtman, Eli},
195 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
196 | year = {2023}
197 | }
198 |
199 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 |
11 |
12 | from glob import glob
13 | import os
14 | import numpy as np
15 | import torch
16 |
17 | from PIL import Image
18 |
19 | import torchvision.transforms as T
20 | import torchvision.transforms.functional as F
21 |
22 |
23 | from torch.utils.data import Dataset
24 | import random
25 | import sys
26 |
27 |
28 | class PIHData(Dataset):
29 | def __init__(self, data_directory, device=torch.device("cpu")):
30 | """
31 |
32 | Parameters
33 | ----------
34 | data_directory : str
35 | The directory containing the training image data.
36 | max_offset : tuple
37 | The maximum offset to crop an image to.
38 | magnitude : bool
39 | If True, train using magnitude image as input. Otherwise, use real and imaginary image in separate channels.
40 | device : torch.device
41 | The device to load the data to.
42 | complex : bool
43 | If True, return images as complex data. Otherwise check for magnitude return or for real and imaginary
44 | channels. This is needed when training, since post processing is done in the model (adds phase augmentation
45 | and converts to magnitude or channels). Magnitude and channels are implemented for evaluation.
46 | """
47 |
48 | self.image_paths = glob(f"{data_directory}/*_gt.jpg")
49 | print(
50 | f"Using data from: {data_directory}\nFound {len(self.image_paths)} image paths."
51 | )
52 | self.device = device
53 | self.transforms = T.Compose([T.ToTensor()])
54 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
55 |
56 | def __len__(self):
57 | return len(self.image_paths)
58 |
59 | def __getitem__(self, index):
60 | """Get image at the specified index.
61 |
62 | Parameters
63 | ----------
64 | index : int
65 | The image index.
66 |
67 | Returns
68 | -------
69 | patch: torch.Tensor
70 |
71 | """
72 |
73 | image_path = self.image_paths[index]
74 | ground_truth = Image.open(image_path)
75 | input_image = Image.open(image_path[: image_path.rindex("_")] + ".jpg")
76 | input_mask = Image.open(image_path[: image_path.rindex("_")] + "_mask.jpg")
77 |
78 | # original_image = np.load(self.image_paths[index])[None].astype(np.complex64)
79 |
80 | return (
81 | self.transforms(input_image),
82 | self.transforms_mask(input_mask),
83 | self.transforms(ground_truth),
84 | image_path,
85 | )
86 |
87 |
88 | class PIHDataRandom(Dataset):
89 | def __init__(self, data_directory, device=torch.device("cpu")):
90 | """
91 |
92 | Parameters
93 | ----------
94 | data_directory : str
95 | The directory containing the training image data.
96 | max_offset : tuple
97 | The maximum offset to crop an image to.
98 | magnitude : bool
99 | If True, train using magnitude image as input. Otherwise, use real and imaginary image in separate channels.
100 | device : torch.device
101 | The device to load the data to.
102 | complex : bool
103 | If True, return images as complex data. Otherwise check for magnitude return or for real and imaginary
104 | channels. This is needed when training, since post processing is done in the model (adds phase augmentation
105 | and converts to magnitude or channels). Magnitude and channels are implemented for evaluation.
106 | """
107 |
108 | self.image_paths = glob(f"{data_directory}/*_gt.jpg")
109 | print(
110 | f"Using data from: {data_directory}\nFound {len(self.image_paths)} image paths."
111 | )
112 | self.device = device
113 | self.transforms = T.Compose([T.ToTensor()])
114 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
115 |
116 | def __len__(self):
117 | return len(self.image_paths)
118 |
119 | def __getitem__(self, index):
120 | """Get image at the specified index.
121 |
122 | Parameters
123 | ----------
124 | index : int
125 | The image index.
126 |
127 | Returns
128 | -------
129 | patch: torch.Tensor
130 |
131 | """
132 |
133 | image_path = self.image_paths[index]
134 |
135 | ground_truth = self.transforms(Image.open(image_path))
136 | mask_torch = self.transforms(
137 | Image.open(image_path[: image_path.rindex("_")] + "_mask.jpg")
138 | )
139 |
140 | imag_torch = T.functional.adjust_contrast(
141 | ground_truth, (np.random.rand() * 0.4 + 0.8)
142 | )
143 |
144 | imag_torch = T.functional.adjust_brightness(
145 | imag_torch, (np.random.rand() * 0.4 + 0.8)
146 | )
147 |
148 | imag_torch = T.functional.adjust_saturation(
149 | imag_torch, (np.random.rand() * 0.4 + 0.8)
150 | )
151 |
152 | # Read functions for color transform: Cross - chaneel - YCC
153 | imag_torch[0, ...] = (
154 | imag_torch[0, ...] * (np.random.rand() * 0.3 + 0.70)
155 | + imag_torch[0, ...] * imag_torch[0, ...] * (np.random.rand() - 0.5) * 0.1
156 | + imag_torch[0, ...]
157 | * imag_torch[0, ...]
158 | * imag_torch[0, ...]
159 | * (np.random.rand() - 0.5)
160 | * 0.05
161 | )
162 |
163 | imag_torch[1, ...] = (
164 | imag_torch[1, ...] * (np.random.rand() * 0.3 + 0.70)
165 | + +imag_torch[1, ...] * imag_torch[1, ...] * (np.random.rand() - 0.5) * 0.1
166 | + imag_torch[1, ...]
167 | * imag_torch[1, ...]
168 | * imag_torch[1, ...]
169 | * (np.random.rand() - 0.5)
170 | * 0.05
171 | )
172 |
173 | imag_torch[2, ...] = (
174 | imag_torch[2, ...] * (np.random.rand() * 0.3 + 0.70)
175 | + imag_torch[2, ...] * imag_torch[2, ...] * (np.random.rand() - 0.5) * 0.1
176 | + imag_torch[2, ...]
177 | * imag_torch[2, ...]
178 | * imag_torch[2, ...]
179 | * (np.random.rand() - 0.5)
180 | * 0.05
181 | )
182 |
183 | imag_composite = ground_truth * (1 - mask_torch) + imag_torch * mask_torch
184 | # original_image = np.load(self.image_paths[index])[None].astype(np.complex64)
185 |
186 | return (
187 | imag_composite,
188 | mask_torch,
189 | ground_truth,
190 | image_path,
191 | )
192 |
193 |
194 | class PIHDataNGT(Dataset):
195 | def __init__(self, data_directory, device=torch.device("cpu")):
196 | """
197 |
198 | Parameters
199 | ----------
200 | data_directory : str
201 | The directory containing the training image data.
202 | max_offset : tuple
203 | The maximum offset to crop an image to.
204 | magnitude : bool
205 | If True, train using magnitude image as input. Otherwise, use real and imaginary image in separate channels.
206 | device : torch.device
207 | The device to load the data to.
208 | complex : bool
209 | If True, return images as complex data. Otherwise check for magnitude return or for real and imaginary
210 | channels. This is needed when training, since post processing is done in the model (adds phase augmentation
211 | and converts to magnitude or channels). Magnitude and channels are implemented for evaluation.
212 | """
213 |
214 | self.image_paths = glob(f"{data_directory}/*_mask.jpg")
215 | print(
216 | f"Using data from: {data_directory}\nFound {len(self.image_paths)} image paths."
217 | )
218 | self.device = device
219 | self.transforms = T.Compose([T.ToTensor()])
220 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
221 |
222 | def __len__(self):
223 | return len(self.image_paths)
224 |
225 | def __getitem__(self, index):
226 | """Get image at the specified index.
227 |
228 | Parameters
229 | ----------
230 | index : int
231 | The image index.
232 |
233 | Returns
234 | -------
235 | patch: torch.Tensor
236 |
237 | """
238 |
239 | image_path = self.image_paths[index]
240 | # ground_truth = Image.open(image_path)
241 | input_image = Image.open(image_path[: image_path.rindex("_")] + ".jpg")
242 | input_mask = Image.open(image_path[: image_path.rindex("_")] + "_mask.jpg")
243 |
244 | # original_image = np.load(self.image_paths[index])[None].astype(np.complex64)
245 |
246 | return (
247 | self.transforms(input_image),
248 | self.transforms_mask(input_mask),
249 | self.transforms_mask(input_mask),
250 | image_path,
251 | )
252 |
253 |
254 | class IhdDataset(Dataset):
255 | def __init__(self, opt):
256 | self.image_paths = []
257 | self.isTrain = opt.train
258 | if opt.train == True:
259 | print("loading training file")
260 | self.trainfile = opt.datadir + "IHD_train.txt"
261 | with open(self.trainfile, "r") as f:
262 | for line in f.readlines():
263 | self.image_paths.append(
264 | os.path.join(opt.datadir, "", line.rstrip())
265 | )
266 | else:
267 | print("loading test file")
268 | self.trainfile = opt.datadir + "IHD_test.txt"
269 | with open(self.trainfile, "r") as f:
270 | for line in f.readlines():
271 | self.image_paths.append(
272 | os.path.join(opt.datadir, "", line.rstrip())
273 | )
274 | self.transforms = T.Compose([T.ToTensor()])
275 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
276 | self.image_size = 512
277 |
278 | print(
279 | f"Using data from: {opt.datadir}\nFound {len(self.image_paths)} image paths."
280 | )
281 |
282 | def __getitem__(self, index):
283 |
284 | path = self.image_paths[index]
285 | name_parts = path.split("_")
286 | mask_path = self.image_paths[index].replace("composite_images", "masks")
287 | mask_path = mask_path.replace(("_" + name_parts[-1]), ".png")
288 | target_path = self.image_paths[index].replace("composite_images", "real_images")
289 | target_path = target_path.replace(
290 | ("_" + name_parts[-2] + "_" + name_parts[-1]), ".jpg"
291 | )
292 |
293 | comp = Image.open(path).convert("RGB")
294 | real = Image.open(target_path).convert("RGB")
295 | mask = Image.open(mask_path).convert("RGB")
296 |
297 | if np.random.rand() > 0.5 and self.isTrain:
298 | comp, mask, real = F.hflip(comp), F.hflip(mask), F.hflip(real)
299 |
300 | if not (comp.size[0] == self.image_size and comp.size[1] == self.image_size):
301 | # assert 0
302 | comp = F.resize(comp, [self.image_size, self.image_size])
303 | mask = F.resize(mask, [self.image_size, self.image_size])
304 | real = F.resize(real, [self.image_size, self.image_size])
305 |
306 | comp = self.transforms(comp)
307 | mask = self.transforms_mask(mask)
308 |
309 | real = self.transforms(real)
310 |
311 | return (comp, mask, real, path)
312 |
313 | def __len__(self):
314 | """Return the total number of images."""
315 | return len(self.image_paths)
316 |
317 |
318 | class DataCompositeGAN(Dataset):
319 | def __init__(self, data_directory, ratio=1, augment=False, colorjitter=True, lowres=False,return_raw=False, ratio_constrain=False):
320 | """
321 |
322 | Parameters
323 | ----------
324 | data_directory : str
325 | The directory containing the training image data.
326 | """
327 | self.lowres = lowres
328 | self.image_paths = glob(f"{data_directory}/masks/*_mask.png")
329 |
330 | self.image_paths = self.image_paths[0 : int(len(self.image_paths) * ratio)]
331 |
332 | self.length = len(self.image_paths)
333 | print(
334 | f"Using data from: {data_directory}\nFound {len(self.image_paths)} image paths."
335 | )
336 | self.transforms = T.Compose([T.ToTensor()])
337 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
338 | self.colorjitter = colorjitter
339 | if self.colorjitter:
340 | self.transform_color = T.ColorJitter(
341 | brightness=[0.65, 1.35], contrast=0.2, saturation=0, hue=0
342 | ) ## 0.3 0.7
343 | self.augment = augment
344 | self.returnraw = return_raw
345 | self.ratio_constrain = ratio_constrain
346 | if ratio_constrain:
347 | print("Using Constrained Ratio")
348 |
349 | def __len__(self):
350 | return len(self.image_paths)
351 |
352 | def __getitem__(self, index):
353 | """Get image at the specified index.
354 |
355 | Parameters
356 | ----------
357 | index : int
358 | The image index.
359 |
360 | Returns
361 | -------
362 | patch: torch.Tensor
363 |
364 | Foreground
365 | """
366 |
367 | path_bg = self.image_paths[index] # ForeGround
368 |
369 | path_fg = self.image_paths[np.random.randint(0, self.length)]
370 |
371 | ### fore-ground image loading
372 |
373 | path_fg_image = path_fg.replace("masks/", "real_images/")
374 | path_fg_image = path_fg_image.replace("_mask.png", ".png")
375 |
376 | path_fg_bg = path_fg.replace("masks/", "bg/")
377 |
378 | mask_fg = Image.open(path_fg)
379 |
380 | image_fg = Image.open(path_fg_image)
381 |
382 | image_fg_bg = Image.open(path_fg_bg)
383 |
384 |
385 |
386 | ### back-ground image loading
387 |
388 | path_bg_image = path_bg.replace("masks/", "real_images/")
389 | path_bg_image = path_bg_image.replace("_mask.png", ".png")
390 |
391 | path_bg_bg = path_bg.replace("masks/", "bg/")
392 |
393 | mask_bg = Image.open(path_bg)
394 |
395 | image_bg = Image.open(path_bg_image)
396 |
397 |
398 |
399 | image_bg_augment = image_bg
400 |
401 | if self.augment:
402 | if "before" in path_bg_image:
403 | path_bg_image_augment = path_bg_image.replace("before", "after")
404 | elif "after" in path_bg_image:
405 | path_bg_image_augment = path_bg_image.replace("after", "before")
406 |
407 |
408 | # image_bg_augment = Image.open(path_bg.replace("masks/",'composite/'))
409 | image_bg_augment = Image.open(path_bg_image_augment)
410 |
411 |
412 | image_bg_bg = Image.open(path_bg_bg)
413 | if self.lowres:
414 | mask_fg = mask_fg.resize((256,256))
415 | image_fg = image_fg.resize((256,256))
416 | image_fg_bg = image_fg_bg.resize((256,256))
417 | mask_bg = mask_bg.resize((256,256))
418 | image_bg = image_bg.resize((256,256))
419 | image_bg_augment = image_bg_augment.resize((256,256))
420 | image_bg_bg = image_bg_bg.resize((256,256))
421 |
422 |
423 |
424 |
425 | mask_bg_bbox = mask_bg.getbbox()
426 | mask_fg_bbox = mask_fg.getbbox()
427 |
428 |
429 |
430 | ## Target
431 | x_1_1, y_1_1, x_1_2, y_1_2 = mask_bg_bbox
432 | center_1_x = (x_1_1 + x_1_2) / 2
433 | center_1_y = (y_1_1 + y_1_2) / 2
434 |
435 | ##
436 | x_2_1, y_2_1, x_2_2, y_2_2 = mask_fg_bbox
437 | ration_x = (x_1_2 - x_1_1) / (x_2_2 - x_2_1) if x_2_2 != x_2_1 else 1
438 | ration_y = (y_1_2 - y_1_1) / (y_2_2 - y_2_1) if y_2_2 != y_2_1 else 1
439 |
440 | ## Scaling
441 |
442 | if not self.ratio_constrain:
443 |
444 | mask_fg_aff = F.affine(
445 | mask_fg, angle=0, translate=[0, 0], scale=min(ration_y, ration_x), shear=0
446 | )
447 | image_fg_aff = F.affine(
448 | image_fg, angle=0, translate=[0, 0], scale=min(ration_y, ration_x), shear=0
449 | )
450 | else:
451 |
452 | length_box = max(y_1_2-y_1_1,x_1_2-x_1_1)
453 |
454 | if length_box < 100:
455 | ration_x = (100) / (x_2_2 - x_2_1) if x_2_2 != x_2_1 else 1
456 | ration_y = (100) / (y_2_2 - y_2_1) if y_2_2 != y_2_1 else 1
457 |
458 | mask_fg_aff = F.affine(
459 | mask_fg, angle=0, translate=[0, 0], scale=min(ration_y, ration_x), shear=0)
460 | image_fg_aff = F.affine(
461 | image_fg, angle=0, translate=[0, 0], scale=min(ration_y, ration_x), shear=0
462 | )
463 |
464 |
465 |
466 | if mask_fg_aff.getbbox() == None:
467 | mask_fg_aff = F.affine(mask_fg, angle=0, translate=[0, 0], scale=1, shear=0)
468 |
469 | x_2_1_a, y_2_1_a, x_2_2_a, y_2_2_a = mask_fg_aff.getbbox()
470 | center_2_x_a = (x_2_1_a + x_2_2_a) / 2
471 | center_2_y_a = (y_2_1_a + y_2_2_a) / 2
472 |
473 | shift_fg_x = np.random.randint(-10, 10)
474 | shift_fg_y = np.random.randint(-10, 10)
475 |
476 | mask_fg_aff_all = F.affine(
477 | mask_fg_aff,
478 | angle=0,
479 | translate=[
480 | center_1_x - center_2_x_a + shift_fg_x,
481 | center_1_y - center_2_y_a + shift_fg_y,
482 | ],
483 | scale=1,
484 | shear=0,
485 | )
486 | image_fg_aff_all = F.affine(
487 | image_fg_aff,
488 | angle=0,
489 | translate=[
490 | center_1_x - center_2_x_a + shift_fg_x,
491 | center_1_y - center_2_y_a + shift_fg_y,
492 | ],
493 | scale=1,
494 | shear=0,
495 | )
496 |
497 | if self.colorjitter:
498 | if np.random.rand() < 1:
499 | # print("i love you one")
500 | image_fg_aff_all = self.transform_color(image_fg_aff_all)
501 |
502 | im_composite = Image.composite(image_fg_aff_all, image_bg_bg, mask_fg_aff_all)
503 |
504 | ## What we want to output? Background, im_composite, mask_fg_aff_all, real_image
505 |
506 | if self.returnraw:
507 | if self.colorjitter:
508 | if np.random.rand() < 1:
509 | # print("i love you two")
510 |
511 | image_bg_augment_f = self.transform_color(image_bg_augment)
512 | image_bg_augment = Image.composite(image_bg_augment_f, image_bg, mask_bg)
513 | else:
514 | image_bg_augment = Image.composite(image_bg_augment, image_bg, mask_bg)
515 |
516 |
517 | return (
518 | self.transforms(image_bg_bg),
519 | self.transforms(im_composite),
520 | self.transforms_mask(mask_fg_aff_all),
521 | self.transforms(image_bg),
522 | self.transforms_mask(mask_bg),
523 | self.transforms(image_bg_augment),
524 | path_fg,
525 | path_bg,
526 | )
527 |
528 | else:
529 |
530 | shift_bg_x = np.random.randint(-10, 10)
531 | shift_bg_y = np.random.randint(-10, 10)
532 |
533 | mask_bg_shift = F.affine(
534 | mask_bg,
535 | angle=0,
536 | translate=[
537 | shift_bg_x,
538 | shift_bg_y,
539 | ],
540 | scale=1,
541 | shear=0,
542 | )
543 |
544 | image_bg_shift = F.affine(
545 | image_bg,
546 | angle=0,
547 | translate=[
548 | shift_bg_x,
549 | shift_bg_y,
550 | ],
551 | scale=1,
552 | shear=0,
553 | )
554 |
555 | image_bg_augment_shift = F.affine(
556 | image_bg_augment,
557 | angle=0,
558 | translate=[
559 | shift_bg_x,
560 | shift_bg_y,
561 | ],
562 | scale=1,
563 | shear=0,
564 | )
565 |
566 | im_real = Image.composite(image_bg_shift, image_bg_bg, mask_bg_shift)
567 |
568 | if self.colorjitter:
569 | if np.random.rand() < 1:
570 | # print("i love you two")
571 |
572 | image_bg_augment_shift = self.transform_color(image_bg_augment_shift)
573 |
574 | im_real_augment = Image.composite(
575 | image_bg_augment_shift, image_bg_bg, mask_bg_shift
576 | )
577 |
578 | # Dataset output orders: 1. Background (inpainted) 2. Image Composite 3. Mask 4. Real Image
579 | return (
580 | self.transforms(image_bg_bg),
581 | self.transforms(im_composite),
582 | self.transforms_mask(mask_fg_aff_all),
583 | self.transforms(im_real),
584 | self.transforms_mask(mask_bg_shift),
585 | self.transforms(im_real_augment),
586 | path_fg,
587 | path_bg,
588 | )
589 |
590 |
591 | class PIHData_Composite(Dataset):
592 | def __init__(self, data_directory,lowres,original=False):
593 | """
594 |
595 | Parameters
596 | ----------
597 | data_directory : str
598 | The directory containing the training image data.
599 | max_offset : tuple
600 | The maximum offset to crop an image to.
601 | magnitude : bool
602 | If True, train using magnitude image as input. Otherwise, use real and imaginary image in separate channels.
603 | device : torch.device
604 | The device to load the data to.
605 | complex : bool
606 | If True, return images as complex data. Otherwise check for magnitude return or for real and imaginary
607 | channels. This is needed when training, since post processing is done in the model (adds phase augmentation
608 | and converts to magnitude or channels). Magnitude and channels are implemented for evaluation.
609 | """
610 |
611 | self.image_paths = glob(f"{data_directory}/*_bg.jpg")
612 | print(
613 | f"Using data from: {data_directory}\nFound {len(self.image_paths)} image paths."
614 | )
615 | self.transforms = T.Compose([T.ToTensor()])
616 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
617 | self.lowres = lowres
618 | self.original = original
619 | if lowres:
620 | self.res = 256
621 | else:
622 | self.res = 512
623 |
624 | def __len__(self):
625 | return len(self.image_paths)
626 |
627 | def __getitem__(self, index):
628 | """Get image at the specified index.
629 |
630 | Parameters
631 | ----------
632 | index : int
633 | The image index.
634 |
635 | Returns
636 | -------
637 | patch: torch.Tensor
638 |
639 | """
640 |
641 | image_path = self.image_paths[index]
642 |
643 | if self.original:
644 | input_bg = Image.open(image_path)
645 |
646 | input_composite = Image.open(image_path.replace("bg", "composite"))
647 | input_mask = Image.open(image_path.replace("bg", "mask"))
648 | if os.path.exists(image_path.replace("bg", "real")):
649 | input_real = Image.open(image_path.replace("bg", "real"))
650 | else:
651 | input_real = Image.open(image_path.replace("bg", "gt"))
652 |
653 | else:
654 | input_bg = Image.open(image_path).resize((self.res, self.res))
655 |
656 | input_composite = Image.open(image_path.replace("bg", "composite")).resize(
657 | (self.res, self.res)
658 | )
659 | input_mask = Image.open(image_path.replace("bg", "mask")).resize((self.res, self.res))
660 | if os.path.exists(image_path.replace("bg", "real")):
661 | input_real = Image.open(image_path.replace("bg", "real")).resize((self.res, self.res))
662 | else:
663 | input_real = Image.open(image_path.replace("bg", "gt")).resize((self.res, self.res))
664 |
665 | # original_image = np.load(self.image_paths[index])[None].astype(np.complex64)
666 |
667 | return (
668 | self.transforms(input_bg),
669 | self.transforms(input_composite),
670 | self.transforms_mask(input_mask),
671 | self.transforms(input_real),
672 | image_path,
673 | )
674 |
675 |
676 | class DataCompositeGAN_iharmony(Dataset):
677 | def __init__(
678 | self, data_directory, ratio=1, augment=False, colorjitter=True, return_raw=False,lowres=False
679 | ):
680 | """
681 |
682 | Parameters
683 | ----------
684 | data_directory : str
685 | The directory containing the training image data.
686 | """
687 |
688 | self.image_paths = glob(f"{data_directory}/masks/*_mask.png")
689 |
690 | self.image_paths = self.image_paths[0 : int(len(self.image_paths) * ratio)]
691 |
692 | self.length = len(self.image_paths)
693 | print(
694 | f"Using data from: {data_directory}\nFound {len(self.image_paths)} image paths."
695 | )
696 | self.transforms = T.Compose([T.ToTensor()])
697 | self.transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
698 | self.colorjitter = colorjitter
699 | self.lowres = lowres
700 | if self.colorjitter:
701 | self.transform_color = T.ColorJitter(
702 | brightness=0.3, contrast=0.1, saturation=0.0, hue=0.0
703 | )
704 | self.augment = augment
705 | self.return_raw = return_raw
706 |
707 | def __len__(self):
708 | return len(self.image_paths)
709 |
710 | def __getitem__(self, index):
711 | """Get image at the specified index.
712 |
713 | Parameters
714 | ----------
715 | index : int
716 | The image index.
717 |
718 | Returns
719 | -------
720 | patch: torch.Tensor
721 |
722 | Foreground
723 | """
724 |
725 | path_bg = self.image_paths[index] # ForeGround
726 |
727 | path_fg = self.image_paths[np.random.randint(0, self.length)]
728 |
729 | ### fore-ground image loading
730 |
731 | path_fg_image = path_fg.replace("masks/", "real_images/")
732 | path_fg_image = path_fg_image.replace("_mask.png", ".jpg")
733 |
734 | path_fg_bg = path_fg.replace("masks/", "bg/")
735 |
736 | mask_fg = Image.open(path_fg)
737 |
738 | image_fg = Image.open(path_fg_image)
739 |
740 | image_fg_bg = Image.open(path_fg_bg)
741 |
742 | ### back-ground image loading
743 |
744 | path_bg_image = path_bg.replace("masks/", "real_images/")
745 | path_bg_image = path_bg_image.replace("_mask.png", ".jpg")
746 |
747 | path_bg_bg = path_bg.replace("masks/", "bg/")
748 |
749 | mask_bg = Image.open(path_bg)
750 |
751 | image_bg = Image.open(path_bg_image)
752 |
753 | image_bg_augment = image_bg
754 |
755 | if self.augment:
756 | path_bg_image_augment = path_bg_image.replace(
757 | "real_images", "composite"
758 | ).replace(".jpg", "_composite.jpg")
759 | # if "before" in path_bg_image:
760 | # path_bg_image_augment = path_bg_image.replace("before", "after")
761 | # elif "after" in path_bg_image:
762 | # path_bg_image_augment = path_bg_image.replace("after", "before")
763 | image_bg_augment = Image.open(path_bg_image_augment)
764 |
765 | image_bg_bg = Image.open(path_bg_bg)
766 |
767 |
768 | if self.lowres:
769 | mask_fg = mask_fg.resize((256,256))
770 | image_fg = image_fg.resize((256,256))
771 | image_fg_bg = image_fg_bg.resize((256,256))
772 | mask_bg = mask_bg.resize((256,256))
773 | image_bg = image_bg.resize((256,256))
774 | image_bg_augment = image_bg_augment.resize((256,256))
775 | image_bg_bg = image_bg_bg.resize((256,256))
776 |
777 | mask_bg_bbox = mask_bg.getbbox()
778 | mask_fg_bbox = mask_fg.getbbox()
779 |
780 | ## Target
781 | x_1_1, y_1_1, x_1_2, y_1_2 = mask_bg_bbox
782 | center_1_x = (x_1_1 + x_1_2) / 2
783 | center_1_y = (y_1_1 + y_1_2) / 2
784 |
785 | ##
786 | x_2_1, y_2_1, x_2_2, y_2_2 = mask_fg_bbox
787 | ration_x = (x_1_2 - x_1_1) / (x_2_2 - x_2_1) if x_2_2 != x_2_1 else 1
788 | ration_y = (y_1_2 - y_1_1) / (y_2_2 - y_2_1) if y_2_2 != y_2_1 else 1
789 |
790 | ## Scaling
791 | mask_fg_aff = F.affine(
792 | mask_fg, angle=0, translate=[0, 0], scale=min(ration_y, ration_x), shear=0
793 | )
794 | image_fg_aff = F.affine(
795 | image_fg, angle=0, translate=[0, 0], scale=min(ration_y, ration_x), shear=0
796 | )
797 | if mask_fg_aff.getbbox() == None:
798 | mask_fg_aff = F.affine(mask_fg, angle=0, translate=[0, 0], scale=1, shear=0)
799 |
800 | x_2_1_a, y_2_1_a, x_2_2_a, y_2_2_a = mask_fg_aff.getbbox()
801 | center_2_x_a = (x_2_1_a + x_2_2_a) / 2
802 | center_2_y_a = (y_2_1_a + y_2_2_a) / 2
803 |
804 | shift_fg_x = np.random.randint(-10, 10)
805 | shift_fg_y = np.random.randint(-10, 10)
806 |
807 | mask_fg_aff_all = F.affine(
808 | mask_fg_aff,
809 | angle=0,
810 | translate=[
811 | center_1_x - center_2_x_a + shift_fg_x,
812 | center_1_y - center_2_y_a + shift_fg_y,
813 | ],
814 | scale=1,
815 | shear=0,
816 | )
817 | image_fg_aff_all = F.affine(
818 | image_fg_aff,
819 | angle=0,
820 | translate=[
821 | center_1_x - center_2_x_a + shift_fg_x,
822 | center_1_y - center_2_y_a + shift_fg_y,
823 | ],
824 | scale=1,
825 | shear=0,
826 | )
827 |
828 | if self.colorjitter:
829 | if np.random.rand() < 1:
830 | # print("i love you one")
831 | image_fg_aff_all = self.transform_color(image_fg_aff_all)
832 |
833 | im_composite = Image.composite(image_fg_aff_all, image_bg_bg, mask_fg_aff_all)
834 |
835 | ## What we want to output? Background, im_composite, mask_fg_aff_all, real_image
836 |
837 | if self.return_raw:
838 |
839 | if self.colorjitter:
840 | if np.random.rand() < 1:
841 | # print("i love you two")
842 |
843 | image_bg_augment_f = self.transform_color(image_bg_augment)
844 | image_bg_augment = Image.composite(image_bg_augment_f, image_bg_augment, mask_bg)
845 | # else:
846 | # image_bg_augment = Image.composite(image_bg_augment, image_bg, mask_bg)
847 |
848 |
849 |
850 | return (
851 | self.transforms(image_bg_bg),
852 | self.transforms(im_composite),
853 | self.transforms_mask(mask_fg_aff_all),
854 | self.transforms(image_bg),
855 | self.transforms_mask(mask_bg),
856 | self.transforms(image_bg_augment),
857 | path_fg,
858 | path_bg,
859 | )
860 | else:
861 |
862 | shift_bg_x = np.random.randint(-10, 10)
863 | shift_bg_y = np.random.randint(-10, 10)
864 |
865 | mask_bg_shift = F.affine(
866 | mask_bg,
867 | angle=0,
868 | translate=[
869 | shift_bg_x,
870 | shift_bg_y,
871 | ],
872 | scale=1,
873 | shear=0,
874 | )
875 |
876 | image_bg_shift = F.affine(
877 | image_bg,
878 | angle=0,
879 | translate=[
880 | shift_bg_x,
881 | shift_bg_y,
882 | ],
883 | scale=1,
884 | shear=0,
885 | )
886 |
887 | image_bg_augment_shift = F.affine(
888 | image_bg_augment,
889 | angle=0,
890 | translate=[
891 | shift_bg_x,
892 | shift_bg_y,
893 | ],
894 | scale=1,
895 | shear=0,
896 | )
897 |
898 | im_real = Image.composite(image_bg_shift, image_bg_bg, mask_bg_shift)
899 |
900 | if self.colorjitter:
901 | if np.random.rand() < 1:
902 | # print("i love you two")
903 |
904 | image_bg_augment_shift = self.transform_color(
905 | image_bg_augment_shift
906 | )
907 |
908 | im_real_augment = Image.composite(
909 | image_bg_augment_shift, image_bg_bg, mask_bg_shift
910 | )
911 |
912 | # Dataset output orders: 1. Background (inpainted) 2. Image Composite 3. Mask 4. Real Image
913 | return (
914 | self.transforms(image_bg_bg),
915 | self.transforms(im_composite),
916 | self.transforms_mask(mask_fg_aff_all),
917 | self.transforms(im_real),
918 | self.transforms_mask(mask_bg_shift),
919 | self.transforms(im_real_augment),
920 | path_fg,
921 | path_bg,
922 | )
923 |
--------------------------------------------------------------------------------
/demo_data/train/bg/00057ce6-67c5-411e-8087-3799b638518a_after_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/demo_data/train/bg/00057ce6-67c5-411e-8087-3799b638518a_after_mask.png
--------------------------------------------------------------------------------
/demo_data/train/bg/00057ce6-67c5-411e-8087-3799b638518a_before_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/demo_data/train/bg/00057ce6-67c5-411e-8087-3799b638518a_before_mask.png
--------------------------------------------------------------------------------
/demo_data/train/masks/00057ce6-67c5-411e-8087-3799b638518a_after_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/demo_data/train/masks/00057ce6-67c5-411e-8087-3799b638518a_after_mask.png
--------------------------------------------------------------------------------
/demo_data/train/masks/00057ce6-67c5-411e-8087-3799b638518a_before_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/demo_data/train/masks/00057ce6-67c5-411e-8087-3799b638518a_before_mask.png
--------------------------------------------------------------------------------
/demo_data/train/real_images/00057ce6-67c5-411e-8087-3799b638518a_after.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/demo_data/train/real_images/00057ce6-67c5-411e-8087-3799b638518a_after.png
--------------------------------------------------------------------------------
/demo_data/train/real_images/00057ce6-67c5-411e-8087-3799b638518a_before.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/demo_data/train/real_images/00057ce6-67c5-411e-8087-3799b638518a_before.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: pytorch_pih
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - python=3.8
7 | - cudatoolkit=11.3
8 | - pytorch=1.11.0
9 | - torchvision=0.12.0
10 | - cudnn
11 | - numpy
12 | - pandas
13 | - jupyter
14 | - pip
15 | - tqdm
16 | - ipython
17 | - pillow
18 | - pip:
19 | - matplotlib
20 | - opencv-python
21 |
22 |
23 |
--------------------------------------------------------------------------------
/github_images/Figure5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/Figure5.png
--------------------------------------------------------------------------------
/github_images/Figure_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/Figure_3.png
--------------------------------------------------------------------------------
/github_images/Figure_5_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/Figure_5_final.png
--------------------------------------------------------------------------------
/github_images/Figure_8_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/Figure_8_final.png
--------------------------------------------------------------------------------
/github_images/Figure_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/Figure_teaser.png
--------------------------------------------------------------------------------
/github_images/S1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/S1.png
--------------------------------------------------------------------------------
/github_images/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/github_images/demo.gif
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 |
11 | import matplotlib.pyplot as plt
12 | import torch
13 | import numpy as np
14 | import torchvision.transforms as T
15 | import torchvision.transforms.functional as F
16 | import PIL
17 | from PIL import Image
18 | import cv2
19 | from model import Model_Composite_PL, Model_Composite
20 | from optparse import OptionParser
21 | import os
22 | import time
23 |
24 |
25 | transforms_mask = T.Compose([T.Grayscale(), T.ToTensor()])
26 | transform = T.Compose([T.ToTensor()])
27 | resize = T.Resize((512, 512))
28 |
29 |
30 | def get_concat_h(im1, im2):
31 | dst = Image.new("RGB", (im1.width + im2.width, im1.height))
32 | dst.paste(im1, (0, 0))
33 | dst.paste(im2, (im1.width, 0))
34 | return dst
35 |
36 |
37 |
38 |
39 | def get_args():
40 | parser = OptionParser()
41 | parser.add_option("--bg", help="Directory to the background image.")
42 | parser.add_option("--fg", help="Directory to the foreground image.")
43 |
44 | parser.add_option("--checkpoints", "--ld", help="Directory to checkpoints, default is model/ckpt_g39.pth")
45 |
46 | parser.add_option(
47 | "--gpu",
48 | action="store_true",
49 | help="If specified, will use GPU",
50 | )
51 |
52 | parser.add_option(
53 | "--light",
54 | action="store_true",
55 | help="If specified, will use light model",
56 | )
57 |
58 | (options, args) = parser.parse_args()
59 | return options
60 |
61 |
62 |
63 |
64 |
65 | class Evaluater:
66 | def __init__(self):
67 |
68 | self.args = get_args()
69 |
70 | self.name_cat = self.args.bg.split('/')[-1].split('.')[0]+'_'+self.args.fg.split('/')[-1].split('.')[0]
71 |
72 |
73 | self.fg = Image.open(self.args.fg)
74 |
75 | self.mask = self.fg.split()[-1]
76 |
77 | self.background = Image.open(self.args.bg).resize(self.fg.size)
78 |
79 | self.img_composite = Image.composite(self.fg, self.background, self.mask)
80 |
81 |
82 |
83 |
84 |
85 | if self.args.gpu:
86 | device = "cuda"
87 | else:
88 | device = "cpu"
89 |
90 | self.Model = Model_Composite_PL(
91 | dim=32,
92 | masking=True,
93 | brush=True,
94 | maskoffset=0.6,
95 | swap=True,
96 | Vit_bool=False,
97 | onlyupsample=True,
98 | aggupsample=True,
99 | light=self.args.light,
100 | Eff_bool=self.args.light,
101 | ).to(device)
102 |
103 | if self.args.checkpoints is not None:
104 | model_path = self.args.checkpoints
105 | else:
106 | model_path = os.getcwd() + '/pretrained/ckpt_g39.pth'
107 |
108 | checkpoint = torch.load(model_path, map_location=device)
109 | self.Model.load_state_dict(checkpoint["state_dict"])
110 |
111 | self.Model.eval()
112 | self.bg_low= resize(self.background)
113 | self.composite_low= resize(self.img_composite)
114 | self.mask_low = resize(self.mask)
115 |
116 | # Load image
117 |
118 |
119 | self.torch_bg = transform(self.background).to(device)
120 | self.torch_composite = transform(self.img_composite).to(device)
121 | self.torch_mask = transforms_mask(self.mask).to(device)
122 |
123 | self.torch_bg_low = transform(self.bg_low).to(device)
124 | self.torch_composite_low = transform(self.composite_low).to(device)
125 | self.torch_mask_low = transforms_mask(self.mask_low).to(device)
126 |
127 | def evaluate(self):
128 |
129 | with torch.no_grad():
130 | inter_composite, output_composite, par1, par2 = self.Model(
131 | self.torch_bg_low[None, ...],
132 | self.torch_composite_low[None, ...],
133 | self.torch_mask_low[None, ...],
134 | )
135 |
136 |
137 | hr_intermediate = (
138 | self.Model.PL3D(self.Model.pl_table, self.torch_composite[None,...]) * self.torch_mask
139 | + (1 - self.torch_mask) * self.torch_bg
140 | )
141 |
142 | Gainmap_Resize = T.Resize(self.torch_bg.shape[-2:])
143 | # print(Gain_map)
144 |
145 | output_results = (
146 | hr_intermediate * Gainmap_Resize(self.Model.gainmap) * self.torch_mask
147 | + (1 - self.torch_mask) * self.torch_bg
148 | )
149 |
150 | output_lr = T.ToPILImage()(output_results[0,...])
151 | output_lr.save('results/%s_final.png'%(self.name_cat))
152 |
153 | output_gm = T.ToPILImage()( (Gainmap_Resize(self.Model.gainmap) * self.torch_mask)[0,...])
154 |
155 | output_gm.save('results/%s_gainmap.png'%(self.name_cat))
156 |
157 |
158 |
159 |
160 |
161 | #### Save Fig
162 |
163 | curves = par2.cpu().detach().numpy()
164 |
165 | red_curve = curves[0, 0, 0, 0, :]
166 | green_curve = curves[0, 1, 0, :, 0]
167 | blue_curve = curves[0, 2, :, 0, 0]
168 |
169 | plt.figure()
170 | plt.plot(np.linspace(0, 1, 32), red_curve, "r")
171 | plt.plot(np.linspace(0, 1, 32), green_curve, "g")
172 | plt.plot(np.linspace(0, 1, 32), blue_curve, "b")
173 | plt.ylim(0, 1)
174 | plt.legend(["Reg", "Green", "Blue"])
175 | plt.title("Learned Color Curves")
176 |
177 | plt.savefig("results/%s_color.jpg"%(self.name_cat))
178 |
179 |
180 |
181 | plt.close()
182 |
183 | im_final = get_concat_h( self.img_composite,get_concat_h(self.mask,output_lr))
184 |
185 | im_final.save('results/%s_results_summary.png'%(self.name_cat))
186 |
187 |
188 | if __name__ == "__main__":
189 | torch.backends.cudnn.benchmark = False
190 | torch.backends.cudnn.deterministic = True
191 | evaluater = Evaluater()
192 | evaluater.evaluate()
--------------------------------------------------------------------------------
/inference_scripts/Inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 |
7 | CUDA_VISIBLE_DEVICES=$3
8 |
9 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
10 |
11 | CUDA_VISIBLE_DEVICES=$3 python PIH_test.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing/ \
12 | -g 0 \
13 | --checkpoints $1 \
14 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
15 | --bs 1 \
16 |
17 | CUDA_VISIBLE_DEVICES=$3 python PIH_test.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_different_large/ \
18 | -g 0 \
19 | --checkpoints $1 \
20 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_different/ \
21 | --bs 1 \
22 |
23 |
24 | CUDA_VISIBLE_DEVICES=$3 python PIH_test.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_real_large/ \
25 | -g 0 \
26 | --checkpoints $1 \
27 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_real/ \
28 | --ngt \
29 | --bs 1 \
30 |
31 | CUDA_VISIBLE_DEVICES=$3 python PIH_test.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_real_aug/ \
32 | -g 0 \
33 | --checkpoints $1 \
34 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_real_aug/ \
35 | --ngt \
36 | --bs 1 \
37 |
38 | mkdir /home/kewang/website_michael/results/$2/
39 |
40 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
41 |
42 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
43 | /home/kewang/website_michael/results/$2/results_testing/ \
44 | --order original intermediate results gt \
45 | --name $2-results_testing
46 |
47 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_different/ \
48 | /home/kewang/website_michael/results/$2/results_testing_different/ \
49 | --order original intermediate results gt \
50 | --name $2-results_testing_different
51 |
52 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_real/ \
53 | /home/kewang/website_michael/results/$2/results_testing_real/ \
54 | --order original intermediate results\
55 | --name $2-real
56 |
57 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_real_aug/ \
58 | /home/kewang/website_michael/results/$2/results_testing_real_aug/ \
59 | --order original intermediate results\
60 | --name $2-real-aug
61 |
62 | # # Network hyperparameters
63 | # device=1
64 | # lr=1e-5
65 | # batch_size=16
66 | # date=202206070
67 | # name=iharmonysimplemodel
68 |
69 | # model_name=exp_${date}_batch_size_$((batch_size))_lr_${lr}_${name}
70 |
71 | # # Set folder names
72 | # dir_data=/mnt/localssd/Image_Harmonization_Dataset/
73 | # dir_log=/home/kewang/sensei-fs-symlink/users/kewang/projects/PIH/PIH_ResNet/results/$model_name
74 |
75 |
76 |
77 | # CUDA_VISIBLE_DEVICES=$device python PIH_train.py --datadir $dir_data \
78 | # -g 0 \
79 | # --logdir $dir_log \
80 | # --bs $batch_size \
81 | # --lr $lr \
82 | # --force_train_from_scratch \
83 | # --ihd \
84 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 |
7 | CUDA_VISIBLE_DEVICES=$3
8 |
9 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
10 |
11 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/ \
12 | -g 0 \
13 | --checkpoints $1 \
14 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
15 | --bs 1 \
16 | --composite \
17 | --lut \
18 | --lut-dim 16 \
19 | --num-testing 2000 \
20 | --nocurve \
21 | --piecewiselinear \
22 |
23 | mkdir /home/kewang/website_michael/results/$2/
24 |
25 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
26 |
27 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
28 | /home/kewang/website_michael/results/$2/results_testing/ \
29 | --order bg mask original results real curves\
30 | --name $2-results_testing
31 |
32 |
33 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | if [ $5 == realhm ]
11 | then
12 | a=File_for_testing_composite_realhm
13 | elif [ $5 == realhr ]
14 | then
15 | a=File_for_testing_composite_realhr
16 | elif [ $5 == iharm ]
17 | then
18 | a=File_for_testing_composite_iharm
19 | elif [ $5 == LR ]
20 | then
21 | a=File_for_testing_composite_LR_2000
22 |
23 | elif [ $5 == LRself ]
24 | then
25 | a=File_for_testing_composite_LR_self_2000
26 |
27 |
28 | elif [ $5 == self ]
29 | then
30 | a=File_for_testing_self
31 |
32 | elif [ $5 == realself ]
33 | then
34 | a=File_for_testing_composite_realself
35 |
36 | elif [ $5 == adobe ]
37 | then
38 | a=File_for_testing_composite_adobereal
39 |
40 | else
41 | a=0
42 | fi
43 |
44 | echo Data stream: $a
45 |
46 | CUDA_VISIBLE_DEVICES=$3
47 |
48 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
49 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
50 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
51 |
52 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/$a \
53 | -g 0 \
54 | --checkpoints $1 \
55 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
56 | --bs 1 \
57 | --composite \
58 | --num-testing 50000 \
59 | --nocurve \
60 | --piecewiselinear \
61 | --masking \
62 | --brush \
63 | --maskoffset $4 \
64 | --swap \
65 | --onlyupsample \
66 | --twoinputs \
67 | --aggupsample \
68 | --dim 64 \
69 | # --effbool \
70 |
71 | # --lowres \
72 | # --effbool \
73 | # --dim 64 \
74 | # --lut \
75 | # --lut-dim 16 \
76 |
77 |
78 | # --vitbool \
79 |
80 | # --pihnetbool \
81 | # --pihnetbool \
82 | # --lut \
83 |
84 | # --lut \
85 | # --onlyupsample \
86 | # --maskconvkernel 3 \
87 |
88 | # # --onlyupsample \
89 | # # --maskconvkernel 3 \
90 | # # --maskoffset 0 \
91 |
92 | # # --onlyupsample \
93 | # # --maskconvkernel 3 \
94 | # # --maskoffset 0 \
95 | # # --mask
96 |
97 | # # --onlyupsample \
98 | # # --nosig \
99 |
100 |
101 | mkdir /home/kewang/website_michael/results/$2/
102 |
103 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
104 |
105 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
106 | /home/kewang/website_michael/results/$2/results_testing/ \
107 | --order bg mask original intermediate results real curves\
108 | --name $2-results_testing
109 |
110 |
111 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking_3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | if [ $5 == realhm ]
11 | then
12 | a=File_for_testing_composite_realhm
13 | elif [ $5 == realhr ]
14 | then
15 | a=File_for_testing_composite_realhr
16 | elif [ $5 == iharm ]
17 | then
18 | a=File_for_testing_composite_iharm
19 | elif [ $5 == LR ]
20 | then
21 | a=File_for_testing_composite_LR_2000
22 |
23 | elif [ $5 == LRself ]
24 | then
25 | a=File_for_testing_composite_LR_self_2000
26 |
27 |
28 | elif [ $5 == self ]
29 | then
30 | a=File_for_testing_self
31 |
32 |
33 | elif [ $5 == adobe ]
34 | then
35 | a=File_for_testing_composite_adobereal
36 |
37 | elif [ $5 == realself ]
38 | then
39 | a=File_for_testing_composite_realself
40 |
41 |
42 | else
43 | a=0
44 | fi
45 |
46 | echo Data stream: $a
47 |
48 | CUDA_VISIBLE_DEVICES=$3
49 |
50 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
51 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
52 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
53 |
54 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking.py --datadir /mnt/localssd/$a \
55 | -g 0 \
56 | --checkpoints $1 \
57 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
58 | --bs 1 \
59 | --composite \
60 | --num-testing 500 \
61 | --nocurve \
62 | --piecewiselinear \
63 | --masking \
64 | --brush \
65 | --maskoffset $4 \
66 | --swap \
67 | --onlyupsample \
68 | --aggupsample \
69 | --dim 32 \
70 | # --lut \
71 | # --lut-dim 16 \
72 | # --lowres \
73 | # --vitbool \
74 | # --ibn \
75 | # --effbool \
76 | # --ibn \
77 | # --vitbool \
78 | # --ibn \
79 | # --bgshadow \
80 | # --twoinputs \
81 |
82 |
83 | # --vitbool \
84 |
85 | # --pihnetbool \
86 | # --pihnetbool \
87 | # --lut \
88 |
89 | # --lut \
90 | # --onlyupsample \
91 | # --maskconvkernel 3 \
92 |
93 | # # --onlyupsample \
94 | # # --maskconvkernel 3 \
95 | # # --maskoffset 0 \
96 |
97 | # # --onlyupsample \
98 | # # --maskconvkernel 3 \
99 | # # --maskoffset 0 \
100 | # # --mask
101 |
102 | # # --onlyupsample \
103 | # # --nosig \
104 |
105 |
106 | mkdir /home/kewang/website_michael/results/$2/
107 |
108 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
109 |
110 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
111 | /home/kewang/website_michael/results/$2/results_testing/ \
112 | --order bg mask original intermediate results real curves\
113 | --name $2-results_testing
114 |
115 |
116 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking_3_depth.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | echo dim $6
11 | if [ $5 == realhm ]
12 | then
13 | a=File_for_testing_composite_realhm
14 | elif [ $5 == realhr ]
15 | then
16 | a=File_for_testing_composite_realhr
17 | elif [ $5 == iharm ]
18 | then
19 | a=File_for_testing_composite_iharm
20 | elif [ $5 == LR ]
21 | then
22 | a=File_for_testing_composite_LR_2000
23 |
24 | elif [ $5 == LRself ]
25 | then
26 | a=File_for_testing_composite_LR_self_2000
27 |
28 | elif [ $5 == adobe ]
29 | then
30 | a=File_for_testing_composite_adobereal
31 |
32 | elif [ $5 == realself ]
33 | then
34 | a=File_for_testing_composite_realself
35 |
36 |
37 |
38 |
39 | elif [ $5 == self ]
40 | then
41 | a=File_for_testing_self
42 | else
43 | a=0
44 | fi
45 |
46 | echo Data stream: $a
47 |
48 | CUDA_VISIBLE_DEVICES=$3
49 |
50 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
51 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
52 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
53 |
54 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking.py --datadir /mnt/localssd/$a \
55 | -g 0 \
56 | --checkpoints $1 \
57 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
58 | --bs 1 \
59 | --composite \
60 | --num-testing 500 \
61 | --nocurve \
62 | --piecewiselinear \
63 | --masking \
64 | --brush \
65 | --maskoffset $4 \
66 | --swap \
67 | --onlyupsample \
68 | --aggupsample \
69 | --depthmap \
70 | --dim $6 \
71 | --bgshadow \
72 | --dual \
73 | # --ibn \
74 | # --effbool \
75 | # --twoinputs \
76 |
77 | # --ibn \
78 | # --depth \
79 | # --twoinputs \
80 |
81 |
82 | # --vitbool \
83 |
84 | # --pihnetbool \
85 | # --pihnetbool \
86 | # --lut \
87 |
88 | # --lut \
89 | # --onlyupsample \
90 | # --maskconvkernel 3 \
91 |
92 | # # --onlyupsample \
93 | # # --maskconvkernel 3 \
94 | # # --maskoffset 0 \
95 |
96 | # # --onlyupsample \
97 | # # --maskconvkernel 3 \
98 | # # --maskoffset 0 \
99 | # # --mask
100 |
101 | # # --onlyupsample \
102 | # # --nosig \
103 |
104 |
105 | mkdir /home/kewang/website_michael/results/$2/
106 |
107 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
108 |
109 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
110 | /home/kewang/website_michael/results/$2/results_testing/ \
111 | --order bg mask original intermediate results real curves\
112 | --name $2-results_testing
113 |
114 |
115 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking_3_highres.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | if [ $5 == realhm ]
11 | then
12 | a=File_for_testing_composite_realhm
13 |
14 |
15 | elif [ $5 == realhmfull ]
16 | then
17 | a=File_for_testing_composite_realhmfullres
18 |
19 | elif [ $5 == realhr ]
20 | then
21 | a=File_for_testing_composite_realhr
22 | elif [ $5 == iharm ]
23 | then
24 | a=File_for_testing_composite_iharm
25 | elif [ $5 == LR ]
26 | then
27 | a=File_for_testing_composite_LR_2000
28 |
29 | elif [ $5 == LRself ]
30 | then
31 | a=File_for_testing_composite_LR_self_1000_HR_final
32 |
33 |
34 |
35 | elif [ $5 == self ]
36 | then
37 | a=File_for_testing_self
38 |
39 |
40 | elif [ $5 == adobe ]
41 | then
42 | a=File_for_testing_composite_adobereal
43 |
44 | elif [ $5 == realself ]
45 | then
46 | a=File_for_testing_composite_realself
47 |
48 |
49 | else
50 | a=0
51 | fi
52 |
53 | echo Data stream: $a
54 |
55 | CUDA_VISIBLE_DEVICES=$3
56 |
57 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
58 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
59 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
60 |
61 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking_highres.py --datadir /mnt/localssd/$a \
62 | -g 0 \
63 | --checkpoints $1 \
64 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
65 | --bs 1 \
66 | --composite \
67 | --num-testing 500000 \
68 | --nocurve \
69 | --piecewiselinear \
70 | --masking \
71 | --brush \
72 | --maskoffset $4 \
73 | --swap \
74 | --onlyupsample \
75 | --aggupsample \
76 | --dim 32 \
77 | --original \
78 | --lrdata \
79 |
80 | # --lut \
81 | # --lut-dim 16 \
82 | # --lowres \
83 | # --vitbool \
84 | # --ibn \
85 | # --effbool \
86 | # --ibn \
87 | # --vitbool \
88 | # --ibn \
89 | # --bgshadow \
90 | # --twoinputs \
91 |
92 |
93 | # --vitbool \
94 |
95 | # --pihnetbool \
96 | # --pihnetbool \
97 | # --lut \
98 |
99 | # --lut \
100 | # --onlyupsample \
101 | # --maskconvkernel 3 \
102 |
103 | # # --onlyupsample \
104 | # # --maskconvkernel 3 \
105 | # # --maskoffset 0 \
106 |
107 | # # --onlyupsample \
108 | # # --maskconvkernel 3 \
109 | # # --maskoffset 0 \
110 | # # --mask
111 |
112 | # # --onlyupsample \
113 | # # --nosig \
114 |
115 |
116 | # mkdir /home/kewang/website_michael/results/$2/
117 |
118 | # cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
119 |
120 | # python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
121 | # /home/kewang/website_michael/results/$2/results_testing/ \
122 | # --order bg mask original intermediate results real curves\
123 | # --name $2-results_testing
124 |
125 |
126 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking_3_noweb.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | if [ $5 == realhm ]
11 | then
12 | a=File_for_testing_composite_realhm
13 | elif [ $5 == realhr ]
14 | then
15 | a=File_for_testing_composite_realhr
16 | elif [ $5 == iharm ]
17 | then
18 | a=File_for_testing_composite_iharm
19 | elif [ $5 == LR ]
20 | then
21 | a=File_for_testing_composite_LR_2000
22 |
23 | elif [ $5 == LRself ]
24 | then
25 | a=File_for_testing_composite_LR_self_2000
26 |
27 |
28 | elif [ $5 == self ]
29 | then
30 | a=File_for_testing_self
31 |
32 |
33 | elif [ $5 == adobe ]
34 | then
35 | a=File_for_testing_composite_adobereal
36 |
37 | elif [ $5 == realself ]
38 | then
39 | a=File_for_testing_composite_realself
40 |
41 |
42 | else
43 | a=0
44 | fi
45 |
46 | echo Data stream: $a
47 |
48 | CUDA_VISIBLE_DEVICES=$3
49 |
50 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
51 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
52 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
53 |
54 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking_tiny.py --datadir /mnt/localssd/$a \
55 | -g 0 \
56 | --checkpoints $1 \
57 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
58 | --bs 1 \
59 | --composite \
60 | --num-testing 500 \
61 | --nocurve \
62 | --piecewiselinear \
63 | --masking \
64 | --brush \
65 | --maskoffset $4 \
66 | --swap \
67 | --onlyupsample \
68 | --aggupsample \
69 | --dim 32 \
70 | # --lut \
71 | # --lut-dim 16 \
72 | # --lowres \
73 | # --vitbool \
74 | # --ibn \
75 | # --effbool \
76 | # --ibn \
77 | # --vitbool \
78 | # --ibn \
79 | # --bgshadow \
80 | # --twoinputs \
81 |
82 |
83 | # --vitbool \
84 |
85 | # --pihnetbool \
86 | # --pihnetbool \
87 | # --lut \
88 |
89 | # --lut \
90 | # --onlyupsample \
91 | # --maskconvkernel 3 \
92 |
93 | # # --onlyupsample \
94 | # # --maskconvkernel 3 \
95 | # # --maskoffset 0 \
96 |
97 | # # --onlyupsample \
98 | # # --maskconvkernel 3 \
99 | # # --maskoffset 0 \
100 | # # --mask
101 |
102 | # # --onlyupsample \
103 | # # --nosig \
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking_highres.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | if [ $5 == realhm ]
11 | then
12 | a=File_for_testing_composite_realhm
13 | elif [ $5 == realhr ]
14 | then
15 | a=File_for_testing_composite_realhr
16 | elif [ $5 == iharm ]
17 | then
18 | a=File_for_testing_composite_iharm
19 | elif [ $5 == LR ]
20 | then
21 | a=File_for_testing_composite_LR_2000
22 |
23 | elif [ $5 == LRself ]
24 | then
25 | a=File_for_testing_composite_LR_self_2000
26 |
27 |
28 | elif [ $5 == self ]
29 | then
30 | a=File_for_testing_self
31 |
32 | elif [ $5 == realself ]
33 | then
34 | a=File_for_testing_composite_realself
35 |
36 | elif [ $5 == adobe ]
37 | then
38 | a=File_for_testing_composite_adobereal
39 |
40 | elif [ $5 == adobe5k ]
41 | then
42 | a=File_for_testing_HAdobe5k
43 |
44 |
45 |
46 | else
47 | a=0
48 | fi
49 |
50 | echo Data stream: $a
51 |
52 | CUDA_VISIBLE_DEVICES=$3
53 |
54 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
55 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
56 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
57 |
58 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking_highres.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/$a \
59 | -g 0 \
60 | --checkpoints $1 \
61 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
62 | --bs 1 \
63 | --composite \
64 | --num-testing 50000 \
65 | --nocurve \
66 | --piecewiselinear \
67 | --masking \
68 | --brush \
69 | --maskoffset $4 \
70 | --swap \
71 | --onlyupsample \
72 | --twoinputs \
73 | --aggupsample \
74 | --dim 64 \
75 | --original \
76 | # --effbool \
77 |
78 | # --lowres \
79 | # --effbool \
80 | # --dim 64 \
81 | # --lut \
82 | # --lut-dim 16 \
83 |
84 |
85 | # --vitbool \
86 |
87 | # --pihnetbool \
88 | # --pihnetbool \
89 | # --lut \
90 |
91 | # --lut \
92 | # --onlyupsample \
93 | # --maskconvkernel 3 \
94 |
95 | # # --onlyupsample \
96 | # # --maskconvkernel 3 \
97 | # # --maskoffset 0 \
98 |
99 | # # --onlyupsample \
100 | # # --maskconvkernel 3 \
101 | # # --maskoffset 0 \
102 | # # --mask
103 |
104 | # # --onlyupsample \
105 | # # --nosig \
106 |
107 |
108 | # mkdir /home/kewang/website_michael/results/$2/
109 |
110 | # cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
111 |
112 | # python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
113 | # /home/kewang/website_michael/results/$2/results_testing/ \
114 | # --order bg mask original intermediate results real curves\
115 | # --name $2-results_testing
116 |
117 |
118 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_masking_pixel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 | echo mask offset: $4
7 |
8 | echo $5
9 |
10 | if [ $5 == realhm ]
11 | then
12 | a=File_for_testing_composite_realhm
13 | elif [ $5 == realhr ]
14 | then
15 | a=File_for_testing_composite_realhr
16 | elif [ $5 == iharm ]
17 | then
18 | a=File_for_testing_composite_iharm
19 | elif [ $5 == LR ]
20 | then
21 | a=File_for_testing_composite_LR_2000
22 |
23 | elif [ $5 == LRself ]
24 | then
25 | a=File_for_testing_composite_LR_self_2000
26 |
27 |
28 | elif [ $5 == self ]
29 | then
30 | a=File_for_testing_self
31 |
32 | elif [ $5 == realself ]
33 | then
34 | a=File_for_testing_composite_realself
35 |
36 | elif [ $5 == adobe ]
37 | then
38 | a=File_for_testing_composite_adobereal
39 |
40 | else
41 | a=0
42 | fi
43 |
44 | echo Data stream: $a
45 |
46 | CUDA_VISIBLE_DEVICES=$3
47 |
48 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
49 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_realhm/
50 | # /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_LR_self_2000/
51 |
52 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN_masking.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/$a \
53 | -g 0 \
54 | --checkpoints $1 \
55 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
56 | --bs 1 \
57 | --composite \
58 | --num-testing 500 \
59 | --nocurve \
60 | --piecewiselinear \
61 | --masking \
62 | --maskoffset $4 \
63 | --swap \
64 | --twoinputs \
65 | --dim 64 \
66 | # --lut \
67 | # --lut-dim 24 \
68 | # --effbool \
69 |
70 | # --lowres \
71 | # --effbool \
72 | # --dim 64 \
73 | # --lut \
74 | # --lut-dim 16 \
75 |
76 |
77 | # --vitbool \
78 |
79 | # --pihnetbool \
80 | # --pihnetbool \
81 | # --lut \
82 |
83 | # --lut \
84 | # --onlyupsample \
85 | # --maskconvkernel 3 \
86 |
87 | # # --onlyupsample \
88 | # # --maskconvkernel 3 \
89 | # # --maskoffset 0 \
90 |
91 | # # --onlyupsample \
92 | # # --maskconvkernel 3 \
93 | # # --maskoffset 0 \
94 | # # --mask
95 |
96 | # # --onlyupsample \
97 | # # --nosig \
98 |
99 |
100 | mkdir /home/kewang/website_michael/results/$2/
101 |
102 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
103 |
104 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
105 | /home/kewang/website_michael/results/$2/results_testing/ \
106 | --order bg mask original intermediate results real curves\
107 | --name $2-results_testing
108 |
109 |
110 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_Composite_unet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 |
7 | CUDA_VISIBLE_DEVICES=$3
8 |
9 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
10 |
11 | CUDA_VISIBLE_DEVICES=$3 python PIH_test_compositeGAN.py --datadir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/File_for_testing_composite_2000/ \
12 | -g 0 \
13 | --checkpoints $1 \
14 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
15 | --bs 1 \
16 | --composite \
17 | --unet \
18 | --num-testing 2000 \
19 |
20 | mkdir /home/kewang/website_michael/results/$2/
21 |
22 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
23 |
24 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing/ \
25 | /home/kewang/website_michael/results/$2/results_testing/ \
26 | --order bg mask original intermediate results real \
27 | --name $2-results_testing
28 |
29 | # # Network hyperparameters
30 | # device=1
31 | # lr=1e-5
32 | # batch_size=16
33 | # date=202206070
34 | # name=iharmonysimplemodel
35 |
36 | # model_name=exp_${date}_batch_size_$((batch_size))_lr_${lr}_${name}
37 |
38 | # # Set folder names
39 | # dir_data=/mnt/localssd/Image_Harmonization_Dataset/
40 | # dir_log=/home/kewang/sensei-fs-symlink/users/kewang/projects/PIH/PIH_ResNet/results/$model_name
41 |
42 |
43 |
44 | # CUDA_VISIBLE_DEVICES=$device python PIH_train.py --datadir $dir_data \
45 | # -g 0 \
46 | # --logdir $dir_log \
47 | # --bs $batch_size \
48 | # --lr $lr \
49 | # --force_train_from_scratch \
50 | # --ihd \
51 |
--------------------------------------------------------------------------------
/inference_scripts/Inference_iHarmony.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo Checkpoint: $1
4 | echo Model_Name: $2
5 | echo GPU: $3
6 |
7 | CUDA_VISIBLE_DEVICES=$3
8 |
9 | mkdir /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/
10 |
11 |
12 |
13 | CUDA_VISIBLE_DEVICES=$3 python PIH_test.py --datadir /mnt/localssd/Image_Harmonization_Dataset/ \
14 | -g 0 \
15 | --checkpoints $1 \
16 | --tmp_results /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_iHarmony/ \
17 | --bs 1 \
18 | --num-testing 1000 \
19 | --ihd \
20 |
21 |
22 | mkdir /home/kewang/website_michael/results/$2/
23 |
24 | cd /home/kewang/website_michael/ttools2-main/ttools2/scripts/
25 |
26 | python /home/kewang/website_michael/ttools2-main/ttools2/scripts/image_gallery.py /home/kewang/sensei-fs-symlink/users/kewang/projects/data_processing/results_images/$2/results_testing_iHarmony/ \
27 | /home/kewang/website_michael/results/$2/results_testing_iHarmony/ \
28 | --order original mask intermediate results gt\
29 | --name $2-results-testing-iHarmony
30 |
31 | # # Network hyperparameters
32 | # device=1
33 | # lr=1e-5
34 | # batch_size=16
35 | # date=202206070
36 | # name=iharmonysimplemodel
37 |
38 | # model_name=exp_${date}_batch_size_$((batch_size))_lr_${lr}_${name}
39 |
40 | # # Set folder names
41 | # dir_data=/mnt/localssd/Image_Harmonization_Dataset/
42 | # dir_log=/home/kewang/sensei-fs-symlink/users/kewang/projects/PIH/PIH_ResNet/results/$model_name
43 |
44 |
45 |
46 | # CUDA_VISIBLE_DEVICES=$device python PIH_train.py --datadir $dir_data \
47 | # -g 0 \
48 | # --logdir $dir_log \
49 | # --bs $batch_size \
50 | # --lr $lr \
51 | # --force_train_from_scratch \
52 | # --ihd \
53 |
--------------------------------------------------------------------------------
/pretrained/pretrained.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/pretrained/pretrained.placeholder
--------------------------------------------------------------------------------
/results/results.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adobe/PIH/2823cccf0778c6ea213a3d366f03864ac8ab82e6/results/results.placeholder
--------------------------------------------------------------------------------
/scripts/installation.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | conda create -n pytorch_pih python=3.9
11 | conda activate pytorch_pih
12 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
13 | pip install matplotlib
14 | pip install opencv-python
15 | pip install tqdm
16 |
--------------------------------------------------------------------------------
/scripts/train_example.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 |
11 | # Network hyperparameters
12 | device=0
13 | lr=4e-5
14 | lrd=4e-5
15 | batch_size=8
16 | date=Demo_train_example
17 | reconweight=None
18 | training_ratio=1
19 | inputdimD=3
20 | recon_ratio=0.5
21 | recon_weight=0.92 ## Used here
22 |
23 | name=iharmony_${inputdimD}_ratio_${training_ratio}_${recon_ratio}_reconweight_${recon_weight}
24 |
25 | model_name=exp_${date}_batch_size_$((batch_size))_lr_${lr}_${name}_device_${device}
26 |
27 | # Set folder names
28 | dir_data=demo_data/train/
29 | # dir_data=/mnt/localssd/LR_data/train/
30 |
31 |
32 | CUDA_VISIBLE_DEVICES=$device python PIH_train.py --datadir $dir_data \
33 | -g 0 \
34 | --onlysaveg \
35 | --bs $batch_size \
36 | --lr $lr \
37 | --lrd $lrd \
38 | --force_train_from_scratch \
39 | --tempdir \
40 | $model_name \
41 | --workers 8 \
42 | --trainingratio ${training_ratio} \
43 | --unetd \
44 | --inputdimD ${inputdimD} \
45 | --nocurve \
46 | --reconratio ${recon_ratio} \
47 | --piecewiselinear \
48 | --pl-dim 32 \
49 | --pairaugment \
50 | --purepairaugment \
51 | --lowdim \
52 | --ganlossmask \
53 | --reconwithgan \
54 | --reconweight ${recon_weight} \
55 | --masking \
56 | --brush \
57 | --maskoffset 0.5 \
58 | --swap \
59 | --onlyupsample \
60 | --joint \
61 | --lessskip \
62 | --aggupsample \
63 | --scheduler \
64 | --colorjitter \
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/utils/efficientnet_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | import collections.abc as container_abc
11 | from collections import OrderedDict
12 | from math import ceil, floor
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torch.utils import model_zoo
18 |
19 |
20 | def _pair(x):
21 | if isinstance(x, container_abc.Iterable):
22 | return x
23 | return (x, x)
24 |
25 |
26 | def torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride):
27 | if in_spatial_shape is None:
28 | return None
29 | # in_spatial_shape -> [H,W]
30 | hin, win = _pair(in_spatial_shape)
31 | kh, kw = _pair(kernel_size)
32 | sh, sw = _pair(stride)
33 |
34 | # dilation and padding are ignored since they are always fixed in efficientnetV2
35 | hout = int(floor((hin - kh - 1) / sh + 1))
36 | wout = int(floor((win - kw - 1) / sw + 1))
37 | return hout, wout
38 |
39 |
40 | def get_activation(act_fn: str, **kwargs):
41 | if act_fn in ("silu", "swish"):
42 | return nn.SiLU(**kwargs)
43 | elif act_fn == "relu":
44 | return nn.ReLU(**kwargs)
45 | elif act_fn == "relu6":
46 | return nn.ReLU6(**kwargs)
47 | elif act_fn == "elu":
48 | return nn.ELU(**kwargs)
49 | elif act_fn == "leaky_relu":
50 | return nn.LeakyReLU(**kwargs)
51 | elif act_fn == "selu":
52 | return nn.SELU(**kwargs)
53 | elif act_fn == "mish":
54 | return nn.Mish(**kwargs)
55 | else:
56 | raise ValueError("Unsupported act_fn {}".format(act_fn))
57 |
58 |
59 | def round_filters(filters, width_coefficient, depth_divisor=8):
60 | """Round number of filters based on depth multiplier."""
61 | min_depth = depth_divisor
62 | filters *= width_coefficient
63 | new_filters = max(
64 | min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor
65 | )
66 | return int(new_filters)
67 |
68 |
69 | def round_repeats(repeats, depth_coefficient):
70 | """Round number of filters based on depth multiplier."""
71 | return int(ceil(depth_coefficient * repeats))
72 |
73 |
74 | class DropConnect(nn.Module):
75 | def __init__(self, rate=0.5):
76 | super(DropConnect, self).__init__()
77 | self.keep_prob = None
78 | self.set_rate(rate)
79 |
80 | def set_rate(self, rate):
81 | if not 0 <= rate < 1:
82 | raise ValueError("rate must be 0<=rate<1, got {} instead".format(rate))
83 | self.keep_prob = 1 - rate
84 |
85 | def forward(self, x):
86 | if self.training:
87 | random_tensor = self.keep_prob + torch.rand(
88 | [x.size(0), 1, 1, 1], dtype=x.dtype, device=x.device
89 | )
90 | binary_tensor = torch.floor(random_tensor)
91 | return torch.mul(torch.div(x, self.keep_prob), binary_tensor)
92 | else:
93 | return x
94 |
95 |
96 | class SamePaddingConv2d(nn.Module):
97 | def __init__(
98 | self,
99 | in_spatial_shape,
100 | in_channels,
101 | out_channels,
102 | kernel_size,
103 | stride,
104 | dilation=1,
105 | enforce_in_spatial_shape=False,
106 | **kwargs
107 | ):
108 | super(SamePaddingConv2d, self).__init__()
109 |
110 | self._in_spatial_shape = _pair(in_spatial_shape)
111 | # e.g. throw exception if input spatial shape does not match in_spatial_shape
112 | # when calling self.forward()
113 | self.enforce_in_spatial_shape = enforce_in_spatial_shape
114 | kernel_size = _pair(kernel_size)
115 | stride = _pair(stride)
116 | dilation = _pair(dilation)
117 |
118 | in_height, in_width = self._in_spatial_shape
119 | filter_height, filter_width = kernel_size
120 | stride_heigth, stride_width = stride
121 | dilation_height, dilation_width = dilation
122 |
123 | out_height = int(ceil(float(in_height) / float(stride_heigth)))
124 | out_width = int(ceil(float(in_width) / float(stride_width)))
125 |
126 | pad_along_height = max(
127 | (out_height - 1) * stride_heigth
128 | + filter_height
129 | + (filter_height - 1) * (dilation_height - 1)
130 | - in_height,
131 | 0,
132 | )
133 | pad_along_width = max(
134 | (out_width - 1) * stride_width
135 | + filter_width
136 | + (filter_width - 1) * (dilation_width - 1)
137 | - in_width,
138 | 0,
139 | )
140 |
141 | pad_top = pad_along_height // 2
142 | pad_bottom = pad_along_height - pad_top
143 | pad_left = pad_along_width // 2
144 | pad_right = pad_along_width - pad_left
145 |
146 | paddings = (pad_left, pad_right, pad_top, pad_bottom)
147 | if any(p > 0 for p in paddings):
148 | self.zero_pad = nn.ZeroPad2d(paddings)
149 | else:
150 | self.zero_pad = None
151 | self.conv = nn.Conv2d(
152 | in_channels=in_channels,
153 | out_channels=out_channels,
154 | kernel_size=kernel_size,
155 | stride=stride,
156 | dilation=dilation,
157 | **kwargs
158 | )
159 |
160 | self._out_spatial_shape = (out_height, out_width)
161 |
162 | @property
163 | def out_spatial_shape(self):
164 | return self._out_spatial_shape
165 |
166 | def check_spatial_shape(self, x):
167 | if (
168 | x.size(2) != self._in_spatial_shape[0]
169 | or x.size(3) != self._in_spatial_shape[1]
170 | ):
171 | raise ValueError(
172 | "Expected input spatial shape {}, got {} instead".format(
173 | self._in_spatial_shape, x.shape[2:]
174 | )
175 | )
176 |
177 | def forward(self, x):
178 | if self.enforce_in_spatial_shape:
179 | self.check_spatial_shape(x)
180 | if self.zero_pad is not None:
181 | x = self.zero_pad(x)
182 | x = self.conv(x)
183 | return x
184 |
185 |
186 | class SqueezeExcitate(nn.Module):
187 | def __init__(self, in_channels, se_size, activation=None):
188 | super(SqueezeExcitate, self).__init__()
189 | self.dim_reduce = nn.Conv2d(
190 | in_channels=in_channels, out_channels=se_size, kernel_size=1
191 | )
192 | self.dim_restore = nn.Conv2d(
193 | in_channels=se_size, out_channels=in_channels, kernel_size=1
194 | )
195 | self.activation = F.relu if activation is None else activation
196 |
197 | def forward(self, x):
198 | inp = x
199 | x = F.adaptive_avg_pool2d(x, (1, 1))
200 | x = self.dim_reduce(x)
201 | x = self.activation(x)
202 | x = self.dim_restore(x)
203 | x = torch.sigmoid(x)
204 | return torch.mul(inp, x)
205 |
206 |
207 | class MBConvBlockV2(nn.Module):
208 | def __init__(
209 | self,
210 | in_channels,
211 | out_channels,
212 | kernel_size,
213 | stride,
214 | expansion_factor,
215 | act_fn,
216 | act_kwargs=None,
217 | bn_epsilon=None,
218 | bn_momentum=None,
219 | se_size=None,
220 | drop_connect_rate=None,
221 | bias=False,
222 | tf_style_conv=False,
223 | in_spatial_shape=None,
224 | ):
225 |
226 | super().__init__()
227 |
228 | if act_kwargs is None:
229 | act_kwargs = {}
230 | exp_channels = in_channels * expansion_factor
231 |
232 | self.ops_lst = []
233 |
234 | # expansion convolution
235 | if expansion_factor != 1:
236 | self.expand_conv = nn.Conv2d(
237 | in_channels=in_channels,
238 | out_channels=exp_channels,
239 | kernel_size=1,
240 | bias=bias,
241 | )
242 |
243 | self.expand_bn = nn.BatchNorm2d(
244 | num_features=exp_channels, eps=bn_epsilon, momentum=bn_momentum
245 | )
246 |
247 | self.expand_act = get_activation(act_fn, **act_kwargs)
248 | self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act])
249 |
250 | # depth-wise convolution
251 | if tf_style_conv:
252 | self.dp_conv = SamePaddingConv2d(
253 | in_spatial_shape=in_spatial_shape,
254 | in_channels=exp_channels,
255 | out_channels=exp_channels,
256 | kernel_size=kernel_size,
257 | stride=stride,
258 | groups=exp_channels,
259 | bias=bias,
260 | )
261 | self.out_spatial_shape = self.dp_conv.out_spatial_shape
262 | else:
263 | self.dp_conv = nn.Conv2d(
264 | in_channels=exp_channels,
265 | out_channels=exp_channels,
266 | kernel_size=kernel_size,
267 | stride=stride,
268 | padding=1,
269 | groups=exp_channels,
270 | bias=bias,
271 | )
272 | self.out_spatial_shape = torch_conv_out_spatial_shape(
273 | in_spatial_shape, kernel_size, stride
274 | )
275 |
276 | self.dp_bn = nn.BatchNorm2d(
277 | num_features=exp_channels, eps=bn_epsilon, momentum=bn_momentum
278 | )
279 |
280 | self.dp_act = get_activation(act_fn, **act_kwargs)
281 | self.ops_lst.extend([self.dp_conv, self.dp_bn, self.dp_act])
282 |
283 | # Squeeze and Excitate
284 | if se_size is not None:
285 | self.se = SqueezeExcitate(
286 | exp_channels, se_size, activation=get_activation(act_fn, **act_kwargs)
287 | )
288 | self.ops_lst.append(self.se)
289 |
290 | # projection layer
291 | self.project_conv = nn.Conv2d(
292 | in_channels=exp_channels,
293 | out_channels=out_channels,
294 | kernel_size=1,
295 | bias=bias,
296 | )
297 |
298 | self.project_bn = nn.BatchNorm2d(
299 | num_features=out_channels, eps=bn_epsilon, momentum=bn_momentum
300 | )
301 |
302 | # no activation function in projection layer
303 |
304 | self.ops_lst.extend([self.project_conv, self.project_bn])
305 |
306 | self.skip_enabled = in_channels == out_channels and stride == 1
307 |
308 | if self.skip_enabled and drop_connect_rate is not None:
309 | self.drop_connect = DropConnect(drop_connect_rate)
310 | self.ops_lst.append(self.drop_connect)
311 |
312 | def forward(self, x):
313 | inp = x
314 | for op in self.ops_lst:
315 | x = op(x)
316 | if self.skip_enabled:
317 | return x + inp
318 | else:
319 | return x
320 |
321 |
322 | class FusedMBConvBlockV2(nn.Module):
323 | def __init__(
324 | self,
325 | in_channels,
326 | out_channels,
327 | kernel_size,
328 | stride,
329 | expansion_factor,
330 | act_fn,
331 | act_kwargs=None,
332 | bn_epsilon=None,
333 | bn_momentum=None,
334 | se_size=None,
335 | drop_connect_rate=None,
336 | bias=False,
337 | tf_style_conv=False,
338 | in_spatial_shape=None,
339 | ):
340 |
341 | super().__init__()
342 |
343 | if act_kwargs is None:
344 | act_kwargs = {}
345 | exp_channels = in_channels * expansion_factor
346 |
347 | self.ops_lst = []
348 |
349 | # expansion convolution
350 | expansion_out_shape = in_spatial_shape
351 | if expansion_factor != 1:
352 | if tf_style_conv:
353 | self.expand_conv = SamePaddingConv2d(
354 | in_spatial_shape=in_spatial_shape,
355 | in_channels=in_channels,
356 | out_channels=exp_channels,
357 | kernel_size=kernel_size,
358 | stride=stride,
359 | bias=bias,
360 | )
361 | expansion_out_shape = self.expand_conv.out_spatial_shape
362 | else:
363 | self.expand_conv = nn.Conv2d(
364 | in_channels=in_channels,
365 | out_channels=exp_channels,
366 | kernel_size=kernel_size,
367 | padding=1,
368 | stride=stride,
369 | bias=bias,
370 | )
371 | expansion_out_shape = torch_conv_out_spatial_shape(
372 | in_spatial_shape, kernel_size, stride
373 | )
374 |
375 | self.expand_bn = nn.BatchNorm2d(
376 | num_features=exp_channels, eps=bn_epsilon, momentum=bn_momentum
377 | )
378 |
379 | self.expand_act = get_activation(act_fn, **act_kwargs)
380 | self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act])
381 |
382 | # Squeeze and Excitate
383 | if se_size is not None:
384 | self.se = SqueezeExcitate(
385 | exp_channels, se_size, activation=get_activation(act_fn, **act_kwargs)
386 | )
387 | self.ops_lst.append(self.se)
388 |
389 | # projection layer
390 | kernel_size = 1 if expansion_factor != 1 else kernel_size
391 | stride = 1 if expansion_factor != 1 else stride
392 | if tf_style_conv:
393 | self.project_conv = SamePaddingConv2d(
394 | in_spatial_shape=expansion_out_shape,
395 | in_channels=exp_channels,
396 | out_channels=out_channels,
397 | kernel_size=kernel_size,
398 | stride=stride,
399 | bias=bias,
400 | )
401 | self.out_spatial_shape = self.project_conv.out_spatial_shape
402 | else:
403 | self.project_conv = nn.Conv2d(
404 | in_channels=exp_channels,
405 | out_channels=out_channels,
406 | kernel_size=kernel_size,
407 | stride=stride,
408 | padding=1 if kernel_size > 1 else 0,
409 | bias=bias,
410 | )
411 | self.out_spatial_shape = torch_conv_out_spatial_shape(
412 | expansion_out_shape, kernel_size, stride
413 | )
414 |
415 | self.project_bn = nn.BatchNorm2d(
416 | num_features=out_channels, eps=bn_epsilon, momentum=bn_momentum
417 | )
418 |
419 | self.ops_lst.extend([self.project_conv, self.project_bn])
420 |
421 | if expansion_factor == 1:
422 | self.project_act = get_activation(act_fn, **act_kwargs)
423 | self.ops_lst.append(self.project_act)
424 |
425 | self.skip_enabled = in_channels == out_channels and stride == 1
426 |
427 | if self.skip_enabled and drop_connect_rate is not None:
428 | self.drop_connect = DropConnect(drop_connect_rate)
429 | self.ops_lst.append(self.drop_connect)
430 |
431 | def forward(self, x):
432 | inp = x
433 | for op in self.ops_lst:
434 | x = op(x)
435 | if self.skip_enabled:
436 | return x + inp
437 | else:
438 | return x
439 |
440 |
441 | class EfficientNetV2(nn.Module):
442 | _models = {
443 | "b0": {
444 | "num_repeat": [1, 2, 2, 3, 5, 8],
445 | "kernel_size": [3, 3, 3, 3, 3, 3],
446 | "stride": [1, 2, 2, 2, 1, 2],
447 | "expand_ratio": [1, 4, 4, 4, 6, 6],
448 | "in_channel": [32, 16, 32, 48, 96, 112],
449 | "out_channel": [16, 32, 48, 96, 112, 192],
450 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25],
451 | "conv_type": [1, 1, 1, 0, 0, 0],
452 | "is_feature_stage": [False, True, True, False, True, True],
453 | "width_coefficient": 1.0,
454 | "depth_coefficient": 1.0,
455 | "train_size": 192,
456 | "eval_size": 224,
457 | "dropout": 0.2,
458 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVBhWkZRcWNXR3dINmRLP2U9UUI5ZndH/root/content",
459 | "model_name": "efficientnet_v2_b0_21k_ft1k-a91e14c5.pth",
460 | },
461 | "b1": {
462 | "num_repeat": [1, 2, 2, 3, 5, 8],
463 | "kernel_size": [3, 3, 3, 3, 3, 3],
464 | "stride": [1, 2, 2, 2, 1, 2],
465 | "expand_ratio": [1, 4, 4, 4, 6, 6],
466 | "in_channel": [32, 16, 32, 48, 96, 112],
467 | "out_channel": [16, 32, 48, 96, 112, 192],
468 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25],
469 | "conv_type": [1, 1, 1, 0, 0, 0],
470 | "is_feature_stage": [False, True, True, False, True, True],
471 | "width_coefficient": 1.0,
472 | "depth_coefficient": 1.1,
473 | "train_size": 192,
474 | "eval_size": 240,
475 | "dropout": 0.2,
476 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVJnVGV5UndSY2J2amwtP2U9dTBiV1lO/root/content",
477 | "model_name": "efficientnet_v2_b1_21k_ft1k-58f4fb47.pth",
478 | },
479 | "b2": {
480 | "num_repeat": [1, 2, 2, 3, 5, 8],
481 | "kernel_size": [3, 3, 3, 3, 3, 3],
482 | "stride": [1, 2, 2, 2, 1, 2],
483 | "expand_ratio": [1, 4, 4, 4, 6, 6],
484 | "in_channel": [32, 16, 32, 48, 96, 112],
485 | "out_channel": [16, 32, 48, 96, 112, 192],
486 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25],
487 | "conv_type": [1, 1, 1, 0, 0, 0],
488 | "is_feature_stage": [False, True, True, False, True, True],
489 | "width_coefficient": 1.1,
490 | "depth_coefficient": 1.2,
491 | "train_size": 208,
492 | "eval_size": 260,
493 | "dropout": 0.3,
494 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVY4M2NySVFZbU41X0tGP2U9ZERZVmxK/root/content",
495 | "model_name": "efficientnet_v2_b2_21k_ft1k-db4ac0ee.pth",
496 | },
497 | "b3": {
498 | "num_repeat": [1, 2, 2, 3, 5, 8],
499 | "kernel_size": [3, 3, 3, 3, 3, 3],
500 | "stride": [1, 2, 2, 2, 1, 2],
501 | "expand_ratio": [1, 4, 4, 4, 6, 6],
502 | "in_channel": [32, 16, 32, 48, 96, 112],
503 | "out_channel": [16, 32, 48, 96, 112, 192],
504 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25],
505 | "conv_type": [1, 1, 1, 0, 0, 0],
506 | "is_feature_stage": [False, True, True, False, True, True],
507 | "width_coefficient": 1.2,
508 | "depth_coefficient": 1.4,
509 | "train_size": 240,
510 | "eval_size": 300,
511 | "dropout": 0.3,
512 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVpkamdZUzhhaDdtTTZLP2U9anA4VWN2/root/content",
513 | "model_name": "efficientnet_v2_b3_21k_ft1k-3da5874c.pth",
514 | },
515 | "s": {
516 | "num_repeat": [2, 4, 4, 6, 9, 15],
517 | "kernel_size": [3, 3, 3, 3, 3, 3],
518 | "stride": [1, 2, 2, 2, 1, 2],
519 | "expand_ratio": [1, 4, 4, 4, 6, 6],
520 | "in_channel": [24, 24, 48, 64, 128, 160],
521 | "out_channel": [24, 48, 64, 128, 160, 256],
522 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25],
523 | "conv_type": [1, 1, 1, 0, 0, 0],
524 | "is_feature_stage": [False, True, True, False, True, True],
525 | "width_coefficient": 1.0,
526 | "depth_coefficient": 1.0,
527 | "train_size": 300,
528 | "eval_size": 384,
529 | "dropout": 0.2,
530 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllbFF5VWJOZzd0cmhBbm8/root/content",
531 | "model_name": "efficientnet_v2_s_21k_ft1k-dbb43f38.pth",
532 | },
533 | "m": {
534 | "num_repeat": [2, 5, 5, 7, 14, 18, 5],
535 | "kernel_size": [3, 3, 3, 3, 3, 3, 3],
536 | "stride": [1, 2, 2, 2, 1, 2, 1],
537 | "expand_ratio": [1, 4, 4, 4, 6, 6, 6],
538 | "in_channel": [24, 24, 48, 64, 128, 160, 224],
539 | "out_channel": [24, 48, 64, 128, 160, 224, 512],
540 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25, 0.25],
541 | "conv_type": [1, 1, 1, 0, 0, 0, 0],
542 | "is_feature_stage": [False, True, True, False, True, False, True],
543 | "width_coefficient": 1.0,
544 | "depth_coefficient": 1.0,
545 | "train_size": 384,
546 | "eval_size": 480,
547 | "dropout": 0.3,
548 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllN1ZDazRFb0o1bnlyNUE/root/content",
549 | "model_name": "efficientnet_v2_m_21k_ft1k-da8e56c0.pth",
550 | },
551 | "l": {
552 | "num_repeat": [4, 7, 7, 10, 19, 25, 7],
553 | "kernel_size": [3, 3, 3, 3, 3, 3, 3],
554 | "stride": [1, 2, 2, 2, 1, 2, 1],
555 | "expand_ratio": [1, 4, 4, 4, 6, 6, 6],
556 | "in_channel": [32, 32, 64, 96, 192, 224, 384],
557 | "out_channel": [32, 64, 96, 192, 224, 384, 640],
558 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25, 0.25],
559 | "conv_type": [1, 1, 1, 0, 0, 0, 0],
560 | "is_feature_stage": [False, True, True, False, True, False, True],
561 | "feature_stages": [1, 2, 4, 6],
562 | "width_coefficient": 1.0,
563 | "depth_coefficient": 1.0,
564 | "train_size": 384,
565 | "eval_size": 480,
566 | "dropout": 0.4,
567 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmcmIyRHEtQTBhUTBhWVE/root/content",
568 | "model_name": "efficientnet_v2_l_21k_ft1k-08121eee.pth",
569 | },
570 | "xl": {
571 | "num_repeat": [4, 8, 8, 16, 24, 32, 8],
572 | "kernel_size": [3, 3, 3, 3, 3, 3, 3],
573 | "stride": [1, 2, 2, 2, 1, 2, 1],
574 | "expand_ratio": [1, 4, 4, 4, 6, 6, 6],
575 | "in_channel": [32, 32, 64, 96, 192, 256, 512],
576 | "out_channel": [32, 64, 96, 192, 256, 512, 640],
577 | "se_ratio": [None, None, None, 0.25, 0.25, 0.25, 0.25],
578 | "conv_type": [1, 1, 1, 0, 0, 0, 0],
579 | "is_feature_stage": [False, True, True, False, True, False, True],
580 | "feature_stages": [1, 2, 4, 6],
581 | "width_coefficient": 1.0,
582 | "depth_coefficient": 1.0,
583 | "train_size": 384,
584 | "eval_size": 512,
585 | "dropout": 0.4,
586 | "weight_url": "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmVXQtRHJLa21taUkxWkE/root/content",
587 | "model_name": "efficientnet_v2_xl_21k_ft1k-1fcc9744.pth",
588 | },
589 | }
590 |
591 | def __init__(
592 | self,
593 | model_name,
594 | in_channels=3,
595 | n_classes=1000,
596 | tf_style_conv=False,
597 | in_spatial_shape=None,
598 | activation="silu",
599 | activation_kwargs=None,
600 | bias=False,
601 | drop_connect_rate=0.2,
602 | dropout_rate=None,
603 | bn_epsilon=1e-3,
604 | bn_momentum=0.01,
605 | pretrained=False,
606 | progress=False,
607 | ):
608 | super().__init__()
609 |
610 | self.blocks = nn.ModuleList()
611 | self.model_name = model_name
612 | self.cfg = self._models[model_name]
613 |
614 | if tf_style_conv and in_spatial_shape is None:
615 | in_spatial_shape = self.cfg["eval_size"]
616 |
617 | activation_kwargs = {} if activation_kwargs is None else activation_kwargs
618 | dropout_rate = self.cfg["dropout"] if dropout_rate is None else dropout_rate
619 | _input_ch = in_channels
620 |
621 | self.feature_block_ids = []
622 |
623 | # stem
624 | if tf_style_conv:
625 | self.stem_conv = SamePaddingConv2d(
626 | in_spatial_shape=in_spatial_shape,
627 | in_channels=in_channels,
628 | out_channels=round_filters(
629 | self.cfg["in_channel"][0], self.cfg["width_coefficient"]
630 | ),
631 | kernel_size=3,
632 | stride=2,
633 | bias=bias,
634 | )
635 | in_spatial_shape = self.stem_conv.out_spatial_shape
636 | else:
637 | self.stem_conv = nn.Conv2d(
638 | in_channels=in_channels,
639 | out_channels=round_filters(
640 | self.cfg["in_channel"][0], self.cfg["width_coefficient"]
641 | ),
642 | kernel_size=3,
643 | stride=2,
644 | padding=1,
645 | bias=bias,
646 | )
647 |
648 | self.stem_bn = nn.BatchNorm2d(
649 | num_features=round_filters(
650 | self.cfg["in_channel"][0], self.cfg["width_coefficient"]
651 | ),
652 | eps=bn_epsilon,
653 | momentum=bn_momentum,
654 | )
655 |
656 | self.stem_act = get_activation(activation, **activation_kwargs)
657 |
658 | drop_connect_rates = self.get_dropconnect_rates(drop_connect_rate)
659 |
660 | stages = zip(
661 | *[
662 | self.cfg[x]
663 | for x in [
664 | "num_repeat",
665 | "kernel_size",
666 | "stride",
667 | "expand_ratio",
668 | "in_channel",
669 | "out_channel",
670 | "se_ratio",
671 | "conv_type",
672 | "is_feature_stage",
673 | ]
674 | ]
675 | )
676 |
677 | idx = 0
678 |
679 | for stage_args in stages:
680 | (
681 | num_repeat,
682 | kernel_size,
683 | stride,
684 | expand_ratio,
685 | in_channels,
686 | out_channels,
687 | se_ratio,
688 | conv_type,
689 | is_feature_stage,
690 | ) = stage_args
691 |
692 | in_channels = round_filters(in_channels, self.cfg["width_coefficient"])
693 | out_channels = round_filters(out_channels, self.cfg["width_coefficient"])
694 | num_repeat = round_repeats(num_repeat, self.cfg["depth_coefficient"])
695 |
696 | conv_block = MBConvBlockV2 if conv_type == 0 else FusedMBConvBlockV2
697 |
698 | for _ in range(num_repeat):
699 | se_size = (
700 | None if se_ratio is None else max(1, int(in_channels * se_ratio))
701 | )
702 | _b = conv_block(
703 | in_channels=in_channels,
704 | out_channels=out_channels,
705 | kernel_size=kernel_size,
706 | stride=stride,
707 | expansion_factor=expand_ratio,
708 | act_fn=activation,
709 | act_kwargs=activation_kwargs,
710 | bn_epsilon=bn_epsilon,
711 | bn_momentum=bn_momentum,
712 | se_size=se_size,
713 | drop_connect_rate=drop_connect_rates[idx],
714 | bias=bias,
715 | tf_style_conv=tf_style_conv,
716 | in_spatial_shape=in_spatial_shape,
717 | )
718 | self.blocks.append(_b)
719 | idx += 1
720 | if tf_style_conv:
721 | in_spatial_shape = _b.out_spatial_shape
722 | in_channels = out_channels
723 | stride = 1
724 |
725 | if is_feature_stage:
726 | self.feature_block_ids.append(idx - 1)
727 |
728 | head_conv_out_channels = round_filters(1280, self.cfg["width_coefficient"])
729 |
730 | self.head_conv = nn.Conv2d(
731 | in_channels=in_channels,
732 | out_channels=head_conv_out_channels,
733 | kernel_size=1,
734 | bias=bias,
735 | )
736 | self.head_bn = nn.BatchNorm2d(
737 | num_features=head_conv_out_channels, eps=bn_epsilon, momentum=bn_momentum
738 | )
739 | self.head_act = get_activation(activation, **activation_kwargs)
740 |
741 | self.dropout = nn.Dropout(p=dropout_rate)
742 |
743 | self.avpool = nn.AdaptiveAvgPool2d((1, 1))
744 | self.fc = nn.Linear(head_conv_out_channels, n_classes)
745 |
746 | if pretrained:
747 | self._load_state(_input_ch, n_classes, progress, tf_style_conv)
748 |
749 | return
750 |
751 | def _load_state(self, in_channels, n_classes, progress, tf_style_conv):
752 | state_dict = model_zoo.load_url(
753 | self.cfg["weight_url"], progress=progress, file_name=self.cfg["model_name"]
754 | )
755 |
756 | strict = True
757 |
758 | if not tf_style_conv:
759 | state_dict = OrderedDict(
760 | [
761 | (k.replace(".conv.", "."), v) if ".conv." in k else (k, v)
762 | for k, v in state_dict.items()
763 | ]
764 | )
765 |
766 | if in_channels != 3:
767 | if tf_style_conv:
768 | state_dict.pop("stem_conv.conv.weight")
769 | else:
770 | state_dict.pop("stem_conv.weight")
771 | strict = False
772 |
773 | if n_classes != 1000:
774 | state_dict.pop("fc.weight")
775 | state_dict.pop("fc.bias")
776 | strict = False
777 |
778 | self.load_state_dict(state_dict, strict=strict)
779 | print("Model weights loaded successfully.")
780 |
781 | def get_dropconnect_rates(self, drop_connect_rate):
782 | nr = self.cfg["num_repeat"]
783 | dc = self.cfg["depth_coefficient"]
784 | total = sum(round_repeats(nr[i], dc) for i in range(len(nr)))
785 | return [drop_connect_rate * i / total for i in range(total)]
786 |
787 | def get_features(self, x):
788 | x = self.stem_act(self.stem_bn(self.stem_conv(x)))
789 |
790 | features = []
791 | feat_idx = 0
792 | for block_idx, block in enumerate(self.blocks):
793 | x = block(x)
794 | if block_idx == self.feature_block_ids[feat_idx]:
795 | features.append(x)
796 | feat_idx += 1
797 |
798 | return features
799 |
800 | def forward(self, x):
801 | x = self.stem_act(self.stem_bn(self.stem_conv(x)))
802 | for block in self.blocks:
803 | x = block(x)
804 | x = self.head_act(self.head_bn(self.head_conv(x)))
805 | x = self.dropout(torch.flatten(self.avpool(x), 1))
806 | x = self.fc(x)
807 |
808 | return x
809 |
--------------------------------------------------------------------------------
/utils/mobilenet_v3.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | """
11 | Creates a MobileNetV3 Model as defined in:
12 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
13 | Searching for MobileNetV3
14 | arXiv preprint arXiv:1905.02244.
15 | """
16 |
17 | import torch.nn as nn
18 | import math
19 |
20 |
21 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small']
22 |
23 |
24 | def _make_divisible(v, divisor, min_value=None):
25 | """
26 | This function is taken from the original tf repo.
27 | It ensures that all layers have a channel number that is divisible by 8
28 | It can be seen here:
29 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
30 | :param v:
31 | :param divisor:
32 | :param min_value:
33 | :return:
34 | """
35 | if min_value is None:
36 | min_value = divisor
37 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
38 | # Make sure that round down does not go down by more than 10%.
39 | if new_v < 0.9 * v:
40 | new_v += divisor
41 | return new_v
42 |
43 |
44 | class h_sigmoid(nn.Module):
45 | def __init__(self, inplace=True):
46 | super(h_sigmoid, self).__init__()
47 | self.relu = nn.ReLU6(inplace=inplace)
48 |
49 | def forward(self, x):
50 | return self.relu(x + 3) / 6
51 |
52 |
53 | class h_swish(nn.Module):
54 | def __init__(self, inplace=True):
55 | super(h_swish, self).__init__()
56 | self.sigmoid = h_sigmoid(inplace=inplace)
57 |
58 | def forward(self, x):
59 | return x * self.sigmoid(x)
60 |
61 |
62 | class SELayer(nn.Module):
63 | def __init__(self, channel, reduction=4):
64 | super(SELayer, self).__init__()
65 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
66 | self.fc = nn.Sequential(
67 | nn.Linear(channel, _make_divisible(channel // reduction, 8)),
68 | nn.ReLU(inplace=True),
69 | nn.Linear(_make_divisible(channel // reduction, 8), channel),
70 | h_sigmoid()
71 | )
72 |
73 | def forward(self, x):
74 | b, c, _, _ = x.size()
75 | y = self.avg_pool(x).view(b, c)
76 | y = self.fc(y).view(b, c, 1, 1)
77 | return x * y
78 |
79 |
80 | def conv_3x3_bn(inp, oup, stride):
81 | return nn.Sequential(
82 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
83 | nn.BatchNorm2d(oup),
84 | h_swish()
85 | )
86 |
87 |
88 | def conv_1x1_bn(inp, oup):
89 | return nn.Sequential(
90 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
91 | nn.BatchNorm2d(oup),
92 | h_swish()
93 | )
94 |
95 |
96 | class InvertedResidual(nn.Module):
97 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
98 | super(InvertedResidual, self).__init__()
99 | assert stride in [1, 2]
100 |
101 | self.identity = stride == 1 and inp == oup
102 |
103 | if inp == hidden_dim:
104 | self.conv = nn.Sequential(
105 | # dw
106 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
107 | nn.BatchNorm2d(hidden_dim),
108 | h_swish() if use_hs else nn.ReLU(inplace=True),
109 | # Squeeze-and-Excite
110 | SELayer(hidden_dim) if use_se else nn.Identity(),
111 | # pw-linear
112 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
113 | nn.BatchNorm2d(oup),
114 | )
115 | else:
116 | self.conv = nn.Sequential(
117 | # pw
118 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
119 | nn.BatchNorm2d(hidden_dim),
120 | h_swish() if use_hs else nn.ReLU(inplace=True),
121 | # dw
122 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
123 | nn.BatchNorm2d(hidden_dim),
124 | # Squeeze-and-Excite
125 | SELayer(hidden_dim) if use_se else nn.Identity(),
126 | h_swish() if use_hs else nn.ReLU(inplace=True),
127 | # pw-linear
128 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
129 | nn.BatchNorm2d(oup),
130 | )
131 |
132 | def forward(self, x):
133 | if self.identity:
134 | return x + self.conv(x)
135 | else:
136 | return self.conv(x)
137 |
138 |
139 | class MobileNetV3(nn.Module):
140 | def __init__(self, cfgs, mode,input_f=3, num_classes=1000, width_mult=1.):
141 | super(MobileNetV3, self).__init__()
142 | # setting of inverted residual blocks
143 | self.cfgs = cfgs
144 | assert mode in ['large', 'small']
145 |
146 | # building first layer
147 | input_channel = _make_divisible(16 * width_mult, 8)
148 | layers = [conv_3x3_bn(input_f, input_channel, 2)]
149 | # building inverted residual blocks
150 | block = InvertedResidual
151 | for k, t, c, use_se, use_hs, s in self.cfgs:
152 | output_channel = _make_divisible(c * width_mult, 8)
153 | exp_size = _make_divisible(input_channel * t, 8)
154 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
155 | input_channel = output_channel
156 | self.features = nn.Sequential(*layers)
157 | # building last several layers
158 | self.conv = conv_1x1_bn(input_channel, exp_size)
159 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
160 | output_channel = {'large': 1280, 'small': 1024}
161 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode]
162 | self.classifier = nn.Sequential(
163 | nn.Linear(exp_size, output_channel),
164 | h_swish(),
165 | nn.Dropout(0.2),
166 | nn.Linear(output_channel, num_classes),
167 | )
168 |
169 | self._initialize_weights()
170 |
171 | def forward(self, x):
172 | x = self.features(x)
173 | x = self.conv(x)
174 | x = self.avgpool(x)
175 | x = x.view(x.size(0), -1)
176 | x = self.classifier(x)
177 | return x
178 |
179 | def _initialize_weights(self):
180 | for m in self.modules():
181 | if isinstance(m, nn.Conv2d):
182 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
183 | m.weight.data.normal_(0, math.sqrt(2. / n))
184 | if m.bias is not None:
185 | m.bias.data.zero_()
186 | elif isinstance(m, nn.BatchNorm2d):
187 | m.weight.data.fill_(1)
188 | m.bias.data.zero_()
189 | elif isinstance(m, nn.Linear):
190 | m.weight.data.normal_(0, 0.01)
191 | m.bias.data.zero_()
192 |
193 |
194 | def mobilenetv3_large(**kwargs):
195 | """
196 | Constructs a MobileNetV3-Large model
197 | """
198 | cfgs = [
199 | # k, t, c, SE, HS, s
200 | [3, 1, 16, 0, 0, 1],
201 | [3, 4, 24, 0, 0, 2],
202 | [3, 3, 24, 0, 0, 1],
203 | [5, 3, 40, 1, 0, 2],
204 | [5, 3, 40, 1, 0, 1],
205 | [5, 3, 40, 1, 0, 1],
206 | [3, 6, 80, 0, 1, 2],
207 | [3, 2.5, 80, 0, 1, 1],
208 | [3, 2.3, 80, 0, 1, 1],
209 | [3, 2.3, 80, 0, 1, 1],
210 | [3, 6, 112, 1, 1, 1],
211 | [3, 6, 112, 1, 1, 1],
212 | [5, 6, 160, 1, 1, 2],
213 | [5, 6, 160, 1, 1, 1],
214 | [5, 6, 160, 1, 1, 1]
215 | ]
216 | return MobileNetV3(cfgs, mode='large', **kwargs)
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 | def mobilenetv3_small(input_f = 3,num_classes=1000, **kwargs):
227 | """
228 | Constructs a MobileNetV3-Small model
229 | """
230 | cfgs = [
231 | # k, t, c, SE, HS, s
232 | [3, 1, 16, 1, 0, 2],
233 | [3, 4.5, 24, 0, 0, 2],
234 | [3, 3.67, 24, 0, 0, 1],
235 | [5, 4, 40, 1, 1, 2],
236 | [5, 6, 40, 1, 1, 1],
237 | [5, 6, 40, 1, 1, 1],
238 | [5, 3, 48, 1, 1, 1],
239 | [5, 3, 48, 1, 1, 1],
240 | [5, 6, 96, 1, 1, 2],
241 | [5, 6, 96, 1, 1, 1],
242 | [5, 6, 96, 1, 1, 1],
243 | ]
244 |
245 | return MobileNetV3(cfgs,input_f=input_f,num_classes=num_classes, mode='small', **kwargs)
--------------------------------------------------------------------------------
/utils/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | class IBN(nn.Module):
15 | r"""Instance-Batch Normalization layer from
16 | `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"
17 | `
18 | Args:
19 | planes (int): Number of channels for the input tensor
20 | ratio (float): Ratio of instance normalization in the IBN layer
21 | """
22 |
23 | def __init__(self, planes, ratio=0.5):
24 | super(IBN, self).__init__()
25 | self.half = int(planes * ratio)
26 | self.IN = nn.InstanceNorm2d(self.half, affine=True)
27 | self.BN = nn.BatchNorm2d(planes - self.half)
28 |
29 | def forward(self, x):
30 | split = torch.split(x, self.half, 1)
31 | out1 = self.IN(split[0].contiguous())
32 | out2 = self.BN(split[1].contiguous())
33 | out = torch.cat((out1, out2), 1)
34 | return out
35 |
36 |
37 | class SELayer(nn.Module):
38 | def __init__(self, channel, reduction=16):
39 | super(SELayer, self).__init__()
40 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
41 | self.fc = nn.Sequential(
42 | nn.Linear(channel, int(channel / reduction), bias=False),
43 | nn.ReLU(inplace=True),
44 | nn.Linear(int(channel / reduction), channel, bias=False),
45 | nn.Sigmoid(),
46 | )
47 |
48 | def forward(self, x):
49 | b, c, _, _ = x.size()
50 | y = self.avg_pool(x).view(b, c)
51 | y = self.fc(y).view(b, c, 1, 1)
52 | return x * y.expand_as(x)
53 |
--------------------------------------------------------------------------------
/utils/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn import init
13 | import functools
14 | from torch.optim import lr_scheduler
15 |
16 |
17 | ###############################################################################
18 | # Helper Functions
19 | ###############################################################################
20 |
21 |
22 | class Identity(nn.Module):
23 | def forward(self, x):
24 | return x
25 |
26 |
27 | def get_norm_layer(norm_type="instance"):
28 | """Return a normalization layer
29 | Parameters:
30 | norm_type (str) -- the name of the normalization layer: batch | instance | none
31 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
32 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
33 | """
34 | if norm_type == "batch":
35 | norm_layer = functools.partial(
36 | nn.BatchNorm2d, affine=True, track_running_stats=True
37 | )
38 | elif norm_type == "instance":
39 | norm_layer = functools.partial(
40 | nn.InstanceNorm2d, affine=False, track_running_stats=False
41 | )
42 | elif norm_type == "none":
43 |
44 | def norm_layer(x):
45 | return Identity()
46 |
47 | else:
48 | raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
49 | return norm_layer
50 |
51 |
52 | def get_scheduler(optimizer, opt):
53 | """Return a learning rate scheduler
54 | Parameters:
55 | optimizer -- the optimizer of the network
56 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
57 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
58 | For 'linear', we keep the same learning rate for the first epochs
59 | and linearly decay the rate to zero over the next epochs.
60 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
61 | See https://pytorch.org/docs/stable/optim.html for more details.
62 | """
63 | if opt.lr_policy == "linear":
64 |
65 | def lambda_rule(epoch):
66 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(
67 | opt.n_epochs_decay + 1
68 | )
69 | return lr_l
70 |
71 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
72 | elif opt.lr_policy == "step":
73 | scheduler = lr_scheduler.StepLR(
74 | optimizer, step_size=opt.lr_decay_iters, gamma=0.1
75 | )
76 | elif opt.lr_policy == "plateau":
77 | scheduler = lr_scheduler.ReduceLROnPlateau(
78 | optimizer, mode="min", factor=0.2, threshold=0.01, patience=5
79 | )
80 | elif opt.lr_policy == "cosine":
81 | scheduler = lr_scheduler.CosineAnnealingLR(
82 | optimizer, T_max=opt.n_epochs, eta_min=0
83 | )
84 | else:
85 | return NotImplementedError(
86 | "learning rate policy [%s] is not implemented", opt.lr_policy
87 | )
88 | return scheduler
89 |
90 |
91 | def init_weights(net, init_type="normal", init_gain=0.02):
92 | """Initialize network weights.
93 | Parameters:
94 | net (network) -- network to be initialized
95 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
96 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
97 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
98 | work better for some applications. Feel free to try yourself.
99 | """
100 |
101 | def init_func(m): # define the initialization function
102 | classname = m.__class__.__name__
103 | if hasattr(m, "weight") and (
104 | classname.find("Conv") != -1 or classname.find("Linear") != -1
105 | ):
106 | if init_type == "normal":
107 | init.normal_(m.weight.data, 0.0, init_gain)
108 | elif init_type == "xavier":
109 | init.xavier_normal_(m.weight.data, gain=init_gain)
110 | elif init_type == "kaiming":
111 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
112 | elif init_type == "orthogonal":
113 | init.orthogonal_(m.weight.data, gain=init_gain)
114 | else:
115 | raise NotImplementedError(
116 | "initialization method [%s] is not implemented" % init_type
117 | )
118 | if hasattr(m, "bias") and m.bias is not None:
119 | init.constant_(m.bias.data, 0.0)
120 | elif (
121 | classname.find("BatchNorm2d") != -1
122 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
123 | init.normal_(m.weight.data, 1.0, init_gain)
124 | init.constant_(m.bias.data, 0.0)
125 |
126 | print("initialize network with %s" % init_type)
127 | net.apply(init_func) # apply the initialization function
128 |
129 |
130 | def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]):
131 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
132 | Parameters:
133 | net (network) -- the network to be initialized
134 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
135 | gain (float) -- scaling factor for normal, xavier and orthogonal.
136 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
137 | Return an initialized network.
138 | """
139 | if len(gpu_ids) > 0:
140 | assert torch.cuda.is_available()
141 | net.to(gpu_ids[0])
142 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
143 | init_weights(net, init_type, init_gain=init_gain)
144 | return net
145 |
146 |
147 | def define_G(
148 | input_nc,
149 | output_nc,
150 | ngf,
151 | netG,
152 | norm="batch",
153 | use_dropout=False,
154 | init_type="normal",
155 | init_gain=0.02,
156 | gpu_ids=[],
157 | ):
158 | """Create a generator
159 | Parameters:
160 | input_nc (int) -- the number of channels in input images
161 | output_nc (int) -- the number of channels in output images
162 | ngf (int) -- the number of filters in the last conv layer
163 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
164 | norm (str) -- the name of normalization layers used in the network: batch | instance | none
165 | use_dropout (bool) -- if use dropout layers.
166 | init_type (str) -- the name of our initialization method.
167 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
168 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
169 | Returns a generator
170 | Our current implementation provides two types of generators:
171 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
172 | The original U-Net paper: https://arxiv.org/abs/1505.04597
173 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
174 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
175 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
176 | The generator has been initialized by . It uses RELU for non-linearity.
177 | """
178 | net = None
179 | norm_layer = get_norm_layer(norm_type=norm)
180 |
181 | if netG == "resnet_9blocks":
182 | net = ResnetGenerator(
183 | input_nc,
184 | output_nc,
185 | ngf,
186 | norm_layer=norm_layer,
187 | use_dropout=use_dropout,
188 | n_blocks=9,
189 | )
190 | elif netG == "resnet_6blocks":
191 | net = ResnetGenerator(
192 | input_nc,
193 | output_nc,
194 | ngf,
195 | norm_layer=norm_layer,
196 | use_dropout=use_dropout,
197 | n_blocks=6,
198 | )
199 | elif netG == "unet_128":
200 | net = UnetGenerator(
201 | input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout
202 | )
203 | elif netG == "unet_256":
204 | net = UnetGenerator(
205 | input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout
206 | )
207 | else:
208 | raise NotImplementedError("Generator model name [%s] is not recognized" % netG)
209 | return init_net(net, init_type, init_gain, gpu_ids)
210 |
211 |
212 | def define_D(
213 | input_nc,
214 | ndf,
215 | netD,
216 | n_layers_D=3,
217 | norm="batch",
218 | init_type="normal",
219 | init_gain=0.02,
220 | gpu_ids=[],
221 | ):
222 | """Create a discriminator
223 | Parameters:
224 | input_nc (int) -- the number of channels in input images
225 | ndf (int) -- the number of filters in the first conv layer
226 | netD (str) -- the architecture's name: basic | n_layers | pixel
227 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
228 | norm (str) -- the type of normalization layers used in the network.
229 | init_type (str) -- the name of the initialization method.
230 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
231 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
232 | Returns a discriminator
233 | Our current implementation provides three types of discriminators:
234 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
235 | It can classify whether 70×70 overlapping patches are real or fake.
236 | Such a patch-level discriminator architecture has fewer parameters
237 | than a full-image discriminator and can work on arbitrarily-sized images
238 | in a fully convolutional fashion.
239 | [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
240 | with the parameter (default=3 as used in [basic] (PatchGAN).)
241 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
242 | It encourages greater color diversity but has no effect on spatial statistics.
243 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
244 | """
245 | net = None
246 | norm_layer = get_norm_layer(norm_type=norm)
247 |
248 | if netD == "basic": # default PatchGAN classifier
249 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
250 | elif netD == "n_layers": # more options
251 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
252 | elif netD == "pixel": # classify if each pixel is real or fake
253 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
254 | else:
255 | raise NotImplementedError(
256 | "Discriminator model name [%s] is not recognized" % netD
257 | )
258 | return init_net(net, init_type, init_gain, gpu_ids)
259 |
260 |
261 | ##############################################################################
262 | # Classes
263 | ##############################################################################
264 | class GANLoss(nn.Module):
265 | """Define different GAN objectives.
266 | The GANLoss class abstracts away the need to create the target label tensor
267 | that has the same size as the input.
268 | """
269 |
270 | def __init__(
271 | self,
272 | gan_mode,
273 | gan_loss_mask=False,
274 | target_real_label=1.0,
275 | target_fake_label=0.0,
276 | ):
277 | """Initialize the GANLoss class.
278 | Parameters:
279 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
280 | target_real_label (bool) - - label for a real image
281 | target_fake_label (bool) - - label of a fake image
282 | Note: Do not use sigmoid as the last layer of Discriminator.
283 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
284 | """
285 | super(GANLoss, self).__init__()
286 | self.register_buffer("real_label", torch.tensor(target_real_label))
287 | self.register_buffer("fake_label", torch.tensor(target_fake_label))
288 | self.gan_mode = gan_mode
289 | self.gan_loss_mask = gan_loss_mask
290 | if gan_mode == "lsgan":
291 | self.loss = nn.MSELoss()
292 | elif gan_mode == "vanilla":
293 | self.loss = nn.BCEWithLogitsLoss()
294 | elif gan_mode in ["wgangp"]:
295 | self.loss = None
296 | else:
297 | raise NotImplementedError("gan mode %s not implemented" % gan_mode)
298 |
299 | def get_target_tensor(self, prediction, target_is_real, mask=None):
300 | """Create label tensors with the same size as the input.
301 | Parameters:
302 | prediction (tensor) - - tpyically the prediction from a discriminator
303 | target_is_real (bool) - - if the ground truth label is for real images or fake images
304 | Returns:
305 | A label tensor filled with ground truth label, and with the size of the input
306 | """
307 |
308 | if target_is_real:
309 | target_tensor = self.real_label
310 | return target_tensor.expand_as(prediction)
311 |
312 | else:
313 | if self.gan_loss_mask:
314 | return 1 - mask
315 | else:
316 | target_tensor = self.fake_label
317 | return target_tensor.expand_as(prediction)
318 |
319 | def __call__(self, prediction, target_is_real, mask=None):
320 | """Calculate loss given Discriminator's output and grount truth labels.
321 | Parameters:
322 | prediction (tensor) - - tpyically the prediction output from a discriminator
323 | target_is_real (bool) - - if the ground truth label is for real images or fake images
324 | Returns:
325 | the calculated loss.
326 | """
327 | if self.gan_mode in ["lsgan", "vanilla"]:
328 | target_tensor = self.get_target_tensor(prediction, target_is_real, mask)
329 | loss = self.loss(prediction, target_tensor)
330 | elif self.gan_mode == "wgangp":
331 | if target_is_real:
332 | loss = -prediction.mean()
333 | else:
334 | loss = prediction.mean()
335 | return loss
336 |
337 |
338 | def cal_gradient_penalty(
339 | netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0
340 | ):
341 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
342 | Arguments:
343 | netD (network) -- discriminator network
344 | real_data (tensor array) -- real images
345 | fake_data (tensor array) -- generated images from the generator
346 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
347 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
348 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
349 | lambda_gp (float) -- weight for this loss
350 | Returns the gradient penalty loss
351 | """
352 | if lambda_gp > 0.0:
353 | if (
354 | type == "real"
355 | ): # either use real images, fake images, or a linear interpolation of two.
356 | interpolatesv = real_data
357 | elif type == "fake":
358 | interpolatesv = fake_data
359 | elif type == "mixed":
360 | alpha = torch.rand(real_data.shape[0], 1, device=device)
361 | alpha = (
362 | alpha.expand(
363 | real_data.shape[0], real_data.nelement() // real_data.shape[0]
364 | )
365 | .contiguous()
366 | .view(*real_data.shape)
367 | )
368 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
369 | else:
370 | raise NotImplementedError("{} not implemented".format(type))
371 | interpolatesv.requires_grad_(True)
372 | disc_interpolates = netD(interpolatesv)
373 | gradients = torch.autograd.grad(
374 | outputs=disc_interpolates,
375 | inputs=interpolatesv,
376 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
377 | create_graph=True,
378 | retain_graph=True,
379 | only_inputs=True,
380 | )
381 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
382 | gradient_penalty = (
383 | ((gradients + 1e-16).norm(2, dim=1) - constant) ** 2
384 | ).mean() * lambda_gp # added eps
385 | return gradient_penalty, gradients
386 | else:
387 | return 0.0, None
388 |
389 |
390 | class ResnetGenerator(nn.Module):
391 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
392 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
393 | """
394 |
395 | def __init__(
396 | self,
397 | input_nc,
398 | output_nc,
399 | ngf=64,
400 | norm_layer=nn.BatchNorm2d,
401 | use_dropout=False,
402 | n_blocks=6,
403 | padding_type="reflect",
404 | ):
405 | """Construct a Resnet-based generator
406 | Parameters:
407 | input_nc (int) -- the number of channels in input images
408 | output_nc (int) -- the number of channels in output images
409 | ngf (int) -- the number of filters in the last conv layer
410 | norm_layer -- normalization layer
411 | use_dropout (bool) -- if use dropout layers
412 | n_blocks (int) -- the number of ResNet blocks
413 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
414 | """
415 | assert n_blocks >= 0
416 | super(ResnetGenerator, self).__init__()
417 | if type(norm_layer) == functools.partial:
418 | use_bias = norm_layer.func == nn.InstanceNorm2d
419 | else:
420 | use_bias = norm_layer == nn.InstanceNorm2d
421 |
422 | model = [
423 | nn.ReflectionPad2d(3),
424 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
425 | norm_layer(ngf),
426 | nn.ReLU(True),
427 | ]
428 |
429 | n_downsampling = 2
430 | for i in range(n_downsampling): # add downsampling layers
431 | mult = 2**i
432 | model += [
433 | nn.Conv2d(
434 | ngf * mult,
435 | ngf * mult * 2,
436 | kernel_size=3,
437 | stride=2,
438 | padding=1,
439 | bias=use_bias,
440 | ),
441 | norm_layer(ngf * mult * 2),
442 | nn.ReLU(True),
443 | ]
444 |
445 | mult = 2**n_downsampling
446 | for i in range(n_blocks): # add ResNet blocks
447 |
448 | model += [
449 | ResnetBlock(
450 | ngf * mult,
451 | padding_type=padding_type,
452 | norm_layer=norm_layer,
453 | use_dropout=use_dropout,
454 | use_bias=use_bias,
455 | )
456 | ]
457 |
458 | for i in range(n_downsampling): # add upsampling layers
459 | mult = 2 ** (n_downsampling - i)
460 | model += [
461 | nn.ConvTranspose2d(
462 | ngf * mult,
463 | int(ngf * mult / 2),
464 | kernel_size=3,
465 | stride=2,
466 | padding=1,
467 | output_padding=1,
468 | bias=use_bias,
469 | ),
470 | norm_layer(int(ngf * mult / 2)),
471 | nn.ReLU(True),
472 | ]
473 | model += [nn.ReflectionPad2d(3)]
474 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
475 | model += [nn.Tanh()]
476 |
477 | self.model = nn.Sequential(*model)
478 |
479 | def forward(self, input):
480 | """Standard forward"""
481 | return self.model(input)
482 |
483 |
484 | class ResnetBlock(nn.Module):
485 | """Define a Resnet block"""
486 |
487 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
488 | """Initialize the Resnet block
489 | A resnet block is a conv block with skip connections
490 | We construct a conv block with build_conv_block function,
491 | and implement skip connections in function.
492 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
493 | """
494 | super(ResnetBlock, self).__init__()
495 | self.conv_block = self.build_conv_block(
496 | dim, padding_type, norm_layer, use_dropout, use_bias
497 | )
498 |
499 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
500 | """Construct a convolutional block.
501 | Parameters:
502 | dim (int) -- the number of channels in the conv layer.
503 | padding_type (str) -- the name of padding layer: reflect | replicate | zero
504 | norm_layer -- normalization layer
505 | use_dropout (bool) -- if use dropout layers.
506 | use_bias (bool) -- if the conv layer uses bias or not
507 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
508 | """
509 | conv_block = []
510 | p = 0
511 | if padding_type == "reflect":
512 | conv_block += [nn.ReflectionPad2d(1)]
513 | elif padding_type == "replicate":
514 | conv_block += [nn.ReplicationPad2d(1)]
515 | elif padding_type == "zero":
516 | p = 1
517 | else:
518 | raise NotImplementedError("padding [%s] is not implemented" % padding_type)
519 |
520 | conv_block += [
521 | nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
522 | norm_layer(dim),
523 | nn.ReLU(True),
524 | ]
525 | if use_dropout:
526 | conv_block += [nn.Dropout(0.5)]
527 |
528 | p = 0
529 | if padding_type == "reflect":
530 | conv_block += [nn.ReflectionPad2d(1)]
531 | elif padding_type == "replicate":
532 | conv_block += [nn.ReplicationPad2d(1)]
533 | elif padding_type == "zero":
534 | p = 1
535 | else:
536 | raise NotImplementedError("padding [%s] is not implemented" % padding_type)
537 | conv_block += [
538 | nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
539 | norm_layer(dim),
540 | ]
541 |
542 | return nn.Sequential(*conv_block)
543 |
544 | def forward(self, x):
545 | """Forward function (with skip connections)"""
546 | out = x + self.conv_block(x) # add skip connections
547 | return out
548 |
549 |
550 | class UnetGenerator(nn.Module):
551 | """Create a Unet-based generator"""
552 |
553 | def __init__(
554 | self,
555 | input_nc,
556 | output_nc,
557 | num_downs,
558 | ngf=64,
559 | norm_layer=nn.BatchNorm2d,
560 | use_dropout=False,
561 | ):
562 | """Construct a Unet generator
563 | Parameters:
564 | input_nc (int) -- the number of channels in input images
565 | output_nc (int) -- the number of channels in output images
566 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
567 | image of size 128x128 will become of size 1x1 # at the bottleneck
568 | ngf (int) -- the number of filters in the last conv layer
569 | norm_layer -- normalization layer
570 | We construct the U-Net from the innermost layer to the outermost layer.
571 | It is a recursive process.
572 | """
573 | super(UnetGenerator, self).__init__()
574 | # construct unet structure
575 | unet_block = UnetSkipConnectionBlock(
576 | ngf * 8,
577 | ngf * 8,
578 | input_nc=None,
579 | submodule=None,
580 | norm_layer=norm_layer,
581 | innermost=True,
582 | ) # add the innermost layer
583 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
584 | unet_block = UnetSkipConnectionBlock(
585 | ngf * 8,
586 | ngf * 8,
587 | input_nc=None,
588 | submodule=unet_block,
589 | norm_layer=norm_layer,
590 | use_dropout=use_dropout,
591 | )
592 | # gradually reduce the number of filters from ngf * 8 to ngf
593 | unet_block = UnetSkipConnectionBlock(
594 | ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
595 | )
596 | unet_block = UnetSkipConnectionBlock(
597 | ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
598 | )
599 | unet_block = UnetSkipConnectionBlock(
600 | ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
601 | )
602 | self.model = UnetSkipConnectionBlock(
603 | output_nc,
604 | ngf,
605 | input_nc=input_nc,
606 | submodule=unet_block,
607 | outermost=True,
608 | norm_layer=norm_layer,
609 | ) # add the outermost layer
610 |
611 | def forward(self, input):
612 | """Standard forward"""
613 | return self.model(input)
614 |
615 |
616 | class UnetSkipConnectionBlock(nn.Module):
617 | """Defines the Unet submodule with skip connection.
618 | X -------------------identity----------------------
619 | |-- downsampling -- |submodule| -- upsampling --|
620 | """
621 |
622 | def __init__(
623 | self,
624 | outer_nc,
625 | inner_nc,
626 | input_nc=None,
627 | submodule=None,
628 | outermost=False,
629 | innermost=False,
630 | norm_layer=nn.BatchNorm2d,
631 | use_dropout=False,
632 | ):
633 | """Construct a Unet submodule with skip connections.
634 | Parameters:
635 | outer_nc (int) -- the number of filters in the outer conv layer
636 | inner_nc (int) -- the number of filters in the inner conv layer
637 | input_nc (int) -- the number of channels in input images/features
638 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
639 | outermost (bool) -- if this module is the outermost module
640 | innermost (bool) -- if this module is the innermost module
641 | norm_layer -- normalization layer
642 | use_dropout (bool) -- if use dropout layers.
643 | """
644 | super(UnetSkipConnectionBlock, self).__init__()
645 | self.outermost = outermost
646 | if type(norm_layer) == functools.partial:
647 | use_bias = norm_layer.func == nn.InstanceNorm2d
648 | else:
649 | use_bias = norm_layer == nn.InstanceNorm2d
650 | if input_nc is None:
651 | input_nc = outer_nc
652 | downconv = nn.Conv2d(
653 | input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
654 | )
655 | downrelu = nn.LeakyReLU(0.2, True)
656 | downnorm = norm_layer(inner_nc)
657 | uprelu = nn.ReLU(True)
658 | upnorm = norm_layer(outer_nc)
659 |
660 | if outermost:
661 | upconv = nn.ConvTranspose2d(
662 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1
663 | )
664 | down = [downconv]
665 | up = [uprelu, upconv, nn.Tanh()]
666 | model = down + [submodule] + up
667 | elif innermost:
668 | upconv = nn.ConvTranspose2d(
669 | inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
670 | )
671 | down = [downrelu, downconv]
672 | up = [uprelu, upconv, upnorm]
673 | model = down + up
674 | else:
675 | upconv = nn.ConvTranspose2d(
676 | inner_nc * 2,
677 | outer_nc,
678 | kernel_size=4,
679 | stride=2,
680 | padding=1,
681 | bias=use_bias,
682 | )
683 | down = [downrelu, downconv, downnorm]
684 | up = [uprelu, upconv, upnorm]
685 |
686 | if use_dropout:
687 | model = down + [submodule] + up + [nn.Dropout(0.5)]
688 | else:
689 | model = down + [submodule] + up
690 |
691 | self.model = nn.Sequential(*model)
692 |
693 | def forward(self, x):
694 | if self.outermost:
695 | return self.model(x)
696 | else: # add skip connections
697 | return torch.cat([x, self.model(x)], 1)
698 |
699 |
700 | class NLayerDiscriminator(nn.Module):
701 | """Defines a PatchGAN discriminator"""
702 |
703 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
704 | """Construct a PatchGAN discriminator
705 | Parameters:
706 | input_nc (int) -- the number of channels in input images
707 | ndf (int) -- the number of filters in the last conv layer
708 | n_layers (int) -- the number of conv layers in the discriminator
709 | norm_layer -- normalization layer
710 | """
711 | super(NLayerDiscriminator, self).__init__()
712 | if (
713 | type(norm_layer) == functools.partial
714 | ): # no need to use bias as BatchNorm2d has affine parameters
715 | use_bias = norm_layer.func == nn.InstanceNorm2d
716 | else:
717 | use_bias = norm_layer == nn.InstanceNorm2d
718 |
719 | kw = 4
720 | padw = 1
721 | sequence = [
722 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
723 | nn.LeakyReLU(0.2, True),
724 | ]
725 | nf_mult = 1
726 | nf_mult_prev = 1
727 | for n in range(1, n_layers): # gradually increase the number of filters
728 | nf_mult_prev = nf_mult
729 | nf_mult = min(2**n, 8)
730 | sequence += [
731 | nn.Conv2d(
732 | ndf * nf_mult_prev,
733 | ndf * nf_mult,
734 | kernel_size=kw,
735 | stride=2,
736 | padding=padw,
737 | bias=use_bias,
738 | ),
739 | norm_layer(ndf * nf_mult),
740 | nn.LeakyReLU(0.2, True),
741 | ]
742 |
743 | nf_mult_prev = nf_mult
744 | nf_mult = min(2**n_layers, 8)
745 | sequence += [
746 | nn.Conv2d(
747 | ndf * nf_mult_prev,
748 | ndf * nf_mult,
749 | kernel_size=kw,
750 | stride=1,
751 | padding=padw,
752 | bias=use_bias,
753 | ),
754 | norm_layer(ndf * nf_mult),
755 | nn.LeakyReLU(0.2, True),
756 | ]
757 |
758 | sequence += [
759 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
760 | ] # output 1 channel prediction map
761 | self.model = nn.Sequential(*sequence)
762 |
763 | def forward(self, input):
764 | """Standard forward."""
765 | return self.model(input)
766 |
767 |
768 | class PixelDiscriminator(nn.Module):
769 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
770 |
771 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
772 | """Construct a 1x1 PatchGAN discriminator
773 | Parameters:
774 | input_nc (int) -- the number of channels in input images
775 | ndf (int) -- the number of filters in the last conv layer
776 | norm_layer -- normalization layer
777 | """
778 | super(PixelDiscriminator, self).__init__()
779 | if (
780 | type(norm_layer) == functools.partial
781 | ): # no need to use bias as BatchNorm2d has affine parameters
782 | use_bias = norm_layer.func == nn.InstanceNorm2d
783 | else:
784 | use_bias = norm_layer == nn.InstanceNorm2d
785 |
786 | self.net = [
787 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
788 | nn.LeakyReLU(0.2, True),
789 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
790 | norm_layer(ndf * 2),
791 | nn.LeakyReLU(0.2, True),
792 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias),
793 | ]
794 |
795 | self.net = nn.Sequential(*self.net)
796 |
797 | def forward(self, input):
798 | """Standard forward."""
799 | return self.net(input)
800 |
--------------------------------------------------------------------------------
/utils/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | import torch.nn as nn
11 | import torch.utils.model_zoo as model_zoo
12 | import torch
13 | import torch.nn.functional as f
14 | # from vit_pytorch import ViT
15 |
16 | from utils.efficientnet_v2 import EfficientNetV2
17 | from utils.mobilenet_v3 import MobileNetV3, mobilenetv3_small
18 |
19 | __all__ = [
20 | "ResNet",
21 | "resnet18",
22 | "resnet34",
23 | "resnet50",
24 | "resnet101",
25 | "resnet152",
26 | "resnext50_32x4d",
27 | "resnext101_32x8d",
28 | ]
29 |
30 |
31 | model_urls = {
32 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
33 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
34 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
35 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
36 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
37 | }
38 |
39 |
40 | def conv3x3(in_planes, out_planes, stride=1, groups=1):
41 | """3x3 convolution with padding"""
42 | return nn.Conv2d(
43 | in_planes,
44 | out_planes,
45 | kernel_size=3,
46 | stride=stride,
47 | padding=1,
48 | groups=groups,
49 | bias=False,
50 | )
51 |
52 |
53 | def conv1x1(in_planes, out_planes, stride=1):
54 | """1x1 convolution"""
55 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
56 |
57 |
58 | class BasicBlock(nn.Module):
59 | expansion = 1
60 |
61 | def __init__(
62 | self,
63 | inplanes,
64 | planes,
65 | stride=1,
66 | downsample=None,
67 | groups=1,
68 | base_width=64,
69 | norm_layer=None,
70 | ):
71 | super(BasicBlock, self).__init__()
72 | if norm_layer is None:
73 | norm_layer = nn.BatchNorm2d
74 | if groups != 1 or base_width != 64:
75 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
76 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
77 | self.conv1 = conv3x3(inplanes, planes, stride)
78 | self.bn1 = norm_layer(planes)
79 | self.relu = nn.ReLU(inplace=False)
80 | self.conv2 = conv3x3(planes, planes)
81 | self.bn2 = norm_layer(planes)
82 | self.downsample = downsample
83 | self.stride = stride
84 |
85 | def forward(self, x):
86 | identity = x
87 |
88 | out = self.conv1(x)
89 | out = self.bn1(out)
90 | out = self.relu(out)
91 |
92 | out = self.conv2(out)
93 | out = self.bn2(out)
94 |
95 | if self.downsample is not None:
96 | identity = self.downsample(x)
97 |
98 | out = out + identity
99 | out = self.relu(out)
100 |
101 | return out
102 |
103 |
104 | class Bottleneck(nn.Module):
105 | expansion = 4
106 |
107 | def __init__(
108 | self,
109 | inplanes,
110 | planes,
111 | stride=1,
112 | downsample=None,
113 | groups=1,
114 | base_width=64,
115 | norm_layer=None,
116 | ):
117 | super(Bottleneck, self).__init__()
118 | if norm_layer is None:
119 | norm_layer = nn.BatchNorm2d
120 | width = int(planes * (base_width / 64.0)) * groups
121 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
122 | self.conv1 = conv1x1(inplanes, width)
123 | self.bn1 = norm_layer(width)
124 | self.conv2 = conv3x3(width, width, stride, groups)
125 | self.bn2 = norm_layer(width)
126 | self.conv3 = conv1x1(width, planes * self.expansion)
127 | self.bn3 = norm_layer(planes * self.expansion)
128 | self.relu = nn.ReLU(inplace=False)
129 | self.downsample = downsample
130 | self.stride = stride
131 |
132 | def forward(self, x):
133 | identity = x
134 |
135 | out = self.conv1(x)
136 | out = self.bn1(out)
137 | out = self.relu(out)
138 |
139 | out = self.conv2(out)
140 | out = self.bn2(out)
141 | out = self.relu(out)
142 |
143 | out = self.conv3(out)
144 | out = self.bn3(out)
145 |
146 | if self.downsample is not None:
147 | identity = self.downsample(x)
148 |
149 | out = out + identity
150 | out = self.relu(out)
151 |
152 | return out
153 |
154 |
155 | class ResNet(nn.Module):
156 | def __init__(
157 | self,
158 | block,
159 | layers,
160 | input_f=2,
161 | num_classes=1000,
162 | zero_init_residual=False,
163 | groups=1,
164 | width_per_group=64,
165 | norm_layer=None,
166 | sigmoid=False,
167 | ):
168 | super(ResNet, self).__init__()
169 | if norm_layer is None:
170 | norm_layer = nn.BatchNorm2d
171 |
172 | self.inplanes = 64
173 | self.groups = groups
174 | self.base_width = width_per_group
175 | print(input_f)
176 | self.conv1 = nn.Conv2d(
177 | input_f, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
178 | )
179 | self.bn1 = norm_layer(self.inplanes)
180 | self.relu = nn.ReLU(inplace=False)
181 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
182 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
183 | self.layer2 = self._make_layer(
184 | block, 128, layers[1], stride=2, norm_layer=norm_layer
185 | )
186 | self.layer3 = self._make_layer(
187 | block, 256, layers[2], stride=2, norm_layer=norm_layer
188 | )
189 | self.layer4 = self._make_layer(
190 | block, 512, layers[3], stride=2, norm_layer=norm_layer
191 | )
192 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
193 | self.fc = nn.Linear(512 * block.expansion, num_classes)
194 | self.sigmoid = sigmoid
195 | for m in self.modules():
196 | if isinstance(m, nn.Conv2d):
197 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
198 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
199 | nn.init.constant_(m.weight, 1)
200 | nn.init.constant_(m.bias, 0)
201 |
202 | # Zero-initialize the last BN in each residual branch,
203 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
204 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
205 | if zero_init_residual:
206 | for m in self.modules():
207 | if isinstance(m, Bottleneck):
208 | nn.init.constant_(m.bn3.weight, 0)
209 | elif isinstance(m, BasicBlock):
210 | nn.init.constant_(m.bn2.weight, 0)
211 |
212 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
213 | if norm_layer is None:
214 | norm_layer = nn.BatchNorm2d
215 | downsample = None
216 | if stride != 1 or self.inplanes != planes * block.expansion:
217 | downsample = nn.Sequential(
218 | conv1x1(self.inplanes, planes * block.expansion, stride),
219 | norm_layer(planes * block.expansion),
220 | )
221 |
222 | layers = []
223 | layers.append(
224 | block(
225 | self.inplanes,
226 | planes,
227 | stride,
228 | downsample,
229 | self.groups,
230 | self.base_width,
231 | norm_layer,
232 | )
233 | )
234 | self.inplanes = planes * block.expansion
235 | for _ in range(1, blocks):
236 | layers.append(
237 | block(
238 | self.inplanes,
239 | planes,
240 | groups=self.groups,
241 | base_width=self.base_width,
242 | norm_layer=norm_layer,
243 | )
244 | )
245 |
246 | return nn.Sequential(*layers)
247 |
248 | def forward(self, x):
249 | x = self.conv1(x)
250 | x = self.bn1(x)
251 | x = self.relu(x)
252 | x = self.maxpool(x)
253 |
254 | x = self.layer1(x)
255 | x = self.layer2(x)
256 | x = self.layer3(x)
257 | x = self.layer4(x)
258 |
259 | x = self.avgpool(x)
260 | x1 = x.view(x.size(0), -1)
261 | x = self.fc(x1)
262 | if self.sigmoid:
263 | x = nn.Sigmoid()(x)
264 | else:
265 | pass
266 | # x = f.normalize(x, p=2, dim=1)
267 | return x, x1
268 |
269 |
270 | class ResNet_PIH(nn.Module):
271 | def __init__(
272 | self,
273 | block,
274 | layers,
275 | input_f=2,
276 | zero_init_residual=False,
277 | groups=1,
278 | width_per_group=64,
279 | norm_layer=None,
280 | ):
281 | super(ResNet_PIH, self).__init__()
282 | if norm_layer is None:
283 | norm_layer = nn.BatchNorm2d
284 |
285 | self.inplanes = 64
286 | self.groups = groups
287 | self.base_width = width_per_group
288 | print(input_f)
289 | self.conv1 = nn.Conv2d(
290 | input_f, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
291 | )
292 | self.bn1 = norm_layer(self.inplanes)
293 | self.relu = nn.ReLU(inplace=False)
294 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
295 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
296 | self.layer2 = self._make_layer(
297 | block, 128, layers[1], stride=2, norm_layer=norm_layer
298 | )
299 | self.layer3 = self._make_layer(
300 | block, 256, layers[2], stride=2, norm_layer=norm_layer
301 | )
302 | self.layer4 = self._make_layer(
303 | block, 512, layers[3], stride=2, norm_layer=norm_layer
304 | )
305 | self.avgpool = nn.AdaptiveAvgPool2d((3, 3))
306 | for m in self.modules():
307 | if isinstance(m, nn.Conv2d):
308 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
309 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
310 | nn.init.constant_(m.weight, 1)
311 | nn.init.constant_(m.bias, 0)
312 |
313 | # Zero-initialize the last BN in each residual branch,
314 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
315 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
316 | if zero_init_residual:
317 | for m in self.modules():
318 | if isinstance(m, Bottleneck):
319 | nn.init.constant_(m.bn3.weight, 0)
320 | elif isinstance(m, BasicBlock):
321 | nn.init.constant_(m.bn2.weight, 0)
322 |
323 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
324 | if norm_layer is None:
325 | norm_layer = nn.BatchNorm2d
326 | downsample = None
327 | if stride != 1 or self.inplanes != planes * block.expansion:
328 | downsample = nn.Sequential(
329 | conv1x1(self.inplanes, planes * block.expansion, stride),
330 | norm_layer(planes * block.expansion),
331 | )
332 |
333 | layers = []
334 | layers.append(
335 | block(
336 | self.inplanes,
337 | planes,
338 | stride,
339 | downsample,
340 | self.groups,
341 | self.base_width,
342 | norm_layer,
343 | )
344 | )
345 | self.inplanes = planes * block.expansion
346 | for _ in range(1, blocks):
347 | layers.append(
348 | block(
349 | self.inplanes,
350 | planes,
351 | groups=self.groups,
352 | base_width=self.base_width,
353 | norm_layer=norm_layer,
354 | )
355 | )
356 |
357 | return nn.Sequential(*layers)
358 |
359 | def forward(self, x):
360 | x = self.conv1(x)
361 | x = self.bn1(x)
362 | x = self.relu(x)
363 | x = self.maxpool(x)
364 |
365 | x = self.layer1(x)
366 | x = self.layer2(x)
367 | x = self.layer3(x)
368 | x = self.layer4(x)
369 |
370 | x = self.avgpool(x)
371 | x = x.view(x.size(0), -1)
372 |
373 | return x
374 |
375 |
376 |
377 |
378 |
379 | class EffNetV2(nn.Module):
380 | def __init__(
381 | self, pretrained=False, input_f=7, num_classes=1000, sigmoid=False, **kwargs
382 | ):
383 | super(EffNetV2, self).__init__()
384 | self.num_classes = num_classes
385 |
386 | self.EffNet = EfficientNetV2(
387 | "s", in_channels=input_f, n_classes=self.num_classes, pretrained=False
388 | )
389 | self.sigmoid = sigmoid
390 |
391 | def forward(self, x):
392 |
393 | feature_output = self.EffNet(x)
394 | if self.sigmoid:
395 | feature_output = nn.Sigmoid()(feature_output)
396 |
397 | return feature_output, 0
398 |
399 |
400 |
401 |
402 | class MobileNetV3(nn.Module):
403 | def __init__(
404 | self, pretrained=False, input_f=7, num_classes=1000, sigmoid=False, **kwargs
405 | ):
406 | super(MobileNetV3, self).__init__()
407 | self.num_classes = num_classes
408 |
409 | self.EffNet = mobilenetv3_small(input_f,num_classes)
410 | self.sigmoid = sigmoid
411 |
412 | def forward(self, x):
413 |
414 | feature_output = self.EffNet(x)
415 | if self.sigmoid:
416 | feature_output = nn.Sigmoid()(feature_output)
417 |
418 | return feature_output, 0
419 |
420 |
421 | class PIHNet(nn.Module):
422 | def __init__(
423 | self, pretrained=False, input_f=7, num_classes=1000, sigmoid=False, **kwargs
424 | ):
425 | super(PIHNet, self).__init__()
426 | self.model_bg = ResNet_PIH(BasicBlock, [3, 4, 6, 3], input_f=4, **kwargs)
427 |
428 | self.model_fg = ResNet_PIH(BasicBlock, [3, 4, 6, 3], input_f=4, **kwargs)
429 |
430 | self.model_cp = ResNet_PIH(BasicBlock, [3, 4, 6, 3], input_f=4, **kwargs)
431 |
432 | self.classifier = nn.Sequential(
433 | nn.Linear(512 * 3 * 3, 1024),
434 | nn.ReLU(True),
435 | nn.Linear(1024, num_classes),
436 | )
437 | self.sigmoid = sigmoid
438 |
439 | def forward(self, x):
440 | input_image = x[:, :3, ...]
441 | bg_image = x[:, 3:6, ...]
442 | mask_image = x[:, 6:, ...]
443 | feature_bg = self.model_bg(torch.cat((1 - mask_image, bg_image), 1))
444 | feature_fg = self.model_fg(torch.cat((mask_image, input_image * mask_image), 1))
445 | feature_cp = self.model_cp(torch.cat((mask_image, input_image), 1))
446 |
447 | feature_all = feature_bg + feature_fg + feature_cp
448 | feature_output = self.classifier(feature_all)
449 | if self.sigmoid:
450 | feature_output = nn.Sigmoid()(feature_output)
451 | return feature_output, 0
452 |
453 |
454 | def resnet18(pretrained=False, input_f=4, num_classes=1000, **kwargs):
455 | """Constructs a ResNet-18 model.
456 | Args:
457 | pretrained (bool): If True, returns a model pre-trained on ImageNet
458 | """
459 | model = ResNet(
460 | BasicBlock, [2, 2, 2, 2], input_f=input_f, num_classes=num_classes, **kwargs
461 | )
462 | if pretrained:
463 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
464 | return model
465 |
466 |
467 | def resnet18_m(pretrained=False, num_classes=1000, **kwargs):
468 | """Constructs a ResNet-18 model.
469 | Args:
470 | pretrained (bool): If True, returns a model pre-trained on ImageNet
471 | """
472 | model = ResNet(
473 | BasicBlock, [2, 2, 2, 2], input_f=1, num_classes=num_classes, **kwargs
474 | )
475 | if pretrained:
476 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
477 | return model
478 |
479 |
480 | def resnet34(pretrained=False, input_f=4, num_classes=1000, sigmoid=False, **kwargs):
481 | """Constructs a ResNet-34 model.
482 | Args:
483 | pretrained (bool): If True, returns a model pre-trained on ImageNet
484 | """
485 | model = ResNet(
486 | BasicBlock,
487 | [3, 4, 6, 3],
488 | input_f=input_f,
489 | num_classes=num_classes,
490 | sigmoid=sigmoid,
491 | **kwargs
492 | )
493 | if pretrained:
494 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
495 | return model
496 |
497 |
498 | def resnet50(pretrained=False, input_f=4, num_classes=1000, sigmoid=False, **kwargs):
499 | """Constructs a ResNet-50 model.
500 | Args:
501 | pretrained (bool): If True, returns a model pre-trained on ImageNet
502 | """
503 | model = ResNet(
504 | Bottleneck,
505 | [3, 4, 6, 3],
506 | num_classes=num_classes,
507 | input_f=input_f,
508 | sigmoid=sigmoid,
509 | **kwargs
510 | )
511 | if pretrained:
512 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
513 | return model
514 |
515 |
516 | def resnet101(pretrained=False, input_f=4, num_classes=1000, sigmoid=False, **kwargs):
517 | """Constructs a ResNet-101 model.
518 | Args:
519 | pretrained (bool): If True, returns a model pre-trained on ImageNet
520 | """
521 | # print("Using Resnet 101")
522 | model = ResNet(Bottleneck, [3, 4, 23, 3],
523 | num_classes=num_classes,
524 | input_f=input_f,
525 | sigmoid=sigmoid, **kwargs)
526 | if pretrained:
527 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
528 | return model
529 |
530 |
531 | def resnet152(pretrained=False, **kwargs):
532 | """Constructs a ResNet-152 model.
533 | Args:
534 | pretrained (bool): If True, returns a model pre-trained on ImageNet
535 | """
536 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
537 | if pretrained:
538 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
539 | return model
540 |
541 |
542 | def resnext50_32x4d(pretrained=False, **kwargs):
543 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs)
544 | # if pretrained:
545 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d']))
546 | return model
547 |
548 |
549 | def resnext101_32x8d(pretrained=False, **kwargs):
550 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
551 | # if pretrained:
552 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d']))
553 | return model
554 |
--------------------------------------------------------------------------------
/utils/resnet_ibn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Adobe. All rights reserved.
2 | # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License. You may obtain a copy
4 | # of the License at http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | # Unless required by applicable law or agreed to in writing, software distributed under
7 | # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8 | # OF ANY KIND, either express or implied. See the License for the specific language
9 | # governing permissions and limitations under the License.
10 | import math
11 | import warnings
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 | from utils.modules import IBN
17 |
18 |
19 | __all__ = [
20 | "ResNet_IBN",
21 | "resnet18_ibn_a",
22 | "resnet34_ibn_a",
23 | "resnet50_ibn_a",
24 | "resnet101_ibn_a",
25 | "resnet152_ibn_a",
26 | "resnet18_ibn_b",
27 | "resnet34_ibn_b",
28 | "resnet50_ibn_b",
29 | "resnet101_ibn_b",
30 | "resnet152_ibn_b",
31 | ]
32 |
33 |
34 | model_urls = {
35 | "resnet18_ibn_a": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth",
36 | "resnet34_ibn_a": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth",
37 | "resnet50_ibn_a": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth",
38 | "resnet101_ibn_a": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth",
39 | "resnet18_ibn_b": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_b-bc2f3c11.pth",
40 | "resnet34_ibn_b": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_b-04134c37.pth",
41 | "resnet50_ibn_b": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_b-9ca61e85.pth",
42 | "resnet101_ibn_b": "https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_b-c55f6dba.pth",
43 | }
44 |
45 |
46 | class BasicBlock_IBN(nn.Module):
47 | expansion = 1
48 |
49 | def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
50 | super(BasicBlock_IBN, self).__init__()
51 | self.conv1 = nn.Conv2d(
52 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False
53 | )
54 | if ibn == "a":
55 | self.bn1 = IBN(planes)
56 | else:
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
60 | self.bn2 = nn.BatchNorm2d(planes)
61 | self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == "b" else None
62 | self.downsample = downsample
63 | self.stride = stride
64 |
65 | def forward(self, x):
66 | residual = x
67 |
68 | out = self.conv1(x)
69 | out = self.bn1(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv2(out)
73 | out = self.bn2(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | if self.IN is not None:
80 | out = self.IN(out)
81 | out = self.relu(out)
82 |
83 | return out
84 |
85 |
86 | class Bottleneck_IBN(nn.Module):
87 | expansion = 4
88 |
89 | def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
90 | super(Bottleneck_IBN, self).__init__()
91 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
92 | if ibn == "a":
93 | self.bn1 = IBN(planes)
94 | else:
95 | self.bn1 = nn.BatchNorm2d(planes)
96 | self.conv2 = nn.Conv2d(
97 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
98 | )
99 | self.bn2 = nn.BatchNorm2d(planes)
100 | self.conv3 = nn.Conv2d(
101 | planes, planes * self.expansion, kernel_size=1, bias=False
102 | )
103 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
104 | self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == "b" else None
105 | self.relu = nn.ReLU(inplace=True)
106 | self.downsample = downsample
107 | self.stride = stride
108 |
109 | def forward(self, x):
110 | residual = x
111 |
112 | out = self.conv1(x)
113 | out = self.bn1(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv2(out)
117 | out = self.bn2(out)
118 | out = self.relu(out)
119 |
120 | out = self.conv3(out)
121 | out = self.bn3(out)
122 |
123 | if self.downsample is not None:
124 | residual = self.downsample(x)
125 |
126 | out += residual
127 | if self.IN is not None:
128 | out = self.IN(out)
129 | out = self.relu(out)
130 |
131 | return out
132 |
133 |
134 | class ResNet_IBN(nn.Module):
135 | def __init__(
136 | self,
137 | block,
138 | layers,
139 | ibn_cfg=("a", "a", "a", None),
140 | input_f=7,
141 | num_classes=1000,
142 | sigmoid=False,
143 | ):
144 | self.inplanes = 64
145 | self.sigmoid = sigmoid
146 | super(ResNet_IBN, self).__init__()
147 | self.conv1 = nn.Conv2d(
148 | input_f, 64, kernel_size=7, stride=2, padding=3, bias=False
149 | )
150 | if ibn_cfg[0] == "b":
151 | self.bn1 = nn.InstanceNorm2d(64, affine=True)
152 | else:
153 | self.bn1 = nn.BatchNorm2d(64)
154 | self.relu = nn.ReLU(inplace=True)
155 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
156 | self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0])
157 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1])
158 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2])
159 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3])
160 | self.avgpool = nn.AvgPool2d(7)
161 | self.fc = nn.Linear(512 * 4 * block.expansion, num_classes)
162 |
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
166 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
167 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
168 | m.weight.data.fill_(1)
169 | m.bias.data.zero_()
170 |
171 | def _make_layer(self, block, planes, blocks, stride=1, ibn=None):
172 | downsample = None
173 | if stride != 1 or self.inplanes != planes * block.expansion:
174 | downsample = nn.Sequential(
175 | nn.Conv2d(
176 | self.inplanes,
177 | planes * block.expansion,
178 | kernel_size=1,
179 | stride=stride,
180 | bias=False,
181 | ),
182 | nn.BatchNorm2d(planes * block.expansion),
183 | )
184 |
185 | layers = []
186 | layers.append(
187 | block(
188 | self.inplanes, planes, None if ibn == "b" else ibn, stride, downsample
189 | )
190 | )
191 | self.inplanes = planes * block.expansion
192 | for i in range(1, blocks):
193 | layers.append(
194 | block(
195 | self.inplanes,
196 | planes,
197 | None if (ibn == "b" and i < blocks - 1) else ibn,
198 | )
199 | )
200 |
201 | return nn.Sequential(*layers)
202 |
203 | def forward(self, x):
204 | x = self.conv1(x)
205 | x = self.bn1(x)
206 | x = self.relu(x)
207 | x = self.maxpool(x)
208 |
209 | x = self.layer1(x)
210 | x = self.layer2(x)
211 | x = self.layer3(x)
212 | x = self.layer4(x)
213 |
214 | x = self.avgpool(x)
215 | x = x.view(x.size(0), -1)
216 | x = self.fc(x)
217 |
218 | if self.sigmoid:
219 | x = nn.Sigmoid()(x)
220 | else:
221 | pass
222 |
223 | return x, 0
224 |
225 |
226 | def resnet18_ibn_a(pretrained=False, **kwargs):
227 | """Constructs a ResNet-18-IBN-a model.
228 | Args:
229 | pretrained (bool): If True, returns a model pre-trained on ImageNet
230 | """
231 | model = ResNet_IBN(
232 | block=BasicBlock_IBN,
233 | layers=[2, 2, 2, 2],
234 | ibn_cfg=("a", "a", "a", None),
235 | **kwargs
236 | )
237 | if pretrained:
238 | model.load_state_dict(
239 | torch.hub.load_state_dict_from_url(model_urls["resnet18_ibn_a"])
240 | )
241 | return model
242 |
243 |
244 | def resnet34_ibn_a(pretrained=False, **kwargs):
245 | """Constructs a ResNet-34-IBN-a model.
246 | Args:
247 | pretrained (bool): If True, returns a model pre-trained on ImageNet
248 | """
249 | model = ResNet_IBN(
250 | block=BasicBlock_IBN,
251 | layers=[3, 4, 6, 3],
252 | ibn_cfg=("a", "a", "a", None),
253 | **kwargs
254 | )
255 | if pretrained:
256 | model.load_state_dict(
257 | torch.hub.load_state_dict_from_url(model_urls["resnet34_ibn_a"])
258 | )
259 | return model
260 |
261 |
262 | def resnet50_ibn_a(pretrained=False, **kwargs):
263 | """Constructs a ResNet-50-IBN-a model.
264 | Args:
265 | pretrained (bool): If True, returns a model pre-trained on ImageNet
266 | """
267 | model = ResNet_IBN(
268 | block=Bottleneck_IBN,
269 | layers=[3, 4, 6, 3],
270 | ibn_cfg=("a", "a", "a", None),
271 | **kwargs
272 | )
273 | if pretrained:
274 | model.load_state_dict(
275 | torch.hub.load_state_dict_from_url(model_urls["resnet50_ibn_a"])
276 | )
277 | return model
278 |
279 |
280 | def resnet101_ibn_a(pretrained=False, **kwargs):
281 | """Constructs a ResNet-101-IBN-a model.
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | """
285 | model = ResNet_IBN(
286 | block=Bottleneck_IBN,
287 | layers=[3, 4, 23, 3],
288 | ibn_cfg=("a", "a", "a", None),
289 | **kwargs
290 | )
291 | if pretrained:
292 | model.load_state_dict(
293 | torch.hub.load_state_dict_from_url(model_urls["resnet101_ibn_a"])
294 | )
295 | return model
296 |
297 |
298 | def resnet152_ibn_a(pretrained=False, **kwargs):
299 | """Constructs a ResNet-152-IBN-a model.
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | """
303 | model = ResNet_IBN(
304 | block=Bottleneck_IBN,
305 | layers=[3, 8, 36, 3],
306 | ibn_cfg=("a", "a", "a", None),
307 | **kwargs
308 | )
309 | if pretrained:
310 | warnings.warn("Pretrained model not available for ResNet-152-IBN-a!")
311 | return model
312 |
313 |
314 | def resnet18_ibn_b(pretrained=False, **kwargs):
315 | """Constructs a ResNet-18-IBN-b model.
316 | Args:
317 | pretrained (bool): If True, returns a model pre-trained on ImageNet
318 | """
319 | model = ResNet_IBN(
320 | block=BasicBlock_IBN,
321 | layers=[2, 2, 2, 2],
322 | ibn_cfg=("b", "b", None, None),
323 | **kwargs
324 | )
325 | if pretrained:
326 | model.load_state_dict(
327 | torch.hub.load_state_dict_from_url(model_urls["resnet18_ibn_b"])
328 | )
329 | return model
330 |
331 |
332 | def resnet34_ibn_b(pretrained=False, **kwargs):
333 | """Constructs a ResNet-34-IBN-b model.
334 | Args:
335 | pretrained (bool): If True, returns a model pre-trained on ImageNet
336 | """
337 | model = ResNet_IBN(
338 | block=BasicBlock_IBN,
339 | layers=[3, 4, 6, 3],
340 | ibn_cfg=("b", "b", None, None),
341 | **kwargs
342 | )
343 | if pretrained:
344 | model.load_state_dict(
345 | torch.hub.load_state_dict_from_url(model_urls["resnet34_ibn_b"])
346 | )
347 | return model
348 |
349 |
350 | def resnet50_ibn_b(
351 | pretrained=False, input_f=4, num_classes=1000, sigmoid=False, **kwargs
352 | ):
353 | """Constructs a ResNet-50-IBN-b model.
354 | Args:
355 | pretrained (bool): If True, returns a model pre-trained on ImageNet
356 | """
357 | model = ResNet_IBN(
358 | block=Bottleneck_IBN,
359 | layers=[3, 4, 6, 3],
360 | ibn_cfg=("b", "b", None, None),
361 | input_f=input_f,
362 | num_classes=num_classes,
363 | sigmoid=sigmoid,
364 | **kwargs
365 | )
366 | return model
367 |
368 |
369 | def resnet101_ibn_b(pretrained=False, **kwargs):
370 | """Constructs a ResNet-101-IBN-b model.
371 | Args:
372 | pretrained (bool): If True, returns a model pre-trained on ImageNet
373 | """
374 | model = ResNet_IBN(
375 | block=Bottleneck_IBN,
376 | layers=[3, 4, 23, 3],
377 | ibn_cfg=("b", "b", None, None),
378 | **kwargs
379 | )
380 | if pretrained:
381 | model.load_state_dict(
382 | torch.hub.load_state_dict_from_url(model_urls["resnet101_ibn_b"])
383 | )
384 | return model
385 |
386 |
387 | def resnet152_ibn_b(pretrained=False, **kwargs):
388 | """Constructs a ResNet-152-IBN-b model.
389 | Args:
390 | pretrained (bool): If True, returns a model pre-trained on ImageNet
391 | """
392 | model = ResNet_IBN(
393 | block=Bottleneck_IBN,
394 | layers=[3, 8, 36, 3],
395 | ibn_cfg=("b", "b", None, None),
396 | **kwargs
397 | )
398 | if pretrained:
399 | warnings.warn("Pretrained model not available for ResNet-152-IBN-b!")
400 | return model
401 |
--------------------------------------------------------------------------------
/utils/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_model import UNet
2 |
--------------------------------------------------------------------------------
/utils/unet/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | from .unet_parts import *
4 |
5 |
6 | class UNet(nn.Module):
7 | def __init__(self, n_channels, n_classes, bilinear=False):
8 | super(UNet, self).__init__()
9 | self.n_channels = n_channels
10 | self.n_classes = n_classes
11 | self.bilinear = bilinear
12 |
13 | self.inc = DoubleConv(n_channels, 64)
14 | self.down1 = Down(64, 128)
15 | self.down2 = Down(128, 256)
16 | self.down3 = Down(256, 512)
17 | factor = 2 if bilinear else 1
18 | self.down4 = Down(512, 1024 // factor)
19 | self.up1 = Up(1024, 512 // factor, bilinear)
20 | self.up2 = Up(512, 256 // factor, bilinear)
21 | self.up3 = Up(256, 128 // factor, bilinear)
22 | self.up4 = Up(128, 64, bilinear)
23 | self.outc = OutConv(64, n_classes)
24 |
25 | def forward(self, x):
26 | x1 = self.inc(x)
27 | x2 = self.down1(x1)
28 | x3 = self.down2(x2)
29 | x4 = self.down3(x3)
30 | x5 = self.down4(x4)
31 | x = self.up1(x5, x4)
32 | x = self.up2(x, x3)
33 | x = self.up3(x, x2)
34 | x = self.up4(x, x1)
35 | logits = self.outc(x)
36 | return logits
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/utils/unet/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 | def forward(self, x1, x2):
57 | x1 = self.up(x1)
58 | # input is CHW
59 | diffY = x2.size()[2] - x1.size()[2]
60 | diffX = x2.size()[3] - x1.size()[3]
61 |
62 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63 | diffY // 2, diffY - diffY // 2])
64 | # if you have padding issues, see
65 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67 | x = torch.cat([x2, x1], dim=1)
68 | return self.conv(x)
69 |
70 |
71 | class OutConv(nn.Module):
72 | def __init__(self, in_channels, out_channels):
73 | super(OutConv, self).__init__()
74 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75 | self.sigmoid = nn.Sigmoid()
76 | def forward(self, x):
77 | return self.sigmoid(self.conv(x))
78 |
--------------------------------------------------------------------------------