├── .gitignore ├── README.md ├── __pycache__ ├── augment.cpython-38.pyc ├── cifar_remixmatch.cpython-38.pyc └── wideresnet.cpython-38.pyc ├── augment.py ├── cifar_remixmatch.py ├── remixmatch.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── eval.cpython-38.pyc │ ├── logger.cpython-38.pyc │ └── misc.cpython-38.pyc ├── eval.py ├── logger.py └── misc.py └── wideresnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiment/* 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReMixmatch-pytorch 2 | -------------------------------------------------------------------------------- /__pycache__/augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/__pycache__/augment.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/cifar_remixmatch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/__pycache__/cifar_remixmatch.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/wideresnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/__pycache__/wideresnet.cpython-38.pyc -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py 3 | """ 4 | 5 | # code in this file is adpated from 6 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 7 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 8 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 9 | 10 | import logging 11 | import random 12 | 13 | import numpy as np 14 | import PIL 15 | import PIL.ImageOps 16 | import PIL.ImageEnhance 17 | import PIL.ImageDraw 18 | from PIL import Image 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | PARAMETER_MAX = 10 23 | 24 | 25 | def AutoContrast(img, **kwarg): 26 | return PIL.ImageOps.autocontrast(img) 27 | 28 | 29 | def Brightness(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Brightness(img).enhance(v) 32 | 33 | 34 | def Color(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Color(img).enhance(v) 37 | 38 | 39 | def Contrast(img, v, max_v, bias=0): 40 | v = _float_parameter(v, max_v) + bias 41 | return PIL.ImageEnhance.Contrast(img).enhance(v) 42 | 43 | 44 | def Cutout(img, v, max_v, bias=0): 45 | if v == 0: 46 | return img 47 | v = _float_parameter(v, max_v) + bias 48 | v = int(v * min(img.size)) 49 | return CutoutAbs(img, v) 50 | 51 | 52 | def CutoutAbs(img, v, **kwarg): 53 | w, h = img.size 54 | x0 = np.random.uniform(0, w) 55 | y0 = np.random.uniform(0, h) 56 | x0 = int(max(0, x0 - v / 2.)) 57 | y0 = int(max(0, y0 - v / 2.)) 58 | x1 = int(min(w, x0 + v)) 59 | y1 = int(min(h, y0 + v)) 60 | xy = (x0, y0, x1, y1) 61 | # gray 62 | color = (127, 127, 127) 63 | img = img.copy() 64 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 65 | return img 66 | 67 | 68 | def Equalize(img, **kwarg): 69 | return PIL.ImageOps.equalize(img) 70 | 71 | 72 | def Identity(img, **kwarg): 73 | return img 74 | 75 | 76 | def Invert(img, **kwarg): 77 | return PIL.ImageOps.invert(img) 78 | 79 | 80 | def Posterize(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | return PIL.ImageOps.posterize(img, v) 83 | 84 | 85 | def Rotate(img, v, max_v, bias=0): 86 | v = _int_parameter(v, max_v) + bias 87 | if random.random() < 0.5: 88 | v = -v 89 | return img.rotate(v) 90 | 91 | 92 | def Sharpness(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 95 | 96 | 97 | def ShearX(img, v, max_v, bias=0): 98 | v = _float_parameter(v, max_v) + bias 99 | if random.random() < 0.5: 100 | v = -v 101 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 102 | 103 | 104 | def ShearY(img, v, max_v, bias=0): 105 | v = _float_parameter(v, max_v) + bias 106 | if random.random() < 0.5: 107 | v = -v 108 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 109 | 110 | 111 | def Solarize(img, v, max_v, bias=0): 112 | v = _int_parameter(v, max_v) + bias 113 | return PIL.ImageOps.solarize(img, 256 - v) 114 | 115 | 116 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 117 | v = _int_parameter(v, max_v) + bias 118 | if random.random() < 0.5: 119 | v = -v 120 | img_np = np.array(img).astype(np.int) 121 | img_np = img_np + v 122 | img_np = np.clip(img_np, 0, 255) 123 | img_np = img_np.astype(np.uint8) 124 | img = Image.fromarray(img_np) 125 | return PIL.ImageOps.solarize(img, threshold) 126 | 127 | 128 | def TranslateX(img, v, max_v, bias=0): 129 | v = _float_parameter(v, max_v) + bias 130 | if random.random() < 0.5: 131 | v = -v 132 | v = int(v * img.size[0]) 133 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 134 | 135 | 136 | def TranslateY(img, v, max_v, bias=0): 137 | v = _float_parameter(v, max_v) + bias 138 | if random.random() < 0.5: 139 | v = -v 140 | v = int(v * img.size[1]) 141 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 142 | 143 | 144 | def _float_parameter(v, max_v): 145 | return float(v) * max_v / PARAMETER_MAX 146 | 147 | 148 | def _int_parameter(v, max_v): 149 | return int(v * max_v / PARAMETER_MAX) 150 | 151 | 152 | def fixmatch_augment_pool(): 153 | # FixMatch paper 154 | augs = [(AutoContrast, None, None), 155 | (Brightness, 0.9, 0.05), 156 | (Color, 0.9, 0.05), 157 | (Contrast, 0.9, 0.05), 158 | (Equalize, None, None), 159 | (Identity, None, None), 160 | (Posterize, 4, 4), 161 | (Rotate, 30, 0), 162 | (Sharpness, 0.9, 0.05), 163 | (ShearX, 0.3, 0), 164 | (ShearY, 0.3, 0), 165 | (Solarize, 256, 0), 166 | (TranslateX, 0.3, 0), 167 | (TranslateY, 0.3, 0)] 168 | return augs 169 | 170 | 171 | def my_augment_pool(): 172 | # Test 173 | augs = [(AutoContrast, None, None), 174 | (Brightness, 1.8, 0.1), 175 | (Color, 1.8, 0.1), 176 | (Contrast, 1.8, 0.1), 177 | (Cutout, 0.2, 0), 178 | (Equalize, None, None), 179 | (Invert, None, None), 180 | (Posterize, 4, 4), 181 | (Rotate, 30, 0), 182 | (Sharpness, 1.8, 0.1), 183 | (ShearX, 0.3, 0), 184 | (ShearY, 0.3, 0), 185 | (Solarize, 256, 0), 186 | (SolarizeAdd, 110, 0), 187 | (TranslateX, 0.45, 0), 188 | (TranslateY, 0.45, 0)] 189 | return augs 190 | 191 | 192 | class RandAugmentPC(object): 193 | def __init__(self, n, m): 194 | assert n >= 1 195 | assert 1 <= m <= 10 196 | self.n = n 197 | self.m = m 198 | self.augment_pool = my_augment_pool() 199 | 200 | def __call__(self, img): 201 | ops = random.choices(self.augment_pool, k=self.n) 202 | for op, max_v, bias in ops: 203 | prob = np.random.uniform(0.2, 0.8) 204 | if random.random() + prob >= 1: 205 | img = op(img, v=self.m, max_v=max_v, bias=bias) 206 | img = CutoutAbs(img, 16) 207 | return img 208 | 209 | 210 | class RandAugmentMC(object): 211 | def __init__(self, n, m): 212 | assert n >= 1 213 | assert 1 <= m <= 10 214 | self.n = n 215 | self.m = m 216 | self.augment_pool = fixmatch_augment_pool() 217 | 218 | def __call__(self, img): 219 | ops = random.choices(self.augment_pool, k=self.n) 220 | for op, max_v, bias in ops: 221 | v = np.random.randint(1, self.m) 222 | if random.random() < 0.5: 223 | img = op(img, v=v, max_v=max_v, bias=bias) 224 | img = CutoutAbs(img, 16) 225 | return img -------------------------------------------------------------------------------- /cifar_remixmatch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | 8 | import augment 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | cifar10_mean = (0.4914, 0.4822, 0.4465) 13 | cifar10_std = (0.2471, 0.2435, 0.2616) 14 | cifar100_mean = (0.5071, 0.4867, 0.4408) 15 | cifar100_std = (0.2675, 0.2565, 0.2761) 16 | normal_mean = (0.5, 0.5, 0.5) 17 | normal_std = (0.5, 0.5, 0.5) 18 | 19 | 20 | def get_cifar10(root, num_labeled, num_expand_x, num_expand_u): 21 | #base dataset 22 | base_dataset = datasets.CIFAR10(root, train=True, download=True) 23 | 24 | # split dataset labeled,unlabeled 25 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 26 | base_dataset.targets, num_labeled, num_expand_x, num_expand_u, num_classes=10) 27 | 28 | # Augment dataset 29 | # labeled data, strong augmentatinon 30 | transform_labeled = transforms.Compose([ 31 | transforms.RandomHorizontalFlip(), 32 | transforms.RandomCrop(size=32, 33 | padding=int(32*0.125), 34 | padding_mode='reflect'), 35 | augment.RandAugmentMC(n=2, m=10), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 38 | ]) 39 | 40 | 41 | train_labeled_dataset = CIFAR10SSL( 42 | root, train_labeled_idxs, train=True, 43 | transform=transform_labeled 44 | ) 45 | 46 | # unlabeled data, strong + weak augmentation 47 | train_unlabeled_dataset = CIFAR10SSL( 48 | root, train_unlabeled_idxs, train=True, 49 | transform=TransformRemix(mean=cifar10_mean, std=cifar10_std) 50 | ) 51 | 52 | # validation data 53 | transform_val = transforms.Compose([ 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 56 | ]) 57 | 58 | # dataset 59 | test_dataset = datasets.CIFAR10( 60 | root, train=False, transform=transform_val, download=False) 61 | logger.info("Dataset: CIFAR10") 62 | logger.info(f"Labeled examples: {len(train_labeled_idxs)}" 63 | f" Unlabeled examples: {len(train_unlabeled_idxs)}") 64 | 65 | return train_labeled_dataset,train_unlabeled_dataset, test_dataset 66 | 67 | 68 | def get_cifar100(root, num_labeled, num_expand_x, num_expand_u): 69 | 70 | transform_labeled = transforms.Compose([ 71 | transforms.RandomHorizontalFlip(), 72 | transforms.RandomCrop(size=32, 73 | padding=int(32*0.125), 74 | padding_mode='reflect'), 75 | transforms.ToTensor(), 76 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 77 | 78 | transform_val = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 81 | 82 | base_dataset = datasets.CIFAR100( 83 | root, train=True, download=True) 84 | 85 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 86 | base_dataset.targets, num_labeled, num_classes=100) 87 | 88 | train_labeled_dataset = CIFAR100SSL( 89 | root, train_labeled_idxs, train=True, 90 | transform=transform_labeled) 91 | 92 | train_unlabeled_dataset = CIFAR100SSL( 93 | root, train_unlabeled_idxs, train=True, 94 | transform=TransformRemix(mean=cifar100_mean, std=cifar100_std)) 95 | 96 | test_dataset = datasets.CIFAR100( 97 | root, train=False, transform=transform_val, download=False) 98 | 99 | logger.info("Dataset: CIFAR100") 100 | logger.info(f"Labeled examples: {len(train_labeled_idxs)}" 101 | f" Unlabeled examples: {len(train_unlabeled_idxs)}") 102 | 103 | return train_labeled_dataset, train_unlabeled_dataset, test_dataset 104 | 105 | 106 | def x_u_split(labels, 107 | num_labeled, 108 | num_expand_x, 109 | num_expand_u, 110 | num_classes): 111 | label_per_class = num_labeled // num_classes 112 | labels = np.array(labels) 113 | labeled_idx = [] 114 | unlabeled_idx = [] 115 | for i in range(num_classes): 116 | idx = np.where(labels == i)[0] 117 | np.random.shuffle(idx) 118 | labeled_idx.extend(idx[:label_per_class]) 119 | unlabeled_idx.extend(idx[label_per_class:]) 120 | 121 | exapand_labeled = num_expand_x // len(labeled_idx) 122 | exapand_unlabeled = num_expand_u // len(unlabeled_idx) 123 | labeled_idx = np.hstack( 124 | [labeled_idx for _ in range(exapand_labeled)]) 125 | unlabeled_idx = np.hstack( 126 | [unlabeled_idx for _ in range(exapand_unlabeled)]) 127 | 128 | if len(labeled_idx) < num_expand_x: 129 | diff = num_expand_x - len(labeled_idx) 130 | labeled_idx = np.hstack( 131 | (labeled_idx, np.random.choice(labeled_idx, diff))) 132 | else: 133 | assert len(labeled_idx) == num_expand_x 134 | 135 | if len(unlabeled_idx) < num_expand_u: 136 | diff = num_expand_u - len(unlabeled_idx) 137 | unlabeled_idx = np.hstack( 138 | (unlabeled_idx, np.random.choice(unlabeled_idx, diff))) 139 | else: 140 | assert len(unlabeled_idx) == num_expand_u 141 | 142 | return labeled_idx, unlabeled_idx 143 | 144 | 145 | class TransformRemix(object): 146 | def __init__(self, mean, std): 147 | self.weak = transforms.Compose([ 148 | transforms.RandomHorizontalFlip(), 149 | transforms.RandomCrop(size=32, 150 | padding=int(32*0.125), 151 | padding_mode='reflect')]) 152 | self.strong = transforms.Compose([ 153 | transforms.RandomHorizontalFlip(), 154 | transforms.RandomCrop(size=32, 155 | padding=int(32*0.125), 156 | padding_mode='reflect'), 157 | augment.RandAugmentMC(n=2, m=10)]) 158 | self.normalize = transforms.Compose([ 159 | transforms.ToTensor(), 160 | transforms.Normalize(mean=mean, std=std)]) 161 | 162 | def __call__(self, x): 163 | weak = self.weak(x) 164 | strong = self.strong(x) 165 | return self.normalize(weak), self.normalize(strong) 166 | 167 | 168 | class CIFAR10SSL(datasets.CIFAR10): 169 | def __init__(self, root, indexs, train=True, 170 | transform=None, target_transform=None, 171 | download=False): 172 | super().__init__(root, train=train, 173 | transform=transform, 174 | target_transform=target_transform, 175 | download=download) 176 | if indexs is not None: 177 | self.data = self.data[indexs] 178 | self.targets = np.array(self.targets)[indexs] 179 | 180 | def __getitem__(self, index): 181 | img, target = self.data[index], self.targets[index] 182 | img = Image.fromarray(img) 183 | 184 | if self.transform is not None: 185 | img = self.transform(img) 186 | 187 | if self.target_transform is not None: 188 | target = self.target_transform(target) 189 | 190 | return img, target 191 | 192 | 193 | class CIFAR100SSL(datasets.CIFAR100): 194 | def __init__(self, root, indexs, train=True, 195 | transform=None, target_transform=None, 196 | download=False): 197 | super().__init__(root, train=train, 198 | transform=transform, 199 | target_transform=target_transform, 200 | download=download) 201 | if indexs is not None: 202 | self.data = self.data[indexs] 203 | self.targets = np.array(self.targets)[indexs] 204 | 205 | def __getitem__(self, index): 206 | img, target = self.data[index], self.targets[index] 207 | img = Image.fromarray(img) 208 | 209 | if self.transform is not None: 210 | img = self.transform(img) 211 | 212 | if self.target_transform is not None: 213 | target = self.target_transform(target) 214 | 215 | return img, target -------------------------------------------------------------------------------- /remixmatch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | from copy import deepcopy 9 | from collections import OrderedDict 10 | 11 | import numpy as np 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.distributions as dist 17 | import torch.optim as optim 18 | from torch.optim.lr_scheduler import LambdaLR 19 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 20 | from torch.utils.data.distributed import DistributedSampler 21 | from torch.utils.tensorboard import SummaryWriter 22 | from tqdm import tqdm 23 | 24 | from torchvision import transforms 25 | import random 26 | 27 | from cifar_remixmatch import get_cifar10, get_cifar100 28 | from utils import AverageMeter, accuracy 29 | 30 | import wideresnet as models 31 | 32 | import pdb 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | DATASET_GETTERS = {'cifar10': get_cifar10, 37 | 'cifar100': get_cifar100} 38 | 39 | best_acc = 0 40 | 41 | 42 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'): 43 | filepath = os.path.join(checkpoint, filename) 44 | torch.save(state, filepath) 45 | if is_best: 46 | shutil.copyfile(filepath, os.path.join(checkpoint, 47 | 'model_best.pth.tar')) 48 | 49 | 50 | def set_seed(args): 51 | random.seed(args.seed) 52 | np.random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | if args.n_gpu > 0: 55 | torch.cuda.manual_seed_all(args.seed) 56 | 57 | 58 | def get_cosine_schedule_with_warmup(optimizer, 59 | num_warmup_steps, 60 | num_training_steps, 61 | num_cycles=7./16., 62 | last_epoch=-1): 63 | def _lr_lambda(current_step): 64 | if current_step < num_warmup_steps: 65 | return float(current_step) / float(max(1, num_warmup_steps)) 66 | no_progress = float(current_step - num_warmup_steps) / \ 67 | float(max(1, num_training_steps - num_warmup_steps)) 68 | return max(0., math.cos(math.pi * num_cycles * no_progress)) 69 | 70 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 71 | 72 | 73 | def main(): 74 | parser = argparse.ArgumentParser(description='PyTorch ReMixMatch Training') 75 | parser.add_argument('--gpu-id', default='0', type=int, 76 | help='id(s) for CUDA_VISIBLE_DEVICES') 77 | parser.add_argument('--num-workers', type=int, default=4, 78 | help='number of workers') 79 | parser.add_argument('--dataset', default='cifar10', type=str, 80 | choices=['cifar10', 'cifar100'], 81 | help='dataset name') 82 | parser.add_argument('--num-labeled', type=int, default=4000, 83 | help='number of labeled data') 84 | parser.add_argument('--arch', default='wideresnet', type=str, 85 | choices=['wideresnet', 'resnext'], 86 | help='dataset name') 87 | parser.add_argument('--epochs', default=500, type=int, 88 | help='number of total epochs to run') 89 | parser.add_argument('--start-epoch', default=0, type=int, 90 | help='manual epoch number (useful on restarts)') 91 | parser.add_argument('--batch-size', default=64, type=int, 92 | help='train batchsize') 93 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 94 | help='initial learning rate') 95 | parser.add_argument('--warmup', default=0, type=float, 96 | help='warmup epochs (unlabeled data based)') 97 | parser.add_argument('--wdecay', default=5e-4, type=float, 98 | help='weight decay') 99 | parser.add_argument('--nesterov', action='store_true', default=True, 100 | help='use nesterov momentum') 101 | parser.add_argument('--use-ema', action='store_true', default=True, 102 | help='use EMA model') 103 | parser.add_argument('--ema-decay', default=0.999, type=float, 104 | help='EMA decay rate') 105 | parser.add_argument('--mu', default=7, type=int, 106 | help='coefficient of unlabeled batch size') 107 | parser.add_argument('--lambda-u', default=1, type=float, 108 | help='coefficient of unlabeled loss') 109 | parser.add_argument('--threshold', default=0.95, type=float, 110 | help='pseudo label threshold') 111 | parser.add_argument('--k-img', default=65536, type=int, 112 | help='number of labeled examples') 113 | parser.add_argument('--out', default='experiment/', 114 | help='directory to output the result') 115 | parser.add_argument('--resume', default='', type=str, 116 | help='path to latest checkpoint (default: none)') 117 | parser.add_argument('--seed', type=int, default=-1, 118 | help="random seed (-1: don't use random seed)") 119 | parser.add_argument("--amp", action="store_true", 120 | help="use 16-bit (mixed) precision through NVIDIA apex AMP") 121 | parser.add_argument("--opt_level", type=str, default="O1", 122 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 123 | "See details at https://nvidia.github.io/apex/amp.html") 124 | parser.add_argument("--local_rank", type=int, default=-1, 125 | help="For distributed training: local_rank") 126 | parser.add_argument('--no-progress', action='store_true', 127 | help="don't use progress bar"), 128 | parser.add_argument("--beta", type=float, default=0.5, 129 | help="mixup rate") 130 | 131 | 132 | args = parser.parse_args() 133 | 134 | global best_acc 135 | 136 | def create_model(args): 137 | if args.arch == 'wideresnet': 138 | model = models.build_wideresnet(depth=args.model_depth, 139 | widen_factor=args.model_width, 140 | dropout=0, 141 | num_classes=args.num_classes) 142 | ''' 143 | elif args.arch == 'resnext': 144 | import models.resnext as models 145 | model = models.build_resnext(cardinality=args.model_cardinality, 146 | depth=args.model_depth, 147 | width=args.model_width, 148 | num_classes=args.num_classes) 149 | ''' 150 | logger.info("Total params: {:.2f}M".format( 151 | sum(p.numel() for p in model.parameters())/1e6)) 152 | 153 | return model 154 | 155 | if args.local_rank == -1: 156 | device = torch.device('cuda', args.gpu_id) 157 | args.world_size = 1 158 | args.n_gpu = torch.cuda.device_count() 159 | else: 160 | torch.cuda.set_device(args.local_rank) 161 | device = torch.device('cuda', args.local_rank) 162 | torch.distributed.init_process_group(backend='nccl') 163 | args.world_size = torch.distributed.get_world_size() 164 | args.n_gpu = 1 165 | 166 | args.device = device 167 | 168 | if args.dataset == 'cifar10': 169 | args.num_classes = 10 170 | if args.arch == 'wideresnet': 171 | args.model_depth = 28 172 | args.model_width = 2 173 | if args.arch == 'resnext': 174 | args.model_cardinality = 4 175 | args.model_depth = 28 176 | args.model_width = 4 177 | 178 | elif args.dataset == 'cifar100': 179 | args.num_classes = 100 180 | if args.arch == 'wideresnet': 181 | args.model_depth = 28 182 | args.model_width = 10 183 | if args.arch == 'resnext': 184 | args.model_cardinality = 8 185 | args.model_depth = 29 186 | args.model_width = 64 187 | 188 | logging.basicConfig( 189 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 190 | datefmt="%m/%d/%Y %H:%M:%S", 191 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 192 | 193 | logger.warning( 194 | f"Process rank: {args.local_rank}, " 195 | f"device: {args.device}, " 196 | f"n_gpu: {args.n_gpu}, " 197 | f"distributed training: {bool(args.local_rank != -1)}, " 198 | f"16-bits training: {args.amp}",) 199 | 200 | logger.info(dict(args._get_kwargs())) 201 | 202 | if args.seed != -1: 203 | set_seed(args) 204 | 205 | args.out = args.out + "result_"+str(args.num_labeled) 206 | if args.local_rank in [-1, 0]: 207 | os.makedirs(args.out, exist_ok=True) 208 | writer = SummaryWriter(args.out) 209 | 210 | if args.local_rank not in [-1, 0]: 211 | torch.distributed.barrier() 212 | 213 | labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset]( 214 | '~/htlim/data/'+args.dataset, args.num_labeled, args.k_img, args.k_img * args.mu) 215 | 216 | model = create_model(args) 217 | 218 | ''' 219 | model_rot = models.classifier(in_channel=3, num_classes=4, filters=32) 220 | optimizer_rot = optim.SGD(model_rot.parameters(), lr=args.lr, 221 | momentum=0.9, nesterov=args.nesterov) 222 | 223 | model_rot.to(args.device) 224 | ''' 225 | #multi GPU 226 | ''' 227 | if torch.cuda.device_count() > 1: 228 | model = nn.DataParallel(model) 229 | model_rot = nn.DataParallel(model_rot) 230 | ''' 231 | 232 | if args.local_rank == 0: 233 | torch.distributed.barrier() 234 | 235 | model.to(args.device) 236 | 237 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler 238 | 239 | labeled_trainloader = DataLoader( 240 | labeled_dataset, 241 | sampler=train_sampler(labeled_dataset), 242 | batch_size=args.batch_size, 243 | num_workers=args.num_workers, 244 | drop_last=True) 245 | 246 | unlabeled_trainloader = DataLoader( 247 | unlabeled_dataset, 248 | sampler=train_sampler(unlabeled_dataset), 249 | batch_size=args.batch_size*args.mu, 250 | num_workers=args.num_workers, 251 | drop_last=True) 252 | 253 | test_loader = DataLoader( 254 | test_dataset, 255 | sampler=SequentialSampler(test_dataset), 256 | batch_size=args.batch_size, 257 | num_workers=args.num_workers) 258 | 259 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 260 | momentum=0.9, nesterov=args.nesterov) 261 | 262 | args.iteration = args.k_img // args.batch_size // args.world_size 263 | args.total_steps = args.epochs * args.iteration 264 | scheduler = get_cosine_schedule_with_warmup( 265 | optimizer, args.warmup * args.iteration, args.total_steps) 266 | 267 | if args.use_ema: 268 | ema_model = ModelEMA(args, model, args.ema_decay, device) 269 | 270 | start_epoch = 0 271 | 272 | if args.local_rank not in [-1, 0]: 273 | torch.distributed.barrier() 274 | 275 | if args.resume: 276 | logger.info("==> Resuming from checkpoint..") 277 | assert os.path.isfile( 278 | args.resume), "Error: no checkpoint directory found!" 279 | args.out = os.path.dirname(args.resume) 280 | checkpoint = torch.load(args.resume) 281 | best_acc = checkpoint['best_acc'] 282 | start_epoch = checkpoint['epoch'] 283 | model.load_state_dict(checkpoint['state_dict']) 284 | #model_rot.load_state_dict(checkpoint['state_dict2']) 285 | if args.use_ema: 286 | ema_model.ema.load_state_dict(checkpoint['ema_state_dict']) 287 | optimizer.load_state_dict(checkpoint['optimizer']) 288 | scheduler.load_state_dict(checkpoint['scheduler']) 289 | 290 | if args.local_rank == 0: 291 | torch.distributed.barrier() 292 | 293 | if args.amp: 294 | from apex import amp 295 | model, optimizer = amp.initialize( 296 | model, optimizer, opt_level=args.opt_level) 297 | 298 | if args.local_rank != -1: 299 | model = torch.nn.parallel.DistributedDataParallel( 300 | model, device_ids=[args.local_rank], 301 | output_device=args.local_rank, find_unused_parameters=True) 302 | 303 | logger.info("***** Running training *****") 304 | logger.info(f" Task = {args.dataset}@{args.num_labeled}") 305 | logger.info(f" Num Epochs = {args.epochs}") 306 | logger.info(f" Batch size per GPU = {args.batch_size}") 307 | logger.info( 308 | f" Total train batch size = {args.batch_size*args.world_size}") 309 | logger.info(f" Total optimization steps = {args.total_steps}") 310 | 311 | test_accs = [] 312 | 313 | 314 | model.zero_grad() 315 | #model_rot.zero_grad() 316 | 317 | for epoch in range(start_epoch, args.epochs): 318 | train_loss, train_loss_x, train_loss_u, train_loss_us, train_loss_r = train( 319 | args, labeled_trainloader,unlabeled_trainloader, 320 | model, optimizer, ema_model, scheduler, epoch) 321 | 322 | if args.no_progress: 323 | logger.info("Epoch {}. train_loss: {:.4f}. train_loss_x: {:.4f}. train_loss_u: {:.4f}. train_loss_us: {:.4f}. train_loss_r: {:.4f}." 324 | .format(epoch+1, train_loss, train_loss_x, train_loss_u, train_loss_us, train_loss_r)) 325 | 326 | if args.use_ema: 327 | test_model = ema_model.ema 328 | else: 329 | test_model = model 330 | 331 | test_loss, test_acc = test(args, test_loader, test_model, epoch) 332 | 333 | if args.local_rank in [-1, 0]: 334 | writer.add_scalar('train/1.train_loss', train_loss, epoch) 335 | writer.add_scalar('train/2.train_loss_x', train_loss_x, epoch) 336 | writer.add_scalar('train/3.train_loss_u', train_loss_u, epoch) 337 | writer.add_scalar('train/3.train_loss_us', train_loss_us, epoch) 338 | writer.add_scalar('train/3.train_loss_r', train_loss_r, epoch) 339 | #writer.add_scalar('train/4.mask', mask_prob, epoch) 340 | writer.add_scalar('test/1.test_acc', test_acc, epoch) 341 | writer.add_scalar('test/2.test_loss', test_loss, epoch) 342 | 343 | is_best = test_acc > best_acc 344 | best_acc = max(test_acc, best_acc) 345 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 346 | model_to_save = model.module if hasattr(model, "module") else model 347 | #model_rot_to_save = model_rot.module if hasattr(model_rot, "module") else model_rot 348 | if args.use_ema: 349 | ema_to_save = ema_model.ema.module if hasattr( 350 | ema_model.ema, "module") else ema_model.ema 351 | save_checkpoint({ 352 | 'epoch': epoch + 1, 353 | 'state_dict': model_to_save.state_dict(), 354 | #'state_dict2':model_rot_to_save.state_dict(), 355 | 'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None, 356 | 'acc': test_acc, 357 | 'best_acc': best_acc, 358 | 'optimizer': optimizer.state_dict(), 359 | #'optimizer_rot':optimizer_rot.state_dict(), 360 | 'scheduler': scheduler.state_dict(), 361 | }, is_best, args.out) 362 | 363 | test_accs.append(test_acc) 364 | logger.info('Best top-1 acc: {:.2f}'.format(best_acc)) 365 | logger.info('Mean top-1 acc: {:.2f}\n'.format( 366 | np.mean(test_accs[-20:]))) 367 | 368 | if args.local_rank in [-1, 0]: 369 | writer.close() 370 | 371 | def random_rotate(x): 372 | b4 = x.shape[0] // 4 373 | x1 = x[:b4] 374 | x2 = torch.rot90(x[b4:], 1, [2,3]) 375 | x3 = torch.rot90(x2[b4:], 1, [2,3]) 376 | x4 = torch.rot90(x3[b4:], 1, [2,3]) 377 | l = np.zeros(b4, np.int32) 378 | l = torch.from_numpy(np.concatenate([l, l + 1, l + 2, l + 3])) 379 | return torch.cat((x1,x2[:b4],x3[:b4],x4), dim=0), l 380 | 381 | ''' 382 | def random_rotate(x): 383 | b4 = x.shape[0] // 4 384 | x, xt = x[:2 * b4], torch.transpose(x[2 * b4:], 3,2) 385 | l = np.zeros(b4, np.int32) 386 | l = torch.from_numpy(np.concatenate([l, l + 1, l + 2, l + 3])) 387 | return np.concatenate([x[:b4], torch.flip(x[b4:],[1]), torch.flip(xt[:b4], [0,1]), torch.flip(xt[b4:],[0])], axis=0), l 388 | ''' 389 | 390 | def train(args, labeled_trainloader, unlabeled_trainloader, 391 | model, optimizer, ema_model, scheduler, epoch): 392 | if args.amp: 393 | from apex import amp 394 | batch_time = AverageMeter() 395 | data_time = AverageMeter() 396 | losses = AverageMeter() 397 | losses_x = AverageMeter() 398 | losses_u = AverageMeter() 399 | losses_us = AverageMeter() 400 | losses_r = AverageMeter() 401 | losses_ua = AverageMeter() 402 | end = time.time() 403 | 404 | if not args.no_progress: 405 | p_bar = tqdm(range(args.iteration), 406 | disable=args.local_rank not in [-1, 0]) 407 | 408 | train_loader = zip(labeled_trainloader, unlabeled_trainloader) 409 | 410 | criterion = nn.CrossEntropyLoss() 411 | 412 | model.train() 413 | #model_rot.train() 414 | 415 | for batch_idx, (data_x, data_u) in enumerate(train_loader): #data_x : labeled data , data_u : unlabeled data 416 | inputs_x, targets_x = data_x 417 | (inputs_u_w, inputs_u_s), targets_uo = data_u #inputs_u_w : unlabeled weak augmentation, inputs_u_s : unlabeled strong augmentation 418 | data_time.update(time.time() - end) 419 | batch_size = inputs_x.shape[0] 420 | 421 | # unlabeled inputs 422 | #inputs_u = torch.cat((inputs_u_w, inputs_u_s),dim=0) 423 | 424 | 425 | # rotation strong augmented unlabeled data 426 | rot_y, rot_l = random_rotate(inputs_u_s) 427 | rot_y = rot_y.to(args.device) 428 | logits_rot = model(rot_y) 429 | 430 | rot_l = rot_l.type(torch.int64) 431 | rot_l = rot_l.to(args.device) 432 | 433 | # supervised loss 434 | logits_x = model(inputs_x.to(args.device)) 435 | 436 | # pseudo label 437 | with torch.no_grad(): 438 | # compute guessed labels of unlabel samples 439 | #need to fix for distribution alignment 440 | logits_u_temp = model(inputs_u_w.to(args.device)) 441 | #q = torch.cat((torch.softmax(logits_u_w, dim=1), torch.softmax(logits_u_s, dim=1)), dim=0) 442 | q = torch.softmax(logits_u_temp,dim=1) 443 | q = q * (torch.softmax(logits_x, dim=1)).mean()/q.mean() 444 | 445 | pt = q**(1/0.5) 446 | targets_u = pt / pt.sum(dim=1, keepdim=True) 447 | pseudo_label = targets_u.detach() 448 | 449 | 450 | _, targets_u = torch.max(pseudo_label, dim=-1) 451 | targets_u = targets_u.long() 452 | 453 | # unlabeled strong loss 454 | logits_u_s = model(inputs_u_s.to(args.device)) 455 | 456 | 457 | # concat & shuffle for mixup 458 | inputs_m = torch.cat((inputs_x,inputs_u_s),dim=0).to(args.device) 459 | #targets_m= torch.cat((targets_x.to(args.device), targets_u),dim=0).to(args.device) 460 | 461 | l = 0.75 462 | 463 | idx = torch.randperm(inputs_m.size(0)) 464 | 465 | input_a, input_b = inputs_m, inputs_m[idx] 466 | #target_a, target_b = targets_m, targets_m[idx] 467 | mixed_input = l*input_a + (1-l)*input_b 468 | #mixed_target = l*target_a + (1-l)*target_b 469 | 470 | # interleave labeled and unlabeled samples between batches to get correct batchnorm calculation 471 | mixed_input = list(torch.split(mixed_input,batch_size)) # mixed_input( 3 * batch_size) 를 batch_size개수만큼씩 자른다. 472 | mixed_input = interleave(mixed_input,batch_size) # mixed_output 도 마찬가지 473 | logits = [model(mixed_input[0])] 474 | for input in mixed_input[1:]: 475 | logits.append(model(input)) # now logits is list. 476 | # logits[0] : p_b 477 | # logits[1:] : q_b 478 | # put interleaved samples back 479 | logits = interleave(logits,batch_size) 480 | logits_m_x = logits[0] 481 | logits_m_u = torch.cat(logits[1:],dim=0) 482 | 483 | #logits_m = model(mixed_input) 484 | 485 | #mixed_target = mixed_target.long() 486 | 487 | # unlabeld mixup loss 488 | #Lu = criterion(logits_m_u[len(logits_m_u)//2:], mixed_target[batch_size//2+len(mixed_target)//2:]) 489 | 490 | # loss 491 | #Lx = criterion(logits_x,targets_x.to(args.device)) 492 | 493 | # rotation losss 494 | loss_rot = F.cross_entropy(logits_rot, rot_l) 495 | #loss_rot=0 496 | Lx = criterion(logits_m_x, targets_x.to(args.device)) 497 | #Lu = criterion(logits_m_u, mixed_target[batch_size:]) 498 | Lu = criterion(logits_m_u, targets_u) 499 | Lus = (F.cross_entropy(logits_u_s, targets_u,reduction='none')).mean() 500 | 501 | # unlabeled guessed label answer check 502 | 503 | cnt=0 504 | for i,j in zip(targets_u, targets_uo): 505 | if i ==j: 506 | cnt += 1 507 | Lua = cnt / targets_u.size()[0] 508 | 509 | loss = Lx + 1.5*Lu + 0.5*Lus + 0.5*loss_rot 510 | 511 | losses.update(loss.item()) 512 | losses_x.update(Lx.item()) 513 | losses_u.update(Lu.item()) 514 | losses_us.update(Lus.item()) 515 | losses_ua.update(Lua) 516 | losses_r.update(loss_rot.item()) 517 | 518 | if args.amp: 519 | with amp.scale_loss(loss, optimizer) as scaled_loss: 520 | scaled_loss.backward() 521 | else: 522 | loss.backward() 523 | 524 | optimizer.step() 525 | scheduler.step() 526 | if args.use_ema: 527 | ema_model.update(model) 528 | 529 | model.zero_grad() 530 | #model_rot.zero_grad() 531 | 532 | batch_time.update(time.time() - end) 533 | end = time.time() 534 | #mask_prob = mask.mean().item() 535 | if not args.no_progress: 536 | p_bar.set_description("Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.6f}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. Loss_x: {loss_x:.4f}. Loss_u: {loss_u:.4f}. Loss_us: {loss_us:.4f}. Loss_ua: {loss_ua:.4f}. Loss_r: {loss_r:.4f}.".format( 537 | epoch=epoch + 1, 538 | epochs=args.epochs, 539 | batch=batch_idx + 1, 540 | iter=args.iteration, 541 | lr=scheduler.get_last_lr()[0], 542 | data=data_time.avg, 543 | bt=batch_time.avg, 544 | loss=losses.avg, 545 | loss_x=losses_x.avg, 546 | loss_u=losses_u.avg, 547 | loss_us=losses_us.avg, 548 | loss_ua=losses_ua.avg, 549 | loss_r=losses_r.avg 550 | #mask=mask_prob 551 | )) 552 | p_bar.update() 553 | if not args.no_progress: 554 | p_bar.close() 555 | return losses.avg, losses_x.avg, losses_u.avg, losses_us.avg, losses_r.avg 556 | 557 | 558 | def test(args, test_loader, model, epoch): 559 | batch_time = AverageMeter() 560 | data_time = AverageMeter() 561 | losses = AverageMeter() 562 | top1 = AverageMeter() 563 | top5 = AverageMeter() 564 | end = time.time() 565 | 566 | if not args.no_progress: 567 | test_loader = tqdm(test_loader, 568 | disable=args.local_rank not in [-1, 0]) 569 | 570 | with torch.no_grad(): 571 | for batch_idx, (inputs, targets) in enumerate(test_loader): 572 | data_time.update(time.time() - end) 573 | model.eval() 574 | 575 | inputs = inputs.to(args.device) 576 | targets = targets.to(args.device) 577 | outputs = model(inputs) 578 | loss = F.cross_entropy(outputs, targets) 579 | 580 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 581 | losses.update(loss.item(), inputs.shape[0]) 582 | top1.update(prec1.item(), inputs.shape[0]) 583 | top5.update(prec5.item(), inputs.shape[0]) 584 | batch_time.update(time.time() - end) 585 | end = time.time() 586 | if not args.no_progress: 587 | test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format( 588 | batch=batch_idx + 1, 589 | iter=len(test_loader), 590 | data=data_time.avg, 591 | bt=batch_time.avg, 592 | loss=losses.avg, 593 | top1=top1.avg, 594 | top5=top5.avg, 595 | )) 596 | if not args.no_progress: 597 | test_loader.close() 598 | 599 | logger.info("top-1 acc: {:.2f}".format(top1.avg)) 600 | logger.info("top-5 acc: {:.2f}".format(top5.avg)) 601 | return losses.avg, top1.avg 602 | 603 | 604 | class ModelEMA(object): 605 | def __init__(self, args, model, decay, device='', resume=''): 606 | self.ema = deepcopy(model) 607 | self.ema.eval() 608 | self.decay = decay 609 | self.device = device 610 | self.wd = args.lr * args.wdecay 611 | if device: 612 | self.ema.to(device=device) 613 | self.ema_has_module = hasattr(self.ema, 'module') 614 | if resume: 615 | self._load_checkpoint(resume) 616 | for p in self.ema.parameters(): 617 | p.requires_grad_(False) 618 | 619 | def _load_checkpoint(self, checkpoint_path): 620 | checkpoint = torch.load(checkpoint_path) 621 | assert isinstance(checkpoint, dict) 622 | if 'ema_state_dict' in checkpoint: 623 | new_state_dict = OrderedDict() 624 | for k, v in checkpoint['ema_state_dict'].items(): 625 | if self.ema_has_module: 626 | name = 'module.' + k if not k.startswith('module') else k 627 | else: 628 | name = k 629 | new_state_dict[name] = v 630 | self.ema.load_state_dict(new_state_dict) 631 | 632 | def update(self, model): 633 | needs_module = hasattr(model, 'module') and not self.ema_has_module 634 | with torch.no_grad(): 635 | msd = model.state_dict() 636 | for k, ema_v in self.ema.state_dict().items(): 637 | if needs_module: 638 | k = 'module.' + k 639 | model_v = msd[k].detach() 640 | if self.device: 641 | model_v = model_v.to(device=self.device) 642 | ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) 643 | # weight decay 644 | if 'bn' not in k: 645 | msd[k] = msd[k] * (1. - self.wd) 646 | 647 | def interleave_offsets(batch, nu): 648 | groups = [batch // (nu + 1)] * (nu + 1) 649 | for x in range(batch - sum(groups)): 650 | groups[-x - 1] += 1 651 | offsets = [0] 652 | for g in groups: 653 | offsets.append(offsets[-1] + g) 654 | assert offsets[-1] == batch 655 | return offsets 656 | 657 | def interleave(xy, batch): 658 | nu = len(xy) - 1 659 | offsets = interleave_offsets(batch, nu) 660 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] 661 | for i in range(1, nu + 1): 662 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 663 | return [torch.cat(v, dim=0) for v in xy] 664 | 665 | if __name__ == '__main__': 666 | cudnn.benchmark = True 667 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .eval import * 6 | 7 | # progress bar 8 | import os, sys 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 10 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/utils/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim3944/ReMixmatch-pytorch/02bc850f24af676ecea75984255b9b2e9a54509a/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/kekmodel/FixMatch-pytorch/blob/master/models/wideresnet.py 3 | """ 4 | 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def mish(x): 15 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 16 | return x * torch.tanh(F.softplus(x)) 17 | 18 | 19 | class PSBatchNorm2d(nn.BatchNorm2d): 20 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 21 | 22 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 23 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 24 | self.alpha = alpha 25 | 26 | def forward(self, x): 27 | return super().forward(x) + self.alpha 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False): 32 | super(BasicBlock, self).__init__() 33 | self.bn1 = PSBatchNorm2d(in_planes, momentum=0.001) 34 | self.relu1 = mish 35 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 36 | padding=1, bias=False) 37 | self.bn2 = PSBatchNorm2d(out_planes, momentum=0.001) 38 | self.relu2 = mish 39 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 40 | padding=1, bias=False) 41 | self.drop_rate = drop_rate 42 | self.equalInOut = (in_planes == out_planes) 43 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 44 | padding=0, bias=False) or None 45 | self.activate_before_residual = activate_before_residual 46 | 47 | def forward(self, x): 48 | if not self.equalInOut and self.activate_before_residual == True: 49 | x = self.relu1(self.bn1(x)) 50 | else: 51 | out = self.relu1(self.bn1(x)) 52 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 53 | if self.drop_rate > 0: 54 | out = F.dropout(out, p=self.drop_rate, training=self.training) 55 | out = self.conv2(out) 56 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 57 | 58 | 59 | class NetworkBlock(nn.Module): 60 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 61 | super(NetworkBlock, self).__init__() 62 | self.layer = self._make_layer( 63 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 64 | 65 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 66 | layers = [] 67 | for i in range(int(nb_layers)): 68 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 69 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | return self.layer(x) 74 | 75 | 76 | class WideResNet(nn.Module): 77 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0): 78 | super(WideResNet, self).__init__() 79 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 80 | assert((depth - 4) % 6 == 0) 81 | n = (depth - 4) / 6 82 | block = BasicBlock 83 | # 1st conv before any network block 84 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 85 | padding=1, bias=False) 86 | # 1st block 87 | self.block1 = NetworkBlock( 88 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True) 89 | # 2nd block 90 | self.block2 = NetworkBlock( 91 | n, channels[1], channels[2], block, 2, drop_rate) 92 | # 3rd block 93 | self.block3 = NetworkBlock( 94 | n, channels[2], channels[3], block, 2, drop_rate) 95 | # global average pooling and classifier 96 | self.bn1 = PSBatchNorm2d(channels[3], momentum=0.001) 97 | self.relu = mish 98 | self.fc = nn.Linear(channels[3], num_classes) 99 | self.channels = channels[3] 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_normal_(m.weight, 104 | mode='fan_out', 105 | nonlinearity='leaky_relu') 106 | elif isinstance(m, PSBatchNorm2d): 107 | nn.init.constant_(m.weight, 1.0) 108 | nn.init.constant_(m.bias, 0.0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.xavier_normal_(m.weight) 111 | nn.init.constant_(m.bias, 0.0) 112 | 113 | def forward(self, x): 114 | out = self.conv1(x) 115 | out = self.block1(out) 116 | out = self.block2(out) 117 | out = self.block3(out) 118 | out = self.relu(self.bn1(out)) 119 | out = F.adaptive_avg_pool2d(out, 1) 120 | out = out.view(-1, self.channels) 121 | return self.fc(out) 122 | 123 | 124 | def build_wideresnet(depth, widen_factor, dropout, num_classes): 125 | logger.info(f"Model: WideResNet {depth}x{widen_factor}") 126 | return WideResNet(depth=depth, 127 | widen_factor=widen_factor, 128 | drop_rate=dropout, 129 | num_classes=num_classes) --------------------------------------------------------------------------------