├── .gitignore
├── LICENSE
├── README.md
├── acvc-pytorch.yml
├── analysis
└── GeneralizationExpProcessor.py
├── assets
├── ACVC_CAM.png
└── ACVC_flow.png
├── losses
├── AttentionConsistency.py
├── Distillation.py
└── JSDivergence.py
├── models
└── ResNet.py
├── preprocessing
├── Datasets.py
└── image
│ ├── ACVCGenerator.py
│ ├── AblationGenerator.py
│ ├── AugMixGenerator.py
│ ├── CutMixGenerator.py
│ ├── CutOutGenerator.py
│ ├── ImageGenerator.py
│ ├── MixUpGenerator.py
│ └── RandAugmentGenerator.py
├── requirements.txt
├── run.py
├── run_experiments.sh
├── settings.ini
├── testers
└── DomainGeneralization_tester.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | results
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 EML Tübingen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ACVC
2 | > [**Attention Consistency on Visual Corruptions for Single-Source Domain Generalization**](https://openaccess.thecvf.com/content/CVPR2022W/L3D-IVU/html/Cugu_Attention_Consistency_on_Visual_Corruptions_for_Single-Source_Domain_Generalization_CVPRW_2022_paper.html)
3 | > [Ilke Cugu](https://cuguilke.github.io/),
4 | > [Massimiliano Mancini](https://www.eml-unitue.de/people/massimiliano-mancini),
5 | > [Yanbei Chen](https://www.eml-unitue.de/people/yanbei-chen),
6 | > [Zeynep Akata](https://www.eml-unitue.de/people/zeynep-akata)
7 | > *IEEE Computer Vision and Pattern Recognition Workshops (CVPRW), 2022*
8 |
9 |
10 |
11 |
12 |
13 | The official PyTorch implementation of the **CVPR 2022, L3D-IVU Workshop** paper titled "Attention Consistency on Visual Corruptions for Single-Source Domain Generalization". This repository contains: (1) our single-source domain generalization benchmark that aims at generalizing from natural images to other domains such as paintings, cliparts and skethces, (2) our adaptation/version of well-known advanced data augmentation techniques in the literaure, and (3) our final model ACVC which fuses visual corruptions with an attention consistency loss.
14 |
15 | ## Dependencies
16 | ```
17 | torch~=1.5.1+cu101
18 | numpy~=1.19.5
19 | torchvision~=0.6.1+cu101
20 | Pillow~=8.3.1
21 | matplotlib~=3.1.1
22 | sklearn~=0.0
23 | scikit-learn~=0.24.1
24 | scipy~=1.6.1
25 | imagecorruptions~=1.1.2
26 | tqdm~=4.58.0
27 | pycocotools~=2.0.0
28 | ```
29 |
30 | - We also include a YAML script `acvc-pytorch.yml` that is prepared for an easy Anaconda environment setup.
31 |
32 | - One can also use the `requirements.txt` if one knows one's craft.
33 |
34 | ## Training
35 |
36 | Training is done via `run.py`. To get the up-to-date list of commands:
37 | ```shell
38 | python run.py --help
39 | ```
40 |
41 | We include a sample script `run_experiments.sh` for a quick start.
42 |
43 | ## Analysis
44 |
45 | The benchmark results are prepared by `analysis/GeneralizationExpProcessor.py`, which outputs LaTeX tables of the cumulative results in a .tex file.
46 |
47 | - For example:
48 | ```shell
49 | python GeneralizationExpProcessor.py --path generalization.json --to_dir ./results --image_format pdf
50 | ```
51 |
52 | - You can also run distributed experiments, and merge the results later on:
53 | ```shell
54 | python GeneralizationExpProcessor.py --merge_logs generalization_gpu0.json generalization_gpu1.json
55 | ```
56 |
57 | ## Case Study: COCO benchmark
58 |
59 | COCO benchmark is especially useful for further studies on ACVC since it includes segmentation masks per image.
60 |
61 | Here are the steps to make it work:
62 | 1. For this benchmark you only need 10 classes:
63 | ```
64 | airplane
65 | bicycle
66 | bus
67 | car
68 | horse
69 | knife
70 | motorcycle
71 | skateboard
72 | train
73 | truck
74 | ```
75 |
76 |
77 | 2. Download COCO 2017 [trainset](http://images.cocodataset.org/zips/train2017.zip), [valset](images.cocodataset.org/zips/val2017.zip), and [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)
78 |
79 |
80 | 3. Extract the annotations zip file into a folder named `COCO` inside your choice of `data_dir` (For example: `datasets/COCO`)
81 |
82 |
83 | 4. Extract train and val set zip files into a subfolder named `downloads` (For example: `datasets/COCO/downloads`)
84 |
85 |
86 | 5. Download [DomainNet (clean version)](https://ai.bu.edu/M3SDA/)
87 |
88 |
89 | 6. Create a new `DomainNet` folder next to your `COCO` folder
90 |
91 |
92 | 7. Extract each domain's zip file under its respective subfolder (For example: `datasets/DomainNet/clipart`)
93 |
94 |
95 | 8. Back to the project, use `--first_run` argument once while running the training script:
96 | ```shell
97 | python run.py --loss CrossEntropy --epochs 1 --corruption_mode None --data_dir datasets --first_run --train_dataset COCO --test_datasets DomainNet:Real --print_config
98 | ```
99 |
100 |
101 | 9. If everything works fine, you will see `train2017` and `val2017` folders under `COCO`
102 |
103 |
104 | 10. Both folders must contain 10 subfolders that belong to shared classes between COCO and DomainNet
105 |
106 |
107 | 11. Now, try running ACVC as well:
108 | ```shell
109 | python run.py --loss CrossEntropy AttentionConsistency --epochs 1 --corruption_mode acvc --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real --print_config
110 | ```
111 |
112 |
113 | 12. All good? Then, you are good to go with the COCO section of `run_experiments.sh` to run multiple experiments
114 |
115 |
116 | 13. That's it!
117 |
118 | ## Citation
119 |
120 | If you use these codes in your research, please cite:
121 |
122 | ```bibtex
123 | @InProceedings{Cugu_2022_CVPR,
124 | author = {Cugu, Ilke and Mancini, Massimiliano and Chen, Yanbei and Akata, Zeynep},
125 | title = {Attention Consistency on Visual Corruptions for Single-Source Domain Generalization},
126 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
127 | month = {June},
128 | year = {2022},
129 | pages = {4165-4174}
130 | }
131 | ```
132 |
133 | ## References
134 |
135 | We indicate if a function or script is borrowed externally inside each file.
136 | Specifically for visual corruption implementations we benefit from:
137 |
138 | - The imagecorruptions library of [Autonomous Driving when Winter is Coming](https://github.com/bethgelab/imagecorruptions).
139 |
140 | Consider citing this work as well if you use it in your project.
141 |
--------------------------------------------------------------------------------
/acvc-pytorch.yml:
--------------------------------------------------------------------------------
1 | name: acvc-pytorch
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=4.5=1_gnu
8 | - blas=1.0=mkl
9 | - ca-certificates=2021.5.25=h06a4308_1
10 | - certifi=2021.5.30=py37h06a4308_0
11 | - cudatoolkit=10.1.243=h6bb024c_0
12 | - freetype=2.10.4=h5ab3b9f_0
13 | - intel-openmp=2021.2.0=h06a4308_610
14 | - jpeg=9b=h024ee3a_2
15 | - lcms2=2.12=h3be6417_0
16 | - ld_impl_linux-64=2.35.1=h7274673_9
17 | - libffi=3.3=he6710b0_2
18 | - libgcc-ng=9.3.0=h5101ec6_17
19 | - libgomp=9.3.0=h5101ec6_17
20 | - libpng=1.6.37=hbc83047_0
21 | - libstdcxx-ng=9.3.0=hd4cf53a_17
22 | - libtiff=4.2.0=h85742a9_0
23 | - libwebp-base=1.2.0=h27cfd23_0
24 | - lz4-c=1.9.3=h2531618_0
25 | - mkl=2021.2.0=h06a4308_296
26 | - mkl-service=2.3.0=py37h27cfd23_1
27 | - mkl_fft=1.3.0=py37h42c9631_2
28 | - mkl_random=1.2.1=py37ha9443f7_2
29 | - ncurses=6.2=he6710b0_1
30 | - ninja=1.10.2=hff7bd54_1
31 | - numpy=1.20.2=py37h2d18471_0
32 | - numpy-base=1.20.2=py37hfae3a4d_0
33 | - olefile=0.46=py37_0
34 | - openssl=1.1.1k=h27cfd23_0
35 | - pillow=8.2.0=py37he98fc37_0
36 | - pip=21.1.2=py37h06a4308_0
37 | - python=3.7.10=h12debd9_4
38 | - pytorch=1.5.1=py3.7_cuda10.1.243_cudnn7.6.3_0
39 | - readline=8.1=h27cfd23_0
40 | - setuptools=52.0.0=py37h06a4308_0
41 | - six=1.16.0=pyhd3eb1b0_0
42 | - sqlite=3.36.0=hc218d9a_0
43 | - tk=8.6.10=hbc83047_0
44 | - torchvision=0.6.1=py37_cu101
45 | - wheel=0.36.2=pyhd3eb1b0_0
46 | - xz=5.2.5=h7b6447c_0
47 | - zlib=1.2.11=h7b6447c_3
48 | - zstd=1.4.9=haebb681_0
49 | - pip:
50 | - absl-py==0.13.0
51 | - astor==0.8.1
52 | - cached-property==1.5.2
53 | - click==8.0.1
54 | - cycler==0.10.0
55 | - cython==0.29.23
56 | - gast==0.2.2
57 | - google-pasta==0.2.0
58 | - grpcio==1.38.1
59 | - h5py==3.3.0
60 | - importlib-metadata==4.6.0
61 | - joblib==1.0.1
62 | - keras-applications==1.0.8
63 | - keras-preprocessing==1.1.2
64 | - kiwisolver==1.3.1
65 | - markdown==3.3.4
66 | - matplotlib==3.1.1
67 | - matplotlib-venn==0.11.6
68 | - opencv-python==4.5.1.48
69 | - opt-einsum==3.3.0
70 | - overrides==3.1.0
71 | - pandas==1.3.0
72 | - protobuf==3.17.3
73 | - pycocotools==2.0.2
74 | - pyparsing==2.4.7
75 | - python-dateutil==2.8.1
76 | - pytz==2021.1
77 | - pyyaml==5.4.1
78 | - scikit-learn==0.24.2
79 | - scipy==1.7.0
80 | - tensorboard==1.15.0
81 | - tensorboardx==2.4
82 | - tensorflow-estimator==1.15.1
83 | - termcolor==1.1.0
84 | - threadpoolctl==2.1.0
85 | - tqdm==4.58.0
86 | - typing-extensions==3.10.0.0
87 | - typing-utils==0.1.0
88 | - werkzeug==2.0.1
89 | - wrapt==1.12.1
90 | - zipp==3.4.1
91 | prefix: /home/ubuntu/anaconda3/envs/acvc-pytorch
92 |
--------------------------------------------------------------------------------
/analysis/GeneralizationExpProcessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from tools import plot_learning_curve
6 |
7 | class GeneralizationExpProcessor:
8 | """
9 | Custom class to analyze and visualiza the cumulative experimental results
10 |
11 | # Arguments
12 | :param to_dir: (string)
13 | :param path: (string) absolute path to 'experiment.json' file if already exists
14 | """
15 |
16 | def __init__(self,
17 | config,
18 | to_dir=".",
19 | include_std=True,
20 | image_format="pdf",
21 | path="generalization.json"):
22 | self.config = config
23 | self.to_dir = to_dir
24 | self.include_std = include_std
25 | self.image_format = image_format
26 | self.path = path
27 | self.hist = {}
28 |
29 | if os.path.isfile(self.path):
30 | self.load_data()
31 |
32 | if not os.path.exists(self.to_dir):
33 | os.makedirs(self.to_dir)
34 |
35 | def save_data(self):
36 | """
37 | Overwrites the experiment history
38 | """
39 | with open(self.path, "w") as hist_file:
40 | json.dump(self.hist, hist_file)
41 |
42 | def load_data(self):
43 | """
44 | Loads the current experiment history to append new results
45 | """
46 | with open(self.path, "r") as hist_file:
47 | self.hist = json.load(hist_file)
48 |
49 | def merge_logs(self, paths):
50 | print("Merge started...")
51 | for path in paths:
52 | print("Processing... %s" % path)
53 | with open(path) as hist_file:
54 | temp_hist = json.load(hist_file)
55 |
56 | for key in temp_hist:
57 | if key in self.hist:
58 | self.hist[key].extend(temp_hist[key])
59 | else:
60 | self.hist[key] = temp_hist[key]
61 |
62 | # Save the changes
63 | self.save_data()
64 | print("Merge completed.")
65 | print("New file is saved at %s" % self.path)
66 |
67 | def get_info_loss_mode(self, s):
68 | temp = 4 if s.find("info=") == -1 else 5
69 | i = max(s.find("info="), s.find("aug="))
70 | result = ""
71 | while s[i + temp] != "]":
72 | result += s[i + temp]
73 | i += 1
74 |
75 | return result
76 |
77 | def process_key(self, key):
78 | if key == "None":
79 | return "Baseline"
80 |
81 | return key
82 |
83 | def run_PACS_analysis(self, hist):
84 | pacs = [
85 | "PACS:Photo",
86 | "PACS:Art",
87 | "PACS:Cartoon",
88 | "PACS:Sketch"
89 | ]
90 |
91 | prefix = [
92 | "\\begin{table*}",
93 | "\t \\begin{adjustbox}{width=1\\textwidth}",
94 | "\t \\centering",
95 | "\t\t \\begin{tabular}{lcccccc} ",
96 | "\t\t \\toprule",
97 | "\t\t & Photo & Art & Cartoon & Sketch & Avg. & Max. \\\\",
98 | "\t\t \\midrule",
99 | ]
100 |
101 | # History contains repeated experiments, so compute mean & std
102 | hist_flag = False
103 | processed_hist = {}
104 | cumulative_hist = {}
105 | for i in hist:
106 | processed_hist[i] = {}
107 | cumulative_hist[i] = {"acc": [], "loss": [], "val_acc": [], "val_loss": []}
108 | for entry in hist[i]:
109 | for key in entry:
110 | if key == "history": # Cumulative learning curve
111 | for j in entry[key]:
112 | cumulative_hist[i][j].append(entry[key][j])
113 | hist_flag = True
114 |
115 | else:
116 | if key in processed_hist[i]:
117 | processed_hist[i][key].append(entry[key])
118 | else:
119 | processed_hist[i][key] = [entry[key]]
120 |
121 | # Plot cumulative learning curve
122 | if hist_flag:
123 | for i in cumulative_hist:
124 | temp_hist = {}
125 | for key in cumulative_hist[i]:
126 | temp = np.array(cumulative_hist[i][key])
127 | temp_avg = np.mean(temp, axis=0)
128 | temp_std = np.std(temp, axis=0)
129 | temp_hist[key] = temp_avg
130 | temp_hist["%s_std" % key] = temp_std
131 | plot_learning_curve(temp_hist, os.path.join(self.to_dir, "%s_PACS_learning_curve.%s" % (self.get_info_loss_mode(i), self.image_format)))
132 |
133 | # Prepare results in LaTeX format
134 | results = []
135 | for i in processed_hist:
136 |
137 | entry = processed_hist[i]
138 | temp_obj = "" if "vanilla" in i else " + contrastive"
139 | info_loss_mode = self.get_info_loss_mode(i)
140 | key = info_loss_mode + temp_obj
141 | key = key.replace("_", " ")
142 | key = self.process_key(key)
143 |
144 | accs = {}
145 | stds = {}
146 | for dataset in entry:
147 | cumulative_results = 100 * np.array(entry[dataset])
148 | avg = np.mean(cumulative_results)
149 | std = np.std(cumulative_results)
150 | accs[dataset] = float(avg)
151 | stds[dataset] = float(std)
152 |
153 | # Calculate avg performance per entry
154 | exp_count = len(entry["PACS:Photo"])
155 | avg_list = np.zeros(exp_count)
156 | for j in range(exp_count):
157 | avg_list[j] = np.mean(100 * np.array([entry[temp][j] for temp in entry if temp != "PACS:Photo" and temp in pacs]))
158 |
159 | # PACS table
160 | temp = "\t\t %s " % key
161 | while len(temp) < 52:
162 | temp += " "
163 | for j in pacs:
164 | temp += "& $%.2f \\pm %.1f$ " % (accs[j], stds[j]) if self.include_std else "& $%.2f$ " % (accs[j])
165 | avg, std, top = float(np.mean(avg_list)), float(np.std(avg_list)), float(avg_list.max())
166 | temp += "& $%.2f \\pm %.1f$ & $%.2f$ \\\\" % (avg, std, top)
167 | results.append(temp)
168 |
169 | body = ["\n".join(prefix)]
170 | body.append("\n".join(results))
171 | body.append("\t\t \\bottomrule")
172 | body.append("\t\t \\end{tabular}")
173 | body.append("\t \\end{adjustbox}")
174 | body.append("\t \\caption{ResNet-18 results on PACS benchmark}")
175 | body.append("\\end{table*}")
176 |
177 | # Export the LaTeX file
178 | with open(os.path.join(self.to_dir, "result_PACS.tex"), '+w') as tex_file:
179 | tex_file.write("\n".join(body))
180 |
181 | def run_COCO_analysis(self, hist):
182 | coco = [
183 | "COCO",
184 | "DomainNet:Real",
185 | "DomainNet:Painting",
186 | "DomainNet:Infograph",
187 | "DomainNet:Clipart",
188 | "DomainNet:Sketch",
189 | "DomainNet:Quickdraw"
190 | ]
191 |
192 | prefix = [
193 | "\\begin{table*}",
194 | "\t \\centering",
195 | #"\t\t \\begin{tabular}{lcccccccc} ",
196 | "\t\t \\begin{tabular}{lccccccccc} ",
197 | "\t\t \\toprule",
198 | #"\t\t & COCO & Real & Painting & Infograph & Clipart & Sketch & Avg. & Max. \\\\",
199 | "\t\t & COCO & Real & Painting & Infograph & Clipart & Sketch & Quickdraw & Avg. & Max. \\\\",
200 | "\t\t \\midrule",
201 | ]
202 |
203 | # History contains repeated experiments, so compute mean & std
204 | hist_flag = False
205 | processed_hist = {}
206 | cumulative_hist = {}
207 | for i in hist:
208 | processed_hist[i] = {}
209 | cumulative_hist[i] = {"acc": [], "loss": [], "val_acc": [], "val_loss": []}
210 | for entry in hist[i]:
211 | for key in entry:
212 | if key == "history": # Cumulative learning curve
213 | for j in entry[key]:
214 | cumulative_hist[i][j].append(entry[key][j])
215 | hist_flag = True
216 |
217 | else:
218 | if key in processed_hist[i]:
219 | processed_hist[i][key].append(entry[key])
220 | else:
221 | processed_hist[i][key] = [entry[key]]
222 |
223 | # Plot cumulative learning curve
224 | if hist_flag:
225 | for i in cumulative_hist:
226 | temp_hist = {}
227 | for key in cumulative_hist[i]:
228 | temp = np.array(cumulative_hist[i][key])
229 | temp_avg = np.mean(temp, axis=0)
230 | temp_std = np.std(temp, axis=0)
231 | temp_hist[key] = temp_avg
232 | temp_hist["%s_std" % key] = temp_std
233 | plot_learning_curve(temp_hist, os.path.join(self.to_dir, "%s_COCO_learning_curve.%s" % (self.get_info_loss_mode(i), self.image_format)))
234 |
235 | # Prepare results in LaTeX format
236 | results = []
237 | for i in processed_hist:
238 | entry = processed_hist[i]
239 | temp_obj = "" if "vanilla" in i else " + contrastive"
240 | info_loss_mode = self.get_info_loss_mode(i)
241 | key = info_loss_mode + temp_obj
242 | key = key.replace("_", " ")
243 | key = self.process_key(key)
244 |
245 | accs = {}
246 | stds = {}
247 | for dataset in entry:
248 | cumulative_results = 100 * np.array(entry[dataset])
249 | avg = np.mean(cumulative_results)
250 | std = np.std(cumulative_results)
251 | accs[dataset] = float(avg)
252 | stds[dataset] = float(std)
253 |
254 | # Calculate avg performance per entry
255 | exp_count = len(entry["COCO"])
256 | avg_list = np.zeros(exp_count)
257 | for j in range(exp_count):
258 | avg_list[j] = np.mean(100 * np.array([entry[temp][j] for temp in entry if temp != "COCO" and temp in coco]))
259 |
260 | # COCO table
261 | temp = "\t\t %s " % key
262 | while len(temp) < 52:
263 | temp += " "
264 | for j in coco:
265 | temp += "& $%.2f \\pm %.1f$ " % (accs[j], stds[j]) if self.include_std else "& $%.2f$ " % (accs[j])
266 | avg, std, top = float(np.mean(avg_list)), float(np.std(avg_list)), float(avg_list.max())
267 | temp += "& $%.2f \\pm %.1f$ & $%.2f$ \\\\" % (avg, std, top)
268 | results.append(temp)
269 |
270 | body = ["\n".join(prefix)]
271 | body.append("\n".join(results))
272 | body.append("\t\t \\bottomrule")
273 | body.append("\t\t \\end{tabular}")
274 | body.append("\t \\vspace{-3pt}")
275 | body.append("\t \\caption{ResNet-18 results on COCO benchmark}")
276 | body.append("\\end{table*}")
277 |
278 | # Export the LaTeX file
279 | with open(os.path.join(self.to_dir, "result_COCO.tex"), '+w') as tex_file:
280 | tex_file.write("\n".join(body))
281 |
282 | def run_FullDomainNet_analysis(self, hist):
283 | domainnet = [
284 | "FullDomainNet:Real",
285 | "FullDomainNet:Painting",
286 | "FullDomainNet:Infograph",
287 | "FullDomainNet:Clipart",
288 | "FullDomainNet:Sketch",
289 | "FullDomainNet:Quickdraw"
290 | ]
291 |
292 | prefix = [
293 | "\\begin{table*}",
294 | "\t \\centering",
295 | "\t\t \\begin{tabular}{lcccccccc} ",
296 | "\t\t \\toprule",
297 | "\t\t & Real & Painting & Infograph & Clipart & Sketch & Quickdraw & Avg. & Max. \\\\",
298 | "\t\t \\midrule",
299 | ]
300 |
301 | # History contains repeated experiments, so compute mean & std
302 | hist_flag = False
303 | processed_hist = {}
304 | cumulative_hist = {}
305 | for i in hist:
306 | processed_hist[i] = {}
307 | cumulative_hist[i] = {"acc": [], "loss": [], "val_acc": [], "val_loss": []}
308 | for entry in hist[i]:
309 | for key in entry:
310 | if key == "history": # Cumulative learning curve
311 | for j in entry[key]:
312 | cumulative_hist[i][j].append(entry[key][j])
313 | hist_flag = True
314 |
315 | else:
316 | if key in processed_hist[i]:
317 | processed_hist[i][key].append(entry[key])
318 | else:
319 | processed_hist[i][key] = [entry[key]]
320 |
321 | # Plot cumulative learning curve
322 | if hist_flag:
323 | for i in cumulative_hist:
324 | temp_hist = {}
325 | for key in cumulative_hist[i]:
326 | temp = np.array(cumulative_hist[i][key])
327 | temp_avg = np.mean(temp, axis=0)
328 | temp_std = np.std(temp, axis=0)
329 | temp_hist[key] = temp_avg
330 | temp_hist["%s_std" % key] = temp_std
331 | plot_learning_curve(temp_hist, os.path.join(self.to_dir, "%s_DomainNet_learning_curve.%s" % (self.get_info_loss_mode(i), self.image_format)))
332 |
333 | # Prepare results in LaTeX format
334 | results = []
335 | for i in processed_hist:
336 | entry = processed_hist[i]
337 | temp_obj = "" if "vanilla" in i else " + contrastive"
338 | info_loss_mode = self.get_info_loss_mode(i)
339 | key = info_loss_mode + temp_obj
340 | key = key.replace("_", " ")
341 | key = self.process_key(key)
342 |
343 | accs = {}
344 | stds = {}
345 | for dataset in entry:
346 | cumulative_results = 100 * np.array(entry[dataset])
347 | avg = np.mean(cumulative_results)
348 | std = np.std(cumulative_results)
349 | accs[dataset] = float(avg)
350 | stds[dataset] = float(std)
351 |
352 | # Calculate avg performance per entry
353 | exp_count = len(entry["FullDomainNet:Real"])
354 | avg_list = np.zeros(exp_count)
355 | for j in range(exp_count):
356 | avg_list[j] = np.mean(100 * np.array([entry[temp][j] for temp in entry if temp != "FullDomainNet:Real" and temp in domainnet]))
357 |
358 | # DomainNet table
359 | temp = "\t\t %s " % key
360 | while len(temp) < 52:
361 | temp += " "
362 | for j in domainnet:
363 | temp += "& $%.2f \\pm %.1f$ " % (accs[j], stds[j]) if self.include_std else "& $%.2f$ " % (accs[j])
364 | avg, std, top = float(np.mean(avg_list)), float(np.std(avg_list)), float(avg_list.max())
365 | temp += "& $%.2f \\pm %.1f$ & $%.2f$ \\\\" % (avg, std, top)
366 | results.append(temp)
367 |
368 | body = ["\n".join(prefix)]
369 | body.append("\n".join(results))
370 | body.append("\t\t \\bottomrule")
371 | body.append("\t\t \\end{tabular}")
372 | body.append("\t \\vspace{-3pt}")
373 | body.append("\t \\caption{ResNet-18 results on DomainNet benchmark}")
374 | body.append("\\end{table*}")
375 |
376 | # Export the LaTeX file
377 | with open(os.path.join(self.to_dir, "result_DomainNet.tex"), '+w') as tex_file:
378 | tex_file.write("\n".join(body))
379 |
380 | def run(self):
381 | benchmarks = {"PACS": {"hist": {}, "analysis_func": self.run_PACS_analysis},
382 | "COCO": {"hist": {}, "analysis_func": self.run_COCO_analysis},
383 | "FullDomainNet": {"hist": {}, "analysis_func": self.run_FullDomainNet_analysis}}
384 |
385 | for experiment in self.hist:
386 |
387 | if "PACS" in experiment:
388 | benchmarks["PACS"]["hist"][experiment] = self.hist[experiment]
389 |
390 | elif "COCO" in experiment:
391 | benchmarks["COCO"]["hist"][experiment] = self.hist[experiment]
392 |
393 | elif "FullDomainNet" in experiment:
394 | benchmarks["FullDomainNet"]["hist"][experiment] = self.hist[experiment]
395 |
396 | for benchmark in benchmarks:
397 | if len(benchmarks[benchmark]) > 0:
398 | benchmarks[benchmark]["analysis_func"](benchmarks[benchmark]["hist"])
399 |
400 | if __name__ == '__main__':
401 | # Dynamic parameters
402 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
403 | parser.add_argument("--path", default="../generalization.json", help="filepath to experiment history (JSON file)", type=str)
404 | parser.add_argument("--to_dir", default="../results", help="filepath to save charts, models, etc.", type=str)
405 | parser.add_argument("--merge_logs", help="merges experiment results for the given JSON file paths", nargs="+")
406 | parser.add_argument("--image_format", default="png", help="", type=str)
407 | args = vars(parser.parse_args())
408 |
409 | hist_path = args["path"]
410 | experimentProcessor = GeneralizationExpProcessor(config=args, path=hist_path, include_std=True, image_format=args["image_format"], to_dir=args["to_dir"])
411 | if args["merge_logs"] is None:
412 | experimentProcessor.run()
413 | else:
414 | experimentProcessor.merge_logs(paths=args["merge_logs"])
--------------------------------------------------------------------------------
/assets/ACVC_CAM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/ACVC/2f93b23fa94735eb1079c2a9fec9e62ea6a4194e/assets/ACVC_CAM.png
--------------------------------------------------------------------------------
/assets/ACVC_flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/ACVC/2f93b23fa94735eb1079c2a9fec9e62ea6a4194e/assets/ACVC_flow.png
--------------------------------------------------------------------------------
/losses/AttentionConsistency.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn
4 | from torch.autograd import Variable
5 |
6 | class AttentionConsistency(nn.Module):
7 | def __init__(self, lambd=6e-2, T=1.0):
8 | super().__init__()
9 | self.name = "AttentionConsistency"
10 | self.T = T
11 | self.lambd = lambd
12 |
13 | def CAM_neg(self, c):
14 | result = c.reshape(c.size(0), c.size(1), -1)
15 | result = -nn.functional.log_softmax(result / self.T, dim=2) / result.size(2)
16 | result = result.sum(2)
17 |
18 | return result
19 |
20 | def CAM_pos(self, c):
21 | result = c.reshape(c.size(0), c.size(1), -1)
22 | result = nn.functional.softmax(result / self.T, dim=2)
23 |
24 | return result
25 |
26 | def forward(self, c, ci_list, y, segmentation_masks=None):
27 | """
28 | CAM (batch_size, num_classes, feature_map.shpae[0], feature_map.shpae[1]) based loss
29 |
30 | Argumens:
31 | :param c: (Torch.tensor) clean image's CAM
32 | :param ci_list: (Torch.tensor) list of augmented image's CAMs
33 | :param y: (Torch.tensor) ground truth labels
34 | :param segmentation_masks: (numpy.array)
35 | :return:
36 | """
37 | c1 = c.clone()
38 | c1 = Variable(c1)
39 | c0 = self.CAM_neg(c)
40 |
41 | # Top-k negative classes
42 | c1 = c1.sum(2).sum(2)
43 | index = torch.zeros(c1.size())
44 | c1[range(c0.size(0)), y] = - float("Inf")
45 | topk_ind = torch.topk(c1, 3, dim=1)[1]
46 | index[torch.tensor(range(c1.size(0))).unsqueeze(1), topk_ind] = 1
47 | index = index > 0.5
48 |
49 | # Negative CAM loss
50 | neg_loss = c0[index].sum() / c0.size(0)
51 | for ci in ci_list:
52 | ci = self.CAM_neg(ci)
53 | neg_loss += ci[index].sum() / ci.size(0)
54 | neg_loss /= len(ci_list) + 1
55 |
56 | # Positive CAM loss
57 | index = torch.zeros(c1.size())
58 | true_ind = [[i] for i in y]
59 | index[torch.tensor(range(c1.size(0))).unsqueeze(1), true_ind] = 1
60 | index = index > 0.5
61 | p0 = self.CAM_pos(c)[index]
62 | pi_list = [self.CAM_pos(ci)[index] for ci in ci_list]
63 |
64 | # Middle ground for Jensen-Shannon divergence
65 | p_count = 1 + len(pi_list)
66 | if segmentation_masks is None:
67 | p_mixture = p0.detach().clone()
68 | for pi in pi_list:
69 | p_mixture += pi
70 | p_mixture = torch.clamp(p_mixture / p_count, 1e-7, 1).log()
71 |
72 | else:
73 | mask = np.interp(segmentation_masks, (segmentation_masks.min(), segmentation_masks.max()), (0, 1))
74 | p_mixture = torch.from_numpy(mask).cuda()
75 | p_mixture = p_mixture.reshape(p_mixture.size(0), -1)
76 | p_mixture = torch.nn.functional.normalize(p_mixture, dim=1)
77 |
78 | pos_loss = nn.functional.kl_div(p_mixture, p0, reduction='batchmean')
79 | for pi in pi_list:
80 | pos_loss += nn.functional.kl_div(p_mixture, pi, reduction='batchmean')
81 | pos_loss /= p_count
82 |
83 | loss = pos_loss + neg_loss
84 | return self.lambd * loss
--------------------------------------------------------------------------------
/losses/Distillation.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 | class Distillation(nn.Module):
5 | def __init__(self, temperature=2.0):
6 | super().__init__()
7 | self.T = temperature
8 | self.criterion = nn.KLDivLoss()
9 |
10 | def forward(self, z_s, z_t):
11 |
12 | return self.criterion(F.log_softmax(z_s / self.T), F.softmax(z_t / self.T)) # * (self.T * self.T)
--------------------------------------------------------------------------------
/losses/JSDivergence.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class JSDivergence(nn.Module):
6 | def __init__(self, lambd=12):
7 | super().__init__()
8 | self.name = "JSDivergence"
9 | self.lambd = lambd
10 |
11 | def forward(self, p0, pi_list, y=None):
12 | p_count = 1 + len(pi_list)
13 | p_mixture = p0.detach().clone()
14 | for pi in pi_list:
15 | p_mixture += pi
16 | p_mixture = torch.clamp(p_mixture / p_count, 1e-7, 1).log()
17 |
18 | loss = F.kl_div(p_mixture, p0, reduction='batchmean')
19 | for pi in pi_list:
20 | loss += F.kl_div(p_mixture, pi, reduction='batchmean')
21 | loss /= p_count
22 |
23 | return self.lambd * loss
--------------------------------------------------------------------------------
/models/ResNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models.utils import load_state_dict_from_url
5 |
6 | model_urls = {
7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
12 | }
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
16 | """3x3 convolution with padding"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=dilation, groups=groups, bias=False, dilation=dilation)
19 |
20 |
21 | def conv1x1(in_planes, out_planes, stride=1):
22 | """1x1 convolution"""
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
30 | base_width=64, dilation=1, norm_layer=None):
31 | super(BasicBlock, self).__init__()
32 | if norm_layer is None:
33 | norm_layer = nn.BatchNorm2d
34 | if groups != 1 or base_width != 64:
35 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36 | if dilation > 1:
37 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
39 | self.conv1 = conv3x3(inplanes, planes, stride)
40 | self.bn1 = norm_layer(planes)
41 | self.relu = nn.ReLU(inplace=True)
42 | self.conv2 = conv3x3(planes, planes)
43 | self.bn2 = norm_layer(planes)
44 | self.downsample = downsample
45 | self.stride = stride
46 |
47 | def forward(self, x):
48 | identity = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 |
57 | if self.downsample is not None:
58 | identity = self.downsample(x)
59 |
60 | out += identity
61 | out = self.relu(out)
62 |
63 | return out
64 |
65 |
66 | class Bottleneck(nn.Module):
67 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
68 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
69 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
70 | # This variant is also known as ResNet V1.5 and improves accuracy according to
71 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
72 |
73 | expansion = 4
74 |
75 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
76 | base_width=64, dilation=1, norm_layer=None):
77 | super(Bottleneck, self).__init__()
78 | if norm_layer is None:
79 | norm_layer = nn.BatchNorm2d
80 | width = int(planes * (base_width / 64.)) * groups
81 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
82 | self.conv1 = conv1x1(inplanes, width)
83 | self.bn1 = norm_layer(width)
84 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
85 | self.bn2 = norm_layer(width)
86 | self.conv3 = conv1x1(width, planes * self.expansion)
87 | self.bn3 = norm_layer(planes * self.expansion)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.downsample = downsample
90 | self.stride = stride
91 |
92 | def forward(self, x):
93 | identity = x
94 |
95 | out = self.conv1(x)
96 | out = self.bn1(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv2(out)
100 | out = self.bn2(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv3(out)
104 | out = self.bn3(out)
105 |
106 | if self.downsample is not None:
107 | identity = self.downsample(x)
108 |
109 | out += identity
110 | out = self.relu(out)
111 |
112 | return out
113 |
114 |
115 | class ResNet(nn.Module):
116 |
117 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
118 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
119 | norm_layer=None):
120 | super(ResNet, self).__init__()
121 | if norm_layer is None:
122 | norm_layer = nn.BatchNorm2d
123 | self._norm_layer = norm_layer
124 |
125 | self.inplanes = 64
126 | self.dilation = 1
127 | if replace_stride_with_dilation is None:
128 | # each element in the tuple indicates if we should replace
129 | # the 2x2 stride with a dilated convolution instead
130 | replace_stride_with_dilation = [False, False, False]
131 | if len(replace_stride_with_dilation) != 3:
132 | raise ValueError("replace_stride_with_dilation should be None "
133 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
134 | self.groups = groups
135 | self.base_width = width_per_group
136 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
137 | bias=False)
138 | self.bn1 = norm_layer(self.inplanes)
139 | self.relu = nn.ReLU(inplace=True)
140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
141 | self.layer1 = self._make_layer(block, 64, layers[0])
142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
143 | dilate=replace_stride_with_dilation[0])
144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
145 | dilate=replace_stride_with_dilation[1])
146 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
147 | dilate=replace_stride_with_dilation[2])
148 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
149 | self.fc = nn.Linear(512 * block.expansion, num_classes)
150 |
151 | for m in self.modules():
152 | if isinstance(m, nn.Conv2d):
153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
155 | nn.init.constant_(m.weight, 1)
156 | nn.init.constant_(m.bias, 0)
157 |
158 | # Zero-initialize the last BN in each residual branch,
159 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
160 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
161 | if zero_init_residual:
162 | for m in self.modules():
163 | if isinstance(m, Bottleneck):
164 | nn.init.constant_(m.bn3.weight, 0)
165 | elif isinstance(m, BasicBlock):
166 | nn.init.constant_(m.bn2.weight, 0)
167 |
168 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
169 | norm_layer = self._norm_layer
170 | downsample = None
171 | previous_dilation = self.dilation
172 | if dilate:
173 | self.dilation *= stride
174 | stride = 1
175 | if stride != 1 or self.inplanes != planes * block.expansion:
176 | downsample = nn.Sequential(
177 | conv1x1(self.inplanes, planes * block.expansion, stride),
178 | norm_layer(planes * block.expansion),
179 | )
180 |
181 | layers = []
182 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
183 | self.base_width, previous_dilation, norm_layer))
184 | self.inplanes = planes * block.expansion
185 | for _ in range(1, blocks):
186 | layers.append(block(self.inplanes, planes, groups=self.groups,
187 | base_width=self.base_width, dilation=self.dilation,
188 | norm_layer=norm_layer))
189 |
190 | return nn.Sequential(*layers)
191 |
192 | def _forward_impl(self, x):
193 | end_points = {}
194 |
195 | # See note [TorchScript super()]
196 | x = self.conv1(x)
197 | x = self.bn1(x)
198 | x = self.relu(x)
199 | x = self.maxpool(x)
200 |
201 | x = self.layer1(x)
202 | x = self.layer2(x)
203 | x = self.layer3(x)
204 | x = self.layer4(x)
205 | end_points['Feature'] = x
206 |
207 | x = self.avgpool(x)
208 | x = torch.flatten(x, 1)
209 | end_points['Embedding'] = x
210 |
211 | x = self.fc(x)
212 | end_points['Predictions'] = F.softmax(input=x, dim=-1)
213 |
214 | # Taken from: https://github.com/GuoleiSun/HNC_loss
215 | end_points['CAM'] = F.conv2d(end_points['Feature'], self.fc.weight.view(self.fc.out_features, end_points['Feature'].size(1), 1, 1)) + self.fc.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
216 |
217 | return x, end_points
218 |
219 | def forward(self, x):
220 | return self._forward_impl(x)
221 |
222 |
223 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
224 | model = ResNet(block, layers, **kwargs)
225 | if pretrained:
226 | state_dict = load_state_dict_from_url(model_urls[arch],
227 | progress=progress)
228 | model.load_state_dict(state_dict)
229 | return model
230 |
231 |
232 | def resnet18(pretrained=False, progress=True, **kwargs):
233 | r"""ResNet-18 model from
234 | `"Deep Residual Learning for Image Recognition" `_
235 |
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
241 | **kwargs)
242 |
243 |
244 | def resnet34(pretrained=False, progress=True, **kwargs):
245 | r"""ResNet-34 model from
246 | `"Deep Residual Learning for Image Recognition" `_
247 |
248 | Args:
249 | pretrained (bool): If True, returns a model pre-trained on ImageNet
250 | progress (bool): If True, displays a progress bar of the download to stderr
251 | """
252 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
253 | **kwargs)
254 |
255 |
256 | def resnet50(pretrained=False, progress=True, **kwargs):
257 | r"""ResNet-50 model from
258 | `"Deep Residual Learning for Image Recognition" `_
259 |
260 | Args:
261 | pretrained (bool): If True, returns a model pre-trained on ImageNet
262 | progress (bool): If True, displays a progress bar of the download to stderr
263 | """
264 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
265 | **kwargs)
266 |
267 |
268 | def resnet101(pretrained=False, progress=True, **kwargs):
269 | r"""ResNet-101 model from
270 | `"Deep Residual Learning for Image Recognition" `_
271 |
272 | Args:
273 | pretrained (bool): If True, returns a model pre-trained on ImageNet
274 | progress (bool): If True, displays a progress bar of the download to stderr
275 | """
276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
277 | **kwargs)
278 |
279 |
280 | def resnet152(pretrained=False, progress=True, **kwargs):
281 | r"""ResNet-152 model from
282 | `"Deep Residual Learning for Image Recognition" `_
283 |
284 | Args:
285 | pretrained (bool): If True, returns a model pre-trained on ImageNet
286 | progress (bool): If True, displays a progress bar of the download to stderr
287 | """
288 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
289 | **kwargs)
290 |
291 | def get_resnet(depth, num_classes, freeze=False):
292 | model = None
293 |
294 | if depth == 18:
295 | model = resnet18(pretrained=True)
296 | num_ftrs = model.fc.in_features
297 | model.fc = torch.nn.Linear(num_ftrs, num_classes)
298 | torch.nn.init.kaiming_normal_(model.fc.weight)
299 |
300 | elif depth == 50:
301 | model = resnet50(pretrained=True)
302 | num_ftrs = model.fc.in_features
303 | model.fc = torch.nn.Linear(num_ftrs, num_classes)
304 | torch.nn.init.kaiming_normal_(model.fc.weight)
305 |
306 | if freeze:
307 | for param in model.parameters():
308 | param.requires_grad = False
309 |
310 | for param in model.fc.parameters():
311 | param.requires_grad = True
312 |
313 | return model
314 |
315 |
--------------------------------------------------------------------------------
/preprocessing/Datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import, division
2 | import os
3 | import torchvision
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from pycocotools.coco import COCO
8 | from tools import log, resize_image
9 |
10 | def preprocess_dataset(x, train, img_mean_mode):
11 | # Compute image mean if applicable
12 | if img_mean_mode is not None:
13 | if train:
14 |
15 | if img_mean_mode == "per_channel":
16 | x_ = np.copy(x)
17 | x_ = x_.astype('float32') / 255.0
18 | img_mean = np.array([np.mean(x_[:, :, :, 0]), np.mean(x_[:, :, :, 1]), np.mean(x_[:, :, :, 2])])
19 |
20 | elif img_mean_mode == "imagenet":
21 | img_mean = np.array([0.485, 0.456, 0.406])
22 |
23 | else:
24 | raise Exception("Invalid img_mean_mode..!")
25 | np.save("img_mean.npy", img_mean)
26 |
27 | return x
28 |
29 | def normalize_dataset(x, img_mean_mode):
30 | if img_mean_mode is not None:
31 |
32 | if img_mean_mode == "imagenet":
33 | img_mean = np.array([0.485, 0.456, 0.406])
34 | img_std = np.array([0.229, 0.224, 0.225])
35 |
36 | else:
37 | assert os.path.exists("img_mean.npy"), "Image mean file cannot be found!"
38 | img_mean = np.load("img_mean.npy")
39 | img_std = np.array([1.0, 1.0, 1.0])
40 |
41 | transform = torchvision.transforms.Compose([
42 | torchvision.transforms.ToTensor(),
43 | torchvision.transforms.Normalize(mean=img_mean, std=img_std)
44 | ])
45 |
46 | else:
47 | transform = torchvision.transforms.Compose([
48 | torchvision.transforms.ToTensor()
49 | ])
50 |
51 | x_ = transform(x)
52 |
53 | return x_
54 |
55 | def load_PACS(subset="photo", train=True, img_mean_mode="imagenet", distillation=False, data_dir="../../datasets"):
56 | subset = "art_painting" if subset.lower() == "art" else subset.lower()
57 | root_path = os.path.join(os.path.dirname(__file__), data_dir, 'PACS')
58 | data_path = os.path.join(root_path, subset.lower())
59 | classes = {"dog": 0, "elephant": 1, "giraffe": 2, "guitar": 3, "horse": 4, "house": 5, "person": 6}
60 | img_dim = (224, 224)
61 |
62 | imagedata = []
63 | labels = []
64 | teacher_logits = None
65 | if subset == "photo":
66 | label_file = "photo_train.txt" if train else "photo_val.txt"
67 | label_file = os.path.join(root_path, label_file)
68 |
69 | # Gather images and labels
70 | with open(label_file, "r") as f_label:
71 | for line in f_label:
72 | temp = line[:-1].split(" ")
73 | img = Image.open(os.path.join(root_path, temp[0]))
74 | img = resize_image(img, img_dim)
75 | imagedata.append(np.array(img))
76 | labels.append(int(temp[1]) - 1)
77 |
78 | imagedata = np.array(imagedata)
79 | labels = np.array(labels)
80 |
81 | # Include teacher logits as well if applicable
82 | if train and distillation:
83 | logit_file = os.path.join(root_path, "teacher_logits.npy")
84 | assert os.path.exists(logit_file), "Teacher logits cannot be found for PACS!"
85 | teacher_logits = np.load(logit_file)
86 |
87 | else:
88 | for class_dir in os.listdir(data_path):
89 | label = classes[class_dir]
90 | path = os.path.join(data_path, class_dir)
91 |
92 | for img_file in os.listdir(path):
93 | if img_file.endswith("jpg") or img_file.endswith("png"):
94 | img_path = os.path.join(path, img_file)
95 | img = Image.open(img_path)
96 | img = resize_image(img, img_dim)
97 | imagedata.append(np.array(img))
98 | labels.append(label)
99 |
100 | imagedata = np.array(imagedata)
101 | labels = np.array(labels)
102 |
103 | # Normalize the data
104 | imagedata = preprocess_dataset(imagedata, train=train, img_mean_mode=img_mean_mode)
105 | result = {"images": imagedata, "labels": labels} if train else (imagedata, labels)
106 |
107 | if not teacher_logits is None:
108 | result["teacher_logits"] = teacher_logits
109 |
110 | return result
111 |
112 | def prep_COCO(save_masks_as_image=True, data_dir="../../datasets"):
113 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'COCO')
114 | classes = {"airplane": 0, "bicycle": 1, "bus": 2, "car": 3, "horse": 4, "knife": 5, "motorcycle": 6,
115 | "skateboard": 7, "train": 8, "truck": 9}
116 | object_scene_ratio_lower_threshold = 0.1
117 | object_scene_ratio_upper_threshold = 1.0
118 | class_names = [""] * 10
119 | class_ids = [0] * 10
120 | img_dim = (224, 224)
121 |
122 | # Prepare training data
123 | train_data_path = os.path.join(data_path, "annotations", "instances_train2017.json")
124 | coco = COCO(train_data_path)
125 |
126 | catIds = coco.getCatIds()
127 | cats = coco.loadCats(catIds)
128 | for cat in cats:
129 | cat_name = cat["name"]
130 | if cat_name in classes:
131 | class_names[classes[cat_name]] = cat_name
132 | class_ids[classes[cat_name]] = cat["id"]
133 |
134 | total_img_count = 0
135 | landing_dir = os.path.join(data_path, "downloads", "train2017")
136 | target_dir = os.path.join(data_path, "train2017")
137 | if not os.path.exists(target_dir):
138 | os.makedirs(target_dir)
139 |
140 | for i in range(len(class_ids)):
141 | class_id = class_ids[i]
142 | class_name = class_names[i]
143 | target_class_dir = os.path.join(target_dir, class_name)
144 |
145 | if not os.path.exists(target_class_dir):
146 | os.makedirs(target_class_dir)
147 |
148 | img_ids = coco.getImgIds(catIds=class_id)
149 | for img_id in tqdm(img_ids):
150 | img_info = coco.loadImgs(img_id)
151 | assert len(img_info) == 1, "Image retrieval problem in COCO training set!"
152 | img_info = img_info[0]
153 |
154 | ann_id = coco.getAnnIds(imgIds=img_id, catIds=class_id)
155 | anns = coco.loadAnns(ann_id)
156 |
157 | # Generate binary mask
158 | mask = np.zeros((img_info['height'], img_info['width']))
159 | for j in range(len(anns)):
160 | mask = np.maximum(coco.annToMask(anns[j]), mask)
161 |
162 | if object_scene_ratio_lower_threshold < (
163 | np.sum(mask) / mask.size) <= object_scene_ratio_upper_threshold:
164 | total_img_count += 1
165 |
166 | # Copy relevant image to dest and save its corresponding binary mask
167 | source_path = os.path.join(landing_dir, img_info["file_name"])
168 | assert os.path.exists(source_path), "Image is not found in the source path!"
169 | dest_path = os.path.join(target_class_dir, img_info["file_name"])
170 | img = Image.open(source_path)
171 | img = resize_image(img, img_dim)
172 | img.save(dest_path)
173 |
174 | if save_masks_as_image:
175 | mask_img_path = os.path.join(target_class_dir, "%s_mask.jpg" % img_info["file_name"].split(".jpg")[0])
176 | mask_img = np.array(mask * 255, dtype=np.uint8)
177 | mask_img = Image.fromarray(mask_img)
178 | mask_img = resize_image(mask_img, img_dim)
179 | mask_img.save(mask_img_path)
180 | else:
181 | mask_path = os.path.join(target_class_dir, img_info["file_name"].replace("jpg", "npy"))
182 | np.save(mask_path, mask)
183 | log("%s COCO training images are prepared." % total_img_count)
184 |
185 | # Prepare validation data
186 | val_data_path = os.path.join(data_path, "annotations", "instances_val2017.json")
187 | coco = COCO(val_data_path)
188 |
189 | total_img_count = 0
190 | landing_dir = os.path.join(data_path, "downloads", "val2017")
191 | target_dir = os.path.join(data_path, "val2017")
192 | if not os.path.exists(target_dir):
193 | os.makedirs(target_dir)
194 |
195 | for i in range(len(class_ids)):
196 | class_id = class_ids[i]
197 | class_name = class_names[i]
198 | target_class_dir = os.path.join(target_dir, class_name)
199 |
200 | if not os.path.exists(target_class_dir):
201 | os.makedirs(target_class_dir)
202 |
203 | img_ids = coco.getImgIds(catIds=class_id)
204 | for img_id in tqdm(img_ids):
205 | img_info = coco.loadImgs(img_id)
206 | assert len(img_info) == 1, "Image retrieval problem in COCO validation set!"
207 | img_info = img_info[0]
208 |
209 | ann_id = coco.getAnnIds(imgIds=img_id, catIds=class_id)
210 | anns = coco.loadAnns(ann_id)
211 |
212 | # Generate binary mask
213 | mask = np.zeros((img_info['height'], img_info['width']))
214 | for j in range(len(anns)):
215 | mask = np.maximum(coco.annToMask(anns[j]), mask)
216 |
217 | if object_scene_ratio_lower_threshold < (np.sum(mask) / mask.size) <= object_scene_ratio_upper_threshold:
218 | total_img_count += 1
219 |
220 | # Copy relevant image to dest and save its corresponding binary mask
221 | source_path = os.path.join(landing_dir, img_info["file_name"])
222 | assert os.path.exists(source_path), "Image is not found in the source path!"
223 | dest_path = os.path.join(target_class_dir, img_info["file_name"])
224 | img = Image.open(source_path)
225 | img = resize_image(img, img_dim)
226 | img.save(dest_path)
227 |
228 | if save_masks_as_image:
229 | mask_img_path = os.path.join(target_class_dir, "%s_mask.jpg" % img_info["file_name"].split(".jpg")[0])
230 | mask_img = np.array(mask * 255, dtype=np.uint8)
231 | mask_img = Image.fromarray(mask_img)
232 | mask_img = resize_image(mask_img, img_dim)
233 | mask_img.save(mask_img_path)
234 | else:
235 | mask_path = os.path.join(target_class_dir, img_info["file_name"].replace("jpg", "npy"))
236 | np.save(mask_path, mask)
237 |
238 | log("%s COCO validation images are prepared." % total_img_count)
239 |
240 | def load_COCO(train=True, first_run=False, img_mean_mode="imagenet", distillation=False, data_dir="../../datasets"):
241 | if first_run:
242 | prep_COCO(data_dir=data_dir)
243 |
244 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'COCO')
245 | classes = {"airplane": 0, "bicycle": 1, "bus": 2, "car": 3, "horse": 4, "knife": 5, "motorcycle": 6,
246 | "skateboard": 7, "train": 8, "truck": 9}
247 | per_class_img_limit = 1000
248 | img_dim = (224, 224)
249 |
250 | if train:
251 | # Training data
252 | x_train = []
253 | y_train = []
254 | masks = []
255 | for class_dir in classes:
256 | label = classes[class_dir]
257 | path = os.path.join(data_path, "train2017", class_dir)
258 |
259 | file_list = [i for i in sorted(os.listdir(path)) if i.endswith("jpg") and "mask" not in i][:per_class_img_limit]
260 | for img_file in file_list:
261 | img_path = os.path.join(path, img_file)
262 | img = Image.open(img_path)
263 | img = resize_image(img, img_dim)
264 | mask_path = os.path.join(path, "%s_mask.jpg" % img_file.split(".jpg")[0])
265 | mask = Image.open(mask_path).convert('L')
266 | mask = resize_image(mask, img_dim)
267 | x_train.append(np.array(img))
268 | y_train.append(label)
269 | masks.append(np.array(mask))
270 |
271 | x_train = np.array(x_train)
272 | y_train = np.array(y_train)
273 | masks = np.array(masks)
274 |
275 | # Normalize the data
276 | x_train = preprocess_dataset(x_train, train=True, img_mean_mode=img_mean_mode)
277 | masks = masks.astype('float32') / 255.0
278 | masks = masks[:, :, :, np.newaxis]
279 | result = {"images": x_train, "labels": y_train, "segmentation_masks": masks}
280 |
281 | if distillation:
282 | logit_file = os.path.join(data_path, "teacher_logits.npy")
283 | assert os.path.exists(logit_file), "Teacher logits cannot be found for COCO!"
284 | result["teacher_logits"] = np.load(logit_file)
285 |
286 | else:
287 | # Validation data
288 | x_val = []
289 | y_val = []
290 | for class_dir in classes:
291 | label = classes[class_dir]
292 | path = os.path.join(data_path, "val2017", class_dir)
293 |
294 | file_list = [i for i in sorted(os.listdir(path)) if i.endswith("jpg") and "mask" not in i]
295 | for img_file in file_list:
296 | img_path = os.path.join(path, img_file)
297 | img = Image.open(img_path)
298 | img = resize_image(img, img_dim)
299 | x_val.append(np.array(img))
300 | y_val.append(label)
301 |
302 | x_val = np.array(x_val)
303 | y_val = np.array(y_val)
304 |
305 | # Normalize the data
306 | x_val = preprocess_dataset(x_val, train=False, img_mean_mode=img_mean_mode)
307 | result = (x_val, y_val)
308 |
309 | return result
310 |
311 | def load_DomainNet(subset, img_mean_mode="imagenet", data_dir="../../datasets"):
312 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'DomainNet', subset.lower())
313 | classes = {"airplane": 0, "bicycle": 1, "bus": 2, "car": 3, "horse": 4, "knife": 5, "motorcycle": 6,
314 | "skateboard": 7, "train": 8, "truck": 9}
315 | img_dim = (224, 224)
316 |
317 | imagedata = []
318 | labels = []
319 | for class_dir in classes:
320 | label = classes[class_dir]
321 | path = os.path.join(data_path, class_dir)
322 |
323 | for img_file in os.listdir(path):
324 | if img_file.endswith("jpg") or img_file.endswith("png"):
325 | img_path = os.path.join(path, img_file)
326 | img = Image.open(img_path)
327 | img = resize_image(img, img_dim,)
328 | imagedata.append(np.array(img))
329 | labels.append(label)
330 |
331 | imagedata = np.array(imagedata)
332 | imagedata = preprocess_dataset(imagedata, train=False, img_mean_mode=img_mean_mode)
333 | labels = np.array(labels)
334 |
335 | return imagedata, labels
336 |
337 | def load_FullDomainNet(subset, train=True, img_mean_mode="imagenet", distillation=False, data_dir="../../datasets"):
338 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'FullDomainNet')
339 | img_dim = (224, 224)
340 | subset = subset.lower()
341 | if subset == "real":
342 | labelfile = os.path.join(data_path, "real_train.txt") if train else os.path.join(data_path, "real_test.txt")
343 | else:
344 | labelfile = os.path.join(data_path, "fast.txt") #os.path.join(data_path, "%s.txt" % subset)
345 |
346 | # Gather image paths and labels
347 | imagepath = []
348 | labels = []
349 | with open(labelfile, "r") as f_label:
350 | for line in f_label:
351 | temp = line[:-1].split(" ")
352 | imagepath.append(temp[0])
353 | labels.append(int(temp[1]))
354 | labels = np.array(labels)
355 |
356 | imagedata = np.empty([labels.shape[0]] + list(img_dim) + [3], dtype="uint8")
357 | for i in range(len(labels)):
358 | img_path = os.path.join(data_path, imagepath[i])
359 | img = Image.open(img_path)
360 | img = resize_image(img, img_dim)
361 | imagedata[i] = np.array(img)
362 |
363 | imagedata = preprocess_dataset(imagedata, train=train, img_mean_mode=img_mean_mode)
364 |
365 | return {"images": imagedata, "labels": labels} if train else (imagedata, labels)
--------------------------------------------------------------------------------
/preprocessing/image/ACVCGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tools import shuffle_data
4 | from preprocessing.Datasets import normalize_dataset
5 | from PIL import Image as PILImage
6 | from scipy.stats import truncnorm
7 | from preprocessing.image.RandAugmentGenerator import RandAugment
8 | from imagecorruptions import corrupt, get_corruption_names
9 |
10 | class ACVCGenerator:
11 | def __init__(self,
12 | dataset,
13 | batch_size,
14 | stage,
15 | epochs,
16 | corruption_mode,
17 | corruption_dist,
18 | img_mean_mode=None,
19 | rand_aug=False,
20 | seed=13,
21 | orig_plus_aug=True):
22 | """
23 | 19 out of 22 corruptions used in this generator are taken from ImageNet-10.C:
24 | - https://github.com/bethgelab/imagecorruptions
25 |
26 | :param dataset: (tuple) x, y, segmentation mask (optional)
27 | :param batch_size: (int) # of inputs in a mini-batch
28 | :param stage: (str) train | test
29 | :param epochs: (int) # of full training passes
30 | :param corruption_mode: (str) requied for VisCo experiments
31 | :param corruption_dist: (str) requied to determine the corruption rate
32 | :param img_mean_mode: (str) use this to revert the image into its original form before augmentation
33 | :param rand_aug: (bool) enable/disable on-the-fly random data augmentation
34 | :param seed: (int) seed for input shuffle
35 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
36 | """
37 |
38 | if stage not in ['train', 'test']:
39 | assert ValueError('invalid stage!')
40 |
41 | # Settings
42 | self.batch_size = batch_size
43 | self.stage = stage
44 | self.epochs = epochs
45 | self.corruption_mode = corruption_mode
46 | self.corruption_dist = corruption_dist
47 | self.img_mean_mode = img_mean_mode
48 | self.rand_aug = rand_aug
49 | self.seed = seed
50 | self.orig_plus_aug = orig_plus_aug
51 |
52 | # Preparation
53 | self.configuration()
54 | self.load_data(dataset)
55 | self.random_augmentation = RandAugment(1, 5)
56 | if self.img_mean_mode is not None:
57 | self.img_mean = np.load("img_mean.npy")
58 |
59 | def configuration(self):
60 | self.shuffle_count = 1
61 | self.current_index = 0
62 |
63 | def shuffle(self):
64 | self.image_count = len(self.labels)
65 | self.current_index = 0
66 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images,
67 | labels=self.labels,
68 | teacher_logits=self.teacher_logits,
69 | segmentation_masks=self.segmentation_masks,
70 | seed=self.seed + self.shuffle_count)
71 | self.shuffle_count += 1
72 |
73 | def load_data(self, dataset):
74 | self.images = dataset["images"]
75 | self.labels = dataset["labels"]
76 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
77 | self.segmentation_masks = dataset["segmentation_masks"] if "segmentation_masks" in dataset else None
78 |
79 | self.len_images = len(self.images)
80 | self.len_labels = len(self.labels)
81 | assert self.len_images == self.len_labels
82 | self.image_count = self.len_labels
83 |
84 | if self.stage == 'train':
85 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images,
86 | labels=self.labels,
87 | teacher_logits=self.teacher_logits,
88 | segmentation_masks=self.segmentation_masks,
89 | seed=self.seed)
90 |
91 | def get_batch_count(self):
92 | return (self.len_labels // self.batch_size) + 1
93 |
94 | def get_truncated_normal(self, mean=0, sd=1, low=0, upp=10):
95 | return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
96 |
97 | def get_severity(self):
98 | return np.random.randint(1, 6)
99 |
100 | def draw_cicle(self, shape, diamiter):
101 | """
102 | Input:
103 | shape : tuple (height, width)
104 | diameter : scalar
105 |
106 | Output:
107 | np.array of shape that says True within a circle with diamiter = around center
108 | """
109 | assert len(shape) == 2
110 | TF = np.zeros(shape, dtype="bool")
111 | center = np.array(TF.shape) / 2.0
112 |
113 | for iy in range(shape[0]):
114 | for ix in range(shape[1]):
115 | TF[iy, ix] = (iy - center[0]) ** 2 + (ix - center[1]) ** 2 < diamiter ** 2
116 | return TF
117 |
118 | def filter_circle(self, TFcircle, fft_img_channel):
119 | temp = np.zeros(fft_img_channel.shape[:2], dtype=complex)
120 | temp[TFcircle] = fft_img_channel[TFcircle]
121 | return temp
122 |
123 | def inv_FFT_all_channel(self, fft_img):
124 | img_reco = []
125 | for ichannel in range(fft_img.shape[2]):
126 | img_reco.append(np.fft.ifft2(np.fft.ifftshift(fft_img[:, :, ichannel])))
127 | img_reco = np.array(img_reco)
128 | img_reco = np.transpose(img_reco, (1, 2, 0))
129 | return img_reco
130 |
131 | def high_pass_filter(self, x, severity):
132 | x = x.astype("float32") / 255.
133 | c = [.01, .02, .03, .04, .05][severity - 1]
134 |
135 | d = int(c * x.shape[0])
136 | TFcircle = self.draw_cicle(shape=x.shape[:2], diamiter=d)
137 | TFcircle = ~TFcircle
138 |
139 | fft_img = np.zeros_like(x, dtype=complex)
140 | for ichannel in range(fft_img.shape[2]):
141 | fft_img[:, :, ichannel] = np.fft.fftshift(np.fft.fft2(x[:, :, ichannel]))
142 |
143 | # For each channel, pass filter
144 | fft_img_filtered = []
145 | for ichannel in range(fft_img.shape[2]):
146 | fft_img_channel = fft_img[:, :, ichannel]
147 | temp = self.filter_circle(TFcircle, fft_img_channel)
148 | fft_img_filtered.append(temp)
149 | fft_img_filtered = np.array(fft_img_filtered)
150 | fft_img_filtered = np.transpose(fft_img_filtered, (1, 2, 0))
151 | x = np.clip(np.abs(self.inv_FFT_all_channel(fft_img_filtered)), a_min=0, a_max=1)
152 |
153 | x = PILImage.fromarray((x * 255.).astype("uint8"))
154 | return x
155 |
156 | def constant_amplitude(self, x, severity):
157 | """
158 | A visual corruption based on amplitude information of a Fourier-transformed image
159 |
160 | Adopted from: https://github.com/MediaBrain-SJTU/FACT
161 | """
162 | x = x.astype("float32") / 255.
163 | c = [.05, .1, .15, .2, .25][severity - 1]
164 |
165 | # FFT
166 | x_fft = np.fft.fft2(x, axes=(0, 1))
167 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft)
168 |
169 | # Amplitude replacement
170 | beta = 1.0 - c
171 | x_abs = np.ones_like(x_abs) * max(0, beta)
172 |
173 | # Inverse FFT
174 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1))
175 | x = x_abs * (np.e ** (1j * x_pha))
176 | x = np.real(np.fft.ifft2(x, axes=(0, 1)))
177 |
178 | x = PILImage.fromarray((x * 255.).astype("uint8"))
179 | return x
180 |
181 | def phase_scaling(self, x, severity):
182 | """
183 | A visual corruption based on phase information of a Fourier-transformed image
184 |
185 | Adopted from: https://github.com/MediaBrain-SJTU/FACT
186 | """
187 | x = x.astype("float32") / 255.
188 | c = [.1, .2, .3, .4, .5][severity - 1]
189 |
190 | # FFT
191 | x_fft = np.fft.fft2(x, axes=(0, 1))
192 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft)
193 |
194 | # Phase scaling
195 | alpha = 1.0 - c
196 | x_pha = x_pha * max(0, alpha)
197 |
198 | # Inverse FFT
199 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1))
200 | x = x_abs * (np.e ** (1j * x_pha))
201 | x = np.real(np.fft.ifft2(x, axes=(0, 1)))
202 |
203 | x = PILImage.fromarray((x * 255.).astype("uint8"))
204 | return x
205 |
206 | def apply_corruption(self, x, corruption_name):
207 | severity = self.get_severity()
208 |
209 | custom_corruptions = {"high_pass_filter": self.high_pass_filter,
210 | "constant_amplitude": self.constant_amplitude,
211 | "phase_scaling": self.phase_scaling}
212 |
213 | if corruption_name in get_corruption_names('all'):
214 | x = corrupt(x, corruption_name=corruption_name, severity=severity)
215 | x = PILImage.fromarray(x)
216 |
217 | elif corruption_name in custom_corruptions:
218 | x = custom_corruptions[corruption_name](x, severity=severity)
219 |
220 | else:
221 | assert True, "%s is not a supported corruption!" % corruption_name
222 |
223 | return x
224 |
225 | def acvc(self, x):
226 | i = np.random.randint(0, 22)
227 | corruption_func = {0: "fog",
228 | 1: "snow",
229 | 2: "frost",
230 | 3: "spatter",
231 | 4: "zoom_blur",
232 | 5: "defocus_blur",
233 | 6: "glass_blur",
234 | 7: "gaussian_blur",
235 | 8: "motion_blur",
236 | 9: "speckle_noise",
237 | 10: "shot_noise",
238 | 11: "impulse_noise",
239 | 12: "gaussian_noise",
240 | 13: "jpeg_compression",
241 | 14: "pixelate",
242 | 15: "elastic_transform",
243 | 16: "brightness",
244 | 17: "saturate",
245 | 18: "contrast",
246 | 19: "high_pass_filter",
247 | 20: "constant_amplitude",
248 | 21: "phase_scaling"}
249 | return self.apply_corruption(x, corruption_func[i])
250 |
251 | def corruption(self, x, segmentation_mask=None):
252 | if self.rand_aug and np.random.uniform(0, 1) > 0.5:
253 | x_ = self.random_augmentation(PILImage.fromarray(x))
254 |
255 | else:
256 | x_ = np.copy(x)
257 | x_ = self.acvc(x_)
258 |
259 | return x_
260 |
261 | def get_batch(self, epoch=None):
262 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
263 | teacher_logits = None if self.teacher_logits is None else []
264 |
265 | if self.orig_plus_aug:
266 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
267 | images = torch.zeros(tensor_shape, dtype=torch.float32)
268 | augmented_images = torch.zeros(tensor_shape, dtype=torch.float32)
269 | for i in range(self.batch_size):
270 | # Avoid over flow
271 | if self.current_index > self.image_count - 1:
272 | if self.stage == "train":
273 | self.shuffle()
274 | else:
275 | self.current_index = 0
276 |
277 | x = self.images[self.current_index]
278 | y = self.labels[self.current_index]
279 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index]
280 | images[i] = normalize_dataset(PILImage.fromarray(x), img_mean_mode=self.img_mean_mode)
281 | labels[i] = y
282 |
283 | augmented_x = self.corruption(x, mask)
284 | augmented_images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode)
285 |
286 | if teacher_logits is not None:
287 | teacher_logits.append(self.teacher_logits[self.current_index])
288 |
289 | self.current_index += 1
290 |
291 | # Include teacher logits as soft labels if applicable
292 | if teacher_logits is not None:
293 | labels = [labels, np.array(teacher_logits)]
294 |
295 | batches = [(images, labels), (augmented_images, labels)]
296 |
297 | else:
298 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
299 | images = torch.zeros(tensor_shape, dtype=torch.float32)
300 | for i in range(self.batch_size):
301 | # Avoid over flow
302 | if self.current_index > self.image_count - 1:
303 | if self.stage == "train":
304 | self.shuffle()
305 | else:
306 | self.current_index = 0
307 |
308 | x = self.images[self.current_index]
309 | y = self.labels[self.current_index]
310 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index]
311 |
312 | augmented_x = self.corruption(x, mask)
313 | images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode)
314 | labels[i] = y
315 |
316 | if teacher_logits is not None:
317 | teacher_logits.append(self.teacher_logits[self.current_index])
318 |
319 | self.current_index += 1
320 |
321 | # Include teacher logits as soft labels if applicable
322 | if teacher_logits is not None:
323 | labels = [labels, np.array(teacher_logits)]
324 |
325 | batches = [(images, labels)]
326 |
327 | return batches
--------------------------------------------------------------------------------
/preprocessing/image/AblationGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tools import shuffle_data
4 | from preprocessing.Datasets import normalize_dataset
5 | from PIL import Image as PILImage
6 | from scipy.stats import truncnorm
7 | from preprocessing.image.RandAugmentGenerator import RandAugment
8 | from imagecorruptions import corrupt, get_corruption_names
9 |
10 | class AblationGenerator:
11 | def __init__(self,
12 | dataset,
13 | batch_size,
14 | stage,
15 | epochs,
16 | corruption_mode,
17 | corruption_dist,
18 | img_mean_mode=None,
19 | rand_aug=False,
20 | seed=13,
21 | orig_plus_aug=True):
22 | """
23 | 19 out of 22 corruptions used in this generator are taken from ImageNet-10.C:
24 | - https://github.com/bethgelab/imagecorruptions
25 |
26 | :param dataset: (tuple) x, y, segmentation mask (optional)
27 | :param batch_size: (int) # of inputs in a mini-batch
28 | :param stage: (str) train | test
29 | :param epochs: (int) # of full training passes
30 | :param corruption_mode: (str) requied for VisCo experiments
31 | :param corruption_dist: (str) requied to determine the corruption rate
32 | :param img_mean_mode: (str) use this to revert the image into its original form before augmentation
33 | :param rand_aug: (bool) enable/disable on-the-fly random data augmentation
34 | :param seed: (int) seed for input shuffle
35 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
36 | """
37 |
38 | if stage not in ['train', 'test']:
39 | assert ValueError('invalid stage!')
40 |
41 | # Settings
42 | self.batch_size = batch_size
43 | self.stage = stage
44 | self.epochs = epochs
45 | self.corruption_mode = corruption_mode
46 | self.corruption_dist = corruption_dist
47 | self.img_mean_mode = img_mean_mode
48 | self.rand_aug = rand_aug
49 | self.seed = seed
50 | self.orig_plus_aug = orig_plus_aug
51 |
52 | # Preparation
53 | self.configuration()
54 | self.load_data(dataset)
55 | self.random_augmentation = RandAugment(1, 5)
56 | if self.img_mean_mode is not None:
57 | self.img_mean = np.load("img_mean.npy")
58 |
59 | def configuration(self):
60 | self.shuffle_count = 1
61 | self.current_index = 0
62 |
63 | def shuffle(self):
64 | self.image_count = len(self.labels)
65 | self.current_index = 0
66 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images,
67 | labels=self.labels,
68 | teacher_logits=self.teacher_logits,
69 | segmentation_masks=self.segmentation_masks,
70 | seed=self.seed + self.shuffle_count)
71 | self.shuffle_count += 1
72 |
73 | def load_data(self, dataset):
74 | self.images = dataset["images"]
75 | self.labels = dataset["labels"]
76 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
77 | self.segmentation_masks = dataset["segmentation_masks"] if "segmentation_masks" in dataset else None
78 |
79 | self.len_images = len(self.images)
80 | self.len_labels = len(self.labels)
81 | assert self.len_images == self.len_labels
82 | self.image_count = self.len_labels
83 |
84 | if self.stage == 'train':
85 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images,
86 | labels=self.labels,
87 | teacher_logits=self.teacher_logits,
88 | segmentation_masks=self.segmentation_masks,
89 | seed=self.seed)
90 |
91 | def get_batch_count(self):
92 | return (self.len_labels // self.batch_size) + 1
93 |
94 | def get_truncated_normal(self, mean=0, sd=1, low=0, upp=10):
95 | return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
96 |
97 | def get_severity(self):
98 | return 1 #np.random.randint(1, 6)
99 |
100 | def get_random_boundingbox(self, img, l_param):
101 | width = img.shape[0]
102 | height = img.shape[1]
103 |
104 | r_x = np.random.randint(width)
105 | r_y = np.random.randint(height)
106 |
107 | r_l = np.sqrt(1 - l_param)
108 | r_w = int(width * r_l)
109 | r_h = int(height * r_l)
110 |
111 | if r_x + r_w < width:
112 | bbox_x1 = r_x
113 | bbox_x2 = r_w
114 | else:
115 | bbox_x1 = width - r_w
116 | bbox_x2 = width
117 | if r_y + r_h < height:
118 | bbox_y1 = r_y
119 | bbox_y2 = r_h
120 | else:
121 | bbox_y1 = height - r_h
122 | bbox_y2 = height
123 |
124 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2
125 |
126 | def draw_cicle(self, shape, diamiter):
127 | """
128 | Input:
129 | shape : tuple (height, width)
130 | diameter : scalar
131 |
132 | Output:
133 | np.array of shape that says True within a circle with diamiter = around center
134 | """
135 | assert len(shape) == 2
136 | TF = np.zeros(shape, dtype="bool")
137 | center = np.array(TF.shape) / 2.0
138 |
139 | for iy in range(shape[0]):
140 | for ix in range(shape[1]):
141 | TF[iy, ix] = (iy - center[0]) ** 2 + (ix - center[1]) ** 2 < diamiter ** 2
142 | return TF
143 |
144 | def filter_circle(self, TFcircle, fft_img_channel):
145 | temp = np.zeros(fft_img_channel.shape[:2], dtype=complex)
146 | temp[TFcircle] = fft_img_channel[TFcircle]
147 | return temp
148 |
149 | def inv_FFT_all_channel(self, fft_img):
150 | img_reco = []
151 | for ichannel in range(fft_img.shape[2]):
152 | img_reco.append(np.fft.ifft2(np.fft.ifftshift(fft_img[:, :, ichannel])))
153 | img_reco = np.array(img_reco)
154 | img_reco = np.transpose(img_reco, (1, 2, 0))
155 | return img_reco
156 |
157 | def high_pass_filter(self, x, severity):
158 | x = x.astype("float32") / 255.
159 | c = [.01, .02, .03, .04, .05][severity - 1]
160 |
161 | d = int(c * x.shape[0])
162 | TFcircle = self.draw_cicle(shape=x.shape[:2], diamiter=d)
163 | TFcircle = ~TFcircle
164 |
165 | fft_img = np.zeros_like(x, dtype=complex)
166 | for ichannel in range(fft_img.shape[2]):
167 | fft_img[:, :, ichannel] = np.fft.fftshift(np.fft.fft2(x[:, :, ichannel]))
168 |
169 | # For each channel, pass filter
170 | fft_img_filtered = []
171 | for ichannel in range(fft_img.shape[2]):
172 | fft_img_channel = fft_img[:, :, ichannel]
173 | temp = self.filter_circle(TFcircle, fft_img_channel)
174 | fft_img_filtered.append(temp)
175 | fft_img_filtered = np.array(fft_img_filtered)
176 | fft_img_filtered = np.transpose(fft_img_filtered, (1, 2, 0))
177 | x = np.clip(np.abs(self.inv_FFT_all_channel(fft_img_filtered)), a_min=0, a_max=1)
178 |
179 | x = PILImage.fromarray((x * 255.).astype("uint8"))
180 | return x
181 |
182 | def constant_amplitude(self, x, severity):
183 | """
184 | A visual corruption based on amplitude information of a Fourier-transformed image
185 |
186 | Adopted from: https://github.com/MediaBrain-SJTU/FACT
187 | """
188 | x = x.astype("float32") / 255.
189 | c = [.05, .1, .15, .2, .25][severity - 1]
190 |
191 | # FFT
192 | x_fft = np.fft.fft2(x, axes=(0, 1))
193 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft)
194 |
195 | # Amplitude replacement
196 | beta = 1.0 - c
197 | x_abs = np.ones_like(x_abs) * max(0, beta)
198 |
199 | # Inverse FFT
200 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1))
201 | x = x_abs * (np.e ** (1j * x_pha))
202 | x = np.real(np.fft.ifft2(x, axes=(0, 1)))
203 |
204 | x = PILImage.fromarray((x * 255.).astype("uint8"))
205 | return x
206 |
207 | def phase_scaling(self, x, severity):
208 | """
209 | A visual corruption based on phase information of a Fourier-transformed image
210 |
211 | Adopted from: https://github.com/MediaBrain-SJTU/FACT
212 | """
213 | x = x.astype("float32") / 255.
214 | c = [.1, .2, .3, .4, .5][severity - 1]
215 |
216 | # FFT
217 | x_fft = np.fft.fft2(x, axes=(0, 1))
218 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft)
219 |
220 | # Phase scaling
221 | alpha = 1.0 - c
222 | x_pha = x_pha * max(0, alpha)
223 |
224 | # Inverse FFT
225 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1))
226 | x = x_abs * (np.e ** (1j * x_pha))
227 | x = np.real(np.fft.ifft2(x, axes=(0, 1)))
228 |
229 | x = PILImage.fromarray((x * 255.).astype("uint8"))
230 | return x
231 |
232 | def apply_corruption(self, x, corruption_name):
233 | severity = self.get_severity()
234 |
235 | custom_corruptions = {"high_pass_filter": self.high_pass_filter,
236 | "constant_amplitude": self.constant_amplitude,
237 | "phase_scaling": self.phase_scaling}
238 |
239 | if corruption_name in get_corruption_names('all'):
240 | x = corrupt(x, corruption_name=corruption_name, severity=severity)
241 | x = PILImage.fromarray(x)
242 |
243 | elif corruption_name in custom_corruptions:
244 | x = custom_corruptions[corruption_name](x, severity=severity)
245 |
246 | else:
247 | assert True, "%s is not a supported corruption!" % corruption_name
248 |
249 | return x
250 |
251 | def weather(self, x):
252 | i = np.random.randint(0, 4)
253 | corruption_func = {0: "fog",
254 | 1: "snow",
255 | 2: "frost",
256 | 3: "spatter"}
257 | return self.apply_corruption(x, corruption_func[i])
258 |
259 | def blur(self, x):
260 | i = np.random.randint(0, 5)
261 | corruption_func = {0: "zoom_blur",
262 | 1: "defocus_blur",
263 | 2: "glass_blur",
264 | 3: "gaussian_blur",
265 | 4: "motion_blur"}
266 | return self.apply_corruption(x, corruption_func[i])
267 |
268 | def noise(self, x):
269 | i = np.random.randint(0, 4)
270 | corruption_func = {0: "speckle_noise",
271 | 1: "shot_noise",
272 | 2: "impulse_noise",
273 | 3: "gaussian_noise"}
274 | return self.apply_corruption(x, corruption_func[i])
275 |
276 | def digital(self, x):
277 | i = np.random.randint(0, 6)
278 | corruption_func = {0: "jpeg_compression",
279 | 1: "pixelate",
280 | 2: "elastic_transform",
281 | 3: "brightness",
282 | 4: "saturate",
283 | 5: "contrast"}
284 | return self.apply_corruption(x, corruption_func[i])
285 |
286 | def fourier(self, x):
287 | i = np.random.randint(0, 3)
288 | corruption_func = {0: "high_pass_filter",
289 | 1: "constant_amplitude",
290 | 2: "phase_scaling"}
291 | return self.apply_corruption(x, corruption_func[i])
292 |
293 | def all(self, x):
294 | i = np.random.randint(0, 22) if self.corruption_mode == "all++" else np.random.randint(0, 19)
295 | corruption_func = {0: "fog",
296 | 1: "snow",
297 | 2: "frost",
298 | 3: "zoom_blur",
299 | 4: "defocus_blur",
300 | 5: "glass_blur",
301 | 6: "gaussian_blur",
302 | 7: "motion_blur",
303 | 8: "speckle_noise",
304 | 9: "shot_noise",
305 | 10: "impulse_noise",
306 | 11: "gaussian_noise",
307 | 12: "jpeg_compression",
308 | 13: "pixelate",
309 | 14: "spatter",
310 | 15: "elastic_transform",
311 | 16: "brightness",
312 | 17: "saturate",
313 | 18: "contrast",
314 | 19: "high_pass_filter",
315 | 20: "constant_amplitude",
316 | 21: "phase_scaling"}
317 | return self.apply_corruption(x, corruption_func[i])
318 |
319 | def corruption(self, x, segmentation_mask=None):
320 | if self.rand_aug and np.random.uniform(0, 1) > 0.5:
321 | x_ = self.random_augmentation(PILImage.fromarray(x))
322 |
323 | else:
324 | super_funcs = {"weather": self.weather,
325 | "blur": self.blur,
326 | "noise": self.noise,
327 | "digital": self.digital,
328 | "all": self.all,
329 | "all++": self.all}
330 | x_ = np.copy(x)
331 | corruption_mode = self.corruption_mode.split("_")[0]
332 |
333 | # Apply corruption
334 | if corruption_mode in super_funcs:
335 | x_ = super_funcs[corruption_mode](x_)
336 | else:
337 | x_ = self.apply_corruption(x, corruption_mode)
338 |
339 | return x_
340 |
341 | def get_batch(self, epoch=None):
342 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
343 | teacher_logits = None if self.teacher_logits is None else []
344 |
345 | if self.orig_plus_aug:
346 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
347 | images = torch.zeros(tensor_shape, dtype=torch.float32)
348 | augmented_images = torch.zeros(tensor_shape, dtype=torch.float32)
349 | for i in range(self.batch_size):
350 | # Avoid over flow
351 | if self.current_index > self.image_count - 1:
352 | if self.stage == "train":
353 | self.shuffle()
354 | else:
355 | self.current_index = 0
356 |
357 | x = self.images[self.current_index]
358 | y = self.labels[self.current_index]
359 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index]
360 | images[i] = normalize_dataset(PILImage.fromarray(x), img_mean_mode=self.img_mean_mode)
361 | labels[i] = y
362 |
363 | augmented_x = self.corruption(x, mask)
364 | augmented_images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode)
365 |
366 | if teacher_logits is not None:
367 | teacher_logits.append(self.teacher_logits[self.current_index])
368 |
369 | self.current_index += 1
370 |
371 | # Include teacher logits as soft labels if applicable
372 | if teacher_logits is not None:
373 | labels = [labels, np.array(teacher_logits)]
374 |
375 | batches = [(images, labels), (augmented_images, labels)]
376 |
377 | else:
378 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
379 | images = torch.zeros(tensor_shape, dtype=torch.float32)
380 | for i in range(self.batch_size):
381 | # Avoid over flow
382 | if self.current_index > self.image_count - 1:
383 | if self.stage == "train":
384 | self.shuffle()
385 | else:
386 | self.current_index = 0
387 |
388 | x = self.images[self.current_index]
389 | y = self.labels[self.current_index]
390 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index]
391 |
392 | augmented_x = self.corruption(x, mask)
393 | images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode)
394 | labels[i] = y
395 |
396 | if teacher_logits is not None:
397 | teacher_logits.append(self.teacher_logits[self.current_index])
398 |
399 | self.current_index += 1
400 |
401 | # Include teacher logits as soft labels if applicable
402 | if teacher_logits is not None:
403 | labels = [labels, np.array(teacher_logits)]
404 |
405 | batches = [(images, labels)]
406 |
407 | return batches
408 |
409 | def test(self, x, img_id):
410 |
411 | for f in get_corruption_names("all"):
412 | x_ = np.copy(x)
413 | x_ = self.apply_corruption(x_, f)
414 | x_.save("example_%d_%s.jpeg" % (img_id, f))
415 | print("%s testing is done." % f)
416 |
417 | for f in ["high_pass_filter", "constant_amplitude", "phase_scaling"]:
418 | x_ = np.copy(x)
419 | x_ = self.apply_corruption(x_, f)
420 | x_.save("example_%d_%s.jpeg" % (img_id, f))
421 | print("%s testing is done." % f)
422 |
423 | exit(1)
--------------------------------------------------------------------------------
/preprocessing/image/AugMixGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image, ImageOps, ImageEnhance
4 | from tools import shuffle_data
5 | from preprocessing.Datasets import normalize_dataset
6 |
7 | # Copyright 2019 Google LLC
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # https://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | # ==============================================================================
21 | """Base augmentations operators."""
22 |
23 | # ImageNet code should change this value
24 | IMAGE_SIZE = 224
25 |
26 |
27 | def int_parameter(level, maxval):
28 | """Helper function to scale `val` between 0 and maxval .
29 |
30 | Args:
31 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
32 | maxval: Maximum value that the operation can have. This will be scaled to
33 | level/PARAMETER_MAX.
34 |
35 | Returns:
36 | An int that results from scaling `maxval` according to `level`.
37 | """
38 | return int(level * maxval / 10)
39 |
40 |
41 | def float_parameter(level, maxval):
42 | """Helper function to scale `val` between 0 and maxval.
43 |
44 | Args:
45 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
46 | maxval: Maximum value that the operation can have. This will be scaled to
47 | level/PARAMETER_MAX.
48 |
49 | Returns:
50 | A float that results from scaling `maxval` according to `level`.
51 | """
52 | return float(level) * maxval / 10.
53 |
54 |
55 | def sample_level(n):
56 | return np.random.uniform(low=0.1, high=n)
57 |
58 |
59 | def autocontrast(pil_img, _):
60 | return ImageOps.autocontrast(pil_img)
61 |
62 |
63 | def equalize(pil_img, _):
64 | return ImageOps.equalize(pil_img)
65 |
66 |
67 | def posterize(pil_img, level):
68 | level = int_parameter(sample_level(level), 4)
69 | return ImageOps.posterize(pil_img, 4 - level)
70 |
71 |
72 | def rotate(pil_img, level):
73 | degrees = int_parameter(sample_level(level), 30)
74 | if np.random.uniform() > 0.5:
75 | degrees = -degrees
76 | return pil_img.rotate(degrees, resample=Image.BILINEAR)
77 |
78 |
79 | def solarize(pil_img, level):
80 | level = int_parameter(sample_level(level), 256)
81 | return ImageOps.solarize(pil_img, 256 - level)
82 |
83 |
84 | def shear_x(pil_img, level):
85 | level = float_parameter(sample_level(level), 0.3)
86 | if np.random.uniform() > 0.5:
87 | level = -level
88 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
89 | Image.AFFINE, (1, level, 0, 0, 1, 0),
90 | resample=Image.BILINEAR)
91 |
92 |
93 | def shear_y(pil_img, level):
94 | level = float_parameter(sample_level(level), 0.3)
95 | if np.random.uniform() > 0.5:
96 | level = -level
97 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
98 | Image.AFFINE, (1, 0, 0, level, 1, 0),
99 | resample=Image.BILINEAR)
100 |
101 |
102 | def translate_x(pil_img, level):
103 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
104 | if np.random.random() > 0.5:
105 | level = -level
106 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
107 | Image.AFFINE, (1, 0, level, 0, 1, 0),
108 | resample=Image.BILINEAR)
109 |
110 |
111 | def translate_y(pil_img, level):
112 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
113 | if np.random.random() > 0.5:
114 | level = -level
115 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
116 | Image.AFFINE, (1, 0, 0, 0, 1, level),
117 | resample=Image.BILINEAR)
118 |
119 |
120 | # operation that overlaps with ImageNet-C's test set
121 | def color(pil_img, level):
122 | level = float_parameter(sample_level(level), 1.8) + 0.1
123 | return ImageEnhance.Color(pil_img).enhance(level)
124 |
125 |
126 | # operation that overlaps with ImageNet-C's test set
127 | def contrast(pil_img, level):
128 | level = float_parameter(sample_level(level), 1.8) + 0.1
129 | return ImageEnhance.Contrast(pil_img).enhance(level)
130 |
131 |
132 | # operation that overlaps with ImageNet-C's test set
133 | def brightness(pil_img, level):
134 | level = float_parameter(sample_level(level), 1.8) + 0.1
135 | return ImageEnhance.Brightness(pil_img).enhance(level)
136 |
137 |
138 | # operation that overlaps with ImageNet-C's test set
139 | def sharpness(pil_img, level):
140 | level = float_parameter(sample_level(level), 1.8) + 0.1
141 | return ImageEnhance.Sharpness(pil_img).enhance(level)
142 |
143 | augmentations_augmix = [
144 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
145 | translate_x, translate_y
146 | ]
147 |
148 | augmentations_augmix_all = [
149 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
150 | translate_x, translate_y, color, contrast, brightness, sharpness
151 | ]
152 |
153 | class AugMixGenerator:
154 | def __init__(self,
155 | dataset,
156 | batch_size,
157 | stage,
158 | aug_prob_coeff= 1.,
159 | mixture_width = 3,
160 | mixture_depth = -1,
161 | aug_severity = 1,
162 | img_mean_mode=None,
163 | seed=13,
164 | orig_plus_aug=True):
165 | """
166 | :param dataset: (tuple) x, y, segmentation mask (optional)
167 | :param batch_size: (int) # of inputs in a mini-batch
168 | :param stage: (str) train | test
169 | :param aug_prob_coeff: (float) alpha in the paper
170 | :param aug_severity: (int) from the reposityory
171 | :param mixture_width: (int) from the paper
172 | :param mixture_depth: (int) from the paper
173 | :param img_mean_mode: (str) use this for image normalization
174 | :param seed: (int) seed for input shuffle
175 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
176 | """
177 |
178 | if stage not in ['train', 'test']:
179 | assert ValueError('invalid stage!')
180 |
181 | # Settings
182 | self.batch_size = batch_size
183 | self.stage = stage
184 | self.aug_prob_coeff = aug_prob_coeff
185 | self.mixture_width = mixture_width
186 | self.mixture_depth = mixture_depth
187 | self.aug_severity = aug_severity
188 | self.img_mean_mode = img_mean_mode
189 | self.seed = seed
190 | self.orig_plus_aug = orig_plus_aug
191 |
192 | # Preparation
193 | self.configuration()
194 | self.load_data(dataset)
195 |
196 | def configuration(self):
197 | self.shuffle_count = 1
198 | self.current_index = 0
199 |
200 | def shuffle(self):
201 | self.image_count = len(self.labels)
202 | self.current_index = 0
203 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
204 | labels=self.labels,
205 | teacher_logits=self.teacher_logits,
206 | seed=self.seed + self.shuffle_count)
207 | self.shuffle_count += 1
208 |
209 | def load_data(self, dataset):
210 | self.images = dataset["images"]
211 | self.labels = dataset["labels"]
212 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
213 |
214 | self.len_images = len(self.images)
215 | self.len_labels = len(self.labels)
216 | assert self.len_images == self.len_labels
217 | self.image_count = self.len_labels
218 |
219 | if self.stage == 'train':
220 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
221 | labels=self.labels,
222 | teacher_logits=self.teacher_logits,
223 | seed=self.seed)
224 |
225 | def get_batch_count(self):
226 | return (self.len_labels // self.batch_size) + 1
227 |
228 | def augmix(self, x):
229 | aug_list = augmentations_augmix_all
230 |
231 | ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
232 | m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
233 |
234 | mix = np.zeros_like(np.array(x)).astype("float32")
235 | for i in range(self.mixture_width):
236 | x_ = x.copy()
237 | depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
238 | for _ in range(depth):
239 | op = np.random.choice(aug_list)
240 | x_ = op(x_, self.aug_severity)
241 |
242 | # Preprocessing commutes since all coefficients are convex
243 | mix += ws[i] * np.array(x_).astype("float32")
244 |
245 | mix = Image.fromarray(mix.astype("uint8"))
246 | x_ = Image.blend(x, mix, m)
247 |
248 | return x_
249 |
250 | def augment(self, x):
251 |
252 | x_ = [self.augmix(x) for _ in range(2)]
253 |
254 | return x_
255 |
256 | def get_batch(self, epoch=None):
257 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
258 | teacher_logits = None if self.teacher_logits is None else []
259 | expansion_coeff = 2
260 |
261 | if self.orig_plus_aug:
262 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
263 | images = torch.zeros(tensor_shape, dtype=torch.float32)
264 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)]
265 | for i in range(self.batch_size):
266 | # Avoid over flow
267 | if self.current_index > self.image_count - 1:
268 | if self.stage == "train":
269 | self.shuffle()
270 | else:
271 | self.current_index = 0
272 |
273 | x = Image.fromarray(self.images[self.current_index])
274 | y = self.labels[self.current_index]
275 | images[i] = normalize_dataset(x, self.img_mean_mode)
276 | labels[i] = y
277 |
278 | augmented_x = self.augment(x)
279 | for j in range(expansion_coeff):
280 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode)
281 |
282 | if teacher_logits is not None:
283 | teacher_logits.append(self.teacher_logits[self.current_index])
284 |
285 | self.current_index += 1
286 |
287 | # Include teacher logits as soft labels if applicable
288 | if teacher_logits is not None:
289 | labels = [labels, np.array(teacher_logits)]
290 |
291 | batches = [(images, labels)]
292 | for i in range(expansion_coeff):
293 | batches.append((augmented_images[i], labels))
294 |
295 | else:
296 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
297 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)]
298 | for i in range(self.batch_size):
299 | # Avoid over flow
300 | if self.current_index > self.image_count - 1:
301 | if self.stage == "train":
302 | self.shuffle()
303 | else:
304 | self.current_index = 0
305 |
306 | x = Image.fromarray(self.images[self.current_index])
307 | y = self.labels[self.current_index]
308 | labels[i] = y
309 |
310 | augmented_x = self.augment(x)
311 | for j in range(expansion_coeff):
312 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode)
313 |
314 | if teacher_logits is not None:
315 | teacher_logits.append(self.teacher_logits[self.current_index])
316 |
317 | self.current_index += 1
318 |
319 | # Include teacher logits as soft labels if applicable
320 | if teacher_logits is not None:
321 | labels = [labels, np.array(teacher_logits)]
322 |
323 | batches = []
324 | for i in range(expansion_coeff):
325 | batches.append((augmented_images[i], labels))
326 |
327 | return batches
328 |
--------------------------------------------------------------------------------
/preprocessing/image/CutMixGenerator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from tools import shuffle_data
6 | from preprocessing.Datasets import normalize_dataset
7 |
8 | class CutMixGenerator:
9 | def __init__(self,
10 | dataset,
11 | batch_size,
12 | stage,
13 | img_mean_mode=None,
14 | seed=13,
15 | orig_plus_aug=True):
16 | """
17 | :param dataset: (tuple) x, y, segmentation mask (optional)
18 | :param batch_size: (int) # of inputs in a mini-batch
19 | :param stage: (str) train | test
20 | :param img_mean_mode: (str) use this for image normalization
21 | :param seed: (int) seed for input shuffle
22 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
23 | """
24 |
25 | if stage not in ['train', 'test']:
26 | assert ValueError('invalid stage!')
27 |
28 | # Settings
29 | self.batch_size = batch_size
30 | self.stage = stage
31 | self.img_mean_mode = img_mean_mode
32 | self.seed = seed
33 | self.orig_plus_aug = orig_plus_aug
34 |
35 | # Preparation
36 | self.configuration()
37 | self.load_data(dataset)
38 |
39 | def configuration(self):
40 | self.shuffle_count = 1
41 | self.current_index = 0
42 |
43 | def shuffle(self):
44 | self.image_count = len(self.labels)
45 | self.current_index = 0
46 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
47 | labels=self.labels,
48 | teacher_logits=self.teacher_logits,
49 | seed=self.seed + self.shuffle_count)
50 | self.shuffle_count += 1
51 |
52 | def load_data(self, dataset):
53 | self.images = dataset["images"]
54 | self.labels = dataset["labels"]
55 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
56 |
57 | self.len_images = len(self.images)
58 | self.len_labels = len(self.labels)
59 | assert self.len_images == self.len_labels
60 | self.image_count = self.len_labels
61 |
62 | if self.stage == 'train':
63 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
64 | labels=self.labels,
65 | teacher_logits=self.teacher_logits,
66 | seed=self.seed)
67 |
68 | def get_batch_count(self):
69 | return (self.len_labels // self.batch_size) + 1
70 |
71 | def get_random_boundingbox(self, img, l_param):
72 | width = img.shape[0]
73 | height = img.shape[1]
74 |
75 | r_x = np.random.randint(width)
76 | r_y = np.random.randint(height)
77 |
78 | r_l = np.sqrt(1 - l_param)
79 | r_w = int(width * r_l)
80 | r_h = int(height * r_l)
81 |
82 | if r_x + r_w < width:
83 | bbox_x1 = r_x
84 | bbox_x2 = r_w
85 | else:
86 | bbox_x1 = width - r_w
87 | bbox_x2 = width
88 | if r_y + r_h < height:
89 | bbox_y1 = r_y
90 | bbox_y2 = r_h
91 | else:
92 | bbox_y1 = height - r_h
93 | bbox_y2 = height
94 |
95 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2
96 |
97 | def cutmix(self, image_batch, label_batch, beta=1.0):
98 | l_param = np.random.beta(beta, beta, self.batch_size)
99 | rand_index = torch.randperm(self.batch_size)
100 |
101 | x = image_batch.detach().clone()
102 | y = np.zeros_like(label_batch)
103 |
104 | for i in range(self.batch_size):
105 | bx1, by1, bx2, by2 = self.get_random_boundingbox(x[i], l_param[i])
106 | x[i][:, bx1:bx2, by1:by2] = image_batch[rand_index[i]][:, bx1:bx2, by1:by2]
107 | y[i] = label_batch[rand_index[i]]
108 |
109 | # Adjust lambda to exactly match pixel ratio
110 | l_param[i] = 1 - ((bx2 - bx1) * (by2 - by1) / (x.size()[-1] * x.size()[-2]))
111 |
112 | return x, (label_batch, y, l_param)
113 |
114 | def get_batch(self, epoch=None):
115 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
116 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
117 | images = torch.zeros(tensor_shape, dtype=torch.float32)
118 | for i in range(self.batch_size):
119 | # Avoid over flow
120 | if self.current_index > self.image_count - 1:
121 | if self.stage == "train":
122 | self.shuffle()
123 | else:
124 | self.current_index = 0
125 |
126 | x = self.images[self.current_index]
127 | y = self.labels[self.current_index]
128 | images[i] = normalize_dataset(Image.fromarray(x), img_mean_mode=self.img_mean_mode)
129 | labels[i] = y
130 |
131 | self.current_index += 1
132 |
133 | augmented_images, augmented_labels = self.cutmix(images, labels)
134 |
135 | if self.orig_plus_aug:
136 | batches = [(images, (labels, labels, (np.ones_like(augmented_labels[2])))), (augmented_images, augmented_labels)]
137 |
138 | else:
139 | batches = [(augmented_images, augmented_labels)]
140 |
141 | return batches
142 |
--------------------------------------------------------------------------------
/preprocessing/image/CutOutGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | from tools import shuffle_data
5 | from preprocessing.Datasets import normalize_dataset
6 |
7 | class CutOutGenerator:
8 | def __init__(self,
9 | dataset,
10 | batch_size,
11 | stage,
12 | img_mean_mode=None,
13 | seed=13,
14 | orig_plus_aug=True):
15 | """
16 | :param dataset: (tuple) x, y, segmentation mask (optional)
17 | :param batch_size: (int) # of inputs in a mini-batch
18 | :param stage: (str) train | test
19 | :param img_mean_mode: (str) use this for image normalization
20 | :param seed: (int) seed for input shuffle
21 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
22 | """
23 |
24 | if stage not in ['train', 'test']:
25 | assert ValueError('invalid stage!')
26 |
27 | # Settings
28 | self.batch_size = batch_size
29 | self.stage = stage
30 | self.img_mean_mode=img_mean_mode
31 | self.seed = seed
32 | self.orig_plus_aug = orig_plus_aug
33 |
34 | # Preparation
35 | self.configuration()
36 | self.load_data(dataset)
37 |
38 | def configuration(self):
39 | self.shuffle_count = 1
40 | self.current_index = 0
41 |
42 | def shuffle(self):
43 | self.image_count = len(self.labels)
44 | self.current_index = 0
45 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
46 | labels=self.labels,
47 | teacher_logits=self.teacher_logits,
48 | seed=self.seed + self.shuffle_count)
49 | self.shuffle_count += 1
50 |
51 | def load_data(self, dataset):
52 | self.images = dataset["images"]
53 | self.labels = dataset["labels"]
54 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
55 |
56 | self.len_images = len(self.images)
57 | self.len_labels = len(self.labels)
58 | assert self.len_images == self.len_labels
59 | self.image_count = self.len_labels
60 |
61 | if self.stage == 'train':
62 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
63 | labels=self.labels,
64 | teacher_logits=self.teacher_logits,
65 | seed=self.seed)
66 |
67 | def get_batch_count(self):
68 | return (self.len_labels // self.batch_size) + 1
69 |
70 | def get_random_eraser(self, p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1 / 0.3, v_l=0.0, v_h=1.0):
71 | """
72 | This CutOut implementation is taken from:
73 | - https://github.com/yu4u/cutout-random-erasing
74 |
75 | ...and modified for info loss experiments
76 |
77 | # Arguments:
78 | :param p: (float) the probability that random erasing is performed
79 | :param s_l: (float) minimum proportion of erased area against input image
80 | :param s_h: (float) maximum proportion of erased area against input image
81 | :param r_1: (float) minimum aspect ratio of erased area
82 | :param r_2: (float) maximum aspect ratio of erased area
83 | :param v_l: (float) minimum value for erased area
84 | :param v_h: (float) maximum value for erased area
85 | :param fill: (str) fill-in mode for the cropped area
86 |
87 | :return: (np.array) augmented image
88 | """
89 | def eraser(orig_img):
90 | input_img = np.copy(orig_img)
91 | if input_img.ndim == 3:
92 | img_h, img_w, img_c = input_img.shape
93 | elif input_img.ndim == 2:
94 | img_h, img_w = input_img.shape
95 |
96 | p_1 = np.random.rand()
97 |
98 | if p_1 > p:
99 | return input_img
100 |
101 | while True:
102 | s = np.random.uniform(s_l, s_h) * img_h * img_w
103 | r = np.random.uniform(r_1, r_2)
104 | w = int(np.sqrt(s / r))
105 | h = int(np.sqrt(s * r))
106 | left = np.random.randint(0, img_w)
107 | top = np.random.randint(0, img_h)
108 |
109 | if left + w <= img_w and top + h <= img_h:
110 | break
111 |
112 | input_img[top:top + h, left:left + w] = 0
113 |
114 | return input_img
115 |
116 | return eraser
117 |
118 | def cutout(self, x):
119 |
120 | eraser = self.get_random_eraser()
121 | x_ = eraser(x)
122 |
123 | return x_
124 |
125 | def get_batch(self, epoch=None):
126 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
127 |
128 | if self.orig_plus_aug:
129 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
130 | images = torch.zeros(tensor_shape, dtype=torch.float32)
131 | augmented_images = torch.zeros(tensor_shape, dtype=torch.float32)
132 | for i in range(self.batch_size):
133 | # Avoid over flow
134 | if self.current_index > self.image_count - 1:
135 | if self.stage == "train":
136 | self.shuffle()
137 | else:
138 | self.current_index = 0
139 |
140 | x = self.images[self.current_index]
141 | y = self.labels[self.current_index]
142 | images[i] = normalize_dataset(Image.fromarray(x), img_mean_mode=self.img_mean_mode)
143 | labels[i] = y
144 |
145 | augmented_x = self.cutout(x)
146 | augmented_images[i] = normalize_dataset(Image.fromarray(augmented_x), img_mean_mode=self.img_mean_mode)
147 |
148 | self.current_index += 1
149 |
150 | batches = [(images, labels), (augmented_images, labels)]
151 |
152 | else:
153 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
154 | images = torch.zeros(tensor_shape, dtype=torch.float32)
155 | for i in range(self.batch_size):
156 | # Avoid over flow
157 | if self.current_index > self.image_count - 1:
158 | if self.stage == "train":
159 | self.shuffle()
160 | else:
161 | self.current_index = 0
162 |
163 | x = self.images[self.current_index]
164 | y = self.labels[self.current_index]
165 |
166 | augmented_x = self.cutout(x)
167 | images[i] = normalize_dataset(Image.fromarray(augmented_x), img_mean_mode=self.img_mean_mode)
168 | labels[i] = y
169 |
170 | self.current_index += 1
171 |
172 | batches = [(images, labels)]
173 |
174 | return batches
175 |
--------------------------------------------------------------------------------
/preprocessing/image/ImageGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | from tools import shuffle_data
5 | from preprocessing.Datasets import normalize_dataset
6 |
7 | class ImageGenerator:
8 | def __init__(self,
9 | dataset,
10 | batch_size,
11 | stage,
12 | img_mean_mode=None,
13 | seed=13):
14 | """
15 | :param dataset: (tuple) x, y
16 | :param batch_size: (int) # of inputs in a mini-batch
17 | :param stage: (str) train | test
18 | :param img_mean_mode: (str) use this for image normalization
19 | :param seed: (int) seed for input shuffle
20 | """
21 |
22 | if stage not in ['train', 'test']:
23 | assert ValueError('invalid stage!')
24 |
25 | # Settings
26 | self.batch_size = batch_size
27 | self.stage = stage
28 | self.img_mean_mode = img_mean_mode
29 | self.seed = seed
30 |
31 | # Preparation
32 | self.configuration()
33 | self.load_data(dataset)
34 |
35 | def configuration(self):
36 | self.shuffle_count = 1
37 | self.current_index = 0
38 |
39 | def shuffle(self):
40 | self.image_count = len(self.labels)
41 | self.current_index = 0
42 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
43 | labels=self.labels,
44 | teacher_logits=self.teacher_logits,
45 | seed=self.seed + self.shuffle_count)
46 | self.shuffle_count += 1
47 |
48 | def load_data(self, dataset):
49 | self.images = dataset["images"] if type(dataset) is dict else dataset[0]
50 | self.labels = dataset["labels"] if type(dataset) is dict else dataset[1]
51 | self.teacher_logits = dataset["teacher_logits"] if type(dataset) is dict and "teacher_logits" in dataset else None
52 |
53 | self.len_images = len(self.images)
54 | self.len_labels = len(self.labels)
55 | assert self.len_images == self.len_labels
56 | self.image_count = self.len_labels
57 |
58 | if self.stage == 'train':
59 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
60 | labels=self.labels,
61 | teacher_logits=self.teacher_logits,
62 | seed=self.seed)
63 |
64 | def get_batch_count(self):
65 | return (self.len_labels // self.batch_size) + 1
66 |
67 | def get_batch(self, epoch=None):
68 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
69 | teacher_logits = None if self.teacher_logits is None else []
70 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
71 | images = torch.zeros(tensor_shape, dtype=torch.float32)
72 | for i in range(self.batch_size):
73 | # Avoid over flow
74 | if self.current_index > self.image_count - 1:
75 | if self.stage == "train":
76 | self.shuffle()
77 | else:
78 | self.current_index = 0
79 |
80 | image = Image.fromarray(self.images[self.current_index])
81 | images[i] = normalize_dataset(image, img_mean_mode=self.img_mean_mode)
82 | labels[i] = self.labels[self.current_index]
83 |
84 | if teacher_logits is not None:
85 | teacher_logits.append(self.teacher_logits[self.current_index])
86 |
87 | self.current_index += 1
88 |
89 | # Include teacher logits as soft labels if applicable
90 | if teacher_logits is not None:
91 | labels = [labels, np.array(teacher_logits)]
92 |
93 | return [(images, labels)]
94 |
--------------------------------------------------------------------------------
/preprocessing/image/MixUpGenerator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from tools import shuffle_data
6 | from preprocessing.Datasets import normalize_dataset
7 |
8 | class MixUpGenerator:
9 | def __init__(self,
10 | dataset,
11 | batch_size,
12 | stage,
13 | img_mean_mode=None,
14 | seed=13,
15 | orig_plus_aug=True):
16 | """
17 | :param dataset: (tuple) x, y, segmentation mask (optional)
18 | :param batch_size: (int) # of inputs in a mini-batch
19 | :param stage: (str) train | test
20 | :param img_mean_mode: (str) use this for image normalization
21 | :param seed: (int) seed for input shuffle
22 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
23 | """
24 |
25 | if stage not in ['train', 'test']:
26 | assert ValueError('invalid stage!')
27 |
28 | # Settings
29 | self.batch_size = batch_size
30 | self.stage = stage
31 | self.img_mean_mode = img_mean_mode
32 | self.seed = seed
33 | self.orig_plus_aug = orig_plus_aug
34 |
35 | # Preparation
36 | self.configuration()
37 | self.load_data(dataset)
38 |
39 | def configuration(self):
40 | self.shuffle_count = 1
41 | self.current_index = 0
42 |
43 | def shuffle(self):
44 | self.image_count = len(self.labels)
45 | self.current_index = 0
46 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
47 | labels=self.labels,
48 | teacher_logits=self.teacher_logits,
49 | seed=self.seed + self.shuffle_count)
50 | self.shuffle_count += 1
51 |
52 | def load_data(self, dataset):
53 | self.images = dataset["images"]
54 | self.labels = dataset["labels"]
55 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
56 |
57 | self.len_images = len(self.images)
58 | self.len_labels = len(self.labels)
59 | assert self.len_images == self.len_labels
60 | self.image_count = self.len_labels
61 |
62 | if self.stage == 'train':
63 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
64 | labels=self.labels,
65 | teacher_logits=self.teacher_logits,
66 | seed=self.seed)
67 |
68 | def get_batch_count(self):
69 | return (self.len_labels // self.batch_size) + 1
70 |
71 | def mixup(self, image_batch, label_batch, beta=1.0):
72 | l_param = np.random.beta(beta, beta, self.batch_size)
73 | rand_index = np.random.permutation(self.batch_size)
74 |
75 | x = image_batch.detach().clone()
76 | y = np.zeros_like(label_batch)
77 |
78 | for i in range(self.batch_size):
79 | x[i] = l_param[i] * x[i] + (1 - l_param[i]) * x[rand_index[i]]
80 | y[i] = label_batch[rand_index[i]]
81 |
82 | return x, (label_batch, y, l_param)
83 |
84 | def get_batch(self, epoch=None):
85 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
86 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
87 | images = torch.zeros(tensor_shape, dtype=torch.float32)
88 | for i in range(self.batch_size):
89 | # Avoid over flow
90 | if self.current_index > self.image_count - 1:
91 | if self.stage == "train":
92 | self.shuffle()
93 | else:
94 | self.current_index = 0
95 |
96 | x = self.images[self.current_index]
97 | y = self.labels[self.current_index]
98 | images[i] = normalize_dataset(Image.fromarray(x), img_mean_mode=self.img_mean_mode)
99 | labels[i] = y
100 |
101 | self.current_index += 1
102 |
103 | augmented_images, augmented_labels = self.mixup(images, labels)
104 |
105 | if self.orig_plus_aug:
106 | batches = [(images, (labels, labels, (np.ones_like(labels)))), (augmented_images, augmented_labels)]
107 |
108 | else:
109 | batches = [(augmented_images, augmented_labels)]
110 |
111 | return batches
112 |
--------------------------------------------------------------------------------
/preprocessing/image/RandAugmentGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
4 | import numpy as np
5 | from PIL import Image
6 | from tools import shuffle_data
7 | from preprocessing.Datasets import normalize_dataset
8 |
9 | def ShearX(img, v): # [-0.3, 0.3]
10 | assert -0.3 <= v <= 0.3
11 | if random.random() > 0.5:
12 | v = -v
13 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
14 |
15 | def ShearY(img, v): # [-0.3, 0.3]
16 | assert -0.3 <= v <= 0.3
17 | if random.random() > 0.5:
18 | v = -v
19 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
20 |
21 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
22 | assert -0.45 <= v <= 0.45
23 | if random.random() > 0.5:
24 | v = -v
25 | v = v * img.size[0]
26 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
27 |
28 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
29 | assert 0 <= v
30 | if random.random() > 0.5:
31 | v = -v
32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 |
34 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
35 | assert -0.45 <= v <= 0.45
36 | if random.random() > 0.5:
37 | v = -v
38 | v = v * img.size[1]
39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
40 |
41 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
42 | assert 0 <= v
43 | if random.random() > 0.5:
44 | v = -v
45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
46 |
47 | def Rotate(img, v): # [-30, 30]
48 | assert -30 <= v <= 30
49 | if random.random() > 0.5:
50 | v = -v
51 | return img.rotate(v)
52 |
53 | def AutoContrast(img, _):
54 | return PIL.ImageOps.autocontrast(img)
55 |
56 | def Invert(img, _):
57 | return PIL.ImageOps.invert(img)
58 |
59 | def Equalize(img, _):
60 | return PIL.ImageOps.equalize(img)
61 |
62 | def Flip(img, _): # not from the paper
63 | return PIL.ImageOps.mirror(img)
64 |
65 | def Solarize(img, v): # [0, 256]
66 | assert 0 <= v <= 256
67 | return PIL.ImageOps.solarize(img, v)
68 |
69 | def SolarizeAdd(img, addition=0, threshold=128):
70 | img_np = np.array(img).astype(np.int)
71 | img_np = img_np + addition
72 | img_np = np.clip(img_np, 0, 255)
73 | img_np = img_np.astype(np.uint8)
74 | img = Image.fromarray(img_np)
75 | return PIL.ImageOps.solarize(img, threshold)
76 |
77 | def Posterize(img, v): # [4, 8]
78 | v = int(v)
79 | v = max(1, v)
80 | return PIL.ImageOps.posterize(img, v)
81 |
82 | def Contrast(img, v): # [0.1,1.9]
83 | assert 0.1 <= v <= 1.9
84 | return PIL.ImageEnhance.Contrast(img).enhance(v)
85 |
86 | def Color(img, v): # [0.1,1.9]
87 | assert 0.1 <= v <= 1.9
88 | return PIL.ImageEnhance.Color(img).enhance(v)
89 |
90 | def Brightness(img, v): # [0.1,1.9]
91 | assert 0.1 <= v <= 1.9
92 | return PIL.ImageEnhance.Brightness(img).enhance(v)
93 |
94 | def Sharpness(img, v): # [0.1,1.9]
95 | assert 0.1 <= v <= 1.9
96 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
97 |
98 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
99 | assert 0.0 <= v <= 0.2
100 | if v <= 0.:
101 | return img
102 |
103 | v = v * img.size[0]
104 | return CutoutAbs(img, v)
105 |
106 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
107 | # assert 0 <= v <= 20
108 | if v < 0:
109 | return img
110 | w, h = img.size
111 | x0 = np.random.uniform(w)
112 | y0 = np.random.uniform(h)
113 |
114 | x0 = int(max(0, x0 - v / 2.))
115 | y0 = int(max(0, y0 - v / 2.))
116 | x1 = min(w, x0 + v)
117 | y1 = min(h, y0 + v)
118 |
119 | xy = (x0, y0, x1, y1)
120 | color = (125, 123, 114)
121 | # color = (0, 0, 0)
122 | img = img.copy()
123 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
124 | return img
125 |
126 | def SamplePairing(imgs): # [0, 0.4]
127 | def f(img1, v):
128 | i = np.random.choice(len(imgs))
129 | img2 = PIL.Image.fromarray(imgs[i])
130 | return PIL.Image.blend(img1, img2, v)
131 |
132 | return f
133 |
134 | def Identity(img, v):
135 | return img
136 |
137 | def augment_list(): # 16 oeprations and their ranges
138 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
139 | # l = [
140 | # (Identity, 0., 1.0),
141 | # (ShearX, 0., 0.3), # 0
142 | # (ShearY, 0., 0.3), # 1
143 | # (TranslateX, 0., 0.33), # 2
144 | # (TranslateY, 0., 0.33), # 3
145 | # (Rotate, 0, 30), # 4
146 | # (AutoContrast, 0, 1), # 5
147 | # (Invert, 0, 1), # 6
148 | # (Equalize, 0, 1), # 7
149 | # (Solarize, 0, 110), # 8
150 | # (Posterize, 4, 8), # 9
151 | # # (Contrast, 0.1, 1.9), # 10
152 | # (Color, 0.1, 1.9), # 11
153 | # (Brightness, 0.1, 1.9), # 12
154 | # (Sharpness, 0.1, 1.9), # 13
155 | # # (Cutout, 0, 0.2), # 14
156 | # # (SamplePairing(imgs), 0, 0.4), # 15
157 | # ]
158 |
159 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
160 | l = [
161 | (AutoContrast, 0, 1),
162 | (Equalize, 0, 1),
163 | (Invert, 0, 1),
164 | (Rotate, 0, 30),
165 | (Posterize, 0, 4),
166 | (Solarize, 0, 256),
167 | (SolarizeAdd, 0, 110),
168 | (Color, 0.1, 1.9),
169 | (Contrast, 0.1, 1.9),
170 | (Brightness, 0.1, 1.9),
171 | (Sharpness, 0.1, 1.9),
172 | (ShearX, 0., 0.3),
173 | (ShearY, 0., 0.3),
174 | (CutoutAbs, 0, 40),
175 | (TranslateXabs, 0., 100),
176 | (TranslateYabs, 0., 100),
177 | ]
178 |
179 | return l
180 |
181 | class RandAugment:
182 | """
183 | Adopted from: https://github.com/ildoonet/pytorch-randaugment
184 | """
185 | def __init__(self, n, m):
186 | self.n = n
187 | self.m = m # [0, 30]
188 | self.augment_list = augment_list()
189 |
190 | def __call__(self, img):
191 |
192 | ops = random.choices(self.augment_list, k=self.n)
193 | for op, minval, maxval in ops:
194 | val = (float(self.m) / 30) * float(maxval - minval) + minval
195 | new_img = op(img, val)
196 |
197 | return new_img
198 |
199 | class RandAugmentGenerator:
200 | def __init__(self,
201 | dataset,
202 | batch_size,
203 | stage,
204 | corruption_mode,
205 | aug_count = 1,
206 | aug_rate = 5,
207 | img_mean_mode=None,
208 | seed=13,
209 | orig_plus_aug=True):
210 | """
211 | :param dataset: (tuple) x, y, segmentation mask (optional)
212 | :param batch_size: (int) # of inputs in a mini-batch
213 | :param stage: (str) train | test
214 | :param corruption_mode: (str) requied for adv RandAugment experiments
215 | :param aug_count: (int) N in the paper
216 | :param aug_rate: (int) M in the paper
217 | :param img_mean_mode: (str) use this for image normalization
218 | :param seed: (int) seed for input shuffle
219 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones
220 | """
221 |
222 | if stage not in ['train', 'test']:
223 | assert ValueError('invalid stage!')
224 |
225 | # Settings
226 | self.batch_size = batch_size
227 | self.stage = stage
228 | self.corruption_mode = corruption_mode
229 | self.img_mean_mode = img_mean_mode
230 | self.seed = seed
231 | self.orig_plus_aug = orig_plus_aug
232 |
233 | # Preparation
234 | self.configuration()
235 | self.load_data(dataset)
236 | self.random_augmentation = RandAugment(aug_count, aug_rate)
237 |
238 | def configuration(self):
239 | self.shuffle_count = 1
240 | self.current_index = 0
241 |
242 | def shuffle(self):
243 | self.image_count = len(self.labels)
244 | self.current_index = 0
245 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
246 | labels=self.labels,
247 | teacher_logits=self.teacher_logits,
248 | seed=self.seed + self.shuffle_count)
249 | self.shuffle_count += 1
250 |
251 | def load_data(self, dataset):
252 | self.images = dataset["images"]
253 | self.labels = dataset["labels"]
254 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None
255 |
256 | self.len_images = len(self.images)
257 | self.len_labels = len(self.labels)
258 | assert self.len_images == self.len_labels
259 | self.image_count = self.len_labels
260 |
261 | if self.stage == 'train':
262 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images,
263 | labels=self.labels,
264 | teacher_logits=self.teacher_logits,
265 | seed=self.seed)
266 |
267 | def get_batch_count(self):
268 | return (self.len_labels // self.batch_size) + 1
269 |
270 | def augment(self, x, n):
271 |
272 | x_ = [self.random_augmentation(x) for _ in range(n)]
273 |
274 | return x_
275 |
276 | def get_batch(self, epoch=None):
277 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2])
278 | teacher_logits = None if self.teacher_logits is None else []
279 | expansion_coeff = 5 if self.corruption_mode == "randaugment++" else 1
280 |
281 | if self.orig_plus_aug:
282 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
283 | images = torch.zeros(tensor_shape, dtype=torch.float32)
284 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)]
285 | for i in range(self.batch_size):
286 | # Avoid over flow
287 | if self.current_index > self.image_count - 1:
288 | if self.stage == "train":
289 | self.shuffle()
290 | else:
291 | self.current_index = 0
292 |
293 | x = Image.fromarray(self.images[self.current_index])
294 | y = self.labels[self.current_index]
295 | images[i] = normalize_dataset(x, self.img_mean_mode)
296 | labels[i] = y
297 |
298 | augmented_x = self.augment(x, expansion_coeff)
299 | for j in range(expansion_coeff):
300 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode)
301 |
302 | if teacher_logits is not None:
303 | teacher_logits.append(self.teacher_logits[self.current_index])
304 |
305 | self.current_index += 1
306 |
307 | # Include teacher logits as soft labels if applicable
308 | if teacher_logits is not None:
309 | labels = [labels, np.array(teacher_logits)]
310 |
311 | batches = [(images, labels)]
312 | for i in range(expansion_coeff):
313 | batches.append((augmented_images[i], labels))
314 |
315 | else:
316 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:]))
317 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)]
318 | for i in range(self.batch_size):
319 | # Avoid over flow
320 | if self.current_index > self.image_count - 1:
321 | if self.stage == "train":
322 | self.shuffle()
323 | else:
324 | self.current_index = 0
325 |
326 | x = Image.fromarray(self.images[self.current_index])
327 | y = self.labels[self.current_index]
328 | labels[i] = y
329 |
330 | augmented_x = self.augment(x, expansion_coeff)
331 | for j in range(expansion_coeff):
332 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode)
333 |
334 | if teacher_logits is not None:
335 | teacher_logits.append(self.teacher_logits[self.current_index])
336 |
337 | self.current_index += 1
338 |
339 | # Include teacher logits as soft labels if applicable
340 | if teacher_logits is not None:
341 | labels = [labels, np.array(teacher_logits)]
342 |
343 | batches = []
344 | for i in range(expansion_coeff):
345 | batches.append((augmented_images[i], labels))
346 |
347 | return batches
348 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch~=1.5.1+cu101
2 | numpy~=1.19.5
3 | torchvision~=0.6.1+cu101
4 | Pillow~=8.3.1
5 | matplotlib~=3.1.1
6 | sklearn~=0.0
7 | scikit-learn~=0.24.1
8 | scipy~=1.6.1
9 | imagecorruptions~=1.1.2
10 | tqdm~=4.58.0
11 | pycocotools~=2.0.0
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import configparser
4 | from models.ResNet import get_resnet
5 | from testers.DomainGeneralization_tester import DomainGeneralization_tester
6 | from tools import *
7 |
8 | datasets_str = """
9 | Supported benchmarks:
10 | - PACS
11 | - COCO
12 | - DomainNet
13 | """
14 |
15 | def create_config_file(config):
16 | # Default configurations
17 | config["DEFAULT"] = {"version": "1.0.1",
18 | "model": "resnet",
19 | "depth": 18,
20 | "lr": 4e-3,
21 | "batch_size": 128,
22 | "epochs": 30,
23 | "optimizer": "sgd",
24 | "momentum": 0.9,
25 | "temperature": 1.0,
26 | "img_mean_mode": "imagenet",
27 | "corruption_mode": "None",
28 | "corruption_dist": "uniform",
29 | "only_corrupted": False,
30 | "loss": "CrossEntropy",
31 | "train_dataset": "PACS:Photo",
32 | "test_datasets": "None",
33 | "print_config": True,
34 | "data_dir": "../../datasets",
35 | "first_run": False,
36 | "model_dir": ".",
37 | "save_model": False,
38 | "knowledge_distillation": False,
39 | "random_aug": False}
40 |
41 | with open("settings.ini", "w+") as config_file:
42 | config.write(config_file)
43 |
44 | if __name__ == '__main__':
45 |
46 | # Dynamic parameters
47 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
48 | parser.add_argument("--model", help="selected neural net architecture", type=str)
49 | parser.add_argument("--depth", help="# of layers", type=int)
50 | parser.add_argument("--lr", help="learning rate", type=float)
51 | parser.add_argument("--batch_size", help="batch size (must be an even number!)", type=int)
52 | parser.add_argument("--epochs", help="# of epochs", type=int)
53 | parser.add_argument("--optimizer", help="optimization algorithm", type=str)
54 | parser.add_argument("--momentum", help="momentum (only relevant if the 'optimizer' algorithm is using it)", type=float)
55 | parser.add_argument("--weight_decay", help="L2 regularization penalty", type=float)
56 | parser.add_argument("--temperature", help="temperature for contrastive distillation loss", type=float)
57 | parser.add_argument("--pretrained_weights", help="imagenet | None", type=str)
58 | parser.add_argument("--img_mean_mode", help="image mean subtraction mode for dataset preprocessing, options: None | per_pixel | per_channel", type=str)
59 | parser.add_argument("--corruption_mode", help="visual corruptions on inputs as a data augmentation method", type=str)
60 | parser.add_argument("--corruption_dist", help="distribution from which the corruption rate is randomly sampled per image", type=str)
61 | parser.add_argument("--only_corrupted", help="when info loss applied, only the corrupted images will be in the mini-batch", action="store_true")
62 | parser.add_argument("--loss", help="loss function(s)", nargs="+")
63 | parser.add_argument("--train_dataset", help=datasets_str, type=str)
64 | parser.add_argument("--test_datasets", help="list of test sets for domain generalization experiments", nargs="+")
65 | parser.add_argument("--to_path", help="filepath to save models with custom names", type=str)
66 | parser.add_argument("--data_dir", help="filepath to save datasets", type=str)
67 | parser.add_argument("--first_run", help="to initiate COCO preprocessing", action="store_true")
68 | parser.add_argument("--model_dir", help="filepath to save models", type=str)
69 | parser.add_argument("--print_config", help="prints the active configurations", action="store_true")
70 | parser.add_argument("--save_model", help="to save the trained models", action="store_true")
71 | parser.add_argument("--knowledge_distillation", help="to disable batch normalization", action="store_true")
72 | parser.add_argument("--random_aug", help="to enable random data augmentation", action="store_true")
73 | args = vars(parser.parse_args())
74 |
75 | # Static parameters
76 | config = configparser.ConfigParser(allow_no_value=True)
77 | try:
78 | if not os.path.exists("settings.ini"):
79 | create_config_file(config)
80 |
81 | # Override the default values if specified
82 | config.read("settings.ini")
83 | temp = dict(config["DEFAULT"])
84 | temp.update({k: v for k, v in args.items() if v is not None})
85 | config.read_dict({"DEFAULT": temp})
86 | config = config["DEFAULT"]
87 |
88 | # Assign the active values
89 | version = config["version"]
90 | arch = config["model"].lower()
91 | depth = int(config["depth"])
92 | lr = float(config["lr"])
93 | batch_size = int(config["batch_size"])
94 | epochs = int(config["epochs"])
95 | optimizer = config["optimizer"]
96 | momentum = float(config["momentum"])
97 | weight_decay = float(config["weight_decay"]) if "weight_decay" in config else .0
98 | temperature = float(config["temperature"])
99 | pretrained_weights = config["pretrained_weights"] if "pretrained_weights" in config else None
100 | img_mean_mode = config["img_mean_mode"] if config["img_mean_mode"].lower() != "none" else None
101 | corruption_mode = config["corruption_mode"] if config["corruption_mode"].lower() != "none" else None
102 | corruption_dist = config["corruption_dist"]
103 | loss = config["loss"]
104 | train_dataset = config["train_dataset"]
105 | test_datasets = config["test_datasets"]
106 | to_path = config["to_path"] if "to_path" in config else None
107 | data_dir = config["data_dir"]
108 | model_dir = config["model_dir"]
109 | FIRST_RUN = config["first_run"]
110 | PRINT_CONFIG = config.getboolean("print_config")
111 | SAVE_MODEL = config.getboolean("save_model")
112 | KNOWLEDGE_DISTILLATION = config.getboolean("knowledge_distillation")
113 | ONLY_CORRUPTED = config.getboolean("only_corrupted")
114 | RANDOM_AUG = config.getboolean("random_aug")
115 | log("Configuration is completed.")
116 | except Exception as e:
117 | log("Error: " + str(e), LogType.ERROR)
118 | log("Configuration fault! New settings.ini is created. Restart the program.", LogType.ERROR)
119 | create_config_file(config)
120 | exit(1)
121 |
122 | # Process benchmark parameters
123 | log("Single-source domain generalization experiment...")
124 |
125 | # Process selected neural net
126 | if arch not in ["resnet"]:
127 | log("Nice try... but %s is not a supported neural net architecture!" % arch, LogType.ERROR)
128 | exit(1)
129 |
130 | # Process selected datasets for benchmarking
131 | datasets = ["COCO",
132 | "PACS:Photo",
133 | "FullDomainNet:Real"]
134 | # Dataset checker
135 | if train_dataset not in datasets:
136 | log("Nice try... but %s is not an allowed dataset!" % train_dataset, LogType.ERROR)
137 | exit(1)
138 |
139 | # Process selected test datasets for domain generalization
140 | if args["test_datasets"] is not None and len(args["test_datasets"]) > 0:
141 | supported_datasets = ["PACS:Art",
142 | "PACS:Cartoon",
143 | "PACS:Sketch",
144 | "PACS:Photo",
145 | "DomainNet:Real",
146 | "DomainNet:Infograph",
147 | "DomainNet:Clipart",
148 | "DomainNet:Painting",
149 | "DomainNet:Quickdraw",
150 | "DomainNet:Sketch",
151 | "FullDomainNet:Infograph",
152 | "FullDomainNet:Clipart",
153 | "FullDomainNet:Painting",
154 | "FullDomainNet:Quickdraw",
155 | "FullDomainNet:Sketch"]
156 | # Dataset checker
157 | for s in args["test_datasets"]:
158 | if s not in supported_datasets:
159 | log("Nice try... but %s is not an allowed dataset!" % s, LogType.ERROR)
160 | exit(1)
161 |
162 | # Handle specific dataset selections
163 | test_datasets = args["test_datasets"]
164 | elif test_datasets == "None":
165 | test_datasets = None
166 | else:
167 | test_datasets = [test_datasets]
168 |
169 | # Process loss function(s)
170 | if args["loss"] is not None and len(args["loss"]) > 0:
171 | if len(args["loss"]) == 1:
172 | loss = args["loss"][0]
173 | else:
174 | loss = args["loss"]
175 |
176 | # Process the mini-batch state
177 | orig_plus_aug = False if ONLY_CORRUPTED else True
178 |
179 | # Log the active configuration if needed
180 | if PRINT_CONFIG:
181 | log_config(config)
182 |
183 | # Prepare the benchmark
184 | tester = DomainGeneralization_tester(train_dataset=train_dataset,
185 | test_dataset=test_datasets,
186 | img_mean_mode=img_mean_mode,
187 | data_dir=data_dir,
188 | distillation=KNOWLEDGE_DISTILLATION,
189 | first_run=FIRST_RUN,
190 | wait=True)
191 | tester.activate() # manually trigger the dataset loader
192 | n_classes = tester.get_n_classes()
193 |
194 | # Build the baseline model
195 | model_name = "%s[%s][img_mean=%s][aug=%s]" % (get_arch_name(arch, depth), train_dataset, img_mean_mode, corruption_mode)
196 | #model_name = "%s[%s][img_mean=%s][aug=%s_T%s]" % (get_arch_name(arch, depth), train_dataset, img_mean_mode, corruption_mode, temperature)
197 |
198 | if arch == "resnet":
199 | model = get_resnet(depth, n_classes)
200 |
201 | log("Baseline model is ready.")
202 |
203 | # Train the baseline model
204 | log("Baseline model training...")
205 | hist, score = tester.run(model,
206 | name=model_name,
207 | optimizer=optimizer,
208 | lr=lr,
209 | momentum=momentum,
210 | weight_decay=weight_decay,
211 | loss=loss,
212 | batch_size=batch_size,
213 | epochs=epochs,
214 | corruption_mode=corruption_mode,
215 | corruption_dist=corruption_dist,
216 | orig_plus_aug=orig_plus_aug,
217 | temperature=temperature,
218 | rand_aug=RANDOM_AUG)
219 |
220 | log("%s Test accuracy: %s" % (model_name, score))
221 | log("----------------------------------------------------------------")
222 |
223 | # Plot and save the learning curve
224 | chart_path = "%s_learning_curve.png" % model_name
225 | chart_path = chart_path.replace(":", "_")
226 | plot_learning_curve(hist, chart_path)
227 |
228 | # Save the baseline model & print its structure
229 | if SAVE_MODEL:
230 | if to_path is None:
231 | torch.save(model, os.path.join(model_dir, "%s.pth" % model_name))
232 | else:
233 | torch.save(model, to_path)
234 | print(model)
235 |
236 | del model
237 |
238 | log("Done.")
239 |
--------------------------------------------------------------------------------
/run_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo ""
3 | echo " --------------------------------- "
4 | echo " | | "
5 | echo "----------------- Domain Generalization Experiments ------------------"
6 | echo " | | "
7 | echo " --------------------------------- "
8 | echo ""
9 | echo "----------------------------------------------------------------------"
10 | for i in 1 2 3 4 5; do
11 | echo "COCO experiments..."
12 | for depth in 18; do
13 | for lr in 4e-3; do
14 | for img_mean in imagenet; do
15 |
16 | # Baseline
17 | echo "Exp #"$i" [Baseline]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
18 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode None --print_config > dump
19 |
20 | # MixUp experiments
21 | echo "Exp #"$i" [MixUp]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
22 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode mixup --print_config > dump
23 |
24 | # CutOut experiments
25 | echo "Exp #"$i" [CutOut]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
26 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode cutout --print_config > dump
27 |
28 | # CutMix experiments
29 | echo "Exp #"$i" [CutMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
30 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode cutmix --print_config > dump
31 |
32 | # RandAugment experiments
33 | echo "Exp #"$i" [RandAugment]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
34 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode randaugment --print_config > dump
35 |
36 | # AugMix experiments
37 | echo "Exp #"$i" [AugMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
38 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy JSDivergence --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode augmix --print_config > dump
39 |
40 | # VC experiments
41 | echo "Exp #"$i" [VC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
42 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode vc --print_config > dump
43 |
44 | # ACVC experiments
45 | echo "Exp #"$i" [ACVC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
46 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy AttentionConsistency --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode acvc --print_config > dump
47 |
48 | done
49 | done
50 | done
51 | echo "PACS experiments..."
52 | for depth in 18; do
53 | for lr in 4e-3; do
54 | for img_mean in imagenet; do
55 |
56 | # Baseline
57 | echo "Exp #"$i" [Baseline]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
58 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode None --print_config > dump
59 |
60 | # MixUp experiments
61 | echo "Exp #"$i" [MixUp]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
62 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode mixup --print_config > dump
63 |
64 | # CutOut experiments
65 | echo "Exp #"$i" [CutOut]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
66 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode cutout --print_config > dump
67 |
68 | # CutMix experiments
69 | echo "Exp #"$i" [CutMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
70 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode cutmix --print_config > dump
71 |
72 | # RandAugment experiments
73 | echo "Exp #"$i" [RandAugment]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
74 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode randaugment --print_config > dump
75 |
76 | # AugMix experiments
77 | echo "Exp #"$i" [AugMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
78 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy JSDivergence --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode augmix --print_config > dump
79 |
80 | echo "Exp #"$i" [VC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
81 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode vc --print_config > dump
82 |
83 | echo "Exp #"$i" [ACVC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}"
84 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy AttentionConsistency --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode acvc --print_config > dump
85 |
86 | done
87 | done
88 | done
89 | done
90 | echo "Done."
91 |
--------------------------------------------------------------------------------
/settings.ini:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | version = 1.0.1
3 | model = resnet
4 | depth = 18
5 | lr = 4e-3
6 | batch_size = 128
7 | epochs = 30
8 | optimizer = sgd
9 | momentum = 0.9
10 | temperature = 1.0
11 | img_mean_mode = imagenet
12 | corruption_mode = None
13 | corruption_dist = uniform
14 | only_corrupted = False
15 | loss = CrossEntropy
16 | train_dataset = PACS:Photo
17 | test_datasets = None
18 | print_config = True
19 | save_model = False
20 | data_dir = ../../datasets
21 | first_run = False
22 | model_dir = .
23 | knowledge_distillation = False
24 | random_aug = False
25 |
26 |
--------------------------------------------------------------------------------
/testers/DomainGeneralization_tester.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | from tools import log, compute_accuracy
7 | from losses.AttentionConsistency import AttentionConsistency
8 | from losses.Distillation import Distillation
9 | from losses.JSDivergence import JSDivergence
10 | from preprocessing import Datasets
11 | from preprocessing.image.ImageGenerator import ImageGenerator
12 | from preprocessing.image.ACVCGenerator import ACVCGenerator
13 | from preprocessing.image.MixUpGenerator import MixUpGenerator
14 | from preprocessing.image.CutOutGenerator import CutOutGenerator
15 | from preprocessing.image.CutMixGenerator import CutMixGenerator
16 | from preprocessing.image.AblationGenerator import AblationGenerator
17 | from preprocessing.image.RandAugmentGenerator import RandAugmentGenerator
18 | from preprocessing.image.AugMixGenerator import AugMixGenerator
19 |
20 | class LR_Scheduler:
21 | def __init__(self, base_lr, dataset, optimizer):
22 | self.base_lr = base_lr
23 | self.dataset = dataset
24 | self.optimizer = optimizer
25 | self.epoch = 0
26 |
27 | def step(self):
28 | self.epoch += 1
29 | lr = self.base_lr
30 |
31 | if self.epoch > 24:
32 | lr *= 1e-1
33 |
34 | for param_group in self.optimizer.param_groups:
35 | param_group['lr'] = lr
36 |
37 | class DomainGeneralization_tester:
38 | def __init__(self, train_dataset, test_dataset, img_mean_mode=None, distillation=False, wait=False, data_dir=None, first_run=False):
39 | """
40 | Conditional constructor to enable creating testers without immediately triggering the dataset loader.
41 |
42 | # Arguments
43 | :param train_dataset: (str) name of the training set
44 | :param test_dataset: (str) name of the test set
45 | :param img_mean_mode: (str) image mean subtraction mode for dataset preprocessing
46 | :param distillation: (bool) enable/disable knowledge distillation loss
47 | :param wait: (bool) If True, then the constructor will onyl create the instance and
48 | wait for manual activation to actually load the dataset
49 | :param data_dir: (str) relative filepath to datasets
50 | :param first_run: (bool) to initiate COCO benchmark preprocessing
51 | """
52 | # Base class variables
53 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
54 | self.dataset = None
55 | self.model = None
56 | self.x_test = None
57 | self.y_test = None
58 | self.training_set_size = None
59 | self.img_mean_mode = img_mean_mode
60 | self.distillation = distillation
61 | self.data_dir = os.path.join(ROOT_DIR, "../../datasets") if data_dir is None else "../%s" % data_dir
62 | self.first_run = first_run
63 |
64 | # Loss func init
65 | self.classification_loss = None
66 | self.contrastive_loss = None
67 | self.distillation_loss = None
68 |
69 | # Extra config
70 | self.train_dataset = train_dataset
71 | self.test_dataset = test_dataset
72 | if not wait:
73 | self.activate()
74 |
75 | def activate(self):
76 | """
77 | Call this function to manually start dataset deployment
78 | """
79 |
80 | # Deploy the dataset
81 | if "COCO" == self.train_dataset:
82 | self.dataset = Datasets.load_COCO(first_run=self.first_run, train=True, img_mean_mode=self.img_mean_mode, distillation=self.distillation, data_dir=self.data_dir)
83 | self.x_test, self.y_test = Datasets.load_COCO(first_run=False, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir)
84 |
85 | elif "FullDomainNet" in self.train_dataset:
86 | subset = self.train_dataset.split(":")[1]
87 | self.dataset = Datasets.load_FullDomainNet(subset, train=True, img_mean_mode=self.img_mean_mode, distillation=self.distillation, data_dir=self.data_dir)
88 | self.x_test, self.y_test = Datasets.load_FullDomainNet(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir)
89 |
90 | elif "PACS" in self.train_dataset:
91 | subset = self.train_dataset.split(":")[1]
92 | self.dataset = Datasets.load_PACS(subset, train=True, img_mean_mode=self.img_mean_mode, distillation=self.distillation, data_dir=self.data_dir)
93 | self.x_test, self.y_test = Datasets.load_PACS(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir)
94 |
95 | else:
96 | assert False, "Train dataset: %s is not supported yet!" % self.train_dataset
97 |
98 | # Support for multiple test sets
99 | self.x_test = {self.train_dataset: self.x_test}
100 | self.y_test = {self.train_dataset: self.y_test}
101 |
102 | for test_set in self.test_dataset:
103 |
104 | if "FullDomainNet" in self.train_dataset:
105 | subset = test_set.split(":")[1]
106 | x_temp, y_temp = Datasets.load_FullDomainNet(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir)
107 | self.x_test[test_set] = x_temp
108 | self.y_test[test_set] = y_temp
109 |
110 | elif "DomainNet" in test_set:
111 | subset = test_set.split(":")[1]
112 | x_temp, y_temp = Datasets.load_DomainNet(subset, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir)
113 | self.x_test[test_set] = x_temp
114 | self.y_test[test_set] = y_temp
115 |
116 | elif "PACS" in test_set:
117 | subset = test_set.split(":")[1]
118 | x_temp, y_temp = Datasets.load_PACS(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir)
119 | self.x_test[test_set] = x_temp
120 | self.y_test[test_set] = y_temp
121 |
122 | else:
123 | assert False, "Test dataset: %s is not supported yet!" % test_set
124 |
125 | def record_generalization_results(self, result_dict, path="generalization.json"):
126 | if result_dict is not None:
127 | # Load the previous records if exist
128 | hist_cache = {}
129 | if os.path.isfile(path):
130 | with open(path, "r") as hist_file:
131 | hist_cache = json.load(hist_file)
132 |
133 | # Record new results
134 | for model_name in result_dict:
135 | if model_name in hist_cache:
136 | hist_cache[model_name].append(result_dict[model_name])
137 | else:
138 | hist_cache[model_name] = [result_dict[model_name]]
139 |
140 | # Save the updated records
141 | with open(path, "w+") as hist_file:
142 | json.dump(hist_cache, hist_file)
143 |
144 | def get_n_classes(self):
145 |
146 | if self.train_dataset == "COCO":
147 | return 10
148 |
149 | elif "FullDomainNet" in self.train_dataset:
150 | return 345
151 |
152 | elif "PACS" in self.train_dataset:
153 | return 7
154 |
155 | else:
156 | assert False, "Error: update tester.get_n_classes() for %s dataset" % self.train_dataset
157 |
158 | def get_optimizer(self, optimizer, lr=None, momentum=None, weight_decay=0.):
159 | selected_optimizer = optimizer.lower()
160 | if selected_optimizer == "adam":
161 | return torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-7, weight_decay=weight_decay)
162 | elif selected_optimizer == "sgd":
163 | return torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
164 | elif selected_optimizer == "nsgd":
165 | return torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=weight_decay)
166 | else:
167 | return None
168 |
169 | def get_contrastive_element(self, loss):
170 | if loss == "AttentionConsistency":
171 | return "CAM"
172 |
173 | elif loss == "JSDivergence":
174 | return "Predictions"
175 |
176 | else:
177 | raise ValueError("Unsupported contrastive loss")
178 |
179 | def get_loss_func(self, loss, temperature=None):
180 | if loss == "CrossEntropy":
181 | return torch.nn.CrossEntropyLoss()
182 |
183 | elif loss == "AttentionConsistency":
184 | return AttentionConsistency(T=temperature)
185 |
186 | elif loss == "Distillation":
187 | assert temperature is not None, "Distillation requires temperature as an argument"
188 | return Distillation(temperature=temperature)
189 |
190 | elif loss == "JSDivergence":
191 | return JSDivergence()
192 |
193 | else:
194 | return None
195 |
196 | def get_multi_class_loss(self, outputs, y1, y2, l_param):
197 | return torch.mean(self.classification_loss(outputs, y1) * l_param + self.classification_loss(outputs, y2) * (1. - l_param))
198 |
199 | def test(self, x, y):
200 | self.model.eval()
201 |
202 | with torch.no_grad():
203 | preds = []
204 | for i in range(x.shape[0]):
205 | img = np.copy(x[i])
206 | img = Image.fromarray(img)
207 | img = Datasets.normalize_dataset(img, img_mean_mode=self.img_mean_mode).cuda()
208 | pred = self.model(img[None, ...])[-1]['Predictions']
209 | pred = pred.cpu().data.numpy()
210 | preds.append(pred)
211 |
212 | preds = np.array(preds)
213 | accuracy = compute_accuracy(predictions=preds, labels=y)
214 |
215 | return accuracy
216 |
217 | def validate(self):
218 | self.model.eval()
219 | batch_count = self.valGenerator.get_batch_count()
220 |
221 | with torch.no_grad():
222 | val_acc = 0
223 | val_loss = 0
224 | for i in range(batch_count):
225 | batches = self.valGenerator.get_batch()
226 |
227 | for x_val, y_val in batches:
228 | images = x_val.cuda()
229 | labels = torch.from_numpy(y_val).long().cuda()
230 | outputs, end_points = self.model(images)
231 |
232 | predictions = end_points['Predictions']
233 | predictions = predictions.cpu().data.numpy()
234 |
235 | val_loss += self.classification_loss(outputs, labels).cpu().item()
236 | val_acc += compute_accuracy(predictions=predictions, labels=y_val)
237 |
238 | # Switch back to training mode
239 | self.model.train()
240 |
241 | return val_loss / batch_count, val_acc / batch_count
242 |
243 | def train(self, epochs, multi_class=False):
244 | hist = {"acc": [], "loss": [], "val_acc": [], "val_loss": []}
245 | batch_count = self.trainGenerator.get_batch_count()
246 |
247 | for epoch in range(epochs):
248 | train_acc = 0
249 | train_loss = 0
250 | self.model.train()
251 | for i in range(batch_count):
252 | batches = self.trainGenerator.get_batch(epoch)
253 | losses = []
254 |
255 | end_points_list = {}
256 | for x_train, y_train in batches:
257 | # Forward with the adapted parameters
258 | inputs = x_train.cuda()
259 | outputs, end_points = self.model(x=inputs)
260 |
261 | # Knowledge distillation loss
262 | if self.distillation_loss is not None:
263 | teacher_logits = torch.from_numpy(y_train[1]).cuda().squeeze()
264 | y_train = y_train[0]
265 | losses.append(self.distillation_loss(outputs, teacher_logits))
266 |
267 | # Accumulate orig + augmented image representation/embedding if there is a contrastive loss
268 | if self.contrastive_loss is not None:
269 | for loss in self.contrastive_loss:
270 | contrastive_element = self.get_contrastive_element(loss.name)
271 | if contrastive_element not in end_points_list:
272 | end_points_list[contrastive_element] = []
273 | end_points_list[contrastive_element].append(end_points[contrastive_element])
274 |
275 | # Classification loss
276 | if multi_class:
277 | y1 = torch.from_numpy(y_train[0]).long().cuda()
278 | y2 = torch.from_numpy(y_train[1]).long().cuda()
279 | l_param = torch.from_numpy(y_train[2]).cuda()
280 | loss = self.get_multi_class_loss(outputs, y1, y2, l_param)
281 | y_train = y_train[0]
282 |
283 | else:
284 | labels = torch.from_numpy(y_train).long().cuda()
285 | loss = self.classification_loss(outputs, labels)
286 | losses.append(loss)
287 | train_loss += loss.cpu().item()
288 |
289 | # Acc
290 | predictions = end_points['Predictions']
291 | predictions = predictions.cpu().data.numpy()
292 | train_acc += compute_accuracy(predictions=predictions, labels=y_train)
293 |
294 | # Contrastive loss
295 | if self.contrastive_loss is not None:
296 | for loss in self.contrastive_loss:
297 | contrastive_element = self.get_contrastive_element(loss.name)
298 | f0 = end_points_list[contrastive_element][0]
299 | losses.append(loss(f0, end_points_list[contrastive_element][1:], y_train))
300 |
301 | # Init the grad to zeros first
302 | self.optimizer.zero_grad()
303 |
304 | # Backward your network
305 | loss = sum(losses)
306 | loss.backward()
307 |
308 | # Optimize the parameters
309 | self.optimizer.step()
310 |
311 | # Learning rate scheduler
312 | self.scheduler.step()
313 |
314 | # Validation
315 | val_loss, val_acc = self.validate()
316 |
317 | # Record learning history
318 | hist["acc"].append(train_acc / batch_count)
319 | hist["loss"].append(train_loss / batch_count)
320 | hist["val_acc"].append(val_acc)
321 | hist["val_loss"].append(val_loss)
322 |
323 | return hist
324 |
325 | def run(self,
326 | model,
327 | name,
328 | optimizer="adam",
329 | lr=1e-3,
330 | momentum=None,
331 | weight_decay=.0,
332 | loss="CrossEntropy",
333 | batch_size=128,
334 | epochs=200,
335 | corruption_mode=None,
336 | corruption_dist="uniform",
337 | orig_plus_aug=True,
338 | temperature=1.0,
339 | rand_aug=False):
340 | """
341 | Runs the benchmark
342 |
343 | # Arguments
344 | :param model: PyTorch model
345 | :param name: (str) model name
346 | :param optimizer: (str) name of the selected Keras optimizer
347 | :param lr: (float) learning rate
348 | :param momentum: (float) only relevant for the optimization algorithms that use momentum
349 | :param weight_decay: (float) only relevant for the optimization algorithms that use weight decay
350 | :param loss: (str | list) name of the selected Keras loss function(s)
351 | :param batch_size: (int) # of inputs in a mini-batch
352 | :param epochs: (int) # of full training passes
353 | :param corruption_mode: (str) requied for VisCo experiments
354 | :param corruption_dist: (str) requied for VisCo experiments
355 | :param orig_plus_aug: (bool) requied for VisCo experiments
356 | :param temperature: (float) temperature for contrastive distillation loss
357 | :param rand_aug: (bool) enable/disable random data augmentation
358 |
359 | :return: (history, score)
360 | """
361 | # Custom configurations
362 | self.model = model
363 | self.model.cuda()
364 | is_contrastive = False
365 | multi_class = False
366 |
367 | # Set the optimizer
368 | self.optimizer = self.get_optimizer(optimizer, lr=lr, momentum=momentum, weight_decay=weight_decay)
369 |
370 | # Set the learning rate scheduler
371 | self.scheduler = LR_Scheduler(lr, self.train_dataset, self.optimizer)
372 |
373 | # Set loss functions
374 | if type(loss) is list:
375 | self.classification_loss = self.get_loss_func(loss[0])
376 | self.contrastive_loss = []
377 | for i in range(1, len(loss)):
378 | self.contrastive_loss.append(self.get_loss_func(loss[i], temperature=temperature))
379 | is_contrastive = True
380 | else:
381 | self.classification_loss = self.get_loss_func(loss)
382 | if self.distillation:
383 | self.distillation_loss = self.get_loss_func("Distillation", temperature=temperature)
384 |
385 | # Training data generator
386 | corruption_mode = None if corruption_mode is None else corruption_mode.lower()
387 | if corruption_mode is None:
388 | self.trainGenerator = ImageGenerator(dataset=self.dataset,
389 | batch_size=batch_size,
390 | stage="train",
391 | img_mean_mode=self.img_mean_mode,
392 | seed=13)
393 |
394 | elif corruption_mode == "mixup":
395 | multi_class = True
396 | self.trainGenerator = MixUpGenerator(dataset=self.dataset,
397 | batch_size=batch_size,
398 | stage="train",
399 | img_mean_mode=self.img_mean_mode,
400 | seed=13,
401 | orig_plus_aug=orig_plus_aug)
402 |
403 | elif corruption_mode == "cutout":
404 | self.trainGenerator = CutOutGenerator(dataset=self.dataset,
405 | batch_size=batch_size,
406 | stage="train",
407 | img_mean_mode=self.img_mean_mode,
408 | seed=13,
409 | orig_plus_aug=orig_plus_aug)
410 |
411 | elif corruption_mode == "cutmix":
412 | multi_class = True
413 | self.trainGenerator = CutMixGenerator(dataset=self.dataset,
414 | batch_size=batch_size,
415 | stage="train",
416 | img_mean_mode=self.img_mean_mode,
417 | seed=13,
418 | orig_plus_aug=orig_plus_aug)
419 |
420 | elif "augmix" == corruption_mode:
421 | self.trainGenerator = AugMixGenerator(dataset=self.dataset,
422 | batch_size=batch_size,
423 | stage="train",
424 | img_mean_mode=self.img_mean_mode,
425 | seed=13,
426 | orig_plus_aug=orig_plus_aug)
427 |
428 | elif "randaugment" in corruption_mode:
429 | self.trainGenerator = RandAugmentGenerator(dataset=self.dataset,
430 | batch_size=batch_size,
431 | stage="train",
432 | corruption_mode=corruption_mode,
433 | img_mean_mode=self.img_mean_mode,
434 | seed=13,
435 | orig_plus_aug=orig_plus_aug)
436 |
437 | elif "acvc" == corruption_mode or "vc" == corruption_mode:
438 | self.trainGenerator = ACVCGenerator(dataset=self.dataset,
439 | batch_size=batch_size,
440 | stage="train",
441 | epochs=epochs,
442 | corruption_mode=corruption_mode,
443 | corruption_dist=corruption_dist,
444 | img_mean_mode=self.img_mean_mode,
445 | rand_aug=rand_aug,
446 | seed=13,
447 | orig_plus_aug=orig_plus_aug)
448 |
449 | else:
450 | self.trainGenerator = AblationGenerator(dataset=self.dataset,
451 | batch_size=batch_size,
452 | stage="train",
453 | epochs=epochs,
454 | corruption_mode=corruption_mode,
455 | corruption_dist=corruption_dist,
456 | img_mean_mode=self.img_mean_mode,
457 | rand_aug=rand_aug,
458 | seed=13,
459 | orig_plus_aug=orig_plus_aug)
460 |
461 | # Validation data generator
462 | self.valGenerator = ImageGenerator(dataset=(self.x_test[self.train_dataset], self.y_test[self.train_dataset]),
463 | batch_size=batch_size,
464 | stage="test")
465 |
466 | # Train the model
467 | hist = self.train(epochs=epochs, multi_class=multi_class)
468 |
469 | # Evaluate the model with the test datasets
470 | log("-----------------------------------------------------------------------------")
471 | learning_mode = "contrastive" if is_contrastive else "vanilla"
472 | key = "%s[%s]" % (name, learning_mode)
473 | generalization_results = {}
474 | scores = []
475 | for test_set in self.x_test:
476 | score = self.test(self.x_test[test_set], self.y_test[test_set])
477 | generalization_results[test_set] = score
478 | log("%s %s Test accuracy: %.4f" % (name, test_set, score))
479 |
480 | # Avg. single-source domain generalization accuracy
481 | if test_set != self.train_dataset:
482 | scores.append(score)
483 | avg_score = float(np.mean(np.array(scores)))
484 | log("%s Avg. DG Test accuracy: %.4f" % (name, avg_score))
485 |
486 | if generalization_results == {}:
487 | generalization_results = None
488 | else:
489 | generalization_results["history"] = hist
490 | generalization_results = {key: generalization_results}
491 | self.record_generalization_results(generalization_results)
492 | log("-----------------------------------------------------------------------------")
493 | score = self.test(self.x_test[self.train_dataset], self.y_test[self.train_dataset])
494 |
495 | return hist, score
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import warnings
4 | import itertools
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from PIL import Image
8 | from sklearn.metrics import confusion_matrix, accuracy_score
9 | from time import localtime, strftime
10 | from enum import Enum
11 |
12 | logging.basicConfig(filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), "experiment.log"), filemode="a", format="%(message)s", level=logging.INFO)
13 |
14 | # Limit unwanted logging messages from packages
15 | warnings.filterwarnings("ignore", category=DeprecationWarning)
16 | matplotlib_logger = logging.getLogger('matplotlib')
17 | matplotlib_logger.setLevel(logging.ERROR)
18 |
19 | class LogType(Enum):
20 | DEBUG = 0
21 | INFO = 1
22 | WARNING = 2
23 | ERROR = 3
24 |
25 | def log(msg, log_type=LogType.INFO, to_file=True, to_stdout=True):
26 | msg = "%s %s" % (get_time(), msg)
27 |
28 | if to_stdout:
29 | print(msg)
30 | if to_file and log_type == LogType.DEBUG:
31 | logging.debug(msg)
32 | elif to_file and log_type == LogType.INFO:
33 | logging.info(msg)
34 | elif to_file and log_type == LogType.WARNING:
35 | logging.warning(msg)
36 | elif to_file and log_type == LogType.ERROR:
37 | logging.error(msg)
38 |
39 | def log_config(config):
40 | log("Active Configuration:")
41 | log("--------------------")
42 | for key in config:
43 | residual = 24 - len(key)
44 | temp = ""
45 | while len(temp) < residual:
46 | temp += " "
47 | log("%s%s: %s" % (key, temp, config[key]))
48 |
49 | def to_scientific(x):
50 | return "{:.0e}".format(x)
51 |
52 | def get_time():
53 | return "[%s]" % strftime("%a, %d %b %Y %X", localtime())
54 |
55 | def get_arch_name(arch, depth=""):
56 | name = "Unknown"
57 |
58 | if arch == "resnet":
59 | name = "ResNet%s" % depth
60 |
61 | return name
62 |
63 | def plot_learning_curve(training_hist, chart_path):
64 | """
65 | Plots the learning curve of the given training history
66 |
67 | # Arguments
68 | :param training_hist: (hist.history) of keras.models.Model.fit
69 | :param chart_path: (String) file path for the output chart
70 | """
71 | is_ok = True
72 |
73 | # Error handler for missing values
74 | for key in ["acc", "loss", "val_acc", "val_loss"]:
75 | if key not in training_hist:
76 | is_ok = False
77 |
78 | if is_ok:
79 | # Starting building the learning curve graph
80 | fig, ax1 = plt.subplots(figsize=(14, 9))
81 | epoch_list = list(range(1, len(training_hist['acc']) + 1))
82 |
83 | # Plotting training and test losses
84 | train_loss, = ax1.plot(epoch_list, training_hist['loss'], color='red', alpha=.5)
85 | if "loss_std" in training_hist:
86 | ax1.fill_between(epoch_list,
87 | training_hist['loss'] + training_hist['loss_std'],
88 | training_hist['loss'] - training_hist['loss_std'],
89 | color="red",
90 | alpha=.3)
91 | val_loss, = ax1.plot(epoch_list, training_hist['val_loss'], linewidth=2, color='green')
92 | if "val_loss_std" in training_hist:
93 | ax1.fill_between(epoch_list,
94 | training_hist['val_loss'] + training_hist['val_loss_std'],
95 | training_hist['val_loss'] - training_hist['val_loss_std'],
96 | color="green",
97 | alpha=.3)
98 | ax1.set_xlabel('Epochs')
99 | ax1.set_ylabel('Loss')
100 |
101 | # Plotting test accuracy
102 | ax2 = ax1.twinx()
103 | train_accuracy, = ax2.plot(epoch_list, training_hist['acc'], linewidth=1, color='orange')
104 | if "acc_std" in training_hist:
105 | ax2.fill_between(epoch_list,
106 | training_hist['acc'] + training_hist['acc_std'],
107 | training_hist['acc'] - training_hist['acc_std'],
108 | color="orange",
109 | alpha=.3)
110 | val_accuracy, = ax2.plot(epoch_list, training_hist['val_acc'], linewidth=2, color='blue')
111 | if "val_acc_std" in training_hist:
112 | ax2.fill_between(epoch_list,
113 | training_hist['val_acc'] + training_hist['val_acc_std'],
114 | training_hist['val_acc'] - training_hist['val_acc_std'],
115 | color="blue",
116 | alpha=.3)
117 | ax2.set_ylim(bottom=0, top=1)
118 | ax2.set_ylabel('Accuracy')
119 |
120 | # Adding legend
121 | plt.legend([train_loss, val_loss, train_accuracy, val_accuracy], ['Training Loss', 'Validation Loss', 'Training Accuracy', 'Validation Accuracy'], loc=7, bbox_to_anchor=(1, 0.8))
122 | plt.title('Learning Curve')
123 |
124 | # Saving learning curve
125 | plt.savefig(chart_path)
126 | plt.close(fig)
127 |
128 | def plot_confusion_matrix(y_test, y_preds, chart_path, n_classes, class_labels=None):
129 | class_labels = [""]*n_classes if class_labels is None else class_labels
130 |
131 | #Generate the normalized confusion matrix
132 | cm = confusion_matrix(y_test, y_preds)
133 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
134 |
135 | fig = plt.figure(figsize=(33, 26))
136 | plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap('Blues'))
137 | plt.title("Confusion Matrix")
138 | plt.colorbar()
139 | tick_marks = np.arange(n_classes)
140 | plt.xticks(tick_marks, class_labels, rotation=30)
141 | plt.yticks(tick_marks, class_labels)
142 | thresh = cm.max() / 2.
143 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
144 | plt.text(j, i, format(cm[i, j], '.1f'),
145 | horizontalalignment="center",
146 | color="white" if cm[i, j] > thresh else "black")
147 | plt.tight_layout()
148 | plt.ylabel('True label')
149 | plt.xlabel('Predicted label')
150 |
151 | # Saving learning curve
152 | plt.savefig(chart_path)
153 | plt.close(fig)
154 |
155 | def resize_image(img, target_dim):
156 | new_img = img.resize(target_dim, Image.ANTIALIAS)
157 | return new_img
158 |
159 | def shuffle_data_old(samples, labels, segmentation_masks=None, seed=13):
160 | num = len(labels)
161 | shuffle_index = np.random.RandomState(seed=seed).permutation(np.arange(num))
162 | shuffled_samples = samples[shuffle_index]
163 | shuffled_labels = labels[shuffle_index]
164 | shuffled_masks = None if segmentation_masks is None else segmentation_masks[shuffle_index]
165 | return shuffled_samples, shuffled_labels, shuffled_masks
166 |
167 | def shuffle_data(samples, labels, teacher_logits=None, segmentation_masks=None, seed=13):
168 | np.random.seed(seed)
169 | random_state = np.random.get_state()
170 | np.random.shuffle(samples)
171 | np.random.set_state(random_state)
172 | np.random.shuffle(labels)
173 |
174 | if segmentation_masks is not None:
175 | np.random.set_state(random_state)
176 | np.random.shuffle(segmentation_masks)
177 |
178 | if teacher_logits is not None:
179 | np.random.set_state(random_state)
180 | np.random.shuffle(teacher_logits)
181 |
182 | return samples, labels, teacher_logits, segmentation_masks
183 |
184 | def get_contrastive_loss(loss, orig_plus_corrupted=False):
185 | if type(loss) is list:
186 | return loss[1]
187 |
188 | else:
189 | if orig_plus_corrupted:
190 | return "orig_plus_corrupted"
191 |
192 | else:
193 | return "None"
194 |
195 | def compute_accuracy(predictions, labels):
196 | if np.ndim(labels) == 2:
197 | y_true = np.argmax(labels, axis=-1)
198 | else:
199 | y_true = labels
200 | accuracy = accuracy_score(y_true=y_true, y_pred=np.argmax(predictions, axis=-1))
201 | return accuracy
202 |
--------------------------------------------------------------------------------