├── .gitignore ├── LICENSE ├── README.md ├── acvc-pytorch.yml ├── analysis └── GeneralizationExpProcessor.py ├── assets ├── ACVC_CAM.png └── ACVC_flow.png ├── losses ├── AttentionConsistency.py ├── Distillation.py └── JSDivergence.py ├── models └── ResNet.py ├── preprocessing ├── Datasets.py └── image │ ├── ACVCGenerator.py │ ├── AblationGenerator.py │ ├── AugMixGenerator.py │ ├── CutMixGenerator.py │ ├── CutOutGenerator.py │ ├── ImageGenerator.py │ ├── MixUpGenerator.py │ └── RandAugmentGenerator.py ├── requirements.txt ├── run.py ├── run_experiments.sh ├── settings.ini ├── testers └── DomainGeneralization_tester.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | results -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 EML Tübingen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACVC 2 | > [**Attention Consistency on Visual Corruptions for Single-Source Domain Generalization**](https://openaccess.thecvf.com/content/CVPR2022W/L3D-IVU/html/Cugu_Attention_Consistency_on_Visual_Corruptions_for_Single-Source_Domain_Generalization_CVPRW_2022_paper.html) 3 | > [Ilke Cugu](https://cuguilke.github.io/), 4 | > [Massimiliano Mancini](https://www.eml-unitue.de/people/massimiliano-mancini), 5 | > [Yanbei Chen](https://www.eml-unitue.de/people/yanbei-chen), 6 | > [Zeynep Akata](https://www.eml-unitue.de/people/zeynep-akata) 7 | > *IEEE Computer Vision and Pattern Recognition Workshops (CVPRW), 2022* 8 | 9 |

10 | 11 |

12 | 13 | The official PyTorch implementation of the **CVPR 2022, L3D-IVU Workshop** paper titled "Attention Consistency on Visual Corruptions for Single-Source Domain Generalization". This repository contains: (1) our single-source domain generalization benchmark that aims at generalizing from natural images to other domains such as paintings, cliparts and skethces, (2) our adaptation/version of well-known advanced data augmentation techniques in the literaure, and (3) our final model ACVC which fuses visual corruptions with an attention consistency loss. 14 | 15 | ## Dependencies 16 | ``` 17 | torch~=1.5.1+cu101 18 | numpy~=1.19.5 19 | torchvision~=0.6.1+cu101 20 | Pillow~=8.3.1 21 | matplotlib~=3.1.1 22 | sklearn~=0.0 23 | scikit-learn~=0.24.1 24 | scipy~=1.6.1 25 | imagecorruptions~=1.1.2 26 | tqdm~=4.58.0 27 | pycocotools~=2.0.0 28 | ``` 29 | 30 | - We also include a YAML script `acvc-pytorch.yml` that is prepared for an easy Anaconda environment setup. 31 | 32 | - One can also use the `requirements.txt` if one knows one's craft. 33 | 34 | ## Training 35 | 36 | Training is done via `run.py`. To get the up-to-date list of commands: 37 | ```shell 38 | python run.py --help 39 | ``` 40 | 41 | We include a sample script `run_experiments.sh` for a quick start. 42 | 43 | ## Analysis 44 | 45 | The benchmark results are prepared by `analysis/GeneralizationExpProcessor.py`, which outputs LaTeX tables of the cumulative results in a .tex file. 46 | 47 | - For example: 48 | ```shell 49 | python GeneralizationExpProcessor.py --path generalization.json --to_dir ./results --image_format pdf 50 | ``` 51 | 52 | - You can also run distributed experiments, and merge the results later on: 53 | ```shell 54 | python GeneralizationExpProcessor.py --merge_logs generalization_gpu0.json generalization_gpu1.json 55 | ``` 56 | 57 | ## Case Study: COCO benchmark 58 | 59 | COCO benchmark is especially useful for further studies on ACVC since it includes segmentation masks per image. 60 | 61 | Here are the steps to make it work: 62 | 1. For this benchmark you only need 10 classes: 63 | ``` 64 | airplane 65 | bicycle 66 | bus 67 | car 68 | horse 69 | knife 70 | motorcycle 71 | skateboard 72 | train 73 | truck 74 | ``` 75 | 76 | 77 | 2. Download COCO 2017 [trainset](http://images.cocodataset.org/zips/train2017.zip), [valset](images.cocodataset.org/zips/val2017.zip), and [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip) 78 | 79 | 80 | 3. Extract the annotations zip file into a folder named `COCO` inside your choice of `data_dir` (For example: `datasets/COCO`) 81 | 82 | 83 | 4. Extract train and val set zip files into a subfolder named `downloads` (For example: `datasets/COCO/downloads`) 84 | 85 | 86 | 5. Download [DomainNet (clean version)](https://ai.bu.edu/M3SDA/) 87 | 88 | 89 | 6. Create a new `DomainNet` folder next to your `COCO` folder 90 | 91 | 92 | 7. Extract each domain's zip file under its respective subfolder (For example: `datasets/DomainNet/clipart`) 93 | 94 | 95 | 8. Back to the project, use `--first_run` argument once while running the training script: 96 | ```shell 97 | python run.py --loss CrossEntropy --epochs 1 --corruption_mode None --data_dir datasets --first_run --train_dataset COCO --test_datasets DomainNet:Real --print_config 98 | ``` 99 | 100 | 101 | 9. If everything works fine, you will see `train2017` and `val2017` folders under `COCO` 102 | 103 | 104 | 10. Both folders must contain 10 subfolders that belong to shared classes between COCO and DomainNet 105 | 106 | 107 | 11. Now, try running ACVC as well: 108 | ```shell 109 | python run.py --loss CrossEntropy AttentionConsistency --epochs 1 --corruption_mode acvc --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real --print_config 110 | ``` 111 | 112 | 113 | 12. All good? Then, you are good to go with the COCO section of `run_experiments.sh` to run multiple experiments 114 | 115 | 116 | 13. That's it! 117 | 118 | ## Citation 119 | 120 | If you use these codes in your research, please cite: 121 | 122 | ```bibtex 123 | @InProceedings{Cugu_2022_CVPR, 124 | author = {Cugu, Ilke and Mancini, Massimiliano and Chen, Yanbei and Akata, Zeynep}, 125 | title = {Attention Consistency on Visual Corruptions for Single-Source Domain Generalization}, 126 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 127 | month = {June}, 128 | year = {2022}, 129 | pages = {4165-4174} 130 | } 131 | ``` 132 | 133 | ## References 134 | 135 | We indicate if a function or script is borrowed externally inside each file. 136 | Specifically for visual corruption implementations we benefit from: 137 | 138 | - The imagecorruptions library of [Autonomous Driving when Winter is Coming](https://github.com/bethgelab/imagecorruptions). 139 | 140 | Consider citing this work as well if you use it in your project. 141 | -------------------------------------------------------------------------------- /acvc-pytorch.yml: -------------------------------------------------------------------------------- 1 | name: acvc-pytorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2021.5.25=h06a4308_1 10 | - certifi=2021.5.30=py37h06a4308_0 11 | - cudatoolkit=10.1.243=h6bb024c_0 12 | - freetype=2.10.4=h5ab3b9f_0 13 | - intel-openmp=2021.2.0=h06a4308_610 14 | - jpeg=9b=h024ee3a_2 15 | - lcms2=2.12=h3be6417_0 16 | - ld_impl_linux-64=2.35.1=h7274673_9 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.3.0=h5101ec6_17 19 | - libgomp=9.3.0=h5101ec6_17 20 | - libpng=1.6.37=hbc83047_0 21 | - libstdcxx-ng=9.3.0=hd4cf53a_17 22 | - libtiff=4.2.0=h85742a9_0 23 | - libwebp-base=1.2.0=h27cfd23_0 24 | - lz4-c=1.9.3=h2531618_0 25 | - mkl=2021.2.0=h06a4308_296 26 | - mkl-service=2.3.0=py37h27cfd23_1 27 | - mkl_fft=1.3.0=py37h42c9631_2 28 | - mkl_random=1.2.1=py37ha9443f7_2 29 | - ncurses=6.2=he6710b0_1 30 | - ninja=1.10.2=hff7bd54_1 31 | - numpy=1.20.2=py37h2d18471_0 32 | - numpy-base=1.20.2=py37hfae3a4d_0 33 | - olefile=0.46=py37_0 34 | - openssl=1.1.1k=h27cfd23_0 35 | - pillow=8.2.0=py37he98fc37_0 36 | - pip=21.1.2=py37h06a4308_0 37 | - python=3.7.10=h12debd9_4 38 | - pytorch=1.5.1=py3.7_cuda10.1.243_cudnn7.6.3_0 39 | - readline=8.1=h27cfd23_0 40 | - setuptools=52.0.0=py37h06a4308_0 41 | - six=1.16.0=pyhd3eb1b0_0 42 | - sqlite=3.36.0=hc218d9a_0 43 | - tk=8.6.10=hbc83047_0 44 | - torchvision=0.6.1=py37_cu101 45 | - wheel=0.36.2=pyhd3eb1b0_0 46 | - xz=5.2.5=h7b6447c_0 47 | - zlib=1.2.11=h7b6447c_3 48 | - zstd=1.4.9=haebb681_0 49 | - pip: 50 | - absl-py==0.13.0 51 | - astor==0.8.1 52 | - cached-property==1.5.2 53 | - click==8.0.1 54 | - cycler==0.10.0 55 | - cython==0.29.23 56 | - gast==0.2.2 57 | - google-pasta==0.2.0 58 | - grpcio==1.38.1 59 | - h5py==3.3.0 60 | - importlib-metadata==4.6.0 61 | - joblib==1.0.1 62 | - keras-applications==1.0.8 63 | - keras-preprocessing==1.1.2 64 | - kiwisolver==1.3.1 65 | - markdown==3.3.4 66 | - matplotlib==3.1.1 67 | - matplotlib-venn==0.11.6 68 | - opencv-python==4.5.1.48 69 | - opt-einsum==3.3.0 70 | - overrides==3.1.0 71 | - pandas==1.3.0 72 | - protobuf==3.17.3 73 | - pycocotools==2.0.2 74 | - pyparsing==2.4.7 75 | - python-dateutil==2.8.1 76 | - pytz==2021.1 77 | - pyyaml==5.4.1 78 | - scikit-learn==0.24.2 79 | - scipy==1.7.0 80 | - tensorboard==1.15.0 81 | - tensorboardx==2.4 82 | - tensorflow-estimator==1.15.1 83 | - termcolor==1.1.0 84 | - threadpoolctl==2.1.0 85 | - tqdm==4.58.0 86 | - typing-extensions==3.10.0.0 87 | - typing-utils==0.1.0 88 | - werkzeug==2.0.1 89 | - wrapt==1.12.1 90 | - zipp==3.4.1 91 | prefix: /home/ubuntu/anaconda3/envs/acvc-pytorch 92 | -------------------------------------------------------------------------------- /analysis/GeneralizationExpProcessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tools import plot_learning_curve 6 | 7 | class GeneralizationExpProcessor: 8 | """ 9 | Custom class to analyze and visualiza the cumulative experimental results 10 | 11 | # Arguments 12 | :param to_dir: (string) 13 | :param path: (string) absolute path to 'experiment.json' file if already exists 14 | """ 15 | 16 | def __init__(self, 17 | config, 18 | to_dir=".", 19 | include_std=True, 20 | image_format="pdf", 21 | path="generalization.json"): 22 | self.config = config 23 | self.to_dir = to_dir 24 | self.include_std = include_std 25 | self.image_format = image_format 26 | self.path = path 27 | self.hist = {} 28 | 29 | if os.path.isfile(self.path): 30 | self.load_data() 31 | 32 | if not os.path.exists(self.to_dir): 33 | os.makedirs(self.to_dir) 34 | 35 | def save_data(self): 36 | """ 37 | Overwrites the experiment history 38 | """ 39 | with open(self.path, "w") as hist_file: 40 | json.dump(self.hist, hist_file) 41 | 42 | def load_data(self): 43 | """ 44 | Loads the current experiment history to append new results 45 | """ 46 | with open(self.path, "r") as hist_file: 47 | self.hist = json.load(hist_file) 48 | 49 | def merge_logs(self, paths): 50 | print("Merge started...") 51 | for path in paths: 52 | print("Processing... %s" % path) 53 | with open(path) as hist_file: 54 | temp_hist = json.load(hist_file) 55 | 56 | for key in temp_hist: 57 | if key in self.hist: 58 | self.hist[key].extend(temp_hist[key]) 59 | else: 60 | self.hist[key] = temp_hist[key] 61 | 62 | # Save the changes 63 | self.save_data() 64 | print("Merge completed.") 65 | print("New file is saved at %s" % self.path) 66 | 67 | def get_info_loss_mode(self, s): 68 | temp = 4 if s.find("info=") == -1 else 5 69 | i = max(s.find("info="), s.find("aug=")) 70 | result = "" 71 | while s[i + temp] != "]": 72 | result += s[i + temp] 73 | i += 1 74 | 75 | return result 76 | 77 | def process_key(self, key): 78 | if key == "None": 79 | return "Baseline" 80 | 81 | return key 82 | 83 | def run_PACS_analysis(self, hist): 84 | pacs = [ 85 | "PACS:Photo", 86 | "PACS:Art", 87 | "PACS:Cartoon", 88 | "PACS:Sketch" 89 | ] 90 | 91 | prefix = [ 92 | "\\begin{table*}", 93 | "\t \\begin{adjustbox}{width=1\\textwidth}", 94 | "\t \\centering", 95 | "\t\t \\begin{tabular}{lcccccc} ", 96 | "\t\t \\toprule", 97 | "\t\t & Photo & Art & Cartoon & Sketch & Avg. & Max. \\\\", 98 | "\t\t \\midrule", 99 | ] 100 | 101 | # History contains repeated experiments, so compute mean & std 102 | hist_flag = False 103 | processed_hist = {} 104 | cumulative_hist = {} 105 | for i in hist: 106 | processed_hist[i] = {} 107 | cumulative_hist[i] = {"acc": [], "loss": [], "val_acc": [], "val_loss": []} 108 | for entry in hist[i]: 109 | for key in entry: 110 | if key == "history": # Cumulative learning curve 111 | for j in entry[key]: 112 | cumulative_hist[i][j].append(entry[key][j]) 113 | hist_flag = True 114 | 115 | else: 116 | if key in processed_hist[i]: 117 | processed_hist[i][key].append(entry[key]) 118 | else: 119 | processed_hist[i][key] = [entry[key]] 120 | 121 | # Plot cumulative learning curve 122 | if hist_flag: 123 | for i in cumulative_hist: 124 | temp_hist = {} 125 | for key in cumulative_hist[i]: 126 | temp = np.array(cumulative_hist[i][key]) 127 | temp_avg = np.mean(temp, axis=0) 128 | temp_std = np.std(temp, axis=0) 129 | temp_hist[key] = temp_avg 130 | temp_hist["%s_std" % key] = temp_std 131 | plot_learning_curve(temp_hist, os.path.join(self.to_dir, "%s_PACS_learning_curve.%s" % (self.get_info_loss_mode(i), self.image_format))) 132 | 133 | # Prepare results in LaTeX format 134 | results = [] 135 | for i in processed_hist: 136 | 137 | entry = processed_hist[i] 138 | temp_obj = "" if "vanilla" in i else " + contrastive" 139 | info_loss_mode = self.get_info_loss_mode(i) 140 | key = info_loss_mode + temp_obj 141 | key = key.replace("_", " ") 142 | key = self.process_key(key) 143 | 144 | accs = {} 145 | stds = {} 146 | for dataset in entry: 147 | cumulative_results = 100 * np.array(entry[dataset]) 148 | avg = np.mean(cumulative_results) 149 | std = np.std(cumulative_results) 150 | accs[dataset] = float(avg) 151 | stds[dataset] = float(std) 152 | 153 | # Calculate avg performance per entry 154 | exp_count = len(entry["PACS:Photo"]) 155 | avg_list = np.zeros(exp_count) 156 | for j in range(exp_count): 157 | avg_list[j] = np.mean(100 * np.array([entry[temp][j] for temp in entry if temp != "PACS:Photo" and temp in pacs])) 158 | 159 | # PACS table 160 | temp = "\t\t %s " % key 161 | while len(temp) < 52: 162 | temp += " " 163 | for j in pacs: 164 | temp += "& $%.2f \\pm %.1f$ " % (accs[j], stds[j]) if self.include_std else "& $%.2f$ " % (accs[j]) 165 | avg, std, top = float(np.mean(avg_list)), float(np.std(avg_list)), float(avg_list.max()) 166 | temp += "& $%.2f \\pm %.1f$ & $%.2f$ \\\\" % (avg, std, top) 167 | results.append(temp) 168 | 169 | body = ["\n".join(prefix)] 170 | body.append("\n".join(results)) 171 | body.append("\t\t \\bottomrule") 172 | body.append("\t\t \\end{tabular}") 173 | body.append("\t \\end{adjustbox}") 174 | body.append("\t \\caption{ResNet-18 results on PACS benchmark}") 175 | body.append("\\end{table*}") 176 | 177 | # Export the LaTeX file 178 | with open(os.path.join(self.to_dir, "result_PACS.tex"), '+w') as tex_file: 179 | tex_file.write("\n".join(body)) 180 | 181 | def run_COCO_analysis(self, hist): 182 | coco = [ 183 | "COCO", 184 | "DomainNet:Real", 185 | "DomainNet:Painting", 186 | "DomainNet:Infograph", 187 | "DomainNet:Clipart", 188 | "DomainNet:Sketch", 189 | "DomainNet:Quickdraw" 190 | ] 191 | 192 | prefix = [ 193 | "\\begin{table*}", 194 | "\t \\centering", 195 | #"\t\t \\begin{tabular}{lcccccccc} ", 196 | "\t\t \\begin{tabular}{lccccccccc} ", 197 | "\t\t \\toprule", 198 | #"\t\t & COCO & Real & Painting & Infograph & Clipart & Sketch & Avg. & Max. \\\\", 199 | "\t\t & COCO & Real & Painting & Infograph & Clipart & Sketch & Quickdraw & Avg. & Max. \\\\", 200 | "\t\t \\midrule", 201 | ] 202 | 203 | # History contains repeated experiments, so compute mean & std 204 | hist_flag = False 205 | processed_hist = {} 206 | cumulative_hist = {} 207 | for i in hist: 208 | processed_hist[i] = {} 209 | cumulative_hist[i] = {"acc": [], "loss": [], "val_acc": [], "val_loss": []} 210 | for entry in hist[i]: 211 | for key in entry: 212 | if key == "history": # Cumulative learning curve 213 | for j in entry[key]: 214 | cumulative_hist[i][j].append(entry[key][j]) 215 | hist_flag = True 216 | 217 | else: 218 | if key in processed_hist[i]: 219 | processed_hist[i][key].append(entry[key]) 220 | else: 221 | processed_hist[i][key] = [entry[key]] 222 | 223 | # Plot cumulative learning curve 224 | if hist_flag: 225 | for i in cumulative_hist: 226 | temp_hist = {} 227 | for key in cumulative_hist[i]: 228 | temp = np.array(cumulative_hist[i][key]) 229 | temp_avg = np.mean(temp, axis=0) 230 | temp_std = np.std(temp, axis=0) 231 | temp_hist[key] = temp_avg 232 | temp_hist["%s_std" % key] = temp_std 233 | plot_learning_curve(temp_hist, os.path.join(self.to_dir, "%s_COCO_learning_curve.%s" % (self.get_info_loss_mode(i), self.image_format))) 234 | 235 | # Prepare results in LaTeX format 236 | results = [] 237 | for i in processed_hist: 238 | entry = processed_hist[i] 239 | temp_obj = "" if "vanilla" in i else " + contrastive" 240 | info_loss_mode = self.get_info_loss_mode(i) 241 | key = info_loss_mode + temp_obj 242 | key = key.replace("_", " ") 243 | key = self.process_key(key) 244 | 245 | accs = {} 246 | stds = {} 247 | for dataset in entry: 248 | cumulative_results = 100 * np.array(entry[dataset]) 249 | avg = np.mean(cumulative_results) 250 | std = np.std(cumulative_results) 251 | accs[dataset] = float(avg) 252 | stds[dataset] = float(std) 253 | 254 | # Calculate avg performance per entry 255 | exp_count = len(entry["COCO"]) 256 | avg_list = np.zeros(exp_count) 257 | for j in range(exp_count): 258 | avg_list[j] = np.mean(100 * np.array([entry[temp][j] for temp in entry if temp != "COCO" and temp in coco])) 259 | 260 | # COCO table 261 | temp = "\t\t %s " % key 262 | while len(temp) < 52: 263 | temp += " " 264 | for j in coco: 265 | temp += "& $%.2f \\pm %.1f$ " % (accs[j], stds[j]) if self.include_std else "& $%.2f$ " % (accs[j]) 266 | avg, std, top = float(np.mean(avg_list)), float(np.std(avg_list)), float(avg_list.max()) 267 | temp += "& $%.2f \\pm %.1f$ & $%.2f$ \\\\" % (avg, std, top) 268 | results.append(temp) 269 | 270 | body = ["\n".join(prefix)] 271 | body.append("\n".join(results)) 272 | body.append("\t\t \\bottomrule") 273 | body.append("\t\t \\end{tabular}") 274 | body.append("\t \\vspace{-3pt}") 275 | body.append("\t \\caption{ResNet-18 results on COCO benchmark}") 276 | body.append("\\end{table*}") 277 | 278 | # Export the LaTeX file 279 | with open(os.path.join(self.to_dir, "result_COCO.tex"), '+w') as tex_file: 280 | tex_file.write("\n".join(body)) 281 | 282 | def run_FullDomainNet_analysis(self, hist): 283 | domainnet = [ 284 | "FullDomainNet:Real", 285 | "FullDomainNet:Painting", 286 | "FullDomainNet:Infograph", 287 | "FullDomainNet:Clipart", 288 | "FullDomainNet:Sketch", 289 | "FullDomainNet:Quickdraw" 290 | ] 291 | 292 | prefix = [ 293 | "\\begin{table*}", 294 | "\t \\centering", 295 | "\t\t \\begin{tabular}{lcccccccc} ", 296 | "\t\t \\toprule", 297 | "\t\t & Real & Painting & Infograph & Clipart & Sketch & Quickdraw & Avg. & Max. \\\\", 298 | "\t\t \\midrule", 299 | ] 300 | 301 | # History contains repeated experiments, so compute mean & std 302 | hist_flag = False 303 | processed_hist = {} 304 | cumulative_hist = {} 305 | for i in hist: 306 | processed_hist[i] = {} 307 | cumulative_hist[i] = {"acc": [], "loss": [], "val_acc": [], "val_loss": []} 308 | for entry in hist[i]: 309 | for key in entry: 310 | if key == "history": # Cumulative learning curve 311 | for j in entry[key]: 312 | cumulative_hist[i][j].append(entry[key][j]) 313 | hist_flag = True 314 | 315 | else: 316 | if key in processed_hist[i]: 317 | processed_hist[i][key].append(entry[key]) 318 | else: 319 | processed_hist[i][key] = [entry[key]] 320 | 321 | # Plot cumulative learning curve 322 | if hist_flag: 323 | for i in cumulative_hist: 324 | temp_hist = {} 325 | for key in cumulative_hist[i]: 326 | temp = np.array(cumulative_hist[i][key]) 327 | temp_avg = np.mean(temp, axis=0) 328 | temp_std = np.std(temp, axis=0) 329 | temp_hist[key] = temp_avg 330 | temp_hist["%s_std" % key] = temp_std 331 | plot_learning_curve(temp_hist, os.path.join(self.to_dir, "%s_DomainNet_learning_curve.%s" % (self.get_info_loss_mode(i), self.image_format))) 332 | 333 | # Prepare results in LaTeX format 334 | results = [] 335 | for i in processed_hist: 336 | entry = processed_hist[i] 337 | temp_obj = "" if "vanilla" in i else " + contrastive" 338 | info_loss_mode = self.get_info_loss_mode(i) 339 | key = info_loss_mode + temp_obj 340 | key = key.replace("_", " ") 341 | key = self.process_key(key) 342 | 343 | accs = {} 344 | stds = {} 345 | for dataset in entry: 346 | cumulative_results = 100 * np.array(entry[dataset]) 347 | avg = np.mean(cumulative_results) 348 | std = np.std(cumulative_results) 349 | accs[dataset] = float(avg) 350 | stds[dataset] = float(std) 351 | 352 | # Calculate avg performance per entry 353 | exp_count = len(entry["FullDomainNet:Real"]) 354 | avg_list = np.zeros(exp_count) 355 | for j in range(exp_count): 356 | avg_list[j] = np.mean(100 * np.array([entry[temp][j] for temp in entry if temp != "FullDomainNet:Real" and temp in domainnet])) 357 | 358 | # DomainNet table 359 | temp = "\t\t %s " % key 360 | while len(temp) < 52: 361 | temp += " " 362 | for j in domainnet: 363 | temp += "& $%.2f \\pm %.1f$ " % (accs[j], stds[j]) if self.include_std else "& $%.2f$ " % (accs[j]) 364 | avg, std, top = float(np.mean(avg_list)), float(np.std(avg_list)), float(avg_list.max()) 365 | temp += "& $%.2f \\pm %.1f$ & $%.2f$ \\\\" % (avg, std, top) 366 | results.append(temp) 367 | 368 | body = ["\n".join(prefix)] 369 | body.append("\n".join(results)) 370 | body.append("\t\t \\bottomrule") 371 | body.append("\t\t \\end{tabular}") 372 | body.append("\t \\vspace{-3pt}") 373 | body.append("\t \\caption{ResNet-18 results on DomainNet benchmark}") 374 | body.append("\\end{table*}") 375 | 376 | # Export the LaTeX file 377 | with open(os.path.join(self.to_dir, "result_DomainNet.tex"), '+w') as tex_file: 378 | tex_file.write("\n".join(body)) 379 | 380 | def run(self): 381 | benchmarks = {"PACS": {"hist": {}, "analysis_func": self.run_PACS_analysis}, 382 | "COCO": {"hist": {}, "analysis_func": self.run_COCO_analysis}, 383 | "FullDomainNet": {"hist": {}, "analysis_func": self.run_FullDomainNet_analysis}} 384 | 385 | for experiment in self.hist: 386 | 387 | if "PACS" in experiment: 388 | benchmarks["PACS"]["hist"][experiment] = self.hist[experiment] 389 | 390 | elif "COCO" in experiment: 391 | benchmarks["COCO"]["hist"][experiment] = self.hist[experiment] 392 | 393 | elif "FullDomainNet" in experiment: 394 | benchmarks["FullDomainNet"]["hist"][experiment] = self.hist[experiment] 395 | 396 | for benchmark in benchmarks: 397 | if len(benchmarks[benchmark]) > 0: 398 | benchmarks[benchmark]["analysis_func"](benchmarks[benchmark]["hist"]) 399 | 400 | if __name__ == '__main__': 401 | # Dynamic parameters 402 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 403 | parser.add_argument("--path", default="../generalization.json", help="filepath to experiment history (JSON file)", type=str) 404 | parser.add_argument("--to_dir", default="../results", help="filepath to save charts, models, etc.", type=str) 405 | parser.add_argument("--merge_logs", help="merges experiment results for the given JSON file paths", nargs="+") 406 | parser.add_argument("--image_format", default="png", help="", type=str) 407 | args = vars(parser.parse_args()) 408 | 409 | hist_path = args["path"] 410 | experimentProcessor = GeneralizationExpProcessor(config=args, path=hist_path, include_std=True, image_format=args["image_format"], to_dir=args["to_dir"]) 411 | if args["merge_logs"] is None: 412 | experimentProcessor.run() 413 | else: 414 | experimentProcessor.merge_logs(paths=args["merge_logs"]) -------------------------------------------------------------------------------- /assets/ACVC_CAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ACVC/2f93b23fa94735eb1079c2a9fec9e62ea6a4194e/assets/ACVC_CAM.png -------------------------------------------------------------------------------- /assets/ACVC_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ACVC/2f93b23fa94735eb1079c2a9fec9e62ea6a4194e/assets/ACVC_flow.png -------------------------------------------------------------------------------- /losses/AttentionConsistency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | class AttentionConsistency(nn.Module): 7 | def __init__(self, lambd=6e-2, T=1.0): 8 | super().__init__() 9 | self.name = "AttentionConsistency" 10 | self.T = T 11 | self.lambd = lambd 12 | 13 | def CAM_neg(self, c): 14 | result = c.reshape(c.size(0), c.size(1), -1) 15 | result = -nn.functional.log_softmax(result / self.T, dim=2) / result.size(2) 16 | result = result.sum(2) 17 | 18 | return result 19 | 20 | def CAM_pos(self, c): 21 | result = c.reshape(c.size(0), c.size(1), -1) 22 | result = nn.functional.softmax(result / self.T, dim=2) 23 | 24 | return result 25 | 26 | def forward(self, c, ci_list, y, segmentation_masks=None): 27 | """ 28 | CAM (batch_size, num_classes, feature_map.shpae[0], feature_map.shpae[1]) based loss 29 | 30 | Argumens: 31 | :param c: (Torch.tensor) clean image's CAM 32 | :param ci_list: (Torch.tensor) list of augmented image's CAMs 33 | :param y: (Torch.tensor) ground truth labels 34 | :param segmentation_masks: (numpy.array) 35 | :return: 36 | """ 37 | c1 = c.clone() 38 | c1 = Variable(c1) 39 | c0 = self.CAM_neg(c) 40 | 41 | # Top-k negative classes 42 | c1 = c1.sum(2).sum(2) 43 | index = torch.zeros(c1.size()) 44 | c1[range(c0.size(0)), y] = - float("Inf") 45 | topk_ind = torch.topk(c1, 3, dim=1)[1] 46 | index[torch.tensor(range(c1.size(0))).unsqueeze(1), topk_ind] = 1 47 | index = index > 0.5 48 | 49 | # Negative CAM loss 50 | neg_loss = c0[index].sum() / c0.size(0) 51 | for ci in ci_list: 52 | ci = self.CAM_neg(ci) 53 | neg_loss += ci[index].sum() / ci.size(0) 54 | neg_loss /= len(ci_list) + 1 55 | 56 | # Positive CAM loss 57 | index = torch.zeros(c1.size()) 58 | true_ind = [[i] for i in y] 59 | index[torch.tensor(range(c1.size(0))).unsqueeze(1), true_ind] = 1 60 | index = index > 0.5 61 | p0 = self.CAM_pos(c)[index] 62 | pi_list = [self.CAM_pos(ci)[index] for ci in ci_list] 63 | 64 | # Middle ground for Jensen-Shannon divergence 65 | p_count = 1 + len(pi_list) 66 | if segmentation_masks is None: 67 | p_mixture = p0.detach().clone() 68 | for pi in pi_list: 69 | p_mixture += pi 70 | p_mixture = torch.clamp(p_mixture / p_count, 1e-7, 1).log() 71 | 72 | else: 73 | mask = np.interp(segmentation_masks, (segmentation_masks.min(), segmentation_masks.max()), (0, 1)) 74 | p_mixture = torch.from_numpy(mask).cuda() 75 | p_mixture = p_mixture.reshape(p_mixture.size(0), -1) 76 | p_mixture = torch.nn.functional.normalize(p_mixture, dim=1) 77 | 78 | pos_loss = nn.functional.kl_div(p_mixture, p0, reduction='batchmean') 79 | for pi in pi_list: 80 | pos_loss += nn.functional.kl_div(p_mixture, pi, reduction='batchmean') 81 | pos_loss /= p_count 82 | 83 | loss = pos_loss + neg_loss 84 | return self.lambd * loss -------------------------------------------------------------------------------- /losses/Distillation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class Distillation(nn.Module): 5 | def __init__(self, temperature=2.0): 6 | super().__init__() 7 | self.T = temperature 8 | self.criterion = nn.KLDivLoss() 9 | 10 | def forward(self, z_s, z_t): 11 | 12 | return self.criterion(F.log_softmax(z_s / self.T), F.softmax(z_t / self.T)) # * (self.T * self.T) -------------------------------------------------------------------------------- /losses/JSDivergence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class JSDivergence(nn.Module): 6 | def __init__(self, lambd=12): 7 | super().__init__() 8 | self.name = "JSDivergence" 9 | self.lambd = lambd 10 | 11 | def forward(self, p0, pi_list, y=None): 12 | p_count = 1 + len(pi_list) 13 | p_mixture = p0.detach().clone() 14 | for pi in pi_list: 15 | p_mixture += pi 16 | p_mixture = torch.clamp(p_mixture / p_count, 1e-7, 1).log() 17 | 18 | loss = F.kl_div(p_mixture, p0, reduction='batchmean') 19 | for pi in pi_list: 20 | loss += F.kl_div(p_mixture, pi, reduction='batchmean') 21 | loss /= p_count 22 | 23 | return self.lambd * loss -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.utils import load_state_dict_from_url 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 30 | base_width=64, dilation=1, norm_layer=None): 31 | super(BasicBlock, self).__init__() 32 | if norm_layer is None: 33 | norm_layer = nn.BatchNorm2d 34 | if groups != 1 or base_width != 64: 35 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 36 | if dilation > 1: 37 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 38 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 39 | self.conv1 = conv3x3(inplanes, planes, stride) 40 | self.bn1 = norm_layer(planes) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = conv3x3(planes, planes) 43 | self.bn2 = norm_layer(planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 68 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 69 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 70 | # This variant is also known as ResNet V1.5 and improves accuracy according to 71 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 72 | 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 76 | base_width=64, dilation=1, norm_layer=None): 77 | super(Bottleneck, self).__init__() 78 | if norm_layer is None: 79 | norm_layer = nn.BatchNorm2d 80 | width = int(planes * (base_width / 64.)) * groups 81 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 82 | self.conv1 = conv1x1(inplanes, width) 83 | self.bn1 = norm_layer(width) 84 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 85 | self.bn2 = norm_layer(width) 86 | self.conv3 = conv1x1(width, planes * self.expansion) 87 | self.bn3 = norm_layer(planes * self.expansion) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | identity = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class ResNet(nn.Module): 116 | 117 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 118 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 119 | norm_layer=None): 120 | super(ResNet, self).__init__() 121 | if norm_layer is None: 122 | norm_layer = nn.BatchNorm2d 123 | self._norm_layer = norm_layer 124 | 125 | self.inplanes = 64 126 | self.dilation = 1 127 | if replace_stride_with_dilation is None: 128 | # each element in the tuple indicates if we should replace 129 | # the 2x2 stride with a dilated convolution instead 130 | replace_stride_with_dilation = [False, False, False] 131 | if len(replace_stride_with_dilation) != 3: 132 | raise ValueError("replace_stride_with_dilation should be None " 133 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 134 | self.groups = groups 135 | self.base_width = width_per_group 136 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 137 | bias=False) 138 | self.bn1 = norm_layer(self.inplanes) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 141 | self.layer1 = self._make_layer(block, 64, layers[0]) 142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 143 | dilate=replace_stride_with_dilation[0]) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 145 | dilate=replace_stride_with_dilation[1]) 146 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 147 | dilate=replace_stride_with_dilation[2]) 148 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 149 | self.fc = nn.Linear(512 * block.expansion, num_classes) 150 | 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 155 | nn.init.constant_(m.weight, 1) 156 | nn.init.constant_(m.bias, 0) 157 | 158 | # Zero-initialize the last BN in each residual branch, 159 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 160 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 161 | if zero_init_residual: 162 | for m in self.modules(): 163 | if isinstance(m, Bottleneck): 164 | nn.init.constant_(m.bn3.weight, 0) 165 | elif isinstance(m, BasicBlock): 166 | nn.init.constant_(m.bn2.weight, 0) 167 | 168 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 169 | norm_layer = self._norm_layer 170 | downsample = None 171 | previous_dilation = self.dilation 172 | if dilate: 173 | self.dilation *= stride 174 | stride = 1 175 | if stride != 1 or self.inplanes != planes * block.expansion: 176 | downsample = nn.Sequential( 177 | conv1x1(self.inplanes, planes * block.expansion, stride), 178 | norm_layer(planes * block.expansion), 179 | ) 180 | 181 | layers = [] 182 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 183 | self.base_width, previous_dilation, norm_layer)) 184 | self.inplanes = planes * block.expansion 185 | for _ in range(1, blocks): 186 | layers.append(block(self.inplanes, planes, groups=self.groups, 187 | base_width=self.base_width, dilation=self.dilation, 188 | norm_layer=norm_layer)) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def _forward_impl(self, x): 193 | end_points = {} 194 | 195 | # See note [TorchScript super()] 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | end_points['Feature'] = x 206 | 207 | x = self.avgpool(x) 208 | x = torch.flatten(x, 1) 209 | end_points['Embedding'] = x 210 | 211 | x = self.fc(x) 212 | end_points['Predictions'] = F.softmax(input=x, dim=-1) 213 | 214 | # Taken from: https://github.com/GuoleiSun/HNC_loss 215 | end_points['CAM'] = F.conv2d(end_points['Feature'], self.fc.weight.view(self.fc.out_features, end_points['Feature'].size(1), 1, 1)) + self.fc.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 216 | 217 | return x, end_points 218 | 219 | def forward(self, x): 220 | return self._forward_impl(x) 221 | 222 | 223 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 224 | model = ResNet(block, layers, **kwargs) 225 | if pretrained: 226 | state_dict = load_state_dict_from_url(model_urls[arch], 227 | progress=progress) 228 | model.load_state_dict(state_dict) 229 | return model 230 | 231 | 232 | def resnet18(pretrained=False, progress=True, **kwargs): 233 | r"""ResNet-18 model from 234 | `"Deep Residual Learning for Image Recognition" `_ 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 241 | **kwargs) 242 | 243 | 244 | def resnet34(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-34 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 253 | **kwargs) 254 | 255 | 256 | def resnet50(pretrained=False, progress=True, **kwargs): 257 | r"""ResNet-50 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet101(pretrained=False, progress=True, **kwargs): 269 | r"""ResNet-101 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnet152(pretrained=False, progress=True, **kwargs): 281 | r"""ResNet-152 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | def get_resnet(depth, num_classes, freeze=False): 292 | model = None 293 | 294 | if depth == 18: 295 | model = resnet18(pretrained=True) 296 | num_ftrs = model.fc.in_features 297 | model.fc = torch.nn.Linear(num_ftrs, num_classes) 298 | torch.nn.init.kaiming_normal_(model.fc.weight) 299 | 300 | elif depth == 50: 301 | model = resnet50(pretrained=True) 302 | num_ftrs = model.fc.in_features 303 | model.fc = torch.nn.Linear(num_ftrs, num_classes) 304 | torch.nn.init.kaiming_normal_(model.fc.weight) 305 | 306 | if freeze: 307 | for param in model.parameters(): 308 | param.requires_grad = False 309 | 310 | for param in model.fc.parameters(): 311 | param.requires_grad = True 312 | 313 | return model 314 | 315 | -------------------------------------------------------------------------------- /preprocessing/Datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | import os 3 | import torchvision 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from pycocotools.coco import COCO 8 | from tools import log, resize_image 9 | 10 | def preprocess_dataset(x, train, img_mean_mode): 11 | # Compute image mean if applicable 12 | if img_mean_mode is not None: 13 | if train: 14 | 15 | if img_mean_mode == "per_channel": 16 | x_ = np.copy(x) 17 | x_ = x_.astype('float32') / 255.0 18 | img_mean = np.array([np.mean(x_[:, :, :, 0]), np.mean(x_[:, :, :, 1]), np.mean(x_[:, :, :, 2])]) 19 | 20 | elif img_mean_mode == "imagenet": 21 | img_mean = np.array([0.485, 0.456, 0.406]) 22 | 23 | else: 24 | raise Exception("Invalid img_mean_mode..!") 25 | np.save("img_mean.npy", img_mean) 26 | 27 | return x 28 | 29 | def normalize_dataset(x, img_mean_mode): 30 | if img_mean_mode is not None: 31 | 32 | if img_mean_mode == "imagenet": 33 | img_mean = np.array([0.485, 0.456, 0.406]) 34 | img_std = np.array([0.229, 0.224, 0.225]) 35 | 36 | else: 37 | assert os.path.exists("img_mean.npy"), "Image mean file cannot be found!" 38 | img_mean = np.load("img_mean.npy") 39 | img_std = np.array([1.0, 1.0, 1.0]) 40 | 41 | transform = torchvision.transforms.Compose([ 42 | torchvision.transforms.ToTensor(), 43 | torchvision.transforms.Normalize(mean=img_mean, std=img_std) 44 | ]) 45 | 46 | else: 47 | transform = torchvision.transforms.Compose([ 48 | torchvision.transforms.ToTensor() 49 | ]) 50 | 51 | x_ = transform(x) 52 | 53 | return x_ 54 | 55 | def load_PACS(subset="photo", train=True, img_mean_mode="imagenet", distillation=False, data_dir="../../datasets"): 56 | subset = "art_painting" if subset.lower() == "art" else subset.lower() 57 | root_path = os.path.join(os.path.dirname(__file__), data_dir, 'PACS') 58 | data_path = os.path.join(root_path, subset.lower()) 59 | classes = {"dog": 0, "elephant": 1, "giraffe": 2, "guitar": 3, "horse": 4, "house": 5, "person": 6} 60 | img_dim = (224, 224) 61 | 62 | imagedata = [] 63 | labels = [] 64 | teacher_logits = None 65 | if subset == "photo": 66 | label_file = "photo_train.txt" if train else "photo_val.txt" 67 | label_file = os.path.join(root_path, label_file) 68 | 69 | # Gather images and labels 70 | with open(label_file, "r") as f_label: 71 | for line in f_label: 72 | temp = line[:-1].split(" ") 73 | img = Image.open(os.path.join(root_path, temp[0])) 74 | img = resize_image(img, img_dim) 75 | imagedata.append(np.array(img)) 76 | labels.append(int(temp[1]) - 1) 77 | 78 | imagedata = np.array(imagedata) 79 | labels = np.array(labels) 80 | 81 | # Include teacher logits as well if applicable 82 | if train and distillation: 83 | logit_file = os.path.join(root_path, "teacher_logits.npy") 84 | assert os.path.exists(logit_file), "Teacher logits cannot be found for PACS!" 85 | teacher_logits = np.load(logit_file) 86 | 87 | else: 88 | for class_dir in os.listdir(data_path): 89 | label = classes[class_dir] 90 | path = os.path.join(data_path, class_dir) 91 | 92 | for img_file in os.listdir(path): 93 | if img_file.endswith("jpg") or img_file.endswith("png"): 94 | img_path = os.path.join(path, img_file) 95 | img = Image.open(img_path) 96 | img = resize_image(img, img_dim) 97 | imagedata.append(np.array(img)) 98 | labels.append(label) 99 | 100 | imagedata = np.array(imagedata) 101 | labels = np.array(labels) 102 | 103 | # Normalize the data 104 | imagedata = preprocess_dataset(imagedata, train=train, img_mean_mode=img_mean_mode) 105 | result = {"images": imagedata, "labels": labels} if train else (imagedata, labels) 106 | 107 | if not teacher_logits is None: 108 | result["teacher_logits"] = teacher_logits 109 | 110 | return result 111 | 112 | def prep_COCO(save_masks_as_image=True, data_dir="../../datasets"): 113 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'COCO') 114 | classes = {"airplane": 0, "bicycle": 1, "bus": 2, "car": 3, "horse": 4, "knife": 5, "motorcycle": 6, 115 | "skateboard": 7, "train": 8, "truck": 9} 116 | object_scene_ratio_lower_threshold = 0.1 117 | object_scene_ratio_upper_threshold = 1.0 118 | class_names = [""] * 10 119 | class_ids = [0] * 10 120 | img_dim = (224, 224) 121 | 122 | # Prepare training data 123 | train_data_path = os.path.join(data_path, "annotations", "instances_train2017.json") 124 | coco = COCO(train_data_path) 125 | 126 | catIds = coco.getCatIds() 127 | cats = coco.loadCats(catIds) 128 | for cat in cats: 129 | cat_name = cat["name"] 130 | if cat_name in classes: 131 | class_names[classes[cat_name]] = cat_name 132 | class_ids[classes[cat_name]] = cat["id"] 133 | 134 | total_img_count = 0 135 | landing_dir = os.path.join(data_path, "downloads", "train2017") 136 | target_dir = os.path.join(data_path, "train2017") 137 | if not os.path.exists(target_dir): 138 | os.makedirs(target_dir) 139 | 140 | for i in range(len(class_ids)): 141 | class_id = class_ids[i] 142 | class_name = class_names[i] 143 | target_class_dir = os.path.join(target_dir, class_name) 144 | 145 | if not os.path.exists(target_class_dir): 146 | os.makedirs(target_class_dir) 147 | 148 | img_ids = coco.getImgIds(catIds=class_id) 149 | for img_id in tqdm(img_ids): 150 | img_info = coco.loadImgs(img_id) 151 | assert len(img_info) == 1, "Image retrieval problem in COCO training set!" 152 | img_info = img_info[0] 153 | 154 | ann_id = coco.getAnnIds(imgIds=img_id, catIds=class_id) 155 | anns = coco.loadAnns(ann_id) 156 | 157 | # Generate binary mask 158 | mask = np.zeros((img_info['height'], img_info['width'])) 159 | for j in range(len(anns)): 160 | mask = np.maximum(coco.annToMask(anns[j]), mask) 161 | 162 | if object_scene_ratio_lower_threshold < ( 163 | np.sum(mask) / mask.size) <= object_scene_ratio_upper_threshold: 164 | total_img_count += 1 165 | 166 | # Copy relevant image to dest and save its corresponding binary mask 167 | source_path = os.path.join(landing_dir, img_info["file_name"]) 168 | assert os.path.exists(source_path), "Image is not found in the source path!" 169 | dest_path = os.path.join(target_class_dir, img_info["file_name"]) 170 | img = Image.open(source_path) 171 | img = resize_image(img, img_dim) 172 | img.save(dest_path) 173 | 174 | if save_masks_as_image: 175 | mask_img_path = os.path.join(target_class_dir, "%s_mask.jpg" % img_info["file_name"].split(".jpg")[0]) 176 | mask_img = np.array(mask * 255, dtype=np.uint8) 177 | mask_img = Image.fromarray(mask_img) 178 | mask_img = resize_image(mask_img, img_dim) 179 | mask_img.save(mask_img_path) 180 | else: 181 | mask_path = os.path.join(target_class_dir, img_info["file_name"].replace("jpg", "npy")) 182 | np.save(mask_path, mask) 183 | log("%s COCO training images are prepared." % total_img_count) 184 | 185 | # Prepare validation data 186 | val_data_path = os.path.join(data_path, "annotations", "instances_val2017.json") 187 | coco = COCO(val_data_path) 188 | 189 | total_img_count = 0 190 | landing_dir = os.path.join(data_path, "downloads", "val2017") 191 | target_dir = os.path.join(data_path, "val2017") 192 | if not os.path.exists(target_dir): 193 | os.makedirs(target_dir) 194 | 195 | for i in range(len(class_ids)): 196 | class_id = class_ids[i] 197 | class_name = class_names[i] 198 | target_class_dir = os.path.join(target_dir, class_name) 199 | 200 | if not os.path.exists(target_class_dir): 201 | os.makedirs(target_class_dir) 202 | 203 | img_ids = coco.getImgIds(catIds=class_id) 204 | for img_id in tqdm(img_ids): 205 | img_info = coco.loadImgs(img_id) 206 | assert len(img_info) == 1, "Image retrieval problem in COCO validation set!" 207 | img_info = img_info[0] 208 | 209 | ann_id = coco.getAnnIds(imgIds=img_id, catIds=class_id) 210 | anns = coco.loadAnns(ann_id) 211 | 212 | # Generate binary mask 213 | mask = np.zeros((img_info['height'], img_info['width'])) 214 | for j in range(len(anns)): 215 | mask = np.maximum(coco.annToMask(anns[j]), mask) 216 | 217 | if object_scene_ratio_lower_threshold < (np.sum(mask) / mask.size) <= object_scene_ratio_upper_threshold: 218 | total_img_count += 1 219 | 220 | # Copy relevant image to dest and save its corresponding binary mask 221 | source_path = os.path.join(landing_dir, img_info["file_name"]) 222 | assert os.path.exists(source_path), "Image is not found in the source path!" 223 | dest_path = os.path.join(target_class_dir, img_info["file_name"]) 224 | img = Image.open(source_path) 225 | img = resize_image(img, img_dim) 226 | img.save(dest_path) 227 | 228 | if save_masks_as_image: 229 | mask_img_path = os.path.join(target_class_dir, "%s_mask.jpg" % img_info["file_name"].split(".jpg")[0]) 230 | mask_img = np.array(mask * 255, dtype=np.uint8) 231 | mask_img = Image.fromarray(mask_img) 232 | mask_img = resize_image(mask_img, img_dim) 233 | mask_img.save(mask_img_path) 234 | else: 235 | mask_path = os.path.join(target_class_dir, img_info["file_name"].replace("jpg", "npy")) 236 | np.save(mask_path, mask) 237 | 238 | log("%s COCO validation images are prepared." % total_img_count) 239 | 240 | def load_COCO(train=True, first_run=False, img_mean_mode="imagenet", distillation=False, data_dir="../../datasets"): 241 | if first_run: 242 | prep_COCO(data_dir=data_dir) 243 | 244 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'COCO') 245 | classes = {"airplane": 0, "bicycle": 1, "bus": 2, "car": 3, "horse": 4, "knife": 5, "motorcycle": 6, 246 | "skateboard": 7, "train": 8, "truck": 9} 247 | per_class_img_limit = 1000 248 | img_dim = (224, 224) 249 | 250 | if train: 251 | # Training data 252 | x_train = [] 253 | y_train = [] 254 | masks = [] 255 | for class_dir in classes: 256 | label = classes[class_dir] 257 | path = os.path.join(data_path, "train2017", class_dir) 258 | 259 | file_list = [i for i in sorted(os.listdir(path)) if i.endswith("jpg") and "mask" not in i][:per_class_img_limit] 260 | for img_file in file_list: 261 | img_path = os.path.join(path, img_file) 262 | img = Image.open(img_path) 263 | img = resize_image(img, img_dim) 264 | mask_path = os.path.join(path, "%s_mask.jpg" % img_file.split(".jpg")[0]) 265 | mask = Image.open(mask_path).convert('L') 266 | mask = resize_image(mask, img_dim) 267 | x_train.append(np.array(img)) 268 | y_train.append(label) 269 | masks.append(np.array(mask)) 270 | 271 | x_train = np.array(x_train) 272 | y_train = np.array(y_train) 273 | masks = np.array(masks) 274 | 275 | # Normalize the data 276 | x_train = preprocess_dataset(x_train, train=True, img_mean_mode=img_mean_mode) 277 | masks = masks.astype('float32') / 255.0 278 | masks = masks[:, :, :, np.newaxis] 279 | result = {"images": x_train, "labels": y_train, "segmentation_masks": masks} 280 | 281 | if distillation: 282 | logit_file = os.path.join(data_path, "teacher_logits.npy") 283 | assert os.path.exists(logit_file), "Teacher logits cannot be found for COCO!" 284 | result["teacher_logits"] = np.load(logit_file) 285 | 286 | else: 287 | # Validation data 288 | x_val = [] 289 | y_val = [] 290 | for class_dir in classes: 291 | label = classes[class_dir] 292 | path = os.path.join(data_path, "val2017", class_dir) 293 | 294 | file_list = [i for i in sorted(os.listdir(path)) if i.endswith("jpg") and "mask" not in i] 295 | for img_file in file_list: 296 | img_path = os.path.join(path, img_file) 297 | img = Image.open(img_path) 298 | img = resize_image(img, img_dim) 299 | x_val.append(np.array(img)) 300 | y_val.append(label) 301 | 302 | x_val = np.array(x_val) 303 | y_val = np.array(y_val) 304 | 305 | # Normalize the data 306 | x_val = preprocess_dataset(x_val, train=False, img_mean_mode=img_mean_mode) 307 | result = (x_val, y_val) 308 | 309 | return result 310 | 311 | def load_DomainNet(subset, img_mean_mode="imagenet", data_dir="../../datasets"): 312 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'DomainNet', subset.lower()) 313 | classes = {"airplane": 0, "bicycle": 1, "bus": 2, "car": 3, "horse": 4, "knife": 5, "motorcycle": 6, 314 | "skateboard": 7, "train": 8, "truck": 9} 315 | img_dim = (224, 224) 316 | 317 | imagedata = [] 318 | labels = [] 319 | for class_dir in classes: 320 | label = classes[class_dir] 321 | path = os.path.join(data_path, class_dir) 322 | 323 | for img_file in os.listdir(path): 324 | if img_file.endswith("jpg") or img_file.endswith("png"): 325 | img_path = os.path.join(path, img_file) 326 | img = Image.open(img_path) 327 | img = resize_image(img, img_dim,) 328 | imagedata.append(np.array(img)) 329 | labels.append(label) 330 | 331 | imagedata = np.array(imagedata) 332 | imagedata = preprocess_dataset(imagedata, train=False, img_mean_mode=img_mean_mode) 333 | labels = np.array(labels) 334 | 335 | return imagedata, labels 336 | 337 | def load_FullDomainNet(subset, train=True, img_mean_mode="imagenet", distillation=False, data_dir="../../datasets"): 338 | data_path = os.path.join(os.path.dirname(__file__), data_dir, 'FullDomainNet') 339 | img_dim = (224, 224) 340 | subset = subset.lower() 341 | if subset == "real": 342 | labelfile = os.path.join(data_path, "real_train.txt") if train else os.path.join(data_path, "real_test.txt") 343 | else: 344 | labelfile = os.path.join(data_path, "fast.txt") #os.path.join(data_path, "%s.txt" % subset) 345 | 346 | # Gather image paths and labels 347 | imagepath = [] 348 | labels = [] 349 | with open(labelfile, "r") as f_label: 350 | for line in f_label: 351 | temp = line[:-1].split(" ") 352 | imagepath.append(temp[0]) 353 | labels.append(int(temp[1])) 354 | labels = np.array(labels) 355 | 356 | imagedata = np.empty([labels.shape[0]] + list(img_dim) + [3], dtype="uint8") 357 | for i in range(len(labels)): 358 | img_path = os.path.join(data_path, imagepath[i]) 359 | img = Image.open(img_path) 360 | img = resize_image(img, img_dim) 361 | imagedata[i] = np.array(img) 362 | 363 | imagedata = preprocess_dataset(imagedata, train=train, img_mean_mode=img_mean_mode) 364 | 365 | return {"images": imagedata, "labels": labels} if train else (imagedata, labels) -------------------------------------------------------------------------------- /preprocessing/image/ACVCGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tools import shuffle_data 4 | from preprocessing.Datasets import normalize_dataset 5 | from PIL import Image as PILImage 6 | from scipy.stats import truncnorm 7 | from preprocessing.image.RandAugmentGenerator import RandAugment 8 | from imagecorruptions import corrupt, get_corruption_names 9 | 10 | class ACVCGenerator: 11 | def __init__(self, 12 | dataset, 13 | batch_size, 14 | stage, 15 | epochs, 16 | corruption_mode, 17 | corruption_dist, 18 | img_mean_mode=None, 19 | rand_aug=False, 20 | seed=13, 21 | orig_plus_aug=True): 22 | """ 23 | 19 out of 22 corruptions used in this generator are taken from ImageNet-10.C: 24 | - https://github.com/bethgelab/imagecorruptions 25 | 26 | :param dataset: (tuple) x, y, segmentation mask (optional) 27 | :param batch_size: (int) # of inputs in a mini-batch 28 | :param stage: (str) train | test 29 | :param epochs: (int) # of full training passes 30 | :param corruption_mode: (str) requied for VisCo experiments 31 | :param corruption_dist: (str) requied to determine the corruption rate 32 | :param img_mean_mode: (str) use this to revert the image into its original form before augmentation 33 | :param rand_aug: (bool) enable/disable on-the-fly random data augmentation 34 | :param seed: (int) seed for input shuffle 35 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 36 | """ 37 | 38 | if stage not in ['train', 'test']: 39 | assert ValueError('invalid stage!') 40 | 41 | # Settings 42 | self.batch_size = batch_size 43 | self.stage = stage 44 | self.epochs = epochs 45 | self.corruption_mode = corruption_mode 46 | self.corruption_dist = corruption_dist 47 | self.img_mean_mode = img_mean_mode 48 | self.rand_aug = rand_aug 49 | self.seed = seed 50 | self.orig_plus_aug = orig_plus_aug 51 | 52 | # Preparation 53 | self.configuration() 54 | self.load_data(dataset) 55 | self.random_augmentation = RandAugment(1, 5) 56 | if self.img_mean_mode is not None: 57 | self.img_mean = np.load("img_mean.npy") 58 | 59 | def configuration(self): 60 | self.shuffle_count = 1 61 | self.current_index = 0 62 | 63 | def shuffle(self): 64 | self.image_count = len(self.labels) 65 | self.current_index = 0 66 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images, 67 | labels=self.labels, 68 | teacher_logits=self.teacher_logits, 69 | segmentation_masks=self.segmentation_masks, 70 | seed=self.seed + self.shuffle_count) 71 | self.shuffle_count += 1 72 | 73 | def load_data(self, dataset): 74 | self.images = dataset["images"] 75 | self.labels = dataset["labels"] 76 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 77 | self.segmentation_masks = dataset["segmentation_masks"] if "segmentation_masks" in dataset else None 78 | 79 | self.len_images = len(self.images) 80 | self.len_labels = len(self.labels) 81 | assert self.len_images == self.len_labels 82 | self.image_count = self.len_labels 83 | 84 | if self.stage == 'train': 85 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images, 86 | labels=self.labels, 87 | teacher_logits=self.teacher_logits, 88 | segmentation_masks=self.segmentation_masks, 89 | seed=self.seed) 90 | 91 | def get_batch_count(self): 92 | return (self.len_labels // self.batch_size) + 1 93 | 94 | def get_truncated_normal(self, mean=0, sd=1, low=0, upp=10): 95 | return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd) 96 | 97 | def get_severity(self): 98 | return np.random.randint(1, 6) 99 | 100 | def draw_cicle(self, shape, diamiter): 101 | """ 102 | Input: 103 | shape : tuple (height, width) 104 | diameter : scalar 105 | 106 | Output: 107 | np.array of shape that says True within a circle with diamiter = around center 108 | """ 109 | assert len(shape) == 2 110 | TF = np.zeros(shape, dtype="bool") 111 | center = np.array(TF.shape) / 2.0 112 | 113 | for iy in range(shape[0]): 114 | for ix in range(shape[1]): 115 | TF[iy, ix] = (iy - center[0]) ** 2 + (ix - center[1]) ** 2 < diamiter ** 2 116 | return TF 117 | 118 | def filter_circle(self, TFcircle, fft_img_channel): 119 | temp = np.zeros(fft_img_channel.shape[:2], dtype=complex) 120 | temp[TFcircle] = fft_img_channel[TFcircle] 121 | return temp 122 | 123 | def inv_FFT_all_channel(self, fft_img): 124 | img_reco = [] 125 | for ichannel in range(fft_img.shape[2]): 126 | img_reco.append(np.fft.ifft2(np.fft.ifftshift(fft_img[:, :, ichannel]))) 127 | img_reco = np.array(img_reco) 128 | img_reco = np.transpose(img_reco, (1, 2, 0)) 129 | return img_reco 130 | 131 | def high_pass_filter(self, x, severity): 132 | x = x.astype("float32") / 255. 133 | c = [.01, .02, .03, .04, .05][severity - 1] 134 | 135 | d = int(c * x.shape[0]) 136 | TFcircle = self.draw_cicle(shape=x.shape[:2], diamiter=d) 137 | TFcircle = ~TFcircle 138 | 139 | fft_img = np.zeros_like(x, dtype=complex) 140 | for ichannel in range(fft_img.shape[2]): 141 | fft_img[:, :, ichannel] = np.fft.fftshift(np.fft.fft2(x[:, :, ichannel])) 142 | 143 | # For each channel, pass filter 144 | fft_img_filtered = [] 145 | for ichannel in range(fft_img.shape[2]): 146 | fft_img_channel = fft_img[:, :, ichannel] 147 | temp = self.filter_circle(TFcircle, fft_img_channel) 148 | fft_img_filtered.append(temp) 149 | fft_img_filtered = np.array(fft_img_filtered) 150 | fft_img_filtered = np.transpose(fft_img_filtered, (1, 2, 0)) 151 | x = np.clip(np.abs(self.inv_FFT_all_channel(fft_img_filtered)), a_min=0, a_max=1) 152 | 153 | x = PILImage.fromarray((x * 255.).astype("uint8")) 154 | return x 155 | 156 | def constant_amplitude(self, x, severity): 157 | """ 158 | A visual corruption based on amplitude information of a Fourier-transformed image 159 | 160 | Adopted from: https://github.com/MediaBrain-SJTU/FACT 161 | """ 162 | x = x.astype("float32") / 255. 163 | c = [.05, .1, .15, .2, .25][severity - 1] 164 | 165 | # FFT 166 | x_fft = np.fft.fft2(x, axes=(0, 1)) 167 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft) 168 | 169 | # Amplitude replacement 170 | beta = 1.0 - c 171 | x_abs = np.ones_like(x_abs) * max(0, beta) 172 | 173 | # Inverse FFT 174 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1)) 175 | x = x_abs * (np.e ** (1j * x_pha)) 176 | x = np.real(np.fft.ifft2(x, axes=(0, 1))) 177 | 178 | x = PILImage.fromarray((x * 255.).astype("uint8")) 179 | return x 180 | 181 | def phase_scaling(self, x, severity): 182 | """ 183 | A visual corruption based on phase information of a Fourier-transformed image 184 | 185 | Adopted from: https://github.com/MediaBrain-SJTU/FACT 186 | """ 187 | x = x.astype("float32") / 255. 188 | c = [.1, .2, .3, .4, .5][severity - 1] 189 | 190 | # FFT 191 | x_fft = np.fft.fft2(x, axes=(0, 1)) 192 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft) 193 | 194 | # Phase scaling 195 | alpha = 1.0 - c 196 | x_pha = x_pha * max(0, alpha) 197 | 198 | # Inverse FFT 199 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1)) 200 | x = x_abs * (np.e ** (1j * x_pha)) 201 | x = np.real(np.fft.ifft2(x, axes=(0, 1))) 202 | 203 | x = PILImage.fromarray((x * 255.).astype("uint8")) 204 | return x 205 | 206 | def apply_corruption(self, x, corruption_name): 207 | severity = self.get_severity() 208 | 209 | custom_corruptions = {"high_pass_filter": self.high_pass_filter, 210 | "constant_amplitude": self.constant_amplitude, 211 | "phase_scaling": self.phase_scaling} 212 | 213 | if corruption_name in get_corruption_names('all'): 214 | x = corrupt(x, corruption_name=corruption_name, severity=severity) 215 | x = PILImage.fromarray(x) 216 | 217 | elif corruption_name in custom_corruptions: 218 | x = custom_corruptions[corruption_name](x, severity=severity) 219 | 220 | else: 221 | assert True, "%s is not a supported corruption!" % corruption_name 222 | 223 | return x 224 | 225 | def acvc(self, x): 226 | i = np.random.randint(0, 22) 227 | corruption_func = {0: "fog", 228 | 1: "snow", 229 | 2: "frost", 230 | 3: "spatter", 231 | 4: "zoom_blur", 232 | 5: "defocus_blur", 233 | 6: "glass_blur", 234 | 7: "gaussian_blur", 235 | 8: "motion_blur", 236 | 9: "speckle_noise", 237 | 10: "shot_noise", 238 | 11: "impulse_noise", 239 | 12: "gaussian_noise", 240 | 13: "jpeg_compression", 241 | 14: "pixelate", 242 | 15: "elastic_transform", 243 | 16: "brightness", 244 | 17: "saturate", 245 | 18: "contrast", 246 | 19: "high_pass_filter", 247 | 20: "constant_amplitude", 248 | 21: "phase_scaling"} 249 | return self.apply_corruption(x, corruption_func[i]) 250 | 251 | def corruption(self, x, segmentation_mask=None): 252 | if self.rand_aug and np.random.uniform(0, 1) > 0.5: 253 | x_ = self.random_augmentation(PILImage.fromarray(x)) 254 | 255 | else: 256 | x_ = np.copy(x) 257 | x_ = self.acvc(x_) 258 | 259 | return x_ 260 | 261 | def get_batch(self, epoch=None): 262 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 263 | teacher_logits = None if self.teacher_logits is None else [] 264 | 265 | if self.orig_plus_aug: 266 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 267 | images = torch.zeros(tensor_shape, dtype=torch.float32) 268 | augmented_images = torch.zeros(tensor_shape, dtype=torch.float32) 269 | for i in range(self.batch_size): 270 | # Avoid over flow 271 | if self.current_index > self.image_count - 1: 272 | if self.stage == "train": 273 | self.shuffle() 274 | else: 275 | self.current_index = 0 276 | 277 | x = self.images[self.current_index] 278 | y = self.labels[self.current_index] 279 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index] 280 | images[i] = normalize_dataset(PILImage.fromarray(x), img_mean_mode=self.img_mean_mode) 281 | labels[i] = y 282 | 283 | augmented_x = self.corruption(x, mask) 284 | augmented_images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode) 285 | 286 | if teacher_logits is not None: 287 | teacher_logits.append(self.teacher_logits[self.current_index]) 288 | 289 | self.current_index += 1 290 | 291 | # Include teacher logits as soft labels if applicable 292 | if teacher_logits is not None: 293 | labels = [labels, np.array(teacher_logits)] 294 | 295 | batches = [(images, labels), (augmented_images, labels)] 296 | 297 | else: 298 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 299 | images = torch.zeros(tensor_shape, dtype=torch.float32) 300 | for i in range(self.batch_size): 301 | # Avoid over flow 302 | if self.current_index > self.image_count - 1: 303 | if self.stage == "train": 304 | self.shuffle() 305 | else: 306 | self.current_index = 0 307 | 308 | x = self.images[self.current_index] 309 | y = self.labels[self.current_index] 310 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index] 311 | 312 | augmented_x = self.corruption(x, mask) 313 | images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode) 314 | labels[i] = y 315 | 316 | if teacher_logits is not None: 317 | teacher_logits.append(self.teacher_logits[self.current_index]) 318 | 319 | self.current_index += 1 320 | 321 | # Include teacher logits as soft labels if applicable 322 | if teacher_logits is not None: 323 | labels = [labels, np.array(teacher_logits)] 324 | 325 | batches = [(images, labels)] 326 | 327 | return batches -------------------------------------------------------------------------------- /preprocessing/image/AblationGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tools import shuffle_data 4 | from preprocessing.Datasets import normalize_dataset 5 | from PIL import Image as PILImage 6 | from scipy.stats import truncnorm 7 | from preprocessing.image.RandAugmentGenerator import RandAugment 8 | from imagecorruptions import corrupt, get_corruption_names 9 | 10 | class AblationGenerator: 11 | def __init__(self, 12 | dataset, 13 | batch_size, 14 | stage, 15 | epochs, 16 | corruption_mode, 17 | corruption_dist, 18 | img_mean_mode=None, 19 | rand_aug=False, 20 | seed=13, 21 | orig_plus_aug=True): 22 | """ 23 | 19 out of 22 corruptions used in this generator are taken from ImageNet-10.C: 24 | - https://github.com/bethgelab/imagecorruptions 25 | 26 | :param dataset: (tuple) x, y, segmentation mask (optional) 27 | :param batch_size: (int) # of inputs in a mini-batch 28 | :param stage: (str) train | test 29 | :param epochs: (int) # of full training passes 30 | :param corruption_mode: (str) requied for VisCo experiments 31 | :param corruption_dist: (str) requied to determine the corruption rate 32 | :param img_mean_mode: (str) use this to revert the image into its original form before augmentation 33 | :param rand_aug: (bool) enable/disable on-the-fly random data augmentation 34 | :param seed: (int) seed for input shuffle 35 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 36 | """ 37 | 38 | if stage not in ['train', 'test']: 39 | assert ValueError('invalid stage!') 40 | 41 | # Settings 42 | self.batch_size = batch_size 43 | self.stage = stage 44 | self.epochs = epochs 45 | self.corruption_mode = corruption_mode 46 | self.corruption_dist = corruption_dist 47 | self.img_mean_mode = img_mean_mode 48 | self.rand_aug = rand_aug 49 | self.seed = seed 50 | self.orig_plus_aug = orig_plus_aug 51 | 52 | # Preparation 53 | self.configuration() 54 | self.load_data(dataset) 55 | self.random_augmentation = RandAugment(1, 5) 56 | if self.img_mean_mode is not None: 57 | self.img_mean = np.load("img_mean.npy") 58 | 59 | def configuration(self): 60 | self.shuffle_count = 1 61 | self.current_index = 0 62 | 63 | def shuffle(self): 64 | self.image_count = len(self.labels) 65 | self.current_index = 0 66 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images, 67 | labels=self.labels, 68 | teacher_logits=self.teacher_logits, 69 | segmentation_masks=self.segmentation_masks, 70 | seed=self.seed + self.shuffle_count) 71 | self.shuffle_count += 1 72 | 73 | def load_data(self, dataset): 74 | self.images = dataset["images"] 75 | self.labels = dataset["labels"] 76 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 77 | self.segmentation_masks = dataset["segmentation_masks"] if "segmentation_masks" in dataset else None 78 | 79 | self.len_images = len(self.images) 80 | self.len_labels = len(self.labels) 81 | assert self.len_images == self.len_labels 82 | self.image_count = self.len_labels 83 | 84 | if self.stage == 'train': 85 | self.images, self.labels, self.teacher_logits, self.segmentation_masks = shuffle_data(samples=self.images, 86 | labels=self.labels, 87 | teacher_logits=self.teacher_logits, 88 | segmentation_masks=self.segmentation_masks, 89 | seed=self.seed) 90 | 91 | def get_batch_count(self): 92 | return (self.len_labels // self.batch_size) + 1 93 | 94 | def get_truncated_normal(self, mean=0, sd=1, low=0, upp=10): 95 | return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd) 96 | 97 | def get_severity(self): 98 | return 1 #np.random.randint(1, 6) 99 | 100 | def get_random_boundingbox(self, img, l_param): 101 | width = img.shape[0] 102 | height = img.shape[1] 103 | 104 | r_x = np.random.randint(width) 105 | r_y = np.random.randint(height) 106 | 107 | r_l = np.sqrt(1 - l_param) 108 | r_w = int(width * r_l) 109 | r_h = int(height * r_l) 110 | 111 | if r_x + r_w < width: 112 | bbox_x1 = r_x 113 | bbox_x2 = r_w 114 | else: 115 | bbox_x1 = width - r_w 116 | bbox_x2 = width 117 | if r_y + r_h < height: 118 | bbox_y1 = r_y 119 | bbox_y2 = r_h 120 | else: 121 | bbox_y1 = height - r_h 122 | bbox_y2 = height 123 | 124 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2 125 | 126 | def draw_cicle(self, shape, diamiter): 127 | """ 128 | Input: 129 | shape : tuple (height, width) 130 | diameter : scalar 131 | 132 | Output: 133 | np.array of shape that says True within a circle with diamiter = around center 134 | """ 135 | assert len(shape) == 2 136 | TF = np.zeros(shape, dtype="bool") 137 | center = np.array(TF.shape) / 2.0 138 | 139 | for iy in range(shape[0]): 140 | for ix in range(shape[1]): 141 | TF[iy, ix] = (iy - center[0]) ** 2 + (ix - center[1]) ** 2 < diamiter ** 2 142 | return TF 143 | 144 | def filter_circle(self, TFcircle, fft_img_channel): 145 | temp = np.zeros(fft_img_channel.shape[:2], dtype=complex) 146 | temp[TFcircle] = fft_img_channel[TFcircle] 147 | return temp 148 | 149 | def inv_FFT_all_channel(self, fft_img): 150 | img_reco = [] 151 | for ichannel in range(fft_img.shape[2]): 152 | img_reco.append(np.fft.ifft2(np.fft.ifftshift(fft_img[:, :, ichannel]))) 153 | img_reco = np.array(img_reco) 154 | img_reco = np.transpose(img_reco, (1, 2, 0)) 155 | return img_reco 156 | 157 | def high_pass_filter(self, x, severity): 158 | x = x.astype("float32") / 255. 159 | c = [.01, .02, .03, .04, .05][severity - 1] 160 | 161 | d = int(c * x.shape[0]) 162 | TFcircle = self.draw_cicle(shape=x.shape[:2], diamiter=d) 163 | TFcircle = ~TFcircle 164 | 165 | fft_img = np.zeros_like(x, dtype=complex) 166 | for ichannel in range(fft_img.shape[2]): 167 | fft_img[:, :, ichannel] = np.fft.fftshift(np.fft.fft2(x[:, :, ichannel])) 168 | 169 | # For each channel, pass filter 170 | fft_img_filtered = [] 171 | for ichannel in range(fft_img.shape[2]): 172 | fft_img_channel = fft_img[:, :, ichannel] 173 | temp = self.filter_circle(TFcircle, fft_img_channel) 174 | fft_img_filtered.append(temp) 175 | fft_img_filtered = np.array(fft_img_filtered) 176 | fft_img_filtered = np.transpose(fft_img_filtered, (1, 2, 0)) 177 | x = np.clip(np.abs(self.inv_FFT_all_channel(fft_img_filtered)), a_min=0, a_max=1) 178 | 179 | x = PILImage.fromarray((x * 255.).astype("uint8")) 180 | return x 181 | 182 | def constant_amplitude(self, x, severity): 183 | """ 184 | A visual corruption based on amplitude information of a Fourier-transformed image 185 | 186 | Adopted from: https://github.com/MediaBrain-SJTU/FACT 187 | """ 188 | x = x.astype("float32") / 255. 189 | c = [.05, .1, .15, .2, .25][severity - 1] 190 | 191 | # FFT 192 | x_fft = np.fft.fft2(x, axes=(0, 1)) 193 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft) 194 | 195 | # Amplitude replacement 196 | beta = 1.0 - c 197 | x_abs = np.ones_like(x_abs) * max(0, beta) 198 | 199 | # Inverse FFT 200 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1)) 201 | x = x_abs * (np.e ** (1j * x_pha)) 202 | x = np.real(np.fft.ifft2(x, axes=(0, 1))) 203 | 204 | x = PILImage.fromarray((x * 255.).astype("uint8")) 205 | return x 206 | 207 | def phase_scaling(self, x, severity): 208 | """ 209 | A visual corruption based on phase information of a Fourier-transformed image 210 | 211 | Adopted from: https://github.com/MediaBrain-SJTU/FACT 212 | """ 213 | x = x.astype("float32") / 255. 214 | c = [.1, .2, .3, .4, .5][severity - 1] 215 | 216 | # FFT 217 | x_fft = np.fft.fft2(x, axes=(0, 1)) 218 | x_abs, x_pha = np.fft.fftshift(np.abs(x_fft), axes=(0, 1)), np.angle(x_fft) 219 | 220 | # Phase scaling 221 | alpha = 1.0 - c 222 | x_pha = x_pha * max(0, alpha) 223 | 224 | # Inverse FFT 225 | x_abs = np.fft.ifftshift(x_abs, axes=(0, 1)) 226 | x = x_abs * (np.e ** (1j * x_pha)) 227 | x = np.real(np.fft.ifft2(x, axes=(0, 1))) 228 | 229 | x = PILImage.fromarray((x * 255.).astype("uint8")) 230 | return x 231 | 232 | def apply_corruption(self, x, corruption_name): 233 | severity = self.get_severity() 234 | 235 | custom_corruptions = {"high_pass_filter": self.high_pass_filter, 236 | "constant_amplitude": self.constant_amplitude, 237 | "phase_scaling": self.phase_scaling} 238 | 239 | if corruption_name in get_corruption_names('all'): 240 | x = corrupt(x, corruption_name=corruption_name, severity=severity) 241 | x = PILImage.fromarray(x) 242 | 243 | elif corruption_name in custom_corruptions: 244 | x = custom_corruptions[corruption_name](x, severity=severity) 245 | 246 | else: 247 | assert True, "%s is not a supported corruption!" % corruption_name 248 | 249 | return x 250 | 251 | def weather(self, x): 252 | i = np.random.randint(0, 4) 253 | corruption_func = {0: "fog", 254 | 1: "snow", 255 | 2: "frost", 256 | 3: "spatter"} 257 | return self.apply_corruption(x, corruption_func[i]) 258 | 259 | def blur(self, x): 260 | i = np.random.randint(0, 5) 261 | corruption_func = {0: "zoom_blur", 262 | 1: "defocus_blur", 263 | 2: "glass_blur", 264 | 3: "gaussian_blur", 265 | 4: "motion_blur"} 266 | return self.apply_corruption(x, corruption_func[i]) 267 | 268 | def noise(self, x): 269 | i = np.random.randint(0, 4) 270 | corruption_func = {0: "speckle_noise", 271 | 1: "shot_noise", 272 | 2: "impulse_noise", 273 | 3: "gaussian_noise"} 274 | return self.apply_corruption(x, corruption_func[i]) 275 | 276 | def digital(self, x): 277 | i = np.random.randint(0, 6) 278 | corruption_func = {0: "jpeg_compression", 279 | 1: "pixelate", 280 | 2: "elastic_transform", 281 | 3: "brightness", 282 | 4: "saturate", 283 | 5: "contrast"} 284 | return self.apply_corruption(x, corruption_func[i]) 285 | 286 | def fourier(self, x): 287 | i = np.random.randint(0, 3) 288 | corruption_func = {0: "high_pass_filter", 289 | 1: "constant_amplitude", 290 | 2: "phase_scaling"} 291 | return self.apply_corruption(x, corruption_func[i]) 292 | 293 | def all(self, x): 294 | i = np.random.randint(0, 22) if self.corruption_mode == "all++" else np.random.randint(0, 19) 295 | corruption_func = {0: "fog", 296 | 1: "snow", 297 | 2: "frost", 298 | 3: "zoom_blur", 299 | 4: "defocus_blur", 300 | 5: "glass_blur", 301 | 6: "gaussian_blur", 302 | 7: "motion_blur", 303 | 8: "speckle_noise", 304 | 9: "shot_noise", 305 | 10: "impulse_noise", 306 | 11: "gaussian_noise", 307 | 12: "jpeg_compression", 308 | 13: "pixelate", 309 | 14: "spatter", 310 | 15: "elastic_transform", 311 | 16: "brightness", 312 | 17: "saturate", 313 | 18: "contrast", 314 | 19: "high_pass_filter", 315 | 20: "constant_amplitude", 316 | 21: "phase_scaling"} 317 | return self.apply_corruption(x, corruption_func[i]) 318 | 319 | def corruption(self, x, segmentation_mask=None): 320 | if self.rand_aug and np.random.uniform(0, 1) > 0.5: 321 | x_ = self.random_augmentation(PILImage.fromarray(x)) 322 | 323 | else: 324 | super_funcs = {"weather": self.weather, 325 | "blur": self.blur, 326 | "noise": self.noise, 327 | "digital": self.digital, 328 | "all": self.all, 329 | "all++": self.all} 330 | x_ = np.copy(x) 331 | corruption_mode = self.corruption_mode.split("_")[0] 332 | 333 | # Apply corruption 334 | if corruption_mode in super_funcs: 335 | x_ = super_funcs[corruption_mode](x_) 336 | else: 337 | x_ = self.apply_corruption(x, corruption_mode) 338 | 339 | return x_ 340 | 341 | def get_batch(self, epoch=None): 342 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 343 | teacher_logits = None if self.teacher_logits is None else [] 344 | 345 | if self.orig_plus_aug: 346 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 347 | images = torch.zeros(tensor_shape, dtype=torch.float32) 348 | augmented_images = torch.zeros(tensor_shape, dtype=torch.float32) 349 | for i in range(self.batch_size): 350 | # Avoid over flow 351 | if self.current_index > self.image_count - 1: 352 | if self.stage == "train": 353 | self.shuffle() 354 | else: 355 | self.current_index = 0 356 | 357 | x = self.images[self.current_index] 358 | y = self.labels[self.current_index] 359 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index] 360 | images[i] = normalize_dataset(PILImage.fromarray(x), img_mean_mode=self.img_mean_mode) 361 | labels[i] = y 362 | 363 | augmented_x = self.corruption(x, mask) 364 | augmented_images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode) 365 | 366 | if teacher_logits is not None: 367 | teacher_logits.append(self.teacher_logits[self.current_index]) 368 | 369 | self.current_index += 1 370 | 371 | # Include teacher logits as soft labels if applicable 372 | if teacher_logits is not None: 373 | labels = [labels, np.array(teacher_logits)] 374 | 375 | batches = [(images, labels), (augmented_images, labels)] 376 | 377 | else: 378 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 379 | images = torch.zeros(tensor_shape, dtype=torch.float32) 380 | for i in range(self.batch_size): 381 | # Avoid over flow 382 | if self.current_index > self.image_count - 1: 383 | if self.stage == "train": 384 | self.shuffle() 385 | else: 386 | self.current_index = 0 387 | 388 | x = self.images[self.current_index] 389 | y = self.labels[self.current_index] 390 | mask = None if self.segmentation_masks is None else self.segmentation_masks[self.current_index] 391 | 392 | augmented_x = self.corruption(x, mask) 393 | images[i] = normalize_dataset(augmented_x, img_mean_mode=self.img_mean_mode) 394 | labels[i] = y 395 | 396 | if teacher_logits is not None: 397 | teacher_logits.append(self.teacher_logits[self.current_index]) 398 | 399 | self.current_index += 1 400 | 401 | # Include teacher logits as soft labels if applicable 402 | if teacher_logits is not None: 403 | labels = [labels, np.array(teacher_logits)] 404 | 405 | batches = [(images, labels)] 406 | 407 | return batches 408 | 409 | def test(self, x, img_id): 410 | 411 | for f in get_corruption_names("all"): 412 | x_ = np.copy(x) 413 | x_ = self.apply_corruption(x_, f) 414 | x_.save("example_%d_%s.jpeg" % (img_id, f)) 415 | print("%s testing is done." % f) 416 | 417 | for f in ["high_pass_filter", "constant_amplitude", "phase_scaling"]: 418 | x_ = np.copy(x) 419 | x_ = self.apply_corruption(x_, f) 420 | x_.save("example_%d_%s.jpeg" % (img_id, f)) 421 | print("%s testing is done." % f) 422 | 423 | exit(1) -------------------------------------------------------------------------------- /preprocessing/image/AugMixGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image, ImageOps, ImageEnhance 4 | from tools import shuffle_data 5 | from preprocessing.Datasets import normalize_dataset 6 | 7 | # Copyright 2019 Google LLC 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # https://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | # ============================================================================== 21 | """Base augmentations operators.""" 22 | 23 | # ImageNet code should change this value 24 | IMAGE_SIZE = 224 25 | 26 | 27 | def int_parameter(level, maxval): 28 | """Helper function to scale `val` between 0 and maxval . 29 | 30 | Args: 31 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 32 | maxval: Maximum value that the operation can have. This will be scaled to 33 | level/PARAMETER_MAX. 34 | 35 | Returns: 36 | An int that results from scaling `maxval` according to `level`. 37 | """ 38 | return int(level * maxval / 10) 39 | 40 | 41 | def float_parameter(level, maxval): 42 | """Helper function to scale `val` between 0 and maxval. 43 | 44 | Args: 45 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 46 | maxval: Maximum value that the operation can have. This will be scaled to 47 | level/PARAMETER_MAX. 48 | 49 | Returns: 50 | A float that results from scaling `maxval` according to `level`. 51 | """ 52 | return float(level) * maxval / 10. 53 | 54 | 55 | def sample_level(n): 56 | return np.random.uniform(low=0.1, high=n) 57 | 58 | 59 | def autocontrast(pil_img, _): 60 | return ImageOps.autocontrast(pil_img) 61 | 62 | 63 | def equalize(pil_img, _): 64 | return ImageOps.equalize(pil_img) 65 | 66 | 67 | def posterize(pil_img, level): 68 | level = int_parameter(sample_level(level), 4) 69 | return ImageOps.posterize(pil_img, 4 - level) 70 | 71 | 72 | def rotate(pil_img, level): 73 | degrees = int_parameter(sample_level(level), 30) 74 | if np.random.uniform() > 0.5: 75 | degrees = -degrees 76 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 77 | 78 | 79 | def solarize(pil_img, level): 80 | level = int_parameter(sample_level(level), 256) 81 | return ImageOps.solarize(pil_img, 256 - level) 82 | 83 | 84 | def shear_x(pil_img, level): 85 | level = float_parameter(sample_level(level), 0.3) 86 | if np.random.uniform() > 0.5: 87 | level = -level 88 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 89 | Image.AFFINE, (1, level, 0, 0, 1, 0), 90 | resample=Image.BILINEAR) 91 | 92 | 93 | def shear_y(pil_img, level): 94 | level = float_parameter(sample_level(level), 0.3) 95 | if np.random.uniform() > 0.5: 96 | level = -level 97 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 98 | Image.AFFINE, (1, 0, 0, level, 1, 0), 99 | resample=Image.BILINEAR) 100 | 101 | 102 | def translate_x(pil_img, level): 103 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 104 | if np.random.random() > 0.5: 105 | level = -level 106 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 107 | Image.AFFINE, (1, 0, level, 0, 1, 0), 108 | resample=Image.BILINEAR) 109 | 110 | 111 | def translate_y(pil_img, level): 112 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 113 | if np.random.random() > 0.5: 114 | level = -level 115 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 116 | Image.AFFINE, (1, 0, 0, 0, 1, level), 117 | resample=Image.BILINEAR) 118 | 119 | 120 | # operation that overlaps with ImageNet-C's test set 121 | def color(pil_img, level): 122 | level = float_parameter(sample_level(level), 1.8) + 0.1 123 | return ImageEnhance.Color(pil_img).enhance(level) 124 | 125 | 126 | # operation that overlaps with ImageNet-C's test set 127 | def contrast(pil_img, level): 128 | level = float_parameter(sample_level(level), 1.8) + 0.1 129 | return ImageEnhance.Contrast(pil_img).enhance(level) 130 | 131 | 132 | # operation that overlaps with ImageNet-C's test set 133 | def brightness(pil_img, level): 134 | level = float_parameter(sample_level(level), 1.8) + 0.1 135 | return ImageEnhance.Brightness(pil_img).enhance(level) 136 | 137 | 138 | # operation that overlaps with ImageNet-C's test set 139 | def sharpness(pil_img, level): 140 | level = float_parameter(sample_level(level), 1.8) + 0.1 141 | return ImageEnhance.Sharpness(pil_img).enhance(level) 142 | 143 | augmentations_augmix = [ 144 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 145 | translate_x, translate_y 146 | ] 147 | 148 | augmentations_augmix_all = [ 149 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 150 | translate_x, translate_y, color, contrast, brightness, sharpness 151 | ] 152 | 153 | class AugMixGenerator: 154 | def __init__(self, 155 | dataset, 156 | batch_size, 157 | stage, 158 | aug_prob_coeff= 1., 159 | mixture_width = 3, 160 | mixture_depth = -1, 161 | aug_severity = 1, 162 | img_mean_mode=None, 163 | seed=13, 164 | orig_plus_aug=True): 165 | """ 166 | :param dataset: (tuple) x, y, segmentation mask (optional) 167 | :param batch_size: (int) # of inputs in a mini-batch 168 | :param stage: (str) train | test 169 | :param aug_prob_coeff: (float) alpha in the paper 170 | :param aug_severity: (int) from the reposityory 171 | :param mixture_width: (int) from the paper 172 | :param mixture_depth: (int) from the paper 173 | :param img_mean_mode: (str) use this for image normalization 174 | :param seed: (int) seed for input shuffle 175 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 176 | """ 177 | 178 | if stage not in ['train', 'test']: 179 | assert ValueError('invalid stage!') 180 | 181 | # Settings 182 | self.batch_size = batch_size 183 | self.stage = stage 184 | self.aug_prob_coeff = aug_prob_coeff 185 | self.mixture_width = mixture_width 186 | self.mixture_depth = mixture_depth 187 | self.aug_severity = aug_severity 188 | self.img_mean_mode = img_mean_mode 189 | self.seed = seed 190 | self.orig_plus_aug = orig_plus_aug 191 | 192 | # Preparation 193 | self.configuration() 194 | self.load_data(dataset) 195 | 196 | def configuration(self): 197 | self.shuffle_count = 1 198 | self.current_index = 0 199 | 200 | def shuffle(self): 201 | self.image_count = len(self.labels) 202 | self.current_index = 0 203 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 204 | labels=self.labels, 205 | teacher_logits=self.teacher_logits, 206 | seed=self.seed + self.shuffle_count) 207 | self.shuffle_count += 1 208 | 209 | def load_data(self, dataset): 210 | self.images = dataset["images"] 211 | self.labels = dataset["labels"] 212 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 213 | 214 | self.len_images = len(self.images) 215 | self.len_labels = len(self.labels) 216 | assert self.len_images == self.len_labels 217 | self.image_count = self.len_labels 218 | 219 | if self.stage == 'train': 220 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 221 | labels=self.labels, 222 | teacher_logits=self.teacher_logits, 223 | seed=self.seed) 224 | 225 | def get_batch_count(self): 226 | return (self.len_labels // self.batch_size) + 1 227 | 228 | def augmix(self, x): 229 | aug_list = augmentations_augmix_all 230 | 231 | ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) 232 | m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) 233 | 234 | mix = np.zeros_like(np.array(x)).astype("float32") 235 | for i in range(self.mixture_width): 236 | x_ = x.copy() 237 | depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) 238 | for _ in range(depth): 239 | op = np.random.choice(aug_list) 240 | x_ = op(x_, self.aug_severity) 241 | 242 | # Preprocessing commutes since all coefficients are convex 243 | mix += ws[i] * np.array(x_).astype("float32") 244 | 245 | mix = Image.fromarray(mix.astype("uint8")) 246 | x_ = Image.blend(x, mix, m) 247 | 248 | return x_ 249 | 250 | def augment(self, x): 251 | 252 | x_ = [self.augmix(x) for _ in range(2)] 253 | 254 | return x_ 255 | 256 | def get_batch(self, epoch=None): 257 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 258 | teacher_logits = None if self.teacher_logits is None else [] 259 | expansion_coeff = 2 260 | 261 | if self.orig_plus_aug: 262 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 263 | images = torch.zeros(tensor_shape, dtype=torch.float32) 264 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)] 265 | for i in range(self.batch_size): 266 | # Avoid over flow 267 | if self.current_index > self.image_count - 1: 268 | if self.stage == "train": 269 | self.shuffle() 270 | else: 271 | self.current_index = 0 272 | 273 | x = Image.fromarray(self.images[self.current_index]) 274 | y = self.labels[self.current_index] 275 | images[i] = normalize_dataset(x, self.img_mean_mode) 276 | labels[i] = y 277 | 278 | augmented_x = self.augment(x) 279 | for j in range(expansion_coeff): 280 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode) 281 | 282 | if teacher_logits is not None: 283 | teacher_logits.append(self.teacher_logits[self.current_index]) 284 | 285 | self.current_index += 1 286 | 287 | # Include teacher logits as soft labels if applicable 288 | if teacher_logits is not None: 289 | labels = [labels, np.array(teacher_logits)] 290 | 291 | batches = [(images, labels)] 292 | for i in range(expansion_coeff): 293 | batches.append((augmented_images[i], labels)) 294 | 295 | else: 296 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 297 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)] 298 | for i in range(self.batch_size): 299 | # Avoid over flow 300 | if self.current_index > self.image_count - 1: 301 | if self.stage == "train": 302 | self.shuffle() 303 | else: 304 | self.current_index = 0 305 | 306 | x = Image.fromarray(self.images[self.current_index]) 307 | y = self.labels[self.current_index] 308 | labels[i] = y 309 | 310 | augmented_x = self.augment(x) 311 | for j in range(expansion_coeff): 312 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode) 313 | 314 | if teacher_logits is not None: 315 | teacher_logits.append(self.teacher_logits[self.current_index]) 316 | 317 | self.current_index += 1 318 | 319 | # Include teacher logits as soft labels if applicable 320 | if teacher_logits is not None: 321 | labels = [labels, np.array(teacher_logits)] 322 | 323 | batches = [] 324 | for i in range(expansion_coeff): 325 | batches.append((augmented_images[i], labels)) 326 | 327 | return batches 328 | -------------------------------------------------------------------------------- /preprocessing/image/CutMixGenerator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from tools import shuffle_data 6 | from preprocessing.Datasets import normalize_dataset 7 | 8 | class CutMixGenerator: 9 | def __init__(self, 10 | dataset, 11 | batch_size, 12 | stage, 13 | img_mean_mode=None, 14 | seed=13, 15 | orig_plus_aug=True): 16 | """ 17 | :param dataset: (tuple) x, y, segmentation mask (optional) 18 | :param batch_size: (int) # of inputs in a mini-batch 19 | :param stage: (str) train | test 20 | :param img_mean_mode: (str) use this for image normalization 21 | :param seed: (int) seed for input shuffle 22 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 23 | """ 24 | 25 | if stage not in ['train', 'test']: 26 | assert ValueError('invalid stage!') 27 | 28 | # Settings 29 | self.batch_size = batch_size 30 | self.stage = stage 31 | self.img_mean_mode = img_mean_mode 32 | self.seed = seed 33 | self.orig_plus_aug = orig_plus_aug 34 | 35 | # Preparation 36 | self.configuration() 37 | self.load_data(dataset) 38 | 39 | def configuration(self): 40 | self.shuffle_count = 1 41 | self.current_index = 0 42 | 43 | def shuffle(self): 44 | self.image_count = len(self.labels) 45 | self.current_index = 0 46 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 47 | labels=self.labels, 48 | teacher_logits=self.teacher_logits, 49 | seed=self.seed + self.shuffle_count) 50 | self.shuffle_count += 1 51 | 52 | def load_data(self, dataset): 53 | self.images = dataset["images"] 54 | self.labels = dataset["labels"] 55 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 56 | 57 | self.len_images = len(self.images) 58 | self.len_labels = len(self.labels) 59 | assert self.len_images == self.len_labels 60 | self.image_count = self.len_labels 61 | 62 | if self.stage == 'train': 63 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 64 | labels=self.labels, 65 | teacher_logits=self.teacher_logits, 66 | seed=self.seed) 67 | 68 | def get_batch_count(self): 69 | return (self.len_labels // self.batch_size) + 1 70 | 71 | def get_random_boundingbox(self, img, l_param): 72 | width = img.shape[0] 73 | height = img.shape[1] 74 | 75 | r_x = np.random.randint(width) 76 | r_y = np.random.randint(height) 77 | 78 | r_l = np.sqrt(1 - l_param) 79 | r_w = int(width * r_l) 80 | r_h = int(height * r_l) 81 | 82 | if r_x + r_w < width: 83 | bbox_x1 = r_x 84 | bbox_x2 = r_w 85 | else: 86 | bbox_x1 = width - r_w 87 | bbox_x2 = width 88 | if r_y + r_h < height: 89 | bbox_y1 = r_y 90 | bbox_y2 = r_h 91 | else: 92 | bbox_y1 = height - r_h 93 | bbox_y2 = height 94 | 95 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2 96 | 97 | def cutmix(self, image_batch, label_batch, beta=1.0): 98 | l_param = np.random.beta(beta, beta, self.batch_size) 99 | rand_index = torch.randperm(self.batch_size) 100 | 101 | x = image_batch.detach().clone() 102 | y = np.zeros_like(label_batch) 103 | 104 | for i in range(self.batch_size): 105 | bx1, by1, bx2, by2 = self.get_random_boundingbox(x[i], l_param[i]) 106 | x[i][:, bx1:bx2, by1:by2] = image_batch[rand_index[i]][:, bx1:bx2, by1:by2] 107 | y[i] = label_batch[rand_index[i]] 108 | 109 | # Adjust lambda to exactly match pixel ratio 110 | l_param[i] = 1 - ((bx2 - bx1) * (by2 - by1) / (x.size()[-1] * x.size()[-2])) 111 | 112 | return x, (label_batch, y, l_param) 113 | 114 | def get_batch(self, epoch=None): 115 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 116 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 117 | images = torch.zeros(tensor_shape, dtype=torch.float32) 118 | for i in range(self.batch_size): 119 | # Avoid over flow 120 | if self.current_index > self.image_count - 1: 121 | if self.stage == "train": 122 | self.shuffle() 123 | else: 124 | self.current_index = 0 125 | 126 | x = self.images[self.current_index] 127 | y = self.labels[self.current_index] 128 | images[i] = normalize_dataset(Image.fromarray(x), img_mean_mode=self.img_mean_mode) 129 | labels[i] = y 130 | 131 | self.current_index += 1 132 | 133 | augmented_images, augmented_labels = self.cutmix(images, labels) 134 | 135 | if self.orig_plus_aug: 136 | batches = [(images, (labels, labels, (np.ones_like(augmented_labels[2])))), (augmented_images, augmented_labels)] 137 | 138 | else: 139 | batches = [(augmented_images, augmented_labels)] 140 | 141 | return batches 142 | -------------------------------------------------------------------------------- /preprocessing/image/CutOutGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from tools import shuffle_data 5 | from preprocessing.Datasets import normalize_dataset 6 | 7 | class CutOutGenerator: 8 | def __init__(self, 9 | dataset, 10 | batch_size, 11 | stage, 12 | img_mean_mode=None, 13 | seed=13, 14 | orig_plus_aug=True): 15 | """ 16 | :param dataset: (tuple) x, y, segmentation mask (optional) 17 | :param batch_size: (int) # of inputs in a mini-batch 18 | :param stage: (str) train | test 19 | :param img_mean_mode: (str) use this for image normalization 20 | :param seed: (int) seed for input shuffle 21 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 22 | """ 23 | 24 | if stage not in ['train', 'test']: 25 | assert ValueError('invalid stage!') 26 | 27 | # Settings 28 | self.batch_size = batch_size 29 | self.stage = stage 30 | self.img_mean_mode=img_mean_mode 31 | self.seed = seed 32 | self.orig_plus_aug = orig_plus_aug 33 | 34 | # Preparation 35 | self.configuration() 36 | self.load_data(dataset) 37 | 38 | def configuration(self): 39 | self.shuffle_count = 1 40 | self.current_index = 0 41 | 42 | def shuffle(self): 43 | self.image_count = len(self.labels) 44 | self.current_index = 0 45 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 46 | labels=self.labels, 47 | teacher_logits=self.teacher_logits, 48 | seed=self.seed + self.shuffle_count) 49 | self.shuffle_count += 1 50 | 51 | def load_data(self, dataset): 52 | self.images = dataset["images"] 53 | self.labels = dataset["labels"] 54 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 55 | 56 | self.len_images = len(self.images) 57 | self.len_labels = len(self.labels) 58 | assert self.len_images == self.len_labels 59 | self.image_count = self.len_labels 60 | 61 | if self.stage == 'train': 62 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 63 | labels=self.labels, 64 | teacher_logits=self.teacher_logits, 65 | seed=self.seed) 66 | 67 | def get_batch_count(self): 68 | return (self.len_labels // self.batch_size) + 1 69 | 70 | def get_random_eraser(self, p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1 / 0.3, v_l=0.0, v_h=1.0): 71 | """ 72 | This CutOut implementation is taken from: 73 | - https://github.com/yu4u/cutout-random-erasing 74 | 75 | ...and modified for info loss experiments 76 | 77 | # Arguments: 78 | :param p: (float) the probability that random erasing is performed 79 | :param s_l: (float) minimum proportion of erased area against input image 80 | :param s_h: (float) maximum proportion of erased area against input image 81 | :param r_1: (float) minimum aspect ratio of erased area 82 | :param r_2: (float) maximum aspect ratio of erased area 83 | :param v_l: (float) minimum value for erased area 84 | :param v_h: (float) maximum value for erased area 85 | :param fill: (str) fill-in mode for the cropped area 86 | 87 | :return: (np.array) augmented image 88 | """ 89 | def eraser(orig_img): 90 | input_img = np.copy(orig_img) 91 | if input_img.ndim == 3: 92 | img_h, img_w, img_c = input_img.shape 93 | elif input_img.ndim == 2: 94 | img_h, img_w = input_img.shape 95 | 96 | p_1 = np.random.rand() 97 | 98 | if p_1 > p: 99 | return input_img 100 | 101 | while True: 102 | s = np.random.uniform(s_l, s_h) * img_h * img_w 103 | r = np.random.uniform(r_1, r_2) 104 | w = int(np.sqrt(s / r)) 105 | h = int(np.sqrt(s * r)) 106 | left = np.random.randint(0, img_w) 107 | top = np.random.randint(0, img_h) 108 | 109 | if left + w <= img_w and top + h <= img_h: 110 | break 111 | 112 | input_img[top:top + h, left:left + w] = 0 113 | 114 | return input_img 115 | 116 | return eraser 117 | 118 | def cutout(self, x): 119 | 120 | eraser = self.get_random_eraser() 121 | x_ = eraser(x) 122 | 123 | return x_ 124 | 125 | def get_batch(self, epoch=None): 126 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 127 | 128 | if self.orig_plus_aug: 129 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 130 | images = torch.zeros(tensor_shape, dtype=torch.float32) 131 | augmented_images = torch.zeros(tensor_shape, dtype=torch.float32) 132 | for i in range(self.batch_size): 133 | # Avoid over flow 134 | if self.current_index > self.image_count - 1: 135 | if self.stage == "train": 136 | self.shuffle() 137 | else: 138 | self.current_index = 0 139 | 140 | x = self.images[self.current_index] 141 | y = self.labels[self.current_index] 142 | images[i] = normalize_dataset(Image.fromarray(x), img_mean_mode=self.img_mean_mode) 143 | labels[i] = y 144 | 145 | augmented_x = self.cutout(x) 146 | augmented_images[i] = normalize_dataset(Image.fromarray(augmented_x), img_mean_mode=self.img_mean_mode) 147 | 148 | self.current_index += 1 149 | 150 | batches = [(images, labels), (augmented_images, labels)] 151 | 152 | else: 153 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 154 | images = torch.zeros(tensor_shape, dtype=torch.float32) 155 | for i in range(self.batch_size): 156 | # Avoid over flow 157 | if self.current_index > self.image_count - 1: 158 | if self.stage == "train": 159 | self.shuffle() 160 | else: 161 | self.current_index = 0 162 | 163 | x = self.images[self.current_index] 164 | y = self.labels[self.current_index] 165 | 166 | augmented_x = self.cutout(x) 167 | images[i] = normalize_dataset(Image.fromarray(augmented_x), img_mean_mode=self.img_mean_mode) 168 | labels[i] = y 169 | 170 | self.current_index += 1 171 | 172 | batches = [(images, labels)] 173 | 174 | return batches 175 | -------------------------------------------------------------------------------- /preprocessing/image/ImageGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from tools import shuffle_data 5 | from preprocessing.Datasets import normalize_dataset 6 | 7 | class ImageGenerator: 8 | def __init__(self, 9 | dataset, 10 | batch_size, 11 | stage, 12 | img_mean_mode=None, 13 | seed=13): 14 | """ 15 | :param dataset: (tuple) x, y 16 | :param batch_size: (int) # of inputs in a mini-batch 17 | :param stage: (str) train | test 18 | :param img_mean_mode: (str) use this for image normalization 19 | :param seed: (int) seed for input shuffle 20 | """ 21 | 22 | if stage not in ['train', 'test']: 23 | assert ValueError('invalid stage!') 24 | 25 | # Settings 26 | self.batch_size = batch_size 27 | self.stage = stage 28 | self.img_mean_mode = img_mean_mode 29 | self.seed = seed 30 | 31 | # Preparation 32 | self.configuration() 33 | self.load_data(dataset) 34 | 35 | def configuration(self): 36 | self.shuffle_count = 1 37 | self.current_index = 0 38 | 39 | def shuffle(self): 40 | self.image_count = len(self.labels) 41 | self.current_index = 0 42 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 43 | labels=self.labels, 44 | teacher_logits=self.teacher_logits, 45 | seed=self.seed + self.shuffle_count) 46 | self.shuffle_count += 1 47 | 48 | def load_data(self, dataset): 49 | self.images = dataset["images"] if type(dataset) is dict else dataset[0] 50 | self.labels = dataset["labels"] if type(dataset) is dict else dataset[1] 51 | self.teacher_logits = dataset["teacher_logits"] if type(dataset) is dict and "teacher_logits" in dataset else None 52 | 53 | self.len_images = len(self.images) 54 | self.len_labels = len(self.labels) 55 | assert self.len_images == self.len_labels 56 | self.image_count = self.len_labels 57 | 58 | if self.stage == 'train': 59 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 60 | labels=self.labels, 61 | teacher_logits=self.teacher_logits, 62 | seed=self.seed) 63 | 64 | def get_batch_count(self): 65 | return (self.len_labels // self.batch_size) + 1 66 | 67 | def get_batch(self, epoch=None): 68 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 69 | teacher_logits = None if self.teacher_logits is None else [] 70 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 71 | images = torch.zeros(tensor_shape, dtype=torch.float32) 72 | for i in range(self.batch_size): 73 | # Avoid over flow 74 | if self.current_index > self.image_count - 1: 75 | if self.stage == "train": 76 | self.shuffle() 77 | else: 78 | self.current_index = 0 79 | 80 | image = Image.fromarray(self.images[self.current_index]) 81 | images[i] = normalize_dataset(image, img_mean_mode=self.img_mean_mode) 82 | labels[i] = self.labels[self.current_index] 83 | 84 | if teacher_logits is not None: 85 | teacher_logits.append(self.teacher_logits[self.current_index]) 86 | 87 | self.current_index += 1 88 | 89 | # Include teacher logits as soft labels if applicable 90 | if teacher_logits is not None: 91 | labels = [labels, np.array(teacher_logits)] 92 | 93 | return [(images, labels)] 94 | -------------------------------------------------------------------------------- /preprocessing/image/MixUpGenerator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from tools import shuffle_data 6 | from preprocessing.Datasets import normalize_dataset 7 | 8 | class MixUpGenerator: 9 | def __init__(self, 10 | dataset, 11 | batch_size, 12 | stage, 13 | img_mean_mode=None, 14 | seed=13, 15 | orig_plus_aug=True): 16 | """ 17 | :param dataset: (tuple) x, y, segmentation mask (optional) 18 | :param batch_size: (int) # of inputs in a mini-batch 19 | :param stage: (str) train | test 20 | :param img_mean_mode: (str) use this for image normalization 21 | :param seed: (int) seed for input shuffle 22 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 23 | """ 24 | 25 | if stage not in ['train', 'test']: 26 | assert ValueError('invalid stage!') 27 | 28 | # Settings 29 | self.batch_size = batch_size 30 | self.stage = stage 31 | self.img_mean_mode = img_mean_mode 32 | self.seed = seed 33 | self.orig_plus_aug = orig_plus_aug 34 | 35 | # Preparation 36 | self.configuration() 37 | self.load_data(dataset) 38 | 39 | def configuration(self): 40 | self.shuffle_count = 1 41 | self.current_index = 0 42 | 43 | def shuffle(self): 44 | self.image_count = len(self.labels) 45 | self.current_index = 0 46 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 47 | labels=self.labels, 48 | teacher_logits=self.teacher_logits, 49 | seed=self.seed + self.shuffle_count) 50 | self.shuffle_count += 1 51 | 52 | def load_data(self, dataset): 53 | self.images = dataset["images"] 54 | self.labels = dataset["labels"] 55 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 56 | 57 | self.len_images = len(self.images) 58 | self.len_labels = len(self.labels) 59 | assert self.len_images == self.len_labels 60 | self.image_count = self.len_labels 61 | 62 | if self.stage == 'train': 63 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 64 | labels=self.labels, 65 | teacher_logits=self.teacher_logits, 66 | seed=self.seed) 67 | 68 | def get_batch_count(self): 69 | return (self.len_labels // self.batch_size) + 1 70 | 71 | def mixup(self, image_batch, label_batch, beta=1.0): 72 | l_param = np.random.beta(beta, beta, self.batch_size) 73 | rand_index = np.random.permutation(self.batch_size) 74 | 75 | x = image_batch.detach().clone() 76 | y = np.zeros_like(label_batch) 77 | 78 | for i in range(self.batch_size): 79 | x[i] = l_param[i] * x[i] + (1 - l_param[i]) * x[rand_index[i]] 80 | y[i] = label_batch[rand_index[i]] 81 | 82 | return x, (label_batch, y, l_param) 83 | 84 | def get_batch(self, epoch=None): 85 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 86 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 87 | images = torch.zeros(tensor_shape, dtype=torch.float32) 88 | for i in range(self.batch_size): 89 | # Avoid over flow 90 | if self.current_index > self.image_count - 1: 91 | if self.stage == "train": 92 | self.shuffle() 93 | else: 94 | self.current_index = 0 95 | 96 | x = self.images[self.current_index] 97 | y = self.labels[self.current_index] 98 | images[i] = normalize_dataset(Image.fromarray(x), img_mean_mode=self.img_mean_mode) 99 | labels[i] = y 100 | 101 | self.current_index += 1 102 | 103 | augmented_images, augmented_labels = self.mixup(images, labels) 104 | 105 | if self.orig_plus_aug: 106 | batches = [(images, (labels, labels, (np.ones_like(labels)))), (augmented_images, augmented_labels)] 107 | 108 | else: 109 | batches = [(augmented_images, augmented_labels)] 110 | 111 | return batches 112 | -------------------------------------------------------------------------------- /preprocessing/image/RandAugmentGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 4 | import numpy as np 5 | from PIL import Image 6 | from tools import shuffle_data 7 | from preprocessing.Datasets import normalize_dataset 8 | 9 | def ShearX(img, v): # [-0.3, 0.3] 10 | assert -0.3 <= v <= 0.3 11 | if random.random() > 0.5: 12 | v = -v 13 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 14 | 15 | def ShearY(img, v): # [-0.3, 0.3] 16 | assert -0.3 <= v <= 0.3 17 | if random.random() > 0.5: 18 | v = -v 19 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 20 | 21 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 22 | assert -0.45 <= v <= 0.45 23 | if random.random() > 0.5: 24 | v = -v 25 | v = v * img.size[0] 26 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 27 | 28 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 29 | assert 0 <= v 30 | if random.random() > 0.5: 31 | v = -v 32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 33 | 34 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 35 | assert -0.45 <= v <= 0.45 36 | if random.random() > 0.5: 37 | v = -v 38 | v = v * img.size[1] 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 40 | 41 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 42 | assert 0 <= v 43 | if random.random() > 0.5: 44 | v = -v 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | def Rotate(img, v): # [-30, 30] 48 | assert -30 <= v <= 30 49 | if random.random() > 0.5: 50 | v = -v 51 | return img.rotate(v) 52 | 53 | def AutoContrast(img, _): 54 | return PIL.ImageOps.autocontrast(img) 55 | 56 | def Invert(img, _): 57 | return PIL.ImageOps.invert(img) 58 | 59 | def Equalize(img, _): 60 | return PIL.ImageOps.equalize(img) 61 | 62 | def Flip(img, _): # not from the paper 63 | return PIL.ImageOps.mirror(img) 64 | 65 | def Solarize(img, v): # [0, 256] 66 | assert 0 <= v <= 256 67 | return PIL.ImageOps.solarize(img, v) 68 | 69 | def SolarizeAdd(img, addition=0, threshold=128): 70 | img_np = np.array(img).astype(np.int) 71 | img_np = img_np + addition 72 | img_np = np.clip(img_np, 0, 255) 73 | img_np = img_np.astype(np.uint8) 74 | img = Image.fromarray(img_np) 75 | return PIL.ImageOps.solarize(img, threshold) 76 | 77 | def Posterize(img, v): # [4, 8] 78 | v = int(v) 79 | v = max(1, v) 80 | return PIL.ImageOps.posterize(img, v) 81 | 82 | def Contrast(img, v): # [0.1,1.9] 83 | assert 0.1 <= v <= 1.9 84 | return PIL.ImageEnhance.Contrast(img).enhance(v) 85 | 86 | def Color(img, v): # [0.1,1.9] 87 | assert 0.1 <= v <= 1.9 88 | return PIL.ImageEnhance.Color(img).enhance(v) 89 | 90 | def Brightness(img, v): # [0.1,1.9] 91 | assert 0.1 <= v <= 1.9 92 | return PIL.ImageEnhance.Brightness(img).enhance(v) 93 | 94 | def Sharpness(img, v): # [0.1,1.9] 95 | assert 0.1 <= v <= 1.9 96 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 97 | 98 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 99 | assert 0.0 <= v <= 0.2 100 | if v <= 0.: 101 | return img 102 | 103 | v = v * img.size[0] 104 | return CutoutAbs(img, v) 105 | 106 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 107 | # assert 0 <= v <= 20 108 | if v < 0: 109 | return img 110 | w, h = img.size 111 | x0 = np.random.uniform(w) 112 | y0 = np.random.uniform(h) 113 | 114 | x0 = int(max(0, x0 - v / 2.)) 115 | y0 = int(max(0, y0 - v / 2.)) 116 | x1 = min(w, x0 + v) 117 | y1 = min(h, y0 + v) 118 | 119 | xy = (x0, y0, x1, y1) 120 | color = (125, 123, 114) 121 | # color = (0, 0, 0) 122 | img = img.copy() 123 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 124 | return img 125 | 126 | def SamplePairing(imgs): # [0, 0.4] 127 | def f(img1, v): 128 | i = np.random.choice(len(imgs)) 129 | img2 = PIL.Image.fromarray(imgs[i]) 130 | return PIL.Image.blend(img1, img2, v) 131 | 132 | return f 133 | 134 | def Identity(img, v): 135 | return img 136 | 137 | def augment_list(): # 16 oeprations and their ranges 138 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 139 | # l = [ 140 | # (Identity, 0., 1.0), 141 | # (ShearX, 0., 0.3), # 0 142 | # (ShearY, 0., 0.3), # 1 143 | # (TranslateX, 0., 0.33), # 2 144 | # (TranslateY, 0., 0.33), # 3 145 | # (Rotate, 0, 30), # 4 146 | # (AutoContrast, 0, 1), # 5 147 | # (Invert, 0, 1), # 6 148 | # (Equalize, 0, 1), # 7 149 | # (Solarize, 0, 110), # 8 150 | # (Posterize, 4, 8), # 9 151 | # # (Contrast, 0.1, 1.9), # 10 152 | # (Color, 0.1, 1.9), # 11 153 | # (Brightness, 0.1, 1.9), # 12 154 | # (Sharpness, 0.1, 1.9), # 13 155 | # # (Cutout, 0, 0.2), # 14 156 | # # (SamplePairing(imgs), 0, 0.4), # 15 157 | # ] 158 | 159 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 160 | l = [ 161 | (AutoContrast, 0, 1), 162 | (Equalize, 0, 1), 163 | (Invert, 0, 1), 164 | (Rotate, 0, 30), 165 | (Posterize, 0, 4), 166 | (Solarize, 0, 256), 167 | (SolarizeAdd, 0, 110), 168 | (Color, 0.1, 1.9), 169 | (Contrast, 0.1, 1.9), 170 | (Brightness, 0.1, 1.9), 171 | (Sharpness, 0.1, 1.9), 172 | (ShearX, 0., 0.3), 173 | (ShearY, 0., 0.3), 174 | (CutoutAbs, 0, 40), 175 | (TranslateXabs, 0., 100), 176 | (TranslateYabs, 0., 100), 177 | ] 178 | 179 | return l 180 | 181 | class RandAugment: 182 | """ 183 | Adopted from: https://github.com/ildoonet/pytorch-randaugment 184 | """ 185 | def __init__(self, n, m): 186 | self.n = n 187 | self.m = m # [0, 30] 188 | self.augment_list = augment_list() 189 | 190 | def __call__(self, img): 191 | 192 | ops = random.choices(self.augment_list, k=self.n) 193 | for op, minval, maxval in ops: 194 | val = (float(self.m) / 30) * float(maxval - minval) + minval 195 | new_img = op(img, val) 196 | 197 | return new_img 198 | 199 | class RandAugmentGenerator: 200 | def __init__(self, 201 | dataset, 202 | batch_size, 203 | stage, 204 | corruption_mode, 205 | aug_count = 1, 206 | aug_rate = 5, 207 | img_mean_mode=None, 208 | seed=13, 209 | orig_plus_aug=True): 210 | """ 211 | :param dataset: (tuple) x, y, segmentation mask (optional) 212 | :param batch_size: (int) # of inputs in a mini-batch 213 | :param stage: (str) train | test 214 | :param corruption_mode: (str) requied for adv RandAugment experiments 215 | :param aug_count: (int) N in the paper 216 | :param aug_rate: (int) M in the paper 217 | :param img_mean_mode: (str) use this for image normalization 218 | :param seed: (int) seed for input shuffle 219 | :param orig_plus_aug: (bool) if True, original images will be kept in the batch along with corrupted ones 220 | """ 221 | 222 | if stage not in ['train', 'test']: 223 | assert ValueError('invalid stage!') 224 | 225 | # Settings 226 | self.batch_size = batch_size 227 | self.stage = stage 228 | self.corruption_mode = corruption_mode 229 | self.img_mean_mode = img_mean_mode 230 | self.seed = seed 231 | self.orig_plus_aug = orig_plus_aug 232 | 233 | # Preparation 234 | self.configuration() 235 | self.load_data(dataset) 236 | self.random_augmentation = RandAugment(aug_count, aug_rate) 237 | 238 | def configuration(self): 239 | self.shuffle_count = 1 240 | self.current_index = 0 241 | 242 | def shuffle(self): 243 | self.image_count = len(self.labels) 244 | self.current_index = 0 245 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 246 | labels=self.labels, 247 | teacher_logits=self.teacher_logits, 248 | seed=self.seed + self.shuffle_count) 249 | self.shuffle_count += 1 250 | 251 | def load_data(self, dataset): 252 | self.images = dataset["images"] 253 | self.labels = dataset["labels"] 254 | self.teacher_logits = dataset["teacher_logits"] if "teacher_logits" in dataset else None 255 | 256 | self.len_images = len(self.images) 257 | self.len_labels = len(self.labels) 258 | assert self.len_images == self.len_labels 259 | self.image_count = self.len_labels 260 | 261 | if self.stage == 'train': 262 | self.images, self.labels, self.teacher_logits, _ = shuffle_data(samples=self.images, 263 | labels=self.labels, 264 | teacher_logits=self.teacher_logits, 265 | seed=self.seed) 266 | 267 | def get_batch_count(self): 268 | return (self.len_labels // self.batch_size) + 1 269 | 270 | def augment(self, x, n): 271 | 272 | x_ = [self.random_augmentation(x) for _ in range(n)] 273 | 274 | return x_ 275 | 276 | def get_batch(self, epoch=None): 277 | tensor_shape = (self.batch_size, self.images.shape[3], self.images.shape[1], self.images.shape[2]) 278 | teacher_logits = None if self.teacher_logits is None else [] 279 | expansion_coeff = 5 if self.corruption_mode == "randaugment++" else 1 280 | 281 | if self.orig_plus_aug: 282 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 283 | images = torch.zeros(tensor_shape, dtype=torch.float32) 284 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)] 285 | for i in range(self.batch_size): 286 | # Avoid over flow 287 | if self.current_index > self.image_count - 1: 288 | if self.stage == "train": 289 | self.shuffle() 290 | else: 291 | self.current_index = 0 292 | 293 | x = Image.fromarray(self.images[self.current_index]) 294 | y = self.labels[self.current_index] 295 | images[i] = normalize_dataset(x, self.img_mean_mode) 296 | labels[i] = y 297 | 298 | augmented_x = self.augment(x, expansion_coeff) 299 | for j in range(expansion_coeff): 300 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode) 301 | 302 | if teacher_logits is not None: 303 | teacher_logits.append(self.teacher_logits[self.current_index]) 304 | 305 | self.current_index += 1 306 | 307 | # Include teacher logits as soft labels if applicable 308 | if teacher_logits is not None: 309 | labels = [labels, np.array(teacher_logits)] 310 | 311 | batches = [(images, labels)] 312 | for i in range(expansion_coeff): 313 | batches.append((augmented_images[i], labels)) 314 | 315 | else: 316 | labels = np.zeros(tuple([self.batch_size] + list(self.labels.shape)[1:])) 317 | augmented_images = [torch.zeros(tensor_shape, dtype=torch.float32) for _ in range(expansion_coeff)] 318 | for i in range(self.batch_size): 319 | # Avoid over flow 320 | if self.current_index > self.image_count - 1: 321 | if self.stage == "train": 322 | self.shuffle() 323 | else: 324 | self.current_index = 0 325 | 326 | x = Image.fromarray(self.images[self.current_index]) 327 | y = self.labels[self.current_index] 328 | labels[i] = y 329 | 330 | augmented_x = self.augment(x, expansion_coeff) 331 | for j in range(expansion_coeff): 332 | augmented_images[j][i] = normalize_dataset(augmented_x[j], img_mean_mode=self.img_mean_mode) 333 | 334 | if teacher_logits is not None: 335 | teacher_logits.append(self.teacher_logits[self.current_index]) 336 | 337 | self.current_index += 1 338 | 339 | # Include teacher logits as soft labels if applicable 340 | if teacher_logits is not None: 341 | labels = [labels, np.array(teacher_logits)] 342 | 343 | batches = [] 344 | for i in range(expansion_coeff): 345 | batches.append((augmented_images[i], labels)) 346 | 347 | return batches 348 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.5.1+cu101 2 | numpy~=1.19.5 3 | torchvision~=0.6.1+cu101 4 | Pillow~=8.3.1 5 | matplotlib~=3.1.1 6 | sklearn~=0.0 7 | scikit-learn~=0.24.1 8 | scipy~=1.6.1 9 | imagecorruptions~=1.1.2 10 | tqdm~=4.58.0 11 | pycocotools~=2.0.0 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import configparser 4 | from models.ResNet import get_resnet 5 | from testers.DomainGeneralization_tester import DomainGeneralization_tester 6 | from tools import * 7 | 8 | datasets_str = """ 9 | Supported benchmarks: 10 | - PACS 11 | - COCO 12 | - DomainNet 13 | """ 14 | 15 | def create_config_file(config): 16 | # Default configurations 17 | config["DEFAULT"] = {"version": "1.0.1", 18 | "model": "resnet", 19 | "depth": 18, 20 | "lr": 4e-3, 21 | "batch_size": 128, 22 | "epochs": 30, 23 | "optimizer": "sgd", 24 | "momentum": 0.9, 25 | "temperature": 1.0, 26 | "img_mean_mode": "imagenet", 27 | "corruption_mode": "None", 28 | "corruption_dist": "uniform", 29 | "only_corrupted": False, 30 | "loss": "CrossEntropy", 31 | "train_dataset": "PACS:Photo", 32 | "test_datasets": "None", 33 | "print_config": True, 34 | "data_dir": "../../datasets", 35 | "first_run": False, 36 | "model_dir": ".", 37 | "save_model": False, 38 | "knowledge_distillation": False, 39 | "random_aug": False} 40 | 41 | with open("settings.ini", "w+") as config_file: 42 | config.write(config_file) 43 | 44 | if __name__ == '__main__': 45 | 46 | # Dynamic parameters 47 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 48 | parser.add_argument("--model", help="selected neural net architecture", type=str) 49 | parser.add_argument("--depth", help="# of layers", type=int) 50 | parser.add_argument("--lr", help="learning rate", type=float) 51 | parser.add_argument("--batch_size", help="batch size (must be an even number!)", type=int) 52 | parser.add_argument("--epochs", help="# of epochs", type=int) 53 | parser.add_argument("--optimizer", help="optimization algorithm", type=str) 54 | parser.add_argument("--momentum", help="momentum (only relevant if the 'optimizer' algorithm is using it)", type=float) 55 | parser.add_argument("--weight_decay", help="L2 regularization penalty", type=float) 56 | parser.add_argument("--temperature", help="temperature for contrastive distillation loss", type=float) 57 | parser.add_argument("--pretrained_weights", help="imagenet | None", type=str) 58 | parser.add_argument("--img_mean_mode", help="image mean subtraction mode for dataset preprocessing, options: None | per_pixel | per_channel", type=str) 59 | parser.add_argument("--corruption_mode", help="visual corruptions on inputs as a data augmentation method", type=str) 60 | parser.add_argument("--corruption_dist", help="distribution from which the corruption rate is randomly sampled per image", type=str) 61 | parser.add_argument("--only_corrupted", help="when info loss applied, only the corrupted images will be in the mini-batch", action="store_true") 62 | parser.add_argument("--loss", help="loss function(s)", nargs="+") 63 | parser.add_argument("--train_dataset", help=datasets_str, type=str) 64 | parser.add_argument("--test_datasets", help="list of test sets for domain generalization experiments", nargs="+") 65 | parser.add_argument("--to_path", help="filepath to save models with custom names", type=str) 66 | parser.add_argument("--data_dir", help="filepath to save datasets", type=str) 67 | parser.add_argument("--first_run", help="to initiate COCO preprocessing", action="store_true") 68 | parser.add_argument("--model_dir", help="filepath to save models", type=str) 69 | parser.add_argument("--print_config", help="prints the active configurations", action="store_true") 70 | parser.add_argument("--save_model", help="to save the trained models", action="store_true") 71 | parser.add_argument("--knowledge_distillation", help="to disable batch normalization", action="store_true") 72 | parser.add_argument("--random_aug", help="to enable random data augmentation", action="store_true") 73 | args = vars(parser.parse_args()) 74 | 75 | # Static parameters 76 | config = configparser.ConfigParser(allow_no_value=True) 77 | try: 78 | if not os.path.exists("settings.ini"): 79 | create_config_file(config) 80 | 81 | # Override the default values if specified 82 | config.read("settings.ini") 83 | temp = dict(config["DEFAULT"]) 84 | temp.update({k: v for k, v in args.items() if v is not None}) 85 | config.read_dict({"DEFAULT": temp}) 86 | config = config["DEFAULT"] 87 | 88 | # Assign the active values 89 | version = config["version"] 90 | arch = config["model"].lower() 91 | depth = int(config["depth"]) 92 | lr = float(config["lr"]) 93 | batch_size = int(config["batch_size"]) 94 | epochs = int(config["epochs"]) 95 | optimizer = config["optimizer"] 96 | momentum = float(config["momentum"]) 97 | weight_decay = float(config["weight_decay"]) if "weight_decay" in config else .0 98 | temperature = float(config["temperature"]) 99 | pretrained_weights = config["pretrained_weights"] if "pretrained_weights" in config else None 100 | img_mean_mode = config["img_mean_mode"] if config["img_mean_mode"].lower() != "none" else None 101 | corruption_mode = config["corruption_mode"] if config["corruption_mode"].lower() != "none" else None 102 | corruption_dist = config["corruption_dist"] 103 | loss = config["loss"] 104 | train_dataset = config["train_dataset"] 105 | test_datasets = config["test_datasets"] 106 | to_path = config["to_path"] if "to_path" in config else None 107 | data_dir = config["data_dir"] 108 | model_dir = config["model_dir"] 109 | FIRST_RUN = config["first_run"] 110 | PRINT_CONFIG = config.getboolean("print_config") 111 | SAVE_MODEL = config.getboolean("save_model") 112 | KNOWLEDGE_DISTILLATION = config.getboolean("knowledge_distillation") 113 | ONLY_CORRUPTED = config.getboolean("only_corrupted") 114 | RANDOM_AUG = config.getboolean("random_aug") 115 | log("Configuration is completed.") 116 | except Exception as e: 117 | log("Error: " + str(e), LogType.ERROR) 118 | log("Configuration fault! New settings.ini is created. Restart the program.", LogType.ERROR) 119 | create_config_file(config) 120 | exit(1) 121 | 122 | # Process benchmark parameters 123 | log("Single-source domain generalization experiment...") 124 | 125 | # Process selected neural net 126 | if arch not in ["resnet"]: 127 | log("Nice try... but %s is not a supported neural net architecture!" % arch, LogType.ERROR) 128 | exit(1) 129 | 130 | # Process selected datasets for benchmarking 131 | datasets = ["COCO", 132 | "PACS:Photo", 133 | "FullDomainNet:Real"] 134 | # Dataset checker 135 | if train_dataset not in datasets: 136 | log("Nice try... but %s is not an allowed dataset!" % train_dataset, LogType.ERROR) 137 | exit(1) 138 | 139 | # Process selected test datasets for domain generalization 140 | if args["test_datasets"] is not None and len(args["test_datasets"]) > 0: 141 | supported_datasets = ["PACS:Art", 142 | "PACS:Cartoon", 143 | "PACS:Sketch", 144 | "PACS:Photo", 145 | "DomainNet:Real", 146 | "DomainNet:Infograph", 147 | "DomainNet:Clipart", 148 | "DomainNet:Painting", 149 | "DomainNet:Quickdraw", 150 | "DomainNet:Sketch", 151 | "FullDomainNet:Infograph", 152 | "FullDomainNet:Clipart", 153 | "FullDomainNet:Painting", 154 | "FullDomainNet:Quickdraw", 155 | "FullDomainNet:Sketch"] 156 | # Dataset checker 157 | for s in args["test_datasets"]: 158 | if s not in supported_datasets: 159 | log("Nice try... but %s is not an allowed dataset!" % s, LogType.ERROR) 160 | exit(1) 161 | 162 | # Handle specific dataset selections 163 | test_datasets = args["test_datasets"] 164 | elif test_datasets == "None": 165 | test_datasets = None 166 | else: 167 | test_datasets = [test_datasets] 168 | 169 | # Process loss function(s) 170 | if args["loss"] is not None and len(args["loss"]) > 0: 171 | if len(args["loss"]) == 1: 172 | loss = args["loss"][0] 173 | else: 174 | loss = args["loss"] 175 | 176 | # Process the mini-batch state 177 | orig_plus_aug = False if ONLY_CORRUPTED else True 178 | 179 | # Log the active configuration if needed 180 | if PRINT_CONFIG: 181 | log_config(config) 182 | 183 | # Prepare the benchmark 184 | tester = DomainGeneralization_tester(train_dataset=train_dataset, 185 | test_dataset=test_datasets, 186 | img_mean_mode=img_mean_mode, 187 | data_dir=data_dir, 188 | distillation=KNOWLEDGE_DISTILLATION, 189 | first_run=FIRST_RUN, 190 | wait=True) 191 | tester.activate() # manually trigger the dataset loader 192 | n_classes = tester.get_n_classes() 193 | 194 | # Build the baseline model 195 | model_name = "%s[%s][img_mean=%s][aug=%s]" % (get_arch_name(arch, depth), train_dataset, img_mean_mode, corruption_mode) 196 | #model_name = "%s[%s][img_mean=%s][aug=%s_T%s]" % (get_arch_name(arch, depth), train_dataset, img_mean_mode, corruption_mode, temperature) 197 | 198 | if arch == "resnet": 199 | model = get_resnet(depth, n_classes) 200 | 201 | log("Baseline model is ready.") 202 | 203 | # Train the baseline model 204 | log("Baseline model training...") 205 | hist, score = tester.run(model, 206 | name=model_name, 207 | optimizer=optimizer, 208 | lr=lr, 209 | momentum=momentum, 210 | weight_decay=weight_decay, 211 | loss=loss, 212 | batch_size=batch_size, 213 | epochs=epochs, 214 | corruption_mode=corruption_mode, 215 | corruption_dist=corruption_dist, 216 | orig_plus_aug=orig_plus_aug, 217 | temperature=temperature, 218 | rand_aug=RANDOM_AUG) 219 | 220 | log("%s Test accuracy: %s" % (model_name, score)) 221 | log("----------------------------------------------------------------") 222 | 223 | # Plot and save the learning curve 224 | chart_path = "%s_learning_curve.png" % model_name 225 | chart_path = chart_path.replace(":", "_") 226 | plot_learning_curve(hist, chart_path) 227 | 228 | # Save the baseline model & print its structure 229 | if SAVE_MODEL: 230 | if to_path is None: 231 | torch.save(model, os.path.join(model_dir, "%s.pth" % model_name)) 232 | else: 233 | torch.save(model, to_path) 234 | print(model) 235 | 236 | del model 237 | 238 | log("Done.") 239 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "" 3 | echo " --------------------------------- " 4 | echo " | | " 5 | echo "----------------- Domain Generalization Experiments ------------------" 6 | echo " | | " 7 | echo " --------------------------------- " 8 | echo "" 9 | echo "----------------------------------------------------------------------" 10 | for i in 1 2 3 4 5; do 11 | echo "COCO experiments..." 12 | for depth in 18; do 13 | for lr in 4e-3; do 14 | for img_mean in imagenet; do 15 | 16 | # Baseline 17 | echo "Exp #"$i" [Baseline]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 18 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode None --print_config > dump 19 | 20 | # MixUp experiments 21 | echo "Exp #"$i" [MixUp]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 22 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode mixup --print_config > dump 23 | 24 | # CutOut experiments 25 | echo "Exp #"$i" [CutOut]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 26 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode cutout --print_config > dump 27 | 28 | # CutMix experiments 29 | echo "Exp #"$i" [CutMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 30 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode cutmix --print_config > dump 31 | 32 | # RandAugment experiments 33 | echo "Exp #"$i" [RandAugment]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 34 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode randaugment --print_config > dump 35 | 36 | # AugMix experiments 37 | echo "Exp #"$i" [AugMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 38 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy JSDivergence --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode augmix --print_config > dump 39 | 40 | # VC experiments 41 | echo "Exp #"$i" [VC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 42 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode vc --print_config > dump 43 | 44 | # ACVC experiments 45 | echo "Exp #"$i" [ACVC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 46 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy AttentionConsistency --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset COCO --test_datasets DomainNet:Real DomainNet:Infograph DomainNet:Clipart DomainNet:Painting DomainNet:Quickdraw DomainNet:Sketch --corruption_mode acvc --print_config > dump 47 | 48 | done 49 | done 50 | done 51 | echo "PACS experiments..." 52 | for depth in 18; do 53 | for lr in 4e-3; do 54 | for img_mean in imagenet; do 55 | 56 | # Baseline 57 | echo "Exp #"$i" [Baseline]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 58 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode None --print_config > dump 59 | 60 | # MixUp experiments 61 | echo "Exp #"$i" [MixUp]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 62 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode mixup --print_config > dump 63 | 64 | # CutOut experiments 65 | echo "Exp #"$i" [CutOut]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 66 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode cutout --print_config > dump 67 | 68 | # CutMix experiments 69 | echo "Exp #"$i" [CutMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 70 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode cutmix --print_config > dump 71 | 72 | # RandAugment experiments 73 | echo "Exp #"$i" [RandAugment]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 74 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode randaugment --print_config > dump 75 | 76 | # AugMix experiments 77 | echo "Exp #"$i" [AugMix]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 78 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy JSDivergence --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode augmix --print_config > dump 79 | 80 | echo "Exp #"$i" [VC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 81 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode vc --print_config > dump 82 | 83 | echo "Exp #"$i" [ACVC]: Training variant = {depth: "$depth", lr: "$lr", img_mean:"$img_mean"}" 84 | python run.py --model resnet --depth $depth --pretrained_weights imagenet --lr $lr --loss CrossEntropy AttentionConsistency --optimizer sgd --batch_size 128 --img_mean_mode $img_mean --epochs 30 --data_dir datasets --train_dataset PACS:Photo --test_datasets PACS:Art PACS:Cartoon PACS:Sketch --corruption_mode acvc --print_config > dump 85 | 86 | done 87 | done 88 | done 89 | done 90 | echo "Done." 91 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = 1.0.1 3 | model = resnet 4 | depth = 18 5 | lr = 4e-3 6 | batch_size = 128 7 | epochs = 30 8 | optimizer = sgd 9 | momentum = 0.9 10 | temperature = 1.0 11 | img_mean_mode = imagenet 12 | corruption_mode = None 13 | corruption_dist = uniform 14 | only_corrupted = False 15 | loss = CrossEntropy 16 | train_dataset = PACS:Photo 17 | test_datasets = None 18 | print_config = True 19 | save_model = False 20 | data_dir = ../../datasets 21 | first_run = False 22 | model_dir = . 23 | knowledge_distillation = False 24 | random_aug = False 25 | 26 | -------------------------------------------------------------------------------- /testers/DomainGeneralization_tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from tools import log, compute_accuracy 7 | from losses.AttentionConsistency import AttentionConsistency 8 | from losses.Distillation import Distillation 9 | from losses.JSDivergence import JSDivergence 10 | from preprocessing import Datasets 11 | from preprocessing.image.ImageGenerator import ImageGenerator 12 | from preprocessing.image.ACVCGenerator import ACVCGenerator 13 | from preprocessing.image.MixUpGenerator import MixUpGenerator 14 | from preprocessing.image.CutOutGenerator import CutOutGenerator 15 | from preprocessing.image.CutMixGenerator import CutMixGenerator 16 | from preprocessing.image.AblationGenerator import AblationGenerator 17 | from preprocessing.image.RandAugmentGenerator import RandAugmentGenerator 18 | from preprocessing.image.AugMixGenerator import AugMixGenerator 19 | 20 | class LR_Scheduler: 21 | def __init__(self, base_lr, dataset, optimizer): 22 | self.base_lr = base_lr 23 | self.dataset = dataset 24 | self.optimizer = optimizer 25 | self.epoch = 0 26 | 27 | def step(self): 28 | self.epoch += 1 29 | lr = self.base_lr 30 | 31 | if self.epoch > 24: 32 | lr *= 1e-1 33 | 34 | for param_group in self.optimizer.param_groups: 35 | param_group['lr'] = lr 36 | 37 | class DomainGeneralization_tester: 38 | def __init__(self, train_dataset, test_dataset, img_mean_mode=None, distillation=False, wait=False, data_dir=None, first_run=False): 39 | """ 40 | Conditional constructor to enable creating testers without immediately triggering the dataset loader. 41 | 42 | # Arguments 43 | :param train_dataset: (str) name of the training set 44 | :param test_dataset: (str) name of the test set 45 | :param img_mean_mode: (str) image mean subtraction mode for dataset preprocessing 46 | :param distillation: (bool) enable/disable knowledge distillation loss 47 | :param wait: (bool) If True, then the constructor will onyl create the instance and 48 | wait for manual activation to actually load the dataset 49 | :param data_dir: (str) relative filepath to datasets 50 | :param first_run: (bool) to initiate COCO benchmark preprocessing 51 | """ 52 | # Base class variables 53 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 54 | self.dataset = None 55 | self.model = None 56 | self.x_test = None 57 | self.y_test = None 58 | self.training_set_size = None 59 | self.img_mean_mode = img_mean_mode 60 | self.distillation = distillation 61 | self.data_dir = os.path.join(ROOT_DIR, "../../datasets") if data_dir is None else "../%s" % data_dir 62 | self.first_run = first_run 63 | 64 | # Loss func init 65 | self.classification_loss = None 66 | self.contrastive_loss = None 67 | self.distillation_loss = None 68 | 69 | # Extra config 70 | self.train_dataset = train_dataset 71 | self.test_dataset = test_dataset 72 | if not wait: 73 | self.activate() 74 | 75 | def activate(self): 76 | """ 77 | Call this function to manually start dataset deployment 78 | """ 79 | 80 | # Deploy the dataset 81 | if "COCO" == self.train_dataset: 82 | self.dataset = Datasets.load_COCO(first_run=self.first_run, train=True, img_mean_mode=self.img_mean_mode, distillation=self.distillation, data_dir=self.data_dir) 83 | self.x_test, self.y_test = Datasets.load_COCO(first_run=False, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir) 84 | 85 | elif "FullDomainNet" in self.train_dataset: 86 | subset = self.train_dataset.split(":")[1] 87 | self.dataset = Datasets.load_FullDomainNet(subset, train=True, img_mean_mode=self.img_mean_mode, distillation=self.distillation, data_dir=self.data_dir) 88 | self.x_test, self.y_test = Datasets.load_FullDomainNet(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir) 89 | 90 | elif "PACS" in self.train_dataset: 91 | subset = self.train_dataset.split(":")[1] 92 | self.dataset = Datasets.load_PACS(subset, train=True, img_mean_mode=self.img_mean_mode, distillation=self.distillation, data_dir=self.data_dir) 93 | self.x_test, self.y_test = Datasets.load_PACS(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir) 94 | 95 | else: 96 | assert False, "Train dataset: %s is not supported yet!" % self.train_dataset 97 | 98 | # Support for multiple test sets 99 | self.x_test = {self.train_dataset: self.x_test} 100 | self.y_test = {self.train_dataset: self.y_test} 101 | 102 | for test_set in self.test_dataset: 103 | 104 | if "FullDomainNet" in self.train_dataset: 105 | subset = test_set.split(":")[1] 106 | x_temp, y_temp = Datasets.load_FullDomainNet(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir) 107 | self.x_test[test_set] = x_temp 108 | self.y_test[test_set] = y_temp 109 | 110 | elif "DomainNet" in test_set: 111 | subset = test_set.split(":")[1] 112 | x_temp, y_temp = Datasets.load_DomainNet(subset, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir) 113 | self.x_test[test_set] = x_temp 114 | self.y_test[test_set] = y_temp 115 | 116 | elif "PACS" in test_set: 117 | subset = test_set.split(":")[1] 118 | x_temp, y_temp = Datasets.load_PACS(subset, train=False, img_mean_mode=self.img_mean_mode, data_dir=self.data_dir) 119 | self.x_test[test_set] = x_temp 120 | self.y_test[test_set] = y_temp 121 | 122 | else: 123 | assert False, "Test dataset: %s is not supported yet!" % test_set 124 | 125 | def record_generalization_results(self, result_dict, path="generalization.json"): 126 | if result_dict is not None: 127 | # Load the previous records if exist 128 | hist_cache = {} 129 | if os.path.isfile(path): 130 | with open(path, "r") as hist_file: 131 | hist_cache = json.load(hist_file) 132 | 133 | # Record new results 134 | for model_name in result_dict: 135 | if model_name in hist_cache: 136 | hist_cache[model_name].append(result_dict[model_name]) 137 | else: 138 | hist_cache[model_name] = [result_dict[model_name]] 139 | 140 | # Save the updated records 141 | with open(path, "w+") as hist_file: 142 | json.dump(hist_cache, hist_file) 143 | 144 | def get_n_classes(self): 145 | 146 | if self.train_dataset == "COCO": 147 | return 10 148 | 149 | elif "FullDomainNet" in self.train_dataset: 150 | return 345 151 | 152 | elif "PACS" in self.train_dataset: 153 | return 7 154 | 155 | else: 156 | assert False, "Error: update tester.get_n_classes() for %s dataset" % self.train_dataset 157 | 158 | def get_optimizer(self, optimizer, lr=None, momentum=None, weight_decay=0.): 159 | selected_optimizer = optimizer.lower() 160 | if selected_optimizer == "adam": 161 | return torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-7, weight_decay=weight_decay) 162 | elif selected_optimizer == "sgd": 163 | return torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) 164 | elif selected_optimizer == "nsgd": 165 | return torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=weight_decay) 166 | else: 167 | return None 168 | 169 | def get_contrastive_element(self, loss): 170 | if loss == "AttentionConsistency": 171 | return "CAM" 172 | 173 | elif loss == "JSDivergence": 174 | return "Predictions" 175 | 176 | else: 177 | raise ValueError("Unsupported contrastive loss") 178 | 179 | def get_loss_func(self, loss, temperature=None): 180 | if loss == "CrossEntropy": 181 | return torch.nn.CrossEntropyLoss() 182 | 183 | elif loss == "AttentionConsistency": 184 | return AttentionConsistency(T=temperature) 185 | 186 | elif loss == "Distillation": 187 | assert temperature is not None, "Distillation requires temperature as an argument" 188 | return Distillation(temperature=temperature) 189 | 190 | elif loss == "JSDivergence": 191 | return JSDivergence() 192 | 193 | else: 194 | return None 195 | 196 | def get_multi_class_loss(self, outputs, y1, y2, l_param): 197 | return torch.mean(self.classification_loss(outputs, y1) * l_param + self.classification_loss(outputs, y2) * (1. - l_param)) 198 | 199 | def test(self, x, y): 200 | self.model.eval() 201 | 202 | with torch.no_grad(): 203 | preds = [] 204 | for i in range(x.shape[0]): 205 | img = np.copy(x[i]) 206 | img = Image.fromarray(img) 207 | img = Datasets.normalize_dataset(img, img_mean_mode=self.img_mean_mode).cuda() 208 | pred = self.model(img[None, ...])[-1]['Predictions'] 209 | pred = pred.cpu().data.numpy() 210 | preds.append(pred) 211 | 212 | preds = np.array(preds) 213 | accuracy = compute_accuracy(predictions=preds, labels=y) 214 | 215 | return accuracy 216 | 217 | def validate(self): 218 | self.model.eval() 219 | batch_count = self.valGenerator.get_batch_count() 220 | 221 | with torch.no_grad(): 222 | val_acc = 0 223 | val_loss = 0 224 | for i in range(batch_count): 225 | batches = self.valGenerator.get_batch() 226 | 227 | for x_val, y_val in batches: 228 | images = x_val.cuda() 229 | labels = torch.from_numpy(y_val).long().cuda() 230 | outputs, end_points = self.model(images) 231 | 232 | predictions = end_points['Predictions'] 233 | predictions = predictions.cpu().data.numpy() 234 | 235 | val_loss += self.classification_loss(outputs, labels).cpu().item() 236 | val_acc += compute_accuracy(predictions=predictions, labels=y_val) 237 | 238 | # Switch back to training mode 239 | self.model.train() 240 | 241 | return val_loss / batch_count, val_acc / batch_count 242 | 243 | def train(self, epochs, multi_class=False): 244 | hist = {"acc": [], "loss": [], "val_acc": [], "val_loss": []} 245 | batch_count = self.trainGenerator.get_batch_count() 246 | 247 | for epoch in range(epochs): 248 | train_acc = 0 249 | train_loss = 0 250 | self.model.train() 251 | for i in range(batch_count): 252 | batches = self.trainGenerator.get_batch(epoch) 253 | losses = [] 254 | 255 | end_points_list = {} 256 | for x_train, y_train in batches: 257 | # Forward with the adapted parameters 258 | inputs = x_train.cuda() 259 | outputs, end_points = self.model(x=inputs) 260 | 261 | # Knowledge distillation loss 262 | if self.distillation_loss is not None: 263 | teacher_logits = torch.from_numpy(y_train[1]).cuda().squeeze() 264 | y_train = y_train[0] 265 | losses.append(self.distillation_loss(outputs, teacher_logits)) 266 | 267 | # Accumulate orig + augmented image representation/embedding if there is a contrastive loss 268 | if self.contrastive_loss is not None: 269 | for loss in self.contrastive_loss: 270 | contrastive_element = self.get_contrastive_element(loss.name) 271 | if contrastive_element not in end_points_list: 272 | end_points_list[contrastive_element] = [] 273 | end_points_list[contrastive_element].append(end_points[contrastive_element]) 274 | 275 | # Classification loss 276 | if multi_class: 277 | y1 = torch.from_numpy(y_train[0]).long().cuda() 278 | y2 = torch.from_numpy(y_train[1]).long().cuda() 279 | l_param = torch.from_numpy(y_train[2]).cuda() 280 | loss = self.get_multi_class_loss(outputs, y1, y2, l_param) 281 | y_train = y_train[0] 282 | 283 | else: 284 | labels = torch.from_numpy(y_train).long().cuda() 285 | loss = self.classification_loss(outputs, labels) 286 | losses.append(loss) 287 | train_loss += loss.cpu().item() 288 | 289 | # Acc 290 | predictions = end_points['Predictions'] 291 | predictions = predictions.cpu().data.numpy() 292 | train_acc += compute_accuracy(predictions=predictions, labels=y_train) 293 | 294 | # Contrastive loss 295 | if self.contrastive_loss is not None: 296 | for loss in self.contrastive_loss: 297 | contrastive_element = self.get_contrastive_element(loss.name) 298 | f0 = end_points_list[contrastive_element][0] 299 | losses.append(loss(f0, end_points_list[contrastive_element][1:], y_train)) 300 | 301 | # Init the grad to zeros first 302 | self.optimizer.zero_grad() 303 | 304 | # Backward your network 305 | loss = sum(losses) 306 | loss.backward() 307 | 308 | # Optimize the parameters 309 | self.optimizer.step() 310 | 311 | # Learning rate scheduler 312 | self.scheduler.step() 313 | 314 | # Validation 315 | val_loss, val_acc = self.validate() 316 | 317 | # Record learning history 318 | hist["acc"].append(train_acc / batch_count) 319 | hist["loss"].append(train_loss / batch_count) 320 | hist["val_acc"].append(val_acc) 321 | hist["val_loss"].append(val_loss) 322 | 323 | return hist 324 | 325 | def run(self, 326 | model, 327 | name, 328 | optimizer="adam", 329 | lr=1e-3, 330 | momentum=None, 331 | weight_decay=.0, 332 | loss="CrossEntropy", 333 | batch_size=128, 334 | epochs=200, 335 | corruption_mode=None, 336 | corruption_dist="uniform", 337 | orig_plus_aug=True, 338 | temperature=1.0, 339 | rand_aug=False): 340 | """ 341 | Runs the benchmark 342 | 343 | # Arguments 344 | :param model: PyTorch model 345 | :param name: (str) model name 346 | :param optimizer: (str) name of the selected Keras optimizer 347 | :param lr: (float) learning rate 348 | :param momentum: (float) only relevant for the optimization algorithms that use momentum 349 | :param weight_decay: (float) only relevant for the optimization algorithms that use weight decay 350 | :param loss: (str | list) name of the selected Keras loss function(s) 351 | :param batch_size: (int) # of inputs in a mini-batch 352 | :param epochs: (int) # of full training passes 353 | :param corruption_mode: (str) requied for VisCo experiments 354 | :param corruption_dist: (str) requied for VisCo experiments 355 | :param orig_plus_aug: (bool) requied for VisCo experiments 356 | :param temperature: (float) temperature for contrastive distillation loss 357 | :param rand_aug: (bool) enable/disable random data augmentation 358 | 359 | :return: (history, score) 360 | """ 361 | # Custom configurations 362 | self.model = model 363 | self.model.cuda() 364 | is_contrastive = False 365 | multi_class = False 366 | 367 | # Set the optimizer 368 | self.optimizer = self.get_optimizer(optimizer, lr=lr, momentum=momentum, weight_decay=weight_decay) 369 | 370 | # Set the learning rate scheduler 371 | self.scheduler = LR_Scheduler(lr, self.train_dataset, self.optimizer) 372 | 373 | # Set loss functions 374 | if type(loss) is list: 375 | self.classification_loss = self.get_loss_func(loss[0]) 376 | self.contrastive_loss = [] 377 | for i in range(1, len(loss)): 378 | self.contrastive_loss.append(self.get_loss_func(loss[i], temperature=temperature)) 379 | is_contrastive = True 380 | else: 381 | self.classification_loss = self.get_loss_func(loss) 382 | if self.distillation: 383 | self.distillation_loss = self.get_loss_func("Distillation", temperature=temperature) 384 | 385 | # Training data generator 386 | corruption_mode = None if corruption_mode is None else corruption_mode.lower() 387 | if corruption_mode is None: 388 | self.trainGenerator = ImageGenerator(dataset=self.dataset, 389 | batch_size=batch_size, 390 | stage="train", 391 | img_mean_mode=self.img_mean_mode, 392 | seed=13) 393 | 394 | elif corruption_mode == "mixup": 395 | multi_class = True 396 | self.trainGenerator = MixUpGenerator(dataset=self.dataset, 397 | batch_size=batch_size, 398 | stage="train", 399 | img_mean_mode=self.img_mean_mode, 400 | seed=13, 401 | orig_plus_aug=orig_plus_aug) 402 | 403 | elif corruption_mode == "cutout": 404 | self.trainGenerator = CutOutGenerator(dataset=self.dataset, 405 | batch_size=batch_size, 406 | stage="train", 407 | img_mean_mode=self.img_mean_mode, 408 | seed=13, 409 | orig_plus_aug=orig_plus_aug) 410 | 411 | elif corruption_mode == "cutmix": 412 | multi_class = True 413 | self.trainGenerator = CutMixGenerator(dataset=self.dataset, 414 | batch_size=batch_size, 415 | stage="train", 416 | img_mean_mode=self.img_mean_mode, 417 | seed=13, 418 | orig_plus_aug=orig_plus_aug) 419 | 420 | elif "augmix" == corruption_mode: 421 | self.trainGenerator = AugMixGenerator(dataset=self.dataset, 422 | batch_size=batch_size, 423 | stage="train", 424 | img_mean_mode=self.img_mean_mode, 425 | seed=13, 426 | orig_plus_aug=orig_plus_aug) 427 | 428 | elif "randaugment" in corruption_mode: 429 | self.trainGenerator = RandAugmentGenerator(dataset=self.dataset, 430 | batch_size=batch_size, 431 | stage="train", 432 | corruption_mode=corruption_mode, 433 | img_mean_mode=self.img_mean_mode, 434 | seed=13, 435 | orig_plus_aug=orig_plus_aug) 436 | 437 | elif "acvc" == corruption_mode or "vc" == corruption_mode: 438 | self.trainGenerator = ACVCGenerator(dataset=self.dataset, 439 | batch_size=batch_size, 440 | stage="train", 441 | epochs=epochs, 442 | corruption_mode=corruption_mode, 443 | corruption_dist=corruption_dist, 444 | img_mean_mode=self.img_mean_mode, 445 | rand_aug=rand_aug, 446 | seed=13, 447 | orig_plus_aug=orig_plus_aug) 448 | 449 | else: 450 | self.trainGenerator = AblationGenerator(dataset=self.dataset, 451 | batch_size=batch_size, 452 | stage="train", 453 | epochs=epochs, 454 | corruption_mode=corruption_mode, 455 | corruption_dist=corruption_dist, 456 | img_mean_mode=self.img_mean_mode, 457 | rand_aug=rand_aug, 458 | seed=13, 459 | orig_plus_aug=orig_plus_aug) 460 | 461 | # Validation data generator 462 | self.valGenerator = ImageGenerator(dataset=(self.x_test[self.train_dataset], self.y_test[self.train_dataset]), 463 | batch_size=batch_size, 464 | stage="test") 465 | 466 | # Train the model 467 | hist = self.train(epochs=epochs, multi_class=multi_class) 468 | 469 | # Evaluate the model with the test datasets 470 | log("-----------------------------------------------------------------------------") 471 | learning_mode = "contrastive" if is_contrastive else "vanilla" 472 | key = "%s[%s]" % (name, learning_mode) 473 | generalization_results = {} 474 | scores = [] 475 | for test_set in self.x_test: 476 | score = self.test(self.x_test[test_set], self.y_test[test_set]) 477 | generalization_results[test_set] = score 478 | log("%s %s Test accuracy: %.4f" % (name, test_set, score)) 479 | 480 | # Avg. single-source domain generalization accuracy 481 | if test_set != self.train_dataset: 482 | scores.append(score) 483 | avg_score = float(np.mean(np.array(scores))) 484 | log("%s Avg. DG Test accuracy: %.4f" % (name, avg_score)) 485 | 486 | if generalization_results == {}: 487 | generalization_results = None 488 | else: 489 | generalization_results["history"] = hist 490 | generalization_results = {key: generalization_results} 491 | self.record_generalization_results(generalization_results) 492 | log("-----------------------------------------------------------------------------") 493 | score = self.test(self.x_test[self.train_dataset], self.y_test[self.train_dataset]) 494 | 495 | return hist, score -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | import itertools 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | from sklearn.metrics import confusion_matrix, accuracy_score 9 | from time import localtime, strftime 10 | from enum import Enum 11 | 12 | logging.basicConfig(filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), "experiment.log"), filemode="a", format="%(message)s", level=logging.INFO) 13 | 14 | # Limit unwanted logging messages from packages 15 | warnings.filterwarnings("ignore", category=DeprecationWarning) 16 | matplotlib_logger = logging.getLogger('matplotlib') 17 | matplotlib_logger.setLevel(logging.ERROR) 18 | 19 | class LogType(Enum): 20 | DEBUG = 0 21 | INFO = 1 22 | WARNING = 2 23 | ERROR = 3 24 | 25 | def log(msg, log_type=LogType.INFO, to_file=True, to_stdout=True): 26 | msg = "%s %s" % (get_time(), msg) 27 | 28 | if to_stdout: 29 | print(msg) 30 | if to_file and log_type == LogType.DEBUG: 31 | logging.debug(msg) 32 | elif to_file and log_type == LogType.INFO: 33 | logging.info(msg) 34 | elif to_file and log_type == LogType.WARNING: 35 | logging.warning(msg) 36 | elif to_file and log_type == LogType.ERROR: 37 | logging.error(msg) 38 | 39 | def log_config(config): 40 | log("Active Configuration:") 41 | log("--------------------") 42 | for key in config: 43 | residual = 24 - len(key) 44 | temp = "" 45 | while len(temp) < residual: 46 | temp += " " 47 | log("%s%s: %s" % (key, temp, config[key])) 48 | 49 | def to_scientific(x): 50 | return "{:.0e}".format(x) 51 | 52 | def get_time(): 53 | return "[%s]" % strftime("%a, %d %b %Y %X", localtime()) 54 | 55 | def get_arch_name(arch, depth=""): 56 | name = "Unknown" 57 | 58 | if arch == "resnet": 59 | name = "ResNet%s" % depth 60 | 61 | return name 62 | 63 | def plot_learning_curve(training_hist, chart_path): 64 | """ 65 | Plots the learning curve of the given training history 66 | 67 | # Arguments 68 | :param training_hist: (hist.history) of keras.models.Model.fit 69 | :param chart_path: (String) file path for the output chart 70 | """ 71 | is_ok = True 72 | 73 | # Error handler for missing values 74 | for key in ["acc", "loss", "val_acc", "val_loss"]: 75 | if key not in training_hist: 76 | is_ok = False 77 | 78 | if is_ok: 79 | # Starting building the learning curve graph 80 | fig, ax1 = plt.subplots(figsize=(14, 9)) 81 | epoch_list = list(range(1, len(training_hist['acc']) + 1)) 82 | 83 | # Plotting training and test losses 84 | train_loss, = ax1.plot(epoch_list, training_hist['loss'], color='red', alpha=.5) 85 | if "loss_std" in training_hist: 86 | ax1.fill_between(epoch_list, 87 | training_hist['loss'] + training_hist['loss_std'], 88 | training_hist['loss'] - training_hist['loss_std'], 89 | color="red", 90 | alpha=.3) 91 | val_loss, = ax1.plot(epoch_list, training_hist['val_loss'], linewidth=2, color='green') 92 | if "val_loss_std" in training_hist: 93 | ax1.fill_between(epoch_list, 94 | training_hist['val_loss'] + training_hist['val_loss_std'], 95 | training_hist['val_loss'] - training_hist['val_loss_std'], 96 | color="green", 97 | alpha=.3) 98 | ax1.set_xlabel('Epochs') 99 | ax1.set_ylabel('Loss') 100 | 101 | # Plotting test accuracy 102 | ax2 = ax1.twinx() 103 | train_accuracy, = ax2.plot(epoch_list, training_hist['acc'], linewidth=1, color='orange') 104 | if "acc_std" in training_hist: 105 | ax2.fill_between(epoch_list, 106 | training_hist['acc'] + training_hist['acc_std'], 107 | training_hist['acc'] - training_hist['acc_std'], 108 | color="orange", 109 | alpha=.3) 110 | val_accuracy, = ax2.plot(epoch_list, training_hist['val_acc'], linewidth=2, color='blue') 111 | if "val_acc_std" in training_hist: 112 | ax2.fill_between(epoch_list, 113 | training_hist['val_acc'] + training_hist['val_acc_std'], 114 | training_hist['val_acc'] - training_hist['val_acc_std'], 115 | color="blue", 116 | alpha=.3) 117 | ax2.set_ylim(bottom=0, top=1) 118 | ax2.set_ylabel('Accuracy') 119 | 120 | # Adding legend 121 | plt.legend([train_loss, val_loss, train_accuracy, val_accuracy], ['Training Loss', 'Validation Loss', 'Training Accuracy', 'Validation Accuracy'], loc=7, bbox_to_anchor=(1, 0.8)) 122 | plt.title('Learning Curve') 123 | 124 | # Saving learning curve 125 | plt.savefig(chart_path) 126 | plt.close(fig) 127 | 128 | def plot_confusion_matrix(y_test, y_preds, chart_path, n_classes, class_labels=None): 129 | class_labels = [""]*n_classes if class_labels is None else class_labels 130 | 131 | #Generate the normalized confusion matrix 132 | cm = confusion_matrix(y_test, y_preds) 133 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 134 | 135 | fig = plt.figure(figsize=(33, 26)) 136 | plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap('Blues')) 137 | plt.title("Confusion Matrix") 138 | plt.colorbar() 139 | tick_marks = np.arange(n_classes) 140 | plt.xticks(tick_marks, class_labels, rotation=30) 141 | plt.yticks(tick_marks, class_labels) 142 | thresh = cm.max() / 2. 143 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 144 | plt.text(j, i, format(cm[i, j], '.1f'), 145 | horizontalalignment="center", 146 | color="white" if cm[i, j] > thresh else "black") 147 | plt.tight_layout() 148 | plt.ylabel('True label') 149 | plt.xlabel('Predicted label') 150 | 151 | # Saving learning curve 152 | plt.savefig(chart_path) 153 | plt.close(fig) 154 | 155 | def resize_image(img, target_dim): 156 | new_img = img.resize(target_dim, Image.ANTIALIAS) 157 | return new_img 158 | 159 | def shuffle_data_old(samples, labels, segmentation_masks=None, seed=13): 160 | num = len(labels) 161 | shuffle_index = np.random.RandomState(seed=seed).permutation(np.arange(num)) 162 | shuffled_samples = samples[shuffle_index] 163 | shuffled_labels = labels[shuffle_index] 164 | shuffled_masks = None if segmentation_masks is None else segmentation_masks[shuffle_index] 165 | return shuffled_samples, shuffled_labels, shuffled_masks 166 | 167 | def shuffle_data(samples, labels, teacher_logits=None, segmentation_masks=None, seed=13): 168 | np.random.seed(seed) 169 | random_state = np.random.get_state() 170 | np.random.shuffle(samples) 171 | np.random.set_state(random_state) 172 | np.random.shuffle(labels) 173 | 174 | if segmentation_masks is not None: 175 | np.random.set_state(random_state) 176 | np.random.shuffle(segmentation_masks) 177 | 178 | if teacher_logits is not None: 179 | np.random.set_state(random_state) 180 | np.random.shuffle(teacher_logits) 181 | 182 | return samples, labels, teacher_logits, segmentation_masks 183 | 184 | def get_contrastive_loss(loss, orig_plus_corrupted=False): 185 | if type(loss) is list: 186 | return loss[1] 187 | 188 | else: 189 | if orig_plus_corrupted: 190 | return "orig_plus_corrupted" 191 | 192 | else: 193 | return "None" 194 | 195 | def compute_accuracy(predictions, labels): 196 | if np.ndim(labels) == 2: 197 | y_true = np.argmax(labels, axis=-1) 198 | else: 199 | y_true = labels 200 | accuracy = accuracy_score(y_true=y_true, y_pred=np.argmax(predictions, axis=-1)) 201 | return accuracy 202 | --------------------------------------------------------------------------------