├── .gitignore ├── .gitmodules ├── README.md ├── __pycache__ ├── cifar.cpython-36.pyc └── wideresnet.cpython-36.pyc ├── cifar ├── AutoAugment │ └── autoaugment.py ├── cifar.py ├── data │ └── cifar-10-batches-py │ │ └── cifar_label_map_count_4000_index_0 ├── main.py ├── run_uda.sh └── wideresnet.py └── imagenet ├── autoaugment.py ├── data_split ├── labeled_images_0.10.pth └── unlabeled_images_0.90.pth ├── imagenet_dataset.py ├── logs ├── baseline_resnet18_S4L_bs256 ├── baseline_resnet18_S4L_bs512 └── baseline_resnet50_S4L_bs256 ├── scripts ├── run_baseline_resnet18.sh ├── run_baseline_resnet34.sh └── run_baseline_resnet50.sh ├── separate_labeled_unlabeled.py └── train_imagenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/ 2 | *.pyc 3 | *.log 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jongchan/unsupervised_data_augmentation_pytorch/7fce04a05c2da4ca98de32bfc305bb99f511915e/.gitmodules -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Data Augmentation (UDA) 2 | A PyTorch implementation for [Unsupervised Data Augmentation](https://arxiv.org/abs/1904.12848). 3 | 4 | ## Disclaimer 5 | 6 | * This is not an official implementation. The official TensorFlow implementation is at this [Github link](https://github.com/google-research/uda). 7 | * Plan to implement CIFAR10 and ImageNet experiments. 8 | 9 | ## Updates 10 | 11 | - **2019.06.28**: CIFAR-10 with 4,000 labeled set achieves top-1 accuracy **93.69%** without TSA. (on paper, 94.33% without TSA) 12 | 13 | ## Performance 14 | 15 | ### CIFAR-10 16 | 17 | | Exp | Top-1 acc(%) in paper | Top-1 acc(%) | 18 | |-------------------|-----------------------|--------------| 19 | | Baseline | 79.74 | 83.94 | 20 | | UDA (without TSA) | 94.33 | 93.69 | 21 | | UDA | 94.90 | - | 22 | 23 | ### ImageNet (10% labeled) 24 | 25 | | Exp | Top-1 (paper) | Top-5 (paper) | Top-1 | Top-5 | 26 | |----------|---------------|-----------------------|--------|--------| 27 | | RN50 | 55.09 | 77.26 (80.43 in S4L) | 54.184 | 79.116 | 28 | | RN18 | - | - | 50.594 | 76.138 | 29 | | UDA(RN50)| 68.66 | 88.52 | - | - | 30 | | S4L(RN50)| - | 91.23 (ResNet50v2 4x) | - | - | 31 | 32 | ## TODO List 33 | 34 | - [x] CIFAR-10 baseline & UDA validation 35 | - [x] ImageNet ResNet50 baseline validation 36 | - [ ] ImageNet ResNet50 UDA validation 37 | 38 | ## MISC 39 | 40 | - CIFAR10 baseline on paper is from [Realistic Evaluation of Deep Semi-Supervised Learning Algorithms](https://papers.nips.cc/paper/7585-realistic-evaluation-of-deep-semi-supervised-learning-algorithms), and it may be sub-optimal OR use different data split from the UDA paper. A naive baseline with weight decay 5e-4 and 100K iteration with cosine annealing LR can achieve higher performance as shown in the table. 41 | - CIFAR10 labeled set is from AutoAugment policy search subset. 42 | - CIFAR10 AutoAugment policy includes full set (95 policies), rather than 25 policies. 43 | - ImageNet labeled set is randomly selected 10% for each class. 44 | - ImageNet baseline settings are from [S4L: Self-Supervised Semi-Supervised Learning](https://arxiv.org/abs/1905.03670). 45 | -------------------------------------------------------------------------------- /__pycache__/cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jongchan/unsupervised_data_augmentation_pytorch/7fce04a05c2da4ca98de32bfc305bb99f511915e/__pycache__/cifar.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/wideresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jongchan/unsupervised_data_augmentation_pytorch/7fce04a05c2da4ca98de32bfc305bb99f511915e/__pycache__/wideresnet.cpython-36.pyc -------------------------------------------------------------------------------- /cifar/AutoAugment/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | def create_cutout_mask(img_height, img_width, num_channels, size): 5 | """Creates a zero mask used for cutout of shape `img_height` x `img_width`. 6 | Args: 7 | img_height: Height of image cutout mask will be applied to. 8 | img_width: Width of image cutout mask will be applied to. 9 | num_channels: Number of channels in the image. 10 | size: Size of the zeros mask. 11 | Returns: 12 | A mask of shape `img_height` x `img_width` with all ones except for a 13 | square of zeros of shape `size` x `size`. This mask is meant to be 14 | elementwise multiplied with the original image. Additionally returns 15 | the `upper_coord` and `lower_coord` which specify where the cutout mask 16 | will be applied. 17 | """ 18 | assert img_height == img_width 19 | 20 | # Sample center where cutout mask will be applied 21 | height_loc = np.random.randint(low=0, high=img_height) 22 | width_loc = np.random.randint(low=0, high=img_width) 23 | 24 | size = int(size) 25 | # Determine upper right and lower left corners of patch 26 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 27 | lower_coord = (min(img_height, height_loc + size // 2), 28 | min(img_width, width_loc + size // 2)) 29 | mask_height = (lower_coord[0] - upper_coord[0]) 30 | mask_width = (lower_coord[1] - upper_coord[1]) 31 | assert mask_height > 0 32 | assert mask_width > 0 33 | 34 | mask = np.ones((img_height, img_width, num_channels)) 35 | zeros = np.zeros((mask_height, mask_width, num_channels)) 36 | mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = ( 37 | zeros) 38 | return mask, upper_coord, lower_coord 39 | 40 | def _cutout_pil_impl(pil_img, level): 41 | """Apply cutout to pil_img at the specified level.""" 42 | img_height, img_width, num_channels = (32, 32, 3) 43 | _, upper_coord, lower_coord = ( 44 | create_cutout_mask(img_height, img_width, num_channels, level)) 45 | pixels = pil_img.load() # create the pixel map 46 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 47 | for j in range(upper_coord[1], lower_coord[1]): # For every row 48 | pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly 49 | return pil_img 50 | 51 | class ImageNetPolicy(object): 52 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 53 | 54 | Example: 55 | >>> policy = ImageNetPolicy() 56 | >>> transformed = policy(image) 57 | 58 | Example as a PyTorch Transform: 59 | >>> transform=transforms.Compose([ 60 | >>> transforms.Resize(256), 61 | >>> ImageNetPolicy(), 62 | >>> transforms.ToTensor()]) 63 | """ 64 | def __init__(self, fillcolor=(128, 128, 128)): 65 | self.policies = [ 66 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 67 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 68 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 69 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 70 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 71 | 72 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 73 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 74 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 75 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 76 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 77 | 78 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 79 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 80 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 81 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 82 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 83 | 84 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 85 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 86 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 87 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 88 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 89 | 90 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 91 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 92 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 93 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 94 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 95 | ] 96 | 97 | 98 | def __call__(self, img): 99 | policy_idx = random.randint(0, len(self.policies) - 1) 100 | return self.policies[policy_idx](img) 101 | 102 | def __repr__(self): 103 | return "AutoAugment ImageNet Policy" 104 | 105 | class CIFAR10PolicyAll(object): 106 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 107 | 108 | Example: 109 | >>> policy = CIFAR10Policy() 110 | >>> transformed = policy(image) 111 | 112 | Example as a PyTorch Transform: 113 | >>> transform=transforms.Compose([ 114 | >>> transforms.Resize(256), 115 | >>> CIFAR10Policy(), 116 | >>> transforms.ToTensor()]) 117 | """ 118 | def __init__(self, fillcolor=(128, 128, 128)): 119 | self.policies = [ 120 | SubPolicy(0.1, "Invert", 7, 0.2, "Contrast", 6, fillcolor), 121 | SubPolicy(0.7, "Rotate", 2, 0.3, "TranslateX", 9, fillcolor), 122 | SubPolicy(0.8, "Sharpness", 1, 0.9, "Sharpness", 3, fillcolor), 123 | SubPolicy(0.5, "ShearY", 8, 0.7, "TranslateY", 9, fillcolor), 124 | SubPolicy(0.5, "AutoContrast", 8, 0.9, "Equalize", 2, fillcolor), 125 | SubPolicy(0.4, "Solarize", 5, 0.9, "AutoContrast", 3, fillcolor), 126 | SubPolicy(0.9, "TranslateY", 9, 0.7, "TranslateY", 9, fillcolor), 127 | SubPolicy(0.9, "AutoContrast", 2, 0.8, "Solarize", 3, fillcolor), 128 | SubPolicy(0.8, "Equalize", 8, 0.1, "Invert", 3, fillcolor), 129 | SubPolicy(0.7, "TranslateY", 9, 0.9, "AutoContrast", 1, fillcolor), 130 | SubPolicy(0.4, "Solarize", 5, 0.0, "AutoContrast", 2, fillcolor), 131 | SubPolicy(0.7, "TranslateY", 9, 0.7, "TranslateY", 9, fillcolor), 132 | SubPolicy(0.9, "AutoContrast", 0, 0.4, "Solarize", 3, fillcolor), 133 | SubPolicy(0.7, "Equalize", 5, 0.1, "Invert", 3, fillcolor), 134 | SubPolicy(0.7, "TranslateY", 9, 0.7, "TranslateY", 9, fillcolor), 135 | SubPolicy(0.4, "Solarize", 5, 0.9, "AutoContrast", 1, fillcolor), 136 | SubPolicy(0.8, "TranslateY", 9, 0.9, "TranslateY", 9, fillcolor), 137 | SubPolicy(0.8, "AutoContrast", 0, 0.7, "TranslateY", 9, fillcolor), 138 | SubPolicy(0.2, "TranslateY", 7, 0.9, "Color", 6, fillcolor), 139 | SubPolicy(0.7, "Equalize", 6, 0.4, "Color", 9, fillcolor), 140 | SubPolicy(0.2, "ShearY", 7, 0.3, "Posterize", 7, fillcolor), 141 | SubPolicy(0.4, "Color", 3, 0.6, "Brightness", 7, fillcolor), 142 | SubPolicy(0.3, "Sharpness", 9, 0.7, "Brightness", 9, fillcolor), 143 | SubPolicy(0.6, "Equalize", 5, 0.5, "Equalize", 1, fillcolor), 144 | SubPolicy(0.6, "Contrast", 7, 0.6, "Sharpness", 5, fillcolor), 145 | SubPolicy(0.3, "Brightness", 7, 0.5, "AutoContrast", 8, fillcolor), 146 | SubPolicy(0.9, "AutoContrast", 4, 0.5, "AutoContrast", 6, fillcolor), 147 | SubPolicy(0.3, "Solarize", 5, 0.6, "Equalize", 5, fillcolor), 148 | SubPolicy(0.2, "TranslateY", 4, 0.3, "Sharpness", 3, fillcolor), 149 | SubPolicy(0.0, "Brightness", 8, 0.8, "Color", 8, fillcolor), 150 | SubPolicy(0.2, "Solarize", 6, 0.8, "Color", 6, fillcolor), 151 | SubPolicy(0.2, "Solarize", 6, 0.8, "AutoContrast", 1, fillcolor), 152 | SubPolicy(0.4, "Solarize", 1, 0.6, "Equalize", 5, fillcolor), 153 | SubPolicy(0.0, "Brightness", 0, 0.5, "Solarize", 2, fillcolor), 154 | SubPolicy(0.9, "AutoContrast", 5, 0.5, "Brightness", 3, fillcolor), 155 | SubPolicy(0.7, "Contrast", 5, 0.0, "Brightness", 2, fillcolor), 156 | SubPolicy(0.2, "Solarize", 8, 0.1, "Solarize", 5, fillcolor), 157 | SubPolicy(0.5, "Contrast", 1, 0.2, "TranslateY", 9, fillcolor), 158 | SubPolicy(0.6, "AutoContrast", 5, 0.0, "TranslateY", 9, fillcolor), 159 | SubPolicy(0.9, "AutoContrast", 4, 0.8, "Equalize", 4, fillcolor), 160 | SubPolicy(0.0, "Brightness", 7, 0.4, "Equalize", 7, fillcolor), 161 | SubPolicy(0.2, "Solarize", 5, 0.7, "Equalize", 5, fillcolor), 162 | SubPolicy(0.6, "Equalize", 8, 0.6, "Color", 2, fillcolor), 163 | SubPolicy(0.3, "Color", 7, 0.2, "Color", 4, fillcolor), 164 | SubPolicy(0.5, "AutoContrast", 2, 0.7, "Solarize", 2, fillcolor), 165 | SubPolicy(0.2, "AutoContrast", 0, 0.1, "Equalize", 0, fillcolor), 166 | SubPolicy(0.6, "ShearY", 5, 0.6, "Equalize", 5, fillcolor), 167 | SubPolicy(0.9, "Brightness", 3, 0.4, "AutoContrast", 1, fillcolor), 168 | SubPolicy(0.8, "Equalize", 8, 0.7, "Equalize", 7, fillcolor), 169 | SubPolicy(0.7, "Equalize", 7, 0.5, "Solarize", 0, fillcolor), 170 | SubPolicy(0.8, "Equalize", 4, 0.8, "TranslateY", 9, fillcolor), 171 | SubPolicy(0.8, "TranslateY", 9, 0.6, "TranslateY", 9, fillcolor), 172 | SubPolicy(0.9, "TranslateY", 0, 0.5, "TranslateY", 9, fillcolor), 173 | SubPolicy(0.5, "AutoContrast", 3, 0.3, "Solarize", 4, fillcolor), 174 | SubPolicy(0.5, "Solarize", 3, 0.4, "Equalize", 4, fillcolor), 175 | SubPolicy(0.7, "Color", 7, 0.5, "TranslateX", 8, fillcolor), 176 | SubPolicy(0.3, "Equalize", 7, 0.4, "AutoContrast", 8, fillcolor), 177 | SubPolicy(0.4, "TranslateY", 3, 0.2, "Sharpness", 6, fillcolor), 178 | SubPolicy(0.9, "Brightness", 6, 0.2, "Color", 8, fillcolor), 179 | SubPolicy(0.5, "Solarize", 2, 0.0, "Invert", 3, fillcolor), 180 | SubPolicy(0.1, "AutoContrast", 5, 0.0, "Brightness", 0, fillcolor), 181 | SubPolicy(0.2, "Cutout", 4, 0.1, "Equalize", 1, fillcolor), 182 | SubPolicy(0.7, "Equalize", 7, 0.6, "AutoContrast", 4, fillcolor), 183 | SubPolicy(0.1, "Color", 8, 0.2, "ShearY", 3, fillcolor), 184 | SubPolicy(0.4, "ShearY", 2, 0.7, "Rotate", 0, fillcolor), 185 | SubPolicy(0.1, "ShearY", 3, 0.9, "AutoContrast", 5, fillcolor), 186 | SubPolicy(0.3, "TranslateY", 6, 0.3, "Cutout", 3, fillcolor), 187 | SubPolicy(0.5, "Equalize", 0, 0.6, "Solarize", 6, fillcolor), 188 | SubPolicy(0.3, "AutoContrast", 5, 0.2, "Rotate", 7, fillcolor), 189 | SubPolicy(0.8, "Equalize", 2, 0.4, "Invert", 0, fillcolor), 190 | SubPolicy(0.9, "Equalize", 5, 0.7, "Color", 0, fillcolor), 191 | SubPolicy(0.1, "Equalize", 1, 0.1, "ShearY", 3, fillcolor), 192 | SubPolicy(0.7, "AutoContrast", 3, 0.7, "Equalize", 0, fillcolor), 193 | SubPolicy(0.5, "Brightness", 1, 0.1, "Contrast", 7, fillcolor), 194 | SubPolicy(0.1, "Contrast", 4, 0.6, "Solarize", 5, fillcolor), 195 | SubPolicy(0.2, "Solarize", 3, 0.0, "ShearX", 0, fillcolor), 196 | SubPolicy(0.3, "TranslateX", 0, 0.6, "TranslateX", 0, fillcolor), 197 | SubPolicy(0.5, "Equalize", 9, 0.6, "TranslateY", 7, fillcolor), 198 | SubPolicy(0.1, "ShearX", 0, 0.5, "Sharpness", 1, fillcolor), 199 | SubPolicy(0.8, "Equalize", 6, 0.3, "Invert", 6, fillcolor), 200 | SubPolicy(0.3, "AutoContrast", 9, 0.5, "Cutout", 3, fillcolor), 201 | SubPolicy(0.4, "ShearX", 4, 0.9, "AutoContrast", 2, fillcolor), 202 | SubPolicy(0.0, "ShearX", 3, 0.0, "Posterize", 3, fillcolor), 203 | SubPolicy(0.4, "Solarize", 3, 0.2, "Color", 4, fillcolor), 204 | SubPolicy(0.1, "Equalize", 4, 0.7, "Equalize", 6, fillcolor), 205 | SubPolicy(0.3, "Equalize", 8, 0.4, "AutoContrast", 3, fillcolor), 206 | SubPolicy(0.6, "Solarize", 4, 0.7, "AutoContrast", 6, fillcolor), 207 | SubPolicy(0.2, "AutoContrast", 9, 0.4, "Brightness", 8, fillcolor), 208 | SubPolicy(0.1, "Equalize", 0, 0.0, "Equalize", 6, fillcolor), 209 | SubPolicy(0.8, "Equalize", 4, 0.0, "Equalize", 4, fillcolor), 210 | SubPolicy(0.5, "Equalize", 5, 0.1, "AutoContrast", 2, fillcolor), 211 | SubPolicy(0.5, "Solarize", 5, 0.9, "AutoContrast", 5, fillcolor), 212 | SubPolicy(0.6, "AutoContrast", 1, 0.7, "AutoContrast", 8, fillcolor), 213 | SubPolicy(0.2, "Equalize", 0, 0.1, "AutoContrast", 2, fillcolor), 214 | SubPolicy(0.6, "Equalize", 9, 0.4, "Equalize", 4, fillcolor), 215 | ] 216 | 217 | 218 | def __call__(self, img): 219 | policy_idx = random.randint(0, len(self.policies) - 1) 220 | return self.policies[policy_idx](img) 221 | 222 | def __repr__(self): 223 | return "AutoAugment CIFAR10 Policy" 224 | 225 | class CIFAR10Policy(object): 226 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 227 | 228 | Example: 229 | >>> policy = CIFAR10Policy() 230 | >>> transformed = policy(image) 231 | 232 | Example as a PyTorch Transform: 233 | >>> transform=transforms.Compose([ 234 | >>> transforms.Resize(256), 235 | >>> CIFAR10Policy(), 236 | >>> transforms.ToTensor()]) 237 | """ 238 | def __init__(self, fillcolor=(128, 128, 128)): 239 | self.policies = [ 240 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 241 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 242 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 243 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 244 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 245 | 246 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 247 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 248 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 249 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 250 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 251 | 252 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 253 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 254 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 255 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 256 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 257 | 258 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 259 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 260 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 261 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 262 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 263 | 264 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 265 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 266 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 267 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 268 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 269 | ] 270 | 271 | 272 | def __call__(self, img): 273 | policy_idx = random.randint(0, len(self.policies) - 1) 274 | return self.policies[policy_idx](img) 275 | 276 | def __repr__(self): 277 | return "AutoAugment CIFAR10 Policy" 278 | 279 | 280 | class SVHNPolicy(object): 281 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 282 | 283 | Example: 284 | >>> policy = SVHNPolicy() 285 | >>> transformed = policy(image) 286 | 287 | Example as a PyTorch Transform: 288 | >>> transform=transforms.Compose([ 289 | >>> transforms.Resize(256), 290 | >>> SVHNPolicy(), 291 | >>> transforms.ToTensor()]) 292 | """ 293 | def __init__(self, fillcolor=(128, 128, 128)): 294 | self.policies = [ 295 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 296 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 297 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 298 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 299 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 300 | 301 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 302 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 303 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 304 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 305 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 306 | 307 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 308 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 309 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 310 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 311 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 312 | 313 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 314 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 315 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 316 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 317 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 318 | 319 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 320 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 321 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 322 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 323 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 324 | ] 325 | 326 | 327 | def __call__(self, img): 328 | policy_idx = random.randint(0, len(self.policies) - 1) 329 | return self.policies[policy_idx](img) 330 | 331 | def __repr__(self): 332 | return "AutoAugment SVHN Policy" 333 | 334 | 335 | class SubPolicy(object): 336 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 337 | ranges = { 338 | "shearx": np.linspace(0, 0.3, 10), 339 | "sheary": np.linspace(0, 0.3, 10), 340 | "translatex": np.linspace(0, 150 / 331, 10), 341 | "translatey": np.linspace(0, 150 / 331, 10), 342 | "rotate": np.linspace(0, 30, 10), 343 | "color": np.linspace(0.0, 0.9, 10), 344 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 345 | "solarize": np.linspace(256, 0, 10), 346 | "contrast": np.linspace(0.0, 0.9, 10), 347 | "sharpness": np.linspace(0.0, 0.9, 10), 348 | "brightness": np.linspace(0.0, 0.9, 10), 349 | "autocontrast": [0] * 10, 350 | "equalize": [0] * 10, 351 | "invert": [0] * 10, 352 | "cutout": np.round(np.linspace(0, 20, 10), 0).astype(np.int), 353 | } 354 | 355 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 356 | def rotate_with_fill(img, magnitude): 357 | rot = img.convert("RGBA").rotate(magnitude) 358 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 359 | 360 | func = { 361 | "shearx": lambda img, magnitude: img.transform( 362 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 363 | Image.BICUBIC, fillcolor=fillcolor), 364 | "sheary": lambda img, magnitude: img.transform( 365 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 366 | Image.BICUBIC, fillcolor=fillcolor), 367 | "translatex": lambda img, magnitude: img.transform( 368 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 369 | fillcolor=fillcolor), 370 | "translatey": lambda img, magnitude: img.transform( 371 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 372 | fillcolor=fillcolor), 373 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 374 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 375 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 376 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 377 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 378 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 379 | 1 + magnitude * random.choice([-1, 1])), 380 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 381 | 1 + magnitude * random.choice([-1, 1])), 382 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 383 | 1 + magnitude * random.choice([-1, 1])), 384 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 385 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 386 | "invert": lambda img, magnitude: ImageOps.invert(img), 387 | "cutout": lambda img, magnitude: _cutout_pil_impl(img, magnitude) 388 | } 389 | 390 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 391 | # operation1, ranges[operation1][magnitude_idx1], 392 | # operation2, ranges[operation2][magnitude_idx2]) 393 | self.p1 = p1 394 | self.operation1 = func[operation1.lower()] 395 | self.magnitude1 = ranges[operation1.lower()][magnitude_idx1] 396 | self.p2 = p2 397 | self.operation2 = func[operation2.lower()] 398 | self.magnitude2 = ranges[operation2.lower()][magnitude_idx2] 399 | 400 | 401 | def __call__(self, img): 402 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 403 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 404 | return img 405 | -------------------------------------------------------------------------------- /cifar/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | import torch 13 | import torch.utils.data as data 14 | from torchvision.datasets.utils import download_url, check_integrity 15 | 16 | import torchvision.transforms as transforms 17 | import AutoAugment.autoaugment as autoaugment 18 | import json 19 | 20 | class Cutout(object): 21 | """Randomly mask out one or more patches from an image. 22 | 23 | Args: 24 | n_holes (int): Number of patches to cut out of each image. 25 | length (int): The length (in pixels) of each square patch. 26 | """ 27 | def __init__(self, n_holes, length): 28 | self.n_holes = n_holes 29 | self.length = length 30 | 31 | def __call__(self, img): 32 | """ 33 | Args: 34 | img (Tensor): Tensor image of size (C, H, W). 35 | Returns: 36 | Tensor: Image with n_holes of dimension length x length cut out of it. 37 | """ 38 | h = img.size(1) 39 | w = img.size(2) 40 | 41 | mask = np.ones((h, w), np.float32) 42 | 43 | for n in range(self.n_holes): 44 | y = np.random.randint(h) 45 | x = np.random.randint(w) 46 | 47 | y1 = np.clip(y - self.length // 2, 0, h) 48 | y2 = np.clip(y + self.length // 2, 0, h) 49 | x1 = np.clip(x - self.length // 2, 0, w) 50 | x2 = np.clip(x + self.length // 2, 0, w) 51 | 52 | mask[y1: y2, x1: x2] = 0. 53 | 54 | mask = torch.from_numpy(mask) 55 | mask = mask.expand_as(img) 56 | img = img * mask 57 | 58 | return img 59 | 60 | class CIFAR10(data.Dataset): 61 | """`CIFAR10 `_ Dataset. 62 | 63 | Args: 64 | root (string): Root directory of dataset where directory 65 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 66 | train (bool, optional): If True, creates dataset from training set, otherwise 67 | creates from test set. 68 | transform (callable, optional): A function/transform that takes in an PIL image 69 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 70 | target_transform (callable, optional): A function/transform that takes in the 71 | target and transforms it. 72 | download (bool, optional): If true, downloads the dataset from the internet and 73 | puts it in root directory. If dataset is already downloaded, it is not 74 | downloaded again. 75 | 76 | """ 77 | base_folder = 'cifar-10-batches-py' 78 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 79 | filename = "cifar-10-python.tar.gz" 80 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 81 | train_list = [ 82 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 83 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 84 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 85 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 86 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 87 | ] 88 | 89 | test_list = [ 90 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 91 | ] 92 | meta = { 93 | 'filename': 'batches.meta', 94 | 'key': 'label_names', 95 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 96 | } 97 | 98 | def __init__(self, args, train=True, uda=False, normalize=False, add_labeled_to_unlabeled=False): 99 | 100 | #super(CIFAR10, self).__init__(root) 101 | self.normalize = normalize 102 | self.args= args 103 | self.root = os.path.expanduser('./data') 104 | self.uda = uda 105 | 106 | self.train = train 107 | 108 | if self.args.use_cutout: 109 | self.autoaugment = transforms.Compose([ 110 | transforms.ToTensor(), 111 | Cutout(n_holes=1, length=16), 112 | transforms.ToPILImage(), 113 | ]) 114 | elif self.args.UDA_CUTOUT: 115 | print ("USE UDA CUTOUT") 116 | self.autoaugment = transforms.Compose([ 117 | autoaugment.CIFAR10Policy() if not args.cifar10_policy_all else autoaugment.CIFAR10PolicyAll(), 118 | transforms.ToTensor(), 119 | Cutout(n_holes=1, length=16), 120 | transforms.ToPILImage(), 121 | ]) 122 | else: 123 | self.autoaugment = transforms.Compose([ 124 | autoaugment.CIFAR10Policy() if not args.cifar10_policy_all else autoaugment.CIFAR10PolicyAll(), 125 | ]) 126 | 127 | if self.args.AutoAugment and not uda: 128 | print ("labeled set autoaugment") 129 | self.autoaugment_labeled = transforms.Compose([ 130 | autoaugment.CIFAR10Policy() if not args.cifar10_policy_all else autoaugment.CIFAR10PolicyAll(), 131 | ]) 132 | elif self.args.AutoAugment_cutout_only and not uda: 133 | print ("labeled set autoaugment (cutout only)") 134 | self.autoaugment_labeled = transforms.Compose([ 135 | transforms.ToTensor(), 136 | Cutout(n_holes=1, length=16), 137 | transforms.ToPILImage(), 138 | ]) 139 | elif self.args.AutoAugment_all and not uda: 140 | print ("labeled set autoaugment (all)") 141 | self.autoaugment_labeled = transforms.Compose([ 142 | autoaugment.CIFAR10Policy() if not args.cifar10_policy_all else autoaugment.CIFAR10PolicyAll(), 143 | transforms.ToTensor(), 144 | Cutout(n_holes=1, length=16), 145 | transforms.ToPILImage(), 146 | ]) 147 | else: 148 | print ("labeled set no autoaugment") 149 | self.autoaugment_labeled = None 150 | 151 | 152 | 153 | 154 | if self.train: 155 | self.transform = transforms.Compose([ 156 | transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 157 | transforms.RandomHorizontalFlip(), 158 | transforms.ToTensor(), 159 | ]) 160 | 161 | if not self._check_integrity(): 162 | raise RuntimeError('Dataset not found or corrupted.' + 163 | ' You can use download=True to download it') 164 | 165 | if self.train: 166 | downloaded_list = self.train_list 167 | else: 168 | downloaded_list = self.test_list 169 | 170 | self.data = [] 171 | self.targets = [] 172 | 173 | # now load the picked numpy arrays 174 | for file_name, checksum in downloaded_list: 175 | file_path = os.path.join(self.root, self.base_folder, file_name) 176 | with open(file_path, 'rb') as f: 177 | if sys.version_info[0] == 2: 178 | entry = pickle.load(f) 179 | else: 180 | entry = pickle.load(f, encoding='latin1') 181 | self.data.append(entry['data']) 182 | if 'labels' in entry: 183 | self.targets.extend(entry['labels']) 184 | else: 185 | self.targets.extend(entry['fine_labels']) 186 | 187 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 188 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 189 | 190 | self._load_meta() 191 | 192 | if self.train: 193 | with open('./data/cifar-10-batches-py/cifar_label_map_count_4000_index_0', 'r') as f: 194 | label_map_str = f.readlines()[0] 195 | label_map = json.loads(label_map_str)['values'] 196 | label_map = [int(label) for label in label_map] 197 | 198 | if self.uda: 199 | if add_labeled_to_unlabeled: 200 | print ("UDA with labeled set") 201 | self.targets = None 202 | else: 203 | print ("UDA without labeled set") 204 | self.data = np.delete( self.data, label_map, axis=0 ) 205 | self.targets = None 206 | else: 207 | self.data = np.take(self.data, label_map, axis=0) 208 | self.targets = np.take(self.targets, label_map, axis=0) 209 | 210 | print ("loaded data count {}".format(len(self.data))) 211 | 212 | def _load_meta(self): 213 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 214 | if not check_integrity(path, self.meta['md5']): 215 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 216 | ' You can use download=True to download it') 217 | with open(path, 'rb') as infile: 218 | if sys.version_info[0] == 2: 219 | data = pickle.load(infile) 220 | else: 221 | data = pickle.load(infile, encoding='latin1') 222 | self.classes = data[self.meta['key']] 223 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 224 | 225 | def __getitem__(self, index): 226 | """ 227 | Args: 228 | index (int): Index 229 | 230 | Returns: 231 | tuple: (image, target) where target is index of the target class. 232 | """ 233 | 234 | ''' 235 | Labeled data 236 | X_raw --> [common preproc (filp/random crop)] --> GCN --> X 237 | 238 | Unlabeled data 239 | X_raw --> [common preproc (filp/random crop)] --> GCN --> X 240 | X_raw --> AutoAugment --> [common preproc (filp/random crop)] --> GCN --> X_aug 241 | 242 | Test data 243 | X_raw --> GCN --> X 244 | 245 | ''' 246 | index = index % len(self.data) 247 | 248 | if self.uda: 249 | # UNLABELED 250 | img_raw = self.data[index] # image with shape 32,32,3 251 | img_raw = Image.fromarray(img_raw) # PIL image 252 | 253 | img_uda = self.transform(self.autoaugment(img_raw)) # torch tensor shape 3,32,32 254 | img = self.transform(img_raw) # torch tensor shape 3,32,32 255 | return img.type(torch.FloatTensor), img_uda.type(torch.FloatTensor) 256 | else: 257 | 258 | img, target = self.data[index], self.targets[index] # image with shape 32,32,3 259 | img = Image.fromarray(img) # PIL image 260 | if self.train:#LABELED 261 | if self.autoaugment_labeled is not None: 262 | img = self.autoaugment_labeled(img) 263 | img = self.transform(img) # torch tensor shape 3,32,32 264 | else:#TEST 265 | img = transforms.ToTensor()(img) 266 | return img.type(torch.FloatTensor), target 267 | 268 | 269 | ''' 270 | if self.uda: 271 | img = self.data[index] # image with shape 32,32,3 272 | else: 273 | img, target = self.data[index], self.targets[index] # image with shape 32,32,3 274 | img = Image.fromarray(img) # PIL image 275 | 276 | if self.train: 277 | 278 | if self.uda: 279 | # apply UDA 280 | img_uda = self.autoaugment(img) # torch tensor shape 3,32,32 281 | img_uda = self.transform(img_uda) 282 | 283 | img = self.transform(img) # torch tensor shape 3,32,32 284 | #img.clamp_(0.0, 1.0) # remove overflow due to gaussian noise 285 | 286 | if self.uda: 287 | return img.type(torch.FloatTensor), img_uda.type(torch.FloatTensor) 288 | else: 289 | img = transforms.ToTensor()(img) 290 | 291 | return img.type(torch.FloatTensor), target 292 | ''' 293 | 294 | def __len__(self): 295 | if self.train: 296 | if self.uda: 297 | return self.args.eval_iter * self.args.batch_size_unsup 298 | else: 299 | return self.args.eval_iter * self.args.batch_size 300 | else: 301 | return len(self.data) 302 | 303 | def _check_integrity(self): 304 | root = self.root 305 | for fentry in (self.train_list + self.test_list): 306 | filename, md5 = fentry[0], fentry[1] 307 | fpath = os.path.join(root, self.base_folder, filename) 308 | if not check_integrity(fpath, md5): 309 | return False 310 | return True 311 | 312 | def download(self): 313 | import tarfile 314 | 315 | if self._check_integrity(): 316 | print('Files already downloaded and verified') 317 | return 318 | 319 | download_url(self.url, self.root, self.filename, self.tgz_md5) 320 | 321 | # extract file 322 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 323 | tar.extractall(path=self.root) 324 | 325 | def extra_repr(self): 326 | return "Split: {}".format("Train" if self.train is True else "Test") 327 | 328 | 329 | class CIFAR100(CIFAR10): 330 | """`CIFAR100 `_ Dataset. 331 | 332 | This is a subclass of the `CIFAR10` Dataset. 333 | """ 334 | base_folder = 'cifar-100-python' 335 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 336 | filename = "cifar-100-python.tar.gz" 337 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 338 | train_list = [ 339 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 340 | ] 341 | 342 | test_list = [ 343 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 344 | ] 345 | meta = { 346 | 'filename': 'meta', 347 | 'key': 'fine_label_names', 348 | 'md5': '7973b15100ade9c7d40fb424638fde48', 349 | } 350 | if __name__ == '__main__': 351 | import argparse 352 | parser = argparse.ArgumentParser() 353 | parser.add_argument('--batch-size', type=int, default=100) 354 | parser.add_argument('--eval-iter', type=int, default=10000) 355 | parser.add_argument('--batch-size-unsup', type=int, default=960) 356 | parser.add_argument('--gaussian-noise-level',type=float, default=0.15) 357 | args = parser.parse_args() 358 | 359 | cifar10_unnormalize = CIFAR10(args, False, False, normalize=False) 360 | cifar10_normalize = CIFAR10(args, False, False, normalize=True) 361 | 362 | 363 | 364 | loader_normalize = torch.utils.data.DataLoader( cifar10_normalize, batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True ) 365 | loader_unnormalize = torch.utils.data.DataLoader( cifar10_unnormalize, batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True ) 366 | 367 | batch_normalize = next(iter(loader_normalize))[0] 368 | batch_unnormalize = next(iter(loader_unnormalize))[0] 369 | 370 | from train_semi_3 import global_contrast_normalize, ZCA 371 | zca_params = torch.load('./zca_params.pth') 372 | zca = ZCA(zca_params) 373 | batch_unnormalize_gcn = global_contrast_normalize( batch_unnormalize ) 374 | import ipdb;ipdb.set_trace() 375 | 376 | 377 | -------------------------------------------------------------------------------- /cifar/data/cifar-10-batches-py/cifar_label_map_count_4000_index_0: -------------------------------------------------------------------------------- 1 | {"values": ["12221", "18919", "5222", "24534", "2716", "6710", "31016", "23466", "22453", "37466", "25263", "6654", "5966", "10199", "6541", "20001", "43087", "10115", "24878", "43781", "28268", "7268", "1974", "43948", "9364", "38024", "25881", "16324", "27992", "44476", "7088", "2597", "12541", "17832", "24984", "41667", "37010", "8746", "21469", "10082", "12833", "29632", "6481", "11697", "2785", "15291", "6422", "7636", "7396", "22788", "3504", "22736", "27796", "13590", "5193", "11954", "13666", "19376", "19611", "33959", "5843", "17729", "20234", "2792", "20492", "41665", "192", "30728", "6270", "10274", "10465", "10134", "10003", "11723", "17776", "4137", "15103", "4781", "36177", "28849", "43710", "37223", "28175", "5219", "12057", "1529", "42151", "27524", "20377", "12090", "2813", "29459", "30162", "4857", "43308", "18170", "38660", "7378", "40738", "19290", "22463", "38241", "41408", "35193", "23491", "38472", "27511", "2691", "15619", "39301", "26784", "353", "32939", "4412", "44219", "33888", "16591", "32117", "38806", "9266", "3246", "13646", "3751", "29653", "28820", "16392", "11678", "16477", "26580", "30878", "21172", "16453", "4556", "19786", "34752", "35669", "13379", "1916", "6425", "17973", "39028", "14896", "18741", "7756", "44818", "41831", "25448", "44937", "27493", "2552", "28151", "2162", "23133", "9986", "18902", "7250", "41144", "29678", "11986", "6108", "4945", "11636", "26231", "16675", "8627", "36925", "28488", "26163", "14626", "33680", "28632", "1751", "32710", "43613", "42795", "32310", "20551", "23618", "18386", "42993", "31959", "44852", "43863", "42843", "9828", "19467", "9639", "36364", "10566", "39675", "7445", "33833", "8368", "25900", "20210", "35046", "28491", "3916", "9111", "23722", "26650", "10363", "20868", "26440", "39672", "6912", "41518", "34183", "30826", "43626", "44054", "1764", "10688", "16087", "13900", "24489", "30531", "8887", "30034", "41208", "3053", "15765", "38904", "1414", "13310", "8743", "19033", "25111", "15427", "22625", "19560", "12844", "13682", "12969", "30360", "9443", "37145", "32798", "3652", "10869", "6988", "26109", "8141", "3661", "44313", "31663", "4371", "16729", "20007", "2136", "43696", "13738", "34619", "9751", "19337", "3102", "23573", "6350", "26345", "28258", "41262", "39945", "27458", "41465", "41493", "23633", "14166", "20735", "9892", "23354", "22014", "9794", "3521", "14121", "31850", "12463", "25388", "4772", "18616", "26885", "39413", "42188", "14537", "40990", "25233", "22064", "18209", "32756", "480", "26379", "8622", "20340", "22780", "25434", "28136", "3210", "25459", "14718", "40772", "30449", "1016", "12798", "30153", "44083", "28015", "14331", "44588", "32218", "30001", "6332", "32080", "39598", "30082", "29483", "18345", "15938", "17400", "39888", "38991", "43866", "32494", "43926", "4134", "4084", "2650", "1170", "42103", "3257", "40761", "43043", "18706", "19944", "26647", "7771", "30199", "7836", "8461", "30108", "2529", "33130", "6291", "24145", "38117", "35409", "36896", "19916", "43221", "6636", "22499", "7204", "20051", "19688", "26783", "12181", "37339", "25174", "21857", "6876", "44175", "1386", "14341", "18454", "20080", "36524", "15152", "30762", "42466", "17011", "1330", "10107", "34654", "6779", "41942", "404", "22477", "34421", "30854", "35025", "32505", "2845", "32452", "42423", "38094", "39467", "19028", "21112", "35552", "20834", "36309", "24577", "24856", "7743", "40245", "34314", "33024", "32180", "11461", "3789", "36804", "14115", "40971", "17905", "1311", "7861", "27210", "20531", "41186", "27634", "3178", "32499", "38930", "8625", "18894", "16482", "22414", "31032", "1959", "6966", "1309", "20038", "30918", "31886", "37564", "10814", "28669", "12627", "1412", "41538", "15176", "29423", "10740", "31364", "18383", "5217", "5040", "39131", "9748", "22198", "36777", "41199", "35652", "18808", "29533", "17147", "4358", "12226", "43409", "23413", "5012", "27378", "8821", "41584", "35051", "15934", "29070", "26539", "41742", "31223", "12407", "4716", "12135", "27344", "4689", "9932", "33034", "10299", "14031", "5428", "25806", "18104", "28436", "21684", "13316", "41608", "23338", "26461", "44332", "440", "4900", "8932", "9368", "10250", "14876", "24919", "32023", "139", "146", "1392", "9268", "30443", "34141", "39655", "12101", "7544", "22919", "24681", "43623", "34006", "26975", "12449", "10499", "4654", "34409", "25903", "29573", "28490", "2469", "14745", "7050", "28921", "29900", "30922", "1980", "3382", "40712", "4701", "20191", "40825", "23352", "3283", "16610", "1639", "44639", "38272", "39777", "10543", "41836", "42627", "27672", "20818", "32120", "41564", "29769", "19950", "23048", "29450", "29615", "24797", "3835", "5261", "5812", "24139", "6823", "38677", "23139", "20589", "17639", "21725", "1754", "41887", "19195", "29505", "31601", "8050", "29818", "27044", "7559", "4528", "18829", "21932", "13077", "39782", "31344", "22595", "4603", "5884", "8509", "26853", "36408", "41839", "9608", "11627", "11421", "7263", "42014", "33297", "17453", "3998", "12513", "31356", "11229", "17130", "23982", "10795", "22151", "44228", "7127", "20651", "30320", "8153", "4176", "3732", "19947", "30131", "21271", "30657", "38872", "40292", "27860", "34083", "31812", "30445", "22773", "19233", "9179", "39616", "17521", "13160", "17096", "20849", "16828", "14545", "21371", "40451", "13448", "15205", "7273", "35412", "21383", "442", "25404", "11373", "38923", "38445", "11671", "38043", "17630", "21308", "11262", "28224", "6605", "21266", "887", "42981", "31806", "39966", "27445", "16511", "9539", "21678", "30075", "28919", "36856", "34717", "27236", "43655", "10694", "38496", "18730", "40044", "5451", "13701", "4156", "26129", "32510", "274", "22799", "39609", "22634", "22814", "3163", "5565", "13711", "19966", "1555", "39129", "13358", "23720", "34024", "34549", "16114", "34750", "29747", "41427", "17587", "12112", "12686", "34424", "29412", "27194", "31628", "33139", "3048", "22466", "23086", "34663", "23540", "6231", "4977", "42256", "24612", "29327", "115", "24962", "38670", "14238", "11176", "21874", "144", "23799", "24436", "33422", "44899", "2372", "35422", "17085", "10294", "30274", "29648", "20811", "10454", "11443", "14774", "14619", "23067", "20298", "43380", "22577", "25424", "9327", "39333", "36630", "21795", "24335", "15329", "30923", "29471", "3673", "3248", "35244", "5112", "11532", "13499", "44095", "7503", "41996", "44179", "7853", "34250", "28840", "31171", "18711", "42323", "14290", "21082", "18970", "6274", "44919", "20181", "13641", "41423", "13732", "38457", "15448", "570", "35827", "10876", "7041", "8556", "27179", "5498", "42537", "6630", "37503", "34991", "40373", "31967", "28971", "20995", "2245", "21420", "29853", "9526", "40672", "7854", "4936", "33148", "15923", "12185", "16617", "5769", "42076", "19436", "17510", "17506", "17890", "936", "4507", "20243", "38767", "24308", "3812", "39779", "40947", "42599", "6758", "15332", "29572", "30212", "32069", "12747", "34437", "39985", "2070", "28115", "14333", "32839", "14181", "16627", "31117", "39423", "33752", "2733", "29101", "17417", "17885", "36934", "4229", "38892", "12145", "35752", "40302", "16887", "39969", "12351", "40758", "1003", "5643", "8193", "23916", "16770", "20230", "1059", "24328", "41945", "22901", "25421", "4246", "21912", "36143", "30007", "5047", "989", "40655", "33149", "23492", "41308", "16855", "42213", "32612", "8508", "21280", "3641", "30242", "19251", "5403", "37955", "5068", "40674", "17950", "7028", "43954", "9819", "17160", "22953", "1750", "31390", "4718", "41986", "35219", "854", "32532", "40363", "27943", "18038", "13361", "16126", "21222", "12357", "2910", "22596", "15524", "18548", "32900", "715", "22871", "15692", "17575", "15033", "37418", "8108", "6523", "29499", "39705", "670", "14662", "36184", "3526", "7602", "44843", "43185", "44884", "31974", "7873", "26581", "1448", "22091", "39947", "24624", "16937", "12528", "29303", "3286", "36494", "39393", "7567", "32903", "20642", "19762", "38876", "8132", "20713", "35575", "19302", "16594", "791", "22831", "32273", "15815", "22147", "26090", "21811", "36322", "41691", "5051", "38819", "33726", "11613", "36760", "30343", "43769", "34738", "10379", "15657", "34287", "20296", "25178", "22974", "32514", "41975", "14027", "12884", "16475", "3502", "36690", "7348", "8766", "27437", "5466", "22189", "41316", "11682", "9935", "10077", "44123", "18363", "27299", "26377", "15874", "17932", "1642", "6652", "14386", "11500", "26701", "27767", "28526", "21163", "1148", "26724", "18888", "12476", "745", "7246", "345", "5506", "3729", "19037", "11841", "17464", "29334", "35808", "37208", "40799", "12789", "4547", "42018", "18019", "1884", "14221", "30276", "11101", "42286", "43491", "34874", "25028", "39883", "43739", "27160", "25809", "28714", "4671", "24941", "19606", "965", "14047", "5655", "13174", "26414", "8118", "35484", "26411", "13422", "25805", "31552", "8603", "36685", "17811", "33795", "35379", "10022", "2853", "7464", "23290", "10020", "19086", "15822", "4510", "43605", "41952", "326", "42894", "8289", "31826", "19343", "13354", "34053", "24061", "16652", "43038", "38758", "19671", "4102", "42374", "22995", "34524", "16056", "43481", "12682", "30420", "32408", "5475", "15023", "9035", "5562", "1372", "36000", "37677", "26428", "15303", "34351", "2185", "31352", "9289", "28783", "7927", "7316", "35284", "10341", "19826", "8131", "25860", "22628", "41929", "1498", "20749", "22065", "38857", "10917", "26805", "21190", "14661", "21103", "26670", "28810", "3997", "44618", "23507", "25931", "7708", "7026", "15060", "4842", "33074", "33741", "41603", "38588", "3492", "31610", "42514", "39462", "5190", "36794", "30736", "28017", "44573", "12437", "44137", "21637", "36313", "15317", "4666", "33590", "38007", "32685", "3264", "19567", "40330", "31487", "6127", "36882", "26140", "42112", "44201", "36955", "6546", "13600", "16317", "31581", "2849", "12249", "19146", "13106", "17063", "40417", "26527", "15932", "12788", "40441", "10035", "13131", "13568", "40189", "6487", "8397", "34519", "41485", "13724", "2254", "24594", "35148", "38747", "43916", "387", "30621", "8146", "28895", "25018", "24346", "13731", "35605", "43346", "11672", "16448", "41762", "28777", "6125", "21081", "31377", "41447", "11960", "31989", "21786", "5570", "44745", "5532", "37981", "39813", "35893", "22009", "13855", "17219", "15726", "16807", "13018", "12637", "11811", "44811", "41449", "32539", "39012", "34256", "36582", "23173", "4145", "19312", "36014", "21455", "35986", "44581", "640", "20637", "24891", "20854", "16959", "39114", "19951", "30853", "36798", "16186", "23084", "28555", "32782", "38212", "21524", "12684", "36340", "7082", "24274", "37284", "9392", "7533", "36164", "33101", "33330", "23707", "28406", "12408", "3332", "8796", "12188", "42330", "18923", "41697", "14248", "26604", "17540", "13566", "33332", "16971", "41727", "2487", "26518", "722", "14273", "35464", "20890", "31618", "32385", "23349", "2932", "9651", "18294", "13505", "10290", "16347", "15062", "31756", "18568", "29149", "36223", "22136", "17435", "23277", "23603", "16653", "32699", "21585", "37635", "35494", "28984", "3165", "19150", "14040", "32191", "23701", "39307", "39987", "7453", "1351", "15079", "26262", "2066", "33142", "11273", "43037", "36203", "39279", "27601", "43451", "36512", "5041", "35709", "2755", "41484", "34884", "42044", "4417", "29009", "35764", "24221", "32642", "39732", "11963", "16873", "19954", "33861", "9313", "37595", "13847", "17125", "11519", "17315", "44130", "25484", "22126", "11425", "12745", "10724", "28720", "25982", "36220", "3871", "14176", "9211", "8420", "31488", "18637", "15164", "15937", "18212", "31922", "31430", "6587", "1136", "9501", "12921", "29822", "7345", "21963", "11050", "6504", "38722", "39542", "26056", "11787", "30574", "16700", "18127", "27687", "28551", "17437", "10968", "3590", "33020", "7921", "25250", "6182", "23059", "44323", "21326", "12516", "19824", "9408", "39350", "43237", "30727", "272", "31787", "807", "11617", "37288", "30949", "44268", "5581", "42922", "4606", "4699", "41551", "8216", "2077", "14450", "44310", "33440", "27105", "20824", "42669", "8197", "30794", "35123", "8045", "18273", "35826", "18723", "9958", "38802", "17550", "18411", "19444", "4317", "38860", "17704", "40489", "36345", "30807", "25597", "40569", "40152", "11133", "39786", "5848", "29825", "101", "10181", "24714", "32262", "38155", "16983", "13181", "4868", "40170", "30322", "22731", "7789", "37691", "34583", "31145", "20945", "25306", "4980", "36771", "5482", "5993", "3396", "43267", "39245", "28586", "27455", "40159", "2715", "24032", "36441", "43653", "7622", "9812", "9042", "7839", "37860", "32432", "30650", "18083", "11388", "8545", "3040", "10010", "37578", "41545", "31759", "21793", "40808", "3529", "515", "24701", "1503", "30786", "445", "34682", "43462", "11476", "18977", "40627", "803", "14373", "41919", "5901", "27066", "33311", "26499", "36967", "34724", "22174", "44978", "44229", "36731", "2935", "7733", "19005", "36453", "42017", "31838", "42209", "40766", "122", "4897", "38259", "28228", "37960", "21341", "37367", "25916", "13415", "34304", "22430", "267", "5715", "19045", "25308", "22273", "17754", "11819", "21473", "25413", "4203", "14964", "20600", "38023", "33294", "10483", "39458", "23255", "13051", "35247", "34055", "7618", "10296", "25782", "26903", "11792", "12672", "39385", "29813", "28307", "21483", "12468", "3070", "23614", "19927", "24658", "10750", "25007", "4126", "26439", "18705", "11497", "11071", "34016", "41627", "30539", "15799", "12880", "15663", "11272", "39444", "44430", "27752", "25593", "14209", "32194", "42630", "43437", "21264", "17979", "21652", "44783", "6237", "24808", "27988", "2836", "19475", "38293", "10052", "5008", "9736", "41114", "27951", "4597", "41520", "17441", "41278", "8682", "38307", "10179", "15614", "15097", "28396", "3195", "40964", "5149", "24427", "11302", "29069", "19184", "39477", "20215", "40801", "44912", "39690", "17444", "40970", "15839", "39023", "44166", "42215", "38947", "41040", "18014", "9291", "26207", "26917", "37723", "2851", "11624", "32225", "14294", "44048", "14285", "43688", "24650", "14247", "34310", "24591", "21378", "41453", "33708", "601", "37764", "22729", "22697", "14274", "26221", "29548", "17144", "23200", "15699", "17465", "32260", "7337", "12401", "21754", "29375", "44257", "22592", "34151", "37050", "35922", "2144", "3437", "36710", "24672", "29317", "44225", "16767", "1677", "7023", "15348", "22940", "33410", "21975", "11771", "20211", "6389", "17478", "36199", "13278", "6227", "15288", "27253", "41896", "30571", "34242", "31380", "4546", "21040", "40501", "40276", "39302", "6158", "42936", "10325", "41129", "10448", "1496", "35617", "31057", "11921", "43173", "19732", "22864", "19956", "20548", "32555", "23675", "17674", "36515", "39013", "4550", "23324", "31982", "21205", "14504", "38385", "6268", "41875", "40013", "25852", "22444", "21407", "25267", "21431", "26050", "4808", "27743", "42482", "10277", "16321", "43829", "37219", "6600", "37336", "6204", "17523", "13024", "14", "42367", "40960", "27953", "6401", "37833", "18802", "33510", "39212", "25102", "19038", "25686", "903", "938", "17538", "37787", "43181", "12717", "4602", "26749", "29896", "405", "40471", "30600", "30143", "19035", "18210", "26410", "14563", "15487", "10476", "38870", "30573", "16694", "9934", "18684", "29280", "37179", "2087", "4281", "17529", "39865", "36866", "14641", "40588", "3211", "13875", "41709", "21437", "7882", "24775", "29883", "43327", "28390", "14573", "37759", "9319", "22761", "21195", "31843", "22045", "31294", "9304", "24950", "18629", "38986", "10547", "11180", "40582", "19858", "35259", "5783", "33430", "3956", "34872", "17003", "6468", "15093", "39091", "17108", "17675", "22708", "15292", "44793", "12656", "31958", "6993", "21836", "33985", "7421", "12157", "22645", "11184", "44335", "7701", "31100", "35095", "20650", "20381", "24070", "2921", "13768", "7129", "21046", "24628", "21806", "13706", "42651", "15199", "34122", "29882", "44004", "848", "3054", "1289", "7643", "7746", "32196", "18621", "37548", "32162", "20840", "265", "17768", "15464", "43447", "37458", "18372", "29768", "2584", "24449", "8979", "11726", "24774", "40435", "37755", "33393", "38051", "8055", "16025", "37362", "24657", "40579", "24382", "7670", "43379", "30606", "23713", "28070", "28244", "21660", "3757", "15032", "39185", "3445", "970", "18600", "2583", "5803", "33087", "13215", "4845", "20786", "43348", "25894", "17948", "17107", "44681", "15581", "39769", "42560", "36526", "26940", "19064", "41324", "41570", "7703", "41979", "9193", "30656", "12276", "62", "37104", "6878", "36930", "32047", "15476", "35540", "40840", "30594", "16629", "37199", "4405", "43015", "6476", "7155", "17168", "39829", "12361", "35906", "13005", "7096", "11371", "17935", "38636", "5379", "9166", "4931", "7716", "11491", "36021", "10983", "21861", "27725", "37941", "23817", "23025", "1948", "13498", "7442", "7538", "15209", "42345", "42902", "36792", "3901", "21770", "8866", "134", "23857", "10562", "41249", "7092", "25353", "44050", "44145", "36086", "7757", "40693", "27984", "15685", "34570", "24360", "29864", "4112", "9533", "20963", "24060", "25495", "41043", "39344", "34281", "39892", "30405", "39292", "28599", "18027", "33969", "29084", "37168", "22802", "36620", "26690", "34139", "13374", "24252", "35940", "38312", "9662", "9868", "851", "4225", "6730", "2881", "43372", "3012", "43297", "23973", "41234", "32371", "33439", "7318", "643", "37767", "43026", "19022", "18563", "41813", "12527", "32284", "21453", "43086", "39336", "43398", "13746", "2507", "28068", "23636", "31537", "28751", "9842", "28647", "25849", "43123", "44432", "606", "891", "21119", "26734", "30340", "44450", "8185", "11886", "24180", "4179", "17352", "1231", "33564", "35223", "41976", "16768", "3772", "6975", "26778", "14921", "29935", "26039", "44182", "38463", "8445", "24686", "25708", "11327", "10286", "10463", "36413", "9470", "34401", "27658", "2421", "11945", "13030", "10380", "41662", "13340", "11837", "16273", "12871", "33042", "2901", "14724", "17480", "32526", "39584", "13163", "42148", "18799", "17265", "27663", "42607", "42707", "19256", "26218", "29598", "19340", "16986", "18045", "5030", "5421", "15470", "17183", "6145", "19149", "3813", "4414", "28787", "43064", "42505", "808", "32634", "19326", "12056", "19635", "25937", "3407", "39399", "14686", "28473", "3711", "42296", "14771", "8524", "25137", "15798", "22535", "22672", "30066", "4228", "25325", "24120", "20421", "27097", "23642", "43088", "34087", "42248", "21373", "33163", "5105", "35134", "11488", "20833", "10038", "770", "2660", "8669", "41100", "29995", "16142", "32874", "38901", "33615", "16078", "18390", "31539", "7356", "13961", "27593", "38587", "2750", "29098", "33762", "29484", "36211", "1894", "38917", "2191", "3861", "28614", "24398", "13546", "15578", "24165", "11080", "23998", "9412", "15248", "16314", "15460", "3310", "2999", "17819", "17857", "2520", "34004", "6704", "22986", "16250", "39760", "35590", "35790", "1730", "38634", "25964", "37797", "12852", "22214", "28124", "5676", "11006", "32559", "16021", "10577", "4893", "30969", "26515", "5173", "11404", "13667", "32585", "14206", "34261", "4802", "15713", "20217", "28924", "25328", "43512", "29682", "34036", "40871", "6429", "4661", "4171", "28887", "12304", "5830", "25311", "25217", "42806", "14217", "4364", "39631", "35079", "42849", "32328", "13609", "11660", "2132", "34534", "9579", "32696", "41682", "31993", "23557", "32581", "23591", "21644", "41974", "10872", "40578", "36617", "28314", "22254", "24191", "41995", "14899", "21286", "33451", "42572", "37441", "37762", "42769", "14311", "1243", "4014", "1004", "38265", "20726", "25344", "35843", "18915", "35995", "448", "1337", "26196", "38526", "26927", "44204", "15133", "22272", "33629", "44815", "34947", "17004", "18136", "784", "38712", "5473", "34551", "19513", "43358", "25241", "34645", "28226", "11721", "41539", "19021", "23984", "2050", "40946", "8818", "23803", "40793", "18250", "24855", "30190", "38089", "4840", "13870", "17141", "24269", "1828", "37246", "27297", "15121", "8550", "7416", "12592", "23860", "6309", "35612", "35619", "16420", "33043", "34815", "13082", "38096", "30423", "33481", "1182", "44038", "2288", "15620", "39437", "591", "15166", "6283", "2580", "2248", "10358", "15509", "485", "38113", "40154", "31362", "31473", "20192", "35570", "6505", "15373", "42737", "3625", "23879", "30086", "26911", "32917", "27871", "39134", "2664", "20272", "40763", "41997", "40905", "9436", "19718", "42877", "8551", "38618", "28305", "37659", "17965", "17696", "4775", "6255", "3009", "13512", "38354", "26247", "40902", "32402", "11579", "4327", "24770", "35309", "40571", "6890", "27768", "23026", "31135", "10632", "25568", "39096", "44648", "24704", "13263", "8619", "23556", "16243", "14873", "42989", "17680", "22062", "13396", "23670", "165", "13308", "44075", "13652", "5133", "17667", "35759", "25181", "10984", "26013", "26608", "40295", "32869", "10027", "42395", "11406", "10337", "16443", "19201", "16724", "35533", "20317", "31722", "6229", "5630", "33112", "16891", "36210", "33063", "39287", "32415", "36749", "23369", "42349", "18219", "35024", "24503", "19586", "35920", "811", "39774", "35138", "12079", "7932", "22578", "14815", "25863", "10754", "23142", "39735", "11177", "3136", "43443", "19969", "9425", "18347", "34206", "17708", "26594", "3744", "16483", "38015", "26521", "44383", "35848", "21109", "21296", "31334", "2820", "41136", "37886", "28570", "25389", "18880", "18647", "39665", "38974", "15671", "5996", "44267", "22383", "43913", "17455", "13599", "33384", "44955", "27919", "5296", "3797", "17614", "16020", "21720", "16147", "4076", "18393", "42160", "7913", "29717", "26133", "8975", "3371", "20481", "34491", "13225", "1025", "1770", "17386", "13589", "35742", "27884", "23096", "19409", "24854", "1250", "22448", "9286", "30396", "7270", "18640", "19202", "21336", "31245", "14536", "42383", "4797", "19490", "32837", "27153", "26467", "28099", "22498", "40392", "33609", "21625", "28366", "32480", "29516", "5177", "35382", "39686", "1383", "26363", "13239", "20300", "42131", "37474", "41072", "3578", "9396", "43035", "43709", "4111", "1508", "22132", "8478", "39734", "32659", "41304", "29920", "35356", "16877", "11865", "29710", "41893", "33175", "40744", "1681", "17706", "11054", "10715", "5459", "14491", "5381", "12744", "11692", "8549", "28975", "12984", "32972", "21523", "37241", "41806", "36553", "19658", "6572", "40586", "4990", "10631", "4284", "29982", "23510", "36575", "5778", "35473", "22252", "19381", "6982", "34664", "36883", "36881", "24683", "3676", "34679", "35209", "32300", "33593", "3721", "23425", "12606", "27647", "21635", "379", "43877", "3327", "37349", "34778", "42979", "15590", "33622", "19759", "1164", "25055", "859", "31179", "36379", "4914", "12876", "1646", "32815", "21690", "6036", "9180", "33360", "7671", "36832", "30273", "31559", "20717", "29378", "31970", "14041", "35455", "36579", "31398", "40315", "11378", "43254", "14032", "3259", "16514", "29677", "41971", "39370", "28434", "39146", "36697", "22285", "9012", "11571", "40928", "29820", "11969", "10519", "11270", "36944", "32676", "27181", "27719", "25992", "31375", "5631", "14814", "2246", "30919", "20263", "10087", "6749", "26913", "6777", "35368", "37909", "43119", "6421", "1110", "24087", "33227", "1416", "28769", "28646", "43240", "13472", "26496", "2538", "34652", "44060", "36562", "10434", "32721", "22490", "29381", "30934", "22492", "44542", "40086", "34042", "4350", "8246", "41080", "14127", "8501", "7774", "3276", "18054", "18161", "42140", "42328", "29453", "4471", "17938", "20200", "10749", "40369", "44693", "17993", "2866", "23504", "32776", "13658", "26694", "22268", "3138", "23061", "5605", "20525", "19110", "17870", "6895", "30764", "8139", "37904", "16551", "24744", "33554", "33920", "14212", "36520", "14364", "25563", "5609", "29524", "31146", "32592", "27843", "4882", "14902", "17719", "32863", "2972", "17361", "19469", "31481", "37113", "7398", "13839", "38294", "29357", "29760", "269", "9977", "32309", "22104", "23582", "18590", "16941", "24321", "25195", "40796", "28141", "5808", "16225", "34099", "30493", "35104", "24276", "24815", "31866", "30715", "148", "33400", "22467", "39507", "11196", "15881", "32332", "26003", "1812", "28509", "24076", "13237", "4649", "33528", "23360", "27230", "40822", "23613", "34171", "11623", "5522", "23748", "40665", "27347", "3503", "21402", "35881", "40511", "23883", "36970", "21078", "42134", "21685", "17310", "18846", "38006", "653", "39156", "9641", "3141", "40688", "17869", "38284", "39906", "8468", "12559", "41082", "8896", "39110", "25345", "14272", "29157", "32230", "1541", "22540", "10224", "30867", "5014", "16450", "40714", "16886", "9695", "1717", "13119", "25349", "21723", "16824", "27607", "6706", "8741", "40994", "35835", "22474", "4608", "8691", "26160", "35667", "42678", "42585", "9551", "11894", "38687", "17735", "7143", "15068", "28080", "13475", "24709", "25613", "37315", "21010", "41079", "20367", "4763", "32836", "39004", "44353", "17862", "4403", "19983", "2186", "1022", "27766", "6074", "38593", "42644", "21288", "43477", "34530", "12693", "35772", "7153", "13502", "16999", "11717", "14898", "10760", "30658", "38522", "24557", "10094", "7114", "9910", "699", "40852", "44829", "13591", "32049", "34672", "35021", "34554", "17320", "26697", "37760", "21773", "41842", "7007", "38279", "22854", "36252", "16590", "4554", "24499", "43750", "4634", "7795", "27851", "13013", "20919", "8456", "12486", "9833", "15853", "28167", "35729", "19744", "31333", "7303", "44600", "40027", "39587", "35043", "7905", "5060", "2596", "16360", "36357", "20515", "16996", "20015", "33885", "35408", "39974", "37136", "11154", "28443", "18344", "8833", "3101", "15999", "11141", "9775", "18671", "16310", "25391", "2349", "28384", "23554", "352", "43449", "43986", "41676", "17038", "36918", "1661", "8565", "31228", "40765", "42010", "31150", "38206", "19720", "3940", "12415", "19870", "19928", "3127", "849", "7139", "4505", "2987", "26598", "24860", "35411", "35714", "28972", "22684", "41885", "38066", "21669", "41270", "39296", "14367", "31673", "32357", "2257", "40234", "34920", "20120", "16357", "33923", "31772", "41050", "27862", "18967", "26073", "9488", "35726", "14730", "3859", "40197", "19107", "8033", "35042", "17787", "25831", "40351", "7876", "28893", "17533", "23041", "40853", "29076", "11374", "35834", "16041", "20708", "8624", "6928", "38153", "27151", "19629", "19068", "24665", "36071", "35901", "35586", "26098", "42901", "34041", "28253", "41376", "42801", "22159", "2088", "34889", "10854", "44942", "9904", "43727", "28734", "7806", "3122", "38948", "17498", "35006", "10937", "42589", "11144", "44499", "44571", "8931", "34739", "25714", "42885", "30608", "44662", "5681", "25943", "31237", "33201", "13097", "13628", "25061", "38886", "22324", "29229", "33134", "20657", "5930", "5472", "10563", "20516", "24967", "4858", "13421", "32437", "12714", "13820", "37792", "21929", "13484", "3221", "31718", "38519", "33711", "20471", "42681", "7152", "19929", "38167", "37386", "17246", "29759", "1213", "35781", "20766", "34840", "7511", "11018", "25080", "33576", "32901", "31005", "19001", "15243", "32778", "2736", "27854", "10385", "24424", "28069", "10737", "28430", "39765", "9756", "8294", "25933", "24176", "12512", "2604", "9029", "2517", "13339", "11910", "20320", "21223", "38524", "22403", "36657", "33403", "15912", "19921", "27614", "15540", "8643", "30676", "21760", "15107", "44946", "37598", "18615", "595", "38608", "44731", "18477", "22677", "32741", "13637", "35553", "17463", "29539", "41111", "4474", "5563", "8727", "30603", "2145", "42321", "7439", "6233", "34300", "12719", "19164", "23150", "26285", "3344", "11906", "43291", "14320", "19604", "22333", "3745", "39235", "43301", "14205", "10485", "348", "38648", "13550", "13960", "10652", "26185", "27126", "12267", "28996", "37925", "18066", "40593", "16740", "26569", "14302", "43157", "39319", "18013", "33093", "25074", "19980", "8463", "35544", "37954", "23820", "30858", "8134", "38149", "11376", "33067", "31011", "10784", "42142", "31686", "14568", "23157", "8020", "3369", "14769", "41157", "29961", "34161", "43389", "44963", "16252", "34759", "42566", "7266", "31290", "28264", "17094", "15715", "21427", "42293", "19990", "27669", "35266", "20095", "10500", "34822", "14624", "8667", "42428", "42615", "43583", "39428", "19025", "40839", "24133", "14901", "6260", "23273", "26659", "25097", "3920", "25177", "14520", "16538", "12353", "27088", "20532", "35185", "25467", "30940", "30861", "36653", "12195", "3556", "29050", "40477", "7175", "40664", "19757", "13762", "6828", "10097", "4754", "4473", "5405", "18737", "3109", "7436", "13995", "40391", "24965", "41955", "8580", "13287", "7864", "34288", "38869", "28074", "18449", "6670", "7564", "26674", "40049", "19085", "40262", "31869", "42500", "36879", "692", "23463", "31412", "40136", "15792", "3557", "38181", "17134", "35330", "28362", "15940", "17469", "42551", "13468", "29772", "9874", "33963", "44756", "32337", "34838", "9879", "4604", "41011", "16577", "9249", "20947", "23252", "25970", "13490", "5286", "17968", "44126", "9693", "29591", "12166", "10252", "43293", "16686", "915", "4991", "13606", "21282", "42542", "4738", "42733", "15440", "535", "13854", "28961", "3358", "22421", "5662", "1957", "25489", "33132", "26596", "9783", "12364", "23756", "11486", "8199", "30125", "8277", "44550", "1143", "17683", "26444", "6934", "41274", "27480", "21572", "29366", "14533", "39034", "34273", "36998", "37688", "13886", "40474", "19349", "37158", "29306", "21705", "24809", "32109", "14681", "29374", "42613", "34244", "5999", "23853", "13378", "16973", "34593", "1607", "8341", "8755", "39916", "11919", "42814", "2713", "18079", "15747", "5351", "31983", "12214", "4244", "18010", "20644", "15392", "12924", "4870", "24623", "39489", "23909", "44589", "14600", "44837", "333", "30379", "8824", "28616", "5597", "5380", "25381", "29847", "17409", "14214", "1888", "2643", "44147", "15283", "7962", "17496", "38893", "20023", "44841", "15507", "23121", "3786", "8842", "2943", "31716", "3301", "16111", "25830", "37160", "28861", "26092", "7791", "27855", "1732", "26668", "729", "21544", "2772", "28303", "1168", "29546", "22274", "32981", "43242", "44987", "12367", "19761", "36919", "27205", "28960", "11591", "7721", "41590", "2152", "20433", "19779", "44176", "33865", "7928", "4272", "3324", "2712", "21796", "4592", "21965", "2216", "25643", "33094", "43952", "31597", "40854", "8264", "43607", "20326", "15633", "9204", "15331", "31368", "40635", "3687", "2169", "44838", "30765", "31186", "17338", "1866", "6437", "41701", "41504", "23375", "33808", "1029", "35096", "2644", "982", "10644", "28274", "5703", "822", "22843", "9375", "6501", "28400", "23066", "21362", "34597", "6698", "30672", "5069", "17239", "17461", "28289", "3111", "7274", "37424", "41247", "15860", "664", "21363", "32221", "26989", "16838", "34040", "12554", "1373", "26154", "20327", "37096", "8077", "7758", "33204", "37093", "16017", "33428", "40374", "42790", "9261", "40356", "31693", "1119", "20843", "39541", "40041", "4706", "10900", "20122", "16935", "28775", "22036", "31650", "43071", "1270", "36095", "39249", "30213", "28908", "35172", "8219", "14239", "39651", "15075", "28948", "25509", "2382", "43271", "14692", "9118", "33944", "5883", "24072", "38851", "22291", "29140", "19193", "40636", "28703", "39415", "41575", "15779", "32305", "7693", "12892", "39316", "10287", "13359", "8371", "21584", "16498", "11856", "39667", "19433", "593", "12105", "33195", "23631", "5945", "27220", "12966", "29285", "30651", "44716", "6258", "17284", "863", "7887", "19532", "44144", "10802", "43513", "32947", "35239", "10426", "33492", "27056", "19131", "19764", "6002", "19812", "8098", "23264", "37917", "20579", "33859", "39021", "38716", "24944", "5094", "36201", "30866", "32314", "21856", "8479", "16466", "22845", "27795", "16064", "30259", "8823", "27705", "6470", "37863", "10388", "8966", "2323", "19828", "6826", "23153", "11139", "42693", "14953", "29511", "13168", "32125", "3377", "38756", "20603", "2440", "35887", "17286", "4918", "33269", "33906", "23", "23995", "24836", "10520", "10972", "13383", "21314", "5790", "41964", "36205", "39953", "36432", "41708", "37707", "28771", "7599", "27680", "4735", "17350", "10925", "37647", "2108", "17919", "41348", "15679", "35040", "36049", "2760", "23726", "22520", "20208", "41474", "27113", "1172", "112", "36942", "25934", "35766", "37198", "34589", "31309", "2904", "5435", "9654", "6662", "41417", "44940", "24663", "35022", "8171", "31003", "14696", "36369", "31933", "10774", "31999", "8175", "13371", "15100", "9517", "9068", "147", "26909", "7120", "2419", "16433", "41517", "9421", "29191", "18810", "43298", "30059", "7174", "40734", "41235", "27130", "16502", "2314", "38190", "29530", "41744", "27331", "7941", "7798", "19698", "31502", "33156", "6183", "436", "6215", "30611", "18907", "44091", "39986", "12979", "22489", "285", "9010", "7845", "26147", "19846", "12803", "44925", "34265", "39079", "14632", "24379", "21110", "30282", "20104", "38131", "34260", "11992", "38950", "31102", "12054", "4149", "41350", "14554", "28094", "28601", "40123", "42384", "41115", "14154", "37186", "12797", "37332", "27697", "15511", "30894", "19294", "15531", "41499", "3803", "38938", "10482", "36825", "19519", "28695", "20334", "7252", "16076", "39761", "26351", "9805", "35447", "29563", "2233", "9020", "9005", "40942", "1710", "10767", "37514", "20957", "2944", "37638", "42840", "20351", "23501", "9509", "41693", "35374", "5407", "18740", "12233", "14263", "30811", "7729", "19724", "44696", "33129", "42931", "20236", "25297", "39958", "15852", "6212", "466", "28218", "11876", "16688", "21695", "8326", "27476", "22481", "39404", "19118", "3912", "41638", "19141", "8664", "7541", "43178", "23810", "41398", "13720", "37922", "38604", "36850", "31778", "36275", "43618", "11933", "592", "31645", "8781", "40228", "8408", "38743", "11743", "28031", "737", "5445", "23794", "4114", "14794", "13802", "25483", "13678", "15049", "23262", "33200", "473", "12177", "11307", "29957", "44142", "15658", "40008", "34693", "29960", "32533", "1475", "21867", "41692", "21219", "27018", "17840", "42271", "9744", "10991", "41699", "8331", "38746", "17367", "42448", "24849", "11166", "11746", "5371", "819", "16846", "12870", "17152", "19389", "19291", "14149", "28312", "44515", "18483", "7104", "22191", "22856", "33193", "28634", "16082", "35064", "43825", "25460", "11408", "16496", "39118", "13352", "12346", "19415", "44153", "5223", "22949", "28738", "9664", "34791", "34958", "13186", "24728", "17734"]} -------------------------------------------------------------------------------- /cifar/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | from torch.autograd import Variable 16 | 17 | from wideresnet import WideResNet 18 | import numpy as np 19 | 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', type=str, default='cifar10') 24 | parser.add_argument('--arch', type=str, default='WRN-28-2') 25 | parser.add_argument('--name', type=str, required=True) 26 | 27 | # HYPER PARAMS 28 | parser.add_argument('--optimizer', type=str, default='Adam') 29 | parser.add_argument('--dropout-rate', type=float, default=0.3) 30 | parser.add_argument('--lr', type=float, default=3e-2) 31 | parser.add_argument('--final-lr', type=float, default=1.2e-4) 32 | parser.add_argument('--batch-size', type=int, default=100) 33 | parser.add_argument('--l1-reg', type=float, default=0.0) 34 | parser.add_argument('--l2-reg', type=float, default=1e-4) 35 | parser.add_argument('--lr-decay-rate', type=float, default=0.2) 36 | parser.add_argument('--max-iter', type=int, default=100000) 37 | parser.add_argument('--lr-decay-at', nargs='+', default=[80000, 90000]) 38 | parser.add_argument('--lr-decay', type=str, default='step') 39 | parser.add_argument('--normalization', type=str, default='GCN_ZCA') 40 | parser.add_argument('--num-workers', type=int, default=20) 41 | parser.add_argument('--print-freq', type=int, default=20) 42 | parser.add_argument('--split', type=int, default=0) 43 | parser.add_argument('--eval-iter', type=int, default=2000) 44 | parser.add_argument('--nesterov', action='store_true') 45 | 46 | parser.add_argument('--AutoAugment', action='store_true') 47 | parser.add_argument('--AutoAugment-cutout-only', action='store_true') 48 | parser.add_argument('--AutoAugment-all', action='store_true') 49 | parser.add_argument('--UDA', action='store_true') 50 | parser.add_argument('--UDA-CUTOUT', action='store_true') 51 | parser.add_argument('--use-cutout', action='store_true') 52 | parser.add_argument('--TSA', type=str, default=None) 53 | parser.add_argument('--batch-size-unsup', type=int, default=960) 54 | parser.add_argument('--unsup-loss-weight', type=float, default=1.0) 55 | parser.add_argument('--cifar10-policy-all', action='store_true') 56 | parser.add_argument('--clip-grad-norm', default=-1, type=float) 57 | parser.add_argument('--leakiness',type=float,default=0.01) 58 | parser.add_argument('--add-labeled-to-unlabeled', action='store_true') 59 | parser.add_argument('--warmup-steps', type=int, default=0) 60 | 61 | def TSA_th(cur_step): 62 | global args 63 | num_classes = 10 64 | if args.TSA == 'linear': 65 | th = float(cur_step) / float(args.max_iter) * (1-1 / float(num_classes)) + 1 / float(num_classes) 66 | elif args.TSA == 'log': 67 | th = (1 - np.exp(- float(cur_step) / float(args.max_iter) * 5)) * (1 - 1 / float(num_classes)) + 1 / float(num_classes) 68 | elif args.TSA == 'exp': 69 | th = np.exp( (float(cur_step) / float(args.max_iter) - 1) * 5) * (1 - 1 / float(num_classes)) + 1 / float(num_classes) 70 | else: 71 | th = 1.0 72 | return th 73 | 74 | def global_contrast_normalize(X, scale=55., min_divisor=1e-8): 75 | X = X.view(X.size(0), -1) 76 | X = X - X.mean(dim=1, keepdim=True) 77 | 78 | normalizers = torch.sqrt( torch.pow( X, 2).sum(dim=1, keepdim=True)) / scale 79 | normalizers[normalizers < min_divisor] = 1. 80 | X /= normalizers 81 | 82 | return X.view(X.size(0),3,32,32) 83 | #return X 84 | 85 | class ZCA(object): 86 | def __init__(self, zca_params): 87 | self.meanX = torch.FloatTensor(zca_params['meanX']).unsqueeze(0).cuda() 88 | self.W = torch.FloatTensor(zca_params['W']).cuda() 89 | 90 | def __call__(self, sample): 91 | sample = sample.view( sample.size(0), -1 ) 92 | return torch.matmul( sample - self.meanX, self.W ).view(sample.size(0), 3,32,32) 93 | 94 | def main(): 95 | global args, best_prec1, exp_dir 96 | 97 | best_prec1 = 0 98 | args = parser.parse_args() 99 | print (args.lr_decay_at) 100 | assert args.normalization in ['GCN_ZCA', 'GCN'], 'normalization {} unknown'.format(args.normalization) 101 | 102 | global zca 103 | if 'ZCA' in args.normalization: 104 | zca_params = torch.load('./data/cifar-10-batches-py/zca_params.pth') 105 | zca = ZCA(zca_params) 106 | else: 107 | zca = None 108 | 109 | exp_dir = os.path.join('experiments', args.name) 110 | if os.path.exists(exp_dir): 111 | print ("same experiment exist...") 112 | #return 113 | else: 114 | os.makedirs(exp_dir) 115 | 116 | # DATA SETTINGS 117 | global dataset_train, dataset_test 118 | if args.dataset == 'cifar10': 119 | import cifar 120 | dataset_train = cifar.CIFAR10(args, train=True) 121 | dataset_test = cifar.CIFAR10(args, train=False) 122 | if args.UDA: 123 | # loader for UDA 124 | dataset_train_uda = cifar.CIFAR10(args, True, True) 125 | uda_loader = torch.utils.data.DataLoader( dataset_train_uda, batch_size=args.batch_size_unsup, shuffle=True, num_workers=args.num_workers, pin_memory=True ) 126 | iter_uda = iter(uda_loader) 127 | else: 128 | iter_uda = None 129 | 130 | train_loader, test_loader = initialize_loader() 131 | 132 | # MODEL SETTINGS 133 | if args.arch == 'WRN-28-2': 134 | model = WideResNet(28, [100,10][int(args.dataset=='cifar10')], 2, dropRate=args.dropout_rate) 135 | model = torch.nn.DataParallel(model.cuda()) 136 | else: 137 | raise NotImplementedError('arch {} is not implemented'.format(args.arch)) 138 | if args.optimizer == 'Adam': 139 | print ("use Adam optimizer") 140 | optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.l2_reg ) 141 | elif args.optimizer == 'SGD': 142 | print ("use SGD optimizer") 143 | optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.l2_reg, nesterov=args.nesterov) 144 | 145 | if args.lr_decay=='cosine': 146 | print ("use cosine lr scheduler") 147 | global scheduler 148 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_iter, eta_min=args.final_lr) 149 | 150 | global batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup 151 | batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 152 | t = time.time() 153 | model.train() 154 | iter_sup = iter(train_loader) 155 | for train_iter in range(args.max_iter): 156 | # TRAIN 157 | lr = adjust_learning_rate(optimizer, train_iter + 1) 158 | train(model, iter_sup, optimizer, train_iter, data_iterator_uda=iter_uda) 159 | 160 | # LOGGING 161 | if (train_iter+1) % args.print_freq == 0: 162 | print('ITER: [{0}/{1}]\t' 163 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 164 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 165 | 'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})\t' 166 | 'Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})\t' 167 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 168 | 'Learning rate {2} TSA th {3}'.format( 169 | train_iter, args.max_iter, lr, TSA_th(train_iter), 170 | batch_time=batch_time, 171 | loss=losses_sup, 172 | l1_loss=losses_l1, 173 | unsup_loss=losses_unsup, 174 | top1=top1)) 175 | 176 | if (train_iter+1)%args.eval_iter == 0 or train_iter+1 == args.max_iter: 177 | # EVAL 178 | print ("evaluation at iter {}".format(train_iter)) 179 | prec1 = test(model, test_loader) 180 | 181 | is_best = prec1 > best_prec1 182 | best_prec1 = max(prec1, best_prec1) 183 | save_checkpoint({ 184 | 'iter': train_iter + 1, 185 | 'state_dict': model.state_dict(), 186 | 'best_prec1': best_prec1, 187 | }, is_best) 188 | print ("* Best accuracy: {}".format(best_prec1)) 189 | eval_interval_time = time.time() - t; t = time.time() 190 | print ("total {} sec for {} iterations".format(eval_interval_time, args.eval_iter)) 191 | seconds_remaining = eval_interval_time / float(args.eval_iter) * (args.max_iter - train_iter) 192 | print ("{}:{}:{} remaining".format( int(seconds_remaining // 3600), int( (seconds_remaining % 3600) // 60), int(seconds_remaining % 60))) 193 | model.train() 194 | batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 195 | iter_sup = iter(train_loader) 196 | if iter_uda is not None: 197 | iter_uda = iter(uda_loader) 198 | 199 | def train(model, data_iterator, optimizer, iteration, data_iterator_uda=None): 200 | global args 201 | global batch_time, losses_sup, top1, losses_l1, losses_unsup 202 | t = time.time() 203 | 204 | input, target = next(data_iterator) 205 | input, target = input.cuda(), target.cuda().long() 206 | 207 | input = global_contrast_normalize( input ) 208 | 209 | output = model(input) 210 | 211 | tsa_th = TSA_th( iteration ) 212 | 213 | prec1 = accuracy(output.data, target, topk=(1,))[0] 214 | 215 | 216 | 217 | # Loss calculation with TSA 218 | if args.TSA is None: 219 | loss_sup = torch.nn.functional.cross_entropy(output, target, reduction='mean') 220 | else: 221 | num_classes = 10 if args.dataset=='cifar10' else 100 222 | target_onehot = torch.FloatTensor( input.size(0), num_classes ).cuda() 223 | target_onehot.zero_() 224 | target_onehot.scatter_(1, target.unsqueeze(1), 1) 225 | output_softmax = torch.nn.functional.softmax( output, dim=1 ).detach() 226 | gt_softmax = (target_onehot * output_softmax).sum(dim=1) 227 | loss_mask = (gt_softmax <= tsa_th).float() 228 | loss_sup = torch.sum( torch.nn.functional.cross_entropy(output, target, reduction='none') * loss_mask ) / (loss_mask.sum()+1e-6) 229 | #loss_sup = torch.nn.functional.cross_entropy(output, target) 230 | #kl_div_loss = torch.nn.KLDivLoss(reduction='batchmean').cuda() 231 | if args.UDA: 232 | input_unsup, input_unsup_aug = next(data_iterator_uda) 233 | input_unsup = input_unsup.cuda() 234 | input_unsup_aug = input_unsup_aug.cuda() 235 | 236 | input_unsup = global_contrast_normalize( input_unsup ) 237 | input_unsup_aug = global_contrast_normalize( input_unsup_aug ) 238 | 239 | with torch.no_grad(): 240 | output_unsup = model(input_unsup) 241 | output_unsup_aug = model(input_unsup_aug) 242 | 243 | #import ipdb;ipdb.set_trace() 244 | loss_unsup = torch.nn.functional.kl_div( 245 | torch.nn.functional.log_softmax(output_unsup_aug, dim=1), 246 | torch.nn.functional.softmax(output_unsup, dim=1).detach(), 247 | reduction='batchmean') * args.unsup_loss_weight 248 | else: 249 | loss_unsup = None 250 | ''' 251 | loss_l1 = 0 252 | numel = 0 253 | for param in model.parameters(): 254 | loss_l1 += torch.sum(torch.abs(param)) 255 | numel += param.nelement() 256 | #loss_l1 = loss_l1 * args.l1_reg / float(numel) 257 | loss_l1 = loss_l1 * args.l1_reg 258 | ''' 259 | all_linear1_params = torch.cat([x.view(-1) for x in model.parameters()]) 260 | loss_l1 = args.l1_reg * torch.norm(all_linear1_params, 1) 261 | 262 | #loss = loss_sup + loss_l1 263 | 264 | optimizer.zero_grad() 265 | #loss_sup.backward() 266 | #loss_l1.backward() 267 | if loss_unsup is not None: 268 | loss_all = loss_sup + loss_unsup + loss_l1 269 | else: 270 | loss_all = loss_sup + loss_l1 271 | if args.clip_grad_norm > 0: 272 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 273 | loss_all.backward() 274 | optimizer.step() 275 | 276 | top1.update( prec1.item(), input.size(0) ) 277 | losses_sup.update( loss_sup.data.item(), input.size(0)) 278 | losses_l1.update( loss_l1.data.item(), input.size(0)) 279 | if loss_unsup is not None: 280 | losses_unsup.update( loss_unsup.data.item(), args.batch_size_unsup) 281 | batch_time.update(time.time()-t) 282 | 283 | def test(model, val_loader): 284 | """Perform validation on the validation set""" 285 | batch_time = AverageMeter() 286 | losses = AverageMeter() 287 | top1 = AverageMeter() 288 | 289 | # switch to evaluate mode 290 | model.eval() 291 | 292 | end = time.time() 293 | for i, (input, target) in enumerate(val_loader): 294 | target = target.cuda(async=True).long() 295 | input = input.cuda() 296 | input = global_contrast_normalize( input ) 297 | 298 | # compute output 299 | with torch.no_grad(): 300 | output = model(input) 301 | loss = torch.nn.functional.cross_entropy(output, target) 302 | 303 | # measure accuracy and record loss 304 | prec1 = accuracy(output.data, target, topk=(1,))[0] 305 | losses.update(loss.data.item(), input.size(0)) 306 | top1.update(prec1.item(), input.size(0)) 307 | 308 | # measure elapsed time 309 | batch_time.update(time.time() - end) 310 | end = time.time() 311 | 312 | if i % args.print_freq == 0: 313 | print('Test: [{0}/{1}]\t' 314 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 315 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 316 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 317 | i, len(val_loader), batch_time=batch_time, loss=losses, 318 | top1=top1)) 319 | 320 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 321 | return top1.avg 322 | 323 | 324 | class AverageMeter(object): 325 | """Computes and stores the average and current value""" 326 | def __init__(self): 327 | self.reset() 328 | 329 | def reset(self): 330 | self.val = 0 331 | self.avg = 0 332 | self.sum = 0 333 | self.count = 0 334 | 335 | def update(self, val, n=1): 336 | self.val = val 337 | self.sum += val * n 338 | self.count += n 339 | self.avg = self.sum / self.count 340 | 341 | def adjust_learning_rate(optimizer, it): 342 | if args.warmup_steps > 0 and args.warmup_steps > it: 343 | # do warm up lr 344 | lr = float(it) / float(args.warmup_steps) * args.lr 345 | else: 346 | if args.lr_decay=='step': 347 | lr = args.lr 348 | for lr_decay_at in args.lr_decay_at: 349 | lr *= args.lr_decay_rate ** int(it >= int(lr_decay_at) ) 350 | for param_group in optimizer.param_groups: 351 | param_group['lr'] = lr 352 | elif args.lr_decay=='linear': 353 | lr = args.final_lr + (args.lr-args.final_lr) * float(args.max_iter - it) / float(args.max_iter) 354 | for param_group in optimizer.param_groups: 355 | param_group['lr'] = lr 356 | elif args.lr_decay=='cosine': 357 | global scheduler 358 | scheduler.step() 359 | lr = scheduler.get_lr() 360 | else: 361 | raise ValueError('unknown lr decay method {}'.format(args.lr_decay)) 362 | return lr 363 | 364 | def accuracy(output, target, topk=(1,)): 365 | """Computes the precision@k for the specified values of k""" 366 | maxk = max(topk) 367 | batch_size = target.size(0) 368 | 369 | _, pred = output.topk(maxk, 1, True, True) 370 | pred = pred.t() 371 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 372 | 373 | res = [] 374 | for k in topk: 375 | correct_k = correct[:k].view(-1).float().sum(0) 376 | res.append(correct_k.mul_(100.0 / batch_size)) 377 | return res 378 | 379 | def initialize_loader(): 380 | global dataset_train, dataset_test, args 381 | train_loader = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) 382 | test_loader = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=5, pin_memory=True ) 383 | return train_loader, test_loader 384 | 385 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 386 | """Saves checkpoint to disk""" 387 | global exp_dir 388 | filename = os.path.join(exp_dir, filename) 389 | torch.save(state, filename) 390 | if is_best: 391 | shutil.copyfile(filename, os.path.join( exp_dir, 'model_best.pth.tar') ) 392 | 393 | if __name__ == '__main__': 394 | main() 395 | -------------------------------------------------------------------------------- /cifar/run_uda.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --normalization GCN --batch-size 64 --batch-size-unsup 320 --l1-reg 0 --l2-reg 0.0005 --final-lr 0.00012 --max-iter 400000 --lr 0.03 --optimizer SGD --lr-decay cosine --nesterov --warmup-steps 20000 \ 3 | --dropout-rate 0.0 \ 4 | --UDA --cifar10-policy-all --UDA-CUTOUT \ 5 | --name UDA_AutoAugment_FULL_Cutout_no_dropout_400K_warmup_larger_batch_single_gpu \ 6 | >> UDA_AutoAugment_FULL_Cutout_l2_no_dropout_400K_warmup_larger_batch_single_gpu.log & 7 | -------------------------------------------------------------------------------- /cifar/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | #relu = nn.ReLU 7 | relu = nn.LeakyReLU 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, leakiness=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu1 = relu(negative_slope=leakiness, inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(out_planes) 17 | self.relu2 = relu(negative_slope=leakiness, inplace=True) 18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRate 21 | self.equalInOut = (in_planes == out_planes) 22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 23 | padding=0, bias=False) or None 24 | def forward(self, x): 25 | if not self.equalInOut: 26 | x = self.relu1(self.bn1(x)) 27 | else: 28 | out = self.relu1(self.bn1(x)) 29 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 30 | if self.droprate > 0: 31 | out = F.dropout(out, p=self.droprate, training=self.training) 32 | out = self.conv2(out) 33 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, leakiness=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, leakiness=leakiness) 39 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, leakiness=0.0): 40 | layers = [] 41 | for i in range(int(nb_layers)): 42 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, leakiness=leakiness)) 43 | return nn.Sequential(*layers) 44 | def forward(self, x): 45 | return self.layer(x) 46 | 47 | class WideResNet(nn.Module): 48 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, leakiness=0.0): 49 | super(WideResNet, self).__init__() 50 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 51 | assert((depth - 4) % 6 == 0) 52 | n = (depth - 4) / 6 53 | block = BasicBlock 54 | # 1st conv before any network block 55 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 56 | padding=1, bias=False) 57 | # 1st block 58 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, leakiness=leakiness) 59 | # 2nd block 60 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, leakiness=leakiness) 61 | # 3rd block 62 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, leakiness=leakiness) 63 | # global average pooling and classifier 64 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 65 | self.relu = relu(negative_slope=leakiness, inplace=True) 66 | self.fc = nn.Linear(nChannels[3], num_classes) 67 | self.nChannels = nChannels[3] 68 | 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | elif isinstance(m, nn.BatchNorm2d): 74 | m.weight.data.fill_(1) 75 | m.bias.data.zero_() 76 | elif isinstance(m, nn.Linear): 77 | m.bias.data.zero_() 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.block1(out) 81 | out = self.block2(out) 82 | out = self.block3(out) 83 | out = self.relu(self.bn1(out)) 84 | out = F.avg_pool2d(out, 8) 85 | out = out.view(-1, self.nChannels) 86 | return self.fc(out) 87 | -------------------------------------------------------------------------------- /imagenet/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ImageNetPolicy(object): 7 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 8 | 9 | Example: 10 | >>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | 13 | Example as a PyTorch Transform: 14 | >>> transform=transforms.Compose([ 15 | >>> transforms.Resize(256), 16 | >>> ImageNetPolicy(), 17 | >>> transforms.ToTensor()]) 18 | """ 19 | def __init__(self, fillcolor=(128, 128, 128)): 20 | self.policies = [ 21 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 22 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 23 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 24 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 25 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 26 | 27 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 28 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 29 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 30 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 31 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 32 | 33 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 34 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 35 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 37 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 38 | 39 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 40 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 41 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 42 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 43 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 44 | 45 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 46 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 47 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 48 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 49 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 50 | ] 51 | 52 | 53 | def __call__(self, img): 54 | policy_idx = random.randint(0, len(self.policies) - 1) 55 | return self.policies[policy_idx](img) 56 | 57 | def __repr__(self): 58 | return "AutoAugment ImageNet Policy" 59 | 60 | 61 | class CIFAR10Policy(object): 62 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 63 | 64 | Example: 65 | >>> policy = CIFAR10Policy() 66 | >>> transformed = policy(image) 67 | 68 | Example as a PyTorch Transform: 69 | >>> transform=transforms.Compose([ 70 | >>> transforms.Resize(256), 71 | >>> CIFAR10Policy(), 72 | >>> transforms.ToTensor()]) 73 | """ 74 | def __init__(self, fillcolor=(128, 128, 128)): 75 | self.policies = [ 76 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 77 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 78 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 79 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 80 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 81 | 82 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 83 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 84 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 85 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 86 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 87 | 88 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 89 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 90 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 91 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 92 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 93 | 94 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 95 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 96 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 97 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 98 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 99 | 100 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 101 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 102 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 103 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 104 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 105 | ] 106 | 107 | 108 | def __call__(self, img): 109 | policy_idx = random.randint(0, len(self.policies) - 1) 110 | return self.policies[policy_idx](img) 111 | 112 | def __repr__(self): 113 | return "AutoAugment CIFAR10 Policy" 114 | 115 | 116 | class SVHNPolicy(object): 117 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 118 | 119 | Example: 120 | >>> policy = SVHNPolicy() 121 | >>> transformed = policy(image) 122 | 123 | Example as a PyTorch Transform: 124 | >>> transform=transforms.Compose([ 125 | >>> transforms.Resize(256), 126 | >>> SVHNPolicy(), 127 | >>> transforms.ToTensor()]) 128 | """ 129 | def __init__(self, fillcolor=(128, 128, 128)): 130 | self.policies = [ 131 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 133 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 135 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 138 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 139 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 141 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 142 | 143 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 144 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 145 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 146 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 147 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 148 | 149 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 150 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 151 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 152 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 153 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 154 | 155 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 156 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 157 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 158 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 159 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 160 | ] 161 | 162 | 163 | def __call__(self, img): 164 | policy_idx = random.randint(0, len(self.policies) - 1) 165 | return self.policies[policy_idx](img) 166 | 167 | def __repr__(self): 168 | return "AutoAugment SVHN Policy" 169 | 170 | 171 | class SubPolicy(object): 172 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 173 | ranges = { 174 | "shearX": np.linspace(0, 0.3, 10), 175 | "shearY": np.linspace(0, 0.3, 10), 176 | "translateX": np.linspace(0, 150 / 331, 10), 177 | "translateY": np.linspace(0, 150 / 331, 10), 178 | "rotate": np.linspace(0, 30, 10), 179 | "color": np.linspace(0.0, 0.9, 10), 180 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 181 | "solarize": np.linspace(256, 0, 10), 182 | "contrast": np.linspace(0.0, 0.9, 10), 183 | "sharpness": np.linspace(0.0, 0.9, 10), 184 | "brightness": np.linspace(0.0, 0.9, 10), 185 | "autocontrast": [0] * 10, 186 | "equalize": [0] * 10, 187 | "invert": [0] * 10 188 | } 189 | 190 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 191 | def rotate_with_fill(img, magnitude): 192 | rot = img.convert("RGBA").rotate(magnitude) 193 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 194 | 195 | func = { 196 | "shearX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 198 | Image.BICUBIC, fillcolor=fillcolor), 199 | "shearY": lambda img, magnitude: img.transform( 200 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 201 | Image.BICUBIC, fillcolor=fillcolor), 202 | "translateX": lambda img, magnitude: img.transform( 203 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 204 | fillcolor=fillcolor), 205 | "translateY": lambda img, magnitude: img.transform( 206 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 207 | fillcolor=fillcolor), 208 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 209 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 210 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 211 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 212 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 213 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 214 | 1 + magnitude * random.choice([-1, 1])), 215 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 216 | 1 + magnitude * random.choice([-1, 1])), 217 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 218 | 1 + magnitude * random.choice([-1, 1])), 219 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 220 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 221 | "invert": lambda img, magnitude: ImageOps.invert(img) 222 | } 223 | 224 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 225 | # operation1, ranges[operation1][magnitude_idx1], 226 | # operation2, ranges[operation2][magnitude_idx2]) 227 | self.p1 = p1 228 | self.operation1 = func[operation1] 229 | self.magnitude1 = ranges[operation1][magnitude_idx1] 230 | self.p2 = p2 231 | self.operation2 = func[operation2] 232 | self.magnitude2 = ranges[operation2][magnitude_idx2] 233 | 234 | 235 | def __call__(self, img): 236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 238 | return img -------------------------------------------------------------------------------- /imagenet/data_split/labeled_images_0.10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jongchan/unsupervised_data_augmentation_pytorch/7fce04a05c2da4ca98de32bfc305bb99f511915e/imagenet/data_split/labeled_images_0.10.pth -------------------------------------------------------------------------------- /imagenet/data_split/unlabeled_images_0.90.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jongchan/unsupervised_data_augmentation_pytorch/7fce04a05c2da4ca98de32bfc305bb99f511915e/imagenet/data_split/unlabeled_images_0.90.pth -------------------------------------------------------------------------------- /imagenet/imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import random 7 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 8 | 9 | 10 | def is_image_file(filename): 11 | """Checks if a file is an image. 12 | 13 | Args: 14 | filename (string): path to a file 15 | 16 | Returns: 17 | bool: True if the filename ends with a known image extension 18 | """ 19 | filename_lower = filename.lower() 20 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 21 | 22 | 23 | def find_classes(dir): 24 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 25 | classes.sort() 26 | class_to_idx = {classes[i]: i for i in range(len(classes))} 27 | return classes, class_to_idx 28 | 29 | 30 | def pil_loader(path): 31 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 32 | with open(path, 'rb') as f: 33 | with Image.open(f) as img: 34 | return img.convert('RGB') 35 | 36 | 37 | def accimage_loader(path): 38 | import accimage 39 | try: 40 | return accimage.Image(path) 41 | except IOError: 42 | # Potentially a decoding problem, fall back to PIL.Image 43 | return pil_loader(path) 44 | 45 | 46 | def default_loader(path): 47 | from torchvision import get_image_backend 48 | if get_image_backend() == 'accimage': 49 | return accimage_loader(path) 50 | else: 51 | return pil_loader(path) 52 | 53 | def load_db(db_path, class_to_idx): 54 | db = torch.load(db_path) 55 | images = [] 56 | for key in sorted(db.keys()): 57 | for image_path in db[key]: 58 | images.append( (image_path, class_to_idx[key]) ) 59 | return images 60 | 61 | from autoaugment import ImageNetPolicy 62 | class ImageNet(data.Dataset): 63 | 64 | def __init__(self, root, args, transform=None, target_transform=None, 65 | loader=default_loader, db_path='./data_split/labeled_images_0.10.pth', is_unlabeled=False): 66 | classes, class_to_idx = find_classes(root) 67 | #imgs = make_dataset(root, class_to_idx) 68 | imgs = load_db(db_path, class_to_idx) 69 | if len(imgs) == 0: 70 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 71 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 72 | 73 | self.root = root 74 | self.imgs = imgs 75 | self.classes = classes 76 | self.class_to_idx = class_to_idx 77 | self.transform = transform 78 | self.target_transform = target_transform 79 | self.loader = loader 80 | self.is_unlabeled = is_unlabeled 81 | self.autoaugment = ImageNetPolicy() 82 | 83 | self.indices = [i for i in range(len(imgs))] 84 | random.shuffle(self.indices) 85 | if self.is_unlabeled: 86 | self.total_train_count = args.batch_size_unlabeled * args.max_iter * args.unlabeled_iter 87 | else: 88 | self.total_train_count = args.batch_size * args.max_iter 89 | 90 | print ("sample count {}".format(len(self.indices))) 91 | print ("total sample count {}".format(self.total_train_count)) 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | 98 | Returns: 99 | tuple: (image, target) where target is class_index of the target class. 100 | """ 101 | #if self.is_unlabeled: 102 | # print ("reading index {}".format(index)) 103 | random_index = self.indices[index%len(self.indices)] 104 | path, target = self.imgs[random_index] 105 | img = self.loader(path) 106 | 107 | if self.is_unlabeled: 108 | aug_img = self.autoaugment(img) 109 | aug_img = self.transform(aug_img) 110 | 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | if self.target_transform is not None: 114 | target = self.target_transform(target) 115 | 116 | if self.is_unlabeled: 117 | return img, aug_img 118 | else: 119 | return img, target 120 | 121 | def __len__(self): 122 | return self.total_train_count 123 | -------------------------------------------------------------------------------- /imagenet/scripts/run_baseline_resnet18.sh: -------------------------------------------------------------------------------- 1 | 2 | # Settings from S4L paper. 200 epochs, base LR 0.1, LR decay at 140, 160, 180. Batch size unknown. weight decay 0.001 3 | CUDA_VISIBLE_DEVICES=4,5,6,7 \ 4 | python train_imagenet.py ./ImageNet/ \ 5 | --arch resnet18 \ 6 | --workers 30 \ 7 | --batch-size 512 \ 8 | --batch-size-unlabeled 1024 \ 9 | --unlabeled-iter 15 \ 10 | --print-freq 1 \ 11 | --lr 0.3 \ 12 | --weight-decay 0.001 \ 13 | --max-iter 40000 \ 14 | --lr-drop-iter 13000 26000 35000 \ 15 | --warmup --warmup-iter 2500 \ 16 | --save_dir checkpoint/resnet18_UDA_bs512_bs15360_40K \ 17 | # 2>&1 | tee logs/resnet18_UDA_bs512_bs15360_40K 18 | 19 | #CUDA_VISIBLE_DEVICES=4,5,6,7 \ 20 | # python train_imagenet.py ./ImageNet/ \ 21 | # --arch resnet18 \ 22 | # --workers 20 \ 23 | # --batch-size 512 \ 24 | # --lr 0.2 \ 25 | # --weight-decay 0.001 \ 26 | # --max-iter 100000 \ 27 | # --lr-drop-iter 70000 80000 90000 \ 28 | # --warmup --warmup-iter 2500 \ 29 | # --save_dir checkpoint/baseline_resnet18_S4L_bs512 \ 30 | # 2>&1 | tee logs/baseline_resnet18_S4L_bs512 & 31 | -------------------------------------------------------------------------------- /imagenet/scripts/run_baseline_resnet34.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 4 | python train_imagenet.py ./ImageNet/ \ 5 | --arch resnet34 \ 6 | --workers 20 \ 7 | --batch-size 256 \ 8 | --lr 0.1 \ 9 | >> logs/baseline_resnet34_bs256.log && 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 12 | python train_imagenet.py ./ImageNet/ \ 13 | --arch resnet34 \ 14 | --workers 20 \ 15 | --batch-size 128 \ 16 | --lr 0.05 \ 17 | >> logs/baseline_resnet34_bs128.log && 18 | 19 | CUDA_VISIBLE_DEVICES=0,1 \ 20 | python train_imagenet.py ./ImageNet/ \ 21 | --arch resnet34 \ 22 | --workers 20 \ 23 | --batch-size 64 \ 24 | --lr 0.025 \ 25 | >> logs/baseline_resnet34_bs64.log && 26 | 27 | CUDA_VISIBLE_DEVICES=2,3 \ 28 | python train_imagenet.py ./ImageNet/ \ 29 | --arch resnet34 \ 30 | --workers 20 \ 31 | --batch-size 32 \ 32 | --lr 0.0125 \ 33 | >> logs/baseline_resnet34_bs32.log & 34 | -------------------------------------------------------------------------------- /imagenet/scripts/run_baseline_resnet50.sh: -------------------------------------------------------------------------------- 1 | 2 | # Settings from S4L paper. 200 epochs, base LR 0.1, LR decay at 140, 160, 180. Batch size unknown. weight decay 0.001 3 | #CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 4 | # python train_imagenet.py ./ImageNet/ \ 5 | # --arch resnet50 \ 6 | # --workers 20 \ 7 | # --batch-size 256 \ 8 | # --lr 0.1 \ 9 | # --weight-decay 0.001 \ 10 | # --max-iter 100000 \ 11 | # --lr-drop-iter 70000 80000 90000 \ 12 | # --warmup --warmup-iter 2500 \ 13 | # --save_dir checkpoint/baseline_resnet50_S4L_bs256 \ 14 | # >> logs/baseline_resnet50_S4L_bs256 15 | 16 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 17 | python train_imagenet.py ./ImageNet/ \ 18 | --arch resnet50 \ 19 | --workers 40 \ 20 | --batch-size 512 \ 21 | --lr 0.2 \ 22 | --weight-decay 0.001 \ 23 | --max-iter 100000 \ 24 | --lr-drop-iter 70000 80000 90000 \ 25 | --warmup --warmup-iter 2500 \ 26 | --save_dir checkpoint/baseline_resnet50_S4L_bs512 \ 27 | 2>&1 | tee logs/baseline_resnet50_S4L_bs512 28 | 29 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 30 | python train_imagenet.py ./ImageNet/ \ 31 | --arch resnet50 \ 32 | --workers 40 \ 33 | --batch-size 512 \ 34 | --lr 0.2 \ 35 | --weight-decay 0.001 \ 36 | --max-iter 50000 \ 37 | --lr-drop-iter 35000 40000 45000 \ 38 | --warmup --warmup-iter 2500 \ 39 | --save_dir checkpoint/baseline_resnet50_S4L_bs512_50K \ 40 | 2>&1 | tee logs/baseline_resnet50_S4L_bs512_50K 41 | -------------------------------------------------------------------------------- /imagenet/separate_labeled_unlabeled.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import torch 5 | 6 | 7 | 8 | def separate_and_save_dataset(dataset_root, labeled_portion, seed=123): 9 | random.seed(seed) 10 | dataset_cls_dirs = glob.glob( os.path.join(dataset_root, '*') ) 11 | 12 | labeled_images = {} 13 | unlabeled_images = {} 14 | 15 | for cls_idx, dataset_cls_dir in enumerate(sorted(dataset_cls_dirs)): 16 | cls_key = os.path.basename(dataset_cls_dir).replace('/','') 17 | print ("cls {}/{} cls_key {}".format(cls_idx, len(dataset_cls_dirs), cls_key)) 18 | 19 | image_paths = glob.glob( os.path.join(dataset_cls_dir, '*') ) 20 | random.shuffle(image_paths) 21 | print ("total {} images".format(len(image_paths))) 22 | labeled_count = int( len(image_paths) * labeled_portion ) 23 | 24 | labeled_paths = image_paths[:labeled_count] 25 | unlabeled_paths = image_paths[labeled_count:] 26 | 27 | labeled_images[cls_key] = labeled_paths 28 | unlabeled_images[cls_key] = unlabeled_paths 29 | 30 | torch.save(labeled_images, 'data_split/labeled_images_{:.2f}.pth'.format(labeled_portion)) 31 | torch.save(unlabeled_images, 'data_split/unlabeled_images_{:.2f}.pth'.format(1-labeled_portion)) 32 | 33 | if __name__=='__main__': 34 | separate_and_save_dataset('./ImageNet/train', 0.1, 123) 35 | -------------------------------------------------------------------------------- /imagenet/train_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.nn.functional as F 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torchvision.models as models 19 | from imagenet_dataset import ImageNet 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('data', metavar='DIR', 27 | help='path to dataset') 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet18)') 33 | parser.add_argument('-j', '--workers', default=20, type=int, metavar='N', 34 | help='number of data loading workers (default: 4)') 35 | parser.add_argument('--max-iter', default=40000, type=int) 36 | parser.add_argument('--lr-drop-iter', nargs="+", default=[40000//3, 40000*2//3, 40000*8//9]) 37 | parser.add_argument('--eval-iter', default=500, type=int) 38 | parser.add_argument('--print-freq', default=10, type=int) 39 | parser.add_argument('--warmup', action='store_true') 40 | parser.add_argument('--warmup-iter', type=int, default=40000*5//90) 41 | parser.add_argument('-bu', '--batch-size-unlabeled', default=0, type=int) 42 | parser.add_argument('-ui', '--unlabeled-iter', default=30, type=int) 43 | parser.add_argument('-b', '--batch-size', default=512, type=int, 44 | metavar='N', 45 | help='mini-batch size (default: 256), this is the total ' 46 | 'batch size of all GPUs on the current node when ' 47 | 'using Data Parallel or Distributed Data Parallel') 48 | parser.add_argument('--lr', '--learning-rate', default=0.3, type=float, 49 | metavar='LR', help='initial learning rate', dest='lr') 50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 51 | help='momentum') 52 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 53 | metavar='W', help='weight decay (default: 1e-4)', 54 | dest='weight_decay') 55 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 56 | help='path to latest checkpoint (default: none)') 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 60 | help='use pre-trained model') 61 | parser.add_argument('--seed', default=None, type=int, 62 | help='seed for initializing training. ') 63 | parser.add_argument('--gpu', default=None, type=int, 64 | help='GPU id to use.') 65 | parser.add_argument('--save_dir', required=True, type=str) 66 | 67 | best_acc1 = 0 68 | 69 | 70 | def main(): 71 | args = parser.parse_args() 72 | args.lr_drop_iter = [int(val) for val in args.lr_drop_iter] 73 | 74 | if args.seed is not None: 75 | random.seed(args.seed) 76 | torch.manual_seed(args.seed) 77 | cudnn.deterministic = True 78 | warnings.warn('You have chosen to seed training. ' 79 | 'This will turn on the CUDNN deterministic setting, ' 80 | 'which can slow down your training considerably! ' 81 | 'You may see unexpected behavior when restarting ' 82 | 'from checkpoints.') 83 | 84 | if args.gpu is not None: 85 | warnings.warn('You have chosen a specific GPU. This will completely ' 86 | 'disable data parallelism.') 87 | 88 | ngpus_per_node = torch.cuda.device_count() 89 | main_worker(args.gpu, ngpus_per_node, args) 90 | 91 | 92 | def main_worker(gpu, ngpus_per_node, args): 93 | global best_acc1 94 | args.gpu = gpu 95 | 96 | if args.gpu is not None: 97 | print("Use GPU: {} for training".format(args.gpu)) 98 | 99 | # create model 100 | if args.pretrained: 101 | print("=> using pre-trained model '{}'".format(args.arch)) 102 | model = models.__dict__[args.arch](pretrained=True) 103 | else: 104 | print("=> creating model '{}'".format(args.arch)) 105 | model = models.__dict__[args.arch]() 106 | 107 | if args.gpu is not None: 108 | torch.cuda.set_device(args.gpu) 109 | model = model.cuda(args.gpu) 110 | else: 111 | # DataParallel will divide and allocate batch_size to all available GPUs 112 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 113 | model.features = torch.nn.DataParallel(model.features) 114 | model.cuda() 115 | else: 116 | model = torch.nn.DataParallel(model).cuda() 117 | 118 | # define loss function (criterion) and optimizer 119 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 120 | 121 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 122 | momentum=args.momentum, 123 | weight_decay=args.weight_decay) 124 | 125 | # optionally resume from a checkpoint 126 | if args.resume: 127 | if os.path.isfile(args.resume): 128 | print("=> loading checkpoint '{}'".format(args.resume)) 129 | checkpoint = torch.load(args.resume) 130 | args.start_epoch = checkpoint['epoch'] 131 | best_acc1 = checkpoint['best_acc1'] 132 | if args.gpu is not None: 133 | # best_acc1 may be from a checkpoint from a different GPU 134 | best_acc1 = best_acc1.to(args.gpu) 135 | model.load_state_dict(checkpoint['state_dict']) 136 | optimizer.load_state_dict(checkpoint['optimizer']) 137 | print("=> loaded checkpoint '{}' (epoch {})" 138 | .format(args.resume, checkpoint['epoch'])) 139 | else: 140 | print("=> no checkpoint found at '{}'".format(args.resume)) 141 | 142 | cudnn.benchmark = True 143 | 144 | # Data loading code 145 | traindir = os.path.join(args.data, 'train') 146 | valdir = os.path.join(args.data, 'val') 147 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 148 | std=[0.229, 0.224, 0.225]) 149 | 150 | train_labeled_dataset = ImageNet( 151 | traindir, args, 152 | transforms.Compose([ 153 | transforms.RandomResizedCrop(224), 154 | transforms.RandomHorizontalFlip(), 155 | transforms.ToTensor(), 156 | normalize, 157 | ]), 158 | db_path='./data_split/labeled_images_0.10.pth', 159 | ) 160 | 161 | train_unlabeled_dataset = ImageNet( 162 | traindir, args, 163 | transforms.Compose([ 164 | transforms.RandomResizedCrop(224), 165 | transforms.RandomHorizontalFlip(), 166 | transforms.ToTensor(), 167 | normalize, 168 | ]), 169 | db_path='./data_split/unlabeled_images_0.90.pth', 170 | is_unlabeled=True, 171 | ) 172 | 173 | train_labeled_loader = torch.utils.data.DataLoader( 174 | train_labeled_dataset, batch_size=args.batch_size, shuffle=True, 175 | num_workers=args.workers, pin_memory=True, sampler=None) 176 | 177 | if args.batch_size_unlabeled > 0: 178 | train_unlabeled_loader = torch.utils.data.DataLoader( 179 | train_unlabeled_dataset, batch_size=args.batch_size_unlabeled, shuffle=False, 180 | num_workers=args.workers, pin_memory=True, sampler=None) 181 | else: 182 | train_unlabeled_loader = None 183 | 184 | val_loader = torch.utils.data.DataLoader( 185 | datasets.ImageFolder(valdir, transforms.Compose([ 186 | transforms.Resize(256), 187 | transforms.CenterCrop(224), 188 | transforms.ToTensor(), 189 | normalize, 190 | ])), 191 | batch_size=args.batch_size, shuffle=False, 192 | num_workers=args.workers, pin_memory=True) 193 | 194 | entropy_criterion = HLoss() 195 | 196 | iter_sup = iter(train_labeled_loader) 197 | if train_unlabeled_loader is None: 198 | iter_unsup = None 199 | else: 200 | iter_unsup = iter(train_unlabeled_loader) 201 | 202 | model.train() 203 | meters = initialize_meters() 204 | for train_iter in range(args.max_iter): 205 | 206 | lr = adjust_learning_rate(optimizer, train_iter + 1, args) 207 | 208 | train(iter_sup, model, optimizer, criterion, iter_unsup, entropy_criterion, meters, args) 209 | 210 | if (train_iter+1) % args.print_freq == 0: 211 | print('ITER: [{0}/{1}]\t' 212 | 'Data time {data_time.val:.3f} ({data_time.avg:.3f})\t' 213 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 214 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 215 | 'HLoss {h_loss.val:.4f} ({h_loss.avg:.4f})\t' 216 | 'Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})\t' 217 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 218 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 219 | 'Learning rate {2}'.format( 220 | train_iter, args.max_iter, lr, 221 | data_time=meters['data_time'], 222 | batch_time=meters['batch_time'], 223 | loss=meters['losses'], 224 | h_loss=meters['losses_entropy'], 225 | unsup_loss=meters['losses_unsup'], 226 | top1=meters['top1'], 227 | top5=meters['top5'])) 228 | if (train_iter+1) % args.eval_iter == 0: 229 | # evaluate on validation set 230 | acc1 = validate(val_loader, model, criterion, args) 231 | 232 | # remember best acc@1 and save checkpoint 233 | is_best = acc1 > best_acc1 234 | best_acc1 = max(acc1, best_acc1) 235 | 236 | save_checkpoint({ 237 | 'iter': train_iter + 1, 238 | 'arch': args.arch, 239 | 'state_dict': model.state_dict(), 240 | 'best_acc1': best_acc1, 241 | 'optimizer' : optimizer.state_dict(), 242 | }, is_best, args.save_dir) 243 | model.train() 244 | meters = initialize_meters() 245 | 246 | def initialize_meters(): 247 | batch_time = AverageMeter('Time', ':6.3f') 248 | data_time = AverageMeter('Data', ':6.3f') 249 | losses = AverageMeter('Loss', ':.4e') 250 | losses_entropy = AverageMeter('Loss entropy', ':.4e') 251 | losses_unsup = AverageMeter('Loss unsup', ':.4e') 252 | top1 = AverageMeter('Acc@1', ':6.2f') 253 | top5 = AverageMeter('Acc@5', ':6.2f') 254 | return {'batch_time':batch_time, 255 | 'data_time':data_time, 256 | 'losses':losses, 257 | 'losses_entropy':losses_entropy, 258 | 'losses_unsup':losses_unsup, 259 | 'top1':top1, 260 | 'top5':top5} 261 | 262 | def train(iter_sup, model, optimizer, criterion, iter_unsup, entropy_criterion, meters, args): 263 | 264 | t0 = time.time() 265 | images, target = next(iter_sup) 266 | images, target = images.cuda(), target.cuda() 267 | data_time = time.time()-t0 268 | 269 | output = model(images) 270 | 271 | loss_all = 0 272 | loss_cls = criterion(output, target) 273 | meters['losses'].update(loss_cls.item(), images.size(0)) 274 | loss_cls.backward() 275 | loss_all += loss_cls.item() 276 | 277 | if iter_unsup is not None: 278 | ''' 279 | LOSSES for unlabeled samples 280 | ''' 281 | #TODO 282 | #images_unlabeled_all, images_unlabeled_all_aug = next(iter_unsup) 283 | #images_unlabeled_all, images_unlabeled_all_aug = images_unlabeled_all.cuda(), images_unlabeled_all_aug.cuda() 284 | 285 | #sub_batch_count = -( - images_unlabeled_all.size(0) // args.batch_size ) 286 | loss_unsup_all = 0 287 | for sub_batch_idx in range(args.unlabeled_iter): 288 | t1 = time.time() 289 | images_unlabeled, images_unlabeled_aug = next(iter_unsup) 290 | data_time += time.time() - t1 291 | images_unlabeled, images_unlabeled_aug = images_unlabeled.cuda(), images_unlabeled_aug.cuda() 292 | #images_unlabeled = images_unlabeled_all[ sub_batch_idx*args.batch_size: min( (sub_batch_idx+1)*args.batch_size, images_unlabeled_all.size(0))] 293 | #images_unlabeled_aug = images_unlabeled_all_aug[ sub_batch_idx*args.batch_size: min( (sub_batch_idx+1)*args.batch_size, images_unlabeled_all.size(0))] 294 | with torch.no_grad(): 295 | output_unlabeled = model(images_unlabeled) 296 | output_unlabeled_aug = model(images_unlabeled_aug) 297 | 298 | # Technique 1: entropy loss for augmented images 299 | entropy_weight = 1.0 300 | loss_entropy = entropy_weight * entropy_criterion(output_unlabeled_aug) 301 | meters['losses_entropy'].update( loss_entropy.item(), images_unlabeled.size(0) ) 302 | 303 | # Technique 2: Softmax temperature control for unsupervised loss 304 | temperature = 0.4 305 | loss_kl = torch.nn.functional.kl_div( 306 | torch.nn.functional.log_softmax(output_unlabeled_aug, dim=1), 307 | torch.nn.functional.softmax(output_unlabeled / temperature, dim=1).detach(), 308 | reduction='none') 309 | 310 | # Technique 3: confidence-based masking 311 | threshold = 0.5 312 | max_y_unlabeled = torch.max( torch.nn.functional.softmax( output_unlabeled / temperature, dim=1 ), 1, keepdim=True )[0] 313 | mask = (max_y_unlabeled > threshold).type(torch.cuda.FloatTensor) 314 | 315 | loss_kl = torch.sum(loss_kl * mask) / (mask.mean()+1e-8) 316 | 317 | loss_unsup = loss_entropy + loss_kl 318 | unsup_loss_weight = 20.0 319 | loss_unsup = loss_unsup / args.unlabeled_iter * unsup_loss_weight 320 | loss_unsup.backward() 321 | loss_unsup_all += loss_unsup.item() 322 | meters['losses_unsup'].update( loss_unsup_all, args.batch_size_unlabeled * args.unlabeled_iter ) 323 | 324 | 325 | # measure accuracy and record loss 326 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 327 | meters['top1'].update(acc1[0], images.size(0)) 328 | meters['top5'].update(acc5[0], images.size(0)) 329 | 330 | # compute gradient and do SGD step 331 | 332 | #loss_all.backward() 333 | optimizer.step() 334 | optimizer.zero_grad() 335 | 336 | # measure elapsed time 337 | meters['batch_time'].update(time.time() - t0 - data_time) 338 | meters['data_time'].update(data_time) 339 | 340 | 341 | class HLoss(nn.Module): 342 | def __init__(self): 343 | super(HLoss, self).__init__() 344 | 345 | def forward(self, x): 346 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 347 | b = -1.0 * b.sum() 348 | return b 349 | 350 | def validate(val_loader, model, criterion, args): 351 | batch_time = AverageMeter('Time', ':6.3f') 352 | losses = AverageMeter('Loss', ':.4e') 353 | top1 = AverageMeter('Acc@1', ':6.2f') 354 | top5 = AverageMeter('Acc@5', ':6.2f') 355 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 356 | prefix='Test: ') 357 | 358 | # switch to evaluate mode 359 | model.eval() 360 | 361 | with torch.no_grad(): 362 | end = time.time() 363 | for i, (images, target) in enumerate(val_loader): 364 | if args.gpu is not None: 365 | images = images.cuda(args.gpu, non_blocking=True) 366 | target = target.cuda(args.gpu, non_blocking=True) 367 | 368 | # compute output 369 | output = model(images) 370 | loss = criterion(output, target) 371 | 372 | # measure accuracy and record loss 373 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 374 | losses.update(loss.item(), images.size(0)) 375 | top1.update(acc1[0], images.size(0)) 376 | top5.update(acc5[0], images.size(0)) 377 | 378 | # measure elapsed time 379 | batch_time.update(time.time() - end) 380 | end = time.time() 381 | 382 | if i % args.print_freq == 0: 383 | progress.print(i) 384 | 385 | # TODO: this should also be done with the ProgressMeter 386 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 387 | .format(top1=top1, top5=top5)) 388 | 389 | return top1.avg 390 | 391 | 392 | def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'): 393 | os.makedirs(save_dir, exist_ok=True) 394 | torch.save(state, os.path.join(save_dir, filename)) 395 | if is_best: 396 | shutil.copyfile(os.path.join(save_dir, filename), os.path.join(save_dir, 'model_best.pth.tar')) 397 | 398 | 399 | class AverageMeter(object): 400 | """Computes and stores the average and current value""" 401 | def __init__(self, name, fmt=':f'): 402 | self.name = name 403 | self.fmt = fmt 404 | self.reset() 405 | 406 | def reset(self): 407 | self.val = 0 408 | self.avg = 0 409 | self.sum = 0 410 | self.count = 0 411 | 412 | def update(self, val, n=1): 413 | self.val = val 414 | self.sum += val * n 415 | self.count += n 416 | self.avg = self.sum / self.count 417 | 418 | def __str__(self): 419 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 420 | return fmtstr.format(**self.__dict__) 421 | 422 | 423 | class ProgressMeter(object): 424 | def __init__(self, num_batches, *meters, prefix=""): 425 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 426 | self.meters = meters 427 | self.prefix = prefix 428 | 429 | def print(self, batch): 430 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 431 | entries += [str(meter) for meter in self.meters] 432 | print('\t'.join(entries)) 433 | 434 | def _get_batch_fmtstr(self, num_batches): 435 | num_digits = len(str(num_batches // 1)) 436 | fmt = '{:' + str(num_digits) + 'd}' 437 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 438 | 439 | 440 | def adjust_learning_rate(optimizer, train_iter, args): 441 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 442 | if train_iter <= args.warmup_iter and args.warmup: 443 | # warmup 444 | lr = args.lr * ( float(train_iter) / float(args.warmup_iter) ) 445 | elif train_iter < args.lr_drop_iter[0]: 446 | lr = args.lr 447 | elif train_iter >= args.lr_drop_iter[0] and train_iter < args.lr_drop_iter[1]: 448 | lr = args.lr * 0.1 449 | elif train_iter >= args.lr_drop_iter[1] and train_iter < args.lr_drop_iter[2]: 450 | lr = args.lr * 0.01 451 | elif train_iter >= args.lr_drop_iter[2]: 452 | lr = args.lr * 0.001 453 | 454 | for param_group in optimizer.param_groups: 455 | param_group['lr'] = lr 456 | 457 | return lr 458 | 459 | 460 | def accuracy(output, target, topk=(1,)): 461 | """Computes the accuracy over the k top predictions for the specified values of k""" 462 | with torch.no_grad(): 463 | maxk = max(topk) 464 | batch_size = target.size(0) 465 | 466 | _, pred = output.topk(maxk, 1, True, True) 467 | pred = pred.t() 468 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 469 | 470 | res = [] 471 | for k in topk: 472 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 473 | res.append(correct_k.mul_(100.0 / batch_size)) 474 | return res 475 | 476 | 477 | if __name__ == '__main__': 478 | main() 479 | --------------------------------------------------------------------------------