├── utils ├── consensus_loss.py ├── whitening.py ├── folder.py └── batch_norm.py ├── README.md ├── usps_mnist.py └── resnet50_dwt_mec_officehome.py /utils/consensus_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MinEntropyConsensusLoss(nn.Module): 6 | def __init__(self, num_classes, device): 7 | super(MinEntropyConsensusLoss, self).__init__() 8 | self.num_classes = num_classes 9 | self.device = device 10 | 11 | def forward(self, x, y): 12 | i = torch.eye(self.num_classes, device=self.device).unsqueeze(0) 13 | x = F.log_softmax(x, dim=1) 14 | y = F.log_softmax(y, dim=1) 15 | 16 | x = x.unsqueeze(-1) 17 | y = y.unsqueeze(-1) 18 | 19 | ce_x = (- 1.0 * i * x).sum(1) 20 | ce_y = (- 1.0 * i * y).sum(1) 21 | 22 | ce = 0.5 * (ce_x + ce_y).min(1)[0].mean() 23 | 24 | return ce 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Domain Whitening Transform for Unsupervised Domain Adaptation (CVPR 2019) 2 | 3 | 4 | Official PyTorch github repository for the paper [Unsupervised Domain Adaptation using Feature-Whitening and Consensus Loss](http://openaccess.thecvf.com/content_CVPR_2019/html/Roy_Unsupervised_Domain_Adaptation_Using_Feature-Whitening_and_Consensus_Loss_CVPR_2019_paper.html) published in The Conference on Computer Vision and Pattern Recognition (**CVPR**) held at Long Beach, California in June, 2019. 5 | 6 | ### Prerequisites 7 | * Pytorch 1.0 8 | * Python 3.5 9 | 10 | ### Usage 11 | - Office-Home: To run the experiments on the [OfficeHome](http://hemanthdv.org/OfficeHome-Dataset/) dataset first you need to download the dataset from [this](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view) page. Following this step, you would need to download the ResNet50 pre-trained checkpoint, trained on ImageNet with the BatchNorm layers (in the first conv layer and the first Res block) replaced by *whitening* normalization layers. The pre-trained weights is available [here](https://drive.google.com/file/d/1Iw3pCXdiAiJJnZDzh7UToBNQipIVeMS2/view?usp=sharing). 12 | 13 | ``` 14 | python resnet50_dwt_mec_officehome.py --s_dset_path path-to-source-dataset-folder --t_dset_path path-to-target-dataset folder --resnet_path path-to-pre-trained-resnet50-weights 15 | ``` 16 | 17 | - USPS -> MNIST: 18 | ``` 19 | python usps_mnist.py --group_size 4 --source 'usps' --target 'mnist' 20 | ``` 21 | 22 | If you find this code useful for your research, please cite our paper: 23 | ``` 24 | @article{roy2019unsupervised, 25 | title={Unsupervised Domain Adaptation using Feature-Whitening and Consensus Loss}, 26 | author={Roy, Subhankar and Siarohin, Aliaksandr and Sangineto, Enver and Bulo, Samuel Rota and Sebe, Nicu and Ricci, Elisa}, 27 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 28 | year={2019} 29 | } 30 | ``` 31 | 32 | [![Subhankar's GitHub stats](https://github-readme-stats.vercel.app/api?username=roysubhankar)](https://github.com/roysubhankar) 33 | -------------------------------------------------------------------------------- /utils/whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import conv2d 4 | 5 | class _Whitening(nn.Module): 6 | 7 | def __init__(self, num_features, group_size, running_m=None, running_var=None, momentum=0.1, track_running_stats=True, eps=1e-3, alpha=1): 8 | super(_Whitening, self).__init__() 9 | self.num_features = num_features 10 | self.momentum = momentum 11 | self.track_running_stats = track_running_stats 12 | self.eps = eps 13 | self.alpha = alpha 14 | self.group_size = min(self.num_features, group_size) 15 | self.num_groups = self.num_features // self.group_size 16 | self.running_m = running_m 17 | self.running_var = running_var 18 | 19 | if self.track_running_stats and self.running_m is not None: 20 | self.register_buffer('running_mean', self.running_m) 21 | self.register_buffer('running_variance', self.running_var) 22 | else: 23 | self.register_buffer('running_mean', torch.zeros([1, self.num_features, 1, 1], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 24 | self.register_buffer('running_variance', torch.ones([self.num_groups, self.group_size, self.group_size], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 25 | """ 26 | if self.track_running_stats: 27 | self.register_buffer('running_mean', torch.zeros([1, self.num_features, 1, 1], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 28 | self.register_buffer('running_variance', torch.zeros([self.num_groups, self.group_size, self.group_size], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 29 | """ 30 | 31 | def _check_input_dim(self, input): 32 | raise NotImplementedError 33 | 34 | def _check_group_size(self): 35 | raise NotImplementedError 36 | 37 | def forward(self, x): 38 | self._check_input_dim(x) 39 | self._check_group_size() 40 | 41 | m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) 42 | if not self.training and self.track_running_stats: # for inference 43 | m = self.running_mean 44 | xn = x - m 45 | 46 | T = xn.permute(1,0,2,3).contiguous().view(self.num_groups, self.group_size,-1) 47 | f_cov = torch.bmm(T, T.permute(0,2,1)) / T.shape[-1] 48 | f_cov_shrinked = (1-self.eps) * f_cov + self.eps * torch.eye(self.group_size, out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor()).repeat(self.num_groups, 1, 1) 49 | 50 | if not self.training and self.track_running_stats: # for inference 51 | f_cov_shrinked = (1-self.eps) * self.running_variance + self.eps * torch.eye(self.group_size, out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor()).repeat(self.num_groups, 1, 1) 52 | 53 | inv_sqrt = torch.inverse(torch.cholesky(f_cov_shrinked)).contiguous().view(self.num_features, self.group_size, 1, 1) 54 | 55 | decorrelated = conv2d(xn, inv_sqrt, groups = self.num_groups) 56 | 57 | if self.training and self.track_running_stats: 58 | self.running_mean = torch.add(self.momentum * m.detach(), (1 - self.momentum) * self.running_mean, out=self.running_mean) 59 | self.running_variance = torch.add(self.momentum * f_cov.detach(), (1 - self.momentum) * self.running_variance, out=self.running_variance) 60 | 61 | return decorrelated 62 | 63 | class WTransform2d(_Whitening): 64 | def _check_input_dim(self, input): 65 | if input.dim() != 4: 66 | raise ValueError('expected 4D input (got {}D input)'. format(input.dim())) 67 | 68 | def _check_group_size(self): 69 | if self.num_features % self.group_size != 0: 70 | raise ValueError('expected number of channels divisible by group_size (got {} group_size\ 71 | for {} number of features'.format(self.group_size, self.num_features)) 72 | -------------------------------------------------------------------------------- /utils/folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | File modified from: https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 3 | """ 4 | 5 | import torch.utils.data as data 6 | 7 | from PIL import Image 8 | 9 | import os 10 | import os.path 11 | import sys 12 | 13 | 14 | def has_file_allowed_extension(filename, extensions): 15 | """Checks if a file is an allowed extension. 16 | 17 | Args: 18 | filename (string): path to a file 19 | extensions (iterable of strings): extensions to consider (lowercase) 20 | 21 | Returns: 22 | bool: True if the filename ends with one of given extensions 23 | """ 24 | filename_lower = filename.lower() 25 | return any(filename_lower.endswith(ext) for ext in extensions) 26 | 27 | 28 | def is_image_file(filename): 29 | """Checks if a file is an allowed image extension. 30 | 31 | Args: 32 | filename (string): path to a file 33 | 34 | Returns: 35 | bool: True if the filename ends with a known image extension 36 | """ 37 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 38 | 39 | 40 | def make_dataset(dir, class_to_idx, extensions): 41 | images = [] 42 | dir = os.path.expanduser(dir) 43 | for target in sorted(class_to_idx.keys()): 44 | d = os.path.join(dir, target) 45 | if not os.path.isdir(d): 46 | continue 47 | 48 | for root, _, fnames in sorted(os.walk(d)): 49 | for fname in sorted(fnames): 50 | if has_file_allowed_extension(fname, extensions): 51 | path = os.path.join(root, fname) 52 | item = (path, class_to_idx[target]) 53 | images.append(item) 54 | 55 | return images 56 | 57 | 58 | class DatasetFolder(data.Dataset): 59 | """A generic data loader where the samples are arranged in this way: :: 60 | 61 | root/class_x/xxx.ext 62 | root/class_x/xxy.ext 63 | root/class_x/xxz.ext 64 | 65 | root/class_y/123.ext 66 | root/class_y/nsdf3.ext 67 | root/class_y/asd932_.ext 68 | 69 | Args: 70 | root (string): Root directory path. 71 | loader (callable): A function to load a sample given its path. 72 | extensions (list[string]): A list of allowed extensions. 73 | transform (callable, optional): A function/transform that takes in 74 | a sample and returns a transformed version. 75 | E.g, ``transforms.RandomCrop`` for images. 76 | target_transform (callable, optional): A function/transform that takes 77 | in the target and transforms it. 78 | 79 | Attributes: 80 | classes (list): List of the class names. 81 | class_to_idx (dict): Dict with items (class_name, class_index). 82 | samples (list): List of (sample path, class_index) tuples 83 | targets (list): The class_index value for each image in the dataset 84 | """ 85 | 86 | def __init__(self, root, loader, extensions, transform=None, transform_aug=None): 87 | classes, class_to_idx = self._find_classes(root) 88 | samples = make_dataset(root, class_to_idx, extensions) 89 | if len(samples) == 0: 90 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 91 | "Supported extensions are: " + ",".join(extensions))) 92 | 93 | self.root = root 94 | self.loader = loader 95 | self.extensions = extensions 96 | 97 | self.classes = classes 98 | self.class_to_idx = class_to_idx 99 | self.samples = samples 100 | self.targets = [s[1] for s in samples] 101 | 102 | self.transform = transform 103 | self.transform_aug = transform_aug 104 | 105 | def _find_classes(self, dir): 106 | """ 107 | Finds the class folders in a dataset. 108 | 109 | Args: 110 | dir (string): Root directory path. 111 | 112 | Returns: 113 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 114 | 115 | Ensures: 116 | No class is a subdirectory of another. 117 | """ 118 | if sys.version_info >= (3, 5): 119 | # Faster and available in Python 3.5 and above 120 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 121 | else: 122 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 123 | classes.sort() 124 | class_to_idx = {classes[i]: i for i in range(len(classes))} 125 | return classes, class_to_idx 126 | 127 | def __getitem__(self, index): 128 | """ 129 | Args: 130 | index (int): Index 131 | 132 | Returns: 133 | tuple: (sample, target) where target is class_index of the target class. 134 | """ 135 | path, target = self.samples[index] 136 | sample = self.loader(path) 137 | 138 | if self.transform_aug is not None: 139 | sample_aug = self.transform_aug(sample) 140 | 141 | if self.transform is not None: 142 | sample = self.transform(sample) 143 | 144 | if self.transform_aug is not None: 145 | return sample, sample_aug, target 146 | else: 147 | return sample, target 148 | 149 | def __len__(self): 150 | return len(self.samples) 151 | 152 | def __repr__(self): 153 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 154 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 155 | fmt_str += ' Root Location: {}\n'.format(self.root) 156 | tmp = ' Transforms (if any): ' 157 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 158 | tmp = ' Target Transforms (if any): ' 159 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 160 | return fmt_str 161 | 162 | 163 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 164 | 165 | 166 | def pil_loader(path): 167 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 168 | with open(path, 'rb') as f: 169 | img = Image.open(f) 170 | return img.convert('RGB') 171 | 172 | 173 | def accimage_loader(path): 174 | import accimage 175 | try: 176 | return accimage.Image(path) 177 | except IOError: 178 | # Potentially a decoding problem, fall back to PIL.Image 179 | return pil_loader(path) 180 | 181 | 182 | def default_loader(path): 183 | from torchvision import get_image_backend 184 | if get_image_backend() == 'accimage': 185 | return accimage_loader(path) 186 | else: 187 | return pil_loader(path) 188 | 189 | 190 | class ImageFolder(DatasetFolder): 191 | """A generic data loader where the images are arranged in this way: :: 192 | 193 | root/dog/xxx.png 194 | root/dog/xxy.png 195 | root/dog/xxz.png 196 | 197 | root/cat/123.png 198 | root/cat/nsdf3.png 199 | root/cat/asd932_.png 200 | 201 | Args: 202 | root (string): Root directory path. 203 | transform (callable, optional): A function/transform that takes in an PIL image 204 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 205 | target_transform (callable, optional): A function/transform that takes in the 206 | target and transforms it. 207 | loader (callable, optional): A function to load an image given its path. 208 | 209 | Attributes: 210 | classes (list): List of the class names. 211 | class_to_idx (dict): Dict with items (class_name, class_index). 212 | imgs (list): List of (image path, class_index) tuples 213 | """ 214 | def __init__(self, root, transform=None, transform_aug=None, loader=default_loader): 215 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 216 | transform=transform, 217 | transform_aug=transform_aug) 218 | self.imgs = self.samples -------------------------------------------------------------------------------- /utils/batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | """ 8 | File modified from: 9 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 10 | """ 11 | 12 | # TODO: check contiguous in THNN 13 | # TODO: use separate backend functions? 14 | class _BatchNorm(Module): 15 | _version = 2 16 | 17 | def __init__(self, num_features, running_m, running_v, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): 18 | super(_BatchNorm, self).__init__() 19 | self.num_features = num_features 20 | self.eps = eps 21 | self.momentum = momentum 22 | self.affine = affine 23 | self.running_m = running_m 24 | self.running_v = running_v 25 | self.track_running_stats = track_running_stats 26 | if self.affine: 27 | self.weight = Parameter(torch.Tensor(num_features)) 28 | self.bias = Parameter(torch.Tensor(num_features)) 29 | else: 30 | self.register_parameter('weight', None) 31 | self.register_parameter('bias', None) 32 | if self.track_running_stats: 33 | self.register_buffer('running_mean', self.running_m) 34 | self.register_buffer('running_var', self.running_v) 35 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 36 | else: 37 | self.register_parameter('running_mean', None) 38 | self.register_parameter('running_var', None) 39 | self.register_parameter('num_batches_tracked', None) 40 | self.reset_parameters() 41 | 42 | def reset_running_stats(self): 43 | if self.track_running_stats: 44 | self.num_batches_tracked.zero_() 45 | 46 | def reset_parameters(self): 47 | self.reset_running_stats() 48 | if self.affine: 49 | init.uniform_(self.weight) 50 | init.zeros_(self.bias) 51 | def _check_input_dim(self, input): 52 | raise NotImplementedError 53 | 54 | def forward(self, input): 55 | self._check_input_dim(input) 56 | 57 | exponential_average_factor = 0.0 58 | 59 | if self.training and self.track_running_stats: 60 | self.num_batches_tracked += 1 61 | if self.momentum is None: # use cumulative moving average 62 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 63 | else: # use exponential moving average 64 | exponential_average_factor = self.momentum 65 | 66 | return F.batch_norm( 67 | input, self.running_mean, self.running_var, self.weight, self.bias, 68 | self.training or not self.track_running_stats, 69 | exponential_average_factor, self.eps) 70 | 71 | def extra_repr(self): 72 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 73 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 74 | 75 | def _load_from_state_dict(self, state_dict, prefix, metadata, strict, 76 | missing_keys, unexpected_keys, error_msgs): 77 | version = metadata.get('version', None) 78 | 79 | if (version is None or version < 2) and self.track_running_stats: 80 | # at version 2: added num_batches_tracked buffer 81 | # this should have a default value of 0 82 | num_batches_tracked_key = prefix + 'num_batches_tracked' 83 | if num_batches_tracked_key not in state_dict: 84 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 85 | 86 | super(_BatchNorm, self)._load_from_state_dict( 87 | state_dict, prefix, metadata, strict, 88 | missing_keys, unexpected_keys, error_msgs) 89 | 90 | 91 | class BatchNorm1d(_BatchNorm): 92 | r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D 93 | inputs with optional additional channel dimension) as described in the paper 94 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 95 | 96 | .. math:: 97 | 98 | y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 99 | 100 | The mean and standard-deviation are calculated per-dimension over 101 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 102 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled 103 | from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 104 | 105 | Also by default, during training this layer keeps running estimates of its 106 | computed mean and variance, which are then used for normalization during 107 | evaluation. The running estimates are kept with a default :attr:`momentum` 108 | of 0.1. 109 | 110 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 111 | keep running estimates, and batch statistics are instead used during 112 | evaluation time as well. 113 | 114 | .. note:: 115 | This :attr:`momentum` argument is different from one used in optimizer 116 | classes and the conventional notion of momentum. Mathematically, the 117 | update rule for running statistics here is 118 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 119 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 120 | new observed value. 121 | 122 | Because the Batch Normalization is done over the `C` dimension, computing statistics 123 | on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 124 | 125 | Args: 126 | num_features: :math:`C` from an expected input of size 127 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 128 | eps: a value added to the denominator for numerical stability. 129 | Default: 1e-5 130 | momentum: the value used for the running_mean and running_var 131 | computation. Can be set to ``None`` for cumulative moving average 132 | (i.e. simple average). Default: 0.1 133 | affine: a boolean value that when set to ``True``, this module has 134 | learnable affine parameters. Default: ``True`` 135 | track_running_stats: a boolean value that when set to ``True``, this 136 | module tracks the running mean and variance, and when set to ``False``, 137 | this module does not track such statistics and always uses batch 138 | statistics in both training and eval modes. Default: ``True`` 139 | 140 | Shape: 141 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 142 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 143 | 144 | Examples:: 145 | 146 | >>> # With Learnable Parameters 147 | >>> m = nn.BatchNorm1d(100) 148 | >>> # Without Learnable Parameters 149 | >>> m = nn.BatchNorm1d(100, affine=False) 150 | >>> input = torch.randn(20, 100) 151 | >>> output = m(input) 152 | 153 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 154 | https://arxiv.org/abs/1502.03167 155 | """ 156 | 157 | def _check_input_dim(self, input): 158 | if input.dim() != 2 and input.dim() != 3: 159 | raise ValueError('expected 2D or 3D input (got {}D input)' 160 | .format(input.dim())) 161 | 162 | 163 | class BatchNorm2d(_BatchNorm): 164 | r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs 165 | with additional channel dimension) as described in the paper 166 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 167 | 168 | .. math:: 169 | 170 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 171 | 172 | The mean and standard-deviation are calculated per-dimension over 173 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 174 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled 175 | from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 176 | 177 | Also by default, during training this layer keeps running estimates of its 178 | computed mean and variance, which are then used for normalization during 179 | evaluation. The running estimates are kept with a default :attr:`momentum` 180 | of 0.1. 181 | 182 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 183 | keep running estimates, and batch statistics are instead used during 184 | evaluation time as well. 185 | 186 | .. note:: 187 | This :attr:`momentum` argument is different from one used in optimizer 188 | classes and the conventional notion of momentum. Mathematically, the 189 | update rule for running statistics here is 190 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 191 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 192 | new observed value. 193 | 194 | Because the Batch Normalization is done over the `C` dimension, computing statistics 195 | on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 196 | 197 | Args: 198 | num_features: :math:`C` from an expected input of size 199 | :math:`(N, C, H, W)` 200 | eps: a value added to the denominator for numerical stability. 201 | Default: 1e-5 202 | momentum: the value used for the running_mean and running_var 203 | computation. Can be set to ``None`` for cumulative moving average 204 | (i.e. simple average). Default: 0.1 205 | affine: a boolean value that when set to ``True``, this module has 206 | learnable affine parameters. Default: ``True`` 207 | track_running_stats: a boolean value that when set to ``True``, this 208 | module tracks the running mean and variance, and when set to ``False``, 209 | this module does not track such statistics and always uses batch 210 | statistics in both training and eval modes. Default: ``True`` 211 | 212 | Shape: 213 | - Input: :math:`(N, C, H, W)` 214 | - Output: :math:`(N, C, H, W)` (same shape as input) 215 | 216 | Examples:: 217 | 218 | >>> # With Learnable Parameters 219 | >>> m = nn.BatchNorm2d(100) 220 | >>> # Without Learnable Parameters 221 | >>> m = nn.BatchNorm2d(100, affine=False) 222 | >>> input = torch.randn(20, 100, 35, 45) 223 | >>> output = m(input) 224 | 225 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 226 | https://arxiv.org/abs/1502.03167 227 | """ 228 | 229 | def _check_input_dim(self, input): 230 | if input.dim() != 4: 231 | raise ValueError('expected 4D input (got {}D input)' 232 | .format(input.dim())) 233 | 234 | 235 | class BatchNorm3d(_BatchNorm): 236 | r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs 237 | with additional channel dimension) as described in the paper 238 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 239 | 240 | .. math:: 241 | 242 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 243 | 244 | The mean and standard-deviation are calculated per-dimension over 245 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 246 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled 247 | from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 248 | 249 | Also by default, during training this layer keeps running estimates of its 250 | computed mean and variance, which are then used for normalization during 251 | evaluation. The running estimates are kept with a default :attr:`momentum` 252 | of 0.1. 253 | 254 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 255 | keep running estimates, and batch statistics are instead used during 256 | evaluation time as well. 257 | 258 | .. note:: 259 | This :attr:`momentum` argument is different from one used in optimizer 260 | classes and the conventional notion of momentum. Mathematically, the 261 | update rule for running statistics here is 262 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 263 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 264 | new observed value. 265 | 266 | Because the Batch Normalization is done over the `C` dimension, computing statistics 267 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization 268 | or Spatio-temporal Batch Normalization. 269 | 270 | Args: 271 | num_features: :math:`C` from an expected input of size 272 | :math:`(N, C, D, H, W)` 273 | eps: a value added to the denominator for numerical stability. 274 | Default: 1e-5 275 | momentum: the value used for the running_mean and running_var 276 | computation. Can be set to ``None`` for cumulative moving average 277 | (i.e. simple average). Default: 0.1 278 | affine: a boolean value that when set to ``True``, this module has 279 | learnable affine parameters. Default: ``True`` 280 | track_running_stats: a boolean value that when set to ``True``, this 281 | module tracks the running mean and variance, and when set to ``False``, 282 | this module does not track such statistics and always uses batch 283 | statistics in both training and eval modes. Default: ``True`` 284 | 285 | Shape: 286 | - Input: :math:`(N, C, D, H, W)` 287 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 288 | 289 | Examples:: 290 | 291 | >>> # With Learnable Parameters 292 | >>> m = nn.BatchNorm3d(100) 293 | >>> # Without Learnable Parameters 294 | >>> m = nn.BatchNorm3d(100, affine=False) 295 | >>> input = torch.randn(20, 100, 35, 45, 10) 296 | >>> output = m(input) 297 | 298 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 299 | https://arxiv.org/abs/1502.03167 300 | """ 301 | 302 | def _check_input_dim(self, input): 303 | if input.dim() != 5: 304 | raise ValueError('expected 5D input (got {}D input)' 305 | .format(input.dim())) -------------------------------------------------------------------------------- /usps_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | sys.path.append('utils') 5 | import argparse 6 | import gzip 7 | import os 8 | import pickle 9 | import urllib 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from PIL import Image 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torch.utils.data as data 19 | from torch.optim import lr_scheduler 20 | from torchvision import datasets, transforms 21 | 22 | from whitening import WTransform2d 23 | 24 | usps_dataset_multiplier = 6 25 | 26 | class USPS(data.Dataset): 27 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 28 | 29 | def __init__(self, root, train=True, transform=None, transform_aug=None, download=False): 30 | """Init the USPS data set""" 31 | self.root = os.path.expanduser(root) 32 | self.filename = "usps_28x28.pkl" 33 | self.train = train 34 | # Num of train = 7438, Num of test = 1860 35 | self.transform = transform 36 | self.transform_aug = transform_aug 37 | self.dataset_size = None 38 | 39 | if download: 40 | self.download() 41 | 42 | if not self._check_exists(): 43 | raise RuntimeError("Dataset no found." + 44 | " You can use download=True to download it.") 45 | self.train_data, self.train_labels = self.load_samples() 46 | 47 | if self.train: 48 | self.train_data = np.repeat(self.train_data, usps_dataset_multiplier, axis=0) 49 | self.train_labels = np.repeat(self.train_labels, usps_dataset_multiplier, axis=0) 50 | 51 | total_num_samples = self.train_labels.shape[0] 52 | indices = np.arange(total_num_samples) 53 | np.random.shuffle(indices) 54 | self.train_data = self.train_data[indices[0: usps_dataset_multiplier * self.dataset_size], ::] 55 | self.train_labels = self.train_labels[indices[0: usps_dataset_multiplier * self.dataset_size]] 56 | 57 | # self.train_data *= 255.0 58 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # NCHW 59 | 60 | def __getitem__(self, index): 61 | """ Get images and target labels for data loader 62 | 63 | Args: 64 | index (int): Index 65 | Returns: 66 | tuple (image, target): where target is the index of the class 67 | """ 68 | 69 | img, label = self.train_data[index], self.train_labels[index] 70 | 71 | if self.transform_aug is not None: 72 | img_aug = self.transform_aug(img) 73 | 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | 77 | label = torch.squeeze(torch.LongTensor([np.int64(label).item()])) 78 | 79 | if self.transform_aug is not None: 80 | return img, img_aug, label 81 | else: 82 | return img, label 83 | 84 | def __len__(self): 85 | """ Return the size of the dataset """ 86 | if self.train: 87 | return usps_dataset_multiplier * self.dataset_size 88 | else: 89 | return self.dataset_size 90 | 91 | def _check_exists(self): 92 | return os.path.exists(os.path.join(self.root, self.filename)) 93 | 94 | def download(self): 95 | filename = os.path.join(self.root, self.filename) 96 | dirname = os.path.dirname(filename) 97 | if not os.path.isdir(dirname): 98 | os.makedirs(dirname) 99 | if os.path.isfile(filename): 100 | return 101 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 102 | urllib.request.urlretrieve(self.url, filename) 103 | print("Done") 104 | return 105 | 106 | def load_samples(self): 107 | filename = os.path.join(self.root, self.filename) 108 | f = gzip.open(filename, "rb") 109 | data_set = pickle.load(f, encoding="bytes") 110 | f.close() 111 | 112 | if self.train: 113 | images = data_set[0][0] 114 | labels = data_set[0][1] 115 | self.dataset_size = labels.shape[0] 116 | else: 117 | images = data_set[1][0] 118 | labels = data_set[1][1] 119 | self.dataset_size = labels.shape[0] 120 | return images, labels 121 | 122 | 123 | class MNIST(data.Dataset): 124 | """`MNIST `_ Dataset. 125 | Args: 126 | root (string): Root directory of dataset where ``processed/training.pt`` 127 | and ``processed/test.pt`` exist. 128 | train (bool, optional): If True, creates dataset from ``training.pt``, 129 | otherwise from ``test.pt``. 130 | download (bool, optional): If true, downloads the dataset from the internet and 131 | puts it in root directory. If dataset is already downloaded, it is not 132 | downloaded again. 133 | transform (callable, optional): A function/transform that takes in an PIL image 134 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 135 | target_transform (callable, optional): A function/transform that takes in the 136 | target and transforms it. 137 | """ 138 | 139 | processed_folder = 'processed' 140 | training_file = 'training.pt' 141 | test_file = 'test.pt' 142 | 143 | def __init__(self, root, train=True, transform=None, transform_aug=None): 144 | self.root = os.path.expanduser(root) 145 | self.transform = transform 146 | self.transform_aug = transform_aug 147 | self.train = train # training set or test set 148 | 149 | if self.train: 150 | data_file = self.training_file 151 | else: 152 | data_file = self.test_file 153 | self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file)) 154 | 155 | def __getitem__(self, index): 156 | """ 157 | Args: 158 | index (int): Index 159 | Returns: 160 | tuple: (image, target_label, rot_label) where target_label is index of the target class 161 | and rot_label is the rotation index 162 | """ 163 | img, target_label = self.data[index], self.targets[index] 164 | 165 | # doing this so that it is consistent with all other datasets 166 | # to return a PIL Image 167 | img = Image.fromarray(img.numpy(), mode='L') 168 | 169 | if self.transform_aug is not None: 170 | img_aug = self.transform_aug(img) 171 | 172 | if self.transform is not None: 173 | img = self.transform(img) 174 | 175 | if self.transform_aug is not None: 176 | return img, img_aug, target_label 177 | else: 178 | return img, target_label 179 | 180 | def __len__(self): 181 | return len(self.data) 182 | 183 | class EntropyLoss(nn.Module): 184 | ''' Module to compute entropy loss ''' 185 | def __init__(self): 186 | super(EntropyLoss, self).__init__() 187 | 188 | def forward(self, x): 189 | p = F.softmax(x, dim=1) 190 | q = F.log_softmax(x, dim=1) 191 | b = p * q 192 | b = -1.0 * b.sum(-1).mean() 193 | #b = -1.0 * b.sum() 194 | return b 195 | 196 | class LeNet(nn.Module): 197 | def __init__(self, group_size): 198 | super(LeNet, self).__init__() 199 | self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2) 200 | self.ws1 = WTransform2d(num_features=32, group_size=group_size) 201 | self.wt1 = WTransform2d(num_features=32, group_size=group_size) 202 | self.gamma1 = nn.Parameter(torch.ones(32, 1, 1)) 203 | self.beta1 = nn.Parameter(torch.zeros(32, 1, 1)) 204 | #self.conv1_drop = nn.Dropout2d() 205 | 206 | self.conv2 = nn.Conv2d(32, 48, kernel_size=5, padding=2) 207 | self.ws2 = WTransform2d(num_features=48, group_size=group_size) 208 | self.wt2 = WTransform2d(num_features=48, group_size=group_size) 209 | self.gamma2 = nn.Parameter(torch.ones(48, 1, 1)) 210 | self.beta2 = nn.Parameter(torch.zeros(48, 1, 1)) 211 | #self.conv2_drop = nn.Dropout2d() 212 | 213 | self.fc3 = nn.Linear(2352, 100) 214 | self.bns3 = nn.BatchNorm1d(100, affine=False) 215 | self.bnt3 = nn.BatchNorm1d(100, affine=False) 216 | self.gamma3 = nn.Parameter(torch.ones(1, 100)) 217 | self.beta3 = nn.Parameter(torch.zeros(1, 100)) 218 | 219 | self.fc4 = nn.Linear(100, 100) 220 | self.bns4 = nn.BatchNorm1d(100, affine=False) 221 | self.bnt4 = nn.BatchNorm1d(100, affine=False) 222 | self.gamma4 = nn.Parameter(torch.ones(1, 100)) 223 | self.beta4 = nn.Parameter(torch.zeros(1, 100)) 224 | 225 | self.fc5 = nn.Linear(100, 10) 226 | self.bns5 = nn.BatchNorm1d(10, affine=False) 227 | self.bnt5 = nn.BatchNorm1d(10, affine=False) 228 | self.gamma5 = nn.Parameter(torch.ones(1, 10)) 229 | self.beta5 = nn.Parameter(torch.zeros(1, 10)) 230 | 231 | def forward(self, x): 232 | 233 | if self.training: 234 | x = self.conv1(x) 235 | x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0) 236 | #x = self.conv1_drop(F.max_pool2d(F.relu(torch.cat((self.ws1(x_source), self.wt1(x_target)), dim=0)*self.gamma1 + self.beta1), kernel_size=2, stride=2)) 237 | x = F.max_pool2d(F.relu(torch.cat((self.ws1(x_source), self.wt1(x_target)), dim=0)*self.gamma1 + self.beta1), kernel_size=2, stride=2) 238 | 239 | x = self.conv2(x) 240 | x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0) 241 | #x = self.conv2_drop(F.max_pool2d(F.relu(torch.cat((self.ws2(x_source), self.wt2(x_target)), dim=0)*self.gamma2 + self.beta2), kernel_size=2, stride=2)) 242 | x = F.max_pool2d(F.relu(torch.cat((self.ws2(x_source), self.wt2(x_target)), dim=0)*self.gamma2 + self.beta2), kernel_size=2, stride=2) 243 | 244 | x = x.view(x.shape[0], -1) 245 | x = self.fc3(x) 246 | x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0) 247 | #x = F.dropout(F.relu(torch.cat((self.bns3(x_source), self.bnt3(x_target)), dim=0)*self.gamma3 + self.beta3), training=self.training) 248 | x = F.relu(torch.cat((self.bns3(x_source), self.bnt3(x_target)), dim=0)*self.gamma3 + self.beta3) 249 | 250 | x = self.fc4(x) 251 | x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0) 252 | #x = F.dropout(F.relu(torch.cat((self.bns4(x_source), self.bnt4(x_target)), dim=0)*self.gamma4 + self.beta4), training=self.training) 253 | x = F.relu(torch.cat((self.bns4(x_source), self.bnt4(x_target)), dim=0)*self.gamma4 + self.beta4) 254 | 255 | x = self.fc5(x) 256 | x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0) 257 | x = torch.cat((self.bns5(x_source), self.bnt5(x_target)), dim=0)*self.gamma5 + self.beta5 258 | else: 259 | x = self.conv1(x) 260 | #x = self.conv1_drop(F.max_pool2d(F.relu(self.wt1(x)*self.gamma1 + self.beta1), kernel_size=2, stride=2)) 261 | x = F.max_pool2d(F.relu(self.wt1(x)*self.gamma1 + self.beta1), kernel_size=2, stride=2) 262 | 263 | x = self.conv2(x) 264 | #x = self.conv2_drop(F.max_pool2d(F.relu(self.wt2(x)*self.gamma2 + self.beta2), kernel_size=2, stride=2)) 265 | x = F.max_pool2d(F.relu(self.wt2(x)*self.gamma2 + self.beta2), kernel_size=2, stride=2) 266 | 267 | x = x.view(x.shape[0], -1) 268 | x = self.fc3(x) 269 | #x = F.dropout(F.relu(self.bnt3(x)*self.gamma3 + self.beta3), training=self.training) 270 | x = F.relu(self.bnt3(x)*self.gamma3 + self.beta3) 271 | 272 | x = self.fc4(x) 273 | #x = F.dropout(F.relu(self.bnt4(x)*self.gamma4 + self.beta4), training=self.training) 274 | x = F.relu(self.bnt4(x)*self.gamma4 + self.beta4) 275 | 276 | x = self.fc5(x) 277 | x = self.bnt5(x)*self.gamma5 + self.beta5 278 | return x 279 | 280 | 281 | def train(args, model, device, source_train_loader, target_train_loader, optimizer, epoch, lambda_entropy_loss): 282 | model.train() 283 | for batch_idx, (source, target) in enumerate(zip(source_train_loader, target_train_loader)): 284 | source_data = source[0] 285 | source_y = source[1] 286 | target_data = target[0] 287 | 288 | data = torch.cat((source_data, target_data), dim=0) # concat the source and target mini-batches 289 | data, source_y = data.to(device), source_y.to(device) 290 | 291 | optimizer.zero_grad() 292 | output = model(data) 293 | 294 | source_output, target_output = torch.split(output, split_size_or_sections=output.shape[0] // 2, dim=0) 295 | 296 | entropy_criterion = EntropyLoss() 297 | 298 | cls_loss = F.nll_loss(F.log_softmax(source_output), source_y) 299 | entropy_l = lambda_entropy_loss * entropy_criterion(target_output) 300 | 301 | loss = cls_loss + entropy_l 302 | loss.backward() 303 | optimizer.step() 304 | 305 | if batch_idx % args.log_interval == 0: 306 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tClassification Loss: {:.6f} \tEntropy Loss: {:.6f}'.format( 307 | epoch, batch_idx * len(target_data), len(source_train_loader.dataset), 308 | 100. * batch_idx / len(source_train_loader), cls_loss.item(), entropy_l.item())) 309 | 310 | def test(args, model, device, target_test_loader): 311 | model.eval() 312 | test_cls_loss = 0. 313 | correct = 0 314 | with torch.no_grad(): 315 | for data, target in target_test_loader: 316 | data, target = data.to(device), target.to(device) 317 | output = model(data) 318 | test_cls_loss += F.nll_loss(F.log_softmax(output, dim=1), target, size_average=False).item() 319 | pred = F.softmax(output, dim=1).max(1, keepdim=True)[1] # get the index of max log-probability 320 | correct += pred.eq(target.view_as(pred)).sum().item() 321 | 322 | test_cls_loss /= len(target_test_loader.dataset) 323 | print('\nTest set: Classification loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 324 | test_cls_loss, correct, len(target_test_loader.dataset), 325 | 100. * correct / len(target_test_loader.dataset))) 326 | 327 | return 100. * correct / len(target_test_loader.dataset) 328 | 329 | def main(): 330 | # Training settings 331 | parser = argparse.ArgumentParser(description='PyTorch DIAL example') 332 | parser.add_argument('--num_workers', default=2, type=int) 333 | parser.add_argument('--source_batch_size', type=int, default=32, help='input source batch size for training (default: 32)') 334 | parser.add_argument('--target_batch_size', type=int, default=32, help='input target batch size for training (default: 32)') 335 | parser.add_argument('--test_batch_size', type=int, default=100, help='input batch size for testing (default: 100)') 336 | parser.add_argument('--source', type=str, default='usps', help='source dataset name') 337 | parser.add_argument('--target', type=str, default='mnist', help='target dataset name') 338 | parser.add_argument('--epochs', type=int, default=120, help='number of epochs to train (default: 10)') 339 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.01)') 340 | parser.add_argument('--sgd_momentum', type=float, default=0.5, help='SGD momentum (default: 0.5)') 341 | parser.add_argument('--running_momentum', type=float, default=0.1, help='Running momentum for statistics(default: 0.1)') 342 | parser.add_argument('--lambda_entropy_loss', type=float, default=0.1, help='Value of lambda for the entropy loss (default: 0.1)') 343 | parser.add_argument('--log_interval', type=int, default=100, help='how many batches to wait before logging training status') 344 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 345 | parser.add_argument('--from_script', action='store_true', help="use this flag for bulk running from script") 346 | parser.add_argument('--run', default=0, type=int, help="use this flag for bulk running from script") 347 | parser.add_argument('--method', default='bn', help="use this flag for bulk running from script") 348 | parser.add_argument('--group_size', type=int, default=32, help='group size for the whitening matrix (default: 32)') 349 | args = parser.parse_args() 350 | assert args.source != args.target, "source and target datasets can not be the same" 351 | torch.manual_seed(args.seed) 352 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 353 | 354 | """ MNIST train and test data loaders """ 355 | train_loader_mnist = torch.utils.data.DataLoader( 356 | MNIST('../data/mnist', train=True, 357 | transform=transforms.Compose([ 358 | transforms.ToTensor(), 359 | transforms.Normalize(mean=[0.1307], std=[0.3081]) 360 | ])), 361 | batch_size=args.source_batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) 362 | 363 | test_loader_mnist = torch.utils.data.DataLoader( 364 | MNIST('../data/mnist', train=False, 365 | transform=transforms.Compose([ 366 | transforms.ToTensor(), 367 | transforms.Normalize(mean=[0.1307], std=[0.3081]) 368 | ])), 369 | batch_size=args.test_batch_size, shuffle=True, num_workers=args.num_workers) 370 | 371 | """ USPS train and test data loaders """ 372 | train_loader_usps = torch.utils.data.DataLoader( 373 | USPS(root='../data/usps', train=True, 374 | transform=transforms.Compose([ 375 | transforms.ToTensor(), 376 | transforms.Normalize(mean=[0.5], std=[0.5]) 377 | ]) 378 | ,download=True), batch_size=args.target_batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) 379 | 380 | test_loader_usps = torch.utils.data.DataLoader( 381 | USPS(root='../data/usps', train=False, 382 | transform=transforms.Compose([ 383 | transforms.ToTensor(), 384 | transforms.Normalize(mean=[0.5], std=[0.5]) 385 | ]) 386 | ,download=False), batch_size=args.test_batch_size, shuffle=True, num_workers=args.num_workers) 387 | 388 | model = LeNet(group_size=args.group_size).to(device) 389 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) 390 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1) 391 | 392 | if args.source == 'mnist' and args.target == 'usps': 393 | source_train_loader = train_loader_mnist 394 | target_train_loader = train_loader_usps 395 | test_loader = test_loader_usps 396 | elif args.source == 'usps' and args.target == 'mnist': 397 | source_train_loader = train_loader_usps 398 | target_train_loader = train_loader_mnist 399 | test_loader = test_loader_mnist 400 | 401 | for epoch in range(args.epochs): 402 | exp_lr_scheduler.step() 403 | train(args, model, device, source_train_loader, target_train_loader, optimizer, epoch, args.lambda_entropy_loss) 404 | test(args, model, device, test_loader) 405 | 406 | 407 | if __name__ == '__main__': 408 | main() -------------------------------------------------------------------------------- /resnet50_dwt_mec_officehome.py: -------------------------------------------------------------------------------- 1 | """ 2 | File modified from: 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | """ 5 | 6 | from __future__ import print_function 7 | 8 | import sys 9 | sys.path.append('utils') 10 | 11 | import argparse 12 | import os 13 | import numpy as np 14 | from PIL import Image 15 | import scipy.io as sio 16 | import cv2 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.utils.model_zoo as model_zoo 21 | import torch.optim as optim 22 | from torch.optim import lr_scheduler 23 | import torchvision 24 | import torch.nn.functional as F 25 | from torchvision import datasets, models, transforms 26 | 27 | import batch_norm 28 | import folder 29 | import consensus_loss 30 | import whitening 31 | 32 | def conv3x3(in_planes, out_planes, stride=1): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | class whitening_scale_shift(nn.Module): 41 | def __init__(self, planes, group_size, running_mean, running_variance, track_running_stats=True, affine=True): 42 | super(whitening_scale_shift, self).__init__() 43 | self.planes = planes 44 | self.group_size = group_size 45 | self.track_running_stats = track_running_stats 46 | self.affine = affine 47 | self.running_mean = running_mean 48 | self.running_variance = running_variance 49 | 50 | self.wh = whitening.WTransform2d(self.planes, 51 | self.group_size, 52 | running_m=self.running_mean, 53 | running_var=self.running_variance, 54 | track_running_stats=self.track_running_stats) 55 | if self.affine: 56 | self.gamma = nn.Parameter(torch.ones(self.planes, 1, 1)) 57 | self.beta = nn.Parameter(torch.zeros(self.planes, 1, 1)) 58 | 59 | def forward(self, x): 60 | out = self.wh(x) 61 | if self.affine: 62 | out = out * self.gamma + self.beta 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | 69 | def __init__(self, inplanes, planes, layer, sub_layer, bn_dict, group_size=4, stride=1, downsample=None): 70 | super(Bottleneck, self).__init__() 71 | self.expansion = 4 72 | self.conv1 = conv1x1(inplanes, planes) 73 | if layer == 1: 74 | self.bns1 = whitening_scale_shift(planes=planes, 75 | group_size=group_size, 76 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.wh.running_mean'], 77 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.wh.running_variance'], 78 | affine=False) 79 | self.bnt1 = whitening_scale_shift(planes=planes, 80 | group_size=group_size, 81 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.wh.running_mean'], 82 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.wh.running_variance'], 83 | affine=False) 84 | self.bnt1_aug = whitening_scale_shift(planes=planes, 85 | group_size=group_size, 86 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.wh.running_mean'], 87 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.wh.running_variance'], 88 | affine=False) 89 | self.gamma1 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.gamma']) 90 | self.beta1 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.beta']) 91 | else: 92 | self.bns1 = batch_norm.BatchNorm2d(num_features=planes, 93 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.running_mean'], 94 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.running_var'], 95 | affine=False) 96 | self.bnt1 = batch_norm.BatchNorm2d(num_features=planes, 97 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.running_mean'], 98 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.running_var'], 99 | affine=False) 100 | self.bnt1_aug = batch_norm.BatchNorm2d(num_features=planes, 101 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.running_mean'], 102 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.running_var'], 103 | affine=False) 104 | self.gamma1 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.weight'].view(-1, 1, 1)) 105 | self.beta1 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn1.bias'].view(-1, 1, 1)) 106 | 107 | self.conv2 = conv3x3(planes, planes, stride) 108 | if layer == 1: 109 | self.bns2 = whitening_scale_shift(planes=planes, 110 | group_size=group_size, 111 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.wh.running_mean'], 112 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.wh.running_variance'], 113 | affine=False) 114 | self.bnt2 = whitening_scale_shift(planes=planes, 115 | group_size=group_size, 116 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.wh.running_mean'], 117 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.wh.running_variance'], 118 | affine=False) 119 | self.bnt2_aug = whitening_scale_shift(planes=planes, 120 | group_size=group_size, 121 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.wh.running_mean'], 122 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.wh.running_variance'], 123 | affine=False) 124 | self.gamma2 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.gamma']) 125 | self.beta2 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.beta']) 126 | else: 127 | self.bns2 = batch_norm.BatchNorm2d(num_features=planes, 128 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.running_mean'], 129 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.running_var'], 130 | affine=False) 131 | self.bnt2 = batch_norm.BatchNorm2d(num_features=planes, 132 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.running_mean'], 133 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.running_var'], 134 | affine=False) 135 | self.bnt2_aug = batch_norm.BatchNorm2d(num_features=planes, 136 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.running_mean'], 137 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.running_var'], 138 | affine=False) 139 | self.gamma2 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.weight'].view(-1, 1, 1)) 140 | self.beta2 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn2.bias'].view(-1, 1, 1)) 141 | 142 | self.conv3 = conv1x1(planes, planes * self.expansion) 143 | if layer == 1: 144 | self.bns3 = whitening_scale_shift(planes=planes * self.expansion, 145 | group_size=group_size, 146 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.wh.running_mean'], 147 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.wh.running_variance'], 148 | affine=False) 149 | self.bnt3 = whitening_scale_shift(planes=planes * self.expansion, 150 | group_size=group_size, 151 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.wh.running_mean'], 152 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.wh.running_variance'], 153 | affine=False) 154 | self.bnt3_aug = whitening_scale_shift(planes=planes * self.expansion, 155 | group_size=group_size, 156 | running_mean=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.wh.running_mean'], 157 | running_variance=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.wh.running_variance'], 158 | affine=False) 159 | self.gamma3 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.gamma']) 160 | self.beta3 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.beta']) 161 | else: 162 | self.bns3 = batch_norm.BatchNorm2d(num_features=planes * self.expansion, 163 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.running_mean'], 164 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.running_var'], 165 | affine=False) 166 | self.bnt3 = batch_norm.BatchNorm2d(num_features=planes * self.expansion, 167 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.running_mean'], 168 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.running_var'], 169 | affine=False) 170 | self.bnt3_aug = batch_norm.BatchNorm2d(num_features=planes * self.expansion, 171 | running_m=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.running_mean'], 172 | running_v=bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.running_var'], 173 | affine=False) 174 | self.gamma3 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.weight'].view(-1, 1, 1)) 175 | self.beta3 = nn.Parameter(bn_dict['layer' + str(layer) + '.' + str(sub_layer) + '.bn3.bias'].view(-1, 1, 1)) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.downsample = downsample 178 | self.stride = stride 179 | 180 | if self.downsample is not None: 181 | if layer == 1: 182 | self.downsample_bns = whitening_scale_shift(planes=planes * self.expansion, 183 | group_size=group_size, 184 | running_mean=bn_dict['layer' + str(layer) + '.0.downsample_bn.wh.running_mean'], 185 | running_variance=bn_dict['layer' + str(layer) + '.0.downsample_bn.wh.running_variance'], 186 | affine=False) 187 | self.downsample_bnt = whitening_scale_shift(planes=planes * self.expansion, 188 | group_size=group_size, 189 | running_mean=bn_dict['layer' + str(layer) + '.0.downsample_bn.wh.running_mean'], 190 | running_variance=bn_dict['layer' + str(layer) + '.0.downsample_bn.wh.running_variance'], 191 | affine=False) 192 | self.downsample_bnt_aug = whitening_scale_shift(planes=planes * self.expansion, 193 | group_size=group_size, 194 | running_mean=bn_dict['layer' + str(layer) + '.0.downsample_bn.wh.running_mean'], 195 | running_variance=bn_dict['layer' + str(layer) + '.0.downsample_bn.wh.running_variance'], 196 | affine=False) 197 | self.downsample_gamma = nn.Parameter(bn_dict['layer' + str(layer) + '.0.downsample_bn.gamma']) 198 | self.downsample_beta = nn.Parameter(bn_dict['layer' + str(layer) + '.0.downsample_bn.beta']) 199 | else: 200 | self.downsample_bns = batch_norm.BatchNorm2d(num_features=planes * self.expansion, 201 | running_m=bn_dict['layer' + str(layer) + '.0.downsample_bn.running_mean'], 202 | running_v=bn_dict['layer' + str(layer) + '.0.downsample_bn.running_var'], 203 | affine=False) 204 | self.downsample_bnt = batch_norm.BatchNorm2d(num_features=planes * self.expansion, 205 | running_m=bn_dict['layer' + str(layer) + '.0.downsample_bn.running_mean'], 206 | running_v=bn_dict['layer' + str(layer) + '.0.downsample_bn.running_var'], 207 | affine=False) 208 | self.downsample_bnt_aug = batch_norm.BatchNorm2d(num_features=planes * self.expansion, 209 | running_m=bn_dict['layer' + str(layer) + '.0.downsample_bn.running_mean'], 210 | running_v=bn_dict['layer' + str(layer) + '.0.downsample_bn.running_var'], 211 | affine=False) 212 | self.downsample_gamma = nn.Parameter(bn_dict['layer' + str(layer) + '.0.downsample_bn.weight'].view(-1, 1, 1)) 213 | self.downsample_beta = nn.Parameter(bn_dict['layer' + str(layer) + '.0.downsample_bn.bias'].view(-1, 1, 1)) 214 | 215 | def forward(self, x): 216 | if self.training: 217 | # to do 218 | identity = x 219 | out = self.conv1(x) 220 | out_s, out_t, out_t_dup = torch.split(out, split_size_or_sections=out.shape[0] // 3, dim=0) 221 | out = torch.cat((self.bns1(out_s), torch.cat((self.bnt1(out_t), self.bnt1_aug(out_t_dup)), dim=0) ), dim=0) * self.gamma1 + self.beta1 222 | out = self.relu(out) 223 | 224 | out = self.conv2(out) 225 | out_s, out_t, out_t_dup = torch.split(out, split_size_or_sections=out.shape[0] // 3, dim=0) 226 | out = torch.cat((self.bns2(out_s), torch.cat((self.bnt2(out_t), self.bnt2_aug(out_t_dup)), dim=0) ), dim=0) * self.gamma2 + self.beta2 227 | out = self.relu(out) 228 | 229 | out = self.conv3(out) 230 | out_s, out_t, out_t_dup = torch.split(out, split_size_or_sections=out.shape[0] // 3, dim=0) 231 | out = torch.cat((self.bns3(out_s), torch.cat((self.bnt3(out_t), self.bnt3_aug(out_t_dup)), dim=0) ), dim=0) * self.gamma3 + self.beta3 232 | 233 | if self.downsample is not None: 234 | identity = self.downsample(x) 235 | identity_s, identity_t, identity_t_dup = torch.split(identity, split_size_or_sections=identity.shape[0] // 3, dim=0) 236 | identity = torch.cat((self.downsample_bns(identity_s), 237 | torch.cat((self.downsample_bnt(identity_t), self.downsample_bnt_aug(identity_t_dup)), dim=0) ), dim=0) * self.downsample_gamma + self.downsample_beta 238 | 239 | out = out.clone() + identity 240 | out = self.relu(out) 241 | else: 242 | identity = x 243 | 244 | out = self.conv1(x) 245 | out = self.bnt1(out) * self.gamma1 + self.beta1 246 | out = self.relu(out) 247 | 248 | out = self.conv2(out) 249 | out = self.bnt2(out) * self.gamma2 + self.beta2 250 | out = self.relu(out) 251 | 252 | out = self.conv3(out) 253 | out = self.bnt3(out) * self.gamma3 + self.beta3 254 | 255 | if self.downsample is not None: 256 | identity = self.downsample(x) 257 | identity = self.downsample_bnt(identity) * self.downsample_gamma + self.downsample_beta 258 | 259 | out = out.clone() + identity 260 | out = self.relu(out) 261 | 262 | return out 263 | 264 | class ResNet(nn.Module): 265 | 266 | def __init__(self, block, layers, state_dict, num_classes=65, zero_init_residual=False, group_size=4): 267 | super(ResNet, self).__init__() 268 | self.inplanes = 64 269 | self.bn_dict = compute_bn_stats(state_dict) 270 | 271 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 272 | self.bns1 = whitening_scale_shift(planes=64, 273 | group_size=group_size, 274 | running_mean=self.bn_dict['bn1.wh.running_mean'], 275 | running_variance=self.bn_dict['bn1.wh.running_variance'], 276 | affine=False) 277 | self.bnt1 = whitening_scale_shift(planes=64, 278 | group_size=group_size, 279 | running_mean=self.bn_dict['bn1.wh.running_mean'], 280 | running_variance=self.bn_dict['bn1.wh.running_variance'], 281 | affine=False) 282 | self.bnt1_aug = whitening_scale_shift(planes=64, 283 | group_size=group_size, 284 | running_mean=self.bn_dict['bn1.wh.running_mean'], 285 | running_variance=self.bn_dict['bn1.wh.running_variance'], 286 | affine=False) 287 | self.gamma1 = nn.Parameter(self.bn_dict['bn1.gamma']) 288 | self.beta1 = nn.Parameter(self.bn_dict['bn1.beta']) 289 | 290 | self.relu = nn.ReLU(inplace=True) 291 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 292 | self.layer1 = self._make_layer(block, 64, layers[0], self.bn_dict, layer=1) 293 | self.layer2 = self._make_layer(block, 128, layers[1], self.bn_dict, stride=2, layer=2) 294 | self.layer3 = self._make_layer(block, 256, layers[2], self.bn_dict, stride=2, layer=3) 295 | self.layer4 = self._make_layer(block, 512, layers[3], self.bn_dict, stride=2, layer=4) 296 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 297 | self.fc_out = nn.Linear(512 * block.expansion, num_classes) 298 | 299 | for m in self.modules(): 300 | if isinstance(m, nn.Conv2d): 301 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 302 | elif isinstance(m, nn.BatchNorm2d): 303 | nn.init.constant_(m.weight, 1) 304 | nn.init.constant_(m.bias, 0) 305 | 306 | # Zero-initialize the last BN in each residual branch, 307 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 308 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 309 | if zero_init_residual: 310 | for m in self.modules(): 311 | if isinstance(m, Bottleneck): 312 | nn.init.constant_(m.bn3.weight, 0) 313 | elif isinstance(m, BasicBlock): 314 | nn.init.constant_(m.bn2.weight, 0) 315 | 316 | def _make_layer(self, block, planes, blocks, bn_dict, layer=1, group_size=4, stride=1): 317 | downsample = None 318 | if stride != 1 or self.inplanes != planes * block.expansion: 319 | downsample = nn.Sequential( 320 | conv1x1(self.inplanes, planes * block.expansion, stride), 321 | #nn.BatchNorm2d(planes * block.expansion), 322 | ) 323 | 324 | layers = [] 325 | layers.append(block(self.inplanes, planes, layer, 0, bn_dict, group_size, stride, downsample)) 326 | self.inplanes = planes * block.expansion 327 | for i in range(1, blocks): 328 | layers.append(block(self.inplanes, planes, layer, i, bn_dict, group_size)) 329 | 330 | return nn.Sequential(*layers) 331 | 332 | def forward(self, x): 333 | if self.training: 334 | x = self.conv1(x) 335 | x_s, x_t, x_t_dup = torch.split(x, split_size_or_sections=x.shape[0] // 3, dim=0) 336 | x = torch.cat((self.bns1(x_s), torch.cat((self.bnt1(x_t), self.bnt1_aug(x_t_dup)), dim=0) ), dim=0) * self.gamma1 + self.beta1 337 | x = self.relu(x) 338 | x = self.maxpool(x) 339 | 340 | x = self.layer1(x) 341 | x = self.layer2(x) 342 | x = self.layer3(x) 343 | x = self.layer4(x) 344 | 345 | x = self.avgpool(x) 346 | x = x.view(x.size(0), -1) 347 | x = self.fc_out(x) 348 | else: 349 | x = self.conv1(x) 350 | x = self.bnt1(x) * self.gamma1 + self.beta1 351 | x = self.relu(x) 352 | x = self.maxpool(x) 353 | 354 | x = self.layer1(x) 355 | x = self.layer2(x) 356 | x = self.layer3(x) 357 | x = self.layer4(x) 358 | 359 | x = self.avgpool(x) 360 | x = x.view(x.size(0), -1) 361 | x = self.fc_out(x) 362 | 363 | return x 364 | 365 | def resnet50(weights_path, device): 366 | 367 | state_dict_ = torch.load(weights_path, map_location=device) 368 | state_dict_model = state_dict_['state_dict'] 369 | 370 | modified_state_dict = {} 371 | for key in state_dict_model.keys(): 372 | mod_key = key[7:] 373 | modified_state_dict.update({mod_key: state_dict_model[key]}) 374 | 375 | model = ResNet(Bottleneck, [3, 4, 6, 3], modified_state_dict) 376 | model.load_state_dict(modified_state_dict, strict=False) 377 | 378 | return model 379 | 380 | def eval_pass_collect_stats(args, model, device, target_test_loader): 381 | # Run a bunch of forward passes to collect the target statistics before evaluating on the test set 382 | model.train(mode=True) 383 | with torch.no_grad(): 384 | for i in range(10): 385 | print("Pass {} ...".format(i)) 386 | for data, _ in target_test_loader: 387 | data = torch.cat((data, data, data), dim=0) # dont care about source statistics after its trained. 388 | data = data.to(device) 389 | output = model(data) 390 | 391 | def train_infinite_collect_stats(args, model, device, source_train_loader, 392 | target_train_loader, optimizer, lambda_mec_loss, 393 | target_test_loader): 394 | 395 | source_iter = iter(source_train_loader) 396 | target_iter = iter(target_train_loader) 397 | 398 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[6000], gamma=0.1) 399 | 400 | for i in range(args.num_iters): 401 | model.train() 402 | 403 | exp_lr_scheduler.step() 404 | try: 405 | source_data, source_y = next(source_iter) 406 | except: 407 | source_iter = iter(source_train_loader) 408 | source_data, source_y = next(source_iter) 409 | 410 | try: 411 | target_data, target_data_dup, _ = next(target_iter) 412 | except: 413 | target_iter = iter(target_train_loader) 414 | target_data, target_data_dup, _ = next(target_iter) 415 | 416 | data = torch.cat((source_data, target_data, target_data_dup), dim=0) # concat the source and target mini-batches 417 | data, source_y = data.to(device), source_y.to(device) 418 | 419 | optimizer.zero_grad() 420 | output = model(data) 421 | source_output, target_output, target_output_dup = torch.split(output, split_size_or_sections=output.shape[0] // 3, dim=0) 422 | 423 | mec_criterion = consensus_loss.MinEntropyConsensusLoss(num_classes=args.num_classes, device=device) 424 | 425 | cls_loss = F.nll_loss(F.log_softmax(source_output), source_y) 426 | mec_loss = lambda_mec_loss * mec_criterion(target_output, target_output_dup) 427 | 428 | loss = cls_loss + mec_loss 429 | loss.backward() 430 | 431 | optimizer.step() 432 | 433 | if i % args.log_interval == 0: 434 | print('Train Iter: [{}/{}]\tClassification Loss: {:.6f} \t MEC Loss: {:.6f}'.format( 435 | i, args.num_iters, cls_loss.item(), mec_loss.item() 436 | )) 437 | 438 | if (i + 1) % args.check_acc_step == 0: 439 | test(args, model, device, target_test_loader) 440 | 441 | print("Training is complete...") 442 | print("Running a bunch of forward passes to estimate the population statistics of target...") 443 | eval_pass_collect_stats(args, model, device, target_test_loader) 444 | print("Finally computing the precision on the test set...") 445 | test(args, model, device, target_test_loader) 446 | 447 | def test(args, model, device, target_test_loader): 448 | model.eval() 449 | test_loss = 0. 450 | correct = 0 451 | with torch.no_grad(): 452 | for data, target in target_test_loader: 453 | data, target = data.to(device), target.to(device) 454 | output = model(data) 455 | test_loss += F.nll_loss(F.log_softmax(output, dim=1), target, size_average=False).item() 456 | pred = F.softmax(output, dim=1).max(1, keepdim=True)[1] # get the index of max log-probability 457 | correct += pred.eq(target.view_as(pred)).sum().item() 458 | 459 | test_loss /= len(target_test_loader.dataset) 460 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 461 | test_loss, correct, len(target_test_loader.dataset), 462 | 100. * correct / len(target_test_loader.dataset))) 463 | 464 | return 100. * correct / len(target_test_loader.dataset) 465 | 466 | def compute_bn_stats(state_dict): 467 | #state_dict = state_dict = torch.load(path) #'/home/sroy/.torch/models/resnet50-19c8e357.pth' 468 | 469 | bn_key_names = [] 470 | for name, param in state_dict.items(): 471 | if name.find('bn') != -1: 472 | bn_key_names.append(name) 473 | elif name.find('downsample') != -1: 474 | bn_key_names.append(name) 475 | 476 | # keeping only the batch norm specific elements in the dictionary 477 | bn_dict = {k: v for k, v in state_dict.items() if k in bn_key_names} 478 | 479 | return bn_dict 480 | 481 | def _random_affine_augmentation(x): 482 | M = np.float32([[1 + np.random.normal(0.0, 0.1), np.random.normal(0.0, 0.1), 0], 483 | [np.random.normal(0.0, 0.1), 1 + np.random.normal(0.0, 0.1), 0]]) 484 | rows, cols = x.shape[1:3] 485 | dst = cv2.warpAffine(np.transpose(x.numpy(), [1, 2, 0]), M, (cols,rows)) 486 | dst = np.transpose(dst, [2, 0, 1]) 487 | return torch.from_numpy(dst) 488 | 489 | def _gaussian_blur(x, sigma=0.1): 490 | ksize = int(sigma + 0.5) * 8 + 1 491 | dst = cv2.GaussianBlur(x.numpy(), (ksize, ksize), sigma) 492 | return torch.from_numpy(dst) 493 | 494 | 495 | def main(): 496 | 497 | # Training settings 498 | parser = argparse.ArgumentParser(description='PyTorch DWT-MEC OfficeHome') 499 | parser.add_argument('--num_workers', default=2, type=int) 500 | parser.add_argument('--source_batch_size', type=int, default=18, help='input source batch size for training (default: 20)') 501 | parser.add_argument('--target_batch_size', type=int, default=18, help='input target batch size for training (default: 20)') 502 | parser.add_argument('--test_batch_size', type=int, default=10, help='input batch size for testing (default: 10)') 503 | parser.add_argument('--s_dset_path', type=str, default='../data/OfficeHomeDataset_10072016/Art', help="The source dataset path") 504 | parser.add_argument('--t_dset_path', type=str, default='../data/OfficeHomeDataset_10072016/Clipart', help="The target dataset path") 505 | parser.add_argument('--resnet_path', type=str, default='../data/models/model_best_gr_4.pth.tar', help="The pre-trained model path") 506 | parser.add_argument('--img_resize', type=int, default=256, help='size of the input image') 507 | parser.add_argument('--img_crop_size', type=int, default=224, help='size of the cropped image') 508 | parser.add_argument('--num_iters', type=int, default=10000, help='number of iterations to train (default: 10000)') 509 | parser.add_argument('--check_acc_step', type=int, default=100, help='number of iterations steps to check validation accuracy (default: 10)') 510 | parser.add_argument('--lr_change_step', type=int, default=1000) 511 | parser.add_argument('--lr', type=float, default=1e-2, help='learning rate (default: 0.01)') 512 | parser.add_argument('--num_classes', type=int, default=65, help='number of classes in the dataset') 513 | parser.add_argument('--sgd_momentum', type=float, default=0.5, help='SGD momentum (default: 0.5)') 514 | parser.add_argument('--running_momentum', type=float, default=0.1, help='Running momentum for domain statistics(default: 0.1)') 515 | parser.add_argument('--lambda_mec_loss', type=float, default=0.1, help='Value of lambda for the entropy loss (default: 0.1)') 516 | parser.add_argument('--log_interval', type=int, default=10, help='how many batches to wait before logging training status') 517 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 518 | 519 | args = parser.parse_args() 520 | 521 | # set the seed 522 | torch.manual_seed(args.seed) 523 | 524 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 525 | 526 | # transformation on the source data during training and test data during test 527 | data_transform = transforms.Compose([ 528 | transforms.Resize((args.img_resize, args.img_resize)), # spatial size of vgg-f input 529 | transforms.RandomCrop((args.img_crop_size, args.img_crop_size)), 530 | transforms.ToTensor(), 531 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 532 | ]) 533 | 534 | # transformation on the target data 535 | data_transform_dup = transforms.Compose([ 536 | transforms.Resize((args.img_resize, args.img_resize)), 537 | transforms.RandomCrop((args.img_crop_size, args.img_crop_size)), 538 | transforms.RandomHorizontalFlip(), 539 | transforms.ToTensor(), 540 | transforms.Lambda(lambda x: _random_affine_augmentation(x)), 541 | transforms.Lambda(lambda x: _gaussian_blur(x)), 542 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 543 | ]) 544 | 545 | 546 | # train data sets 547 | source_dataset = folder.ImageFolder(root=args.s_dset_path, 548 | transform=data_transform) 549 | target_dataset = folder.ImageFolder(root=args.t_dset_path, 550 | transform=data_transform, 551 | transform_aug=data_transform_dup) 552 | 553 | # test data sets 554 | target_dataset_test = folder.ImageFolder(root=args.t_dset_path, 555 | transform=data_transform) 556 | 557 | # '''''''''''' Train loaders ''''''''''''''' # 558 | source_trainloader = torch.utils.data.DataLoader(source_dataset, 559 | batch_size=args.source_batch_size, 560 | shuffle=True, 561 | num_workers=args.num_workers, 562 | drop_last=True) 563 | 564 | target_trainloader = torch.utils.data.DataLoader(target_dataset, 565 | batch_size=args.source_batch_size, 566 | shuffle=True, 567 | num_workers=args.num_workers, 568 | drop_last=True) 569 | 570 | # '''''''''''' Test loader ''''''''''''''' # 571 | target_testloader = torch.utils.data.DataLoader(target_dataset_test, 572 | batch_size=args.test_batch_size, 573 | shuffle=True, 574 | num_workers=args.num_workers) 575 | 576 | model = resnet50(args.resnet_path, device).to(device) 577 | 578 | final_layer_params = [] 579 | rest_of_the_net_params = [] 580 | 581 | for name, param in model.named_parameters(): 582 | if name.startswith('fc_out'): 583 | final_layer_params.append(param) 584 | else: 585 | rest_of_the_net_params.append(param) 586 | 587 | optimizer = optim.SGD([ 588 | {'params': rest_of_the_net_params}, 589 | {'params': final_layer_params, 'lr': args.lr} 590 | ], lr=args.lr * 0.1, momentum=0.9, weight_decay=5e-4) 591 | 592 | 593 | train_infinite_collect_stats(args=args, 594 | model=model, 595 | device=device, 596 | source_train_loader=source_trainloader, 597 | target_train_loader=target_trainloader, 598 | optimizer=optimizer, 599 | lambda_mec_loss=args.lambda_mec_loss, 600 | target_test_loader=target_testloader) 601 | 602 | if __name__ == '__main__': 603 | main() --------------------------------------------------------------------------------