├── LICENSE ├── Readme.md ├── cache_feats.py ├── docs ├── README.md ├── _config.yml ├── _includes │ ├── google_analytics.html │ └── image_with_caption.html ├── _layouts │ ├── default.html │ └── project.html ├── assets │ └── images │ │ ├── CompRess_pseudo-coda.png │ │ ├── abilative.png │ │ ├── compare_graph.png │ │ ├── compress_poster_final.pdf │ │ ├── query_img_large.jpg │ │ ├── result_table1.png │ │ ├── result_table1_swav.png │ │ ├── result_table2.png │ │ ├── result_table3.png │ │ ├── result_table4.png │ │ ├── result_table5.png │ │ ├── results_table.png │ │ ├── teaser.gif │ │ └── teaser.png ├── bib.txt ├── index.html └── libs │ ├── custom │ ├── my_css.css │ └── my_js.js │ ├── external │ ├── font-awesome-4.7.0 │ │ ├── css │ │ │ ├── font-awesome.css │ │ │ └── font-awesome.min.css │ │ ├── fonts │ │ │ ├── FontAwesome.otf │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ └── fontawesome-webfont.woff2 │ │ ├── less │ │ │ ├── animated.less │ │ │ ├── bordered-pulled.less │ │ │ ├── core.less │ │ │ ├── fixed-width.less │ │ │ ├── font-awesome.less │ │ │ ├── icons.less │ │ │ ├── larger.less │ │ │ ├── list.less │ │ │ ├── mixins.less │ │ │ ├── path.less │ │ │ ├── rotated-flipped.less │ │ │ ├── screen-reader.less │ │ │ ├── stacked.less │ │ │ └── variables.less │ │ └── scss │ │ │ ├── _animated.scss │ │ │ ├── _bordered-pulled.scss │ │ │ ├── _core.scss │ │ │ ├── _fixed-width.scss │ │ │ ├── _icons.scss │ │ │ ├── _larger.scss │ │ │ ├── _list.scss │ │ │ ├── _mixins.scss │ │ │ ├── _path.scss │ │ │ ├── _rotated-flipped.scss │ │ │ ├── _screen-reader.scss │ │ │ ├── _stacked.scss │ │ │ ├── _variables.scss │ │ │ └── font-awesome.scss │ ├── jquery-3.1.1.min.js │ ├── skeleton │ │ ├── normalize.css │ │ └── skeleton.css │ ├── skeleton_tabs │ │ ├── skeleton-tabs.css │ │ └── skeleton-tabs.js │ └── timeline.css │ └── icon.png ├── eval_cluster_alignment.py ├── eval_knn.py ├── eval_linear.py ├── kmeans.py ├── models ├── alexnet.py ├── mobilenet.py ├── resnet.py └── resnet50x4.py ├── nn └── compress_loss.py ├── tools.py ├── train_kmeans.py ├── train_student.py ├── train_student_one_queue.py ├── train_student_without_momentum.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 UMBC Vision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # CompRess: Self-Supervised Learning by Compressing Representations 5 | 6 |

7 | 8 |

