├── .gitignore ├── README.md ├── config └── comm.yml ├── datasets ├── __init__.py ├── aircraft.py ├── car.py ├── cub.py ├── nabirds.py └── tfs.py ├── imgs └── overview.jpg ├── main.py ├── networks ├── __init__.py ├── densenet_ft.py ├── efficientnet_ft.py ├── inception.py ├── inception_ft.py ├── modelutil.py ├── resnet.py └── resnet_ft.py ├── trainer ├── comm_test.py └── comm_train.py └── utils ├── __init__.py ├── conf.py ├── custom_ops.py ├── dataloader.py ├── eval.py ├── io.py ├── log.py ├── misc.py ├── mixmethod.py ├── stat.py └── trainutil.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | LossAccPlotter 3 | predictions 4 | outimgs 5 | temp 6 | tmp 7 | *tmp 8 | *_tmp 9 | *_temp 10 | *.jpg 11 | *.png 12 | utils/glasbey 13 | 14 | 15 | 16 | *.checkpoint 17 | results 18 | *__pycache__* 19 | *.pyc 20 | *.pyo 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) 2 | 3 | PyTorch implementation of SnapMix | [paper](https://arxiv.org/abs/2012.04846) 4 | 5 | ## Method Overview 6 | 7 | ![SnapMix](./imgs/overview.jpg) 8 | 9 | ## Cite 10 | ``` 11 | @inproceedings{huang2021snapmix, 12 | title={SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data}, 13 | author={Shaoli Huang, Xinchao Wang, and Dacheng Tao}, 14 | year={2021}, 15 | booktitle={AAAI Conference on Artificial Intelligence}, 16 | } 17 | ``` 18 | 19 | ## Setup 20 | ### Install Package Dependencies 21 | ``` 22 | torch 23 | torchvision 24 | PyYAML 25 | easydict 26 | tqdm 27 | scikit-learn 28 | efficientnet_pytorch 29 | pandas 30 | opencv 31 | ``` 32 | ### Datasets 33 | ***create a soft link to the dataset directory*** 34 | 35 | CUB dataset 36 | ``` 37 | ln -s /your-path-to/CUB-dataset data/cub 38 | ``` 39 | Car dataset 40 | ``` 41 | ln -s /your-path-to/Car-dataset data/car 42 | ``` 43 | Aircraft dataset 44 | ``` 45 | ln -s /your-path-to/Aircraft-dataset data/aircraft 46 | ``` 47 | 48 | ## Training 49 | 50 | ### Training with Imagenet pre-trained weights 51 | 52 | 53 | ***1. Baseline and Baseline+*** 54 | 55 | To train a model on CUB dataset using the Resnet-50 backbone, 56 | 57 | ``` python main.py ``` # baseline 58 | 59 | ``` python main.py --midlevel``` # baseline+ 60 | 61 | To train model on other datasets using other network backbones, you can specify the following arguments: 62 | 63 | ``` --netname: name of network architectures (support 4 network families: ResNet,DenseNet,InceptionV3,EfficientNet) ``` 64 | 65 | ``` --dataset: dataset name``` 66 | 67 | For example, 68 | 69 | ``` python main.py --netname resnet18 --dataset cub ``` # using the Resnet-18 backbone on CUB dataset 70 | 71 | ``` python main.py --netname efficientnet-b0 --dataset cub ``` # using the EfficientNet-b0 backbone on CUB dataset 72 | 73 | ``` python main.py --netname inceptoinV3 --dataset aircraft ``` # using the inceptionV3 backbone on Aircraft dataset 74 | 75 | 76 | ***2. Training with mixing augmentation*** 77 | 78 | Applying SnapMix in training ( we used the hyperparameter values (prob=1., beta=5) for SnapMix in most of the experiments.): 79 | 80 | ```python main.py --mixmethod snapmix --beta 5 --netname resnet50 --dataset cub ``` # baseline 81 | 82 | ```python main.py --mixmethod snapmix --beta 5 --netname resnet50 --dataset cub --midlevel ``` # baseline+ 83 | 84 | Applying other augmentation methods (currently support cutmix,cutout,and mixup) in training: 85 | 86 | ```python main.py --mixmethod cutmix --beta 3 --netname resnet50 --dataset cub ``` # training with CutMix 87 | 88 | ```python main.py --mixmethod mixup --prob 0.5 --netname resnet50 --dataset cub ``` # training with MixUp 89 | 90 | ***3. Results*** 91 | 92 | ***ResNet architecture.*** 93 | 94 | | Backbone | Method | CUB | Car | Aircraft | 95 | |:--------|:--------|--------:|------:|--------:| 96 | |Resnet-18 | Baseline| 82.35% | 91.15% | 87.80% | 97 | |Resnet-18 | Baseline + SnapMix| 84.29% | 93.12% | 90.17% | 98 | |Resnet-34 | Baseline| 84.98% | 92.02% | 89.92% | 99 | |Resnet-34 | Baseline + SnapMix| 87.06% | 93.95% | 92.36% | 100 | |Resnet-50 | Baseline| 85.49% | 93.04% | 91.07% | 101 | |Resnet-50 | Baseline + SnapMix| 87.75% | 94.30% | 92.08% | 102 | |Resnet-101 | Baseline| 85.62% | 93.09% | 91.59% | 103 | |Resnet-101 | Baseline + SnapMix| 88.45% | 94.44% | 93.74% | 104 | |Resnet-50 | Baseline+| 87.13% | 93.80% | 91.68% | 105 | |Resnet-50 | Baseline+ + SnapMix| 88.70% | 95.00% | 93.24% | 106 | |Resnet-101 | Baseline+| 87.81% | 93.94% | 91.85% | 107 | |Resnet-101 | Baseline+ + SnapMix| 89.32% | 94.84% | 94.05% | 108 | 109 | 110 | ***InceptionV3 architecture.*** 111 | 112 | | Backbone | Method | CUB | 113 | |:--------|:--------|--------:| 114 | |InceptionV3 | Baseline| 82.22% | 115 | |InceptionV3 | Baseline + SnapMix| 85.54%| 116 | 117 | 118 | ***DenseNet architecture.*** 119 | 120 | | Backbone | Method | CUB | 121 | |:--------|:--------|--------:| 122 | |DenseNet121 | Baseline| 84.23% | 123 | |DenseNet121| Baseline + SnapMix| 87.42%| 124 | 125 | 126 | ### Training from scratch 127 | 128 | To train a model without using ImageNet pretrained weights: 129 | 130 | ```python main.py --mixmethod snapmix --prob 0.5 --netname resnet18 --dataset cub --pretrained 0``` # resnet-18 backbone 131 | 132 | ```python main.py --mixmethod snapmix --prob 0.5 --netname resnet50 --dataset cub --pretrained 0 ``` # resnet-50 backbone 133 | 134 | ***2. Results*** 135 | 136 | | Backbone | Method | CUB | 137 | |:--------|:--------|--------:| 138 | |Resnet-18 | Baseline| 64.98% | 139 | |Resnet-18 | Baseline + SnapMix| 70.31%| 140 | |Resnet-50 | Baseline| 66.92% | 141 | |Resnet-50| Baseline + SnapMix| 72.17%| 142 | -------------------------------------------------------------------------------- /config/comm.yml: -------------------------------------------------------------------------------- 1 | 2 | exp_name: snapmix 3 | train_proc: comm 4 | test_proc: comm 5 | 6 | prams_group: ['ftlayer','freshlayer'] 7 | lr_group: [0.001,0.01] 8 | lrstep: [80,150,180] 9 | lrgamma: 0.1 10 | 11 | criterion: CrossEntropyLoss 12 | reduction: none 13 | warp: True 14 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /datasets/aircraft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader 4 | import PIL 5 | from PIL import Image 6 | import os 7 | import numpy as np 8 | import pandas as pd 9 | from datasets.tfs import get_aircraft_transform 10 | 11 | 12 | import pdb 13 | 14 | 15 | def make_dataset(dir, image_ids, targets): 16 | assert(len(image_ids) == len(targets)) 17 | images = [] 18 | dir = os.path.expanduser(dir) 19 | for i in range(len(image_ids)): 20 | item = (os.path.join(dir, 'data', 'images', 21 | '%s.jpg' % image_ids[i]), targets[i]) 22 | images.append(item) 23 | return images 24 | 25 | 26 | def find_classes(classes_file): 27 | # read classes file, separating out image IDs and class names 28 | image_ids = [] 29 | targets = [] 30 | f = open(classes_file, 'r') 31 | for line in f: 32 | split_line = line.split(' ') 33 | image_ids.append(split_line[0]) 34 | targets.append(' '.join(split_line[1:])) 35 | f.close() 36 | 37 | # index class names 38 | classes = np.unique(targets) 39 | class_to_idx = {classes[i]: i for i in range(len(classes))} 40 | targets = [class_to_idx[c] for c in targets] 41 | 42 | return (image_ids, targets, classes, class_to_idx) 43 | 44 | 45 | class ImageLoader(data.Dataset): 46 | """`FGVC-Aircraft `_ Dataset. 47 | Args: 48 | root (string): Root directory path to dataset. 49 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification 50 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). 51 | transform (callable, optional): A function/transform that takes in a PIL image 52 | and returns a transformed version. E.g. ``transforms.RandomCrop`` 53 | target_transform (callable, optional): A function/transform that takes in the 54 | target and transforms it. 55 | loader (callable, optional): A function to load an image given its path. 56 | download (bool, optional): If true, downloads the dataset from the internet and 57 | puts it in the root directory. If dataset is already downloaded, it is not 58 | downloaded again. 59 | """ 60 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 61 | class_types = ('variant', 'family', 'manufacturer') 62 | splits = ('train', 'val', 'trainval', 'test') 63 | 64 | def __init__(self, root='data/aircraft', transform=None, 65 | target_transform=None, train=True, loader=default_loader): 66 | 67 | 68 | self.root = os.path.expanduser(root) 69 | self.class_type = 'variant' 70 | self.split = 'trainval' if train else 'test' 71 | self.classes_file = os.path.join(self.root, 'data', 72 | 'images_%s_%s.txt' % (self.class_type, self.split)) 73 | 74 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) 75 | samples = make_dataset(self.root, image_ids, targets) 76 | 77 | paths = [] 78 | labels = [] 79 | for sample in samples: 80 | path,label = sample 81 | paths.append(path) 82 | labels.append(label) 83 | 84 | datadict = {'path':paths,'label':labels} 85 | data = pd.DataFrame(datadict) 86 | imgs = data.reset_index(drop=True) 87 | 88 | 89 | self.transform = transform 90 | self.target_transform = target_transform 91 | self.loader = loader 92 | 93 | self.imgs = imgs 94 | self.classes = classes 95 | self.class_to_idx = class_to_idx 96 | self.tta = tta 97 | 98 | def __getitem__(self, index): 99 | item = self.imgs.iloc[index] 100 | path = item['path'] 101 | target = item['label'] 102 | img = self.loader(path) 103 | img = self.transform(img) 104 | return img, target 105 | 106 | def __len__(self): 107 | return len(self.imgs) 108 | 109 | def __repr__(self): 110 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 111 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 112 | fmt_str += ' Root Location: {}\n'.format(self.root) 113 | tmp = ' Transforms (if any): ' 114 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 115 | tmp = ' Target Transforms (if any): ' 116 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 117 | return fmt_str 118 | 119 | def _check_exists(self): 120 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ 121 | os.path.exists(self.classes_file) 122 | 123 | def get_dataset(conf): 124 | 125 | datadir = 'data/aircraft' 126 | 127 | if conf and 'datadir' in conf: 128 | datadir = conf.datadir 129 | 130 | conf['num_class'] = 100 131 | 132 | transform_train,transform_test = get_aircraft_transform(conf) 133 | 134 | ds_train = ImageLoader(datadir, train=True, transform=transform_train) 135 | ds_test = ImageLoader(datadir, train=False, transform=transform_test) 136 | 137 | 138 | return ds_train,ds_test 139 | -------------------------------------------------------------------------------- /datasets/car.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import PIL 3 | from PIL import Image 4 | import os 5 | import pandas as pd 6 | import math 7 | from torch.utils.data.sampler import WeightedRandomSampler 8 | import numpy as np 9 | from scipy.io import loadmat 10 | from datasets.tfs import get_car_transform 11 | 12 | import pdb 13 | 14 | def pil_loader(path): 15 | with open(path, 'rb') as f: 16 | img = Image.open(f) 17 | return img.convert('RGB') 18 | 19 | def get_mat_frame(path, img_folder): 20 | results = {} 21 | tmp_mat = loadmat(path) 22 | anno = tmp_mat['annotations'][0] 23 | results['path'] = [os.path.join(img_folder, anno[i][-1][0]) for i in range(anno.shape[0])] 24 | results['label'] = [anno[i][-2][0, 0] for i in range(anno.shape[0])] 25 | return results 26 | 27 | 28 | class ImageLoader(torch.utils.data.Dataset): 29 | 30 | def __init__(self, root='Stanford_Cars', transform=None, target_transform=None, train=False, loader=pil_loader): 31 | 32 | img_folder = root 33 | pd_train = pd.DataFrame.from_dict(get_mat_frame(os.path.join(root, 'devkit', 'cars_train_annos.mat'), 'cars_train')) 34 | pd_test = pd.DataFrame.from_dict(get_mat_frame(os.path.join(root, 'devkit', 'cars_test_annos_withlabels.mat'), 'cars_test')) 35 | data = pd.concat([pd_train, pd_test]) 36 | data['train_flag'] = pd.Series(data.path.isin(pd_train['path'])) 37 | data = data[data['train_flag'] == train] 38 | data['label'] = data['label'] - 1 39 | 40 | imgs = data.reset_index(drop=True) 41 | 42 | if len(imgs) == 0: 43 | raise(RuntimeError("no csv file")) 44 | self.root = img_folder 45 | self.imgs = imgs 46 | self.transform = transform 47 | self.target_transform = target_transform 48 | self.loader = loader 49 | self.train = train 50 | 51 | def __getitem__(self, index): 52 | item = self.imgs.iloc[index] 53 | file_path = item['path'] 54 | target = item['label'] 55 | 56 | img = self.loader(os.path.join(self.root, file_path)) 57 | img = self.transform(img) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.imgs) 63 | 64 | 65 | def get_dataset(conf): 66 | 67 | datadir = 'data/car' 68 | 69 | if conf and 'datadir' in conf: 70 | datadir = conf.datadir 71 | 72 | conf['num_class'] = 196 73 | 74 | transform_train,transform_test = get_car_transform(conf) 75 | 76 | ds_train = ImageLoader(datadir, train=True, transform=transform_train) 77 | ds_test = ImageLoader(datadir, train=False, transform=transform_test) 78 | 79 | 80 | return ds_train,ds_test 81 | 82 | -------------------------------------------------------------------------------- /datasets/cub.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import PIL 3 | from PIL import Image 4 | import os 5 | import pandas as pd 6 | import math 7 | from torch.utils.data.sampler import WeightedRandomSampler 8 | import numpy as np 9 | import torchvision.datasets as tvdataset 10 | from datasets.tfs import get_cub_transform 11 | 12 | import pdb 13 | 14 | def pil_loader(path): 15 | with open(path, 'rb') as f: 16 | img = Image.open(f) 17 | return img.convert('RGB') 18 | 19 | class ImageLoader(torch.utils.data.Dataset): 20 | 21 | def __init__(self, root, transform=None, target_transform=None, train=False, loader=pil_loader): 22 | img_folder = os.path.join(root, "images") 23 | img_paths = pd.read_csv(os.path.join(root, "images.txt"), sep=" ", header=None, names=['idx', 'path']) 24 | img_labels = pd.read_csv(os.path.join(root, "image_class_labels.txt"), sep=" ", header=None, names=['idx', 'label']) 25 | train_test_split = pd.read_csv(os.path.join(root, "train_test_split.txt"), sep=" ", header=None, names=['idx', 'train_flag']) 26 | bounding_box = pd.read_csv(os.path.join(root, "bounding_boxes.txt"), sep=" ", header=None, names=['idx', 'x', 'y', 'w', 'h']) 27 | data = pd.concat([img_paths, img_labels, train_test_split, bounding_box], axis=1) 28 | data['label'] = data['label'] - 1 29 | alldata = data.copy() 30 | 31 | data = data[data['train_flag'] == train] 32 | imgs = data.reset_index(drop=True) 33 | 34 | if len(imgs) == 0: 35 | raise(RuntimeError("no csv file")) 36 | self.root = img_folder 37 | self.imgs = imgs 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | self.loader = loader 41 | self.train = train 42 | print('num of data:{}'.format(len(imgs))) 43 | 44 | def __getitem__(self, index): 45 | """ 46 | Args: 47 | index (int): Index 48 | Returns: 49 | tuple: (image, target) where target is class_index of the target class. 50 | """ 51 | item = self.imgs.iloc[index] 52 | file_path = item['path'] 53 | target = item['label'] 54 | img = self.loader(os.path.join(self.root, file_path)) 55 | img = self.transform(img) 56 | 57 | return img, target 58 | 59 | def __len__(self): 60 | return len(self.imgs) 61 | 62 | 63 | def get_dataset(conf): 64 | 65 | datadir = 'data/cub' 66 | 67 | if conf and 'datadir' in conf: 68 | datadir = conf.datadir 69 | 70 | conf['num_class'] = 200 71 | 72 | transform_train,transform_test = get_cub_transform(conf) 73 | 74 | ds_train = ImageLoader(datadir, train=True, transform=transform_train) 75 | ds_test = ImageLoader(datadir, train=False, transform=transform_test) 76 | 77 | 78 | return ds_train,ds_test 79 | -------------------------------------------------------------------------------- /datasets/nabirds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import PIL 3 | from PIL import Image 4 | import os 5 | import pandas as pd 6 | import math 7 | from torch.utils.data.sampler import WeightedRandomSampler 8 | from datasets.tfs import get_nabirds_transform 9 | import numpy as np 10 | 11 | import pdb 12 | 13 | def pil_loader(path): 14 | with open(path, 'rb') as f: 15 | img = Image.open(f) 16 | return img.convert('RGB') 17 | 18 | class ImageLoader(torch.utils.data.Dataset): 19 | 20 | def __init__(self, root, transform=None, target_transform=None, train=False, loader=pil_loader, tta=None): 21 | img_folder = os.path.join(root, "images") 22 | img_paths = pd.read_csv(os.path.join(root, "images.txt"), sep=" ", header=None, names=['idx', 'path']) 23 | img_labels = pd.read_csv(os.path.join(root, "image_class_labels.txt"), sep=" ", header=None, names=['idx', 'cat_num']) 24 | train_test_split = pd.read_csv(os.path.join(root, "train_test_split.txt"), sep=" ", header=None, names=['idx', 'train_flag']) 25 | data = pd.concat([img_paths, img_labels, train_test_split], axis=1) 26 | cat_nums = data.cat_num.unique().tolist() 27 | cat_nums.sort() 28 | data['label'] = data['cat_num'].apply(lambda x: cat_nums.index(x)) 29 | data = data[data['train_flag'] == train] 30 | 31 | imgs = data.reset_index(drop=True) 32 | 33 | if len(imgs) == 0: 34 | raise(RuntimeError("no csv file")) 35 | self.root = img_folder 36 | self.imgs = imgs 37 | self.transform = transform 38 | self.target_transform = target_transform 39 | self.loader = loader 40 | self.train = train 41 | self.tta = tta 42 | 43 | def __getitem__(self, index): 44 | """ 45 | Args: 46 | index (int): Index 47 | Returns: 48 | tuple: (image, target) where target is class_index of the target class. 49 | """ 50 | item = self.imgs.iloc[index] 51 | file_path = item['path'] 52 | target = item['label'] 53 | 54 | img = self.loader(os.path.join(self.root, file_path)) 55 | 56 | if self.tta is None: 57 | img = self.transform(img) 58 | 59 | elif self.tta == 'flip': 60 | img_1 = self.transform(img) 61 | img_2 = img.transpose(PIL.Image.FLIP_LEFT_RIGHT) 62 | img_2 = self.transform(img_2) 63 | img = torch.stack((img_1, img_2), dim=0) 64 | else: 65 | img = self.transform(img) 66 | 67 | return img, target 68 | 69 | def __len__(self): 70 | return len(self.imgs) 71 | 72 | def get_dataset(conf): 73 | 74 | datadir = 'data/nabirds' 75 | 76 | if conf and 'datadir' in conf: 77 | datadir = conf.datadir 78 | 79 | conf['num_class'] = 555 80 | 81 | transform_train,transform_test = get_nabirds_transform(conf) 82 | 83 | ds_train = ImageLoader(datadir, train=True, transform=transform_train) 84 | ds_test = ImageLoader(datadir, train=False, transform=transform_test) 85 | 86 | 87 | return ds_train,ds_test 88 | -------------------------------------------------------------------------------- /datasets/tfs.py: -------------------------------------------------------------------------------- 1 | 2 | import torchvision.transforms as transforms 3 | from utils import * 4 | 5 | resizedict = {'224':256,'448':512,'112':128} 6 | 7 | 8 | def get_aircraft_transform(conf=None): 9 | return get_cub_transform(conf) 10 | 11 | def get_car_transform(conf=None): 12 | return get_cub_transform(conf) 13 | 14 | 15 | def get_nabirds_transform(conf=None): 16 | return get_cub_transform(conf) 17 | 18 | def get_cub_transform(conf=None): 19 | 20 | resize = 256 21 | cropsize = 224 22 | 23 | if conf and 'cropsize' in conf: 24 | cropsize = conf.cropsize 25 | resize = resizedict[str(cropsize)] 26 | 27 | if 'warp' in conf: 28 | 29 | if conf.warp: 30 | print('using warping') 31 | resize = (resize,resize) 32 | cropsize = (cropsize,cropsize) 33 | 34 | 35 | 36 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 37 | 38 | if resize == cropsize: 39 | tflist = [transforms.RandomResizedCrop(cropsize)] 40 | else: 41 | tflist = [transforms.Resize(resize),transforms.RandomCrop(cropsize)] 42 | 43 | transform_train = transforms.Compose(tflist + [ 44 | #transforms.RandomRotation(15), 45 | transforms.RandomCrop(cropsize), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | normalize]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.Resize(resize), 52 | transforms.CenterCrop(cropsize), 53 | transforms.ToTensor(), 54 | normalize 55 | ]) 56 | 57 | return transform_train,transform_test 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /imgs/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shaoli-Huang/SnapMix/245bda44444879563fafad90062655e2911cc75d/imgs/overview.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | 6 | import networks 7 | import trainer 8 | import logging 9 | import numpy as np 10 | from utils import get_config,set_env,set_logger,set_outdir 11 | from utils import get_dataloader 12 | from utils import get_train_setting,load_checkpoint,get_proc,save_checkpoint 13 | import os 14 | 15 | 16 | def main(conf): 17 | 18 | warnings.filterwarnings("ignore") 19 | best_score = 0. 20 | val_score = 0 21 | val_loss = 0 22 | epoch_start = 0 23 | 24 | 25 | # dataloader 26 | train_loader,val_loader = get_dataloader(conf) 27 | 28 | # model 29 | model = networks.get_model(conf) 30 | model = nn.DataParallel(model).cuda() 31 | 32 | if conf.weightfile is not None: 33 | wmodel = networks.get_model(conf) 34 | wmodel = nn.DataParallel(wmodel).cuda() 35 | checkpoint_dict = load_checkpoint(wmodel, conf.weightfile) 36 | if 'best_score' in checkpoint_dict: 37 | print('best score: {}'.format(best_score)) 38 | else: 39 | wmodel = model 40 | 41 | # training setting 42 | criterion,optimizer,scheduler = get_train_setting(model,conf) 43 | 44 | # training and evaluate process for each epoch 45 | train,validate = get_proc(conf) 46 | 47 | if conf.resume: 48 | checkpoint_dict = load_checkpoint(model, conf.resume) 49 | epoch_start = checkpoint_dict['epoch'] 50 | if 'best_score' in checkpoint_dict: 51 | best_score = checkpoint_dict['best_score'] 52 | print('best score: {}'.format(best_score)) 53 | print('Resuming training process from epoch {}...'.format(epoch_start)) 54 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 55 | scheduler.load_state_dict(checkpoint_dict['scheduler']) 56 | print('Resuming lr scheduler') 57 | print(checkpoint_dict['scheduler']) 58 | 59 | if conf.evaluate: 60 | print( validate(val_loader, model,criterion, conf)) 61 | return 62 | 63 | detach_epoch = conf.epochs + 1 64 | if 'detach_epoch' in conf: 65 | detach_epoch = conf.detach_epoch 66 | 67 | start_eval = 0 68 | if 'start_eval' in conf: 69 | start_eval = conf.start_eval 70 | 71 | 72 | ## ------main loop----- 73 | for epoch in range(epoch_start, conf.epochs): 74 | lr = optimizer.param_groups[0]['lr'] 75 | logging.info("Epoch: [{} | {} LR: {}".format(epoch+1,conf.epochs,lr)) 76 | 77 | if epoch == detach_epoch: 78 | model.module.set_detach(False) 79 | 80 | tmp_loss = train(train_loader, model, criterion, optimizer, conf,wmodel) 81 | infostr = {'Epoch: {} train_loss: {}'.format(epoch+1,tmp_loss)} 82 | logging.info(infostr) 83 | scheduler.step() 84 | 85 | if epoch > start_eval: 86 | with torch.no_grad(): 87 | val_score,val_loss,mscore,ascore = validate(val_loader, model,criterion, conf) 88 | comscore = val_score 89 | if 'midlevel' in conf: 90 | if conf.midlevel: 91 | comscore = ascore 92 | is_best = comscore > best_score 93 | best_score = max(comscore,best_score) 94 | infostr = {'Epoch: {:.4f} loss: {:.4f},gs: {:.4f},bs:{:.4f} ,ms:{:.4f},as:{:.4f}'.format(epoch+1,val_loss,val_score,best_score,mscore,ascore)} 95 | logging.info(infostr) 96 | save_checkpoint( 97 | {'epoch': epoch + 1, 98 | 'state_dict': model.module.state_dict(), 99 | 'optimizer' : optimizer.state_dict(), 100 | 'scheduler' : scheduler.state_dict(), 101 | 'best_score': best_score 102 | }, is_best, outdir=conf['outdir']) 103 | 104 | 105 | print('Best val acc: {}'.format(best_score)) 106 | return 0 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | # get configs and set envs 112 | conf = get_config() 113 | set_env(conf) 114 | # generate outdir name 115 | set_outdir(conf) 116 | # Set the logger 117 | set_logger(conf) 118 | main(conf) 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelutil import get_model 2 | -------------------------------------------------------------------------------- /networks/densenet_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from math import floor 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | import cv2 10 | 11 | import torchvision 12 | from torchvision import models 13 | import pdb 14 | 15 | dimdict = {'densenet121':1024,'densenet201':1920} 16 | 17 | class DenseNet(nn.Module): 18 | 19 | def __init__(self,conf): 20 | super(DenseNet, self).__init__() 21 | basenet = eval('models.'+conf.netname)(pretrained=conf.pretrained) 22 | self.feature = nn.Sequential(*list(basenet.children())[:-1]) 23 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 24 | indim = dimdict[conf.netname] 25 | self.classifier = nn.Linear(indim, conf.num_class) 26 | 27 | def set_detach(self,isdetach): 28 | pass 29 | 30 | 31 | def forward(self, x): 32 | x = self.feature(x) 33 | x = F.relu(x, inplace=True) 34 | fea_pool = self.avg_pool(x).view(x.size(0), -1) 35 | logits = self.classifier(fea_pool) 36 | return logits,x.detach(),None 37 | 38 | #results = {'logit': [logits]} 39 | #return results 40 | 41 | def _init_weight(self, block): 42 | for m in block.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 45 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 46 | nn.init.constant_(m.weight, 1) 47 | nn.init.constant_(m.bias, 0) 48 | 49 | def get_params(self, param_name): 50 | ftlayer_params = list(self.feature.parameters()) 51 | ftlayer_params_ids = list(map(id, ftlayer_params)) 52 | freshlayer_params = filter(lambda p: id(p) not in ftlayer_params_ids, self.parameters()) 53 | 54 | return eval(param_name+'_params') 55 | 56 | 57 | def get_net(conf): 58 | return DenseNet(conf) 59 | -------------------------------------------------------------------------------- /networks/efficientnet_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from math import floor 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | import cv2 10 | from efficientnet_pytorch import EfficientNet 11 | 12 | import torchvision 13 | from networks.resnet import * 14 | import pdb 15 | 16 | dimdict={'efficientnet-b0':1536 ,'efficientnet-b1':1536 ,'efficientnet-b2':1536 ,'efficientnet-b3':1536 ,'efficientnet-b4':1792,'efficientnet-b5':2048,'efficientnet-b6':2304,'efficientnet-b7':2560} 17 | 18 | class ENet(nn.Module): 19 | 20 | def __init__(self,conf): 21 | super(ENet, self).__init__() 22 | self.basemodel = EfficientNet.from_pretrained(conf.netname) 23 | feadim=dimdict[conf.netname] 24 | self.classifier = nn.Linear(feadim, conf.num_class) 25 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 26 | self._dropout = nn.Dropout(0.5) 27 | 28 | 29 | def forward(self, x): 30 | x = self.basemodel.extract_features(x) 31 | fea_pool = self.avg_pool(x).view(x.size(0), -1) 32 | logits = self.classifier(fea_pool) 33 | return logits,x.detach(),None 34 | 35 | 36 | def get_params(self, param_name): 37 | ftlayer_params = list(self.basemodel.parameters()) 38 | ftlayer_params_ids = list(map(id, ftlayer_params)) 39 | freshlayer_params = filter(lambda p: id(p) not in ftlayer_params_ids, self.parameters()) 40 | 41 | return eval(param_name+'_params') 42 | 43 | 44 | def get_net(conf): 45 | return ENet(conf) 46 | -------------------------------------------------------------------------------- /networks/inception.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | #from torchvision.utils import load_state_dict_from_url 8 | from torch.hub import load_state_dict_from_url 9 | 10 | #import torchvision.models.utils as tutil 11 | from typing import Callable, Any, Optional, Tuple, List 12 | 13 | 14 | __all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs'] 15 | 16 | 17 | model_urls = { 18 | # Inception v3 ported from TensorFlow 19 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 20 | } 21 | 22 | InceptionOutputs = namedtuple('InceptionOutputs', ['logits','conv', 'aux_logits']) 23 | InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'conv': Optional[torch.Tensor], 'aux_logits': Optional[torch.Tensor]} 24 | 25 | # Script annotations failed with _GoogleNetOutputs = namedtuple ... 26 | # _InceptionOutputs set here for backwards compat 27 | _InceptionOutputs = InceptionOutputs 28 | 29 | 30 | def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3": 31 | r"""Inception v3 model architecture from 32 | `"Rethinking the Inception Architecture for Computer Vision" `_. 33 | 34 | .. note:: 35 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 36 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 37 | 38 | Args: 39 | pretrained (bool): If True, returns a model pre-trained on ImageNet 40 | progress (bool): If True, displays a progress bar of the download to stderr 41 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 42 | Default: *True* 43 | transform_input (bool): If True, preprocesses the input according to the method with which it 44 | was trained on ImageNet. Default: *False* 45 | """ 46 | if pretrained: 47 | if 'transform_input' not in kwargs: 48 | kwargs['transform_input'] = True 49 | if 'aux_logits' in kwargs: 50 | original_aux_logits = kwargs['aux_logits'] 51 | kwargs['aux_logits'] = True 52 | else: 53 | original_aux_logits = True 54 | kwargs['init_weights'] = False # we are loading weights from a pretrained model 55 | model = Inception3(**kwargs) 56 | state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], 57 | progress=progress) 58 | model.load_state_dict(state_dict) 59 | num_ftrs = model.AuxLogits.fc.in_features 60 | numclass = 200 61 | model.AuxLogits.fc = nn.Linear(num_ftrs, numclass) 62 | num_ftrs = model.fc.in_features 63 | model.fc = nn.Linear(num_ftrs,numclass) 64 | if not original_aux_logits: 65 | model.aux_logits = False 66 | del model.AuxLogits 67 | return model 68 | 69 | 70 | return Inception3(**kwargs) 71 | 72 | 73 | class Inception3(nn.Module): 74 | 75 | def __init__( 76 | self, 77 | num_classes: int = 1000, 78 | aux_logits: bool = True, 79 | transform_input: bool = False, 80 | inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, 81 | init_weights: Optional[bool] = None 82 | ) -> None: 83 | super(Inception3, self).__init__() 84 | if inception_blocks is None: 85 | inception_blocks = [ 86 | BasicConv2d, InceptionA, InceptionB, InceptionC, 87 | InceptionD, InceptionE, InceptionAux 88 | ] 89 | if init_weights is None: 90 | warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ' 91 | 'torchvision. If you wish to keep the old behavior (which leads to long initialization times' 92 | ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) 93 | init_weights = True 94 | assert len(inception_blocks) == 7 95 | conv_block = inception_blocks[0] 96 | inception_a = inception_blocks[1] 97 | inception_b = inception_blocks[2] 98 | inception_c = inception_blocks[3] 99 | inception_d = inception_blocks[4] 100 | inception_e = inception_blocks[5] 101 | inception_aux = inception_blocks[6] 102 | 103 | self.aux_logits = aux_logits 104 | self.transform_input = transform_input 105 | self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) 106 | self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) 107 | self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) 108 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) 109 | self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) 110 | self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) 111 | self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) 112 | self.Mixed_5b = inception_a(192, pool_features=32) 113 | self.Mixed_5c = inception_a(256, pool_features=64) 114 | self.Mixed_5d = inception_a(288, pool_features=64) 115 | self.Mixed_6a = inception_b(288) 116 | self.Mixed_6b = inception_c(768, channels_7x7=128) 117 | self.Mixed_6c = inception_c(768, channels_7x7=160) 118 | self.Mixed_6d = inception_c(768, channels_7x7=160) 119 | self.Mixed_6e = inception_c(768, channels_7x7=192) 120 | if aux_logits: 121 | self.AuxLogits = inception_aux(768, num_classes) 122 | self.Mixed_7a = inception_d(768) 123 | self.Mixed_7b = inception_e(1280) 124 | self.Mixed_7c = inception_e(2048) 125 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 126 | self.dropout = nn.Dropout() 127 | self.fc = nn.Linear(2048, num_classes) 128 | if init_weights: 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 131 | import scipy.stats as stats 132 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 133 | X = stats.truncnorm(-2, 2, scale=stddev) 134 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 135 | values = values.view(m.weight.size()) 136 | with torch.no_grad(): 137 | m.weight.copy_(values) 138 | elif isinstance(m, nn.BatchNorm2d): 139 | nn.init.constant_(m.weight, 1) 140 | nn.init.constant_(m.bias, 0) 141 | 142 | def get_params(self, param_name): 143 | ftlayer_params = list(self.AuxLogits.fc.parameters()) +\ 144 | list(self.fc.parameters()) 145 | ftlayer_params_ids = list(map(id, ftlayer_params)) 146 | freshlayer_params = filter(lambda p: id(p) not in ftlayer_params_ids, self.parameters()) 147 | 148 | return eval(param_name+'_params') 149 | 150 | def _transform_input(self, x: Tensor) -> Tensor: 151 | if self.transform_input: 152 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 153 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 154 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 155 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 156 | return x 157 | 158 | def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: 159 | # N x 3 x 299 x 299 160 | x = self.Conv2d_1a_3x3(x) 161 | # N x 32 x 149 x 149 162 | x = self.Conv2d_2a_3x3(x) 163 | # N x 32 x 147 x 147 164 | x = self.Conv2d_2b_3x3(x) 165 | # N x 64 x 147 x 147 166 | x = self.maxpool1(x) 167 | # N x 64 x 73 x 73 168 | x = self.Conv2d_3b_1x1(x) 169 | # N x 80 x 73 x 73 170 | x = self.Conv2d_4a_3x3(x) 171 | # N x 192 x 71 x 71 172 | x = self.maxpool2(x) 173 | # N x 192 x 35 x 35 174 | x = self.Mixed_5b(x) 175 | # N x 256 x 35 x 35 176 | x = self.Mixed_5c(x) 177 | # N x 288 x 35 x 35 178 | x = self.Mixed_5d(x) 179 | # N x 288 x 35 x 35 180 | x = self.Mixed_6a(x) 181 | # N x 768 x 17 x 17 182 | x = self.Mixed_6b(x) 183 | # N x 768 x 17 x 17 184 | x = self.Mixed_6c(x) 185 | # N x 768 x 17 x 17 186 | x = self.Mixed_6d(x) 187 | # N x 768 x 17 x 17 188 | x = self.Mixed_6e(x) 189 | # N x 768 x 17 x 17 190 | aux_defined = self.training and self.aux_logits 191 | if aux_defined: 192 | aux = self.AuxLogits(x) 193 | else: 194 | aux = None 195 | # N x 768 x 17 x 17 196 | x = self.Mixed_7a(x) 197 | # N x 1280 x 8 x 8 198 | x = self.Mixed_7b(x) 199 | # N x 2048 x 8 x 8 200 | conv = self.Mixed_7c(x) 201 | # N x 2048 x 8 x 8 202 | # Adaptive average pooling 203 | x = self.avgpool(conv) 204 | # N x 2048 x 1 x 1 205 | x = self.dropout(x) 206 | # N x 2048 x 1 x 1 207 | x = torch.flatten(x, 1) 208 | # N x 2048 209 | x = self.fc(x) 210 | # N x 1000 (num_classes) 211 | return x,conv, aux 212 | 213 | @torch.jit.unused 214 | def eager_outputs(self, x: Tensor,conv: Optional[Tensor], aux: Optional[Tensor]) -> InceptionOutputs: 215 | if self.training and self.aux_logits: 216 | return InceptionOutputs(x, conv,aux) 217 | else: 218 | return x # type: ignore[return-value] 219 | 220 | def forward(self, x: Tensor) -> InceptionOutputs: 221 | x = self._transform_input(x) 222 | x,conv, aux = self._forward(x) 223 | aux_defined = self.training and self.aux_logits 224 | if torch.jit.is_scripting(): 225 | if not aux_defined: 226 | warnings.warn("Scripted Inception3 always returns Inception3 Tuple") 227 | return InceptionOutputs(x, aux) 228 | else: 229 | return self.eager_outputs(x,conv, aux) 230 | 231 | 232 | class InceptionA(nn.Module): 233 | 234 | def __init__( 235 | self, 236 | in_channels: int, 237 | pool_features: int, 238 | conv_block: Optional[Callable[..., nn.Module]] = None 239 | ) -> None: 240 | super(InceptionA, self).__init__() 241 | if conv_block is None: 242 | conv_block = BasicConv2d 243 | self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) 244 | 245 | self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) 246 | self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) 247 | 248 | self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) 249 | self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) 250 | self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) 251 | 252 | self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) 253 | 254 | def _forward(self, x: Tensor) -> List[Tensor]: 255 | branch1x1 = self.branch1x1(x) 256 | 257 | branch5x5 = self.branch5x5_1(x) 258 | branch5x5 = self.branch5x5_2(branch5x5) 259 | 260 | branch3x3dbl = self.branch3x3dbl_1(x) 261 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 262 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 263 | 264 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 265 | branch_pool = self.branch_pool(branch_pool) 266 | 267 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 268 | return outputs 269 | 270 | def forward(self, x: Tensor) -> Tensor: 271 | outputs = self._forward(x) 272 | return torch.cat(outputs, 1) 273 | 274 | 275 | class InceptionB(nn.Module): 276 | 277 | def __init__( 278 | self, 279 | in_channels: int, 280 | conv_block: Optional[Callable[..., nn.Module]] = None 281 | ) -> None: 282 | super(InceptionB, self).__init__() 283 | if conv_block is None: 284 | conv_block = BasicConv2d 285 | self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) 286 | 287 | self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) 288 | self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) 289 | self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) 290 | 291 | def _forward(self, x: Tensor) -> List[Tensor]: 292 | branch3x3 = self.branch3x3(x) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 297 | 298 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 299 | 300 | outputs = [branch3x3, branch3x3dbl, branch_pool] 301 | return outputs 302 | 303 | def forward(self, x: Tensor) -> Tensor: 304 | outputs = self._forward(x) 305 | return torch.cat(outputs, 1) 306 | 307 | 308 | class InceptionC(nn.Module): 309 | 310 | def __init__( 311 | self, 312 | in_channels: int, 313 | channels_7x7: int, 314 | conv_block: Optional[Callable[..., nn.Module]] = None 315 | ) -> None: 316 | super(InceptionC, self).__init__() 317 | if conv_block is None: 318 | conv_block = BasicConv2d 319 | self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) 320 | 321 | c7 = channels_7x7 322 | self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) 323 | self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 324 | self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 325 | 326 | self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) 327 | self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 328 | self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 329 | self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 330 | self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 331 | 332 | self.branch_pool = conv_block(in_channels, 192, kernel_size=1) 333 | 334 | def _forward(self, x: Tensor) -> List[Tensor]: 335 | branch1x1 = self.branch1x1(x) 336 | 337 | branch7x7 = self.branch7x7_1(x) 338 | branch7x7 = self.branch7x7_2(branch7x7) 339 | branch7x7 = self.branch7x7_3(branch7x7) 340 | 341 | branch7x7dbl = self.branch7x7dbl_1(x) 342 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 343 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 344 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 345 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 346 | 347 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 348 | branch_pool = self.branch_pool(branch_pool) 349 | 350 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 351 | return outputs 352 | 353 | def forward(self, x: Tensor) -> Tensor: 354 | outputs = self._forward(x) 355 | return torch.cat(outputs, 1) 356 | 357 | 358 | class InceptionD(nn.Module): 359 | 360 | def __init__( 361 | self, 362 | in_channels: int, 363 | conv_block: Optional[Callable[..., nn.Module]] = None 364 | ) -> None: 365 | super(InceptionD, self).__init__() 366 | if conv_block is None: 367 | conv_block = BasicConv2d 368 | self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) 369 | self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) 370 | 371 | self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) 372 | self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) 373 | self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) 374 | self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) 375 | 376 | def _forward(self, x: Tensor) -> List[Tensor]: 377 | branch3x3 = self.branch3x3_1(x) 378 | branch3x3 = self.branch3x3_2(branch3x3) 379 | 380 | branch7x7x3 = self.branch7x7x3_1(x) 381 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 382 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 383 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 384 | 385 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 386 | outputs = [branch3x3, branch7x7x3, branch_pool] 387 | return outputs 388 | 389 | def forward(self, x: Tensor) -> Tensor: 390 | outputs = self._forward(x) 391 | return torch.cat(outputs, 1) 392 | 393 | 394 | class InceptionE(nn.Module): 395 | 396 | def __init__( 397 | self, 398 | in_channels: int, 399 | conv_block: Optional[Callable[..., nn.Module]] = None 400 | ) -> None: 401 | super(InceptionE, self).__init__() 402 | if conv_block is None: 403 | conv_block = BasicConv2d 404 | self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) 405 | 406 | self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) 407 | self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) 408 | self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) 409 | 410 | self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) 411 | self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) 412 | self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) 413 | self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) 414 | 415 | self.branch_pool = conv_block(in_channels, 192, kernel_size=1) 416 | 417 | def _forward(self, x: Tensor) -> List[Tensor]: 418 | branch1x1 = self.branch1x1(x) 419 | 420 | branch3x3 = self.branch3x3_1(x) 421 | branch3x3 = [ 422 | self.branch3x3_2a(branch3x3), 423 | self.branch3x3_2b(branch3x3), 424 | ] 425 | branch3x3 = torch.cat(branch3x3, 1) 426 | 427 | branch3x3dbl = self.branch3x3dbl_1(x) 428 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 429 | branch3x3dbl = [ 430 | self.branch3x3dbl_3a(branch3x3dbl), 431 | self.branch3x3dbl_3b(branch3x3dbl), 432 | ] 433 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 434 | 435 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 436 | branch_pool = self.branch_pool(branch_pool) 437 | 438 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 439 | return outputs 440 | 441 | def forward(self, x: Tensor) -> Tensor: 442 | outputs = self._forward(x) 443 | return torch.cat(outputs, 1) 444 | 445 | 446 | class InceptionAux(nn.Module): 447 | 448 | def __init__( 449 | self, 450 | in_channels: int, 451 | num_classes: int, 452 | conv_block: Optional[Callable[..., nn.Module]] = None 453 | ) -> None: 454 | super(InceptionAux, self).__init__() 455 | if conv_block is None: 456 | conv_block = BasicConv2d 457 | self.conv0 = conv_block(in_channels, 128, kernel_size=1) 458 | self.conv1 = conv_block(128, 768, kernel_size=5) 459 | self.conv1.stddev = 0.01 # type: ignore[assignment] 460 | self.fc = nn.Linear(768, num_classes) 461 | self.fc.stddev = 0.001 # type: ignore[assignment] 462 | 463 | def forward(self, x: Tensor) -> Tensor: 464 | # N x 768 x 17 x 17 465 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 466 | # N x 768 x 5 x 5 467 | x = self.conv0(x) 468 | # N x 128 x 5 x 5 469 | x = self.conv1(x) 470 | # N x 768 x 1 x 1 471 | # Adaptive average pooling 472 | x = F.adaptive_avg_pool2d(x, (1, 1)) 473 | # N x 768 x 1 x 1 474 | x = torch.flatten(x, 1) 475 | # N x 768 476 | x = self.fc(x) 477 | # N x 1000 478 | return x 479 | 480 | 481 | class BasicConv2d(nn.Module): 482 | 483 | def __init__( 484 | self, 485 | in_channels: int, 486 | out_channels: int, 487 | **kwargs: Any 488 | ) -> None: 489 | super(BasicConv2d, self).__init__() 490 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 491 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 492 | 493 | def forward(self, x: Tensor) -> Tensor: 494 | x = self.conv(x) 495 | x = self.bn(x) 496 | return F.relu(x, inplace=True) 497 | -------------------------------------------------------------------------------- /networks/inception_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from math import floor 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | import cv2 10 | 11 | import torchvision 12 | from networks.inception import inception_v3 13 | import pdb 14 | 15 | 16 | def get_net(conf): 17 | return inception_v3(pretrained=conf.pretrained) 18 | -------------------------------------------------------------------------------- /networks/modelutil.py: -------------------------------------------------------------------------------- 1 | 2 | import imp 3 | import os 4 | import torch.nn as nn 5 | 6 | def get_model(conf): 7 | if 'resne' in conf.netname: 8 | net_type = 'resnet_ft' 9 | elif 'densenet' in conf.netname: 10 | net_type = 'densenet_ft' 11 | elif 'inception' in conf.netname: 12 | net_type = 'inception_ft' 13 | elif 'efficient' in conf.netname: 14 | net_type = 'efficientnet_ft' 15 | else: 16 | print('{} type not support'.format(conf.netname)) 17 | 18 | 19 | src_file = os.path.join('networks',net_type+'.py') 20 | netimp = imp.load_source('networks',src_file) 21 | net = netimp.get_net(conf) 22 | return net 23 | 24 | def count_params(net): 25 | print(sum(p.numel() for p in net.parameters() if p.requires_grad)) 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, dataset, depth, num_classes, bottleneck=False): 86 | super(ResNet, self).__init__() 87 | self.dataset = dataset 88 | if self.dataset.startswith('cifar'): 89 | self.inplanes = 16 90 | print(bottleneck) 91 | if bottleneck == True: 92 | n = int((depth - 2) / 9) 93 | block = Bottleneck 94 | else: 95 | n = int((depth - 2) / 6) 96 | block = BasicBlock 97 | 98 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.layer1 = self._make_layer(block, 16, n) 102 | self.layer2 = self._make_layer(block, 32, n, stride=2) 103 | self.layer3 = self._make_layer(block, 64, n, stride=2) 104 | self.avgpool = nn.AvgPool2d(8) 105 | #self.fc = nn.Linear(64 * block.expansion, num_classes) 106 | self.classifier = nn.Linear(64 * block.expansion, num_classes) 107 | 108 | elif dataset == 'imagenet': 109 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 110 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 111 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 112 | 113 | self.inplanes = 64 114 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 119 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 120 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 121 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 122 | self.avgpool = nn.AvgPool2d(7) 123 | #self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 124 | self.classifier = nn.Linear(512 * blocks[depth].expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | conv4 = self.layer3(x) 160 | 161 | x = self.avgpool(conv4) 162 | x = x.view(x.size(0), -1) 163 | #x = self.fc(x) 164 | x = self.classifier(x) 165 | 166 | elif self.dataset == 'imagenet': 167 | x = self.conv1(x) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | 172 | x = self.layer1(x) 173 | x = self.layer2(x) 174 | x = self.layer3(x) 175 | conv4 = self.layer4(x) 176 | 177 | x = self.avgpool(conv4) 178 | x = x.view(x.size(0), -1) 179 | #jx = self.fc(x) 180 | x = self.classifier(x) 181 | 182 | return x,conv4,None 183 | 184 | def get_net(conf): 185 | return ResNet(conf.dataset, conf.depth, conf.num_class, True) 186 | 187 | -------------------------------------------------------------------------------- /networks/resnet_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from math import floor 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | import cv2 10 | 11 | import torchvision 12 | from torchvision import models 13 | import pdb 14 | 15 | class ResNet(nn.Module): 16 | 17 | def __init__(self,conf): 18 | super(ResNet, self).__init__() 19 | basenet = eval('models.'+conf.netname)(pretrained=conf.pretrained) 20 | self.conv3 = nn.Sequential(*list(basenet.children())[:-4]) 21 | self.conv4 = list(basenet.children())[-4] 22 | self.midlevel = False 23 | self.isdetach = True 24 | if 'midlevel' in conf: 25 | self.midlevel = conf.midlevel 26 | if 'isdetach' in conf: 27 | self.isdetach = isdetacjh 28 | 29 | mid_dim = 1024 30 | feadim = 2048 31 | if conf.netname in ['resnet18','resnet34']: 32 | mid_dim = 256 33 | feadim = 512 34 | 35 | if self.midlevel: 36 | self.mcls = nn.Linear(mid_dim, conf.num_class) 37 | self.max_pool = nn.AdaptiveMaxPool2d((1, 1)) 38 | self.conv4_1 = nn.Sequential(nn.Conv2d(mid_dim, mid_dim, 1, 1), nn.ReLU()) 39 | self.conv5 = list(basenet.children())[-3] 40 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 41 | self.classifier = nn.Linear(feadim, conf.num_class) 42 | 43 | def set_detach(self,isdetach=True): 44 | self.isdetach = isdetach 45 | 46 | def forward(self, x): 47 | x = self.conv3(x) 48 | conv4 = self.conv4(x) 49 | x = self.conv5(conv4) 50 | fea_pool = self.avg_pool(x).view(x.size(0), -1) 51 | logits = self.classifier(fea_pool) 52 | 53 | if self.midlevel: 54 | if self.isdetach: 55 | conv4_1 = conv4.detach() 56 | else: 57 | conv4_1 = conv4 58 | conv4_1 = self.conv4_1(conv4_1) 59 | pool4_1 = self.max_pool(conv4_1).view(conv4_1.size(0),-1) 60 | mlogits = self.mcls(pool4_1) 61 | else: 62 | mlogits = None 63 | 64 | return logits,x.detach(),mlogits 65 | 66 | 67 | def _init_weight(self, block): 68 | for m in block.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 71 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 72 | nn.init.constant_(m.weight, 1) 73 | nn.init.constant_(m.bias, 0) 74 | 75 | def get_params(self, param_name): 76 | ftlayer_params = list(self.conv3.parameters()) +\ 77 | list(self.conv4.parameters()) +\ 78 | list(self.conv5.parameters()) 79 | ftlayer_params_ids = list(map(id, ftlayer_params)) 80 | freshlayer_params = filter(lambda p: id(p) not in ftlayer_params_ids, self.parameters()) 81 | 82 | return eval(param_name+'_params') 83 | 84 | 85 | def get_net(conf): 86 | return ResNet(conf) 87 | -------------------------------------------------------------------------------- /trainer/comm_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import cv2 11 | import time 12 | import logging 13 | 14 | from utils import * 15 | 16 | def validate(train_loader, model, criterion, conf): 17 | 18 | batch_time = AverageMeter() 19 | data_time = AverageMeter() 20 | losses = AverageMeter() 21 | 22 | scores = AverageAccMeter() 23 | mscores = AverageAccMeter() 24 | ascores = AverageAccMeter() 25 | end = time.time() 26 | model.eval() 27 | 28 | time_start = time.time() 29 | pbar = tqdm(train_loader, dynamic_ncols=True, total=len(train_loader)) 30 | 31 | for idx, (input, target) in enumerate(pbar): 32 | # measure data loading time 33 | data_time.add(time.time() - end) 34 | input = input.cuda() 35 | target = target.cuda() 36 | 37 | if 'inception' in conf.netname: 38 | output = model(input) 39 | else: 40 | output,_,moutput = model(input) 41 | scores.add(output.data, target) 42 | if 'midlevel' in conf: 43 | if conf.midlevel: 44 | mscores.add(moutput.data, target) 45 | ascores.add(output+moutput.data, target) 46 | 47 | loss = torch.mean(criterion(output, target)) 48 | losses.add(loss.item(), input.size(0)) 49 | del loss,output 50 | 51 | # measure elapsed time 52 | batch_time.add(time.time() - end) 53 | end = time.time() 54 | pbar.set_postfix(batch_time=batch_time.value(), data_time=data_time.value(), loss=losses.value()) 55 | 56 | return scores.value(), losses.value(),mscores.value(),ascores.value() 57 | -------------------------------------------------------------------------------- /trainer/comm_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import cv2 11 | import time 12 | import logging 13 | from utils import * 14 | 15 | 16 | 17 | def train(train_loader, model, criterion, optimizer, conf,wmodel=None): 18 | 19 | batch_time = AverageMeter() 20 | data_time = AverageMeter() 21 | losses = AverageMeter() 22 | scores = AverageAccMeter() 23 | end = time.time() 24 | model.train() 25 | 26 | time_start = time.time() 27 | pbar = tqdm(train_loader, dynamic_ncols=True, total=len(train_loader)) 28 | mixmethod = None 29 | clsw = None 30 | if 'mixmethod' in conf: 31 | if 'baseline' not in conf.mixmethod: 32 | mixmethod = conf.mixmethod 33 | if wmodel is None: 34 | wmodel = model 35 | 36 | for idx, (input, target) in enumerate(pbar): 37 | 38 | # measure data loading time 39 | data_time.add(time.time() - end) 40 | input = input.cuda() 41 | target = target.cuda() 42 | 43 | if 'baseline' not in conf.mixmethod: 44 | input,target_a,target_b,lam_a,lam_b = eval(mixmethod)(input,target,conf,wmodel) 45 | 46 | output,_,moutput = model(input) 47 | 48 | loss_a = criterion(output, target_a) 49 | loss_b = criterion(output, target_b) 50 | loss = torch.mean(loss_a* lam_a + loss_b* lam_b) 51 | 52 | if 'inception' in conf.netname: 53 | loss1_a = criterion(moutput, target_a) 54 | loss1_b = criterion(moutput, target_b) 55 | loss1 = torch.mean(loss1_a* lam_a + loss1_b* lam_b) 56 | loss += 0.4*loss1 57 | 58 | if 'midlevel' in conf: 59 | if conf.midlevel: 60 | loss_ma = criterion(moutput, target_a) 61 | loss_mb = criterion(moutput, target_b) 62 | loss += torch.mean(loss_ma* lam_a + loss_mb* lam_b) 63 | else: 64 | output,_,moutput = model(input) 65 | loss = torch.mean(criterion(output, target)) 66 | 67 | if 'inception' in conf.netname: 68 | loss += 0.4*torch.mean(criterion(moutput,target)) 69 | 70 | if 'midlevel' in conf: 71 | if conf.midlevel: 72 | loss += torch.mean(criterion(moutput,target)) 73 | 74 | # measure accuracy and record loss 75 | losses.add(loss.item(), input.size(0)) 76 | 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | # measure elapsed time 82 | batch_time.add(time.time() - end) 83 | end = time.time() 84 | 85 | pbar.set_postfix(batch_time=batch_time.value(), data_time=data_time.value(), loss=losses.value(), score=0) 86 | 87 | return losses.value() 88 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .conf import * 2 | from .log import * 3 | from .io import * 4 | from .stat import * 5 | from .custom_ops import * 6 | from .trainutil import * 7 | from .dataloader import * 8 | from .mixmethod import * 9 | 10 | -------------------------------------------------------------------------------- /utils/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pprint 4 | import yaml 5 | import numpy as np 6 | import torch 7 | import random 8 | import torch.backends.cudnn as cudnn 9 | from easydict import EasyDict as edict 10 | 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch Training') 13 | 14 | def str2bool(v): 15 | return v.lower() in ('true', '1') 16 | 17 | 18 | def add_argument_group(name): 19 | arg = parser.add_argument_group(name) 20 | arg_lists.append(arg) 21 | return arg 22 | 23 | # ------------------------------ 24 | def parser2dict(): 25 | config, unparsed = parser.parse_known_args() 26 | cfg = edict(config.__dict__) 27 | # print("Config:\n" + pprint.pformat(cfg)) 28 | return edict(cfg) 29 | 30 | 31 | # ------------------------------ 32 | def _merge_a_into_b(a, b): 33 | """Merge config dictionary a into config dictionary b, clobbering the 34 | options in b whenever they are also specified in a. 35 | """ 36 | if type(a) is not edict: 37 | return 38 | 39 | for k, v in a.items(): 40 | # a must specify keys that are in b 41 | #if k not in b: 42 | # raise KeyError('{} is not a valid config key'.format(k)) 43 | 44 | # recursively merge dicts 45 | if type(v) is edict: 46 | try: 47 | _merge_a_into_b(a[k], b[k]) 48 | except: 49 | print('Error under config key: {}'.format(k)) 50 | raise 51 | else: 52 | #if k not in b: 53 | b[k] = v 54 | 55 | def print_conf(opt): 56 | """Print and save options 57 | It will print both current options and default values(if different). 58 | It will save options into a text file / [checkpoints_dir] / opt.txt 59 | """ 60 | message = '' 61 | message += '----------------- Options ---------------\n' 62 | for k, v in sorted(vars(opt).items()): 63 | comment = '' 64 | # default = self.parser.get_default(k) 65 | # if v != default: 66 | # comment = '\t[default: %s]' % str(default) 67 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 68 | message += '----------------- End -------------------' 69 | return message 70 | 71 | 72 | 73 | 74 | def cfg_from_file(cfg): 75 | """Load a config from file filename and merge it into the default options. 76 | """ 77 | 78 | filename=cfg.config 79 | # args from yaml file 80 | with open(filename, 'r') as f: 81 | yaml_cfg = edict(yaml.safe_load(f)) 82 | 83 | _merge_a_into_b(yaml_cfg, cfg) 84 | 85 | return cfg 86 | 87 | 88 | def get_config(): 89 | 90 | # args from argparser 91 | cfg = parser2dict() 92 | cfg = cfg_from_file(cfg) 93 | if 'mixmethod' in cfg: 94 | cfg['mixmethod'] = cfg['mixmethod'].split(',') 95 | if len(cfg['mixmethod']) == 1: 96 | cfg['mixmethod'] = cfg['mixmethod'][0] 97 | 98 | if not cfg.pretrained: 99 | cfg['lr_group'] = [0.01,0.01] 100 | cfg['epochs'] = 300 101 | 102 | if cfg['epochs'] == 300: 103 | cfg['lrstep'] = [150,225,270] 104 | 105 | if cfg['epochs'] == 100: 106 | cfg['lrstep'] = [40,70] 107 | 108 | 109 | if cfg.dataset in ['nabirds','cub']: 110 | cfg['warp'] = False 111 | 112 | return cfg 113 | 114 | 115 | def set_env(cfg): 116 | # set seeding 117 | random.seed(cfg.seed) 118 | np.random.seed(cfg.seed) # cpu vars 119 | torch.manual_seed(cfg.seed) # cpu vars 120 | torch.cuda.manual_seed(cfg.seed) # cpu vars 121 | torch.cuda.manual_seed_all(cfg.seed) # gpu vars 122 | if 'cudnn' in cfg: 123 | torch.backends.cudnn.benchmark = cfg.cudnn 124 | else: 125 | torch.backends.cudnn.benchmark = False 126 | 127 | cudnn.deterministic = True 128 | os.environ["NUMEXPR_MAX_THREADS"] = '16' 129 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_ids 130 | 131 | 132 | 133 | # ---------------------------------------------------------------------------------------- 134 | # base 135 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 4)') 136 | parser.add_argument('-b', '--batch_size', default=16, type=int, metavar='N', help='mini-batch size (default: 64)') 137 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 138 | parser.add_argument('--weightfile', default=None, type=str, metavar='PATH', help='path to model (default: none)') 139 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 140 | parser.add_argument('--seed', default=0, type=int, help='seeding for all random operation') 141 | parser.add_argument('--config', default='config/comm.yml', type=str, help='config files') 142 | 143 | # train 144 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 145 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 146 | parser.add_argument('--resume', default='', type=str, metavar='path', help='path to latest checkpoint (default: none)') 147 | parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') 148 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate') 149 | 150 | parser.add_argument('--pretrained', default=1, type=float, help='loss weights') 151 | 152 | # others 153 | 154 | parser.add_argument('--mixmethod', default='baseline', type=str, help='config files') 155 | parser.add_argument('--netname', default='resnet50', type=str, help='config files') 156 | parser.add_argument('--prob', type=float, default=1.0, help='') 157 | parser.add_argument('--beta', type=float, default=1.0, help='') 158 | parser.add_argument('--dataset', default='cub', type=str, help='dataset') 159 | parser.add_argument('--cropsize', default=448, type=int, metavar='N', help='cropsize') 160 | parser.add_argument( '--midlevel', dest='midlevel', action='store_true', help='midlevel') 161 | parser.add_argument('--train_proc', default='comm', type=str, help='dataset') 162 | parser.add_argument('--start_eval', default=-1, type=int, metavar='N', help='network depth') 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def rand_bbox(size, lam,center=False,attcen=None): 4 | if len(size) == 4: 5 | W = size[2] 6 | H = size[3] 7 | elif len(size) == 3: 8 | W = size[1] 9 | H = size[2] 10 | elif len(size) == 2: 11 | W = size[0] 12 | H = size[1] 13 | else: 14 | raise Exception 15 | 16 | cut_rat = np.sqrt(1. - lam) 17 | cut_w = np.int(W * cut_rat) 18 | cut_h = np.int(H * cut_rat) 19 | 20 | if attcen is None: 21 | # uniform 22 | cx = 0 23 | cy = 0 24 | if W>0 and H>0: 25 | cx = np.random.randint(W) 26 | cy = np.random.randint(H) 27 | if center: 28 | cx = int(W/2) 29 | cy = int(H/2) 30 | else: 31 | cx = attcen[0] 32 | cy = attcen[1] 33 | 34 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 35 | bby1 = np.clip(cy - cut_h // 2, 0, H) 36 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 37 | bby2 = np.clip(cy + cut_h // 2, 0, H) 38 | 39 | return bbx1, bby1, bbx2, bby2 40 | 41 | def get_bbox(imgsize=(224,224),beta=1.0): 42 | 43 | r = np.random.rand(1) 44 | lam = np.random.beta(beta, beta) 45 | bbx1, bby1, bbx2, bby2 = rand_bbox(imgsize, lam) 46 | 47 | return [bbx1,bby1,bbx2,bby2] 48 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils import data 3 | import imp 4 | import os 5 | 6 | 7 | def get_dataloader(conf): 8 | 9 | src_file = os.path.join('datasets',conf.dataset+'.py') 10 | dataimp = imp.load_source('loader',src_file) 11 | ds_train,ds_test = dataimp.get_dataset(conf) 12 | if 'trainshuffle' in conf: 13 | trainshuffle = conf.trainshuffle 14 | else: 15 | trainshuffle = True 16 | 17 | print('train shuffle:',trainshuffle) 18 | train_loader = data.DataLoader(ds_train, batch_size=conf.batch_size, shuffle= trainshuffle, num_workers=conf.workers, pin_memory=True) 19 | val_loader =data.DataLoader(ds_test, batch_size=conf.batch_size, shuffle= False, num_workers=conf.workers, pin_memory=True) 20 | 21 | return train_loader,val_loader 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import pdb 3 | 4 | __all__ = ['accuracy','wrong_index'] 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the precision@k for the specified values of k""" 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | res = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(0) 18 | res.append(correct_k.mul_(100.0 / batch_size)) 19 | return res 20 | 21 | 22 | def wrong_index(output,target,index,topk=(1,)): 23 | maxk = max(topk) 24 | batch_size = target.size(0) 25 | 26 | _, pred = output.topk(maxk, 1, True, True) 27 | pred = pred.t() 28 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 29 | 30 | #res = [] 31 | #for k in topk: 32 | correct_k = correct[:1].view(-1) 33 | out = index[~correct_k] 34 | pred = pred.view(-1) 35 | pre = pred[~correct_k] 36 | #correct_k = correct[:k].view(-1).float().sum(0) 37 | #res.append(correct_k.mul_(100.0 / batch_size)) 38 | return out,pre 39 | 40 | 41 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path as path 4 | from datetime import datetime 5 | import shutil 6 | from tqdm import tqdm 7 | import math 8 | from urllib.request import urlretrieve 9 | 10 | 11 | # ---------------load checkpoint-------------------- 12 | def load_checkpoint(model, pth_file): 13 | print('==> Reading from model checkpoint..') 14 | assert os.path.isfile(pth_file), 'Error: no model checkpoint directory found!' 15 | checkpoint = torch.load(pth_file) 16 | 17 | pretrained_dict = checkpoint['state_dict'] 18 | model_dict = model.module.state_dict() 19 | model_dict.update(pretrained_dict) 20 | 21 | model.module.load_state_dict(model_dict) 22 | print("=> loaded model checkpoint '{}' (epoch {})" 23 | .format(pth_file, checkpoint['epoch'])) 24 | 25 | return checkpoint 26 | 27 | 28 | # ---------------save checkpoint-------------------- 29 | def save_checkpoint(state, is_best=False, outdir='checkpoint', filename='checkpoint.pth',iteral=50): 30 | 31 | epochnum = state['epoch'] 32 | filepath = os.path.join(outdir, filename) 33 | epochpath = str(epochnum)+'_'+filename 34 | epochpath = os.path.join(outdir, epochpath) 35 | if epochnum % iteral == 0: 36 | savepath = epochpath 37 | else: 38 | savepath = filepath 39 | torch.save(state, savepath) 40 | if is_best: 41 | shutil.copyfile(savepath, os.path.join(outdir, 'model_best.pth.tar')) 42 | 43 | 44 | 45 | def set_outdir(conf): 46 | 47 | default_outdir = 'results' 48 | if 'timedir' in conf: 49 | timestr = datetime.now().strftime('%d-%m-%Y_%I_%M-%S_%p') 50 | outdir = os.path.join(default_outdir,conf.exp_name, \ 51 | conf.net_type+'_'+conf.dataset,timestr) 52 | else: 53 | outdir = os.path.join(default_outdir,conf.exp_name, \ 54 | conf.netname+'_'+conf.dataset) 55 | 56 | prefix = 'bs_'+str(conf.batch_size)+'seed_'+str(conf.seed) 57 | 58 | if conf.weightfile: 59 | prefix = 'ft_'+prefix 60 | 61 | if not conf.pretrained: 62 | prefix = 'scratch_'+prefix 63 | 64 | if 'midlevel' in conf: 65 | if conf.midlevel: 66 | prefix += 'mid_' 67 | if 'mixmethod' in conf: 68 | if isinstance(conf.mixmethod,list): 69 | prefix += '_'.join(conf.mixmethod) 70 | else: 71 | prefix += conf.mixmethod+'_' 72 | if 'prob' in conf: 73 | prefix += '_p'+str(conf.prob) 74 | if 'beta' in conf: 75 | prefix += '_b'+str(conf.beta) 76 | 77 | outdir = os.path.join(outdir,prefix) 78 | ensure_dir(outdir) 79 | conf['outdir'] = outdir 80 | 81 | return conf 82 | 83 | 84 | 85 | # check if dir exist, if not create new folder 86 | def ensure_dir(dir_name): 87 | if not os.path.exists(dir_name): 88 | os.makedirs(dir_name) 89 | print('{} is created'.format(dir_name)) 90 | 91 | 92 | def ensure_file(file_path): 93 | 94 | newpath = file_path 95 | if os.path.exists(file_path): 96 | timestr = datetime.now().strftime('%d-%m-%Y_%I_%M-%S_%p_') 97 | newpath = path.join(path.dirname(file_path),timestr + path.basename(file_path)) 98 | return newpath 99 | 100 | 101 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import utils 4 | from utils.conf import print_conf 5 | import os 6 | 7 | 8 | def set_logger(cfg): 9 | """Set the logger to log info in terminal and file `log_path`. 10 | 11 | In general, it is useful to have a logger so that every output to the terminal is saved 12 | in a permanent file. Here we save it to `model_dir/train.log`. 13 | 14 | Example: 15 | ``` 16 | logging.info("Starting training...") 17 | ``` 18 | 19 | Args: 20 | log_path: (string) where to log 21 | """ 22 | 23 | if 'loglevel' in cfg: 24 | loglevel = eval('logging.'+loglevel) 25 | else: 26 | loglevel = logging.INFO 27 | 28 | 29 | if cfg.evaluate: 30 | outname = 'test.log' 31 | else: 32 | outname = 'train.log' 33 | 34 | outdir = cfg['outdir'] 35 | log_path = os.path.join(outdir,outname) 36 | 37 | 38 | logger = logging.getLogger() 39 | logger.setLevel(loglevel) 40 | 41 | if not logger.handlers: 42 | # Logging to a file 43 | file_handler = logging.FileHandler(log_path) 44 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 45 | logger.addHandler(file_handler) 46 | 47 | # Logging to console 48 | stream_handler = logging.StreamHandler() 49 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 50 | logger.addHandler(stream_handler) 51 | 52 | logging.info(print_conf(cfg)) 53 | logging.info('writting logs to file {}'.format(log_path)) 54 | 55 | 56 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | import random 4 | import numpy as np 5 | from sklearn.metrics import accuracy_score 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import shutil 10 | import pdb 11 | import logging 12 | 13 | def set_logger(log_path): 14 | """Set the logger to log info in terminal and file `log_path`. 15 | 16 | In general, it is useful to have a logger so that every output to the terminal is saved 17 | in a permanent file. Here we save it to `model_dir/train.log`. 18 | 19 | Example: 20 | ``` 21 | logging.info("Starting training...") 22 | ``` 23 | 24 | Args: 25 | log_path: (string) where to log 26 | """ 27 | logger = logging.getLogger() 28 | logger.setLevel(logging.INFO) 29 | 30 | if not logger.handlers: 31 | # Logging to a file 32 | file_handler = logging.FileHandler(log_path) 33 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 34 | logger.addHandler(file_handler) 35 | 36 | # Logging to console 37 | stream_handler = logging.StreamHandler() 38 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 39 | logger.addHandler(stream_handler) 40 | 41 | 42 | def count_parameters(model): 43 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 44 | 45 | 46 | # check if dir exist, if not create new folder 47 | def ensure_dir(dir_name): 48 | if not os.path.exists(dir_name): 49 | os.makedirs(dir_name) 50 | 51 | 52 | # ---------------save checkpoint-------------------- 53 | def save_checkpoint(state, is_best=False, checkpoint='checkpoint', filename='checkpoint.pth'): 54 | filepath = os.path.join(checkpoint, filename) 55 | torch.save(state, filepath) 56 | if is_best: 57 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 58 | 59 | 60 | # ---------------update meter-------------------- 61 | def update_meter(dict_meter, dict_content, batch_size): 62 | idx = 0 63 | for key, value in dict_meter.items(): 64 | if type(batch_size) == list: 65 | bs = batch_size[idx] 66 | else: 67 | bs = batch_size 68 | 69 | if isinstance(dict_content[key], torch.Tensor): 70 | value.update(dict_content[key].item(), bs) 71 | else: 72 | value.update(dict_content[key], bs) 73 | idx += 1 74 | 75 | 76 | # ---------------load checkpoint-------------------- 77 | def load_checkpoint(model, pth_file): 78 | print('==> Reading from model checkpoint..') 79 | assert os.path.isfile(pth_file), 'Error: no model checkpoint directory found!' 80 | checkpoint = torch.load(pth_file) 81 | # args.start_epoch = checkpoint['epoch'] 82 | # best_prec1 = checkpoint['best_prec1'] 83 | 84 | pretrained_dict = checkpoint['state_dict'] 85 | model_dict = model.module.state_dict() 86 | model_dict.update(pretrained_dict) 87 | 88 | # model.module.load_state_dict(checkpoint['state_dict']) 89 | model.module.load_state_dict(model_dict) 90 | print("=> loaded model checkpoint '{}' (epoch {})" 91 | .format(pth_file, checkpoint['epoch'])) 92 | 93 | # results = {'model': model, 'checkpoint': checkpoint} 94 | return checkpoint 95 | 96 | 97 | # ---------------running mean-------------------- 98 | class RunningMean: 99 | def __init__(self): 100 | self.val = 0 101 | self.avg = 0 102 | self.sum = 0 103 | self.count = 0 104 | 105 | def update(self, val, n=1.): 106 | self.val = val 107 | self.sum += val * n 108 | self.count += n 109 | self.avg = self.sum / self.count 110 | 111 | @property 112 | def value(self): 113 | # if self.count: 114 | # return float(self.total_value) / self.count 115 | # else: 116 | # return 0 117 | return self.avg 118 | 119 | def __str__(self): 120 | return str(self.value) 121 | 122 | # ---------------more accutate Acc-------------------- 123 | class RunningAcc: 124 | def __init__(self): 125 | self.avg = 0. 126 | self.pred = [] 127 | self.tgt = [] 128 | 129 | def update(self, logits, tgt): 130 | pred = torch.argmax(logits, dim=1) 131 | self.pred.extend(pred.cpu().numpy().tolist()) 132 | self.tgt.extend(tgt.cpu().numpy().tolist()) 133 | 134 | @property 135 | def value(self): 136 | self.avg = accuracy_score(self.pred, self.tgt) 137 | return self.avg*100 138 | 139 | def __str__(self): 140 | return str(self.value) 141 | 142 | def set_seeding(seed): 143 | random.seed(seed) 144 | np.random.seed(seed) # cpu vars 145 | torch.manual_seed(seed) # cpu vars 146 | torch.cuda.manual_seed(seed) # cpu vars 147 | torch.cuda.manual_seed_all(seed) # gpu vars 148 | torch.backends.cudnn.benchmark = False 149 | cudnn.deterministic = True 150 | 151 | 152 | def init_params(net): 153 | '''Init layer parameters.''' 154 | for m in net.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | init.kaiming_normal(m.weight, mode='fan_out') 157 | if m.bias: 158 | init.constant(m.bias, 0) 159 | elif isinstance(m, nn.BatchNorm2d): 160 | init.constant(m.weight, 1) 161 | init.constant(m.bias, 0) 162 | elif isinstance(m, nn.Linear): 163 | init.normal(m.weight, std=1e-3) 164 | if m.bias: 165 | init.constant(m.bias, 0) 166 | 167 | -------------------------------------------------------------------------------- /utils/mixmethod.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import imp 5 | import numpy as np 6 | import utils 7 | import os 8 | import torch.nn.functional as F 9 | import random 10 | import copy 11 | 12 | def get_spm(input,target,conf,model): 13 | 14 | imgsize = (conf.cropsize,conf.cropsize) 15 | bs = input.size(0) 16 | with torch.no_grad(): 17 | output,fms,_ = model(input) 18 | if 'inception' in conf.netname: 19 | clsw = model.module.fc 20 | else: 21 | clsw = model.module.classifier 22 | weight = clsw.weight.data 23 | bias = clsw.bias.data 24 | weight = weight.view(weight.size(0),weight.size(1),1,1) 25 | fms = F.relu(fms) 26 | poolfea = F.adaptive_avg_pool2d(fms,(1,1)).squeeze() 27 | clslogit = F.softmax(clsw.forward(poolfea)) 28 | logitlist = [] 29 | for i in range(bs): 30 | logitlist.append(clslogit[i,target[i]]) 31 | clslogit = torch.stack(logitlist) 32 | 33 | out = F.conv2d(fms, weight, bias=bias) 34 | 35 | outmaps = [] 36 | for i in range(bs): 37 | evimap = out[i,target[i]] 38 | outmaps.append(evimap) 39 | 40 | outmaps = torch.stack(outmaps) 41 | if imgsize is not None: 42 | outmaps = outmaps.view(outmaps.size(0),1,outmaps.size(1),outmaps.size(2)) 43 | outmaps = F.interpolate(outmaps,imgsize,mode='bilinear',align_corners=False) 44 | 45 | outmaps = outmaps.squeeze() 46 | 47 | for i in range(bs): 48 | outmaps[i] -= outmaps[i].min() 49 | outmaps[i] /= outmaps[i].sum() 50 | 51 | 52 | return outmaps,clslogit 53 | 54 | 55 | 56 | def snapmix(input,target,conf,model=None): 57 | 58 | r = np.random.rand(1) 59 | lam_a = torch.ones(input.size(0)) 60 | lam_b = 1 - lam_a 61 | target_b = target.clone() 62 | 63 | if r < conf.prob: 64 | wfmaps,_ = get_spm(input,target,conf,model) 65 | bs = input.size(0) 66 | lam = np.random.beta(conf.beta, conf.beta) 67 | lam1 = np.random.beta(conf.beta, conf.beta) 68 | rand_index = torch.randperm(bs).cuda() 69 | wfmaps_b = wfmaps[rand_index,:,:] 70 | target_b = target[rand_index] 71 | 72 | same_label = target == target_b 73 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam) 74 | bbx1_1, bby1_1, bbx2_1, bby2_1 = utils.rand_bbox(input.size(), lam1) 75 | 76 | area = (bby2-bby1)*(bbx2-bbx1) 77 | area1 = (bby2_1-bby1_1)*(bbx2_1-bbx1_1) 78 | 79 | if area1 > 0 and area>0: 80 | ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone() 81 | ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True) 82 | input[:, :, bbx1:bbx2, bby1:bby2] = ncont 83 | lam_a = 1 - wfmaps[:,bbx1:bbx2,bby1:bby2].sum(2).sum(1)/(wfmaps.sum(2).sum(1)+1e-8) 84 | lam_b = wfmaps_b[:,bbx1_1:bbx2_1,bby1_1:bby2_1].sum(2).sum(1)/(wfmaps_b.sum(2).sum(1)+1e-8) 85 | tmp = lam_a.clone() 86 | lam_a[same_label] += lam_b[same_label] 87 | lam_b[same_label] += tmp[same_label] 88 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) 89 | lam_a[torch.isnan(lam_a)] = lam 90 | lam_b[torch.isnan(lam_b)] = 1-lam 91 | 92 | return input,target,target_b,lam_a.cuda(),lam_b.cuda() 93 | 94 | 95 | def as_cutmix(input,target,conf,model=None): 96 | 97 | r = np.random.rand(1) 98 | lam_a = torch.ones(input.size(0)) 99 | lam_b = 1 - lam_a 100 | target_b = target.clone() 101 | 102 | if r < conf.prob: 103 | bs = input.size(0) 104 | lam = np.random.beta(conf.beta, conf.beta) 105 | rand_index = torch.randperm(bs).cuda() 106 | target_b = target[rand_index] 107 | 108 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam) 109 | bbx1_1, bby1_1, bbx2_1, bby2_1 = utils.rand_bbox(input.size(), lam) 110 | 111 | if (bby2_1-bby1_1)*(bbx2_1-bbx1_1) > 4 and (bby2-bby1)*(bbx2-bbx1)>4: 112 | ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone() 113 | ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True) 114 | input[:, :, bbx1:bbx2, bby1:bby2] = ncont 115 | # adjust lambda to exactly match pixel ratio 116 | lam_a = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) 117 | lam_a *= torch.ones(input.size(0)) 118 | lam_b = 1 - lam_a 119 | 120 | return input,target,target_b,lam_a.cuda(),lam_b.cuda() 121 | 122 | def cutmix(input,target,conf,model=None): 123 | 124 | r = np.random.rand(1) 125 | lam_a = torch.ones(input.size(0)).cuda() 126 | target_b = target.clone() 127 | 128 | if r < conf.prob: 129 | bs = input.size(0) 130 | lam = np.random.beta(conf.beta, conf.beta) 131 | rand_index = torch.randperm(bs).cuda() 132 | target_b = target[rand_index] 133 | input_b = input[rand_index].clone() 134 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam) 135 | input[:, :, bbx1:bbx2, bby1:bby2] = input_b[:, :, bbx1:bbx2, bby1:bby2] 136 | 137 | # adjust lambda to exactly match pixel ratio 138 | lam_a = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) 139 | lam_a *= torch.ones(input.size(0)) 140 | 141 | lam_b = 1 - lam_a 142 | 143 | 144 | return input,target,target_b,lam_a.cuda(),lam_b.cuda() 145 | 146 | 147 | 148 | def cutout(input,target,conf=None,model=None): 149 | 150 | r = np.random.rand(1) 151 | lam = torch.ones(input.size(0)).cuda() 152 | target_b = target.clone() 153 | lam_a = lam 154 | lam_b = 1-lam 155 | 156 | if r < conf.prob: 157 | bs = input.size(0) 158 | lam = 0.75 159 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam) 160 | input[:, :, bbx1:bbx2, bby1:bby2] = 0 161 | 162 | return input,target,target_b,lam_a.cuda(),lam_b.cuda() 163 | 164 | 165 | def mixup(input,target,conf,model=None): 166 | r = np.random.rand(1) 167 | lam_a = torch.ones(input.size(0)).cuda() 168 | bs = input.size(0) 169 | target_a = target 170 | target_b = target 171 | 172 | if r < conf.prob: 173 | rand_index = torch.randperm(bs).cuda() 174 | target_b = target[rand_index] 175 | lam = np.random.beta(conf.beta, conf.beta) 176 | lam_a = lam_a*lam 177 | input = input * lam + input[rand_index] * (1-lam) 178 | 179 | lam_b = 1 - lam_a 180 | 181 | return input,target,target_b,lam_a.cuda(),lam_b.cuda() 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /utils/stat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | import math 4 | from sklearn.metrics import average_precision_score as aps 5 | import numpy as np 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def add(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | 24 | def value(self): 25 | return self.sum / self.count 26 | 27 | class AverageAccMeter(object): 28 | 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0 34 | self.avg = 0 35 | self.sum = 0 36 | self.count = 0 37 | 38 | def add(self, output,target): 39 | n = output.size(0) 40 | self.val = self.accuracy(output,target).item() 41 | self.sum += self.val * n 42 | self.count += n 43 | 44 | def value(self): 45 | if self.sum == 0: 46 | return 0 47 | else: 48 | return self.sum / self.count 49 | 50 | def accuracy(self,output, target, topk=(1,)): 51 | """Computes the precision@k for the specified values of k""" 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 62 | res.append(correct_k.mul_(100.0 / batch_size)) 63 | #wrong_k = batch_size - correct_k 64 | #res.append(wrong_k.mul_(100.0 / batch_size)) 65 | 66 | return res[0] 67 | 68 | -------------------------------------------------------------------------------- /utils/trainutil.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import imp 5 | import numpy as np 6 | import utils 7 | import os 8 | import torch.nn.functional as F 9 | import random 10 | import copy 11 | 12 | 13 | def get_sgd(params,conf): 14 | return torch.optim.SGD(params,conf.lr,momentum=conf.momentum,\ 15 | weight_decay=conf.weight_decay,nesterov=True) 16 | 17 | 18 | # criterion 19 | 20 | def get_criterion(conf): 21 | reduction = 'mean' 22 | if 'reduction' in conf: 23 | reduction = conf.reduction 24 | return eval('nn.'+conf.criterion)(reduction=reduction).cuda() 25 | 26 | # LR scheduler 27 | def get_multisteplr(optim,conf): 28 | 29 | return torch.optim.lr_scheduler.MultiStepLR(optim, \ 30 | milestones=conf.lrstep, gamma=conf.lrgamma, last_epoch=-1) 31 | 32 | def get_proc(conf): 33 | 34 | if 'train_proc' not in conf: 35 | train_proc = conf.net_type 36 | else: 37 | train_proc = conf.train_proc 38 | 39 | if 'test_proc' not in conf: 40 | test_proc = conf.net_type 41 | else: 42 | test_proc = conf.test_proc 43 | 44 | trainfile = '{}_train.py'.format(train_proc) 45 | testfile = '{}_test.py'.format(test_proc) 46 | trainpy= os.path.join('trainer',trainfile) 47 | testpy = os.path.join('trainer',testfile) 48 | train = imp.load_source('train',trainpy).train 49 | validate = imp.load_source('validate',testpy).validate 50 | 51 | return train,validate 52 | 53 | 54 | 55 | # parameters 56 | 57 | def get_params(model,conf=None): 58 | 59 | if conf is not None and 'prams_group' in conf: 60 | prams_group = conf.prams_group 61 | lr_group = conf.lr_group 62 | params = [] 63 | for pram,lr in zip(prams_group,lr_group): 64 | params.append({'params':model.module.get_params(pram),'lr': lr}) 65 | 66 | return params 67 | 68 | return model.parameters() 69 | 70 | 71 | def get_train_setting(model,conf): 72 | 73 | optim = 'sgd' 74 | criterion = 'cross_entropy' 75 | lrscheduler = 'multisteplr' 76 | 77 | if 'optim' in conf: 78 | optim = conf.optim 79 | 80 | if 'criterion' in conf: 81 | criterion = conf.criterion 82 | 83 | if 'lrscheduler' in conf: 84 | lrscheduler = conf.lrscheduler 85 | 86 | 87 | criterion = get_criterion(conf) 88 | optim = eval('get_'+optim)(get_params(model,conf),conf) 89 | lrscheduler = eval('get_'+lrscheduler)(optim,conf) 90 | 91 | return criterion,optim,lrscheduler 92 | 93 | 94 | 95 | 96 | 97 | --------------------------------------------------------------------------------