9 | 10 | This repository is the official implementation of CompRess: Self-Supervised Learning by Compressing Representations 11 | 12 | Project webpage. [https://umbcvision.github.io/CompRess/ 13 | ](https://umbcvision.github.io/CompRess/) 14 | 15 | ``` 16 | @Article{abbasi2020compress, 17 | author = {Koohpayegani, Soroush Abbasi and Tejankar, Ajinkya and Pirsiavash, Hamed}, 18 | title = {CompRess: Self-Supervised Learning by Compressing Representations}, 19 | journal = {Advances in neural information processing systems}, 20 | year = {2020}, 21 | } 22 | ``` 23 | 24 | [comment]: <> (📋Optional: include a graphic explaining your approach/main result, bibtex entry, link to demos, blog posts and tutorials) 25 | 26 | ## Requirements 27 | 28 | Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). We used Python 3.7 for our experiments. 29 | 30 | 31 | - Install PyTorch ([pytorch.org](http://pytorch.org)) 32 | 33 | 34 | To run NN and Cluster Alignment, you require to install FAISS. 35 | 36 | FAISS: 37 | - Install FAISS ([https://github.com/facebookresearch/faiss/blob/master/INSTALL.md](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md)) 38 | 39 | 40 | 41 | 42 | 43 | [comment]: <> (📋Describe how to set up the environment, e.g. pip/conda/docker commands, download datasets, etc...) 44 | 45 | ## Training 46 | 47 | Our code is based on unofficial implementation of MoCo from [https://github.com/HobbitLong/CMC](https://github.com/HobbitLong/CMC). 48 | 49 | 50 | 51 | 52 | 53 | 54 | To train the student(s) using pretrained teachers in the paper : 55 | 56 | 57 | Download pretrained official MoCo ResNet50 model from [https://github.com/facebookresearch/moco](https://github.com/facebookresearch/moco). 58 | 59 | Then train the student using pretrained model: 60 | 61 | ```train 62 | python train_student.py \ 63 | --teacher_arch resnet50 \ 64 | --teacher \ 65 | --student_arch mobilenet \ 66 | --checkpoint_path \ 67 | 68 | ``` 69 | To train the student(s) using cached teachers in the paper : 70 | 71 | 72 | We converted TensorFlow SimCLRv1 ResNet50x4([https://github.com/google-research/simclr](https://github.com/google-research/simclr)) to PyTorch. Optionally, you can download pretrained SimCLR ResNet50x4 PyTorch model from [here](https://drive.google.com/file/d/1fZ2gfHRjVSFz9Hf2PHsPUao9ZKmUXg4z/view?usp=sharing). 73 | 74 | First, run this command to calculate and store cached features. 75 | ```train 76 | python cache_feats.py \ 77 | --weight \ 78 | --save \ 79 | --arch resnet50x4 \ 80 | --data_pre_processing SimCLR \ 81 | 82 | ``` 83 | 84 | 85 | Then train the student using cached features: 86 | 87 | ```train 88 | python train_student.py \ 89 | --cache_teacher \ 90 | --teacher \ 91 | --student_arch mobilenet \ 92 | --checkpoint_path \ 93 | 94 | ``` 95 | 96 | To train the student(s) without Momentum framework execute train_student_without_momentum.py instead of train_student.py 97 | 98 | [comment]: <> (📋Describe how to train the models, with example commands on how to train the models in your paper, including the full training procedure and appropriate hyperparameters.) 99 | ## Evaluation 100 | 101 | To run Nearest Neighbor evaluation on ImageNet, run: 102 | 103 | ```eval 104 | python eval_knn.py \ 105 | --arch alexnet \ 106 | --weights \ 107 | --save \ 108 | 109 | ``` 110 | Note that above execution will cache features too. After first execution, you can add "--load_cache" flag to load cached features from a file. 111 | 112 | To run Cluster Alignment evaluation on ImageNet, run: 113 | 114 | ```eval 115 | python eval_cluster_alignment.py \ 116 | --weights \ 117 | --arch resnet18 \ 118 | --save \ 119 | --visualization \ 120 | --confusion_matrix \ 121 | 122 | ``` 123 | 124 | 125 | To run Linear Classifier evaluation on ImageNet, run: 126 | 127 | ```eval 128 | 129 | python eval_linear.py \ 130 | --arch alexnet \ 131 | --weights \ 132 | --save \ 133 | 134 | ``` 135 | 136 | 137 | 138 | 139 | 140 | ## Results 141 | 142 |

143 | 144 |

145 | 146 | "SOTA Self-Supervised" refers to SimCLR for RexNet50x4 and MoCo for all other architectures. 147 | 148 | Our model achieves the following performance on ImageNet: 149 | 150 | 151 | | Model name | Teacher | Top-1 Linear Classifier Accuracy | Top-1 Nearest Neighbor Accuracy | Top-1 Cluster Alignment Accuracy| Pre-trained | 152 | | ------------------ | --------- |----------------------------------| ----------------- | ------- | ----------------- | 153 | | CompRess(Resnet50) | SimCLR ResNet50x4(cached) | 71.6% | 63.4% | 42.0% | [Pre-trained Resnet50](https://drive.google.com/file/d/15rzzSkcedEuCE7Cm8yLXopA5PqHUQscb/view?usp=sharing) | 154 | | CompRess(Mobilenet)| MoCoV2 ResNet50 | 63.0% | 54.4% | 35.5% | [Pre-trained Mobilenet](https://drive.google.com/file/d/1gNkO48iREh6M6uuLd8TGqaOm3ChWiAdc/view?usp=sharing) | 155 | | CompRess(Resnet18) | MoCoV2 ResNet50 | 61.7% | 53.4% | 34.7% | [Pre-trained Resnet18](https://drive.google.com/file/d/1L-RCmD4gMeicxJhIeqNKU09_sH8R3bwS/view?usp=sharing) | 156 | | CompRess(Resnet18) | SwAV ResNet50 | 65.6% | 56.0% | 26.3% | [Pre-trained Resnet18](https://drive.google.com/file/d/1ZtPUAuq_S6-Yqtuajb-BdffKm--eyxPw/view?usp=sharing) | 157 | | CompRess(Alexnet) | SimCLR ResNet50x4(cached) | 57.6% | 52.3% | 33.3% | [Pre-trained Alexnet](https://drive.google.com/file/d/1wiEdfk6unXKtYRL1faIMoZMXnShaxBMU/view?usp=sharing) | 158 | 159 | 160 | 161 | 162 | 163 | ## License 164 | 165 | This project is under the MIT license. 166 | 167 | 168 | -------------------------------------------------------------------------------- /cache_feats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | from collections import Counter 9 | from random import shuffle 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | from torch.utils.data import DataLoader 23 | from tools import * 24 | from models.resnet50x4 import Resnet50_X4 as resnet50x4 25 | from models.resnet_swav import resnet50w5 26 | from eval_linear import load_weights 27 | from eval_knn import get_feats, faiss_knn, ImageFolderEx 28 | import numpy as np 29 | import time 30 | import pickle 31 | 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 34 | parser.add_argument('data', metavar='DIR', 35 | help='path to dataset') 36 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 37 | help='number of data loading workers (default: 32)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', 40 | help='mini-batch size (default: 256), this is the total ' 41 | 'batch size of all GPUs on the current node when ' 42 | 'using Data Parallel or Distributed Data Parallel') 43 | parser.add_argument('--arch', type=str, default='resnet50x4', 44 | choices=['resnet50x4','resnet50', 'resnet50w5']) 45 | parser.add_argument('--data_pre_processing', type=str, default='SimCLR', 46 | choices=['SimCLR','MoCo']) 47 | parser.add_argument('-p', '--print-freq', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--save', type=str, default='output/cached_feats', 50 | help='directory to store cached features') 51 | parser.add_argument('--weights', default='', type=str, 52 | help='path to pretrained model checkpoint') 53 | 54 | 55 | def main(): 56 | global logger 57 | 58 | args = parser.parse_args() 59 | makedirs(args.save) 60 | 61 | logger = get_logger( 62 | logpath=os.path.join(args.save, 'logs'), 63 | filepath=os.path.abspath(__file__) 64 | ) 65 | def print_pass(*args): 66 | logger.info(*args) 67 | builtins.print = print_pass 68 | 69 | print(args) 70 | 71 | main_worker(args) 72 | 73 | 74 | def get_model(args): 75 | model = None 76 | 77 | if args.arch == 'resnet50x4': 78 | model = resnet50x4() 79 | checkpoint = torch.load(args.weights) 80 | model.load_state_dict(checkpoint['state_dict'], strict=True) 81 | model.fc = nn.Sequential() 82 | elif args.arch == 'resnet50': 83 | model = models.resnet50() 84 | model.fc = nn.Sequential() 85 | checkpoint = torch.load(args.weights) 86 | sd = checkpoint['state_dict'] 87 | sd = {k: v for k, v in sd.items() if 'encoder_q' in k} 88 | sd = {k: v for k, v in sd.items() if 'fc' not in k} 89 | sd = {k.replace('module.encoder_q.', ''): v for k, v in sd.items()} 90 | model.load_state_dict(sd, strict=True) 91 | elif args.arch == 'resnet50w5': 92 | model = resnet50w5() 93 | model.l2norm = None 94 | load_weights(model, args.weights) 95 | 96 | for p in model.parameters(): 97 | p.requires_grad = False 98 | 99 | return model 100 | 101 | 102 | def get_data_loader(args): 103 | if args.data_pre_processing == 'SimCLR': 104 | # Data loaders 105 | traindir = os.path.join(args.data, 'train') 106 | valdir = os.path.join(args.data, 'val') 107 | 108 | train_loader = torch.utils.data.DataLoader( 109 | ImageFolderEx(traindir, transforms.Compose([ 110 | transforms.Resize(256), 111 | transforms.CenterCrop(224), 112 | transforms.ToTensor(), 113 | ])), 114 | batch_size=args.batch_size, shuffle=False, 115 | num_workers=args.workers, pin_memory=True) 116 | 117 | val_loader = torch.utils.data.DataLoader( 118 | ImageFolderEx(valdir, transforms.Compose([ 119 | transforms.Resize(256), 120 | transforms.CenterCrop(224), 121 | transforms.ToTensor(), 122 | ])), 123 | batch_size=args.batch_size, shuffle=False, 124 | num_workers=args.workers, pin_memory=True) 125 | elif args.data_pre_processing == "MoCo": 126 | mean = [0.485, 0.456, 0.406] 127 | std = [0.229, 0.224, 0.225] 128 | normalize = transforms.Normalize(mean=mean, std=std) 129 | 130 | # Data loaders 131 | traindir = os.path.join(args.data, 'train') 132 | valdir = os.path.join(args.data, 'val') 133 | 134 | train_loader = torch.utils.data.DataLoader( 135 | ImageFolderEx(traindir, transforms.Compose([ 136 | transforms.Resize(256), 137 | transforms.CenterCrop(224), 138 | transforms.ToTensor(), 139 | normalize, 140 | ])), 141 | batch_size=args.batch_size, shuffle=False, 142 | num_workers=args.workers, pin_memory=True) 143 | 144 | val_loader = torch.utils.data.DataLoader( 145 | ImageFolderEx(valdir, transforms.Compose([ 146 | transforms.Resize(256), 147 | transforms.CenterCrop(224), 148 | transforms.ToTensor(), 149 | normalize, 150 | ])), 151 | batch_size=args.batch_size, shuffle=False, 152 | num_workers=args.workers, pin_memory=True) 153 | 154 | return train_loader, val_loader 155 | 156 | 157 | def normalize(x): 158 | return x / x.norm(2, dim=1, keepdim=True) 159 | 160 | 161 | def main_worker(args): 162 | model = get_model(args) 163 | model = nn.DataParallel(model).cuda() 164 | 165 | train_loader, val_loader = get_data_loader(args) 166 | 167 | model.eval() 168 | 169 | cudnn.benchmark = True 170 | 171 | feats_file = '%s/train_feats.pth.tar' % args.save 172 | print('get train feats =>') 173 | # train_feats, train_labels, train_inds = torch.load(feats_file) 174 | train_feats, train_labels, train_inds = get_feats(train_loader, model, args.print_freq) 175 | torch.save((train_feats, train_labels, train_inds), feats_file) 176 | 177 | feats_file = '%s/val_feats.pth.tar' % args.save 178 | print('get val feats =>') 179 | # val_feats, val_labels, val_inds = torch.load(feats_file) 180 | val_feats, val_labels, val_inds = get_feats(val_loader, model, args.print_freq) 181 | torch.save((val_feats, val_labels, val_inds), feats_file) 182 | 183 | train_feats = normalize(train_feats) 184 | val_feats = normalize(val_feats) 185 | acc = faiss_knn(train_feats, train_labels, val_feats, val_labels, k=1) 186 | print(' * Acc {:.2f}'.format(acc)) 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # CompReSS framework Webpage 2 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | source: . 2 | destination: ./_site 3 | includes: ./_includes 4 | 5 | collections: 6 | projects: 7 | output: true 8 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /docs/_includes/google_analytics.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_includes/image_with_caption.html: -------------------------------------------------------------------------------- 1 |
2 | {{ include.description }} 3 |
{{ include.description }}
4 |
-------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | {{ page.title }} 9 | 10 | 11 | 12 | 14 | 15 | 16 | 18 | 19 | 20 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 32 | 33 | 34 | 36 | 37 | 38 | 39 | 41 | 42 | 43 | 45 | 46 | 47 | 48 | 50 | 54 | 55 | 56 | 57 | 58 | 60 |
61 | 62 |
63 |
64 |

CompRess:

65 |
66 |
67 |

Self-Supervised Learning by Compressing Representations

68 |
69 |
70 | 71 | {{content}} 72 | 73 | 94 | 95 |
96 | 97 | 98 | {% include google_analytics.html %} 99 | 100 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /docs/_layouts/project.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 |
5 | 6 | {% if page.subtitle == "" or page.subtitle == nil %} 7 |

{{ page.title }}

8 | {% else %} 9 |
10 |

{{ page.title }}

11 |
{{ page.subtitle }}
12 |
13 | {% endif %} 14 | 15 | {{ content }} 16 | 17 |
-------------------------------------------------------------------------------- /docs/assets/images/CompRess_pseudo-coda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/CompRess_pseudo-coda.png -------------------------------------------------------------------------------- /docs/assets/images/abilative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/abilative.png -------------------------------------------------------------------------------- /docs/assets/images/compare_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/compare_graph.png -------------------------------------------------------------------------------- /docs/assets/images/compress_poster_final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/compress_poster_final.pdf -------------------------------------------------------------------------------- /docs/assets/images/query_img_large.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/query_img_large.jpg -------------------------------------------------------------------------------- /docs/assets/images/result_table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/result_table1.png -------------------------------------------------------------------------------- /docs/assets/images/result_table1_swav.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/result_table1_swav.png -------------------------------------------------------------------------------- /docs/assets/images/result_table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/result_table2.png -------------------------------------------------------------------------------- /docs/assets/images/result_table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/result_table3.png -------------------------------------------------------------------------------- /docs/assets/images/result_table4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/result_table4.png -------------------------------------------------------------------------------- /docs/assets/images/result_table5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/result_table5.png -------------------------------------------------------------------------------- /docs/assets/images/results_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/results_table.png -------------------------------------------------------------------------------- /docs/assets/images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/teaser.gif -------------------------------------------------------------------------------- /docs/assets/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/assets/images/teaser.png -------------------------------------------------------------------------------- /docs/bib.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{koohpayegani2020compress, 2 | title={CompRess: Self-Supervised Learning by Compressing Representations}, 3 | author={Koohpayegani, Soroush Abbasi and Tejankar, Ajinkya and Pirsiavash, Hamed}, 4 | booktitle={NeurIPS}, 5 | year={2020} 6 | } 7 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | title: CompRess 4 | --- 5 | 6 | 9 |
10 |

University of Maryland, Baltimore County

11 |
12 |
13 |

denote equal contribution

14 |
15 | 16 | 24 | 25 |
26 |

27 | 28 |
Abstract
29 |

Self-supervised learning aims to learn good representations with unlabeled data. Recent works have shown that larger models benefit more from self-supervised 30 | learning than smaller models. As a result, the gap between supervised and self-supervised learning has been greatly reduced for larger models. In this work, 31 | instead of designing a new pseudo task for self-supervised learning, we develop a model compression method to compress an already learned, deep self-supervised model 32 | (teacher) to a smaller one (student). We train the student model so that it mimics the relative similarity between the datapoints in the teacher's embedding space. 33 | For AlexNet, our method outperforms all previous methods including the fully supervised model on ImageNet linear evaluation (59.0% compared to 56.5%) and on nearest 34 | neighbor evaluation (50.7% compared to 41.4%). To the best of our knowledge, this is the first time a self-supervised AlexNet has outperformed supervised one on 35 | ImageNet classification.

36 | 37 |
Contributions
38 |

39 | The goal of self-supervised learning is to learn good representations without annotations. We find that recent improvements in self-supervised learning have 40 | significantly reduced the gap between them and supervised learning for higher capacity models. But, the gap is still huge for small and simpler models. Thus, we propose 41 | to learn representations by training a high capacity model using an off-the-shelf self supervised method. And then compressing it to a smaller model through a novel 42 | compression method. 43 | 44 | Using our framework(CompRess), we train the student using unlabeled ImageNet which is better than training a self-supervised model from the scratch. 45 | Our method reduce the gap between SOTA self-supervised and supervised models and even outperformed supervised model in Alexnet architecture. 46 |

47 | 48 |

49 | 50 |
Method
51 |

52 | We propose to capture the similarity of each data point to the other training data in the teacher’s embedding space and then transfer that knowledge to the student. 53 | For each image, we compare it with a random set of anchor points and convert the similarities to a probability distribution over the anchor points. This distribution represents each image in terms of its nearest neighbors. 54 |  We want to transfer this knowledge to the student so we get the same distribution for the student as well. 55 | Finally, we train the student to minimize the KL divergence between the two distributions. Intuitively, we want each data point to have the same neighbors in both teacher and student embeddings. 56 | 57 |

58 | 59 |

60 | 61 |
Results
62 | 63 |

64 | Looking at the results, we find that our method is better than other compression methods by a large margin across 3 different evaluation benchmarks 65 | and 2 different teacher SSL methods. This demonstrates the general applicability of our method. 66 | 67 | Inaddition, when we compress ResNet-50x4 to AlexNet, we get 59.0% for Linear, 50.7% for Nearest Neighbor (NN), and 27.6% for Cluster Alignment (CA) 68 | which outperforms the supervised model. On NN, our ResNet-50 is only 1 point worse than its ResNet-50x4 teacher. Note that models below the teacher row use the 69 | student architecture. Since a forward passthrough the teacher is expensive for ResNet50x4, we do not compare with CRD, Reg, and Reg-BN. 70 | 71 |

72 |

73 | 74 |

75 | 76 |

77 | To evaluate the effect of the teacher’s SSL method, we use SwAV ResNet-50 as the teacher and compress it to ResNet-18. 78 | We still get better accuracy compared to other distillation methods. Note that SwAV (concurrent work) [2] is different from MoCo and SimCLR 79 | in that it performs contrastive learning through online clustering. 80 | 81 |

82 |

83 | 84 |

85 | 86 | 87 | 88 |

89 | Comparison with SOTA self-supervised methods. We pick the best layer to report the results that is written in parenthesis: 90 | ‘f7’ refers to ‘fc7’layer and ‘c4’ refers to ‘conv4’ layer. R50x4 refers to the teacher that is trained with SimCLR. ∗ refers to 10-crop 91 | evaluation. † denotes concurrent methods. 92 | 93 |

94 | 95 |

96 | 97 | 98 |

99 | We evaluate AlexNet compressed from ResNet-50x4 on PASCAL-VOC classification and detection tasks. For classification task, we only train a linear classifier on 100 | top of frozen backbone which is in contrast to the baselines that finetune all layers. Our model is on par with ImageNet supervised model. 101 | For classification, we denote the fine-tuned layers in the parenthesis. For detection, all layers are fine-tuned. ∗ denotes bigger AlexNet [23]. 102 | 103 | 104 |

105 | 106 |

107 |

108 | 109 |

110 | Another advantage of our method is data efficiency. For instance, on 1% ImageNet, our ResNet-50 student is significantly better 111 | than other state-of-the-art SSL methods like BYOL and SwAV. Note that we only train a single linear layer on the top of frozen features. 112 | While other methods on this table fine-tune the whole network. ∗ denotes concurrent methods. 113 | 114 | 115 |

116 | 117 |

118 |

119 | 120 | 121 |
CompRess Pseudo-Code
122 |

123 | 124 | 125 |
Cluster Alignment Result
126 | 127 |

Our method is not designed specifically to learn good clustering. However, since it achieves good nearest neighbor results, we evaluate our features in clustering ImageNet dataset. 128 | The goal is to use k-means to cluster our self-supervised features trained on unlabeled ImageNet with no labels, map each cluster to an ImageNet category, and then evaluate 129 | on ImageNet validation set. In order to map clusters to categories, we first calculate the similarity between all (cluster, category) pairs by calculating the number of common 130 | images divided by the size of cluster. Then, we find the best mapping between clusters and categories using Hungarian algorithm. 131 | In the following, we show randomly selected images (columns) from randomly selected clusters (rows) for our best AlexNet modal. 132 | This is done with no manual inspection or cherry-picking. Note that most rows are aligned with semantic categories. 133 |

134 |

135 | 136 |
References
137 |
[1] Mathilde Caron et al. “Deep clustering for unsupervised learning of visual features.” In:Proceedings ofthe European Conference on Computer Vision (ECCV). 2018, pp. 132–149 138 |
[2] Mathilde Caron et al. “Unsupervised learning of visual features by contrasting cluster assignments.” In:arXiv preprint arXiv:2006.09882(2020). 139 |
[3] Ting Chen et al. “A Simple Framework for Contrastive Learning of Visual Representations.” In:arXivpreprint arXiv:2002.05709(2020). 140 |
[4] Xinlei Chen et al. “Improved Baselines with Momentum Contrastive Learning.” In:arXiv preprintarXiv:2003.04297(2020). 141 |
[5] Carl Doersch, Abhinav Gupta, and Alexei A Efros. “Unsupervised visual representation learning bycontext prediction.” In:Proceedings of the IEEE International Conference on Computer Vision. 2015,pp. 1422–1430. 142 |
[6] Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. “Adversarial feature learning.” In:InternationalConference on Learning Representations, ICLR. 2016. 143 |
[7] Zeyu Feng, Chang Xu, and Dacheng Tao. “Self-supervised representation learning by rotation featuredecoupling.” In:Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019,pp. 10364–10374. 144 |
[8] Spyros Gidaris, Praveer Singh, and Nikos Komodakis. “Unsupervised Representation Learning byPredicting Image Rotations.” In:International Conference on Learning Representations. 2018.URL:https://openreview.net/forum?id=S1v4N2l0-. 145 |
[9] Jean-Bastien Grill et al. “Bootstrap your own latent: A new approach to self-supervised learning.” In:arXiv preprint arXiv:2006.07733(2020). 146 |
[10] Kaiming He et al. “Momentum contrast for unsupervised visual representation learning.” In:Proceedingsof the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020, pp. 9729–9738. 147 |
[11] Jiabo Huang et al. “Unsupervised Deep Learning by Neighbourhood Discovery.” In: ed. by KamalikaChaudhuri and Ruslan Salakhutdinov. Vol. 97. Proceedings of Machine Learning Research. Long Beach,California, USA: PMLR, Sept. 2019, pp. 2849–2858.URL:http://proceedings.mlr.press/v97/huang19b.html. 148 |
[12] Simon Jenni and Paolo Favaro. “Self-supervised feature learning by learning to spot artifacts.” In:Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018, pp. 2733–2742. 149 |
[13] Ishan Misra and Laurens van der Maaten. “Self-supervised learning of pretext-invariant representations.”In:arXiv preprint arXiv:1912.01991(2019). 150 |
[14] Mehdi Noroozi and Paolo Favaro. “Unsupervised learning of visual representations by solving jigsawpuzzles.” In:European Conference on Computer Vision. Springer. 2016, pp. 69–84. 151 |
[15] Mehdi Noroozi, Hamed Pirsiavash, and Paolo Favaro. “Representation learning by learning to count.” In:Proceedings of the IEEE International Conference on Computer Vision. 2017, pp. 5898–5906. 152 |
[16] Mehdi Noroozi et al. “Boosting self-supervised learning via knowledge transfer.” In:Proceedings of theIEEE Conference on Computer Vision and Pattern Recognition. 2018, pp. 9359–9367. 153 |
[17] Deepak Pathak et al. “Context encoders: Feature learning by inpainting.” In:Proceedings of the IEEEconference on computer vision and pattern recognition. 2016, pp. 2536–2544. 154 |
[18] Yonglong Tian, Dilip Krishnan, and Phillip Isola. “Contrastive multiview coding.” In:arXiv preprintarXiv:1906.05849(2019). 155 |
[19] Yonglong Tian, Dilip Krishnan, and Phillip Isola. “Contrastive Representation Distillation.” In:Interna-tional Conference on Learning Representations. 2020.URL:https://openreview.net/forum?id=SkgpBJrtvS. 156 |
[20] Yonglong Tian et al. “What makes for good views for contrastive learning.” In:arXiv preprintarXiv:2005.10243(2020). 157 |
[21] Zhirong Wu et al. “Unsupervised feature learning via non-parametric instance discrimination.” In:Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018, pp. 3733–3742. 158 |
[22] Xueting Yan et al. “ClusterFit: Improving Generalization of Visual Representations.” In:CVPR. 2020. 159 |
[23] Asano YM., Rupprecht C., and Vedaldi A. “Self-labelling via simultaneous clustering and represen-tation learning.” In:International Conference on Learning Representations. 2020.URL:https://openreview.net/forum?id=Hyx-jyBFPr. 160 |
[24] Liheng Zhang et al. “Aet vs. aed: Unsupervised representation learning by auto-encoding transformationsrather than data.” In:Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition.2019, pp. 2547–2555. 161 |
[25] Richard Zhang, Phillip Isola, and Alexei A Efros. “Colorful image colorization.” In:European conferenceon computer vision. Springer. 2016, pp. 649–666. 162 |
[26] Richard Zhang, Phillip Isola, and Alexei A Efros. “Split-brain autoencoders: Unsupervised learning bycross-channel prediction.” In:Proceedings of the IEEE Conference on Computer Vision and PatternRecognition. 2017, pp. 1058–1067. 163 |
[27] Chengxu Zhuang, Alex Lin Zhai, and Daniel Yamins. “Local aggregation for unsupervised learning ofvisual embeddings.” In:Proceedings of the IEEE International Conference on Computer Vision. 2019,pp. 6002–6012. 164 | 165 | -------------------------------------------------------------------------------- /docs/libs/custom/my_css.css: -------------------------------------------------------------------------------- 1 | h4 { 2 | letter-spacing: .1rem; } 3 | .container { 4 | max-width: 900px; } 5 | .header { 6 | margin-top: 60px; 7 | margin-bottom: 60px; 8 | text-align: left; } 9 | .main-description h1 { 10 | margin-bottom: 0px; } 11 | .main-description p { 12 | margin-bottom: 0px; } 13 | .main-description .fa { 14 | font-size: 22px; 15 | margin-top: 5px; } 16 | .paper { 17 | margin-bottom: 20px } 18 | .paper .title { 19 | /*text-transform: uppercase; */} 20 | .paper p { 21 | margin-bottom: 2px } 22 | .paper-buttons { 23 | margin-top: 2px } 24 | .paper .button { 25 | display: inline-block; 26 | height: 20px; 27 | padding: 0 5px; 28 | font-weight: 100; 29 | line-height: 20px; 30 | margin-bottom: 0px; 31 | text-transform: none; } 32 | .section-header { 33 | text-transform: uppercase; 34 | font-size: 1.5rem; 35 | letter-spacing: .2rem; 36 | font-weight: 600; } 37 | .docs-section { 38 | border-top: 1px solid #eee; 39 | padding: 4rem 0; 40 | margin-bottom: 0;} 41 | .footer { 42 | text-align: center; 43 | color: rgba(0, 0, 0, 0.4); 44 | border-top: 1px solid #eee; 45 | padding: 15px 0 40px 0; 46 | margin-bottom: 0;} 47 | .navbar { 48 | display: none; } 49 | .title-subtitle h3 { 50 | margin-bottom: 0px; } 51 | .title-subtitle h5 { 52 | color: rgba(0, 0, 0, 0.60); } 53 | .image { 54 | text-align: center; } 55 | 56 | /* ================ Project styling ================ */ 57 | .project-container { 58 | margin-bottom: 30px; 59 | } 60 | .project-image-container { 61 | border: 1px solid rgba(0,0,0,0.2); 62 | padding: 3px; 63 | transition: 0.3s; 64 | } 65 | .project-image-container:hover { 66 | transition: 0.3s; 67 | } 68 | .project-image-container:hover img { 69 | opacity: 0.4; 70 | } 71 | .project-caption { 72 | padding: 3px; 73 | } 74 | .menu { 75 | text-align: center; 76 | } 77 | .menu ul { 78 | list-style: none; 79 | } 80 | .menu li { 81 | display: inline; 82 | padding: 0px 10px; 83 | } 84 | 85 | /* Larger than phone */ 86 | @media (min-width: 550px) { 87 | .header { 88 | margin-top: 5rem; } 89 | } 90 | 91 | /* Larger than tablet */ 92 | @media (min-width: 750px) { 93 | /* Navbar */ 94 | .navbar + .docs-section { 95 | border-top-width: 0; } 96 | .navbar, 97 | .navbar-spacer { 98 | display: block; 99 | width: 100%; 100 | height: 6.5rem; 101 | background: #fff; 102 | z-index: 99; 103 | border-top: 1px solid #eee; 104 | border-bottom: 1px solid #eee; } 105 | .navbar-spacer { 106 | display: none; } 107 | .navbar > .container { 108 | width: 100%; } 109 | .navbar-list { 110 | list-style: none; 111 | margin-bottom: 0; } 112 | .navbar-item { 113 | position: relative; 114 | float: left; 115 | margin-bottom: 0; } 116 | .navbar-link { 117 | text-transform: uppercase; 118 | font-size: 11px; 119 | font-weight: 600; 120 | letter-spacing: .2rem; 121 | margin-right: 35px; 122 | text-decoration: none; 123 | line-height: 6.5rem; 124 | color: #222; } 125 | .navbar-link.active { 126 | color: #33C3F0; } 127 | .has-docked-nav .navbar { 128 | position: fixed; 129 | top: 0; 130 | left: 0; } 131 | .has-docked-nav .navbar-spacer { 132 | display: block; } 133 | /* Re-overiding the width 100% declaration to match size of % based container */ 134 | .has-docked-nav .navbar > .container { 135 | width: 80%; } 136 | 137 | /* Popover */ 138 | .popover.open { 139 | display: block; 140 | } 141 | .popover { 142 | display: none; 143 | position: absolute; 144 | top: 0; 145 | left: 0; 146 | background: #fff; 147 | border: 1px solid #eee; 148 | border-radius: 4px; 149 | top: 92%; 150 | left: -50%; 151 | -webkit-filter: drop-shadow(0 0 6px rgba(0,0,0,.1)); 152 | -moz-filter: drop-shadow(0 0 6px rgba(0,0,0,.1)); 153 | filter: drop-shadow(0 0 6px rgba(0,0,0,.1)); } 154 | .popover-item:first-child .popover-link:after, 155 | .popover-item:first-child .popover-link:before { 156 | bottom: 100%; 157 | left: 50%; 158 | border: solid transparent; 159 | content: " "; 160 | height: 0; 161 | width: 0; 162 | position: absolute; 163 | pointer-events: none; } 164 | .popover-item:first-child .popover-link:after { 165 | border-color: rgba(255, 255, 255, 0); 166 | border-bottom-color: #fff; 167 | border-width: 10px; 168 | margin-left: -10px; } 169 | .popover-item:first-child .popover-link:before { 170 | border-color: rgba(238, 238, 238, 0); 171 | border-bottom-color: #eee; 172 | border-width: 11px; 173 | margin-left: -11px; } 174 | .popover-list { 175 | padding: 0; 176 | margin: 0; 177 | list-style: none; } 178 | .popover-item { 179 | padding: 0; 180 | margin: 0; } 181 | .popover-link { 182 | position: relative; 183 | color: #222; 184 | display: block; 185 | padding: 8px 20px; 186 | border-bottom: 1px solid #eee; 187 | text-decoration: none; 188 | text-transform: uppercase; 189 | font-size: 1.0rem; 190 | font-weight: 600; 191 | text-align: center; 192 | letter-spacing: .1rem; } 193 | .popover-item:first-child .popover-link { 194 | border-radius: 4px 4px 0 0; } 195 | .popover-item:last-child .popover-link { 196 | border-radius: 0 0 4px 4px; 197 | border-bottom-width: 0; } 198 | .popover-link:hover { 199 | color: #fff; 200 | background: #33C3F0; } 201 | .popover-link:hover, 202 | .popover-item:first-child .popover-link:hover:after { 203 | border-bottom-color: #33C3F0; } 204 | } 205 | -------------------------------------------------------------------------------- /docs/libs/custom/my_js.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | 3 | // Variables 4 | var $codeSnippets = $('.code-example-body'), 5 | $nav = $('.navbar'), 6 | $body = $('body'), 7 | $window = $(window), 8 | $popoverLink = $('[data-popover]'), 9 | navOffsetTop = $nav.offset().top, 10 | $document = $(document), 11 | entityMap = { 12 | "&": "&", 13 | "<": "<", 14 | ">": ">", 15 | '"': '"', 16 | "'": ''', 17 | "/": '/' 18 | } 19 | 20 | function init() { 21 | $window.on('scroll', onScroll) 22 | $window.on('resize', resize) 23 | $popoverLink.on('click', openPopover) 24 | $document.on('click', closePopover) 25 | $('a[href^="#"]').on('click', smoothScroll) 26 | buildSnippets(); 27 | } 28 | 29 | function smoothScroll(e) { 30 | e.preventDefault(); 31 | $(document).off("scroll"); 32 | var target = this.hash, 33 | menu = target; 34 | $target = $(target); 35 | $('html, body').stop().animate({ 36 | 'scrollTop': $target.offset().top-40 37 | }, 0, 'swing', function () { 38 | window.location.hash = target; 39 | $(document).on("scroll", onScroll); 40 | }); 41 | } 42 | 43 | function openPopover(e) { 44 | e.preventDefault() 45 | closePopover(); 46 | var popover = $($(this).data('popover')); 47 | popover.toggleClass('open') 48 | e.stopImmediatePropagation(); 49 | } 50 | 51 | function closePopover(e) { 52 | if($('.popover.open').length > 0) { 53 | $('.popover').removeClass('open') 54 | } 55 | } 56 | 57 | $("#button").click(function() { 58 | $('html, body').animate({ 59 | scrollTop: $("#elementtoScrollToID").offset().top 60 | }, 2000); 61 | }); 62 | 63 | function resize() { 64 | $body.removeClass('has-docked-nav') 65 | navOffsetTop = $nav.offset().top 66 | onScroll() 67 | } 68 | 69 | function onScroll() { 70 | if(navOffsetTop < $window.scrollTop() && !$body.hasClass('has-docked-nav')) { 71 | $body.addClass('has-docked-nav') 72 | } 73 | if(navOffsetTop > $window.scrollTop() && $body.hasClass('has-docked-nav')) { 74 | $body.removeClass('has-docked-nav') 75 | } 76 | } 77 | 78 | function escapeHtml(string) { 79 | return String(string).replace(/[&<>"'\/]/g, function (s) { 80 | return entityMap[s]; 81 | }); 82 | } 83 | 84 | function buildSnippets() { 85 | $codeSnippets.each(function() { 86 | var newContent = escapeHtml($(this).html()) 87 | $(this).html(newContent) 88 | }) 89 | } 90 | 91 | 92 | init(); 93 | 94 | }); -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/fonts/FontAwesome.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/libs/external/font-awesome-4.7.0/fonts/FontAwesome.otf -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/libs/external/font-awesome-4.7.0/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/animated.less: -------------------------------------------------------------------------------- 1 | // Animated Icons 2 | // -------------------------- 3 | 4 | .@{fa-css-prefix}-spin { 5 | -webkit-animation: fa-spin 2s infinite linear; 6 | animation: fa-spin 2s infinite linear; 7 | } 8 | 9 | .@{fa-css-prefix}-pulse { 10 | -webkit-animation: fa-spin 1s infinite steps(8); 11 | animation: fa-spin 1s infinite steps(8); 12 | } 13 | 14 | @-webkit-keyframes fa-spin { 15 | 0% { 16 | -webkit-transform: rotate(0deg); 17 | transform: rotate(0deg); 18 | } 19 | 100% { 20 | -webkit-transform: rotate(359deg); 21 | transform: rotate(359deg); 22 | } 23 | } 24 | 25 | @keyframes fa-spin { 26 | 0% { 27 | -webkit-transform: rotate(0deg); 28 | transform: rotate(0deg); 29 | } 30 | 100% { 31 | -webkit-transform: rotate(359deg); 32 | transform: rotate(359deg); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/bordered-pulled.less: -------------------------------------------------------------------------------- 1 | // Bordered & Pulled 2 | // ------------------------- 3 | 4 | .@{fa-css-prefix}-border { 5 | padding: .2em .25em .15em; 6 | border: solid .08em @fa-border-color; 7 | border-radius: .1em; 8 | } 9 | 10 | .@{fa-css-prefix}-pull-left { float: left; } 11 | .@{fa-css-prefix}-pull-right { float: right; } 12 | 13 | .@{fa-css-prefix} { 14 | &.@{fa-css-prefix}-pull-left { margin-right: .3em; } 15 | &.@{fa-css-prefix}-pull-right { margin-left: .3em; } 16 | } 17 | 18 | /* Deprecated as of 4.4.0 */ 19 | .pull-right { float: right; } 20 | .pull-left { float: left; } 21 | 22 | .@{fa-css-prefix} { 23 | &.pull-left { margin-right: .3em; } 24 | &.pull-right { margin-left: .3em; } 25 | } 26 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/core.less: -------------------------------------------------------------------------------- 1 | // Base Class Definition 2 | // ------------------------- 3 | 4 | .@{fa-css-prefix} { 5 | display: inline-block; 6 | font: normal normal normal @fa-font-size-base/@fa-line-height-base FontAwesome; // shortening font declaration 7 | font-size: inherit; // can't have font-size inherit on line above, so need to override 8 | text-rendering: auto; // optimizelegibility throws things off #1094 9 | -webkit-font-smoothing: antialiased; 10 | -moz-osx-font-smoothing: grayscale; 11 | 12 | } 13 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/fixed-width.less: -------------------------------------------------------------------------------- 1 | // Fixed Width Icons 2 | // ------------------------- 3 | .@{fa-css-prefix}-fw { 4 | width: (18em / 14); 5 | text-align: center; 6 | } 7 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/font-awesome.less: -------------------------------------------------------------------------------- 1 | /*! 2 | * Font Awesome 4.7.0 by @davegandy - http://fontawesome.io - @fontawesome 3 | * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License) 4 | */ 5 | 6 | @import "variables.less"; 7 | @import "mixins.less"; 8 | @import "path.less"; 9 | @import "core.less"; 10 | @import "larger.less"; 11 | @import "fixed-width.less"; 12 | @import "list.less"; 13 | @import "bordered-pulled.less"; 14 | @import "animated.less"; 15 | @import "rotated-flipped.less"; 16 | @import "stacked.less"; 17 | @import "icons.less"; 18 | @import "screen-reader.less"; 19 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/larger.less: -------------------------------------------------------------------------------- 1 | // Icon Sizes 2 | // ------------------------- 3 | 4 | /* makes the font 33% larger relative to the icon container */ 5 | .@{fa-css-prefix}-lg { 6 | font-size: (4em / 3); 7 | line-height: (3em / 4); 8 | vertical-align: -15%; 9 | } 10 | .@{fa-css-prefix}-2x { font-size: 2em; } 11 | .@{fa-css-prefix}-3x { font-size: 3em; } 12 | .@{fa-css-prefix}-4x { font-size: 4em; } 13 | .@{fa-css-prefix}-5x { font-size: 5em; } 14 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/list.less: -------------------------------------------------------------------------------- 1 | // List Icons 2 | // ------------------------- 3 | 4 | .@{fa-css-prefix}-ul { 5 | padding-left: 0; 6 | margin-left: @fa-li-width; 7 | list-style-type: none; 8 | > li { position: relative; } 9 | } 10 | .@{fa-css-prefix}-li { 11 | position: absolute; 12 | left: -@fa-li-width; 13 | width: @fa-li-width; 14 | top: (2em / 14); 15 | text-align: center; 16 | &.@{fa-css-prefix}-lg { 17 | left: (-@fa-li-width + (4em / 14)); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/mixins.less: -------------------------------------------------------------------------------- 1 | // Mixins 2 | // -------------------------- 3 | 4 | .fa-icon() { 5 | display: inline-block; 6 | font: normal normal normal @fa-font-size-base/@fa-line-height-base FontAwesome; // shortening font declaration 7 | font-size: inherit; // can't have font-size inherit on line above, so need to override 8 | text-rendering: auto; // optimizelegibility throws things off #1094 9 | -webkit-font-smoothing: antialiased; 10 | -moz-osx-font-smoothing: grayscale; 11 | 12 | } 13 | 14 | .fa-icon-rotate(@degrees, @rotation) { 15 | -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=@{rotation})"; 16 | -webkit-transform: rotate(@degrees); 17 | -ms-transform: rotate(@degrees); 18 | transform: rotate(@degrees); 19 | } 20 | 21 | .fa-icon-flip(@horiz, @vert, @rotation) { 22 | -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=@{rotation}, mirror=1)"; 23 | -webkit-transform: scale(@horiz, @vert); 24 | -ms-transform: scale(@horiz, @vert); 25 | transform: scale(@horiz, @vert); 26 | } 27 | 28 | 29 | // Only display content to screen readers. A la Bootstrap 4. 30 | // 31 | // See: http://a11yproject.com/posts/how-to-hide-content/ 32 | 33 | .sr-only() { 34 | position: absolute; 35 | width: 1px; 36 | height: 1px; 37 | padding: 0; 38 | margin: -1px; 39 | overflow: hidden; 40 | clip: rect(0,0,0,0); 41 | border: 0; 42 | } 43 | 44 | // Use in conjunction with .sr-only to only display content when it's focused. 45 | // 46 | // Useful for "Skip to main content" links; see http://www.w3.org/TR/2013/NOTE-WCAG20-TECHS-20130905/G1 47 | // 48 | // Credit: HTML5 Boilerplate 49 | 50 | .sr-only-focusable() { 51 | &:active, 52 | &:focus { 53 | position: static; 54 | width: auto; 55 | height: auto; 56 | margin: 0; 57 | overflow: visible; 58 | clip: auto; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/path.less: -------------------------------------------------------------------------------- 1 | /* FONT PATH 2 | * -------------------------- */ 3 | 4 | @font-face { 5 | font-family: 'FontAwesome'; 6 | src: url('@{fa-font-path}/fontawesome-webfont.eot?v=@{fa-version}'); 7 | src: url('@{fa-font-path}/fontawesome-webfont.eot?#iefix&v=@{fa-version}') format('embedded-opentype'), 8 | url('@{fa-font-path}/fontawesome-webfont.woff2?v=@{fa-version}') format('woff2'), 9 | url('@{fa-font-path}/fontawesome-webfont.woff?v=@{fa-version}') format('woff'), 10 | url('@{fa-font-path}/fontawesome-webfont.ttf?v=@{fa-version}') format('truetype'), 11 | url('@{fa-font-path}/fontawesome-webfont.svg?v=@{fa-version}#fontawesomeregular') format('svg'); 12 | // src: url('@{fa-font-path}/FontAwesome.otf') format('opentype'); // used when developing fonts 13 | font-weight: normal; 14 | font-style: normal; 15 | } 16 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/rotated-flipped.less: -------------------------------------------------------------------------------- 1 | // Rotated & Flipped Icons 2 | // ------------------------- 3 | 4 | .@{fa-css-prefix}-rotate-90 { .fa-icon-rotate(90deg, 1); } 5 | .@{fa-css-prefix}-rotate-180 { .fa-icon-rotate(180deg, 2); } 6 | .@{fa-css-prefix}-rotate-270 { .fa-icon-rotate(270deg, 3); } 7 | 8 | .@{fa-css-prefix}-flip-horizontal { .fa-icon-flip(-1, 1, 0); } 9 | .@{fa-css-prefix}-flip-vertical { .fa-icon-flip(1, -1, 2); } 10 | 11 | // Hook for IE8-9 12 | // ------------------------- 13 | 14 | :root .@{fa-css-prefix}-rotate-90, 15 | :root .@{fa-css-prefix}-rotate-180, 16 | :root .@{fa-css-prefix}-rotate-270, 17 | :root .@{fa-css-prefix}-flip-horizontal, 18 | :root .@{fa-css-prefix}-flip-vertical { 19 | filter: none; 20 | } 21 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/screen-reader.less: -------------------------------------------------------------------------------- 1 | // Screen Readers 2 | // ------------------------- 3 | 4 | .sr-only { .sr-only(); } 5 | .sr-only-focusable { .sr-only-focusable(); } 6 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/less/stacked.less: -------------------------------------------------------------------------------- 1 | // Stacked Icons 2 | // ------------------------- 3 | 4 | .@{fa-css-prefix}-stack { 5 | position: relative; 6 | display: inline-block; 7 | width: 2em; 8 | height: 2em; 9 | line-height: 2em; 10 | vertical-align: middle; 11 | } 12 | .@{fa-css-prefix}-stack-1x, .@{fa-css-prefix}-stack-2x { 13 | position: absolute; 14 | left: 0; 15 | width: 100%; 16 | text-align: center; 17 | } 18 | .@{fa-css-prefix}-stack-1x { line-height: inherit; } 19 | .@{fa-css-prefix}-stack-2x { font-size: 2em; } 20 | .@{fa-css-prefix}-inverse { color: @fa-inverse; } 21 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_animated.scss: -------------------------------------------------------------------------------- 1 | // Spinning Icons 2 | // -------------------------- 3 | 4 | .#{$fa-css-prefix}-spin { 5 | -webkit-animation: fa-spin 2s infinite linear; 6 | animation: fa-spin 2s infinite linear; 7 | } 8 | 9 | .#{$fa-css-prefix}-pulse { 10 | -webkit-animation: fa-spin 1s infinite steps(8); 11 | animation: fa-spin 1s infinite steps(8); 12 | } 13 | 14 | @-webkit-keyframes fa-spin { 15 | 0% { 16 | -webkit-transform: rotate(0deg); 17 | transform: rotate(0deg); 18 | } 19 | 100% { 20 | -webkit-transform: rotate(359deg); 21 | transform: rotate(359deg); 22 | } 23 | } 24 | 25 | @keyframes fa-spin { 26 | 0% { 27 | -webkit-transform: rotate(0deg); 28 | transform: rotate(0deg); 29 | } 30 | 100% { 31 | -webkit-transform: rotate(359deg); 32 | transform: rotate(359deg); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_bordered-pulled.scss: -------------------------------------------------------------------------------- 1 | // Bordered & Pulled 2 | // ------------------------- 3 | 4 | .#{$fa-css-prefix}-border { 5 | padding: .2em .25em .15em; 6 | border: solid .08em $fa-border-color; 7 | border-radius: .1em; 8 | } 9 | 10 | .#{$fa-css-prefix}-pull-left { float: left; } 11 | .#{$fa-css-prefix}-pull-right { float: right; } 12 | 13 | .#{$fa-css-prefix} { 14 | &.#{$fa-css-prefix}-pull-left { margin-right: .3em; } 15 | &.#{$fa-css-prefix}-pull-right { margin-left: .3em; } 16 | } 17 | 18 | /* Deprecated as of 4.4.0 */ 19 | .pull-right { float: right; } 20 | .pull-left { float: left; } 21 | 22 | .#{$fa-css-prefix} { 23 | &.pull-left { margin-right: .3em; } 24 | &.pull-right { margin-left: .3em; } 25 | } 26 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_core.scss: -------------------------------------------------------------------------------- 1 | // Base Class Definition 2 | // ------------------------- 3 | 4 | .#{$fa-css-prefix} { 5 | display: inline-block; 6 | font: normal normal normal #{$fa-font-size-base}/#{$fa-line-height-base} FontAwesome; // shortening font declaration 7 | font-size: inherit; // can't have font-size inherit on line above, so need to override 8 | text-rendering: auto; // optimizelegibility throws things off #1094 9 | -webkit-font-smoothing: antialiased; 10 | -moz-osx-font-smoothing: grayscale; 11 | 12 | } 13 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_fixed-width.scss: -------------------------------------------------------------------------------- 1 | // Fixed Width Icons 2 | // ------------------------- 3 | .#{$fa-css-prefix}-fw { 4 | width: (18em / 14); 5 | text-align: center; 6 | } 7 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_larger.scss: -------------------------------------------------------------------------------- 1 | // Icon Sizes 2 | // ------------------------- 3 | 4 | /* makes the font 33% larger relative to the icon container */ 5 | .#{$fa-css-prefix}-lg { 6 | font-size: (4em / 3); 7 | line-height: (3em / 4); 8 | vertical-align: -15%; 9 | } 10 | .#{$fa-css-prefix}-2x { font-size: 2em; } 11 | .#{$fa-css-prefix}-3x { font-size: 3em; } 12 | .#{$fa-css-prefix}-4x { font-size: 4em; } 13 | .#{$fa-css-prefix}-5x { font-size: 5em; } 14 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_list.scss: -------------------------------------------------------------------------------- 1 | // List Icons 2 | // ------------------------- 3 | 4 | .#{$fa-css-prefix}-ul { 5 | padding-left: 0; 6 | margin-left: $fa-li-width; 7 | list-style-type: none; 8 | > li { position: relative; } 9 | } 10 | .#{$fa-css-prefix}-li { 11 | position: absolute; 12 | left: -$fa-li-width; 13 | width: $fa-li-width; 14 | top: (2em / 14); 15 | text-align: center; 16 | &.#{$fa-css-prefix}-lg { 17 | left: -$fa-li-width + (4em / 14); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_mixins.scss: -------------------------------------------------------------------------------- 1 | // Mixins 2 | // -------------------------- 3 | 4 | @mixin fa-icon() { 5 | display: inline-block; 6 | font: normal normal normal #{$fa-font-size-base}/#{$fa-line-height-base} FontAwesome; // shortening font declaration 7 | font-size: inherit; // can't have font-size inherit on line above, so need to override 8 | text-rendering: auto; // optimizelegibility throws things off #1094 9 | -webkit-font-smoothing: antialiased; 10 | -moz-osx-font-smoothing: grayscale; 11 | 12 | } 13 | 14 | @mixin fa-icon-rotate($degrees, $rotation) { 15 | -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=#{$rotation})"; 16 | -webkit-transform: rotate($degrees); 17 | -ms-transform: rotate($degrees); 18 | transform: rotate($degrees); 19 | } 20 | 21 | @mixin fa-icon-flip($horiz, $vert, $rotation) { 22 | -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=#{$rotation}, mirror=1)"; 23 | -webkit-transform: scale($horiz, $vert); 24 | -ms-transform: scale($horiz, $vert); 25 | transform: scale($horiz, $vert); 26 | } 27 | 28 | 29 | // Only display content to screen readers. A la Bootstrap 4. 30 | // 31 | // See: http://a11yproject.com/posts/how-to-hide-content/ 32 | 33 | @mixin sr-only { 34 | position: absolute; 35 | width: 1px; 36 | height: 1px; 37 | padding: 0; 38 | margin: -1px; 39 | overflow: hidden; 40 | clip: rect(0,0,0,0); 41 | border: 0; 42 | } 43 | 44 | // Use in conjunction with .sr-only to only display content when it's focused. 45 | // 46 | // Useful for "Skip to main content" links; see http://www.w3.org/TR/2013/NOTE-WCAG20-TECHS-20130905/G1 47 | // 48 | // Credit: HTML5 Boilerplate 49 | 50 | @mixin sr-only-focusable { 51 | &:active, 52 | &:focus { 53 | position: static; 54 | width: auto; 55 | height: auto; 56 | margin: 0; 57 | overflow: visible; 58 | clip: auto; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_path.scss: -------------------------------------------------------------------------------- 1 | /* FONT PATH 2 | * -------------------------- */ 3 | 4 | @font-face { 5 | font-family: 'FontAwesome'; 6 | src: url('#{$fa-font-path}/fontawesome-webfont.eot?v=#{$fa-version}'); 7 | src: url('#{$fa-font-path}/fontawesome-webfont.eot?#iefix&v=#{$fa-version}') format('embedded-opentype'), 8 | url('#{$fa-font-path}/fontawesome-webfont.woff2?v=#{$fa-version}') format('woff2'), 9 | url('#{$fa-font-path}/fontawesome-webfont.woff?v=#{$fa-version}') format('woff'), 10 | url('#{$fa-font-path}/fontawesome-webfont.ttf?v=#{$fa-version}') format('truetype'), 11 | url('#{$fa-font-path}/fontawesome-webfont.svg?v=#{$fa-version}#fontawesomeregular') format('svg'); 12 | // src: url('#{$fa-font-path}/FontAwesome.otf') format('opentype'); // used when developing fonts 13 | font-weight: normal; 14 | font-style: normal; 15 | } 16 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_rotated-flipped.scss: -------------------------------------------------------------------------------- 1 | // Rotated & Flipped Icons 2 | // ------------------------- 3 | 4 | .#{$fa-css-prefix}-rotate-90 { @include fa-icon-rotate(90deg, 1); } 5 | .#{$fa-css-prefix}-rotate-180 { @include fa-icon-rotate(180deg, 2); } 6 | .#{$fa-css-prefix}-rotate-270 { @include fa-icon-rotate(270deg, 3); } 7 | 8 | .#{$fa-css-prefix}-flip-horizontal { @include fa-icon-flip(-1, 1, 0); } 9 | .#{$fa-css-prefix}-flip-vertical { @include fa-icon-flip(1, -1, 2); } 10 | 11 | // Hook for IE8-9 12 | // ------------------------- 13 | 14 | :root .#{$fa-css-prefix}-rotate-90, 15 | :root .#{$fa-css-prefix}-rotate-180, 16 | :root .#{$fa-css-prefix}-rotate-270, 17 | :root .#{$fa-css-prefix}-flip-horizontal, 18 | :root .#{$fa-css-prefix}-flip-vertical { 19 | filter: none; 20 | } 21 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_screen-reader.scss: -------------------------------------------------------------------------------- 1 | // Screen Readers 2 | // ------------------------- 3 | 4 | .sr-only { @include sr-only(); } 5 | .sr-only-focusable { @include sr-only-focusable(); } 6 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/_stacked.scss: -------------------------------------------------------------------------------- 1 | // Stacked Icons 2 | // ------------------------- 3 | 4 | .#{$fa-css-prefix}-stack { 5 | position: relative; 6 | display: inline-block; 7 | width: 2em; 8 | height: 2em; 9 | line-height: 2em; 10 | vertical-align: middle; 11 | } 12 | .#{$fa-css-prefix}-stack-1x, .#{$fa-css-prefix}-stack-2x { 13 | position: absolute; 14 | left: 0; 15 | width: 100%; 16 | text-align: center; 17 | } 18 | .#{$fa-css-prefix}-stack-1x { line-height: inherit; } 19 | .#{$fa-css-prefix}-stack-2x { font-size: 2em; } 20 | .#{$fa-css-prefix}-inverse { color: $fa-inverse; } 21 | -------------------------------------------------------------------------------- /docs/libs/external/font-awesome-4.7.0/scss/font-awesome.scss: -------------------------------------------------------------------------------- 1 | /*! 2 | * Font Awesome 4.7.0 by @davegandy - http://fontawesome.io - @fontawesome 3 | * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License) 4 | */ 5 | 6 | @import "variables"; 7 | @import "mixins"; 8 | @import "path"; 9 | @import "core"; 10 | @import "larger"; 11 | @import "fixed-width"; 12 | @import "list"; 13 | @import "bordered-pulled"; 14 | @import "animated"; 15 | @import "rotated-flipped"; 16 | @import "stacked"; 17 | @import "icons"; 18 | @import "screen-reader"; 19 | -------------------------------------------------------------------------------- /docs/libs/external/skeleton/normalize.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v3.0.2 | MIT License | git.io/normalize */ 2 | 3 | /** 4 | * 1. Set default font family to sans-serif. 5 | * 2. Prevent iOS text size adjust after orientation change, without disabling 6 | * user zoom. 7 | */ 8 | 9 | html { 10 | font-family: sans-serif; /* 1 */ 11 | -ms-text-size-adjust: 100%; /* 2 */ 12 | -webkit-text-size-adjust: 100%; /* 2 */ 13 | } 14 | 15 | /** 16 | * Remove default margin. 17 | */ 18 | 19 | body { 20 | margin: 0; 21 | } 22 | 23 | /* HTML5 display definitions 24 | ========================================================================== */ 25 | 26 | /** 27 | * Correct `block` display not defined for any HTML5 element in IE 8/9. 28 | * Correct `block` display not defined for `details` or `summary` in IE 10/11 29 | * and Firefox. 30 | * Correct `block` display not defined for `main` in IE 11. 31 | */ 32 | 33 | article, 34 | aside, 35 | details, 36 | figcaption, 37 | figure, 38 | footer, 39 | header, 40 | hgroup, 41 | main, 42 | menu, 43 | nav, 44 | section, 45 | summary { 46 | display: block; 47 | } 48 | 49 | /** 50 | * 1. Correct `inline-block` display not defined in IE 8/9. 51 | * 2. Normalize vertical alignment of `progress` in Chrome, Firefox, and Opera. 52 | */ 53 | 54 | audio, 55 | canvas, 56 | progress, 57 | video { 58 | display: inline-block; /* 1 */ 59 | vertical-align: baseline; /* 2 */ 60 | } 61 | 62 | /** 63 | * Prevent modern browsers from displaying `audio` without controls. 64 | * Remove excess height in iOS 5 devices. 65 | */ 66 | 67 | audio:not([controls]) { 68 | display: none; 69 | height: 0; 70 | } 71 | 72 | /** 73 | * Address `[hidden]` styling not present in IE 8/9/10. 74 | * Hide the `template` element in IE 8/9/11, Safari, and Firefox < 22. 75 | */ 76 | 77 | [hidden], 78 | template { 79 | display: none; 80 | } 81 | 82 | /* Links 83 | ========================================================================== */ 84 | 85 | /** 86 | * Remove the gray background color from active links in IE 10. 87 | */ 88 | 89 | a { 90 | background-color: transparent; 91 | } 92 | 93 | /** 94 | * Improve readability when focused and also mouse hovered in all browsers. 95 | */ 96 | 97 | a:active, 98 | a:hover { 99 | outline: 0; 100 | } 101 | 102 | /* Text-level semantics 103 | ========================================================================== */ 104 | 105 | /** 106 | * Address styling not present in IE 8/9/10/11, Safari, and Chrome. 107 | */ 108 | 109 | abbr[title] { 110 | border-bottom: 1px dotted; 111 | } 112 | 113 | /** 114 | * Address style set to `bolder` in Firefox 4+, Safari, and Chrome. 115 | */ 116 | 117 | b, 118 | strong { 119 | font-weight: bold; 120 | } 121 | 122 | /** 123 | * Address styling not present in Safari and Chrome. 124 | */ 125 | 126 | dfn { 127 | font-style: italic; 128 | } 129 | 130 | /** 131 | * Address variable `h1` font-size and margin within `section` and `article` 132 | * contexts in Firefox 4+, Safari, and Chrome. 133 | */ 134 | 135 | h1 { 136 | font-size: 2em; 137 | margin: 0.67em 0; 138 | } 139 | 140 | /** 141 | * Address styling not present in IE 8/9. 142 | */ 143 | 144 | mark { 145 | background: #ff0; 146 | color: #000; 147 | } 148 | 149 | /** 150 | * Address inconsistent and variable font size in all browsers. 151 | */ 152 | 153 | small { 154 | font-size: 80%; 155 | } 156 | 157 | /** 158 | * Prevent `sub` and `sup` affecting `line-height` in all browsers. 159 | */ 160 | 161 | sub, 162 | sup { 163 | font-size: 75%; 164 | line-height: 0; 165 | position: relative; 166 | vertical-align: baseline; 167 | } 168 | 169 | sup { 170 | top: -0.5em; 171 | } 172 | 173 | sub { 174 | bottom: -0.25em; 175 | } 176 | 177 | /* Embedded content 178 | ========================================================================== */ 179 | 180 | /** 181 | * Remove border when inside `a` element in IE 8/9/10. 182 | */ 183 | 184 | img { 185 | border: 0; 186 | } 187 | 188 | /** 189 | * Correct overflow not hidden in IE 9/10/11. 190 | */ 191 | 192 | svg:not(:root) { 193 | overflow: hidden; 194 | } 195 | 196 | /* Grouping content 197 | ========================================================================== */ 198 | 199 | /** 200 | * Address margin not present in IE 8/9 and Safari. 201 | */ 202 | 203 | figure { 204 | margin: 1em 40px; 205 | } 206 | 207 | /** 208 | * Address differences between Firefox and other browsers. 209 | */ 210 | 211 | hr { 212 | -moz-box-sizing: content-box; 213 | box-sizing: content-box; 214 | height: 0; 215 | } 216 | 217 | /** 218 | * Contain overflow in all browsers. 219 | */ 220 | 221 | pre { 222 | overflow: auto; 223 | } 224 | 225 | /** 226 | * Address odd `em`-unit font size rendering in all browsers. 227 | */ 228 | 229 | code, 230 | kbd, 231 | pre, 232 | samp { 233 | font-family: monospace, monospace; 234 | font-size: 1em; 235 | } 236 | 237 | /* Forms 238 | ========================================================================== */ 239 | 240 | /** 241 | * Known limitation: by default, Chrome and Safari on OS X allow very limited 242 | * styling of `select`, unless a `border` property is set. 243 | */ 244 | 245 | /** 246 | * 1. Correct color not being inherited. 247 | * Known issue: affects color of disabled elements. 248 | * 2. Correct font properties not being inherited. 249 | * 3. Address margins set differently in Firefox 4+, Safari, and Chrome. 250 | */ 251 | 252 | button, 253 | input, 254 | optgroup, 255 | select, 256 | textarea { 257 | color: inherit; /* 1 */ 258 | font: inherit; /* 2 */ 259 | margin: 0; /* 3 */ 260 | } 261 | 262 | /** 263 | * Address `overflow` set to `hidden` in IE 8/9/10/11. 264 | */ 265 | 266 | button { 267 | overflow: visible; 268 | } 269 | 270 | /** 271 | * Address inconsistent `text-transform` inheritance for `button` and `select`. 272 | * All other form control elements do not inherit `text-transform` values. 273 | * Correct `button` style inheritance in Firefox, IE 8/9/10/11, and Opera. 274 | * Correct `select` style inheritance in Firefox. 275 | */ 276 | 277 | button, 278 | select { 279 | text-transform: none; 280 | } 281 | 282 | /** 283 | * 1. Avoid the WebKit bug in Android 4.0.* where (2) destroys native `audio` 284 | * and `video` controls. 285 | * 2. Correct inability to style clickable `input` types in iOS. 286 | * 3. Improve usability and consistency of cursor style between image-type 287 | * `input` and others. 288 | */ 289 | 290 | button, 291 | html input[type="button"], /* 1 */ 292 | input[type="reset"], 293 | input[type="submit"] { 294 | -webkit-appearance: button; /* 2 */ 295 | cursor: pointer; /* 3 */ 296 | } 297 | 298 | /** 299 | * Re-set default cursor for disabled elements. 300 | */ 301 | 302 | button[disabled], 303 | html input[disabled] { 304 | cursor: default; 305 | } 306 | 307 | /** 308 | * Remove inner padding and border in Firefox 4+. 309 | */ 310 | 311 | button::-moz-focus-inner, 312 | input::-moz-focus-inner { 313 | border: 0; 314 | padding: 0; 315 | } 316 | 317 | /** 318 | * Address Firefox 4+ setting `line-height` on `input` using `!important` in 319 | * the UA stylesheet. 320 | */ 321 | 322 | input { 323 | line-height: normal; 324 | } 325 | 326 | /** 327 | * It's recommended that you don't attempt to style these elements. 328 | * Firefox's implementation doesn't respect box-sizing, padding, or width. 329 | * 330 | * 1. Address box sizing set to `content-box` in IE 8/9/10. 331 | * 2. Remove excess padding in IE 8/9/10. 332 | */ 333 | 334 | input[type="checkbox"], 335 | input[type="radio"] { 336 | box-sizing: border-box; /* 1 */ 337 | padding: 0; /* 2 */ 338 | } 339 | 340 | /** 341 | * Fix the cursor style for Chrome's increment/decrement buttons. For certain 342 | * `font-size` values of the `input`, it causes the cursor style of the 343 | * decrement button to change from `default` to `text`. 344 | */ 345 | 346 | input[type="number"]::-webkit-inner-spin-button, 347 | input[type="number"]::-webkit-outer-spin-button { 348 | height: auto; 349 | } 350 | 351 | /** 352 | * 1. Address `appearance` set to `searchfield` in Safari and Chrome. 353 | * 2. Address `box-sizing` set to `border-box` in Safari and Chrome 354 | * (include `-moz` to future-proof). 355 | */ 356 | 357 | input[type="search"] { 358 | -webkit-appearance: textfield; /* 1 */ 359 | -moz-box-sizing: content-box; 360 | -webkit-box-sizing: content-box; /* 2 */ 361 | box-sizing: content-box; 362 | } 363 | 364 | /** 365 | * Remove inner padding and search cancel button in Safari and Chrome on OS X. 366 | * Safari (but not Chrome) clips the cancel button when the search input has 367 | * padding (and `textfield` appearance). 368 | */ 369 | 370 | input[type="search"]::-webkit-search-cancel-button, 371 | input[type="search"]::-webkit-search-decoration { 372 | -webkit-appearance: none; 373 | } 374 | 375 | /** 376 | * Define consistent border, margin, and padding. 377 | */ 378 | 379 | fieldset { 380 | border: 1px solid #c0c0c0; 381 | margin: 0 2px; 382 | padding: 0.35em 0.625em 0.75em; 383 | } 384 | 385 | /** 386 | * 1. Correct `color` not being inherited in IE 8/9/10/11. 387 | * 2. Remove padding so people aren't caught out if they zero out fieldsets. 388 | */ 389 | 390 | legend { 391 | border: 0; /* 1 */ 392 | padding: 0; /* 2 */ 393 | } 394 | 395 | /** 396 | * Remove default vertical scrollbar in IE 8/9/10/11. 397 | */ 398 | 399 | textarea { 400 | overflow: auto; 401 | } 402 | 403 | /** 404 | * Don't inherit the `font-weight` (applied by a rule above). 405 | * NOTE: the default cannot safely be changed in Chrome and Safari on OS X. 406 | */ 407 | 408 | optgroup { 409 | font-weight: bold; 410 | } 411 | 412 | /* Tables 413 | ========================================================================== */ 414 | 415 | /** 416 | * Remove most spacing between table cells. 417 | */ 418 | 419 | table { 420 | border-collapse: collapse; 421 | border-spacing: 0; 422 | } 423 | 424 | td, 425 | th { 426 | padding: 0; 427 | } -------------------------------------------------------------------------------- /docs/libs/external/skeleton/skeleton.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Skeleton V2.0.4 3 | * Copyright 2014, Dave Gamache 4 | * www.getskeleton.com 5 | * Free to use under the MIT license. 6 | * http://www.opensource.org/licenses/mit-license.php 7 | * 12/29/2014 8 | */ 9 | 10 | 11 | /* Table of contents 12 | –––––––––––––––––––––––––––––––––––––––––––––––––– 13 | - Grid 14 | - Base Styles 15 | - Typography 16 | - Links 17 | - Buttons 18 | - Forms 19 | - Lists 20 | - Code 21 | - Tables 22 | - Spacing 23 | - Utilities 24 | - Clearing 25 | - Media Queries 26 | */ 27 | 28 | 29 | /* Grid 30 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 31 | .container { 32 | position: relative; 33 | width: 100%; 34 | max-width: 960px; 35 | margin: 0 auto; 36 | padding: 0 20px; 37 | box-sizing: border-box; } 38 | .column, 39 | .columns { 40 | width: 100%; 41 | float: left; 42 | box-sizing: border-box; } 43 | 44 | /* For devices larger than 400px */ 45 | @media (min-width: 400px) { 46 | .container { 47 | width: 85%; 48 | padding: 0; } 49 | } 50 | 51 | /* For devices larger than 550px */ 52 | @media (min-width: 550px) { 53 | .container { 54 | width: 80%; } 55 | .column, 56 | .columns { 57 | margin-left: 4%; } 58 | .column:first-child, 59 | .columns:first-child { 60 | margin-left: 0; } 61 | 62 | .one.column, 63 | .one.columns { width: 4.66666666667%; } 64 | .two.columns { width: 13.3333333333%; } 65 | .three.columns { width: 22%; } 66 | .four.columns { width: 30.6666666667%; } 67 | .five.columns { width: 39.3333333333%; } 68 | .six.columns { width: 48%; } 69 | .seven.columns { width: 56.6666666667%; } 70 | .eight.columns { width: 65.3333333333%; } 71 | .nine.columns { width: 74.0%; } 72 | .ten.columns { width: 82.6666666667%; } 73 | .eleven.columns { width: 91.3333333333%; } 74 | .twelve.columns { width: 100%; margin-left: 0; } 75 | 76 | .one-third.column { width: 30.6666666667%; } 77 | .two-thirds.column { width: 65.3333333333%; } 78 | 79 | .one-half.column { width: 48%; } 80 | 81 | /* Offsets */ 82 | .offset-by-one.column, 83 | .offset-by-one.columns { margin-left: 8.66666666667%; } 84 | .offset-by-two.column, 85 | .offset-by-two.columns { margin-left: 17.3333333333%; } 86 | .offset-by-three.column, 87 | .offset-by-three.columns { margin-left: 26%; } 88 | .offset-by-four.column, 89 | .offset-by-four.columns { margin-left: 34.6666666667%; } 90 | .offset-by-five.column, 91 | .offset-by-five.columns { margin-left: 43.3333333333%; } 92 | .offset-by-six.column, 93 | .offset-by-six.columns { margin-left: 52%; } 94 | .offset-by-seven.column, 95 | .offset-by-seven.columns { margin-left: 60.6666666667%; } 96 | .offset-by-eight.column, 97 | .offset-by-eight.columns { margin-left: 69.3333333333%; } 98 | .offset-by-nine.column, 99 | .offset-by-nine.columns { margin-left: 78.0%; } 100 | .offset-by-ten.column, 101 | .offset-by-ten.columns { margin-left: 86.6666666667%; } 102 | .offset-by-eleven.column, 103 | .offset-by-eleven.columns { margin-left: 95.3333333333%; } 104 | 105 | .offset-by-one-third.column, 106 | .offset-by-one-third.columns { margin-left: 34.6666666667%; } 107 | .offset-by-two-thirds.column, 108 | .offset-by-two-thirds.columns { margin-left: 69.3333333333%; } 109 | 110 | .offset-by-one-half.column, 111 | .offset-by-one-half.columns { margin-left: 52%; } 112 | 113 | } 114 | 115 | 116 | /* Base Styles 117 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 118 | /* NOTE 119 | html is set to 62.5% so that all the REM measurements throughout Skeleton 120 | are based on 10px sizing. So basically 1.5rem = 15px :) */ 121 | html { 122 | font-size: 62.5%; } 123 | body { 124 | font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */ 125 | line-height: 1.6; 126 | font-weight: 400; 127 | font-family: "Raleway", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif; 128 | color: #222; } 129 | 130 | 131 | /* Typography 132 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 133 | h1, h2, h3, h4, h5, h6 { 134 | margin-top: 0; 135 | margin-bottom: 2rem; 136 | font-weight: 300; } 137 | h1 { font-size: 4.0rem; line-height: 1.2; letter-spacing: -.1rem;} 138 | h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; } 139 | h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; } 140 | h4 { font-size: 2.4rem; line-height: 1.35; letter-spacing: -.08rem; } 141 | h5 { font-size: 1.8rem; line-height: 1.5; letter-spacing: -.05rem; } 142 | h6 { font-size: 1.5rem; line-height: 1.6; letter-spacing: 0; } 143 | 144 | /* Larger than phablet */ 145 | @media (min-width: 550px) { 146 | h1 { font-size: 5.0rem; } 147 | h2 { font-size: 4.2rem; } 148 | h3 { font-size: 3.6rem; } 149 | h4 { font-size: 3.0rem; } 150 | h5 { font-size: 2.4rem; } 151 | h6 { font-size: 1.5rem; } 152 | } 153 | 154 | p { 155 | margin-top: 0; } 156 | 157 | 158 | /* Links 159 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 160 | a { 161 | color: #1EAEDB; } 162 | a:hover { 163 | color: #0FA0CE; } 164 | 165 | 166 | /* Buttons 167 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 168 | .button, 169 | button, 170 | input[type="submit"], 171 | input[type="reset"], 172 | input[type="button"] { 173 | display: inline-block; 174 | height: 38px; 175 | padding: 0 30px; 176 | color: #555; 177 | text-align: center; 178 | font-size: 11px; 179 | font-weight: 600; 180 | line-height: 38px; 181 | letter-spacing: .1rem; 182 | text-transform: uppercase; 183 | text-decoration: none; 184 | white-space: nowrap; 185 | background-color: transparent; 186 | border-radius: 4px; 187 | border: 1px solid #bbb; 188 | cursor: pointer; 189 | box-sizing: border-box; } 190 | .button:hover, 191 | button:hover, 192 | input[type="submit"]:hover, 193 | input[type="reset"]:hover, 194 | input[type="button"]:hover, 195 | .button:focus, 196 | button:focus, 197 | input[type="submit"]:focus, 198 | input[type="reset"]:focus, 199 | input[type="button"]:focus { 200 | color: #333; 201 | border-color: #888; 202 | outline: 0; } 203 | .button.button-primary, 204 | button.button-primary, 205 | input[type="submit"].button-primary, 206 | input[type="reset"].button-primary, 207 | input[type="button"].button-primary { 208 | color: #FFF; 209 | background-color: #33C3F0; 210 | border-color: #33C3F0; } 211 | .button.button-primary:hover, 212 | button.button-primary:hover, 213 | input[type="submit"].button-primary:hover, 214 | input[type="reset"].button-primary:hover, 215 | input[type="button"].button-primary:hover, 216 | .button.button-primary:focus, 217 | button.button-primary:focus, 218 | input[type="submit"].button-primary:focus, 219 | input[type="reset"].button-primary:focus, 220 | input[type="button"].button-primary:focus { 221 | color: #FFF; 222 | background-color: #1EAEDB; 223 | border-color: #1EAEDB; } 224 | 225 | 226 | /* Forms 227 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 228 | input[type="email"], 229 | input[type="number"], 230 | input[type="search"], 231 | input[type="text"], 232 | input[type="tel"], 233 | input[type="url"], 234 | input[type="password"], 235 | textarea, 236 | select { 237 | height: 38px; 238 | padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */ 239 | background-color: #fff; 240 | border: 1px solid #D1D1D1; 241 | border-radius: 4px; 242 | box-shadow: none; 243 | box-sizing: border-box; } 244 | /* Removes awkward default styles on some inputs for iOS */ 245 | input[type="email"], 246 | input[type="number"], 247 | input[type="search"], 248 | input[type="text"], 249 | input[type="tel"], 250 | input[type="url"], 251 | input[type="password"], 252 | textarea { 253 | -webkit-appearance: none; 254 | -moz-appearance: none; 255 | appearance: none; } 256 | textarea { 257 | min-height: 65px; 258 | padding-top: 6px; 259 | padding-bottom: 6px; } 260 | input[type="email"]:focus, 261 | input[type="number"]:focus, 262 | input[type="search"]:focus, 263 | input[type="text"]:focus, 264 | input[type="tel"]:focus, 265 | input[type="url"]:focus, 266 | input[type="password"]:focus, 267 | textarea:focus, 268 | select:focus { 269 | border: 1px solid #33C3F0; 270 | outline: 0; } 271 | label, 272 | legend { 273 | display: block; 274 | margin-bottom: .5rem; 275 | font-weight: 600; } 276 | fieldset { 277 | padding: 0; 278 | border-width: 0; } 279 | input[type="checkbox"], 280 | input[type="radio"] { 281 | display: inline; } 282 | label > .label-body { 283 | display: inline-block; 284 | margin-left: .5rem; 285 | font-weight: normal; } 286 | 287 | 288 | /* Lists 289 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 290 | ul { 291 | list-style: circle inside; } 292 | ol { 293 | list-style: decimal inside; } 294 | ol, ul { 295 | padding-left: 0; 296 | margin-top: 0; } 297 | ul ul, 298 | ul ol, 299 | ol ol, 300 | ol ul { 301 | margin: 1.5rem 0 1.5rem 3rem; 302 | font-size: 90%; } 303 | li { 304 | margin-bottom: 1rem; } 305 | 306 | 307 | /* Code 308 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 309 | code { 310 | padding: .2rem .5rem; 311 | margin: 0 .2rem; 312 | font-size: 90%; 313 | white-space: nowrap; 314 | background: #F1F1F1; 315 | border: 1px solid #E1E1E1; 316 | border-radius: 4px; } 317 | pre > code { 318 | display: block; 319 | padding: 1rem 1.5rem; 320 | white-space: pre; } 321 | 322 | 323 | /* Tables 324 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 325 | th, 326 | td { 327 | padding: 12px 15px; 328 | text-align: left; 329 | border-bottom: 1px solid #E1E1E1; } 330 | th:first-child, 331 | td:first-child { 332 | padding-left: 0; } 333 | th:last-child, 334 | td:last-child { 335 | padding-right: 0; } 336 | 337 | 338 | /* Spacing 339 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 340 | button, 341 | .button { 342 | margin-bottom: 1rem; } 343 | input, 344 | textarea, 345 | select, 346 | fieldset { 347 | margin-bottom: 1.5rem; } 348 | pre, 349 | blockquote, 350 | dl, 351 | figure, 352 | table, 353 | p, 354 | ul, 355 | ol, 356 | form { 357 | margin-bottom: 2.5rem; } 358 | 359 | 360 | /* Utilities 361 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 362 | .u-full-width { 363 | width: 100%; 364 | box-sizing: border-box; } 365 | .u-max-full-width { 366 | max-width: 100%; 367 | box-sizing: border-box; } 368 | .u-pull-right { 369 | float: right; } 370 | .u-pull-left { 371 | float: left; } 372 | 373 | 374 | /* Misc 375 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 376 | hr { 377 | margin-top: 3rem; 378 | margin-bottom: 3.5rem; 379 | border-width: 0; 380 | border-top: 1px solid #E1E1E1; } 381 | 382 | 383 | /* Clearing 384 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 385 | 386 | /* Self Clearing Goodness */ 387 | .container:after, 388 | .row:after, 389 | .u-cf { 390 | content: ""; 391 | display: table; 392 | clear: both; } 393 | 394 | 395 | /* Media Queries 396 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 397 | /* 398 | Note: The best way to structure the use of media queries is to create the queries 399 | near the relevant code. For example, if you wanted to change the styles for buttons 400 | on small devices, paste the mobile query code up in the buttons section and style it 401 | there. 402 | */ 403 | 404 | 405 | /* Larger than mobile */ 406 | @media (min-width: 400px) {} 407 | 408 | /* Larger than phablet (also point when grid becomes active) */ 409 | @media (min-width: 550px) {} 410 | 411 | /* Larger than tablet */ 412 | @media (min-width: 750px) {} 413 | 414 | /* Larger than desktop */ 415 | @media (min-width: 1000px) {} 416 | 417 | /* Larger than Desktop HD */ 418 | @media (min-width: 1200px) {} 419 | -------------------------------------------------------------------------------- /docs/libs/external/skeleton_tabs/skeleton-tabs.css: -------------------------------------------------------------------------------- 1 | 2 | ul.tab-nav { 3 | list-style: none; 4 | border-bottom: 1px solid #bbb; 5 | padding-left: 5px; 6 | } 7 | 8 | ul.tab-nav li { 9 | display: inline; 10 | } 11 | 12 | ul.tab-nav li .button { 13 | font-size: 13px; 14 | font-weight: 500; 15 | border-bottom-left-radius: 0; 16 | border-bottom-right-radius: 0; 17 | margin-bottom: -1px; 18 | border-bottom: none; 19 | } 20 | 21 | ul.tab-nav li .active.button { 22 | border-bottom: 1px solid #fff; 23 | color: #000; 24 | font-weight: 600; 25 | } 26 | 27 | .tab-content .tab-pane { 28 | display: none; 29 | } 30 | 31 | .tab-content .tab-pane.active { 32 | display: block; 33 | } -------------------------------------------------------------------------------- /docs/libs/external/skeleton_tabs/skeleton-tabs.js: -------------------------------------------------------------------------------- 1 | 2 | $(function() { 3 | $('ul.tab-nav li .button').click(function() { 4 | var href = $(this).attr('data-ref'); 5 | 6 | $('li .active.button', $(this).parent().parent()).removeClass('active'); 7 | $(this).addClass('active'); 8 | 9 | $('.tab-pane.active', $(href).parent()).removeClass('active'); 10 | $(href).addClass('active'); 11 | 12 | /* 13 | var toScroll = $(this).parent().parent().parent().parent(); 14 | 15 | $('html, body').animate({ 16 | scrollTop: toScroll.offset().top 17 | }, 1000); 18 | */ 19 | 20 | return false; 21 | }); 22 | }); 23 | -------------------------------------------------------------------------------- /docs/libs/external/timeline.css: -------------------------------------------------------------------------------- 1 | /* ================ The Timeline ================ */ 2 | .timeline { 3 | position: relative; 4 | width: 800px; 5 | margin: 0 auto; 6 | margin-top: 20px; 7 | padding: 1em 0; 8 | list-style-type: none; 9 | } 10 | 11 | .timeline:before { 12 | position: absolute; 13 | left: 50%; 14 | top: 0; 15 | content: ' '; 16 | display: block; 17 | width: 6px; 18 | height: 100%; 19 | margin-left: -3px; 20 | background: rgb(80,80,80); 21 | background: -moz-linear-gradient(top, rgba(80,80,80,0) 0%, rgb(80,80,80) 8%, rgb(80,80,80) 92%, rgba(80,80,80,0) 100%); 22 | background: -webkit-gradient(linear, left top, left bottom, color-stop(0%,rgba(30,87,153,1)), color-stop(100%,rgba(125,185,232,1))); 23 | background: -webkit-linear-gradient(top, rgba(80,80,80,0) 0%, rgb(80,80,80) 8%, rgb(80,80,80) 92%, rgba(80,80,80,0) 100%); 24 | background: -o-linear-gradient(top, rgba(80,80,80,0) 0%, rgb(80,80,80) 8%, rgb(80,80,80) 92%, rgba(80,80,80,0) 100%); 25 | background: -ms-linear-gradient(top, rgba(80,80,80,0) 0%, rgb(80,80,80) 8%, rgb(80,80,80) 92%, rgba(80,80,80,0) 100%); 26 | background: linear-gradient(to bottom, rgba(80,80,80,0) 0%, rgb(80,80,80) 8%, rgb(80,80,80) 92%, rgba(80,80,80,0) 100%); 27 | 28 | z-index: 5; 29 | } 30 | 31 | .timeline li { 32 | padding: 1em 0; 33 | } 34 | 35 | .timeline li:after { 36 | content: ""; 37 | display: block; 38 | height: 0; 39 | clear: both; 40 | visibility: hidden; 41 | } 42 | 43 | .direction-l { 44 | position: relative; 45 | width: 370px; 46 | float: left; 47 | text-align: right; 48 | } 49 | 50 | .direction-r { 51 | position: relative; 52 | width: 370px; 53 | float: right; 54 | } 55 | 56 | .flag-wrapper { 57 | position: relative; 58 | display: inline-block; 59 | 60 | text-align: center; 61 | } 62 | 63 | .flag { 64 | position: relative; 65 | display: inline; 66 | background: rgb(248,248,248); 67 | padding: 6px 10px; 68 | border-radius: 5px; 69 | 70 | font-weight: 600; 71 | text-align: left; 72 | } 73 | 74 | .direction-l .flag { 75 | -webkit-box-shadow: -1px 1px 1px rgba(0,0,0,0.15), 0 0 1px rgba(0,0,0,0.15); 76 | -moz-box-shadow: -1px 1px 1px rgba(0,0,0,0.15), 0 0 1px rgba(0,0,0,0.15); 77 | box-shadow: -1px 1px 1px rgba(0,0,0,0.15), 0 0 1px rgba(0,0,0,0.15); 78 | } 79 | 80 | .direction-r .flag { 81 | -webkit-box-shadow: 1px 1px 1px rgba(0,0,0,0.15), 0 0 1px rgba(0,0,0,0.15); 82 | -moz-box-shadow: 1px 1px 1px rgba(0,0,0,0.15), 0 0 1px rgba(0,0,0,0.15); 83 | box-shadow: 1px 1px 1px rgba(0,0,0,0.15), 0 0 1px rgba(0,0,0,0.15); 84 | } 85 | 86 | .direction-l .flag:before, 87 | .direction-r .flag:before { 88 | position: absolute; 89 | top: 50%; 90 | right: -40px; 91 | content: ' '; 92 | display: block; 93 | width: 12px; 94 | height: 12px; 95 | margin-top: -10px; 96 | background: #fff; 97 | border-radius: 10px; 98 | border: 4px solid rgb(255,80,80); 99 | z-index: 10; 100 | } 101 | 102 | .direction-r .flag:before { 103 | left: -40px; 104 | } 105 | 106 | .direction-l .flag:after { 107 | content: ""; 108 | position: absolute; 109 | left: 100%; 110 | top: 50%; 111 | height: 0; 112 | width: 0; 113 | margin-top: -8px; 114 | border: solid transparent; 115 | border-left-color: rgb(248,248,248); 116 | border-width: 8px; 117 | pointer-events: none; 118 | } 119 | 120 | .direction-r .flag:after { 121 | content: ""; 122 | position: absolute; 123 | right: 100%; 124 | top: 50%; 125 | height: 0; 126 | width: 0; 127 | margin-top: -8px; 128 | border: solid transparent; 129 | border-right-color: rgb(248,248,248); 130 | border-width: 8px; 131 | pointer-events: none; 132 | } 133 | 134 | .time-wrapper { 135 | display: inline; 136 | 137 | line-height: 1em; 138 | font-size: 0.66666em; 139 | color: rgb(250,80,80); 140 | vertical-align: middle; 141 | } 142 | 143 | .direction-l .time-wrapper { 144 | float: left; 145 | } 146 | 147 | .direction-r .time-wrapper { 148 | float: right; 149 | } 150 | 151 | .time { 152 | display: inline-block; 153 | padding: 4px 6px; 154 | background: rgb(248,248,248); 155 | } 156 | 157 | .desc { 158 | margin: 1em 0.75em 0 0; 159 | font-size: 0.77777em; 160 | /* font-style: italic; */ 161 | line-height: 1.5em; 162 | } 163 | 164 | .direction-r .desc { 165 | margin: 1em 0 0 0.75em; 166 | } 167 | 168 | /* ================ Timeline Media Queries ================ */ 169 | 170 | @media screen and (max-width: 1000px) { 171 | .timeline { 172 | width: 640px; 173 | } 174 | 175 | .direction-l { 176 | position: relative; 177 | width: 290px; 178 | float: left; 179 | text-align: right; 180 | } 181 | 182 | .direction-r { 183 | position: relative; 184 | width: 290px; 185 | float: right; 186 | } 187 | } 188 | 189 | @media screen and (max-width: 800px) { 190 | .timeline { 191 | width: 620px; 192 | } 193 | 194 | .direction-l { 195 | position: relative; 196 | width: 280px; 197 | float: left; 198 | text-align: right; 199 | } 200 | 201 | .direction-r { 202 | position: relative; 203 | width: 280px; 204 | float: right; 205 | } 206 | } 207 | 208 | @media screen and (max-width: 660px) { 209 | 210 | .timeline { 211 | width: 100%; 212 | padding: 4em 0 1em 0; 213 | } 214 | 215 | .timeline li { 216 | padding: 2em 0; 217 | } 218 | 219 | .direction-l, 220 | .direction-r { 221 | float: none; 222 | width: 100%; 223 | 224 | text-align: center; 225 | } 226 | 227 | .flag-wrapper { 228 | text-align: center; 229 | } 230 | 231 | .flag { 232 | background: rgb(255,255,255); 233 | z-index: 15; 234 | } 235 | 236 | .direction-l .flag:before, 237 | .direction-r .flag:before { 238 | position: absolute; 239 | top: -30px; 240 | left: 50%; 241 | content: ' '; 242 | display: block; 243 | width: 12px; 244 | height: 12px; 245 | margin-left: -9px; 246 | background: #fff; 247 | border-radius: 10px; 248 | border: 4px solid rgb(255,80,80); 249 | z-index: 10; 250 | } 251 | 252 | .direction-l .flag:after, 253 | .direction-r .flag:after { 254 | content: ""; 255 | position: absolute; 256 | left: 50%; 257 | top: -8px; 258 | height: 0; 259 | width: 0; 260 | margin-left: -8px; 261 | border: solid transparent; 262 | border-bottom-color: rgb(255,255,255); 263 | border-width: 8px; 264 | pointer-events: none; 265 | } 266 | 267 | .time-wrapper { 268 | display: block; 269 | position: relative; 270 | margin: 4px 0 0 0; 271 | z-index: 14; 272 | } 273 | 274 | .direction-l .time-wrapper { 275 | float: none; 276 | } 277 | 278 | .direction-r .time-wrapper { 279 | float: none; 280 | } 281 | 282 | .desc { 283 | position: relative; 284 | margin: 1em 0 0 0; 285 | padding: 1em; 286 | background: rgb(245,245,245); 287 | -webkit-box-shadow: 0 0 1px rgba(0,0,0,0.20); 288 | -moz-box-shadow: 0 0 1px rgba(0,0,0,0.20); 289 | box-shadow: 0 0 1px rgba(0,0,0,0.20); 290 | 291 | z-index: 15; 292 | } 293 | 294 | .direction-l .desc, 295 | .direction-r .desc { 296 | position: relative; 297 | margin: 1em 1em 0 1em; 298 | padding: 1em; 299 | 300 | z-index: 15; 301 | } 302 | } 303 | 304 | @media screen and (min-width: 400px ?? max-width: 660px) { 305 | .direction-l .desc, 306 | .direction-r .desc { 307 | margin: 1em 4em 0 4em; 308 | } 309 | } 310 | -------------------------------------------------------------------------------- /docs/libs/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMBCvision/CompRess/c5e57edce75da96482fd36eac484c5aca9676945/docs/libs/icon.png -------------------------------------------------------------------------------- /eval_knn.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | from collections import Counter, OrderedDict 3 | from random import shuffle 4 | import argparse 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.optim 16 | import torch.utils.data 17 | from torch.utils.data import DataLoader 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | import torch.nn.functional as F 22 | import numpy as np 23 | import faiss 24 | 25 | from tools import * 26 | from models.resnet import resnet18, resnet50 27 | from models.alexnet import AlexNet as alexnet 28 | from models.mobilenet import MobileNetV2 as mobilenet 29 | from models.resnet_swav import resnet50w5 30 | from eval_linear import load_weights 31 | 32 | 33 | parser = argparse.ArgumentParser(description='NN evaluation') 34 | parser.add_argument('data', metavar='DIR', help='path to dataset') 35 | parser.add_argument('-j', '--workers', default=8, type=int, 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('-a', '--arch', type=str, default='alexnet', 38 | choices=['alexnet' , 'resnet18' , 'resnet50', 'mobilenet' , 39 | 'moco_alexnet' , 'moco_resnet18' , 'moco_resnet50', 'moco_mobilenet', 'resnet50w5', 40 | 'sup_alexnet' , 'sup_resnet18' , 'sup_resnet50', 'sup_mobilenet', 'pt_alexnet']) 41 | 42 | 43 | parser.add_argument('-b', '--batch-size', default=256, type=int, 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('-p', '--print-freq', default=90, type=int, 48 | help='print frequency (default: 10)') 49 | parser.add_argument('--save', default='./output/cluster_alignment_1', type=str, 50 | help='experiment output directory') 51 | parser.add_argument('--weights', dest='weights', type=str, 52 | help='pre-trained model weights') 53 | parser.add_argument('--load_cache', action='store_true', 54 | help='should the features be recomputed or loaded from the cache') 55 | parser.add_argument('-k', default=1, type=int, help='k in kNN') 56 | 57 | 58 | def main(): 59 | global logger 60 | 61 | args = parser.parse_args() 62 | makedirs(args.save) 63 | 64 | logger = get_logger( 65 | logpath=os.path.join(args.save, 'logs'), 66 | filepath=os.path.abspath(__file__) 67 | ) 68 | def print_pass(*args): 69 | logger.info(*args) 70 | builtins.print = print_pass 71 | 72 | print(args) 73 | 74 | main_worker(args) 75 | 76 | 77 | def get_model(args): 78 | 79 | model = None 80 | if args.arch == 'alexnet' : 81 | model = alexnet() 82 | model.fc = nn.Sequential() 83 | model = torch.nn.DataParallel(model).cuda() 84 | checkpoint = torch.load(args.weights) 85 | msg = model.load_state_dict(checkpoint['model'], strict=False) 86 | print(msg) 87 | 88 | elif args.arch == 'pt_alexnet' : 89 | model = models.alexnet(num_classes=16000) 90 | checkpoint = torch.load(args.weights) 91 | sd = checkpoint['state_dict'] 92 | sd = {k.replace('module.', ''): v for k, v in sd.items()} 93 | msg = model.load_state_dict(sd, strict=True) 94 | classif = list(model.classifier.children())[:5] 95 | model.classifier = nn.Sequential(*classif) 96 | model = torch.nn.DataParallel(model).cuda() 97 | print(model) 98 | print(msg) 99 | 100 | 101 | elif args.arch == 'resnet18' : 102 | model = resnet18() 103 | model.fc = nn.Sequential() 104 | model = torch.nn.DataParallel(model).cuda() 105 | checkpoint = torch.load(args.weights) 106 | model.load_state_dict(checkpoint['model'], strict=False) 107 | 108 | elif args.arch == 'mobilenet' : 109 | model = mobilenet() 110 | model.fc = nn.Sequential() 111 | model = torch.nn.DataParallel(model).cuda() 112 | checkpoint = torch.load(args.weights) 113 | model.load_state_dict(checkpoint['model'] , strict=False) 114 | 115 | elif args.arch == 'resnet50' : 116 | model = resnet50() 117 | model.fc = nn.Sequential() 118 | model = torch.nn.DataParallel(model).cuda() 119 | checkpoint = torch.load(args.weights) 120 | model.load_state_dict(checkpoint['model'], strict=False) 121 | 122 | elif args.arch == 'moco_alexnet' : 123 | model = alexnet() 124 | model.fc = nn.Sequential() 125 | model = nn.Sequential(OrderedDict([('encoder_q', model)])) 126 | model = model.cuda() 127 | checkpoint = torch.load(args.weights) 128 | model.load_state_dict(checkpoint['state_dict'] , strict=False) 129 | 130 | elif args.arch == 'moco_resnet18' : 131 | model = resnet18().cuda() 132 | model = nn.Sequential(OrderedDict([('encoder_q' , model)])) 133 | model = torch.nn.DataParallel(model).cuda() 134 | checkpoint = torch.load(args.weights) 135 | model.load_state_dict(checkpoint['state_dict'] , strict=False) 136 | model.module.encoder_q.fc = nn.Sequential() 137 | 138 | elif args.arch == 'moco_mobilenet' : 139 | model = mobilenet() 140 | model.fc = nn.Sequential() 141 | model = nn.Sequential(OrderedDict([('encoder_q', model)])) 142 | model = torch.nn.DataParallel(model).cuda() 143 | checkpoint = torch.load(args.weights) 144 | model.load_state_dict(checkpoint['state_dict'], strict=False) 145 | 146 | elif args.arch == 'moco_resnet50' : 147 | model = resnet50().cuda() 148 | model = nn.Sequential(OrderedDict([('encoder_q' , model)])) 149 | model = torch.nn.DataParallel(model).cuda() 150 | checkpoint = torch.load(args.weights) 151 | model.load_state_dict(checkpoint['state_dict'] , strict=False) 152 | model.module.encoder_q.fc = nn.Sequential() 153 | 154 | elif args.arch == 'resnet50w5': 155 | model = resnet50w5() 156 | model.l2norm = None 157 | load_weights(model, args.weights) 158 | model = torch.nn.DataParallel(model).cuda() 159 | 160 | elif args.arch == 'sup_alexnet' : 161 | model = models.alexnet(pretrained=True) 162 | modules = list(model.children())[:-1] 163 | classifier_modules = list(model.classifier.children())[:-1] 164 | modules.append(Flatten()) 165 | modules.append(nn.Sequential(*classifier_modules)) 166 | model = nn.Sequential(*modules) 167 | model = model.cuda() 168 | 169 | elif args.arch == 'sup_resnet18' : 170 | model = models.resnet18(pretrained=True) 171 | model.fc = nn.Sequential() 172 | model = torch.nn.DataParallel(model).cuda() 173 | 174 | elif args.arch == 'sup_mobilenet' : 175 | model = models.mobilenet_v2(pretrained=True) 176 | model.classifier = nn.Sequential() 177 | model = torch.nn.DataParallel(model).cuda() 178 | 179 | elif args.arch == 'sup_resnet50' : 180 | model = models.resnet50(pretrained=True) 181 | model.fc = nn.Sequential() 182 | model = torch.nn.DataParallel(model).cuda() 183 | 184 | for param in model.parameters(): 185 | param.requires_grad = False 186 | 187 | return model 188 | 189 | 190 | class ImageFolderEx(datasets.ImageFolder) : 191 | def __getitem__(self, index): 192 | sample, target = super(ImageFolderEx, self).__getitem__(index) 193 | return index, sample, target 194 | 195 | 196 | def get_loaders(dataset_dir, bs, workers): 197 | traindir = os.path.join(dataset_dir, 'train') 198 | valdir = os.path.join(dataset_dir, 'val') 199 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 200 | std=[0.229, 0.224, 0.225]) 201 | 202 | train_loader = DataLoader( 203 | ImageFolderEx(traindir, transforms.Compose([ 204 | transforms.Resize(256), 205 | transforms.CenterCrop(224), 206 | transforms.ToTensor(), 207 | normalize, 208 | ])), 209 | batch_size=bs, shuffle=False, 210 | num_workers=workers, pin_memory=True, 211 | ) 212 | 213 | val_loader = DataLoader( 214 | ImageFolderEx(valdir, transforms.Compose([ 215 | transforms.Resize(256), 216 | transforms.CenterCrop(224), 217 | transforms.ToTensor(), 218 | normalize, 219 | ])), 220 | batch_size=bs, shuffle=False, 221 | num_workers=workers, pin_memory=True, 222 | ) 223 | 224 | return train_loader, val_loader 225 | 226 | 227 | def main_worker(args): 228 | 229 | start = time.time() 230 | # Get train/val loader 231 | # --------------------------------------------------------------- 232 | train_loader, val_loader = get_loaders(args.data, args.batch_size, args.workers) 233 | 234 | # Create and load the model 235 | # If you want to evaluate your model, modify this part and load your model 236 | # ------------------------------------------------------------------------ 237 | # MODIFY 'get_model' TO EVALUATE YOUR MODEL 238 | model = get_model(args) 239 | 240 | # ------------------------------------------------------------------------ 241 | # Forward training samples throw the model and cache feats 242 | # ------------------------------------------------------------------------ 243 | cudnn.benchmark = True 244 | 245 | cached_feats = '%s/train_feats.pth.tar' % args.save 246 | if args.load_cache and os.path.exists(cached_feats): 247 | print('load train feats from cache =>') 248 | train_feats, train_labels, train_inds = torch.load(cached_feats) 249 | else: 250 | print('get train feats =>') 251 | train_feats, train_labels, train_inds = get_feats(train_loader, model, args.print_freq) 252 | torch.save((train_feats, train_labels, train_inds), cached_feats) 253 | 254 | cached_feats = '%s/val_feats.pth.tar' % args.save 255 | if args.load_cache and os.path.exists(cached_feats): 256 | print('load val feats from cache =>') 257 | val_feats, val_labels, val_inds = torch.load(cached_feats) 258 | else: 259 | print('get val feats =>') 260 | val_feats, val_labels, val_inds = get_feats(val_loader, model, args.print_freq) 261 | torch.save((val_feats, val_labels, val_inds), cached_feats) 262 | 263 | # ------------------------------------------------------------------------ 264 | # Calculate NN accuracy on validation set 265 | # ------------------------------------------------------------------------ 266 | 267 | train_feats = normalize(train_feats) 268 | val_feats = normalize(val_feats) 269 | acc = faiss_knn(train_feats, train_labels, val_feats, val_labels, args.k) 270 | nn_time = time.time() - start 271 | print('=> time : {:.2f}s'.format(nn_time)) 272 | print(' * Acc {:.2f}'.format(acc)) 273 | 274 | 275 | def normalize(x): 276 | return x / x.norm(2, dim=1, keepdim=True) 277 | 278 | 279 | def faiss_knn(feats_train, targets_train, feats_val, targets_val, k): 280 | feats_train = feats_train.numpy() 281 | targets_train = targets_train.numpy() 282 | feats_val = feats_val.numpy() 283 | targets_val = targets_val.numpy() 284 | 285 | d = feats_train.shape[-1] 286 | 287 | index = faiss.IndexFlatL2(d) # build the index 288 | co = faiss.GpuMultipleClonerOptions() 289 | co.useFloat16 = True 290 | co.shard = True 291 | gpu_index = faiss.index_cpu_to_all_gpus(index, co) 292 | gpu_index.add(feats_train) 293 | 294 | D, I = gpu_index.search(feats_val, k) 295 | 296 | pred = np.zeros(I.shape[0]) 297 | for i in range(I.shape[0]): 298 | votes = list(Counter(targets_train[I[i]]).items()) 299 | shuffle(votes) 300 | pred[i] = max(votes, key=lambda x: x[1])[0] 301 | 302 | acc = 100.0 * (pred == targets_val).mean() 303 | 304 | return acc 305 | 306 | 307 | def get_feats(loader, model, print_freq): 308 | batch_time = AverageMeter('Time', ':6.3f') 309 | progress = ProgressMeter( 310 | len(loader), 311 | [batch_time], 312 | prefix='Test: ') 313 | 314 | # switch to evaluate mode 315 | model.eval() 316 | feats, labels, indices, ptr = None, None, None, 0 317 | 318 | with torch.no_grad(): 319 | end = time.time() 320 | for i, (index, images, target) in enumerate(loader): 321 | images = images.cuda(non_blocking=True) 322 | cur_targets = target.cpu() 323 | cur_feats = model(images).cpu() 324 | cur_indices = index.cpu() 325 | 326 | B, D = cur_feats.shape 327 | inds = torch.arange(B) + ptr 328 | 329 | if not ptr: 330 | feats = torch.zeros((len(loader.dataset), D)).float() 331 | labels = torch.zeros(len(loader.dataset)).long() 332 | indices = torch.zeros(len(loader.dataset)).long() 333 | 334 | feats.index_copy_(0, inds, cur_feats) 335 | labels.index_copy_(0, inds, cur_targets) 336 | indices.index_copy_(0, inds, cur_indices) 337 | ptr += B 338 | 339 | # measure elapsed time 340 | batch_time.update(time.time() - end) 341 | end = time.time() 342 | 343 | if i % print_freq == 0: 344 | print(progress.display(i)) 345 | 346 | return feats, labels, indices 347 | 348 | 349 | 350 | if __name__ == '__main__': 351 | main() 352 | 353 | -------------------------------------------------------------------------------- /eval_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | from torch.utils.data import DataLoader 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | import torch.nn.functional as F 19 | 20 | from tools import * 21 | from models.alexnet import AlexNet 22 | from models.mobilenet import MobileNetV2 23 | from models.resnet_swav import resnet50w5 24 | 25 | 26 | parser = argparse.ArgumentParser(description='Unsupervised distillation') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('-a', '--arch', default='resnet18', 32 | help='model architecture: ' + 33 | ' | '.join(model_names) + 34 | ' (default: resnet18)') 35 | parser.add_argument('--epochs', default=40, type=int, metavar='N', 36 | help='number of total epochs to run') 37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 38 | help='manual epoch number (useful on restarts)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 45 | metavar='LR', help='initial learning rate', dest='lr') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)', 50 | dest='weight_decay') 51 | parser.add_argument('-p', '--print-freq', default=90, type=int, 52 | metavar='N', help='print frequency (default: 10)') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | parser.add_argument('--seed', default=None, type=int, 56 | help='seed for initializing training. ') 57 | parser.add_argument('--save', default='./output/distill_1', type=str, 58 | help='experiment output directory') 59 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 60 | help='evaluate model on validation set') 61 | parser.add_argument('--weights', dest='weights', type=str, required=True, 62 | help='pre-trained model weights') 63 | parser.add_argument('--lr_schedule', type=str, default='15,30,40', 64 | help='lr drop schedule') 65 | 66 | best_acc1 = 0 67 | 68 | 69 | def main(): 70 | global logger 71 | 72 | args = parser.parse_args() 73 | makedirs(args.save) 74 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 75 | logger.info(args) 76 | 77 | if args.seed is not None: 78 | random.seed(args.seed) 79 | torch.manual_seed(args.seed) 80 | cudnn.deterministic = True 81 | warnings.warn('You have chosen to seed training. ' 82 | 'This will turn on the CUDNN deterministic setting, ' 83 | 'which can slow down your training considerably! ' 84 | 'You may see unexpected behavior when restarting ' 85 | 'from checkpoints.') 86 | 87 | main_worker(args) 88 | 89 | 90 | def load_weights(model, wts_path): 91 | wts = torch.load(wts_path) 92 | if 'state_dict' in wts: 93 | ckpt = wts['state_dict'] 94 | if 'model' in wts: 95 | ckpt = wts['model'] 96 | else: 97 | ckpt = wts 98 | 99 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} 100 | state_dict = {} 101 | 102 | for m_key, m_val in model.state_dict().items(): 103 | if m_key in ckpt: 104 | state_dict[m_key] = ckpt[m_key] 105 | else: 106 | state_dict[m_key] = m_val 107 | print('not copied => ' + m_key) 108 | 109 | model.load_state_dict(state_dict) 110 | 111 | 112 | def get_model(arch, wts_path): 113 | if arch == 'alexnet': 114 | model = AlexNet() 115 | model.fc = nn.Sequential() 116 | load_weights(model, wts_path) 117 | elif arch == 'pt_alexnet': 118 | model = models.alexnet() 119 | classif = list(model.classifier.children())[:5] 120 | model.classifier = nn.Sequential(*classif) 121 | load_weights(model, wts_path) 122 | elif arch == 'mobilenet': 123 | model = MobileNetV2() 124 | model.fc = nn.Sequential() 125 | load_weights(model, wts_path) 126 | elif arch == 'resnet50x5_swav': 127 | model = resnet50w5() 128 | model.l2norm = None 129 | load_weights(model, wts_path) 130 | elif 'resnet' in arch: 131 | model = models.__dict__[arch]() 132 | model.fc = nn.Sequential() 133 | load_weights(model, wts_path) 134 | else: 135 | raise ValueError('arch not found: ' + arch) 136 | 137 | for p in model.parameters(): 138 | p.requires_grad = False 139 | 140 | return model 141 | 142 | 143 | def main_worker(args): 144 | global best_acc1 145 | 146 | # Data loading code 147 | traindir = os.path.join(args.data, 'train') 148 | valdir = os.path.join(args.data, 'val') 149 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 150 | std=[0.229, 0.224, 0.225]) 151 | 152 | train_transform = transforms.Compose([ 153 | transforms.RandomResizedCrop(224), 154 | transforms.RandomHorizontalFlip(), 155 | transforms.ToTensor(), 156 | normalize, 157 | ]) 158 | 159 | val_transform = transforms.Compose([ 160 | transforms.Resize(256), 161 | transforms.CenterCrop(224), 162 | transforms.ToTensor(), 163 | normalize, 164 | ]) 165 | 166 | 167 | train_dataset = datasets.ImageFolder(traindir, train_transform) 168 | train_loader = DataLoader( 169 | train_dataset, 170 | batch_size=args.batch_size, shuffle=True, 171 | num_workers=args.workers, pin_memory=True, 172 | ) 173 | 174 | val_loader = torch.utils.data.DataLoader( 175 | datasets.ImageFolder(valdir, val_transform), 176 | batch_size=args.batch_size, shuffle=False, 177 | num_workers=args.workers, pin_memory=True, 178 | ) 179 | 180 | train_val_loader = torch.utils.data.DataLoader( 181 | datasets.ImageFolder(traindir, val_transform), 182 | batch_size=args.batch_size, shuffle=False, 183 | num_workers=args.workers, pin_memory=True, 184 | ) 185 | 186 | backbone = get_model(args.arch, args.weights) 187 | backbone = nn.DataParallel(backbone).cuda() 188 | backbone.eval() 189 | 190 | 191 | cached_feats = '%s/var_mean.pth.tar' % args.save 192 | if not os.path.exists(cached_feats): 193 | train_feats, _ = get_feats(train_val_loader, backbone, args) 194 | train_var, train_mean = torch.var_mean(train_feats, dim=0) 195 | torch.save((train_var, train_mean), cached_feats) 196 | else: 197 | train_var, train_mean = torch.load(cached_feats) 198 | 199 | linear = nn.Sequential( 200 | Normalize(), 201 | FullBatchNorm(train_var, train_mean), 202 | nn.Linear(get_channels(args.arch), len(train_dataset.classes)), 203 | ) 204 | linear = linear.cuda() 205 | 206 | optimizer = torch.optim.SGD(linear.parameters(), 207 | args.lr, 208 | momentum=args.momentum, 209 | weight_decay=args.weight_decay) 210 | 211 | sched = [int(x) for x in args.lr_schedule.split(',')] 212 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 213 | optimizer, milestones=sched 214 | ) 215 | 216 | # optionally resume from a checkpoint 217 | if args.resume: 218 | if os.path.isfile(args.resume): 219 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 220 | checkpoint = torch.load(args.resume) 221 | args.start_epoch = checkpoint['epoch'] 222 | linear.load_state_dict(checkpoint['state_dict']) 223 | optimizer.load_state_dict(checkpoint['optimizer']) 224 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 225 | logger.info("=> loaded checkpoint '{}' (epoch {})" 226 | .format(args.resume, checkpoint['epoch'])) 227 | else: 228 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 229 | 230 | cudnn.benchmark = True 231 | 232 | if args.evaluate: 233 | validate(val_loader, backbone, linear, args) 234 | return 235 | 236 | for epoch in range(args.start_epoch, args.epochs): 237 | # train for one epoch 238 | train(train_loader, backbone, linear, optimizer, epoch, args) 239 | 240 | # evaluate on validation set 241 | acc1 = validate(val_loader, backbone, linear, args) 242 | 243 | # modify lr 244 | lr_scheduler.step() 245 | logger.info('LR: {:f}'.format(lr_scheduler.get_last_lr()[-1])) 246 | 247 | # remember best acc@1 and save checkpoint 248 | is_best = acc1 > best_acc1 249 | best_acc1 = max(acc1, best_acc1) 250 | 251 | save_checkpoint({ 252 | 'epoch': epoch + 1, 253 | 'state_dict': linear.state_dict(), 254 | 'best_acc1': best_acc1, 255 | 'optimizer': optimizer.state_dict(), 256 | 'lr_scheduler': lr_scheduler.state_dict(), 257 | }, is_best, args.save) 258 | 259 | 260 | class Normalize(nn.Module): 261 | def forward(self, x): 262 | return x / x.norm(2, dim=1, keepdim=True) 263 | 264 | 265 | class FullBatchNorm(nn.Module): 266 | def __init__(self, var, mean): 267 | super(FullBatchNorm, self).__init__() 268 | self.register_buffer('inv_std', (1.0 / torch.sqrt(var + 1e-5))) 269 | self.register_buffer('mean', mean) 270 | 271 | def forward(self, x): 272 | return (x - self.mean) * self.inv_std 273 | 274 | 275 | def get_channels(arch): 276 | if arch == 'alexnet': 277 | c = 4096 278 | elif arch == 'pt_alexnet': 279 | c = 4096 280 | elif arch == 'resnet50': 281 | c = 2048 282 | elif arch == 'resnet18': 283 | c = 512 284 | elif arch == 'mobilenet': 285 | c = 1280 286 | elif arch == 'resnet50x5_swav': 287 | c = 10240 288 | else: 289 | raise ValueError('arch not found: ' + arch) 290 | return c 291 | 292 | 293 | def train(train_loader, backbone, linear, optimizer, epoch, args): 294 | batch_time = AverageMeter('Time', ':6.3f') 295 | data_time = AverageMeter('Data', ':6.3f') 296 | losses = AverageMeter('Loss', ':.4e') 297 | top1 = AverageMeter('Acc@1', ':6.2f') 298 | top5 = AverageMeter('Acc@5', ':6.2f') 299 | progress = ProgressMeter( 300 | len(train_loader), 301 | [batch_time, data_time, losses, top1, top5], 302 | prefix="Epoch: [{}]".format(epoch)) 303 | 304 | # switch to train mode 305 | backbone.eval() 306 | linear.train() 307 | 308 | end = time.time() 309 | for i, (images, target) in enumerate(train_loader): 310 | # measure data loading time 311 | data_time.update(time.time() - end) 312 | 313 | images = images.cuda(non_blocking=True) 314 | target = target.cuda(non_blocking=True) 315 | 316 | # compute output 317 | with torch.no_grad(): 318 | output = backbone(images) 319 | output = linear(output) 320 | loss = F.cross_entropy(output, target) 321 | 322 | # measure accuracy and record loss 323 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 324 | losses.update(loss.item(), images.size(0)) 325 | top1.update(acc1[0], images.size(0)) 326 | top5.update(acc5[0], images.size(0)) 327 | 328 | # compute gradient and do SGD step 329 | optimizer.zero_grad() 330 | loss.backward() 331 | optimizer.step() 332 | 333 | # measure elapsed time 334 | batch_time.update(time.time() - end) 335 | end = time.time() 336 | 337 | if i % args.print_freq == 0: 338 | logger.info(progress.display(i)) 339 | 340 | 341 | def validate(val_loader, backbone, linear, args): 342 | batch_time = AverageMeter('Time', ':6.3f') 343 | losses = AverageMeter('Loss', ':.4e') 344 | top1 = AverageMeter('Acc@1', ':6.2f') 345 | top5 = AverageMeter('Acc@5', ':6.2f') 346 | progress = ProgressMeter( 347 | len(val_loader), 348 | [batch_time, losses, top1, top5], 349 | prefix='Test: ') 350 | 351 | backbone.eval() 352 | linear.eval() 353 | 354 | with torch.no_grad(): 355 | end = time.time() 356 | for i, (images, target) in enumerate(val_loader): 357 | images = images.cuda(non_blocking=True) 358 | target = target.cuda(non_blocking=True) 359 | 360 | # compute output 361 | output = backbone(images) 362 | output = linear(output) 363 | loss = F.cross_entropy(output, target) 364 | 365 | # measure accuracy and record loss 366 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 367 | losses.update(loss.item(), images.size(0)) 368 | top1.update(acc1[0], images.size(0)) 369 | top5.update(acc5[0], images.size(0)) 370 | 371 | # measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | 375 | if i % args.print_freq == 0: 376 | logger.info(progress.display(i)) 377 | 378 | # TODO: this should also be done with the ProgressMeter 379 | logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 380 | .format(top1=top1, top5=top5)) 381 | 382 | return top1.avg 383 | 384 | 385 | def normalize(x): 386 | return x / x.norm(2, dim=1, keepdim=True) 387 | 388 | 389 | def get_feats(loader, model, args): 390 | batch_time = AverageMeter('Time', ':6.3f') 391 | progress = ProgressMeter( 392 | len(loader), 393 | [batch_time], 394 | prefix='Test: ') 395 | 396 | # switch to evaluate mode 397 | model.eval() 398 | feats, labels, ptr = None, None, 0 399 | 400 | with torch.no_grad(): 401 | end = time.time() 402 | for i, (images, target) in enumerate(loader): 403 | images = images.cuda(non_blocking=True) 404 | cur_targets = target.cpu() 405 | cur_feats = normalize(model(images)).cpu() 406 | B, D = cur_feats.shape 407 | inds = torch.arange(B) + ptr 408 | 409 | if not ptr: 410 | feats = torch.zeros((len(loader.dataset), D)).float() 411 | labels = torch.zeros(len(loader.dataset)).long() 412 | 413 | feats.index_copy_(0, inds, cur_feats) 414 | labels.index_copy_(0, inds, cur_targets) 415 | ptr += B 416 | 417 | # measure elapsed time 418 | batch_time.update(time.time() - end) 419 | end = time.time() 420 | 421 | if i % args.print_freq == 0: 422 | logger.info(progress.display(i)) 423 | 424 | return feats, labels 425 | 426 | 427 | if __name__ == '__main__': 428 | main() 429 | -------------------------------------------------------------------------------- /kmeans.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import faiss 12 | 13 | from tools import * 14 | from eval_knn import get_model, get_loaders, normalize 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Unsupervised distillation') 18 | parser.add_argument('data', metavar='DIR', help='path to dataset') 19 | parser.add_argument('-j', '--workers', default=4, type=int, 20 | help='number of data loading workers (default: 4)') 21 | parser.add_argument('-a', '--arch', default='resnet18', 22 | help='model architecture: ' + 23 | ' | '.join(model_names) + 24 | ' (default: resnet18)') 25 | parser.add_argument('-b', '--batch-size', default=256, type=int, 26 | help='mini-batch size (default: 256), this is the total ' 27 | 'batch size of all GPUs on the current node when ' 28 | 'using Data Parallel or Distributed Data Parallel') 29 | parser.add_argument('-p', '--print-freq', default=90, type=int, 30 | help='print frequency (default: 10)') 31 | parser.add_argument('--save', default='./output/kmeans_1', type=str, 32 | help='experiment output directory') 33 | parser.add_argument('--weights', dest='weights', type=str, 34 | help='pre-trained model weights') 35 | parser.add_argument('--load_cache', action='store_true', 36 | help='should the features be recomputed or loaded from the cache') 37 | parser.add_argument('--clusters', default=2000, type=int, help='numbe of clusters') 38 | 39 | best_acc1 = 0 40 | 41 | 42 | def main(): 43 | global logger 44 | 45 | args = parser.parse_args() 46 | makedirs(args.save) 47 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 48 | logger.info(args) 49 | 50 | main_worker(args) 51 | 52 | 53 | def main_worker(args): 54 | global best_acc1 55 | 56 | train_loader, val_loader = get_loaders(args.data, args.batch_size, args.workers) 57 | 58 | model = get_model(args) 59 | model = model.cuda() 60 | 61 | cudnn.benchmark = True 62 | 63 | cached_feats = '%s/train_feats.pth.tar' % args.save 64 | if args.load_cache and os.path.exists(cached_feats): 65 | logger.info('load train feats from cache =>') 66 | train_feats, _ = torch.load(cached_feats) 67 | else: 68 | logger.info('get train feats =>') 69 | train_feats, _ = get_feats(train_loader, model, args.print_freq) 70 | torch.save((train_feats, _), cached_feats) 71 | 72 | cached_feats = '%s/val_feats.pth.tar' % args.save 73 | if args.load_cache and os.path.exists(cached_feats): 74 | logger.info('load val feats from cache =>') 75 | val_feats, _ = torch.load(cached_feats) 76 | else: 77 | logger.info('get val feats =>') 78 | val_feats, _ = get_feats(val_loader, model, args.print_freq) 79 | torch.save((val_feats, _), cached_feats) 80 | 81 | start = time.time() 82 | train_a, val_a = faiss_kmeans(train_feats, val_feats, args.clusters) 83 | 84 | samples = list(s.replace(args.data + '/train/', '') for s, _ in train_loader.dataset.samples) 85 | train_s = list((s, a) for s, a in zip(samples, train_a)) 86 | train_d_path = os.path.join(args.save, 'train_clusters.txt') 87 | with open(train_d_path, 'w') as f: 88 | for pth, cls in train_s: 89 | f.write('{} {}\n'.format(pth, cls)) 90 | 91 | samples = list(s.replace(args.data + '/val/', '') for s, _ in val_loader.dataset.samples) 92 | val_s = list((s, a) for s, a in zip(samples, val_a)) 93 | val_d_path = os.path.join(args.save, 'val_clusters.txt') 94 | with open(val_d_path, 'w') as f: 95 | for pth, cls in val_s: 96 | f.write('{} {}\n'.format(pth, cls)) 97 | 98 | faiss_time = time.time() - start 99 | logger.info('=> faiss time : {:.2f}s'.format(faiss_time)) 100 | 101 | 102 | def faiss_kmeans(train_feats, val_feats, nmb_clusters): 103 | train_feats = train_feats.numpy() 104 | val_feats = val_feats.numpy() 105 | 106 | d = train_feats.shape[-1] 107 | 108 | clus = faiss.Clustering(d, nmb_clusters) 109 | clus.niter = 20 110 | clus.max_points_per_centroid = 10000000 111 | 112 | index = faiss.IndexFlatL2(d) 113 | co = faiss.GpuMultipleClonerOptions() 114 | co.useFloat16 = True 115 | co.shard = True 116 | index = faiss.index_cpu_to_all_gpus(index, co) 117 | 118 | # perform the training 119 | clus.train(train_feats, index) 120 | _, train_a = index.search(train_feats, 1) 121 | _, val_a = index.search(val_feats, 1) 122 | 123 | return list(train_a[:, 0]), list(val_a[:, 0]) 124 | 125 | 126 | def get_feats(loader, model, print_freq): 127 | batch_time = AverageMeter('Time', ':6.3f') 128 | progress = ProgressMeter( 129 | len(loader), 130 | [batch_time], 131 | prefix='Test: ') 132 | 133 | # switch to evaluate mode 134 | model.eval() 135 | feats, labels, ptr = None, None, 0 136 | 137 | with torch.no_grad(): 138 | end = time.time() 139 | for i, (images, target) in enumerate(loader): 140 | images = images.cuda(non_blocking=True) 141 | cur_targets = target.cpu() 142 | cur_feats = normalize(model(images)).cpu() 143 | B, D = cur_feats.shape 144 | inds = torch.arange(B) + ptr 145 | 146 | if not ptr: 147 | feats = torch.zeros((len(loader.dataset), D)).float() 148 | labels = torch.zeros(len(loader.dataset)).long() 149 | 150 | feats.index_copy_(0, inds, cur_feats) 151 | labels.index_copy_(0, inds, cur_targets) 152 | ptr += B 153 | 154 | # measure elapsed time 155 | batch_time.update(time.time() - end) 156 | end = time.time() 157 | 158 | if i % print_freq == 0: 159 | logger.info(progress.display(i)) 160 | 161 | return feats, labels 162 | 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class AlexNet(nn.Module): 8 | 9 | def __init__(self, num_classes=1000): 10 | super(AlexNet, self).__init__() 11 | self.features = nn.Sequential( 12 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=3, stride=2), 15 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 16 | nn.ReLU(inplace=True), 17 | nn.MaxPool2d(kernel_size=3, stride=2), 18 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=3, stride=2), 25 | ) 26 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 27 | self.feats = nn.Sequential( 28 | # nn.Dropout(), 29 | nn.Linear(256 * 6 * 6, 4096), 30 | nn.ReLU(inplace=True), 31 | # nn.Dropout(), 32 | nn.Linear(4096, 4096), 33 | 34 | ) 35 | 36 | self.fc = nn.Linear(4096, num_classes) 37 | 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = self.avgpool(x) 42 | x = torch.flatten(x, 1) 43 | x = F.relu(self.feats(x)) 44 | x = self.fc(x) 45 | return x 46 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn 3 | 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | class ConvBNReLU(nn.Sequential): 42 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 43 | padding = (kernel_size - 1) // 2 44 | super(ConvBNReLU, self).__init__( 45 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 46 | nn.BatchNorm2d(out_planes), 47 | nn.ReLU6(inplace=True) 48 | ) 49 | 50 | 51 | class InvertedResidual(nn.Module): 52 | def __init__(self, inp, oup, stride, expand_ratio): 53 | super(InvertedResidual, self).__init__() 54 | self.stride = stride 55 | assert stride in [1, 2] 56 | 57 | hidden_dim = int(round(inp * expand_ratio)) 58 | self.use_res_connect = self.stride == 1 and inp == oup 59 | 60 | layers = [] 61 | if expand_ratio != 1: 62 | # pw 63 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 64 | layers.extend([ 65 | # dw 66 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | nn.BatchNorm2d(oup), 70 | ]) 71 | self.conv = nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | if self.use_res_connect: 75 | return x + self.conv(x) 76 | else: 77 | return self.conv(x) 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | class MobileNetV2(nn.Module): 90 | def __init__(self, 91 | num_classes=1000, 92 | width_mult=1.0, 93 | inverted_residual_setting=None, 94 | round_nearest=8, 95 | block=None): 96 | """ 97 | MobileNet V2 main class 98 | 99 | Args: 100 | num_classes (int): Number of classes 101 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 102 | inverted_residual_setting: Network structure 103 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 104 | Set to 1 to turn off rounding 105 | block: Module specifying inverted residual building block for mobilenet 106 | 107 | """ 108 | super(MobileNetV2, self).__init__() 109 | 110 | if block is None: 111 | block = InvertedResidual 112 | input_channel = 32 113 | last_channel = 1280 114 | 115 | if inverted_residual_setting is None: 116 | inverted_residual_setting = [ 117 | # t, c, n, s 118 | [1, 16, 1, 1], 119 | [6, 24, 2, 2], 120 | [6, 32, 3, 2], 121 | [6, 64, 4, 2], 122 | [6, 96, 3, 1], 123 | [6, 160, 3, 2], 124 | [6, 320, 1, 1], 125 | ] 126 | 127 | # only check the first element, assuming user knows t,c,n,s are required 128 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 129 | raise ValueError("inverted_residual_setting should be non-empty " 130 | "or a 4-element list, got {}".format(inverted_residual_setting)) 131 | 132 | # building first layer 133 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 134 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 135 | features = [ConvBNReLU(3, input_channel, stride=2)] 136 | # building inverted residual blocks 137 | for t, c, n, s in inverted_residual_setting: 138 | output_channel = _make_divisible(c * width_mult, round_nearest) 139 | for i in range(n): 140 | stride = s if i == 0 else 1 141 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 142 | input_channel = output_channel 143 | # building last several layers 144 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 145 | # make it nn.Sequential 146 | self.features = nn.Sequential(*features) 147 | 148 | # building classifier 149 | self.fc = nn.Linear(self.last_channel, num_classes) 150 | 151 | # weight initialization 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 155 | if m.bias is not None: 156 | nn.init.zeros_(m.bias) 157 | elif isinstance(m, nn.BatchNorm2d): 158 | nn.init.ones_(m.weight) 159 | nn.init.zeros_(m.bias) 160 | elif isinstance(m, nn.Linear): 161 | nn.init.normal_(m.weight, 0, 0.01) 162 | nn.init.zeros_(m.bias) 163 | 164 | def _forward_impl(self, x): 165 | # This exists since TorchScript doesn't support inheritance, so the superclass method 166 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 167 | x = self.features(x) 168 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 169 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 170 | x = self.fc(x) 171 | return x 172 | 173 | def forward(self, x): 174 | return self._forward_impl(x) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, fc_dim=128, in_channel=3, width=1): 102 | self.inplanes = 64 103 | super(ResNet, self).__init__() 104 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | 109 | self.base = int(64 * width) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | self.layer1 = self._make_layer(block, self.base, layers[0]) 112 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc = nn.Linear(self.base * 8 * block.expansion, fc_dim) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x, layer=7): 144 | if layer <= 0: 145 | return x 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | x = self.maxpool(x) 150 | if layer == 1: 151 | return x 152 | x = self.layer1(x) 153 | if layer == 2: 154 | return x 155 | x = self.layer2(x) 156 | if layer == 3: 157 | return x 158 | x = self.layer3(x) 159 | if layer == 4: 160 | return x 161 | x = self.layer4(x) 162 | if layer == 5: 163 | return x 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | 167 | if layer == 6: 168 | return x 169 | x = self.fc(x) 170 | 171 | return x 172 | 173 | 174 | 175 | def resnet18(fc_dim=128, pretrained=False, **kwargs): 176 | """Constructs a ResNet-18 model. 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet(BasicBlock, [2, 2, 2, 2], fc_dim = fc_dim , **kwargs) 181 | if pretrained: 182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 183 | return model 184 | 185 | 186 | 187 | 188 | 189 | def resnet50(fc_dim=128,pretrained=False, **kwargs): 190 | """Constructs a ResNet-50 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(Bottleneck, [3, 4, 6, 3], fc_dim = fc_dim , **kwargs) 195 | if pretrained: 196 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 197 | return model 198 | -------------------------------------------------------------------------------- /models/resnet50x4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb 4 | 5 | class Resnet50_X4(nn.Module) : 6 | 7 | def __init__(self): 8 | super(Resnet50_X4, self).__init__() 9 | norm_layer = nn.BatchNorm2d 10 | 11 | self.conv2d = nn.Conv2d(3, 256, kernel_size=7, stride=2, padding=3, 12 | bias=False) 13 | self.batch_normalization = norm_layer(256) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 16 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 17 | 18 | # layer 1 19 | 20 | # b1 21 | self.conv2d_1 = nn.Conv2d(256, 1024, kernel_size=1, stride=1, bias=False) 22 | self.batch_normalization_1 = norm_layer(1024) 23 | 24 | self.conv2d_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, bias=False) 25 | self.batch_normalization_2 = norm_layer(256) 26 | 27 | self.conv2d_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, bias=False, padding=1) 28 | self.batch_normalization_3 = norm_layer(256) 29 | 30 | self.conv2d_4 = nn.Conv2d(256, 1024, kernel_size=1, stride=1, bias=False) 31 | self.batch_normalization_4 = norm_layer(1024) 32 | 33 | # b2 34 | self.conv2d_5 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, bias=False) 35 | self.batch_normalization_5 = norm_layer(256) 36 | 37 | self.conv2d_6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, bias=False, padding=1) 38 | self.batch_normalization_6 = norm_layer(256) 39 | 40 | self.conv2d_7 = nn.Conv2d(256, 1024, kernel_size=1, stride=1, bias=False) 41 | self.batch_normalization_7 = norm_layer(1024) 42 | 43 | # b3 44 | self.conv2d_8 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, bias=False) 45 | self.batch_normalization_8 = norm_layer(256) 46 | 47 | self.conv2d_9 = nn.Conv2d(256, 256, kernel_size=3, stride=1, bias=False, padding=1) 48 | self.batch_normalization_9 = norm_layer(256) 49 | 50 | self.conv2d_10 = nn.Conv2d(256, 1024, kernel_size=1, stride=1, bias=False) 51 | self.batch_normalization_10 = norm_layer(1024) 52 | 53 | # layer 2 54 | 55 | # b1 56 | self.conv2d_11 = nn.Conv2d(1024, 2048, kernel_size=1, stride=2, bias=False) 57 | self.batch_normalization_11 = norm_layer(2048) 58 | 59 | self.conv2d_12 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, bias=False) 60 | self.batch_normalization_12 = norm_layer(512) 61 | 62 | self.conv2d_13 = nn.Conv2d(512, 512, kernel_size=3, stride=2, bias=False, padding=1) 63 | self.batch_normalization_13 = norm_layer(512) 64 | 65 | self.conv2d_14 = nn.Conv2d(512, 2048, kernel_size=1, stride=1, bias=False) 66 | self.batch_normalization_14 = norm_layer(2048) 67 | 68 | # b2 69 | self.conv2d_15 = nn.Conv2d(2048, 512, kernel_size=1, stride=1, bias=False) 70 | self.batch_normalization_15 = norm_layer(512) 71 | 72 | self.conv2d_16 = nn.Conv2d(512, 512, kernel_size=3, stride=1, bias=False, padding=1) 73 | self.batch_normalization_16 = norm_layer(512) 74 | 75 | self.conv2d_17 = nn.Conv2d(512, 2048, kernel_size=1, stride=1, bias=False) 76 | self.batch_normalization_17 = norm_layer(2048) 77 | 78 | # b3 79 | self.conv2d_18 = nn.Conv2d(2048, 512, kernel_size=1, stride=1, bias=False) 80 | self.batch_normalization_18 = norm_layer(512) 81 | 82 | self.conv2d_19 = nn.Conv2d(512, 512, kernel_size=3, stride=1, bias=False, padding=1) 83 | self.batch_normalization_19 = norm_layer(512) 84 | 85 | self.conv2d_20 = nn.Conv2d(512, 2048, kernel_size=1, stride=1, bias=False) 86 | self.batch_normalization_20 = norm_layer(2048) 87 | 88 | # b4 89 | self.conv2d_21 = nn.Conv2d(2048, 512, kernel_size=1, stride=1, bias=False) 90 | self.batch_normalization_21 = norm_layer(512) 91 | 92 | self.conv2d_22 = nn.Conv2d(512, 512, kernel_size=3, stride=1, bias=False, padding=1) 93 | self.batch_normalization_22 = norm_layer(512) 94 | 95 | self.conv2d_23 = nn.Conv2d(512, 2048, kernel_size=1, stride=1, bias=False) 96 | self.batch_normalization_23 = norm_layer(2048) 97 | 98 | # layer 3 99 | 100 | # b1 101 | self.conv2d_24 = nn.Conv2d(2048, 4096, kernel_size=1, stride=2, bias=False) 102 | self.batch_normalization_24 = norm_layer(4096) 103 | 104 | self.conv2d_25 = nn.Conv2d(2048, 1024, kernel_size=1, stride=1, bias=False) 105 | self.batch_normalization_25 = norm_layer(1024) 106 | 107 | self.conv2d_26 = nn.Conv2d(1024, 1024, kernel_size=3, stride=2, bias=False, padding=1) 108 | self.batch_normalization_26 = norm_layer(1024) 109 | 110 | self.conv2d_27 = nn.Conv2d(1024, 4096, kernel_size=1, stride=1, bias=False) 111 | self.batch_normalization_27 = norm_layer(4096) 112 | 113 | # b2 114 | self.conv2d_28 = nn.Conv2d(4096, 1024, kernel_size=1, stride=1, bias=False) 115 | self.batch_normalization_28 = norm_layer(1024) 116 | 117 | self.conv2d_29 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, bias=False , padding=1) 118 | self.batch_normalization_29 = norm_layer(1024) 119 | 120 | self.conv2d_30 = nn.Conv2d(1024, 4096, kernel_size=1, stride=1, bias=False) 121 | self.batch_normalization_30 = norm_layer(4096) 122 | 123 | # b3 124 | self.conv2d_31 = nn.Conv2d(4096, 1024, kernel_size=1, stride=1, bias=False) 125 | self.batch_normalization_31 = norm_layer(1024) 126 | 127 | self.conv2d_32 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, bias=False, padding=1) 128 | self.batch_normalization_32 = norm_layer(1024) 129 | 130 | self.conv2d_33 = nn.Conv2d(1024, 4096, kernel_size=1, stride=1, bias=False) 131 | self.batch_normalization_33 = norm_layer(4096) 132 | 133 | # b4 134 | self.conv2d_34 = nn.Conv2d(4096, 1024, kernel_size=1, stride=1, bias=False) 135 | self.batch_normalization_34 = norm_layer(1024) 136 | 137 | self.conv2d_35 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, bias=False, padding=1) 138 | self.batch_normalization_35 = norm_layer(1024) 139 | 140 | self.conv2d_36 = nn.Conv2d(1024, 4096, kernel_size=1, stride=1, bias=False) 141 | self.batch_normalization_36 = norm_layer(4096) 142 | 143 | # b5 144 | self.conv2d_37 = nn.Conv2d(4096, 1024, kernel_size=1, stride=1, bias=False) 145 | self.batch_normalization_37 = norm_layer(1024) 146 | 147 | self.conv2d_38 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, bias=False, padding=1) 148 | self.batch_normalization_38 = norm_layer(1024) 149 | 150 | self.conv2d_39 = nn.Conv2d(1024, 4096, kernel_size=1, stride=1, bias=False) 151 | self.batch_normalization_39 = norm_layer(4096) 152 | 153 | # b6 154 | self.conv2d_40 = nn.Conv2d(4096, 1024, kernel_size=1, stride=1, bias=False) 155 | self.batch_normalization_40 = norm_layer(1024) 156 | 157 | self.conv2d_41 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, bias=False, padding=1) 158 | self.batch_normalization_41 = norm_layer(1024) 159 | 160 | self.conv2d_42 = nn.Conv2d(1024, 4096, kernel_size=1, stride=1, bias=False) 161 | self.batch_normalization_42 = norm_layer(4096) 162 | 163 | # layer 4 164 | 165 | # b1 166 | self.conv2d_43 = nn.Conv2d(4096, 8192, kernel_size=1, stride=2, bias=False) 167 | self.batch_normalization_43 = norm_layer(8192) 168 | 169 | self.conv2d_44 = nn.Conv2d(4096, 2048, kernel_size=1, stride=1, bias=False) 170 | self.batch_normalization_44 = norm_layer(2048) 171 | 172 | self.conv2d_45 = nn.Conv2d(2048, 2048, kernel_size=3, stride=2, bias=False, padding=1) 173 | self.batch_normalization_45 = norm_layer(2048) 174 | 175 | self.conv2d_46 = nn.Conv2d(2048, 8192, kernel_size=1, stride=1, bias=False) 176 | self.batch_normalization_46 = norm_layer(8192) 177 | 178 | # b2 179 | self.conv2d_47 = nn.Conv2d(8192, 2048, kernel_size=1, stride=1, bias=False) 180 | self.batch_normalization_47 = norm_layer(2048) 181 | 182 | self.conv2d_48 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, bias=False, padding=1) 183 | self.batch_normalization_48 = norm_layer(2048) 184 | 185 | self.conv2d_49 = nn.Conv2d(2048, 8192, kernel_size=1, stride=1, bias=False) 186 | self.batch_normalization_49 = norm_layer(8192) 187 | 188 | # b2 189 | self.conv2d_50 = nn.Conv2d(8192, 2048, kernel_size=1, stride=1, bias=False) 190 | self.batch_normalization_50 = norm_layer(2048) 191 | 192 | self.conv2d_51 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, bias=False, padding=1) 193 | self.batch_normalization_51 = norm_layer(2048) 194 | 195 | self.conv2d_52 = nn.Conv2d(2048, 8192, kernel_size=1, stride=1, bias=False) 196 | self.batch_normalization_52 = norm_layer(8192) 197 | 198 | self.fc = nn.Linear(8192 , 1000) 199 | 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv2d): 202 | nn.init.constant_(m.weight , 0.1) 203 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 204 | nn.init.constant_(m.weight, 0.2) 205 | nn.init.constant_(m.bias, 0.1) 206 | 207 | def forward(self, x): 208 | x = self.conv2d(x) 209 | x = self.relu(self.batch_normalization(x)) 210 | x = self.maxpool(x) 211 | 212 | # layer1 213 | # b1 214 | shortcut = self.batch_normalization_1(self.conv2d_1(x)) 215 | x = self.relu(self.batch_normalization_2(self.conv2d_2(x))) 216 | x = self.relu(self.batch_normalization_3(self.conv2d_3(x))) 217 | x = self.batch_normalization_4(self.conv2d_4(x)) 218 | x = self.relu(x + shortcut) 219 | 220 | # b2 221 | shortcut = x 222 | x = self.relu(self.batch_normalization_5(self.conv2d_5(x))) 223 | x = self.relu(self.batch_normalization_6(self.conv2d_6(x))) 224 | x = self.batch_normalization_7(self.conv2d_7(x)) 225 | x = self.relu(x + shortcut) 226 | 227 | # b3 228 | shortcut = x 229 | x = self.relu(self.batch_normalization_8(self.conv2d_8(x))) 230 | x = self.relu(self.batch_normalization_9(self.conv2d_9(x))) 231 | x = self.batch_normalization_10(self.conv2d_10(x)) 232 | x = self.relu(x + shortcut) 233 | 234 | # layer2 235 | # b1 236 | shortcut = self.batch_normalization_11(self.conv2d_11(x)) 237 | x = self.relu(self.batch_normalization_12(self.conv2d_12(x))) 238 | x = self.relu(self.batch_normalization_13(self.conv2d_13(x))) 239 | x = self.batch_normalization_14(self.conv2d_14(x)) 240 | x = self.relu(x + shortcut) 241 | 242 | # b2 243 | shortcut = x 244 | x = self.relu(self.batch_normalization_15(self.conv2d_15(x))) 245 | x = self.relu(self.batch_normalization_16(self.conv2d_16(x))) 246 | x = self.batch_normalization_17(self.conv2d_17(x)) 247 | x = self.relu(x + shortcut) 248 | 249 | # b3 250 | shortcut = x 251 | x = self.relu(self.batch_normalization_18(self.conv2d_18(x))) 252 | x = self.relu(self.batch_normalization_19(self.conv2d_19(x))) 253 | x = self.batch_normalization_20(self.conv2d_20(x)) 254 | x = self.relu(x + shortcut) 255 | 256 | # b4 257 | shortcut = x 258 | x = self.relu(self.batch_normalization_21(self.conv2d_21(x))) 259 | x = self.relu(self.batch_normalization_22(self.conv2d_22(x))) 260 | x = self.batch_normalization_23(self.conv2d_23(x)) 261 | x = self.relu(x + shortcut) 262 | 263 | # layer3 264 | # b1 265 | shortcut = self.batch_normalization_24(self.conv2d_24(x)) 266 | x = self.relu(self.batch_normalization_25(self.conv2d_25(x))) 267 | x = self.relu(self.batch_normalization_26(self.conv2d_26(x))) 268 | x = self.batch_normalization_27(self.conv2d_27(x)) 269 | x = self.relu(x + shortcut) 270 | 271 | # b2 272 | shortcut = x 273 | x = self.relu(self.batch_normalization_28(self.conv2d_28(x))) 274 | x = self.relu(self.batch_normalization_29(self.conv2d_29(x))) 275 | x = self.batch_normalization_30(self.conv2d_30(x)) 276 | x = self.relu(x + shortcut) 277 | 278 | # b3 279 | shortcut = x 280 | x = self.relu(self.batch_normalization_31(self.conv2d_31(x))) 281 | x = self.relu(self.batch_normalization_32(self.conv2d_32(x))) 282 | x = self.batch_normalization_33(self.conv2d_33(x)) 283 | x = self.relu(x + shortcut) 284 | 285 | # b4 286 | shortcut = x 287 | x = self.relu(self.batch_normalization_34(self.conv2d_34(x))) 288 | x = self.relu(self.batch_normalization_35(self.conv2d_35(x))) 289 | x = self.batch_normalization_36(self.conv2d_36(x)) 290 | x = self.relu(x + shortcut) 291 | 292 | # b5 293 | shortcut = x 294 | x = self.relu(self.batch_normalization_37(self.conv2d_37(x))) 295 | x = self.relu(self.batch_normalization_38(self.conv2d_38(x))) 296 | x = self.batch_normalization_39(self.conv2d_39(x)) 297 | x = self.relu(x + shortcut) 298 | 299 | # b6 300 | shortcut = x 301 | x = self.relu(self.batch_normalization_40(self.conv2d_40(x))) 302 | x = self.relu(self.batch_normalization_41(self.conv2d_41(x))) 303 | x = self.batch_normalization_42(self.conv2d_42(x)) 304 | x = self.relu(x + shortcut) 305 | 306 | 307 | # layer4 308 | # b1 309 | shortcut = self.batch_normalization_43(self.conv2d_43(x)) 310 | x = self.relu(self.batch_normalization_44(self.conv2d_44(x))) 311 | x = self.relu(self.batch_normalization_45(self.conv2d_45(x))) 312 | x = self.batch_normalization_46(self.conv2d_46(x)) 313 | x = self.relu(x + shortcut) 314 | 315 | # b2 316 | shortcut = x 317 | x = self.relu(self.batch_normalization_47(self.conv2d_47(x))) 318 | x = self.relu(self.batch_normalization_48(self.conv2d_48(x))) 319 | x = self.batch_normalization_49(self.conv2d_49(x)) 320 | x = self.relu(x + shortcut) 321 | 322 | # b3 323 | shortcut = x 324 | x = self.relu(self.batch_normalization_50(self.conv2d_50(x))) 325 | x = self.relu(self.batch_normalization_51(self.conv2d_51(x))) 326 | x = self.batch_normalization_52(self.conv2d_52(x)) 327 | x = self.relu(x + shortcut) 328 | 329 | x = self.avgpool(x) 330 | x = torch.flatten(x , 1) 331 | 332 | x = self.fc(x) 333 | 334 | return x 335 | 336 | -------------------------------------------------------------------------------- /nn/compress_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.nn.functional as F 5 | 6 | 7 | class Teacher: 8 | 9 | def __init__(self, cached=True, cached_feats=None, model=None): 10 | self.model = model 11 | self.cached_feats = cached_feats 12 | self.cached = cached 13 | 14 | def eval(self): 15 | if not self.cached: 16 | self.model.eval() 17 | 18 | def gpu(self): 19 | if not self.cached: 20 | self.model = torch.nn.DataParallel(self.model).cuda() 21 | 22 | def forward(self, x, indices): 23 | if self.cached: 24 | feats = self.cached_feats[indices] 25 | feats = feats.cuda() 26 | else: 27 | with torch.no_grad(): 28 | feats = self.model(x) 29 | 30 | feats = feats.detach() 31 | return feats 32 | 33 | 34 | class SampleSimilarities(nn.Module): 35 | 36 | def __init__(self, feats_dim, queueSize, T): 37 | super(SampleSimilarities, self).__init__() 38 | self.inputSize = feats_dim 39 | self.queueSize = queueSize 40 | self.T = T 41 | self.index = 0 42 | stdv = 1. / math.sqrt(feats_dim / 3) 43 | self.register_buffer('memory', torch.rand(self.queueSize, feats_dim).mul_(2 * stdv).add_(-stdv)) 44 | print('using queue shape: ({},{})'.format(self.queueSize, feats_dim)) 45 | 46 | def forward(self, q, update=True): 47 | batchSize = q.shape[0] 48 | queue = self.memory.clone() 49 | out = torch.mm(queue.detach(), q.transpose(1, 0)) 50 | out = out.transpose(0, 1) 51 | out = torch.div(out, self.T) 52 | out = out.squeeze().contiguous() 53 | 54 | if update: 55 | # update memory bank 56 | with torch.no_grad(): 57 | out_ids = torch.arange(batchSize).cuda() 58 | out_ids += self.index 59 | out_ids = torch.fmod(out_ids, self.queueSize) 60 | out_ids = out_ids.long() 61 | self.memory.index_copy_(0, out_ids, q) 62 | self.index = (self.index + batchSize) % self.queueSize 63 | 64 | return out 65 | 66 | 67 | class CompReSS(nn.Module): 68 | 69 | def __init__(self , teacher_feats_dim, student_feats_dim, queue_size=128000, T=0.04): 70 | super(CompReSS, self).__init__() 71 | 72 | self.l2norm = Normalize(2).cuda() 73 | self.criterion = KLD().cuda() 74 | self.student_sample_similarities = SampleSimilarities(student_feats_dim , queue_size , T).cuda() 75 | self.teacher_sample_similarities = SampleSimilarities(teacher_feats_dim , queue_size , T).cuda() 76 | 77 | def forward(self, teacher_feats, student_feats): 78 | 79 | teacher_feats = self.l2norm(teacher_feats) 80 | student_feats = self.l2norm(student_feats) 81 | 82 | similarities_student = self.student_sample_similarities(student_feats) 83 | similarities_teacher = self.teacher_sample_similarities(teacher_feats) 84 | 85 | loss = self.criterion(similarities_teacher , similarities_student) 86 | return loss 87 | 88 | 89 | class CompReSSA(nn.Module): 90 | 91 | def __init__(self, teacher_feats_dim, queue_size=128000, T=0.04): 92 | super(CompReSSA, self).__init__() 93 | 94 | self.l2norm = Normalize(2).cuda() 95 | self.criterion = KLD().cuda() 96 | self.teacher_sample_similarities = SampleSimilarities(teacher_feats_dim, queue_size, T).cuda() 97 | 98 | def forward(self, teacher_feats, student_feats): 99 | 100 | teacher_feats = self.l2norm(teacher_feats) 101 | student_feats = self.l2norm(student_feats) 102 | 103 | similarities_student = self.teacher_sample_similarities(student_feats, update=False) 104 | similarities_teacher = self.teacher_sample_similarities(teacher_feats) 105 | 106 | loss = self.criterion(similarities_teacher, similarities_student) 107 | return loss 108 | 109 | 110 | 111 | class SampleSimilaritiesMomentum(nn.Module): 112 | 113 | def __init__(self, feats_dim, queueSize, T): 114 | super(SampleSimilaritiesMomentum, self).__init__() 115 | self.inputSize = feats_dim 116 | self.queueSize = queueSize 117 | self.T = T 118 | self.index = 0 119 | stdv = 1. / math.sqrt(feats_dim / 3) 120 | self.register_buffer('memory', torch.rand(self.queueSize, feats_dim).mul_(2 * stdv).add_(-stdv)) 121 | print('using queue shape: ({},{})'.format(self.queueSize, feats_dim)) 122 | 123 | def forward(self, q , q_key): 124 | batchSize = q.shape[0] 125 | queue = self.memory.clone() 126 | out = torch.mm(queue.detach(), q.transpose(1, 0)) 127 | out = out.transpose(0, 1) 128 | out = torch.div(out, self.T) 129 | out = out.squeeze().contiguous() 130 | 131 | # update memory bank 132 | with torch.no_grad(): 133 | out_ids = torch.arange(batchSize).cuda() 134 | out_ids += self.index 135 | out_ids = torch.fmod(out_ids, self.queueSize) 136 | out_ids = out_ids.long() 137 | self.memory.index_copy_(0, out_ids, q_key) 138 | self.index = (self.index + batchSize) % self.queueSize 139 | 140 | return out 141 | 142 | 143 | class CompReSSMomentum(nn.Module): 144 | 145 | def __init__(self , teacher_feats_dim, student_feats_dim, queue_size=128000, T=0.04): 146 | super(CompReSSMomentum, self).__init__() 147 | 148 | self.l2norm = Normalize(2).cuda() 149 | self.criterion = KLD().cuda() 150 | self.student_sample_similarities = SampleSimilaritiesMomentum(student_feats_dim, queue_size, T).cuda() 151 | self.teacher_sample_similarities = SampleSimilarities(teacher_feats_dim, queue_size, T).cuda() 152 | 153 | def forward(self, teacher_feats, student_feats, student_feats_key): 154 | 155 | teacher_feats = self.l2norm(teacher_feats) 156 | student_feats = self.l2norm(student_feats) 157 | student_feats_key = self.l2norm(student_feats_key) 158 | 159 | similarities_student = self.student_sample_similarities(student_feats, student_feats_key) 160 | similarities_teacher = self.teacher_sample_similarities(teacher_feats) 161 | 162 | loss = self.criterion(similarities_teacher, similarities_student) 163 | return loss 164 | 165 | 166 | class Normalize(nn.Module): 167 | 168 | def __init__(self, power=2): 169 | super(Normalize, self).__init__() 170 | self.power = power 171 | 172 | def forward(self, x): 173 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 174 | out = x.div(norm) 175 | return out 176 | 177 | 178 | class KLD(nn.Module): 179 | 180 | def forward(self, targets, inputs): 181 | targets = F.softmax(targets, dim=1) 182 | inputs = F.log_softmax(inputs, dim=1) 183 | return F.kl_div(inputs, targets, reduction='batchmean') 184 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import logging 4 | import os 5 | 6 | import torch 7 | from torch import nn 8 | from torchvision import models 9 | 10 | 11 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 12 | logger = logging.getLogger() 13 | if debug: 14 | level = logging.DEBUG 15 | else: 16 | level = logging.INFO 17 | logger.setLevel(level) 18 | if saving: 19 | info_file_handler = logging.FileHandler(logpath, mode="a") 20 | info_file_handler.setLevel(level) 21 | logger.addHandler(info_file_handler) 22 | if displaying: 23 | console_handler = logging.StreamHandler() 24 | console_handler.setLevel(level) 25 | logger.addHandler(console_handler) 26 | logger.info(filepath) 27 | with open(filepath, "r") as f: 28 | logger.info(f.read()) 29 | 30 | for f in package_files: 31 | logger.info(f) 32 | with open(f, "r") as package_f: 33 | logger.info(package_f.read()) 34 | 35 | return logger 36 | 37 | 38 | def makedirs(dirname): 39 | if not os.path.exists(dirname): 40 | os.makedirs(dirname) 41 | 42 | 43 | def save_each_checkpoint(state, epoch, save_dir): 44 | ckpt_path = os.path.join(save_dir, 'ckpt_%d.pth.tar' % epoch) 45 | torch.save(state, ckpt_path) 46 | 47 | 48 | def save_checkpoint(state, is_best, save_dir): 49 | ckpt_path = os.path.join(save_dir, 'checkpoint.pth.tar') 50 | torch.save(state, ckpt_path) 51 | if is_best: 52 | best_ckpt_path = os.path.join(save_dir, 'model_best.pth.tar') 53 | shutil.copyfile(ckpt_path, best_ckpt_path) 54 | 55 | 56 | class AverageMeter(object): 57 | """Computes and stores the average and current value""" 58 | def __init__(self, name, fmt=':f'): 59 | self.name = name 60 | self.fmt = fmt 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | self.avg = self.sum / self.count 74 | 75 | def __str__(self): 76 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 77 | return fmtstr.format(**self.__dict__) 78 | 79 | 80 | class ProgressMeter(object): 81 | def __init__(self, num_batches, meters, prefix=""): 82 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 83 | self.meters = meters 84 | self.prefix = prefix 85 | 86 | def display(self, batch): 87 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 88 | entries += [str(meter) for meter in self.meters] 89 | return '\t'.join(entries) 90 | 91 | def _get_batch_fmtstr(self, num_batches): 92 | num_digits = len(str(num_batches // 1)) 93 | fmt = '{:' + str(num_digits) + 'd}' 94 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 95 | 96 | 97 | def accuracy(output, target, topk=(1,)): 98 | """Computes the accuracy over the k top predictions for the specified values of k""" 99 | with torch.no_grad(): 100 | maxk = max(topk) 101 | batch_size = target.size(0) 102 | 103 | _, pred = output.topk(maxk, 1, True, True) 104 | pred = pred.t() 105 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 106 | 107 | res = [] 108 | for k in topk: 109 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 110 | res.append(correct_k.mul_(100.0 / batch_size)) 111 | return res 112 | 113 | 114 | arch_to_key = { 115 | 'alexnet': 'alexnet', 116 | 'alexnet_moco': 'alexnet', 117 | 'resnet18': 'resnet18', 118 | 'resnet50': 'resnet50', 119 | 'rotnet_r50': 'resnet50', 120 | 'rotnet_r18': 'resnet18', 121 | 'resnet18_moco': 'resnet18', 122 | 'resnet_moco': 'resnet50', 123 | } 124 | 125 | model_names = list(arch_to_key.keys()) 126 | 127 | 128 | def remove_dropout(model): 129 | classif = model.classifier.children() 130 | classif = [nn.Sequential() if isinstance(m, nn.Dropout) else m for m in classif] 131 | model.classifier = nn.Sequential(*classif) 132 | 133 | 134 | # 1. stores a list of models to ensemble 135 | # 2. forward through each model and save the output 136 | # 3. return mean of the outputs along the class dimension 137 | class EnsembleNet(nn.ModuleList): 138 | def forward(self, x): 139 | out = [m(x) for m in self] 140 | out = torch.stack(out, dim=-1) 141 | out = out.mean(dim=-1) 142 | return out 143 | -------------------------------------------------------------------------------- /train_kmeans.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path 3 | import argparse 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | import sys 10 | from collections import OrderedDict 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.optim 17 | import torch.utils.data 18 | from torch.utils.data import DataLoader 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | import torch.nn.functional as F 23 | 24 | from PIL import Image 25 | 26 | from tools import * 27 | 28 | 29 | model_names = sorted(name for name in models.__dict__ 30 | if name.islower() and not name.startswith("__") 31 | and callable(models.__dict__[name])) 32 | 33 | 34 | parser = argparse.ArgumentParser(description='Unsupervised distillation') 35 | parser.add_argument('data', metavar='DIR', 36 | help='path to dataset') 37 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-a', '--arch', default='resnet18', 40 | help='model architecture: ' + 41 | ' | '.join(model_names) + 42 | ' (default: resnet18)') 43 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 44 | help='number of total epochs to run') 45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 46 | help='manual epoch number (useful on restarts)') 47 | parser.add_argument('-b', '--batch-size', default=256, type=int, 48 | metavar='N', 49 | help='mini-batch size (default: 256), this is the total ' 50 | 'batch size of all GPUs on the current node when ' 51 | 'using Data Parallel or Distributed Data Parallel') 52 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 53 | help='initial learning rate', dest='lr') 54 | parser.add_argument('--cos', action='store_true', 55 | help='whether to cosine learning rate or not') 56 | parser.add_argument('--schedule', type=int, nargs='*', 57 | default=[30,60,80,90], 58 | help='lr drop schedule') 59 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 60 | help='momentum') 61 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 62 | metavar='W', help='weight decay (default: 1e-4)', 63 | dest='weight_decay') 64 | parser.add_argument('-p', '--print-freq', default=90, type=int, 65 | metavar='N', help='print frequency (default: 10)') 66 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 67 | help='path to latest checkpoint (default: none)') 68 | parser.add_argument('--seed', default=None, type=int, 69 | help='seed for initializing training. ') 70 | parser.add_argument('--save', default='./output/distill_1', type=str, 71 | help='experiment output directory') 72 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 73 | help='evaluate model on validation set') 74 | parser.add_argument('--clusters', type=str, required=True, 75 | help='location containing cluster assignment files') 76 | 77 | best_acc1 = 0 78 | 79 | 80 | class ImageFileDataset(datasets.VisionDataset): 81 | def __init__(self, root, f_path, transform=None): 82 | super(ImageFileDataset, self).__init__(root, transform=transform) 83 | with open(f_path, 'r') as f: 84 | lines = [line.strip().split(' ') for line in f.readlines()] 85 | lines = [(os.path.join(root, pth), int(cid)) for pth, cid in lines] 86 | self.samples = lines 87 | self.classes = sorted(set(s[1] for s in self.samples)) 88 | 89 | def __getitem__(self, index): 90 | path, target = self.samples[index] 91 | sample = default_loader(path) 92 | sample = self.transform(sample) 93 | 94 | return sample, target 95 | 96 | def __len__(self): 97 | return len(self.samples) 98 | 99 | 100 | def pil_loader(path): 101 | with open(path, 'rb') as f: 102 | img = Image.open(f) 103 | return img.convert('RGB') 104 | 105 | 106 | def accimage_loader(path): 107 | import accimage 108 | try: 109 | return accimage.Image(path) 110 | except IOError: 111 | # Potentially a decoding problem, fall back to PIL.Image 112 | return pil_loader(path) 113 | 114 | 115 | def default_loader(path): 116 | from torchvision import get_image_backend 117 | if get_image_backend() == 'accimage': 118 | return accimage_loader(path) 119 | else: 120 | return pil_loader(path) 121 | 122 | 123 | def main(): 124 | global logger 125 | 126 | args = parser.parse_args() 127 | makedirs(args.save) 128 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 129 | logger.info(args) 130 | 131 | if args.seed is not None: 132 | random.seed(args.seed) 133 | torch.manual_seed(args.seed) 134 | cudnn.deterministic = True 135 | warnings.warn('You have chosen to seed training. ' 136 | 'This will turn on the CUDNN deterministic setting, ' 137 | 'which can slow down your training considerably! ' 138 | 'You may see unexpected behavior when restarting ' 139 | 'from checkpoints.') 140 | 141 | main_worker(args) 142 | 143 | 144 | def main_worker(args): 145 | global best_acc1 146 | 147 | # Data loading code 148 | traindir = os.path.join(args.data, 'train') 149 | valdir = os.path.join(args.data, 'val') 150 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 151 | std=[0.229, 0.224, 0.225]) 152 | train_clusters = os.path.join(args.clusters, 'train_clusters.txt') 153 | val_clusters = os.path.join(args.clusters, 'val_clusters.txt') 154 | 155 | train_dataset = ImageFileDataset( 156 | traindir, 157 | train_clusters, 158 | transforms.Compose([ 159 | transforms.RandomResizedCrop(224), 160 | transforms.RandomHorizontalFlip(), 161 | transforms.ToTensor(), 162 | normalize, 163 | ])) 164 | 165 | train_loader = DataLoader( 166 | train_dataset, batch_size=args.batch_size, shuffle=True, 167 | num_workers=args.workers, pin_memory=True) 168 | 169 | val_loader = torch.utils.data.DataLoader( 170 | ImageFileDataset( 171 | valdir, 172 | val_clusters, 173 | transforms.Compose([ 174 | transforms.Resize(256), 175 | transforms.CenterCrop(224), 176 | transforms.ToTensor(), 177 | normalize, 178 | ]) 179 | ), 180 | batch_size=args.batch_size, shuffle=False, 181 | num_workers=args.workers, pin_memory=True) 182 | 183 | model = models.__dict__[args.arch](num_classes=len(train_dataset.classes)) 184 | model = nn.DataParallel(model).cuda() 185 | 186 | bn_params = OrderedDict() 187 | for n, m in model.named_modules(): 188 | if isinstance(m, nn.BatchNorm2d): 189 | for p in m.parameters(): 190 | bn_params[p] = True 191 | params = OrderedDict() 192 | for p in model.parameters(): 193 | if p not in bn_params: 194 | params[p] = True 195 | param_groups = [ 196 | {'params': list(params.keys())}, 197 | {'params': list(bn_params.keys()), 'weight_decay': 0}, 198 | ] 199 | 200 | optimizer = torch.optim.SGD(param_groups, 201 | args.lr, 202 | momentum=args.momentum, 203 | weight_decay=args.weight_decay, 204 | nesterov=True) 205 | 206 | # optionally resume from a checkpoint 207 | if args.resume: 208 | if os.path.isfile(args.resume): 209 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 210 | checkpoint = torch.load(args.resume) 211 | args.start_epoch = checkpoint['epoch'] 212 | model.load_state_dict(checkpoint['state_dict']) 213 | optimizer.load_state_dict(checkpoint['optimizer']) 214 | logger.info("=> loaded checkpoint '{}' (epoch {})" 215 | .format(args.resume, checkpoint['epoch'])) 216 | else: 217 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 218 | 219 | cudnn.benchmark = True 220 | 221 | if args.evaluate: 222 | validate(val_loader, model, args) 223 | return 224 | 225 | for epoch in range(args.start_epoch, args.epochs): 226 | adjust_learning_rate(optimizer, epoch, args) 227 | print_lr(optimizer) 228 | 229 | # train for one epoch 230 | train(train_loader, model, optimizer, epoch, args) 231 | 232 | # evaluate on validation set 233 | acc1 = validate(val_loader, model, args) 234 | 235 | # remember best acc@1 and save checkpoint 236 | is_best = acc1 > best_acc1 237 | best_acc1 = max(acc1, best_acc1) 238 | 239 | save_checkpoint({ 240 | 'epoch': epoch + 1, 241 | 'state_dict': model.state_dict(), 242 | 'best_acc1': best_acc1, 243 | 'optimizer': optimizer.state_dict(), 244 | }, is_best, args.save) 245 | 246 | 247 | def train(train_loader, model, optimizer, epoch, args): 248 | batch_time = AverageMeter('Time', ':6.3f') 249 | data_time = AverageMeter('Data', ':6.3f') 250 | losses = AverageMeter('Loss', ':.4e') 251 | top1 = AverageMeter('Acc@1', ':6.2f') 252 | top5 = AverageMeter('Acc@5', ':6.2f') 253 | progress = ProgressMeter( 254 | len(train_loader), 255 | [batch_time, data_time, losses, top1, top5], 256 | prefix="Epoch: [{}]".format(epoch)) 257 | 258 | # switch to train mode 259 | model.train() 260 | 261 | end = time.time() 262 | for i, (images, target) in enumerate(train_loader): 263 | # measure data loading time 264 | data_time.update(time.time() - end) 265 | 266 | images = images.cuda(non_blocking=True) 267 | target = target.cuda(non_blocking=True) 268 | 269 | # compute output 270 | output = model(images) 271 | loss = F.cross_entropy(output, target) 272 | 273 | # measure accuracy and record loss 274 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 275 | losses.update(loss.item(), images.size(0)) 276 | top1.update(acc1[0], images.size(0)) 277 | top5.update(acc5[0], images.size(0)) 278 | 279 | # compute gradient and do SGD step 280 | optimizer.zero_grad() 281 | loss.backward() 282 | optimizer.step() 283 | 284 | # measure elapsed time 285 | batch_time.update(time.time() - end) 286 | end = time.time() 287 | 288 | if i % args.print_freq == 0: 289 | logger.info(progress.display(i)) 290 | 291 | 292 | def validate(val_loader, model, args): 293 | batch_time = AverageMeter('Time', ':6.3f') 294 | losses = AverageMeter('Loss', ':.4e') 295 | top1 = AverageMeter('Acc@1', ':6.2f') 296 | top5 = AverageMeter('Acc@5', ':6.2f') 297 | progress = ProgressMeter( 298 | len(val_loader), 299 | [batch_time, losses, top1, top5], 300 | prefix='Test: ') 301 | 302 | # switch to evaluate mode 303 | model.eval() 304 | 305 | with torch.no_grad(): 306 | end = time.time() 307 | for i, (images, target) in enumerate(val_loader): 308 | images = images.cuda(non_blocking=True) 309 | target = target.cuda(non_blocking=True) 310 | 311 | # compute output 312 | output = model(images) 313 | loss = F.cross_entropy(output, target) 314 | 315 | # measure accuracy and record loss 316 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 317 | losses.update(loss.item(), images.size(0)) 318 | top1.update(acc1[0], images.size(0)) 319 | top5.update(acc5[0], images.size(0)) 320 | 321 | # measure elapsed time 322 | batch_time.update(time.time() - end) 323 | end = time.time() 324 | 325 | if i % args.print_freq == 0: 326 | logger.info(progress.display(i)) 327 | 328 | # TODO: this should also be done with the ProgressMeter 329 | logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 330 | .format(top1=top1, top5=top5)) 331 | 332 | return top1.avg 333 | 334 | 335 | def adjust_learning_rate(optimizer, epoch, args): 336 | """Decay the learning rate based on schedule""" 337 | lr = args.lr 338 | if args.cos: # cosine lr schedule 339 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 340 | else: # stepwise lr schedule 341 | for milestone in args.schedule: 342 | lr *= 0.1 if epoch >= milestone else 1. 343 | for param_group in optimizer.param_groups: 344 | param_group['lr'] = lr 345 | 346 | 347 | def print_lr(optimizer): 348 | lrs = [param_group['lr'] for param_group in optimizer.param_groups] 349 | lrs = ' '.join('{:f}'.format(l) for l in lrs) 350 | logger.info('LR: ' + lrs) 351 | 352 | 353 | if __name__ == '__main__': 354 | main() 355 | 356 | -------------------------------------------------------------------------------- /train_student.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import argparse 8 | import socket 9 | 10 | from torchvision import transforms, datasets 11 | import torch.nn as nn 12 | 13 | from util import adjust_learning_rate, AverageMeter 14 | from models.resnet import resnet18,resnet50 15 | from models.alexnet import AlexNet as alexnet 16 | from models.mobilenet import MobileNetV2 as mobilenet 17 | from nn.compress_loss import CompReSSMomentum, Teacher 18 | 19 | from collections import OrderedDict 20 | 21 | def parse_option(): 22 | 23 | parser = argparse.ArgumentParser('argument for training') 24 | 25 | parser.add_argument('data', type=str, help='path to dataset') 26 | 27 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 28 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 29 | parser.add_argument('--save_freq', type=int, default=2, help='save frequency') 30 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size') 31 | parser.add_argument('--num_workers', type=int, default=12, help='num of workers to use') 32 | parser.add_argument('--epochs', type=int, default=130, help='number of training epochs') 33 | 34 | # optimization 35 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate') 36 | parser.add_argument('--lr_decay_epochs', type=str, default='90,120', help='where to decay lr, can be a list') 37 | parser.add_argument('--lr_decay_rate', type=float, default=0.2, help='decay rate for learning rate') 38 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') 39 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 40 | 41 | 42 | # model definition 43 | parser.add_argument('--student_arch', type=str, default='alexnet', 44 | choices=['alexnet' , 'resnet18' , 'resnet50', 'mobilenet']) 45 | parser.add_argument('--teacher_arch', type=str, default='resnet50', 46 | choices=['resnet50x4', 'resnet50']) 47 | parser.add_argument('--cache_teacher', action='store_true', 48 | help='use cached teacher') 49 | 50 | # CompReSS loss function 51 | parser.add_argument('--compress_memory_size', type=int, default=128000) 52 | parser.add_argument('--compress_t', type=float, default=0.04) 53 | 54 | # GPU setting 55 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 56 | 57 | parser.add_argument('--teacher', type=str, help='teacher weights/feats') 58 | 59 | parser.add_argument('--checkpoint_path', default='output/', type=str, 60 | help='where to save checkpoints. ') 61 | 62 | parser.add_argument('--alpha', type=float, default=0.999, help='exponential moving average weight') 63 | 64 | opt = parser.parse_args() 65 | 66 | iterations = opt.lr_decay_epochs.split(',') 67 | opt.lr_decay_epochs = list([]) 68 | for it in iterations: 69 | opt.lr_decay_epochs.append(int(it)) 70 | 71 | return opt 72 | 73 | 74 | # Extended version of ImageFolder to return index of image too. 75 | class ImageFolderEx(datasets.ImageFolder) : 76 | 77 | def __getitem__(self, index): 78 | sample, target = super(ImageFolderEx, self).__getitem__(index) 79 | return index , sample, target 80 | 81 | 82 | # Create teacher model and load weights. For cached teacher load cahced features instead. 83 | def get_teacher_model(opt): 84 | teacher = None 85 | if opt.cache_teacher : 86 | train_feats, train_labels, indices = torch.load(opt.teacher) 87 | teacher = Teacher(cached=True , cached_feats=train_feats) 88 | 89 | elif opt.teacher_arch == 'resnet50': 90 | model_t = resnet50() 91 | model_t.fc = nn.Sequential() 92 | model_t = nn.Sequential(OrderedDict([('encoder_q', model_t)])) 93 | model_t = torch.nn.DataParallel(model_t).cuda() 94 | checkpoint = torch.load(opt.teacher) 95 | model_t.load_state_dict(checkpoint['state_dict'], strict=False) 96 | model_t = model_t.module.cpu() 97 | 98 | for p in model_t.parameters(): 99 | p.requires_grad = False 100 | teacher = Teacher(cached=False, model=model_t) 101 | 102 | return teacher 103 | 104 | 105 | # Create student query/key model 106 | def get_student_model(opt): 107 | student = None 108 | student_key = None 109 | if opt.student_arch == 'alexnet': 110 | student = alexnet() 111 | student.fc = nn.Sequential() 112 | student_key = alexnet() 113 | student_key.fc = nn.Sequential() 114 | 115 | elif opt.student_arch == 'mobilenet': 116 | student = mobilenet() 117 | student.fc = nn.Sequential() 118 | student_key = mobilenet() 119 | student_key.fc = nn.Sequential() 120 | 121 | elif opt.student_arch == 'resnet18': 122 | student = resnet18() 123 | student.fc = nn.Sequential() 124 | student_key = resnet18() 125 | student_key.fc = nn.Sequential() 126 | 127 | elif opt.student_arch == 'resnet50': 128 | student = resnet50(fc_dim=8192) 129 | student_key = resnet50(fc_dim=8192) 130 | 131 | return student , student_key 132 | 133 | 134 | # Create train loader 135 | def get_train_loader(opt): 136 | data_folder = os.path.join(opt.data, 'train') 137 | image_size = 224 138 | mean = [0.485, 0.456, 0.406] 139 | std = [0.229, 0.224, 0.225] 140 | normalize = transforms.Normalize(mean=mean, std=std) 141 | 142 | train_dataset = ImageFolderEx( 143 | data_folder, 144 | transforms.Compose([ 145 | transforms.RandomResizedCrop(224), 146 | transforms.RandomHorizontalFlip(), 147 | transforms.ToTensor(), 148 | normalize, 149 | ])) 150 | 151 | train_loader = torch.utils.data.DataLoader( 152 | train_dataset, batch_size=opt.batch_size, shuffle=True, 153 | num_workers=opt.num_workers, pin_memory=True) 154 | 155 | return train_loader 156 | 157 | 158 | # Update Key model from Query model 159 | def moment_update(query_model, key_model, m): 160 | """ key_model = m * key_model + (1 - m) query_model """ 161 | for p1, p2 in zip(query_model.parameters(), key_model.parameters()): 162 | p2.data.mul_(m).add_(1-m, p1.detach().data) 163 | 164 | 165 | def main(): 166 | 167 | args = parse_option() 168 | os.makedirs(args.checkpoint_path, exist_ok=True) 169 | 170 | if args.gpu is not None: 171 | print("Use GPU: {} for training".format(args.gpu)) 172 | 173 | train_loader = get_train_loader(args) 174 | 175 | teacher = get_teacher_model(args) 176 | student, student_key = get_student_model(args) 177 | 178 | # Calculate feature dimension of student and teacher 179 | teacher.eval() 180 | student.eval() 181 | tmp_input = torch.randn(2, 3, 224, 224) 182 | feat_t = teacher.forward(tmp_input, 0) 183 | feat_s = student(tmp_input) 184 | student_feats_dim = feat_s.shape[-1] 185 | teacher_feats_dim = feat_t.shape[-1] 186 | 187 | compress = CompReSSMomentum(teacher_feats_dim, student_feats_dim, args.compress_memory_size, args.compress_t) 188 | 189 | student = torch.nn.DataParallel(student).cuda() 190 | student_key = torch.nn.DataParallel(student_key).cuda() 191 | teacher.gpu() 192 | 193 | optimizer = torch.optim.SGD(student.parameters(), 194 | lr=args.learning_rate, 195 | momentum=args.momentum, 196 | weight_decay=args.weight_decay) 197 | 198 | cudnn.benchmark = True 199 | 200 | args.start_epoch = 1 201 | moment_update(student, student_key, 0) 202 | 203 | # routine 204 | for epoch in range(args.start_epoch, args.epochs + 1): 205 | 206 | adjust_learning_rate(epoch, args, optimizer) 207 | print("==> training...") 208 | 209 | time1 = time.time() 210 | loss = train_student(epoch, train_loader, teacher, student, student_key, compress, optimizer, args) 211 | 212 | time2 = time.time() 213 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 214 | 215 | # saving the model 216 | if epoch % args.save_freq == 0: 217 | print('==> Saving...') 218 | state = { 219 | 'opt': args, 220 | 'model': student.state_dict(), 221 | 'optimizer': optimizer.state_dict(), 222 | 'epoch': epoch, 223 | } 224 | 225 | save_file = os.path.join(args.checkpoint_path, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 226 | torch.save(state, save_file) 227 | 228 | # help release GPU memory 229 | del state 230 | torch.cuda.empty_cache() 231 | 232 | 233 | def train_student(epoch, train_loader, teacher, student, student_key, compress, optimizer, opt): 234 | """ 235 | one epoch training for CompReSS 236 | """ 237 | student_key.eval() 238 | student.train() 239 | 240 | def set_bn_train(m): 241 | classname = m.__class__.__name__ 242 | if classname.find('BatchNorm') != -1: 243 | m.train() 244 | student_key.apply(set_bn_train) 245 | 246 | batch_time = AverageMeter() 247 | data_time = AverageMeter() 248 | loss_meter = AverageMeter() 249 | 250 | end = time.time() 251 | for idx, (index, inputs, _) in enumerate(train_loader): 252 | data_time.update(time.time() - end) 253 | 254 | bsz = inputs.size(0) 255 | 256 | inputs = inputs.float() 257 | if opt.gpu is not None: 258 | inputs = inputs.cuda(opt.gpu, non_blocking=True) 259 | else: 260 | inputs = inputs.cuda() 261 | 262 | # ===================forward===================== 263 | teacher_feats = teacher.forward(inputs , index) 264 | student_feats = student(inputs) 265 | 266 | with torch.no_grad(): 267 | student_feats_key = student_key(inputs) 268 | student_feats_key = student_feats_key.detach() 269 | 270 | loss = compress(teacher_feats , student_feats , student_feats_key) 271 | 272 | # ===================backward===================== 273 | optimizer.zero_grad() 274 | loss.backward() 275 | optimizer.step() 276 | 277 | # ===================meters===================== 278 | loss_meter.update(loss.item(), bsz) 279 | 280 | moment_update(student, student_key, opt.alpha) 281 | 282 | torch.cuda.synchronize() 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | # print info 287 | if (idx + 1) % opt.print_freq == 0: 288 | print('Train: [{0}][{1}/{2}]\t' 289 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 290 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 291 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format( 292 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 293 | data_time=data_time, loss=loss_meter)) 294 | sys.stdout.flush() 295 | 296 | return loss_meter.avg 297 | 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /train_student_one_queue.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import socket 8 | 9 | from torchvision import transforms, datasets 10 | import torch.nn as nn 11 | 12 | from util import adjust_learning_rate, AverageMeter 13 | from models.resnet import resnet18, resnet50 14 | from models.alexnet import AlexNet as alexnet 15 | from models.mobilenet import MobileNetV2 as mobilenet 16 | from nn.compress_loss import CompReSSA, Teacher 17 | 18 | from collections import OrderedDict 19 | 20 | def parse_option(): 21 | 22 | parser = argparse.ArgumentParser('argument for training') 23 | 24 | parser.add_argument('data', type=str, help='path to dataset') 25 | 26 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 27 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 28 | parser.add_argument('--save_freq', type=int, default=2, help='save frequency') 29 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size') 30 | parser.add_argument('--num_workers', type=int, default=12, help='num of workers to use') 31 | parser.add_argument('--epochs', type=int, default=130, help='number of training epochs') 32 | 33 | # optimization 34 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate') 35 | parser.add_argument('--lr_decay_epochs', type=str, default='90,120', help='where to decay lr, can be a list') 36 | parser.add_argument('--lr_decay_rate', type=float, default=0.2, help='decay rate for learning rate') 37 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') 38 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 39 | 40 | 41 | # model definition 42 | parser.add_argument('--student_arch', type=str, default='alexnet', 43 | choices=['alexnet', 'resnet18', 'resnet50', 'mobilenet']) 44 | parser.add_argument('--teacher_arch', type=str, default='resnet50', 45 | choices=['resnet50x4', 'resnet50']) 46 | parser.add_argument('--cache_teacher', action='store_true', 47 | help='use cached teacher') 48 | 49 | # CompReSS loss function 50 | parser.add_argument('--compress_memory_size', type=int, default=128000) 51 | parser.add_argument('--compress_t', type=float, default=0.04) 52 | 53 | # GPU setting 54 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 55 | 56 | parser.add_argument('--teacher', type=str, help='teacher weights/feats') 57 | 58 | parser.add_argument('--checkpoint_path', default='output/', type=str, 59 | help='where to save checkpoints. ') 60 | 61 | opt = parser.parse_args() 62 | 63 | iterations = opt.lr_decay_epochs.split(',') 64 | opt.lr_decay_epochs = list([]) 65 | for it in iterations: 66 | opt.lr_decay_epochs.append(int(it)) 67 | 68 | return opt 69 | 70 | 71 | # Extended version of ImageFolder to return index of image too. 72 | class ImageFolderEx(datasets.ImageFolder): 73 | 74 | def __getitem__(self, index): 75 | sample, target = super(ImageFolderEx, self).__getitem__(index) 76 | return index, sample, target 77 | 78 | 79 | # Create teacher model and load weights. For cached teacher load cahced features instead. 80 | def get_teacher_model(opt): 81 | teacher = None 82 | if opt.cache_teacher: 83 | print('==> cached teacher') 84 | train_feats, train_labels, indices = torch.load(opt.teacher) 85 | teacher = Teacher(cached=True, cached_feats=train_feats) 86 | 87 | elif opt.teacher_arch == 'resnet50': 88 | print('==> online teacher') 89 | model_t = resnet50() 90 | model_t.fc = nn.Sequential() 91 | model_t = nn.Sequential(OrderedDict([('encoder_q', model_t)])) 92 | model_t = torch.nn.DataParallel(model_t).cuda() 93 | checkpoint = torch.load(opt.teacher) 94 | msg = model_t.load_state_dict(checkpoint['state_dict'], strict=False) 95 | print('==> loading teacher weights') 96 | print(msg) 97 | model_t = model_t.module.cpu() 98 | 99 | for p in model_t.parameters(): 100 | p.requires_grad = False 101 | teacher = Teacher(cached=False, model=model_t) 102 | 103 | return teacher 104 | 105 | 106 | # Create student query/key model 107 | def get_student_model(opt): 108 | student = None 109 | if opt.student_arch == 'alexnet': 110 | student = alexnet() 111 | student.fc = nn.Sequential() 112 | 113 | elif opt.student_arch == 'mobilenet': 114 | student = mobilenet() 115 | student.fc = nn.Sequential() 116 | 117 | elif opt.student_arch == 'resnet18': 118 | student = resnet18() 119 | student.fc = nn.Sequential() 120 | 121 | elif opt.student_arch == 'resnet50': 122 | student = resnet50(fc_dim=8192) 123 | 124 | return student 125 | 126 | 127 | # Create train loader 128 | def get_train_loader(opt): 129 | data_folder = os.path.join(opt.data, 'train') 130 | mean = [0.485, 0.456, 0.406] 131 | std = [0.229, 0.224, 0.225] 132 | normalize = transforms.Normalize(mean=mean, std=std) 133 | 134 | train_dataset = ImageFolderEx( 135 | data_folder, 136 | transforms.Compose([ 137 | transforms.RandomResizedCrop(224), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | normalize, 141 | ])) 142 | 143 | train_loader = torch.utils.data.DataLoader( 144 | train_dataset, batch_size=opt.batch_size, shuffle=True, 145 | num_workers=opt.num_workers, pin_memory=True) 146 | 147 | return train_loader 148 | 149 | 150 | def main(): 151 | 152 | args = parse_option() 153 | os.makedirs(args.checkpoint_path, exist_ok=True) 154 | 155 | if args.gpu is not None: 156 | print("Use GPU: {} for training".format(args.gpu)) 157 | 158 | train_loader = get_train_loader(args) 159 | 160 | teacher = get_teacher_model(args) 161 | student = get_student_model(args) 162 | 163 | # calculate feature dimension of student and teacher 164 | teacher.eval() 165 | student.eval() 166 | tmp_input = torch.randn(2, 3, 224, 224) 167 | feat_t = teacher.forward(tmp_input, 0) 168 | feat_s = student(tmp_input) 169 | student_feats_dim = feat_s.shape[-1] 170 | teacher_feats_dim = feat_t.shape[-1] 171 | student.fc = nn.Linear(student_feats_dim, teacher_feats_dim) 172 | 173 | compress = CompReSSA(teacher_feats_dim, args.compress_memory_size, args.compress_t) 174 | 175 | student = torch.nn.DataParallel(student).cuda() 176 | teacher.gpu() 177 | 178 | 179 | optimizer = torch.optim.SGD(student.parameters(), 180 | lr=args.learning_rate, 181 | momentum=args.momentum, 182 | weight_decay=args.weight_decay) 183 | 184 | cudnn.benchmark = True 185 | 186 | args.start_epoch = 1 187 | 188 | # routine 189 | for epoch in range(args.start_epoch, args.epochs + 1): 190 | 191 | adjust_learning_rate(epoch, args, optimizer) 192 | print("==> training...") 193 | 194 | time1 = time.time() 195 | loss = train_student(epoch, train_loader, teacher, student, compress, optimizer, args) 196 | 197 | time2 = time.time() 198 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 199 | 200 | # saving the model 201 | if epoch % args.save_freq == 0: 202 | print('==> Saving...') 203 | state = { 204 | 'opt': args, 205 | 'model': student.state_dict(), 206 | 'optimizer': optimizer.state_dict(), 207 | 'epoch': epoch, 208 | } 209 | 210 | save_file = os.path.join(args.checkpoint_path, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 211 | torch.save(state, save_file) 212 | 213 | # help release GPU memory 214 | del state 215 | torch.cuda.empty_cache() 216 | 217 | 218 | def train_student(epoch, train_loader, teacher, student, compress, optimizer, opt): 219 | """ 220 | one epoch training for CompReSS 221 | """ 222 | student.train() 223 | 224 | batch_time = AverageMeter() 225 | data_time = AverageMeter() 226 | loss_meter = AverageMeter() 227 | 228 | end = time.time() 229 | for idx, (index, inputs, _) in enumerate(train_loader): 230 | data_time.update(time.time() - end) 231 | 232 | bsz = inputs.size(0) 233 | 234 | inputs = inputs.float() 235 | if opt.gpu is not None: 236 | inputs = inputs.cuda(opt.gpu, non_blocking=True) 237 | else: 238 | inputs = inputs.cuda() 239 | 240 | # ===================forward===================== 241 | teacher_feats = teacher.forward(inputs, index) 242 | student_feats = student(inputs) 243 | 244 | loss = compress(teacher_feats, student_feats) 245 | 246 | # ===================backward===================== 247 | optimizer.zero_grad() 248 | loss.backward() 249 | optimizer.step() 250 | 251 | # ===================meters===================== 252 | loss_meter.update(loss.item(), bsz) 253 | 254 | torch.cuda.synchronize() 255 | batch_time.update(time.time() - end) 256 | end = time.time() 257 | 258 | # print info 259 | if (idx + 1) % opt.print_freq == 0: 260 | print('Train: [{0}][{1}/{2}]\t' 261 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 262 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 263 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format( 264 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 265 | data_time=data_time, loss=loss_meter)) 266 | sys.stdout.flush() 267 | 268 | return loss_meter.avg 269 | 270 | 271 | if __name__ == '__main__': 272 | main() 273 | -------------------------------------------------------------------------------- /train_student_without_momentum.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import socket 8 | 9 | from torchvision import transforms, datasets 10 | import torch.nn as nn 11 | 12 | from util import adjust_learning_rate, AverageMeter 13 | from models.resnet import resnet18, resnet50 14 | from models.alexnet import AlexNet as alexnet 15 | from models.mobilenet import MobileNetV2 as mobilenet 16 | from nn.compress_loss import CompReSS, Teacher 17 | 18 | from collections import OrderedDict 19 | 20 | 21 | def parse_option(): 22 | parser = argparse.ArgumentParser('argument for training') 23 | 24 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 25 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 26 | parser.add_argument('--save_freq', type=int, default=2, help='save frequency') 27 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size') 28 | parser.add_argument('--num_workers', type=int, default=12, help='num of workers to use') 29 | parser.add_argument('--epochs', type=int, default=130, help='number of training epochs') 30 | 31 | # optimization 32 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate') 33 | parser.add_argument('--lr_decay_epochs', type=str, default='90,120', help='where to decay lr, can be a list') 34 | parser.add_argument('--lr_decay_rate', type=float, default=0.2, help='decay rate for learning rate') 35 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') 36 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 37 | 38 | # model definition 39 | parser.add_argument('--student_arch', type=str, default='alexnet', 40 | choices=['alexnet', 'resnet18', 'resnet50', 'mobilenet']) 41 | parser.add_argument('--teacher_arch', type=str, default='resnet50', 42 | choices=['resnet50x4', 'resnet50']) 43 | parser.add_argument('--cache_teacher', action='store_true', 44 | help='use cached teacher') 45 | 46 | # CompReSS loss function 47 | parser.add_argument('--compress_memory_size', type=int, default=128000) 48 | parser.add_argument('--compress_t', type=float, default=0.04) 49 | 50 | # GPU setting 51 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 52 | 53 | parser.add_argument('--teacher', type=str, help='teacher weights/feats') 54 | 55 | parser.add_argument('--data', type=str, help='first model') 56 | 57 | parser.add_argument('--checkpoint_path', default='output/', type=str, 58 | help='where to save checkpoints. ') 59 | 60 | opt = parser.parse_args() 61 | 62 | iterations = opt.lr_decay_epochs.split(',') 63 | opt.lr_decay_epochs = list([]) 64 | for it in iterations: 65 | opt.lr_decay_epochs.append(int(it)) 66 | 67 | return opt 68 | 69 | 70 | # Extended version of ImageFolder to return index of image too. 71 | class ImageFolderEx(datasets.ImageFolder): 72 | 73 | def __getitem__(self, index): 74 | sample, target = super(ImageFolderEx, self).__getitem__(index) 75 | return index, sample, target 76 | 77 | 78 | # Create teacher model and load weights. For cached teacher load cahced features instead. 79 | def get_teacher_model(opt): 80 | teacher = None 81 | if opt.cache_teacher: 82 | train_feats, train_labels, indices = torch.load(opt.teacher) 83 | teacher = Teacher(cached=True, cached_feats=train_feats) 84 | 85 | elif opt.teacher_arch == 'resnet50': 86 | model_t = resnet50() 87 | model_t.fc = nn.Sequential() 88 | model_t = nn.Sequential(OrderedDict([('encoder_q', model_t)])) 89 | model_t = torch.nn.DataParallel(model_t).cuda() 90 | checkpoint = torch.load(opt.teacher) 91 | model_t.load_state_dict(checkpoint['state_dict'], strict=False) 92 | model_t = model_t.module.cpu() 93 | 94 | for p in model_t.parameters(): 95 | p.requires_grad = False 96 | teacher = Teacher(cached=False, model=model_t) 97 | 98 | return teacher 99 | 100 | 101 | # Create student query/key model 102 | def get_student_model(opt): 103 | student = None 104 | if opt.student_arch == 'alexnet': 105 | student = alexnet() 106 | student.fc = nn.Sequential() 107 | 108 | elif opt.student_arch == 'mobilenet': 109 | student = mobilenet() 110 | student.fc = nn.Sequential() 111 | 112 | elif opt.student_arch == 'resnet18': 113 | student = resnet18() 114 | student.fc = nn.Sequential() 115 | 116 | elif opt.student_arch == 'resnet50': 117 | student = resnet50(fc_dim=8192) 118 | 119 | return student 120 | 121 | 122 | # Create train loader 123 | def get_train_loader(opt): 124 | data_folder = os.path.join(opt.data, 'train') 125 | mean = [0.485, 0.456, 0.406] 126 | std = [0.229, 0.224, 0.225] 127 | normalize = transforms.Normalize(mean=mean, std=std) 128 | 129 | train_dataset = ImageFolderEx( 130 | data_folder, 131 | transforms.Compose([ 132 | transforms.RandomResizedCrop(224), 133 | transforms.RandomHorizontalFlip(), 134 | transforms.ToTensor(), 135 | normalize, 136 | ])) 137 | 138 | train_loader = torch.utils.data.DataLoader( 139 | train_dataset, batch_size=opt.batch_size, shuffle=True, 140 | num_workers=opt.num_workers, pin_memory=True) 141 | 142 | return train_loader 143 | 144 | 145 | def main(): 146 | 147 | args = parse_option() 148 | os.makedirs(args.checkpoint_path, exist_ok=True) 149 | 150 | if args.gpu is not None: 151 | print("Use GPU: {} for training".format(args.gpu)) 152 | 153 | train_loader = get_train_loader(args) 154 | 155 | teacher = get_teacher_model(args) 156 | student = get_student_model(args) 157 | 158 | # Calculate feature dimension of student and teacher 159 | teacher.eval() 160 | student.eval() 161 | tmp_input = torch.randn(2, 3, 224, 224) 162 | feat_t = teacher.forward(tmp_input, 0) 163 | feat_s = student(tmp_input) 164 | student_feats_dim = feat_s.shape[-1] 165 | teacher_feats_dim = feat_t.shape[-1] 166 | 167 | compress = CompReSS(teacher_feats_dim, student_feats_dim, args.compress_memory_size, args.compress_t) 168 | 169 | student = torch.nn.DataParallel(student).cuda() 170 | teacher.gpu() 171 | 172 | optimizer = torch.optim.SGD(student.parameters(), 173 | lr=args.learning_rate, 174 | momentum=args.momentum, 175 | weight_decay=args.weight_decay) 176 | 177 | cudnn.benchmark = True 178 | 179 | args.start_epoch = 1 180 | # routine 181 | for epoch in range(args.start_epoch, args.epochs + 1): 182 | 183 | adjust_learning_rate(epoch, args, optimizer) 184 | print("==> training...") 185 | 186 | time1 = time.time() 187 | loss = train_student(epoch, train_loader, teacher, student, compress, optimizer, args) 188 | 189 | time2 = time.time() 190 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 191 | 192 | 193 | 194 | # saving the model 195 | if epoch % args.save_freq == 0: 196 | print('==> Saving...') 197 | state = { 198 | 'opt': args, 199 | 'model': student.state_dict(), 200 | 'optimizer': optimizer.state_dict(), 201 | 'epoch': epoch, 202 | } 203 | 204 | save_file = os.path.join(args.checkpoint_path, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 205 | torch.save(state, save_file) 206 | 207 | # help release GPU memory 208 | del state 209 | torch.cuda.empty_cache() 210 | 211 | 212 | 213 | def train_student(epoch, train_loader, teacher, student, compress, optimizer, opt): 214 | """ 215 | one epoch training for CompReSS 216 | """ 217 | student.train() 218 | 219 | batch_time = AverageMeter() 220 | data_time = AverageMeter() 221 | loss_meter = AverageMeter() 222 | 223 | end = time.time() 224 | for idx, (index, inputs, _) in enumerate(train_loader): 225 | data_time.update(time.time() - end) 226 | 227 | bsz = inputs.size(0) 228 | 229 | inputs = inputs.float() 230 | if opt.gpu is not None: 231 | inputs = inputs.cuda(opt.gpu, non_blocking=True) 232 | else: 233 | inputs = inputs.cuda() 234 | 235 | # ===================forward===================== 236 | 237 | teacher_feats = teacher.forward(inputs, index) 238 | student_feats = student(inputs) 239 | 240 | loss = compress(teacher_feats, student_feats) 241 | 242 | # ===================backward===================== 243 | optimizer.zero_grad() 244 | loss.backward() 245 | optimizer.step() 246 | 247 | # ===================meters===================== 248 | loss_meter.update(loss.item(), bsz) 249 | 250 | torch.cuda.synchronize() 251 | batch_time.update(time.time() - end) 252 | end = time.time() 253 | 254 | # print info 255 | if (idx + 1) % opt.print_freq == 0: 256 | print('Train: [{0}][{1}/{2}]\t' 257 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 258 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 259 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format( 260 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 261 | data_time=data_time, loss=loss_meter)) 262 | sys.stdout.flush() 263 | 264 | return loss_meter.avg 265 | 266 | 267 | if __name__ == '__main__': 268 | main() 269 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def adjust_learning_rate(epoch, opt, optimizer): 8 | """Sets the learning rate to the initial LR decayed by 0.2 every steep step""" 9 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 10 | if steps > 0: 11 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 12 | print(new_lr) 13 | for param_group in optimizer.param_groups: 14 | param_group['lr'] = new_lr 15 | 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | def __init__(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | def accuracy(output, target, topk=(1,)): 40 | """Computes the accuracy over the k top predictions for the specified values of k""" 41 | with torch.no_grad(): 42 | maxk = max(topk) 43 | batch_size = target.size(0) 44 | 45 | _, pred = output.topk(maxk, 1, True, True) 46 | pred = pred.t() 47 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 48 | 49 | res = [] 50 | for k in topk: 51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 52 | res.append(correct_k.mul_(100.0 / batch_size)) 53 | return res 54 | 55 | 56 | if __name__ == '__main__': 57 | meter = AverageMeter() 58 | --------------------------------------------------------------------------------