├── imgs └── results.png ├── data └── index_list │ ├── cifar100 │ ├── session_2.txt │ ├── session_3.txt │ ├── session_4.txt │ ├── session_5.txt │ ├── session_6.txt │ ├── session_7.txt │ ├── session_8.txt │ └── session_9.txt │ ├── README.md │ ├── mini_imagenet │ ├── session_2.txt │ ├── session_3.txt │ ├── session_4.txt │ ├── session_5.txt │ ├── session_6.txt │ ├── session_7.txt │ ├── session_8.txt │ └── session_9.txt │ └── cub200 │ ├── session_6.txt │ ├── session_11.txt │ ├── session_2.txt │ ├── session_5.txt │ ├── session_3.txt │ ├── session_9.txt │ ├── session_4.txt │ ├── session_8.txt │ ├── session_7.txt │ └── session_10.txt ├── scripts ├── mini_imagenet.sh ├── cifar.sh └── cub.sh ├── base.py ├── postprocess_path.py ├── README.md ├── models ├── resnet20_cifar.py ├── teen │ ├── Network.py │ ├── helper.py │ └── fscil_trainer.py └── resnet18_encoder.py ├── train.py ├── dataloader ├── sampler.py ├── data_utils.py ├── miniimagenet │ ├── miniimagenet.py │ └── autoaugment.py ├── cub200 │ ├── cub200.py │ └── autoaugment.py └── cifar100 │ ├── cifar.py │ └── autoaugment.py └── utils.py /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangkiw/TEEN/HEAD/imgs/results.png -------------------------------------------------------------------------------- /data/index_list/cifar100/session_2.txt: -------------------------------------------------------------------------------- 1 | 29774 2 | 33344 3 | 4815 4 | 6772 5 | 48317 6 | 29918 7 | 33262 8 | 5138 9 | 7342 10 | 47874 11 | 28864 12 | 32471 13 | 4316 14 | 6436 15 | 47498 16 | 29802 17 | 33159 18 | 3730 19 | 5093 20 | 47740 21 | 30548 22 | 34549 23 | 2845 24 | 4996 25 | 47866 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_3.txt: -------------------------------------------------------------------------------- 1 | 28855 2 | 32834 3 | 4603 4 | 6914 5 | 48126 6 | 29932 7 | 33300 8 | 3860 9 | 5424 10 | 47055 11 | 29434 12 | 32604 13 | 4609 14 | 6380 15 | 47844 16 | 30456 17 | 34217 18 | 4361 19 | 6550 20 | 46896 21 | 29664 22 | 32857 23 | 4923 24 | 7502 25 | 47270 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_4.txt: -------------------------------------------------------------------------------- 1 | 31267 2 | 34427 3 | 4799 4 | 6611 5 | 47404 6 | 28509 7 | 31687 8 | 3477 9 | 5563 10 | 48003 11 | 29545 12 | 33412 13 | 5114 14 | 6808 15 | 47692 16 | 29209 17 | 33265 18 | 4131 19 | 6401 20 | 48102 21 | 31290 22 | 34432 23 | 6060 24 | 8451 25 | 48279 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_5.txt: -------------------------------------------------------------------------------- 1 | 32337 2 | 35646 3 | 6022 4 | 9048 5 | 48584 6 | 30768 7 | 34394 8 | 5091 9 | 6510 10 | 48023 11 | 30310 12 | 33230 13 | 5098 14 | 6671 15 | 48349 16 | 29690 17 | 33490 18 | 4260 19 | 5916 20 | 47371 21 | 31173 22 | 34943 23 | 4517 24 | 6494 25 | 47689 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_6.txt: -------------------------------------------------------------------------------- 1 | 30281 2 | 33894 3 | 3768 4 | 6113 5 | 48095 6 | 28913 7 | 32821 8 | 6172 9 | 8276 10 | 48004 11 | 31249 12 | 34088 13 | 5257 14 | 6961 15 | 47534 16 | 30404 17 | 34101 18 | 4985 19 | 6899 20 | 48115 21 | 31823 22 | 35148 23 | 3922 24 | 6548 25 | 48127 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_7.txt: -------------------------------------------------------------------------------- 1 | 30815 2 | 34450 3 | 3481 4 | 5089 5 | 47913 6 | 31683 7 | 34591 8 | 5251 9 | 7608 10 | 47984 11 | 29837 12 | 33823 13 | 4615 14 | 6448 15 | 47752 16 | 31222 17 | 34079 18 | 5686 19 | 7919 20 | 48675 21 | 28567 22 | 32964 23 | 5009 24 | 6201 25 | 47039 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_8.txt: -------------------------------------------------------------------------------- 1 | 29355 2 | 33909 3 | 3982 4 | 5389 5 | 47166 6 | 31058 7 | 35180 8 | 5177 9 | 6890 10 | 48032 11 | 31176 12 | 35098 13 | 5235 14 | 7861 15 | 47830 16 | 30874 17 | 34639 18 | 5266 19 | 7489 20 | 47323 21 | 29960 22 | 34050 23 | 4988 24 | 7434 25 | 48208 26 | -------------------------------------------------------------------------------- /data/index_list/cifar100/session_9.txt: -------------------------------------------------------------------------------- 1 | 30463 2 | 34580 3 | 5230 4 | 6813 5 | 48605 6 | 31702 7 | 35249 8 | 5854 9 | 7765 10 | 48444 11 | 30380 12 | 34028 13 | 5211 14 | 7433 15 | 47988 16 | 31348 17 | 34021 18 | 4929 19 | 7033 20 | 47904 21 | 30627 22 | 33728 23 | 4895 24 | 6299 25 | 47507 26 | -------------------------------------------------------------------------------- /scripts/mini_imagenet.sh: -------------------------------------------------------------------------------- 1 | python train.py teen \ 2 | -project teen \ 3 | -dataset mini_imagenet \ 4 | -dataroot MINI_DATA_DIR \ 5 | -base_mode 'ft_cos' \ 6 | -new_mode 'avg_cos' \ 7 | -gamma 0.1 \ 8 | -lr_base 0.1 \ 9 | -decay 0.0005 \ 10 | -epochs_base 1000 \ 11 | -schedule Cosine \ 12 | -tmax 1000 \ 13 | -gpu '2' \ 14 | -temperature 32 \ 15 | -batch_size_base 128 16 | -------------------------------------------------------------------------------- /scripts/cifar.sh: -------------------------------------------------------------------------------- 1 | python train.py teen \ 2 | -project teen \ 3 | -dataset cifar100 \ 4 | -dataroot CIFAR_DATA_DIR \ 5 | -base_mode 'ft_cos' \ 6 | -new_mode 'avg_cos' \ 7 | -lr_base 0.1 \ 8 | -decay 0.0005 \ 9 | -epochs_base 600 \ 10 | -batch_size_base 256 \ 11 | -schedule Cosine \ 12 | -tmax 600 \ 13 | -gpu '2' \ 14 | -temperature 16 \ 15 | -softmax_t 16 \ 16 | -shift_weight 0.1 17 | -------------------------------------------------------------------------------- /scripts/cub.sh: -------------------------------------------------------------------------------- 1 | python train.py teen \ 2 | -project teen \ 3 | -dataset cub200 \ 4 | -dataroot CUB_DATA_DIR \ 5 | -base_mode 'ft_cos' \ 6 | -new_mode 'avg_cos' \ 7 | -gamma 0.25 \ 8 | -lr_base 0.004 \ 9 | -lr_new 0.1 \ 10 | -decay 0.0005 \ 11 | -epochs_base 400 \ 12 | -schedule Milestone \ 13 | -milestones 50 100 150 200 250 300 \ 14 | -gpu '3' \ 15 | -temperature 32 \ 16 | -batch_size_base 128 \ 17 | -softmax_t 16 \ 18 | -shift_weight 0.5 19 | -------------------------------------------------------------------------------- /data/index_list/README.md: -------------------------------------------------------------------------------- 1 | ### How to use the index files for the experiments ? 2 | 3 | The index files are named like "session_x.txt", where x indicates the session number. Each index file stores the indexes of the images that are selected for the session. 4 | "session_1.txt" stores all the base class training images. Each "session_t.txt" (t>1) stores the 25 (5 classes and 5 shots per class) few-shot new class training images. 5 | You may adopt the following steps to perform the experiments. 6 | 7 | First, at session 1, train a base model using the images in session_1.txt; 8 | 9 | Then, at session t (t>1), finetune the model trained at the previous session (t-1), only using the images in session_t.txt. 10 | 11 | For evaluating the model at session t, first joint all the encountered test sets as a single test set. Then test the current model using all the test images and compute the recognition accuracy. 12 | -------------------------------------------------------------------------------- /base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from dataloader.data_utils import * 3 | from utils import ( 4 | Averager, Timer 5 | ) 6 | 7 | class Trainer(object, metaclass=abc.ABCMeta): 8 | def __init__(self, args): 9 | self.args = args 10 | self.args = set_up_datasets(self.args) 11 | self.dt, self.ft = Averager(), Averager() 12 | self.bt, self.ot = Averager(), Averager() 13 | self.timer = Timer() 14 | self.init_log() 15 | 16 | @abc.abstractmethod 17 | def train(self): 18 | pass 19 | 20 | def init_log(self): 21 | # train statistics 22 | self.trlog = {} 23 | self.trlog['train_loss'] = [] 24 | self.trlog['val_loss'] = [] 25 | self.trlog['test_loss'] = [] 26 | self.trlog['train_acc'] = [] 27 | self.trlog['val_acc'] = [] 28 | self.trlog['test_acc'] = [] 29 | self.trlog['max_acc_epoch'] = 0 30 | self.trlog['max_acc'] = [0.0] * self.args.sessions 31 | 32 | self.trlog['seen_acc'] = [] 33 | self.trlog['unseen_acc'] = [] -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_2.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n03544143/n0354414300000811.jpg 2 | MINI-ImageNet/train/n03544143/n0354414300000906.jpg 3 | MINI-ImageNet/train/n03544143/n0354414300000122.jpg 4 | MINI-ImageNet/train/n03544143/n0354414300000185.jpg 5 | MINI-ImageNet/train/n03544143/n0354414300001258.jpg 6 | MINI-ImageNet/train/n03584254/n0358425400000818.jpg 7 | MINI-ImageNet/train/n03584254/n0358425400000891.jpg 8 | MINI-ImageNet/train/n03584254/n0358425400000145.jpg 9 | MINI-ImageNet/train/n03584254/n0358425400000186.jpg 10 | MINI-ImageNet/train/n03584254/n0358425400001254.jpg 11 | MINI-ImageNet/train/n03676483/n0367648300000772.jpg 12 | MINI-ImageNet/train/n03676483/n0367648300000864.jpg 13 | MINI-ImageNet/train/n03676483/n0367648300000108.jpg 14 | MINI-ImageNet/train/n03676483/n0367648300000164.jpg 15 | MINI-ImageNet/train/n03676483/n0367648300001236.jpg 16 | MINI-ImageNet/train/n03770439/n0377043900000738.jpg 17 | MINI-ImageNet/train/n03770439/n0377043900000835.jpg 18 | MINI-ImageNet/train/n03770439/n0377043900000111.jpg 19 | MINI-ImageNet/train/n03770439/n0377043900000154.jpg 20 | MINI-ImageNet/train/n03770439/n0377043900001238.jpg 21 | MINI-ImageNet/train/n03773504/n0377350400000781.jpg 22 | MINI-ImageNet/train/n03773504/n0377350400000880.jpg 23 | MINI-ImageNet/train/n03773504/n0377350400000132.jpg 24 | MINI-ImageNet/train/n03773504/n0377350400000177.jpg 25 | MINI-ImageNet/train/n03773504/n0377350400001237.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_3.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n03775546/n0377554600000808.jpg 2 | MINI-ImageNet/train/n03775546/n0377554600000891.jpg 3 | MINI-ImageNet/train/n03775546/n0377554600000123.jpg 4 | MINI-ImageNet/train/n03775546/n0377554600000179.jpg 5 | MINI-ImageNet/train/n03775546/n0377554600001241.jpg 6 | MINI-ImageNet/train/n03838899/n0383889900000800.jpg 7 | MINI-ImageNet/train/n03838899/n0383889900000881.jpg 8 | MINI-ImageNet/train/n03838899/n0383889900000124.jpg 9 | MINI-ImageNet/train/n03838899/n0383889900000165.jpg 10 | MINI-ImageNet/train/n03838899/n0383889900001242.jpg 11 | MINI-ImageNet/train/n03854065/n0385406500000796.jpg 12 | MINI-ImageNet/train/n03854065/n0385406500000899.jpg 13 | MINI-ImageNet/train/n03854065/n0385406500000103.jpg 14 | MINI-ImageNet/train/n03854065/n0385406500000146.jpg 15 | MINI-ImageNet/train/n03854065/n0385406500001249.jpg 16 | MINI-ImageNet/train/n03888605/n0388860500000783.jpg 17 | MINI-ImageNet/train/n03888605/n0388860500000875.jpg 18 | MINI-ImageNet/train/n03888605/n0388860500000124.jpg 19 | MINI-ImageNet/train/n03888605/n0388860500000184.jpg 20 | MINI-ImageNet/train/n03888605/n0388860500001245.jpg 21 | MINI-ImageNet/train/n03908618/n0390861800000778.jpg 22 | MINI-ImageNet/train/n03908618/n0390861800000855.jpg 23 | MINI-ImageNet/train/n03908618/n0390861800000127.jpg 24 | MINI-ImageNet/train/n03908618/n0390861800000175.jpg 25 | MINI-ImageNet/train/n03908618/n0390861800001243.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_4.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n03924679/n0392467900000785.jpg 2 | MINI-ImageNet/train/n03924679/n0392467900000871.jpg 3 | MINI-ImageNet/train/n03924679/n0392467900000115.jpg 4 | MINI-ImageNet/train/n03924679/n0392467900000164.jpg 5 | MINI-ImageNet/train/n03924679/n0392467900001240.jpg 6 | MINI-ImageNet/train/n03980874/n0398087400000794.jpg 7 | MINI-ImageNet/train/n03980874/n0398087400000911.jpg 8 | MINI-ImageNet/train/n03980874/n0398087400000124.jpg 9 | MINI-ImageNet/train/n03980874/n0398087400000184.jpg 10 | MINI-ImageNet/train/n03980874/n0398087400001229.jpg 11 | MINI-ImageNet/train/n03998194/n0399819400000790.jpg 12 | MINI-ImageNet/train/n03998194/n0399819400000876.jpg 13 | MINI-ImageNet/train/n03998194/n0399819400000116.jpg 14 | MINI-ImageNet/train/n03998194/n0399819400000164.jpg 15 | MINI-ImageNet/train/n03998194/n0399819400001248.jpg 16 | MINI-ImageNet/train/n04067472/n0406747200000802.jpg 17 | MINI-ImageNet/train/n04067472/n0406747200000904.jpg 18 | MINI-ImageNet/train/n04067472/n0406747200000130.jpg 19 | MINI-ImageNet/train/n04067472/n0406747200000175.jpg 20 | MINI-ImageNet/train/n04067472/n0406747200001257.jpg 21 | MINI-ImageNet/train/n04146614/n0414661400000809.jpg 22 | MINI-ImageNet/train/n04146614/n0414661400000921.jpg 23 | MINI-ImageNet/train/n04146614/n0414661400000132.jpg 24 | MINI-ImageNet/train/n04146614/n0414661400000204.jpg 25 | MINI-ImageNet/train/n04146614/n0414661400001243.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_5.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n04149813/n0414981300000789.jpg 2 | MINI-ImageNet/train/n04149813/n0414981300000879.jpg 3 | MINI-ImageNet/train/n04149813/n0414981300000152.jpg 4 | MINI-ImageNet/train/n04149813/n0414981300000196.jpg 5 | MINI-ImageNet/train/n04149813/n0414981300001229.jpg 6 | MINI-ImageNet/train/n04243546/n0424354600000785.jpg 7 | MINI-ImageNet/train/n04243546/n0424354600000868.jpg 8 | MINI-ImageNet/train/n04243546/n0424354600000109.jpg 9 | MINI-ImageNet/train/n04243546/n0424354600000159.jpg 10 | MINI-ImageNet/train/n04243546/n0424354600001243.jpg 11 | MINI-ImageNet/train/n04251144/n0425114400000797.jpg 12 | MINI-ImageNet/train/n04251144/n0425114400000891.jpg 13 | MINI-ImageNet/train/n04251144/n0425114400000123.jpg 14 | MINI-ImageNet/train/n04251144/n0425114400000170.jpg 15 | MINI-ImageNet/train/n04251144/n0425114400001244.jpg 16 | MINI-ImageNet/train/n04258138/n0425813800000807.jpg 17 | MINI-ImageNet/train/n04258138/n0425813800000900.jpg 18 | MINI-ImageNet/train/n04258138/n0425813800000135.jpg 19 | MINI-ImageNet/train/n04258138/n0425813800000193.jpg 20 | MINI-ImageNet/train/n04258138/n0425813800001252.jpg 21 | MINI-ImageNet/train/n04275548/n0427554800000755.jpg 22 | MINI-ImageNet/train/n04275548/n0427554800000854.jpg 23 | MINI-ImageNet/train/n04275548/n0427554800000127.jpg 24 | MINI-ImageNet/train/n04275548/n0427554800000168.jpg 25 | MINI-ImageNet/train/n04275548/n0427554800001238.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_6.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n04296562/n0429656200000772.jpg 2 | MINI-ImageNet/train/n04296562/n0429656200000862.jpg 3 | MINI-ImageNet/train/n04296562/n0429656200000119.jpg 4 | MINI-ImageNet/train/n04296562/n0429656200000158.jpg 5 | MINI-ImageNet/train/n04296562/n0429656200001223.jpg 6 | MINI-ImageNet/train/n04389033/n0438903300000802.jpg 7 | MINI-ImageNet/train/n04389033/n0438903300000912.jpg 8 | MINI-ImageNet/train/n04389033/n0438903300000157.jpg 9 | MINI-ImageNet/train/n04389033/n0438903300000202.jpg 10 | MINI-ImageNet/train/n04389033/n0438903300001261.jpg 11 | MINI-ImageNet/train/n04418357/n0441835700000746.jpg 12 | MINI-ImageNet/train/n04418357/n0441835700000848.jpg 13 | MINI-ImageNet/train/n04418357/n0441835700000111.jpg 14 | MINI-ImageNet/train/n04418357/n0441835700000163.jpg 15 | MINI-ImageNet/train/n04418357/n0441835700001226.jpg 16 | MINI-ImageNet/train/n04435653/n0443565300000828.jpg 17 | MINI-ImageNet/train/n04435653/n0443565300000932.jpg 18 | MINI-ImageNet/train/n04435653/n0443565300000139.jpg 19 | MINI-ImageNet/train/n04435653/n0443565300000203.jpg 20 | MINI-ImageNet/train/n04435653/n0443565300001245.jpg 21 | MINI-ImageNet/train/n04443257/n0444325700000764.jpg 22 | MINI-ImageNet/train/n04443257/n0444325700000852.jpg 23 | MINI-ImageNet/train/n04443257/n0444325700000125.jpg 24 | MINI-ImageNet/train/n04443257/n0444325700000183.jpg 25 | MINI-ImageNet/train/n04443257/n0444325700001216.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_7.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n04509417/n0450941700000801.jpg 2 | MINI-ImageNet/train/n04509417/n0450941700000882.jpg 3 | MINI-ImageNet/train/n04509417/n0450941700000149.jpg 4 | MINI-ImageNet/train/n04509417/n0450941700000201.jpg 5 | MINI-ImageNet/train/n04509417/n0450941700001242.jpg 6 | MINI-ImageNet/train/n04515003/n0451500300000791.jpg 7 | MINI-ImageNet/train/n04515003/n0451500300000893.jpg 8 | MINI-ImageNet/train/n04515003/n0451500300000112.jpg 9 | MINI-ImageNet/train/n04515003/n0451500300000161.jpg 10 | MINI-ImageNet/train/n04515003/n0451500300001259.jpg 11 | MINI-ImageNet/train/n04522168/n0452216800000790.jpg 12 | MINI-ImageNet/train/n04522168/n0452216800000894.jpg 13 | MINI-ImageNet/train/n04522168/n0452216800000124.jpg 14 | MINI-ImageNet/train/n04522168/n0452216800000180.jpg 15 | MINI-ImageNet/train/n04522168/n0452216800001258.jpg 16 | MINI-ImageNet/train/n04596742/n0459674200000809.jpg 17 | MINI-ImageNet/train/n04596742/n0459674200000897.jpg 18 | MINI-ImageNet/train/n04596742/n0459674200000132.jpg 19 | MINI-ImageNet/train/n04596742/n0459674200000189.jpg 20 | MINI-ImageNet/train/n04596742/n0459674200001241.jpg 21 | MINI-ImageNet/train/n04604644/n0460464400000828.jpg 22 | MINI-ImageNet/train/n04604644/n0460464400000904.jpg 23 | MINI-ImageNet/train/n04604644/n0460464400000124.jpg 24 | MINI-ImageNet/train/n04604644/n0460464400000175.jpg 25 | MINI-ImageNet/train/n04604644/n0460464400001256.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_8.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n04612504/n0461250400000737.jpg 2 | MINI-ImageNet/train/n04612504/n0461250400000810.jpg 3 | MINI-ImageNet/train/n04612504/n0461250400000149.jpg 4 | MINI-ImageNet/train/n04612504/n0461250400000194.jpg 5 | MINI-ImageNet/train/n04612504/n0461250400001160.jpg 6 | MINI-ImageNet/train/n06794110/n0679411000000773.jpg 7 | MINI-ImageNet/train/n06794110/n0679411000000882.jpg 8 | MINI-ImageNet/train/n06794110/n0679411000000124.jpg 9 | MINI-ImageNet/train/n06794110/n0679411000000199.jpg 10 | MINI-ImageNet/train/n06794110/n0679411000001256.jpg 11 | MINI-ImageNet/train/n07584110/n0758411000000764.jpg 12 | MINI-ImageNet/train/n07584110/n0758411000000855.jpg 13 | MINI-ImageNet/train/n07584110/n0758411000000133.jpg 14 | MINI-ImageNet/train/n07584110/n0758411000000180.jpg 15 | MINI-ImageNet/train/n07584110/n0758411000001154.jpg 16 | MINI-ImageNet/train/n07613480/n0761348000000770.jpg 17 | MINI-ImageNet/train/n07613480/n0761348000000868.jpg 18 | MINI-ImageNet/train/n07613480/n0761348000000140.jpg 19 | MINI-ImageNet/train/n07613480/n0761348000000183.jpg 20 | MINI-ImageNet/train/n07613480/n0761348000001254.jpg 21 | MINI-ImageNet/train/n07697537/n0769753700000774.jpg 22 | MINI-ImageNet/train/n07697537/n0769753700000862.jpg 23 | MINI-ImageNet/train/n07697537/n0769753700000142.jpg 24 | MINI-ImageNet/train/n07697537/n0769753700000181.jpg 25 | MINI-ImageNet/train/n07697537/n0769753700001231.jpg 26 | -------------------------------------------------------------------------------- /data/index_list/mini_imagenet/session_9.txt: -------------------------------------------------------------------------------- 1 | MINI-ImageNet/train/n07747607/n0774760700000787.jpg 2 | MINI-ImageNet/train/n07747607/n0774760700000894.jpg 3 | MINI-ImageNet/train/n07747607/n0774760700000140.jpg 4 | MINI-ImageNet/train/n07747607/n0774760700000190.jpg 5 | MINI-ImageNet/train/n07747607/n0774760700001253.jpg 6 | MINI-ImageNet/train/n09246464/n0924646400000794.jpg 7 | MINI-ImageNet/train/n09246464/n0924646400000885.jpg 8 | MINI-ImageNet/train/n09246464/n0924646400000107.jpg 9 | MINI-ImageNet/train/n09246464/n0924646400000145.jpg 10 | MINI-ImageNet/train/n09246464/n0924646400001250.jpg 11 | MINI-ImageNet/train/n09256479/n0925647900000823.jpg 12 | MINI-ImageNet/train/n09256479/n0925647900000899.jpg 13 | MINI-ImageNet/train/n09256479/n0925647900000154.jpg 14 | MINI-ImageNet/train/n09256479/n0925647900000208.jpg 15 | MINI-ImageNet/train/n09256479/n0925647900001246.jpg 16 | MINI-ImageNet/train/n13054560/n1305456000000771.jpg 17 | MINI-ImageNet/train/n13054560/n1305456000000856.jpg 18 | MINI-ImageNet/train/n13054560/n1305456000000102.jpg 19 | MINI-ImageNet/train/n13054560/n1305456000000159.jpg 20 | MINI-ImageNet/train/n13054560/n1305456000001230.jpg 21 | MINI-ImageNet/train/n13133613/n1313361300000758.jpg 22 | MINI-ImageNet/train/n13133613/n1313361300000871.jpg 23 | MINI-ImageNet/train/n13133613/n1313361300000106.jpg 24 | MINI-ImageNet/train/n13133613/n1313361300000160.jpg 25 | MINI-ImageNet/train/n13133613/n1313361300001231.jpg 26 | -------------------------------------------------------------------------------- /postprocess_path.py: -------------------------------------------------------------------------------- 1 | from utils import ensure_path 2 | import os 3 | import datetime 4 | def sub_set_save_path(args): 5 | if args.project == 'teen': 6 | args.save_path = args.save_path +\ 7 | f"-tw_{args.softmax_t}-{args.shift_weight}-{args.soft_mode}" 8 | else: 9 | raise NotImplementedError 10 | return args 11 | 12 | def set_save_path(args): 13 | # base info 14 | time_str = datetime.datetime.now().strftime('%m%d-%H-%M-%S-%f')[:-3] 15 | args.time_str = time_str 16 | mode = args.base_mode + '-' + args.new_mode 17 | if not args.not_data_init: 18 | mode = mode + '-' + 'data_init' 19 | args.save_path = '%s/' % args.dataset 20 | args.save_path = args.save_path + '%s/' % args.project 21 | args.save_path = args.save_path + '%s-start_%d/' % (mode, args.start_session) 22 | 23 | # optimizer & scheduler 24 | if args.schedule == 'Milestone': 25 | mile_stone = str(args.milestones).replace(" ", "").replace(',', '_')[1:-1] 26 | args.save_path = args.save_path +\ 27 | f'{args.time_str}-Epo_{args.epochs_base}-Bs_{args.batch_size_base}'\ 28 | f'-{args.optim}-Lr_{args.lr_base}-decay{args.decay}-Mom_{args.momentum}'\ 29 | f'-MS_{mile_stone}-Gam_{args.gamma}' 30 | 31 | elif args.schedule == 'Step': 32 | args.save_path = args.save_path +\ 33 | f'{args.time_str}-Epo_{args.epochs_base}-Bs_{args.batch_size_base}'\ 34 | f'-{args.optim}-Lr_{args.lr_base}-decay{args.decay}-Mom_{args.momentum}'\ 35 | f'-Step_{args.step}-Gam_{args.gamma}' 36 | 37 | elif args.schedule == 'Cosine': 38 | args.save_path = args.save_path +\ 39 | f'{args.time_str}-Epo_{args.epochs_base}-Bs_{args.batch_size_base}'\ 40 | f'-{args.optim}-Lr_{args.lr_base}-decay{args.decay}-Mom_{args.momentum}'\ 41 | f'-Max_{args.tmax}' 42 | else: 43 | raise NotImplementedError 44 | 45 | # feature normalize 46 | if args.feat_norm: 47 | args.save_path = args.save_path + '-NormT' 48 | else: 49 | args.save_path = args.save_path + '-NormF' 50 | 51 | # train mode 52 | if 'cos' in mode: 53 | args.save_path = args.save_path + '-T_%.2f' % (args.temperature) 54 | if 'ft' in args.new_mode: 55 | args.save_path = args.save_path + '-ftLR_%.3f-ftEpoch_%d' % ( 56 | args.lr_new, args.epochs_new) 57 | 58 | # specific parameters 59 | args = sub_set_save_path(args) 60 | 61 | if args.debug: 62 | args.save_path = os.path.join('debug', args.save_path) 63 | 64 | args.save_path = os.path.join('./checkpoint', args.save_path) 65 | ensure_path(args.save_path) 66 | return args 67 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_6.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0055_141524.jpg 2 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0124_142121.jpg 3 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0133_141069.jpg 4 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0111_143101.jpg 5 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0107_141181.jpg 6 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0079_143998.jpg 7 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0082_144372.jpg 8 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0029_144140.jpg 9 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0066_144541.jpg 10 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0046_144229.jpg 11 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0009_145057.jpg 12 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0116_145607.jpg 13 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0123_145774.jpg 14 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0006_145594.jpg 15 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0013_145553.jpg 16 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0071_148796.jpg 17 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0077_149196.jpg 18 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0030_147825.jpg 19 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0095_149960.jpg 20 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0083_148096.jpg 21 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0009_150954.jpg 22 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0045_150752.jpg 23 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0046_150905.jpg 24 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0103_150493.jpg 25 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0004_150948.jpg 26 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0027_151456.jpg 27 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0077_152255.jpg 28 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0125_151399.jpg 29 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0045_151227.jpg 30 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0119_152709.jpg 31 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0092_153361.jpg 32 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0020_153458.jpg 33 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0060_153190.jpg 34 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0119_153950.jpg 35 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0037_153637.jpg 36 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0018_154825.jpg 37 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0070_154844.jpg 38 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0064_154771.jpg 39 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0058_797399.jpg 40 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0060_154820.jpg 41 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0013_155329.jpg 42 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0079_155394.jpg 43 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0019_155216.jpg 44 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0051_155344.jpg 45 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0081_155256.jpg 46 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0033_155511.jpg 47 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0069_155544.jpg 48 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0096_155449.jpg 49 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0104_155529.jpg 50 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0070_155732.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_11.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0020_183255.jpg 2 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0005_183414.jpg 3 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0068_183662.jpg 4 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0013_182721.jpg 5 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0095_183688.jpg 6 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0040_184061.jpg 7 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0031_184120.jpg 8 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0090_183964.jpg 9 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0005_184098.jpg 10 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0136_184534.jpg 11 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0083_185190.jpg 12 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0084_184715.jpg 13 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0015_184981.jpg 14 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0110_185216.jpg 15 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0081_185080.jpg 16 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0089_186023.jpg 17 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0097_186015.jpg 18 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0025_185696.jpg 19 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0066_186028.jpg 20 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0033_186014.jpg 21 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0113_186675.jpg 22 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0099_186237.jpg 23 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0014_186525.jpg 24 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0020_186702.jpg 25 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0128_186581.jpg 26 | CUB_200_2011/images/196.House_Wren/House_Wren_0108_187102.jpg 27 | CUB_200_2011/images/196.House_Wren/House_Wren_0107_187230.jpg 28 | CUB_200_2011/images/196.House_Wren/House_Wren_0035_187708.jpg 29 | CUB_200_2011/images/196.House_Wren/House_Wren_0094_187226.jpg 30 | CUB_200_2011/images/196.House_Wren/House_Wren_0122_187331.jpg 31 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0056_188241.jpg 32 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0141_188796.jpg 33 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0006_188126.jpg 34 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0044_188270.jpg 35 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0039_188201.jpg 36 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0122_189042.jpg 37 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0063_189121.jpg 38 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0069_188969.jpg 39 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0111_189443.jpg 40 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0027_189331.jpg 41 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0066_189637.jpg 42 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0030_190311.jpg 43 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0075_189578.jpg 44 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0065_189675.jpg 45 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0037_190123.jpg 46 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0004_190606.jpg 47 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0054_190398.jpg 48 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0010_190572.jpg 49 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0126_190407.jpg 50 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0032_190592.jpg 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration (TEEN) 2 | 3 | The code repository for "Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration" [[paper]](https://arxiv.org/abs/2312.05229) (NeurIPS'23) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry: 4 | 5 | @inproceedings{ 6 | wang2023teen, 7 | title={Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration}, 8 | author={Wang, Qi-Wei and Zhou, Da-Wei and Zhang, Yi-Kai and Zhan, De-Chuan, and Ye, Han-Jia}, 9 | booktitle={NeurIPS}, 10 | year={2023} 11 | } 12 | 13 | ## Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration 14 | 15 | Real-world scenarios are usually accompanied by continuously appearing classes with scarce labeled samples, which require the machine learning model to incrementally learn new classes and maintain the knowledge of base classes. In this Few-Shot Class-Incremental Learning (FSCIL) scenario, existing methods either introduce extra learnable components or rely on a frozen feature extractor to mitigate catastrophic forgetting and overfitting problems. However, we find a tendency for existing methods to misclassify the samples of new classes into base classes, which leads to the poor performance of new classes. In other words, the strong discriminability of base classes distracts the classification of new classes. To figure out this intriguing phenomenon, we observe that although the feature extractor is only trained on base classes, it can surprisingly represent the semantic 16 | similarity between the base and unseen new classes. Building upon these analyses, we propose a simple yet effective Training-frEE calibratioN (TEEN) strategy to enhance the discriminability of new classes by fusing the new prototypes (i.e., mean features of a class) with weighted base prototypes. In addition to standard benchmarks in FSCIL, TEEN demonstrates remarkable performance and consis- tent improvements over baseline methods in the few-shot learning scenario. 17 | 18 | 19 | ## Results 20 | 21 | 22 | Please refer to our [paper](https://arxiv.org/abs/2312.05229) for detailed values. 23 | 24 | ## Prerequisites 25 | 26 | The following packages are required to run the scripts: 27 | 28 | - [PyTorch-1.4 and torchvision](https://pytorch.org) 29 | 30 | - tqdm 31 | 32 | ## Dataset 33 | We provide the source code on three benchmark datasets, i.e., CIFAR100, CUB200 and miniImageNet. Please follow the guidelines in [CEC](https://github.com/icoz69/CEC-CVPR2021) to prepare them. 34 | 35 | 36 | ## Code Structures and details 37 | There are four parts in the code. 38 | - `models`: It contains the backbone network and training protocols for the experiment. 39 | - `data`: Images and splits for the data sets. 40 | - `dataloader`: Dataloader of different datasets. 41 | 42 | ## Training scripts 43 | 44 | Please see `scripts` folder. 45 | 46 | 47 | ## Acknowledgment 48 | We thank the following repos providing helpful components/functions in our work. 49 | 50 | - [Awesome Few-Shot Class-Incremental Learning](https://github.com/zhoudw-zdw/Awesome-Few-Shot-Class-Incremental-Learning) 51 | - [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/G-U-N/PyCIL) 52 | - [CEC](https://github.com/icoz69/CEC-CVPR2021) 53 | - [FACT](https://github.com/zhoudw-zdw/CVPR22-Fact) 54 | 55 | 56 | 57 | ## Contact 58 | If there are any questions, please feel free to contact with the author: Qi-Wei Wang (wangqiwei@lamda.nju.edu.cn). Enjoy the code. -------------------------------------------------------------------------------- /data/index_list/cub200/session_2.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0081_96148.jpg 2 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0075_96422.jpg 3 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0026_95832.jpg 4 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0022_95897.jpg 5 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0044_96028.jpg 6 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0072_98035.jpg 7 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0004_98257.jpg 8 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0060_795045.jpg 9 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0039_795063.jpg 10 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0040_795051.jpg 11 | CUB_200_2011/images/103.Sayornis/Sayornis_0099_98593.jpg 12 | CUB_200_2011/images/103.Sayornis/Sayornis_0133_99129.jpg 13 | CUB_200_2011/images/103.Sayornis/Sayornis_0098_98419.jpg 14 | CUB_200_2011/images/103.Sayornis/Sayornis_0011_98610.jpg 15 | CUB_200_2011/images/103.Sayornis/Sayornis_0114_98976.jpg 16 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0037_99954.jpg 17 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0067_100237.jpg 18 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0019_99810.jpg 19 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0058_100218.jpg 20 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0113_99939.jpg 21 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0038_100443.jpg 22 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0018_796403.jpg 23 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0013_796439.jpg 24 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0026_100456.jpg 25 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0004_100479.jpg 26 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0004_100733.jpg 27 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0028_100765.jpg 28 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0062_100693.jpg 29 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0042_100760.jpg 30 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0030_100725.jpg 31 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0009_102112.jpg 32 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0068_101216.jpg 33 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0099_102534.jpg 34 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0001_101213.jpg 35 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0095_101831.jpg 36 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0063_797361.jpg 37 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0050_797374.jpg 38 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0010_797350.jpg 39 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0002_797370.jpg 40 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0026_797357.jpg 41 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0036_103231.jpg 42 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0071_103266.jpg 43 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0085_103155.jpg 44 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0056_103241.jpg 45 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0049_103176.jpg 46 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0106_104216.jpg 47 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0086_104755.jpg 48 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0124_104141.jpg 49 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0117_104227.jpg 50 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0036_104173.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_5.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0079_125579.jpg 2 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0080_125606.jpg 3 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0084_125532.jpg 4 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0094_125602.jpg 5 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0019_125558.jpg 6 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0068_126156.jpg 7 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0100_126267.jpg 8 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0072_127080.jpg 9 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0033_127728.jpg 10 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0095_127118.jpg 11 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0125_128832.jpg 12 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0056_128906.jpg 13 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0085_129180.jpg 14 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0042_128899.jpg 15 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0021_128804.jpg 16 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0096_129388.jpg 17 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0046_129434.jpg 18 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0043_129358.jpg 19 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0019_129407.jpg 20 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0067_129380.jpg 21 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0003_129623.jpg 22 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0045_129483.jpg 23 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0020_129747.jpg 24 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0067_129959.jpg 25 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0053_129501.jpg 26 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0018_130709.jpg 27 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0048_132793.jpg 28 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0070_130127.jpg 29 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0066_130214.jpg 30 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0049_130181.jpg 31 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0018_132974.jpg 32 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0023_134314.jpg 33 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0066_133206.jpg 34 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0050_134054.jpg 35 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0075_134516.jpg 36 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0087_137354.jpg 37 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0043_136878.jpg 38 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0111_135253.jpg 39 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0108_135068.jpg 40 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0064_136322.jpg 41 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0107_138577.jpg 42 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0077_137626.jpg 43 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0040_137885.jpg 44 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0033_137603.jpg 45 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0132_138001.jpg 46 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0032_140425.jpg 47 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0046_139802.jpg 48 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0111_139605.jpg 49 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0116_139923.jpg 50 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0095_139882.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_3.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0127_105742.jpg 2 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0018_26407.jpg 3 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0019_106132.jpg 4 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0011_104921.jpg 5 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0033_105686.jpg 6 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0092_797048.jpg 7 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0042_797056.jpg 8 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0049_797025.jpg 9 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0083_797051.jpg 10 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0063_797042.jpg 11 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0021_794576.jpg 12 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0018_794584.jpg 13 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0025_794564.jpg 14 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0041_794582.jpg 15 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0036_794572.jpg 16 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0019_107192.jpg 17 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0088_107220.jpg 18 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0097_106935.jpg 19 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0055_107213.jpg 20 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0010_107375.jpg 21 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0068_107422.jpg 22 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0036_107451.jpg 23 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0041_796711.jpg 24 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0014_107435.jpg 25 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0076_107393.jpg 26 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0064_108204.jpg 27 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0038_109234.jpg 28 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0098_108644.jpg 29 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0110_108974.jpg 30 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0023_108684.jpg 31 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0104_110699.jpg 32 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0098_110735.jpg 33 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0003_110672.jpg 34 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0029_110720.jpg 35 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0087_110946.jpg 36 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0092_111413.jpg 37 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0111_112968.jpg 38 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0080_111099.jpg 39 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0130_110985.jpg 40 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0053_111388.jpg 41 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0069_113827.jpg 42 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0130_113846.jpg 43 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0091_113486.jpg 44 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0043_113607.jpg 45 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0108_114154.jpg 46 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0104_114908.jpg 47 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0086_115484.jpg 48 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0055_114809.jpg 49 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0012_115324.jpg 50 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0035_114866.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_9.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0023_166764.jpg 2 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0050_166820.jpg 3 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0043_166708.jpg 4 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0098_166794.jpg 5 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0015_166713.jpg 6 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0108_167259.jpg 7 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0098_167293.jpg 8 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0104_167096.jpg 9 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0110_167268.jpg 10 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0081_167234.jpg 11 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0062_168119.jpg 12 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0050_168166.jpg 13 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0055_168600.jpg 14 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0118_167640.jpg 15 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0067_167588.jpg 16 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0083_170281.jpg 17 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0012_170857.jpg 18 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0015_169626.jpg 19 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0126_170311.jpg 20 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0136_170276.jpg 21 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0017_171678.jpg 22 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0127_171742.jpg 23 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0060_171635.jpg 24 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0056_172064.jpg 25 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0102_171147.jpg 26 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0073_172771.jpg 27 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0120_173097.jpg 28 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0063_172682.jpg 29 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0053_173290.jpg 30 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0080_172724.jpg 31 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0062_174412.jpg 32 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0037_173418.jpg 33 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0076_174118.jpg 34 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0070_174650.jpg 35 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0110_173857.jpg 36 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0017_174685.jpg 37 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0039_794859.jpg 38 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0051_794900.jpg 39 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0037_174691.jpg 40 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0018_174715.jpg 41 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0051_175015.jpg 42 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0019_174786.jpg 43 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0023_174977.jpg 44 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0033_174772.jpg 45 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0004_174997.jpg 46 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0107_175320.jpg 47 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0065_175924.jpg 48 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0129_175256.jpg 49 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0126_175368.jpg 50 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0054_175285.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_4.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0014_116129.jpg 2 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0114_116160.jpg 3 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0068_115799.jpg 4 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0110_115644.jpg 5 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0042_115638.jpg 6 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0006_116364.jpg 7 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0018_116402.jpg 8 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0026_116620.jpg 9 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0020_116379.jpg 10 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0011_116597.jpg 11 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0023_796582.jpg 12 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0052_796599.jpg 13 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0054_116850.jpg 14 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0064_796573.jpg 15 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0070_796571.jpg 16 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0040_117088.jpg 17 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0072_795230.jpg 18 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0068_795180.jpg 19 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0081_795215.jpg 20 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0032_795186.jpg 21 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0084_117492.jpg 22 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0009_117535.jpg 23 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0014_117883.jpg 24 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0042_117507.jpg 25 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0072_117951.jpg 26 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0056_117974.jpg 27 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0002_796908.jpg 28 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0051_796902.jpg 29 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0014_796906.jpg 30 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0077_796913.jpg 31 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0049_119596.jpg 32 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0118_118603.jpg 33 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0068_119972.jpg 34 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0052_118583.jpg 35 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0054_120057.jpg 36 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0001_120720.jpg 37 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0048_120758.jpg 38 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0042_796528.jpg 39 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0049_120735.jpg 40 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0035_796533.jpg 41 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0046_121903.jpg 42 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0055_121158.jpg 43 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0107_120990.jpg 44 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0091_121651.jpg 45 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0087_121062.jpg 46 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0094_124974.jpg 47 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0123_125324.jpg 48 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0041_123497.jpg 49 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0086_123751.jpg 50 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0119_124114.jpg 51 | -------------------------------------------------------------------------------- /models/resnet20_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None, last=False): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv3x3(planes, planes) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | self.last = last 23 | 24 | def forward(self, x, train=True): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | 39 | out = self.relu(out) 40 | 41 | return out 42 | 43 | class ResNet(nn.Module): 44 | 45 | def __init__(self, block, layers, num_classes=10): 46 | self.inplanes = 16 47 | super(ResNet, self).__init__() 48 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 49 | bias=False) 50 | self.bn1 = nn.BatchNorm2d(16) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.layer1 = self._make_layer(block, 16, layers[0]) 53 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 54 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, last_phase=True) 55 | # self.avgpool = nn.AvgPool2d(8, stride=1) 56 | 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 60 | elif isinstance(m, nn.BatchNorm2d): 61 | nn.init.constant_(m.weight, 1) 62 | nn.init.constant_(m.bias, 0) 63 | 64 | def _make_layer(self, block, planes, blocks, stride=1, last_phase=False): 65 | downsample = None 66 | if stride != 1 or self.inplanes != planes * block.expansion: 67 | downsample = nn.Sequential( 68 | nn.Conv2d(self.inplanes, planes * block.expansion, 69 | kernel_size=1, stride=stride, bias=False), 70 | nn.BatchNorm2d(planes * block.expansion), 71 | ) 72 | 73 | layers = [] 74 | layers.append(block(self.inplanes, planes, stride, downsample)) 75 | self.inplanes = planes * block.expansion 76 | if last_phase: 77 | for i in range(1, blocks-1): 78 | layers.append(block(self.inplanes, planes)) 79 | layers.append(block(self.inplanes, planes, last=True)) 80 | else: 81 | for i in range(1, blocks): 82 | layers.append(block(self.inplanes, planes)) 83 | 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | x = self.conv1(x) 88 | x = self.bn1(x) 89 | x = self.relu(x) 90 | 91 | x = self.layer1(x) 92 | x = self.layer2(x) 93 | x = self.layer3(x) 94 | 95 | # x = self.avgpool(x) 96 | # x = x.view(x.size(0), -1) 97 | # x = self.fc(x) 98 | 99 | return x 100 | 101 | def resnet20(**kwargs): 102 | n = 3 103 | model = ResNet(BasicBlock, [n, n, n], **kwargs) 104 | return model 105 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_8.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0071_161900.jpg 2 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0035_161741.jpg 3 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0054_161862.jpg 4 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0023_161774.jpg 5 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0040_161883.jpg 6 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0113_162403.jpg 7 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0064_162417.jpg 8 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0091_162378.jpg 9 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0016_162411.jpg 10 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0080_162392.jpg 11 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0012_162701.jpg 12 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0103_162972.jpg 13 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0022_162912.jpg 14 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0005_163197.jpg 15 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0032_162659.jpg 16 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0039_163420.jpg 17 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0020_163353.jpg 18 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0014_797226.jpg 19 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0072_163200.jpg 20 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0080_163399.jpg 21 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0128_163696.jpg 22 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0097_163750.jpg 23 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0094_164152.jpg 24 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0105_163996.jpg 25 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0101_164324.jpg 26 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0079_794820.jpg 27 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0046_794828.jpg 28 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0061_164516.jpg 29 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0068_794825.jpg 30 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0011_794812.jpg 31 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0040_165173.jpg 32 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0001_164704.jpg 33 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0021_165057.jpg 34 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0058_164674.jpg 35 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0053_164631.jpg 36 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0008_165369.jpg 37 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0035_795878.jpg 38 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0050_165278.jpg 39 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0071_165342.jpg 40 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0072_165305.jpg 41 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0041_165709.jpg 42 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0092_165807.jpg 43 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0029_165567.jpg 44 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0030_165782.jpg 45 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0053_165682.jpg 46 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0069_166559.jpg 47 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0035_166586.jpg 48 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0002_166520.jpg 49 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0079_166564.jpg 50 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0015_166535.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_7.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0012_797473.jpg 2 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0007_797481.jpg 3 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0020_797461.jpg 4 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0053_797478.jpg 5 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0003_797467.jpg 6 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0097_156272.jpg 7 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0019_156311.jpg 8 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0121_156233.jpg 9 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0011_156276.jpg 10 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0119_156259.jpg 11 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0078_794776.jpg 12 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0039_794794.jpg 13 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0068_794763.jpg 14 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0012_794785.jpg 15 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0013_794772.jpg 16 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0101_156988.jpg 17 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0006_157025.jpg 18 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0041_156954.jpg 19 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0115_157004.jpg 20 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0056_156968.jpg 21 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0075_158480.jpg 22 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0061_158494.jpg 23 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0004_158376.jpg 24 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0030_158488.jpg 25 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0077_158427.jpg 26 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0042_159012.jpg 27 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0033_159079.jpg 28 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0126_159341.jpg 29 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0071_159072.jpg 30 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0016_158978.jpg 31 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0066_795007.jpg 32 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0032_159632.jpg 33 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0017_794988.jpg 34 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0025_795009.jpg 35 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0058_794994.jpg 36 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0073_797138.jpg 37 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0081_159963.jpg 38 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0071_797108.jpg 39 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0105_797143.jpg 40 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0052_797125.jpg 41 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0057_160037.jpg 42 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0035_160102.jpg 43 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0119_160898.jpg 44 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0102_160073.jpg 45 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0022_160512.jpg 46 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0050_161154.jpg 47 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0130_161682.jpg 48 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0133_161539.jpg 49 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0054_161158.jpg 50 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0024_161619.jpg 51 | -------------------------------------------------------------------------------- /data/index_list/cub200/session_10.txt: -------------------------------------------------------------------------------- 1 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0063_795553.jpg 2 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0011_795566.jpg 3 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0092_795524.jpg 4 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0006_176037.jpg 5 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0018_795546.jpg 6 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0083_176292.jpg 7 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0096_176586.jpg 8 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0119_176485.jpg 9 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0102_176821.jpg 10 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0049_176526.jpg 11 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0043_177070.jpg 12 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0080_177080.jpg 13 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0022_177003.jpg 14 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0050_177331.jpg 15 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0014_177305.jpg 16 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0087_795261.jpg 17 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0001_795271.jpg 18 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0034_795242.jpg 19 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0020_795265.jpg 20 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0077_795247.jpg 21 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0046_177864.jpg 22 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0042_177887.jpg 23 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0024_177661.jpg 24 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0031_796633.jpg 25 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0048_177821.jpg 26 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0094_178049.jpg 27 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0016_178629.jpg 28 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0125_178921.jpg 29 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0004_179215.jpg 30 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0065_179017.jpg 31 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0018_179831.jpg 32 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0007_179932.jpg 33 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0024_179876.jpg 34 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0009_179919.jpg 35 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0012_179905.jpg 36 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0056_180094.jpg 37 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0034_180419.jpg 38 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0110_180521.jpg 39 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0079_180388.jpg 40 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0088_180054.jpg 41 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0112_180827.jpg 42 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0017_181131.jpg 43 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0125_180780.jpg 44 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0086_181891.jpg 45 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0020_182335.jpg 46 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0023_794701.jpg 47 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0033_794721.jpg 48 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0027_794713.jpg 49 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0029_794724.jpg 50 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0039_794736.jpg 51 | -------------------------------------------------------------------------------- /models/teen/Network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.resnet18_encoder import * 5 | from models.resnet20_cifar import * 6 | 7 | 8 | class MYNET(nn.Module): 9 | 10 | def __init__(self, args, mode=None): 11 | super().__init__() 12 | 13 | self.mode = mode 14 | self.args = args 15 | if self.args.dataset in ['cifar100','manyshotcifar']: 16 | self.encoder = resnet20() 17 | self.num_features = 64 18 | if self.args.dataset in ['mini_imagenet','manyshotmini','imagenet100','imagenet1000', 'mini_imagenet_withpath']: 19 | self.encoder = resnet18(False, args) # pretrained=False 20 | self.num_features = 512 21 | if self.args.dataset in ['cub200','manyshotcub']: 22 | self.encoder = resnet18(True, args) # pretrained=True follow TOPIC, models for cub is imagenet pre-trained. https://github.com/xyutao/fscil/issues/11#issuecomment-687548790 23 | self.num_features = 512 24 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 25 | 26 | self.fc = nn.Linear(self.num_features, self.args.num_classes, bias=False) 27 | 28 | def forward_metric(self, x): 29 | x = self.encode(x) 30 | if 'cos' in self.mode: 31 | x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1)) 32 | x = self.args.temperature * x 33 | 34 | elif 'dot' in self.mode: 35 | x = self.fc(x) 36 | x = self.args.temperature * x 37 | return x 38 | 39 | def encode(self, x): 40 | x = self.encoder(x) 41 | x = F.adaptive_avg_pool2d(x, 1) 42 | x = x.squeeze(-1).squeeze(-1) 43 | return x 44 | 45 | def forward(self, input): 46 | if self.mode != 'encoder': 47 | input = self.forward_metric(input) 48 | return input 49 | elif self.mode == 'encoder': 50 | input = self.encode(input) 51 | return input 52 | else: 53 | raise ValueError('Unknown mode') 54 | 55 | def update_fc(self,dataloader,class_list,session): 56 | for batch in dataloader: 57 | data, label = [_.cuda() for _ in batch] 58 | data=self.encode(data).detach() 59 | 60 | if self.args.not_data_init: 61 | new_fc = nn.Parameter( 62 | torch.rand(len(class_list), self.num_features, device="cuda"), 63 | requires_grad=True) 64 | nn.init.kaiming_uniform_(new_fc, a=math.sqrt(5)) 65 | else: 66 | new_fc = self.update_fc_avg(data, label, class_list) 67 | 68 | def update_fc_avg(self,data,label,class_list): 69 | new_fc=[] 70 | for class_index in class_list: 71 | data_index=(label==class_index).nonzero().squeeze(-1) 72 | embedding=data[data_index] 73 | proto=embedding.mean(0) 74 | new_fc.append(proto) 75 | self.fc.weight.data[class_index]=proto 76 | new_fc=torch.stack(new_fc,dim=0) 77 | return new_fc 78 | 79 | def get_logits(self,x,fc): 80 | if 'dot' in self.args.new_mode: 81 | return F.linear(x,fc) 82 | elif 'cos' in self.args.new_mode: 83 | return self.args.temperature * F.linear(F.normalize(x, p=2, dim=-1), F.normalize(fc, p=2, dim=-1)) 84 | 85 | def soft_calibration(self, args, session): 86 | base_protos = self.fc.weight.data[:args.base_class].detach().cpu().data 87 | base_protos = F.normalize(base_protos, p=2, dim=-1) 88 | 89 | cur_protos = self.fc.weight.data[args.base_class + (session-1) * args.way : args.base_class + session * args.way].detach().cpu().data 90 | cur_protos = F.normalize(cur_protos, p=2, dim=-1) 91 | 92 | weights = torch.mm(cur_protos, base_protos.T) * args.softmax_t 93 | norm_weights = torch.softmax(weights, dim=1) 94 | delta_protos = torch.matmul(norm_weights, base_protos) 95 | 96 | delta_protos = F.normalize(delta_protos, p=2, dim=-1) 97 | 98 | updated_protos = (1-args.shift_weight) * cur_protos + args.shift_weight * delta_protos 99 | 100 | self.fc.weight.data[args.base_class + (session-1) * args.way : args.base_class + session * args.way] = updated_protos -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import logging 4 | import sys 5 | 6 | from postprocess_path import set_save_path 7 | from utils import Logger, pprint, set_gpu, set_logging, set_seed 8 | 9 | 10 | def get_command_line_parser(): 11 | parser = argparse.ArgumentParser() 12 | # about dataset and network 13 | parser.add_argument('-project', type=str, default='base', choices=['teen']) 14 | parser.add_argument('-dataset', type=str, default='cifar100', 15 | choices=['mini_imagenet', 'cub200', 'cifar100']) 16 | parser.add_argument('-dataroot', type=str, default='') 17 | parser.add_argument('-temperature', type=float, default=16) 18 | parser.add_argument('-feat_norm', action='store_true', help='If True, normalize the feature.') 19 | 20 | # about pre-training 21 | parser.add_argument('-epochs_base', type=int, default=100) 22 | parser.add_argument('-epochs_new', type=int, default=100) 23 | parser.add_argument('-lr_base', type=float, default=0.1) 24 | parser.add_argument('-lr_new', type=float, default=0.1) 25 | 26 | ## optimizer & scheduler 27 | parser.add_argument('-optim', type=str, default='sgd', choices=['sgd', 'adam']) 28 | parser.add_argument('-schedule', type=str, default='Step', choices=['Step', 'Milestone','Cosine']) 29 | parser.add_argument('-milestones', nargs='+', type=int, default=[60, 70]) 30 | parser.add_argument('-step', type=int, default=20) 31 | parser.add_argument('-decay', type=float, default=0.0005) 32 | parser.add_argument('-momentum', type=float, default=0.9) 33 | parser.add_argument('-gamma', type=float, default=0.1) 34 | parser.add_argument('-tmax', type=int, default=600) #consine scheduler 35 | 36 | parser.add_argument('-not_data_init', action='store_true', help='using average data embedding to init or not') 37 | parser.add_argument('-batch_size_base', type=int, default=128) 38 | parser.add_argument('-batch_size_new', type=int, default=0, help='set 0 will use all the availiable training image for new') 39 | parser.add_argument('-test_batch_size', type=int, default=100) 40 | parser.add_argument('-base_mode', type=str, default='ft_cos', 41 | choices=['ft_dot', 'ft_cos']) # ft_dot means using linear classifier, ft_cos means using cosine classifier 42 | parser.add_argument('-new_mode', type=str, default='avg_cos', 43 | choices=['ft_dot', 'ft_cos', 'avg_cos']) # ft_dot means using linear classifier, ft_cos means using cosine classifier, avg_cos means using average data embedding and cosine classifier 44 | 45 | parser.add_argument('-start_session', type=int, default=0) 46 | parser.add_argument('-model_dir', type=str, default=None, help='loading model parameter from a specific dir') 47 | parser.add_argument('-only_do_incre', action='store_true', help='Load model and incremental learning...') 48 | 49 | # about training 50 | parser.add_argument('-gpu', default='0,1,2,3') 51 | parser.add_argument('-num_workers', type=int, default=8) 52 | parser.add_argument('-seed', type=int, default=1) 53 | parser.add_argument('-debug', action='store_true') 54 | 55 | return parser 56 | 57 | def add_commond_line_parser(params): 58 | project = params[1] 59 | # base parser 60 | parser = get_command_line_parser() 61 | 62 | if project == 'base': 63 | args = parser.parse_args(params[2:]) 64 | return args 65 | 66 | elif project == 'teen': 67 | parser.add_argument('-softmax_t', type=float, default=16) 68 | parser.add_argument('-shift_weight', type=float, default=0.5, help='weights of delta prototypes') 69 | parser.add_argument('-soft_mode', type=str, default='soft_proto', choices=['soft_proto', 'soft_embed', 'hard_proto']) 70 | args = parser.parse_args(params[2:]) 71 | return args 72 | else: 73 | raise NotImplementedError 74 | 75 | if __name__ == '__main__': 76 | args = add_commond_line_parser(sys.argv) 77 | 78 | set_seed(args.seed) 79 | pprint(vars(args)) 80 | args.num_gpu = set_gpu(args) 81 | 82 | set_save_path(args) 83 | 84 | logger = Logger(args, args.save_path) 85 | set_logging('INFO', args.save_path) 86 | logging.info(f"save_path: {args.save_path}") 87 | trainer = importlib.import_module('models.%s.fscil_trainer' % (args.project)).FSCILTrainer(args) 88 | trainer.train() 89 | -------------------------------------------------------------------------------- /models/teen/helper.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from tqdm import tqdm 3 | import torch.nn.functional as F 4 | import logging 5 | 6 | 7 | def base_train(model, trainloader, optimizer, scheduler, epoch, args): 8 | tl = Averager() 9 | ta = Averager() 10 | model = model.train() 11 | # standard classification for pretrain 12 | tqdm_gen = tqdm(trainloader) 13 | for i, batch in enumerate(tqdm_gen, 1): 14 | data, train_label = [_.cuda() for _ in batch] 15 | 16 | logits = model(data) 17 | logits = logits[:, :args.base_class] 18 | loss = F.cross_entropy(logits, train_label) 19 | acc = count_acc(logits, train_label) 20 | 21 | total_loss = loss 22 | 23 | lrc = scheduler.get_last_lr()[0] 24 | tqdm_gen.set_description( 25 | 'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc)) 26 | tl.add(total_loss.item()) 27 | ta.add(acc) 28 | 29 | optimizer.zero_grad() 30 | loss.backward() 31 | optimizer.step() 32 | tl = tl.item() 33 | ta = ta.item() 34 | return tl, ta 35 | 36 | def replace_base_fc(trainset, transform, model, args): 37 | # replace fc.weight with the embedding average of train data 38 | model = model.eval() 39 | 40 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128, 41 | num_workers=8, pin_memory=True, shuffle=False) 42 | trainloader.dataset.transform = transform 43 | embedding_list = [] 44 | label_list = [] 45 | with torch.no_grad(): 46 | for i, batch in enumerate(trainloader): 47 | data, label = [_.cuda() for _ in batch] 48 | model.module.mode = 'encoder' 49 | embedding = model(data) 50 | 51 | embedding_list.append(embedding.cpu()) 52 | label_list.append(label.cpu()) 53 | embedding_list = torch.cat(embedding_list, dim=0) 54 | label_list = torch.cat(label_list, dim=0) 55 | 56 | proto_list = [] 57 | 58 | for class_index in range(args.base_class): 59 | data_index = (label_list == class_index).nonzero() 60 | embedding_this = embedding_list[data_index.squeeze(-1)] 61 | embedding_this = embedding_this.mean(0) 62 | proto_list.append(embedding_this) 63 | 64 | proto_list = torch.stack(proto_list, dim=0) 65 | 66 | model.module.fc.weight.data[:args.base_class] = proto_list 67 | 68 | return model 69 | 70 | def test(model, testloader, epoch, args, session, result_list=None): 71 | test_class = args.base_class + session * args.way 72 | model = model.eval() 73 | vl = Averager() 74 | va = Averager() 75 | va5= Averager() 76 | lgt=torch.tensor([]) 77 | lbs=torch.tensor([]) 78 | with torch.no_grad(): 79 | for i, batch in enumerate(testloader, 1): 80 | data, test_label = [_.cuda() for _ in batch] 81 | logits = model(data) 82 | logits = logits[:, :test_class] 83 | loss = F.cross_entropy(logits, test_label) 84 | acc = count_acc(logits, test_label) 85 | top5acc = count_acc_topk(logits, test_label) 86 | 87 | vl.add(loss.item()) 88 | va.add(acc) 89 | va5.add(top5acc) 90 | 91 | lgt=torch.cat([lgt,logits.cpu()]) 92 | lbs=torch.cat([lbs,test_label.cpu()]) 93 | vl = vl.item() 94 | va = va.item() 95 | va5= va5.item() 96 | 97 | logging.info('epo {}, test, loss={:.4f} acc={:.4f}, acc@5={:.4f}'.format(epoch, vl, va, va5)) 98 | 99 | lgt=lgt.view(-1, test_class) 100 | lbs=lbs.view(-1) 101 | 102 | # if session > 0: 103 | # _preds = torch.argmax(lgt, dim=1) 104 | # torch.save(_preds, f"pred_labels/{args.project}_{args.dataset}_{session}_preds.pt") 105 | # torch.save(lbs, f"pred_labels/{args.project}_{args.dataset}_{session}_labels.pt") 106 | # torch.save(model.module.fc.weight.data.cpu()[:test_class], f"pred_labels/{args.project}_{args.dataset}_{session}_weights.pt") 107 | 108 | if session > 0: 109 | save_model_dir = os.path.join(args.save_path, 'session' + str(session) + 'confusion_matrix') 110 | cm = confmatrix(lgt,lbs,save_model_dir) 111 | perclassacc = cm.diagonal() 112 | seenac = np.mean(perclassacc[:args.base_class]) 113 | unseenac = np.mean(perclassacc[args.base_class:]) 114 | 115 | result_list.append(f"Seen Acc:{seenac} Unseen Acc:{unseenac}") 116 | return vl, (seenac, unseenac, va) 117 | else: 118 | return vl, va 119 | -------------------------------------------------------------------------------- /dataloader/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import copy 4 | 5 | 6 | class CategoriesSampler(): 7 | 8 | def __init__(self, label, n_batch, n_cls, n_per, ): 9 | self.n_batch = n_batch # the number of iterations in the dataloader 10 | self.n_cls = n_cls 11 | self.n_per = n_per 12 | 13 | label = np.array(label) # all data label 14 | self.m_ind = [] # the data index of each class 15 | for i in range(max(label) + 1): 16 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class 17 | ind = torch.from_numpy(ind) 18 | self.m_ind.append(ind) 19 | 20 | def __len__(self): 21 | return self.n_batch 22 | 23 | def __iter__(self): 24 | for i_batch in range(self.n_batch): 25 | batch = [] 26 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] # sample n_cls classes from total classes. 27 | for c in classes: 28 | l = self.m_ind[c] # all data indexs of this class 29 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class 30 | batch.append(l[pos]) 31 | batch = torch.stack(batch).t().reshape(-1) 32 | # .t() transpose, 33 | # due to it, the label is in the sequence of abcdabcdabcd form after reshape, 34 | # instead of aaaabbbbccccdddd 35 | yield batch 36 | # finally sample n_batch* n_cls(way)* n_per(shot) instances. per bacth. 37 | 38 | 39 | 40 | class BasePreserverCategoriesSampler(): 41 | def __init__(self, label, n_batch, n_cls, n_per, ): 42 | self.n_batch = n_batch # the number of iterations in the dataloader 43 | self.n_cls = n_cls 44 | self.n_per = n_per 45 | 46 | label = np.array(label) # all data label 47 | self.m_ind = [] # the data index of each class 48 | for i in range(max(label) + 1): 49 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class 50 | ind = torch.from_numpy(ind) 51 | self.m_ind.append(ind) 52 | 53 | def __len__(self): 54 | return self.n_batch 55 | 56 | def __iter__(self): 57 | 58 | for i_batch in range(self.n_batch): 59 | batch = [] 60 | #classes = torch.randperm(len(self.m_ind))[:self.n_cls] # sample n_cls classes from total classes. 61 | classes=torch.arange(len(self.m_ind)) 62 | for c in classes: 63 | l = self.m_ind[c] # all data indexs of this class 64 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class 65 | batch.append(l[pos]) 66 | batch = torch.stack(batch).t().reshape(-1) 67 | # .t() transpose, 68 | # due to it, the label is in the sequence of abcdabcdabcd form after reshape, 69 | # instead of aaaabbbbccccdddd 70 | yield batch 71 | # finally sample n_batch* n_cls(way)* n_per(shot) instances. per bacth. 72 | 73 | class NewCategoriesSampler(): 74 | 75 | def __init__(self, label, n_batch, n_cls, n_per,): 76 | self.n_batch = n_batch # the number of iterations in the dataloader 77 | self.n_cls = n_cls 78 | self.n_per = n_per 79 | 80 | label = np.array(label) # all data label 81 | self.m_ind = [] # the data index of each class 82 | for i in range(max(label) + 1): 83 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class 84 | ind = torch.from_numpy(ind) 85 | self.m_ind.append(ind) 86 | 87 | self.classlist=np.arange(np.min(label),np.max(label)+1) 88 | #print(self.classlist) 89 | 90 | def __len__(self): 91 | return self.n_batch 92 | 93 | def __iter__(self): 94 | for i_batch in range(self.n_batch): 95 | batch = [] 96 | for c in self.classlist: 97 | l = self.m_ind[c] # all data indexs of this class 98 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class 99 | batch.append(l[pos]) 100 | batch = torch.stack(batch).t().reshape(-1) 101 | yield batch 102 | 103 | 104 | if __name__ == '__main__': 105 | q=np.arange(5,10) 106 | print(q) 107 | y=torch.tensor([5,6,7,8,9,5,6,7,8,9,5,6,7,8,9,5,5,5,55,]) 108 | label = np.array(y) # all data label 109 | m_ind = [] # the data index of each class 110 | for i in range(max(label) + 1): 111 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class 112 | ind = torch.from_numpy(ind) 113 | m_ind.append(ind) 114 | print(m_ind, len(m_ind)) 115 | -------------------------------------------------------------------------------- /dataloader/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataloader.sampler import CategoriesSampler 4 | from torch.utils.data import Dataset 5 | def set_up_datasets(args): 6 | if args.dataset == 'cifar100': 7 | import dataloader.cifar100.cifar as Dataset 8 | args.base_class = 60 9 | args.num_classes=100 10 | args.way = 5 11 | args.shot = 5 12 | args.sessions = 9 13 | 14 | if args.dataset == 'cub200': 15 | import dataloader.cub200.cub200 as Dataset 16 | args.base_class = 100 17 | args.num_classes = 200 18 | args.way = 10 19 | args.shot = 5 20 | args.sessions = 11 21 | 22 | if args.dataset == 'mini_imagenet': 23 | import dataloader.miniimagenet.miniimagenet as Dataset 24 | args.base_class = 60 25 | args.num_classes=100 26 | args.way = 5 27 | args.shot = 5 28 | args.sessions = 9 29 | 30 | 31 | args.Dataset=Dataset 32 | return args 33 | 34 | def get_dataloader(args, session): 35 | if session == 0: 36 | trainset, trainloader, testloader = get_base_dataloader(args) 37 | else: 38 | trainset, trainloader, testloader = get_new_dataloader(args, session) 39 | return trainset, trainloader, testloader 40 | 41 | def get_base_dataloader(args): 42 | txt_path = "data/index_list/" + args.dataset + "/session_" + str(0 + 1) + '.txt' 43 | class_index = np.arange(args.base_class) 44 | if args.dataset == 'cifar100': 45 | 46 | trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=True, 47 | index=class_index, base_sess=True) 48 | testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False, 49 | index=class_index, base_sess=True) 50 | 51 | if args.dataset == 'cub200': 52 | trainset = args.Dataset.CUB200(root=args.dataroot, train=True, 53 | index=class_index, base_sess=True) 54 | testset = args.Dataset.CUB200(root=args.dataroot, train=False, index=class_index) 55 | 56 | if args.dataset == 'mini_imagenet': 57 | trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True, 58 | index=class_index, base_sess=True) 59 | testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False, index=class_index) 60 | 61 | if args.dataset == 'imagenet100' or args.dataset == 'imagenet1000': 62 | trainset = args.Dataset.ImageNet(root=args.dataroot, train=True, 63 | index=class_index, base_sess=True) 64 | testset = args.Dataset.ImageNet(root=args.dataroot, train=False, index=class_index) 65 | 66 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_base, shuffle=True, 67 | num_workers=8, pin_memory=True) 68 | testloader = torch.utils.data.DataLoader( 69 | dataset=testset, batch_size=args.test_batch_size, shuffle=False, num_workers=8, pin_memory=True) 70 | 71 | return trainset, trainloader, testloader 72 | 73 | 74 | def get_new_dataloader(args, session): 75 | txt_path = "data/index_list/" + args.dataset + "/session_" + str(session + 1) + '.txt' 76 | 77 | if args.dataset == 'cifar100': 78 | class_index = open(txt_path).read().splitlines() 79 | trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=False, 80 | index=class_index, base_sess=False) 81 | if args.dataset == 'cub200': 82 | trainset = args.Dataset.CUB200(root=args.dataroot, train=True, 83 | index_path=txt_path) 84 | if args.dataset == 'mini_imagenet': 85 | trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True, 86 | index_path=txt_path) 87 | 88 | if args.batch_size_new == 0: 89 | batch_size_new = trainset.__len__() 90 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_new, shuffle=False, 91 | num_workers=args.num_workers, pin_memory=True) 92 | else: 93 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_new, shuffle=True, 94 | num_workers=args.num_workers, pin_memory=True) 95 | 96 | # test on all encountered classes 97 | class_new = get_session_classes(args, session) 98 | 99 | if args.dataset == 'cifar100': 100 | testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False, 101 | index=class_new, base_sess=False) 102 | if args.dataset == 'cub200': 103 | testset = args.Dataset.CUB200(root=args.dataroot, train=False, 104 | index=class_new) 105 | if args.dataset == 'mini_imagenet': 106 | testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False, 107 | index=class_new) 108 | 109 | testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=args.test_batch_size, shuffle=False, 110 | num_workers=args.num_workers, pin_memory=True) 111 | 112 | return trainset, trainloader, testloader 113 | 114 | def get_session_classes(args,session): 115 | class_list=np.arange(args.base_class + session * args.way) 116 | return class_list 117 | 118 | -------------------------------------------------------------------------------- /dataloader/miniimagenet/miniimagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from .autoaugment import AutoAugImageNetPolicy 10 | 11 | class MiniImageNet(Dataset): 12 | 13 | def __init__(self, root='', train=True, 14 | transform=None, 15 | index_path=None, index=None, base_sess=None, autoaug=True): 16 | if train: 17 | setname = 'train' 18 | else: 19 | setname = 'test' 20 | self.root = os.path.expanduser(root) 21 | self.transform = transform 22 | self.train = train # training set or test set 23 | self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images') 24 | self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split') 25 | 26 | csv_path = osp.join(self.SPLIT_PATH, setname + '.csv') 27 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 28 | 29 | self.data = [] 30 | self.targets = [] 31 | self.data2label = {} 32 | lb = -1 33 | 34 | self.wnids = [] 35 | 36 | for l in lines: 37 | name, wnid = l.split(',') 38 | path = osp.join(self.IMAGE_PATH, name) 39 | if wnid not in self.wnids: 40 | self.wnids.append(wnid) 41 | lb += 1 42 | self.data.append(path) 43 | self.targets.append(lb) 44 | self.data2label[path] = lb 45 | 46 | if autoaug is False: 47 | #do not use autoaug. 48 | if train: 49 | image_size = 84 50 | self.transform = transforms.Compose([ 51 | transforms.RandomResizedCrop(image_size), 52 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 53 | transforms.RandomHorizontalFlip(), 54 | 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 57 | std=[0.229, 0.224, 0.225])]) 58 | if base_sess: 59 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 60 | else: 61 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path) 62 | else: 63 | image_size = 84 64 | self.transform = transforms.Compose([ 65 | transforms.Resize([92, 92]), 66 | transforms.CenterCrop(image_size), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 69 | std=[0.229, 0.224, 0.225])]) 70 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 71 | else: 72 | #use autoaug. 73 | if train: 74 | image_size = 84 75 | self.transform = transforms.Compose([ 76 | transforms.RandomResizedCrop(image_size), 77 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 78 | transforms.RandomHorizontalFlip(), 79 | #add autoaug 80 | AutoAugImageNetPolicy(), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225])]) 84 | if base_sess: 85 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 86 | else: 87 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path) 88 | else: 89 | image_size = 84 90 | self.transform = transforms.Compose([ 91 | transforms.Resize([92, 92]), 92 | transforms.CenterCrop(image_size), 93 | transforms.ToTensor(), 94 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 95 | std=[0.229, 0.224, 0.225])]) 96 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 97 | 98 | def SelectfromTxt(self, data2label, index_path): 99 | #select from txt file, and make cooresponding mampping. 100 | index=[] 101 | lines = [x.strip() for x in open(index_path, 'r').readlines()] 102 | for line in lines: 103 | index.append(line.split('/')[3]) 104 | data_tmp = [] 105 | targets_tmp = [] 106 | for i in index: 107 | img_path = os.path.join(self.IMAGE_PATH, i) 108 | data_tmp.append(img_path) 109 | targets_tmp.append(data2label[img_path]) 110 | 111 | return data_tmp, targets_tmp 112 | 113 | def SelectfromClasses(self, data, targets, index): 114 | #select from csv file, choose all instances from this class. 115 | data_tmp = [] 116 | targets_tmp = [] 117 | for i in index: 118 | ind_cl = np.where(i == targets)[0] 119 | for j in ind_cl: 120 | data_tmp.append(data[j]) 121 | targets_tmp.append(targets[j]) 122 | 123 | return data_tmp, targets_tmp 124 | 125 | def __len__(self): 126 | return len(self.data) 127 | 128 | def __getitem__(self, i): 129 | 130 | path, targets = self.data[i], self.targets[i] 131 | image = self.transform(Image.open(path).convert('RGB')) 132 | return image, targets 133 | 134 | 135 | class MiniImageNet_concate(Dataset): 136 | def __init__(self, train,x1,y1,x2,y2): 137 | 138 | if train: 139 | image_size = 84 140 | self.transform = transforms.Compose([ 141 | transforms.RandomResizedCrop(image_size), 142 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 143 | transforms.RandomHorizontalFlip(), 144 | 145 | transforms.ToTensor(), 146 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 147 | std=[0.229, 0.224, 0.225])]) 148 | 149 | else: 150 | image_size = 84 151 | self.transform = transforms.Compose([ 152 | transforms.Resize([92, 92]), 153 | transforms.CenterCrop(image_size), 154 | transforms.ToTensor(), 155 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 156 | std=[0.229, 0.224, 0.225])]) 157 | 158 | 159 | self.data=x1+x2 160 | self.targets=y1+y2 161 | print(len(self.data),len(self.targets)) 162 | 163 | def __len__(self): 164 | return len(self.data) 165 | 166 | def __getitem__(self, i): 167 | 168 | path, targets = self.data[i], self.targets[i] 169 | image = self.transform(Image.open(path).convert('RGB')) 170 | return image, targets 171 | 172 | if __name__ == '__main__': 173 | txt_path = "../../data/index_list/mini_imagenet/session_1.txt" 174 | # class_index = open(txt_path).read().splitlines() 175 | base_class = 100 176 | class_index = np.arange(base_class) 177 | dataroot = '/data/wangqw/datasets/FSCIL' 178 | batch_size_base = 400 179 | trainset = MiniImageNet(root=dataroot, train=False, transform=None, index=False) 180 | cls = np.unique(trainset.targets) 181 | print(cls) 182 | # trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_base, shuffle=True, num_workers=8, 183 | # pin_memory=True) 184 | # print(trainloader.dataset.data.shape) 185 | -------------------------------------------------------------------------------- /models/teen/fscil_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | 4 | import torch.nn as nn 5 | from base import Trainer 6 | from dataloader.data_utils import get_dataloader 7 | from utils import * 8 | 9 | from .helper import * 10 | from .Network import MYNET 11 | 12 | 13 | class FSCILTrainer(Trainer): 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.args = args 17 | self.set_up_model() 18 | 19 | def set_up_model(self): 20 | self.model = MYNET(self.args, mode=self.args.base_mode) 21 | self.model = nn.DataParallel(self.model, list(range(self.args.num_gpu))) 22 | self.model = self.model.cuda() 23 | 24 | if self.args.model_dir is not None: 25 | logging.info('Loading init parameters from: %s' % self.args.model_dir) 26 | self.best_model_dict = torch.load(self.args.model_dir, 27 | map_location={'cuda:3':'cuda:0'})['params'] 28 | else: 29 | logging.info('random init params') 30 | if self.args.start_session > 0: 31 | logging.info('WARING: Random init weights for new sessions!') 32 | self.best_model_dict = deepcopy(self.model.state_dict()) 33 | 34 | def train(self,): 35 | args = self.args 36 | t_start_time = time.time() 37 | # init train statistics 38 | result_list = [args] 39 | for session in range(args.start_session, args.sessions): 40 | train_set, trainloader, testloader = get_dataloader(args, session) 41 | self.model.load_state_dict(self.best_model_dict) 42 | if session == 0: # load base class train img label 43 | if not args.only_do_incre: 44 | logging.info(f'new classes for this session:{np.unique(train_set.targets)}') 45 | optimizer, scheduler = get_optimizer(args, self.model) 46 | for epoch in range(args.epochs_base): 47 | start_time = time.time() 48 | 49 | tl, ta = base_train(self.model, trainloader, optimizer, scheduler, epoch, args) 50 | tsl, tsa = test(self.model, testloader, epoch, args, session, result_list=result_list) 51 | 52 | # save better model 53 | if (tsa * 100) >= self.trlog['max_acc'][session]: 54 | self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100)) 55 | self.trlog['max_acc_epoch'] = epoch 56 | save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth') 57 | torch.save(dict(params=self.model.state_dict()), save_model_dir) 58 | torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth')) 59 | self.best_model_dict = deepcopy(self.model.state_dict()) 60 | logging.info('********A better model is found!!**********') 61 | logging.info('Saving model to :%s' % save_model_dir) 62 | logging.info('best epoch {}, best test acc={:.3f}'.format( 63 | self.trlog['max_acc_epoch'], self.trlog['max_acc'][session])) 64 | 65 | self.trlog['train_loss'].append(tl) 66 | self.trlog['train_acc'].append(ta) 67 | self.trlog['test_loss'].append(tsl) 68 | self.trlog['test_acc'].append(tsa) 69 | lrc = scheduler.get_last_lr()[0] 70 | 71 | logging.info( 72 | 'epoch:%03d,lr:%.4f,training_loss:%.5f,training_acc:%.5f,test_loss:%.5f,test_acc:%.5f' % ( 73 | epoch, lrc, tl, ta, tsl, tsa)) 74 | print('This epoch takes %d seconds' % (time.time() - start_time), 75 | '\n still need around %.2f mins to finish this session' % ( 76 | (time.time() - start_time) * (args.epochs_base - epoch) / 60)) 77 | scheduler.step() 78 | 79 | # Finish base train 80 | logging.info('>>> Finish Base Train <<<') 81 | result_list.append('Session {}, Test Best Epoch {},\nbest test Acc {:.4f}\n'.format( 82 | session, self.trlog['max_acc_epoch'], self.trlog['max_acc'][session])) 83 | else: 84 | logging.info('>>> Load Model &&& Finish base train...') 85 | assert args.model_dir is not None 86 | 87 | if not args.not_data_init: 88 | self.model.load_state_dict(self.best_model_dict) 89 | self.model = replace_base_fc(train_set, testloader.dataset.transform, self.model, args) 90 | best_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth') 91 | logging.info('Replace the fc with average embedding, and save it to :%s' % best_model_dir) 92 | self.best_model_dict = deepcopy(self.model.state_dict()) 93 | torch.save(dict(params=self.model.state_dict()), best_model_dir) 94 | 95 | self.model.module.mode = 'avg_cos' 96 | tsl, tsa = test(self.model, testloader, 0, args, session, result_list=result_list) 97 | if (tsa * 100) >= self.trlog['max_acc'][session]: 98 | self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100)) 99 | logging.info('The new best test acc of base session={:.3f}'.format( 100 | self.trlog['max_acc'][session])) 101 | 102 | # incremental learning sessions 103 | else: 104 | logging.info("training session: [%d]" % session) 105 | self.model.module.mode = self.args.new_mode 106 | self.model.eval() 107 | trainloader.dataset.transform = testloader.dataset.transform 108 | 109 | if args.soft_mode == 'soft_proto': 110 | self.model.module.update_fc(trainloader, np.unique(train_set.targets), session) 111 | self.model.module.soft_calibration(args, session) 112 | else: 113 | raise NotImplementedError 114 | 115 | tsl, (seenac, unseenac, avgac) = test(self.model, testloader, 0, args, session, result_list=result_list) 116 | 117 | # update results and save model 118 | self.trlog['seen_acc'].append(float('%.3f' % (seenac * 100))) 119 | self.trlog['unseen_acc'].append(float('%.3f' % (unseenac * 100))) 120 | self.trlog['max_acc'][session] = float('%.3f' % (avgac * 100)) 121 | self.best_model_dict = deepcopy(self.model.state_dict()) 122 | 123 | logging.info(f"Session {session} ==> Seen Acc:{self.trlog['seen_acc'][-1]} " 124 | f"Unseen Acc:{self.trlog['unseen_acc'][-1]} Avg Acc:{self.trlog['max_acc'][session]}") 125 | result_list.append('Session {}, test Acc {:.3f}\n'.format(session, self.trlog['max_acc'][session])) 126 | 127 | # Finish all incremental sessions, save results. 128 | result_list, hmeans = postprocess_results(result_list, self.trlog) 129 | save_list_to_txt(os.path.join(args.save_path, 'results.txt'), result_list) 130 | if not self.args.debug: 131 | save_result(args, self.trlog, hmeans) 132 | 133 | t_end_time = time.time() 134 | total_time = (t_end_time - t_start_time) / 60 135 | logging.info(f"Base Session Best epoch:{self.trlog['max_acc_epoch']}") 136 | logging.info('Total time used %.2f mins' % total_time) 137 | logging.info(self.args.time_str) 138 | -------------------------------------------------------------------------------- /dataloader/cub200/cub200.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from .autoaugment import AutoAugImageNetPolicy 10 | 11 | class CUB200(Dataset): 12 | 13 | def __init__(self, root='', train=True, 14 | index_path=None, index=None, base_sess=None, autoaug=False): 15 | self.root = os.path.expanduser(root) 16 | self.train = train # training set or test set 17 | self._pre_operate(self.root) 18 | 19 | if autoaug is False: 20 | #do not use autoaug 21 | if train: 22 | self.transform = transforms.Compose([ 23 | transforms.Resize(256), 24 | # transforms.CenterCrop(224), 25 | transforms.RandomResizedCrop(224), 26 | transforms.RandomHorizontalFlip(), 27 | #add autoaug 28 | #AutoAugImageNetPolicy(), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 31 | ]) 32 | # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path) 33 | if base_sess: 34 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 35 | else: 36 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path) 37 | else: 38 | self.transform = transforms.Compose([ 39 | transforms.Resize(256), 40 | transforms.CenterCrop(224), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 43 | ]) 44 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 45 | else: 46 | #use autoaug 47 | if train: 48 | self.transform = transforms.Compose([ 49 | transforms.Resize(256), 50 | # transforms.CenterCrop(224), 51 | transforms.RandomResizedCrop(224), 52 | transforms.RandomHorizontalFlip(), 53 | #add autoaug 54 | AutoAugImageNetPolicy(), 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 57 | ]) 58 | # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path) 59 | if base_sess: 60 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 61 | else: 62 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path) 63 | else: 64 | self.transform = transforms.Compose([ 65 | transforms.Resize(256), 66 | transforms.CenterCrop(224), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 69 | ]) 70 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index) 71 | def text_read(self, file): 72 | with open(file, 'r') as f: 73 | lines = f.readlines() 74 | for i, line in enumerate(lines): 75 | lines[i] = line.strip('\n') 76 | return lines 77 | 78 | def list2dict(self, list): 79 | dict = {} 80 | for l in list: 81 | s = l.split(' ') 82 | id = int(s[0]) 83 | cls = s[1] 84 | if id not in dict.keys(): 85 | dict[id] = cls 86 | else: 87 | raise EOFError('The same ID can only appear once') 88 | return dict 89 | 90 | def _pre_operate(self, root): 91 | image_file = os.path.join(root, 'CUB_200_2011/images.txt') 92 | split_file = os.path.join(root, 'CUB_200_2011/train_test_split.txt') 93 | class_file = os.path.join(root, 'CUB_200_2011/image_class_labels.txt') 94 | id2image = self.list2dict(self.text_read(image_file)) 95 | id2train = self.list2dict(self.text_read(split_file)) # 1: train images; 0: test iamges 96 | id2class = self.list2dict(self.text_read(class_file)) 97 | train_idx = [] 98 | test_idx = [] 99 | for k in sorted(id2train.keys()): 100 | if id2train[k] == '1': 101 | train_idx.append(k) 102 | else: 103 | test_idx.append(k) 104 | 105 | self.data = [] 106 | self.targets = [] 107 | self.data2label = {} 108 | if self.train: 109 | for k in train_idx: 110 | image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k]) 111 | self.data.append(image_path) 112 | self.targets.append(int(id2class[k]) - 1) 113 | self.data2label[image_path] = (int(id2class[k]) - 1) 114 | 115 | else: 116 | for k in test_idx: 117 | image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k]) 118 | self.data.append(image_path) 119 | self.targets.append(int(id2class[k]) - 1) 120 | self.data2label[image_path] = (int(id2class[k]) - 1) 121 | 122 | def SelectfromTxt(self, data2label, index_path): 123 | index = open(index_path).read().splitlines() 124 | data_tmp = [] 125 | targets_tmp = [] 126 | for i in index: 127 | img_path = os.path.join(self.root, i) 128 | data_tmp.append(img_path) 129 | targets_tmp.append(data2label[img_path]) 130 | 131 | return data_tmp, targets_tmp 132 | 133 | def SelectfromClasses(self, data, targets, index): 134 | data_tmp = [] 135 | targets_tmp = [] 136 | for i in index: 137 | ind_cl = np.where(i == targets)[0] 138 | for j in ind_cl: 139 | data_tmp.append(data[j]) 140 | targets_tmp.append(targets[j]) 141 | 142 | return data_tmp, targets_tmp 143 | 144 | def __len__(self): 145 | return len(self.data) 146 | 147 | def __getitem__(self, i): 148 | path, targets = self.data[i], self.targets[i] 149 | image = self.transform(Image.open(path).convert('RGB')) 150 | return image, targets 151 | 152 | 153 | class CUB200_concate(Dataset): 154 | def __init__(self, train,x1,y1,x2,y2): 155 | 156 | if train: 157 | self.transform = transforms.Compose([ 158 | transforms.Resize(256), 159 | # transforms.CenterCrop(224), 160 | transforms.RandomResizedCrop(224), 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 164 | ]) 165 | else: 166 | self.transform = transforms.Compose([ 167 | transforms.Resize(256), 168 | transforms.CenterCrop(224), 169 | transforms.ToTensor(), 170 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 171 | ]) 172 | self.data=x1+x2 173 | self.targets=y1+y2 174 | print(len(self.data),len(self.targets)) 175 | 176 | 177 | 178 | def __len__(self): 179 | return len(self.data) 180 | 181 | def __getitem__(self, i): 182 | path, targets = self.data[i], self.targets[i] 183 | image = self.transform(Image.open(path).convert('RGB')) 184 | return image, targets 185 | 186 | if __name__ == '__main__': 187 | txt_path = "../../data/index_list/cub200/session_1.txt" 188 | # class_index = open(txt_path).read().splitlines() 189 | base_class = 100 190 | class_index = np.arange(base_class) 191 | dataroot = '~/dataloader/data' 192 | batch_size_base = 400 193 | trainset = CUB200(root=dataroot, train=False, index=class_index, 194 | base_sess=True) 195 | cls = np.unique(trainset.targets) 196 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_base, shuffle=True, num_workers=8, 197 | pin_memory=True) 198 | 199 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pprint as pprint 4 | import random 5 | import shutil 6 | import time 7 | from collections import OrderedDict, defaultdict 8 | from pathlib import Path 9 | 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | from sklearn.metrics import confusion_matrix 15 | import logging 16 | from logging.config import dictConfig 17 | from dataloader.data_utils import * 18 | 19 | _utils_pp = pprint.PrettyPrinter() 20 | 21 | def set_logging(level, work_dir): 22 | LOGGING = { 23 | "version": 1, 24 | "disable_existing_loggers": False, 25 | "formatters": { 26 | "simple": { 27 | "format": f"%(message)s" 28 | }, 29 | }, 30 | "handlers": { 31 | "console": { 32 | "level": f"{level}", 33 | "class": "logging.StreamHandler", 34 | 'formatter': 'simple', 35 | }, 36 | 'file': { 37 | 'level': f"{level}", 38 | 'formatter': 'simple', 39 | 'class': 'logging.FileHandler', 40 | 'filename': f'{work_dir if work_dir is not None else "."}/train.log', 41 | 'mode': 'a', 42 | }, 43 | }, 44 | "loggers": { 45 | "": { 46 | "level": f"{level}", 47 | "handlers": ["console", "file"] if work_dir is not None else ["console"], 48 | }, 49 | }, 50 | } 51 | dictConfig(LOGGING) 52 | logging.info(f"Log level set to: {level}") 53 | 54 | def pprint(x): 55 | _utils_pp.pprint(x) 56 | class ConfigEncoder(json.JSONEncoder): 57 | def default(self, o): 58 | if isinstance(o, type): 59 | return {'$class': o.__module__ + "." + o.__name__} 60 | elif isinstance(o, Enum): 61 | return { 62 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name 63 | } 64 | elif callable(o): 65 | return { 66 | '$function': o.__module__ + "." + o.__name__ 67 | } 68 | return json.JSONEncoder.default(self, o) 69 | class Logger(object): 70 | def __init__(self, args, log_dir, **kwargs): 71 | self.logger_path = os.path.join(log_dir, 'scalars.json') 72 | # self.tb_logger = SummaryWriter( 73 | # logdir=osp.join(log_dir, 'tflogger'), 74 | # **kwargs, 75 | # ) 76 | self.log_config(vars(args)) 77 | 78 | self.scalars = defaultdict(OrderedDict) 79 | 80 | # def add_scalar(self, key, value, counter): 81 | def add_scalar(self, key, value, counter): 82 | assert self.scalars[key].get(counter, None) is None, 'counter should be distinct' 83 | self.scalars[key][counter] = value 84 | # self.tb_logger.add_scalar(key, value, counter) 85 | 86 | def log_config(self, variant_data): 87 | config_filepath = os.path.join(os.path.dirname(self.logger_path), 'configs.json') 88 | with open(config_filepath, "w") as fd: 89 | json.dump(variant_data, fd, indent=2, sort_keys=True, cls=ConfigEncoder) 90 | 91 | def dump(self): 92 | with open(self.logger_path, 'w') as fd: 93 | json.dump(self.scalars, fd, indent=2) 94 | 95 | def set_seed(seed): 96 | if seed == 0: 97 | logging.info(' random seed') 98 | torch.backends.cudnn.benchmark = True 99 | else: 100 | logging.info('manual seed:', seed) 101 | random.seed(seed) 102 | np.random.seed(seed) 103 | torch.manual_seed(seed) 104 | torch.cuda.manual_seed_all(seed) 105 | torch.backends.cudnn.deterministic = True 106 | torch.backends.cudnn.benchmark = False 107 | 108 | def set_gpu(args): 109 | gpu_list = [int(x) for x in args.gpu.split(',')] 110 | logging.info('use gpu:', gpu_list) 111 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 112 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 113 | return gpu_list.__len__() 114 | 115 | def ensure_path(path): 116 | if os.path.exists(path): 117 | pass 118 | else: 119 | logging.info('create folder:', path) 120 | os.makedirs(path) 121 | 122 | class Averager(): 123 | 124 | def __init__(self): 125 | self.n = 0 126 | self.v = 0 127 | 128 | def add(self, x): 129 | self.v = (self.v * self.n + x) / (self.n + 1) 130 | self.n += 1 131 | 132 | def item(self): 133 | return self.v 134 | 135 | 136 | class Timer(): 137 | 138 | def __init__(self): 139 | self.o = time.time() 140 | 141 | def measure(self, p=1): 142 | x = (time.time() - self.o) / p 143 | x = int(x) 144 | if x >= 3600: 145 | return '{:.1f}h'.format(x / 3600) 146 | if x >= 60: 147 | return '{}m'.format(round(x / 60)) 148 | return '{}s'.format(x) 149 | 150 | 151 | def count_acc(logits, label): 152 | pred = torch.argmax(logits, dim=1) 153 | if torch.cuda.is_available(): 154 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 155 | else: 156 | return (pred == label).type(torch.FloatTensor).mean().item() 157 | 158 | def count_acc_topk(x,y,k=5): 159 | _,maxk = torch.topk(x,k,dim=-1) 160 | total = y.size(0) 161 | test_labels = y.view(-1,1) 162 | #top1=(test_labels == maxk[:,0:1]).sum().item() 163 | topk=(test_labels == maxk).sum().item() 164 | return float(topk/total) 165 | 166 | def count_acc_taskIL(logits, label,args): 167 | basenum=args.base_class 168 | incrementnum=(args.num_classes-args.base_class)/args.way 169 | for i in range(len(label)): 170 | currentlabel=label[i] 171 | if currentlabel>> {params_info}-Avg_acc:{trlog['max_acc']} \n Seen_acc:{trlog['seen_acc']} \n Unseen_acc:{trlog['unseen_acc']} \n HMean_acc:{hmeans} \n") 231 | 232 | def harm_mean(seen, unseen): 233 | # compute from session1 234 | assert len(seen) == len(unseen) 235 | harm_means = [] 236 | for _seen, _unseen in zip(seen, unseen): 237 | _hmean = (2 * _seen * _unseen) / (_seen + _unseen + 1e-12) 238 | _hmean = float('%.3f' % (_hmean)) 239 | harm_means.append(_hmean) 240 | return harm_means 241 | 242 | def get_optimizer(args, model, **kwargs): 243 | # prepare optimizer 244 | if args.project in ['teen']: 245 | if args.optim == 'sgd': 246 | optimizer = torch.optim.SGD(model.parameters(), args.lr_base, 247 | momentum=args.momentum, nesterov=True, 248 | weight_decay=args.decay) 249 | elif args.optim == 'adam': 250 | optimizer = torch.optim.Adam(model.parameters(), 251 | lr=args.lr_base, weight_decay=args.decay) 252 | 253 | 254 | # prepare scheduler 255 | if args.schedule == 'Step': 256 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step, 257 | gamma=args.gamma) 258 | elif args.schedule == 'Milestone': 259 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, 260 | gamma=args.gamma) 261 | elif args.schedule == 'Cosine': 262 | assert args.tmax >= 0 , "args.tmax should be greater than 0" 263 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.tmax) 264 | return optimizer, scheduler 265 | 266 | -------------------------------------------------------------------------------- /dataloader/cub200/autoaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 3 | """ 4 | 5 | from PIL import Image, ImageEnhance, ImageOps 6 | import numpy as np 7 | import random 8 | 9 | __all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy'] 10 | 11 | 12 | class AutoAugImageNetPolicy(object): 13 | def __init__(self, fillcolor=(128, 128, 128)): 14 | self.policies = [ 15 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 16 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 17 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 18 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 19 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 20 | 21 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 22 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 23 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 24 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 25 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 26 | 27 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 28 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 29 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 30 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 31 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 32 | 33 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 34 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 35 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 36 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 37 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 38 | 39 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 40 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 41 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 42 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) 43 | ] 44 | 45 | def __call__(self, img): 46 | policy_idx = random.randint(0, len(self.policies) - 1) 47 | return self.policies[policy_idx](img) 48 | 49 | def __repr__(self): 50 | return "AutoAugment ImageNet Policy" 51 | 52 | 53 | class AutoAugCIFAR10Policy(object): 54 | def __init__(self, fillcolor=(128, 128, 128)): 55 | self.policies = [ 56 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 57 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 58 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 59 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 60 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 61 | 62 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 63 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 64 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 65 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 66 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 67 | 68 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 69 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 70 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 71 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 72 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 73 | 74 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 75 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 76 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 77 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 78 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 79 | 80 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 81 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 82 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 83 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 84 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 85 | ] 86 | 87 | def __call__(self, img): 88 | policy_idx = random.randint(0, len(self.policies) - 1) 89 | return self.policies[policy_idx](img) 90 | 91 | def __repr__(self): 92 | return "AutoAugment CIFAR10 Policy" 93 | 94 | 95 | class AutoAugSVHNPolicy(object): 96 | def __init__(self, fillcolor=(128, 128, 128)): 97 | self.policies = [ 98 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 99 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 100 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 101 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 102 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 103 | 104 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 105 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 106 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 107 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 108 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 109 | 110 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 111 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 112 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 113 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 114 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 115 | 116 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 117 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 118 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 119 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 120 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 121 | 122 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 123 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 124 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 125 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 126 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 127 | ] 128 | 129 | def __call__(self, img): 130 | policy_idx = random.randint(0, len(self.policies) - 1) 131 | return self.policies[policy_idx](img) 132 | 133 | def __repr__(self): 134 | return "AutoAugment SVHN Policy" 135 | 136 | 137 | class SubPolicy(object): 138 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 139 | ranges = { 140 | "shearX": np.linspace(0, 0.3, 10), 141 | "shearY": np.linspace(0, 0.3, 10), 142 | "translateX": np.linspace(0, 150 / 331, 10), 143 | "translateY": np.linspace(0, 150 / 331, 10), 144 | "rotate": np.linspace(0, 30, 10), 145 | "color": np.linspace(0.0, 0.9, 10), 146 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 147 | "solarize": np.linspace(256, 0, 10), 148 | "contrast": np.linspace(0.0, 0.9, 10), 149 | "sharpness": np.linspace(0.0, 0.9, 10), 150 | "brightness": np.linspace(0.0, 0.9, 10), 151 | "autocontrast": [0] * 10, 152 | "equalize": [0] * 10, 153 | "invert": [0] * 10 154 | } 155 | 156 | def rotate_with_fill(img, magnitude): 157 | rot = img.convert("RGBA").rotate(magnitude) 158 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 159 | 160 | func = { 161 | "shearX": lambda img, magnitude: img.transform( 162 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 163 | Image.BICUBIC, fillcolor=fillcolor), 164 | "shearY": lambda img, magnitude: img.transform( 165 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 166 | Image.BICUBIC, fillcolor=fillcolor), 167 | "translateX": lambda img, magnitude: img.transform( 168 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 169 | fillcolor=fillcolor), 170 | "translateY": lambda img, magnitude: img.transform( 171 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 172 | fillcolor=fillcolor), 173 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 174 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 175 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 176 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 177 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 178 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 179 | 1 + magnitude * random.choice([-1, 1])), 180 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 181 | 1 + magnitude * random.choice([-1, 1])), 182 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 183 | 1 + magnitude * random.choice([-1, 1])), 184 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 185 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 186 | "invert": lambda img, magnitude: ImageOps.invert(img) 187 | } 188 | 189 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 190 | # operation1, ranges[operation1][magnitude_idx1], 191 | # operation2, ranges[operation2][magnitude_idx2]) 192 | self.p1 = p1 193 | self.operation1 = func[operation1] 194 | self.magnitude1 = ranges[operation1][magnitude_idx1] 195 | self.p2 = p2 196 | self.operation2 = func[operation2] 197 | self.magnitude2 = ranges[operation2][magnitude_idx2] 198 | 199 | def __call__(self, img): 200 | if random.random() < self.p1: 201 | img = self.operation1(img, self.magnitude1) 202 | if random.random() < self.p2: 203 | img = self.operation2(img, self.magnitude2) 204 | return img -------------------------------------------------------------------------------- /dataloader/miniimagenet/autoaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 3 | """ 4 | 5 | from PIL import Image, ImageEnhance, ImageOps 6 | import numpy as np 7 | import random 8 | 9 | __all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy'] 10 | 11 | 12 | class AutoAugImageNetPolicy(object): 13 | def __init__(self, fillcolor=(128, 128, 128)): 14 | self.policies = [ 15 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 16 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 17 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 18 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 19 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 20 | 21 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 22 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 23 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 24 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 25 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 26 | 27 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 28 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 29 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 30 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 31 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 32 | 33 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 34 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 35 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 36 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 37 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 38 | 39 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 40 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 41 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 42 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) 43 | ] 44 | 45 | def __call__(self, img): 46 | policy_idx = random.randint(0, len(self.policies) - 1) 47 | return self.policies[policy_idx](img) 48 | 49 | def __repr__(self): 50 | return "AutoAugment ImageNet Policy" 51 | 52 | 53 | class AutoAugCIFAR10Policy(object): 54 | def __init__(self, fillcolor=(128, 128, 128)): 55 | self.policies = [ 56 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 57 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 58 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 59 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 60 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 61 | 62 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 63 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 64 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 65 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 66 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 67 | 68 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 69 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 70 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 71 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 72 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 73 | 74 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 75 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 76 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 77 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 78 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 79 | 80 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 81 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 82 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 83 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 84 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 85 | ] 86 | 87 | def __call__(self, img): 88 | policy_idx = random.randint(0, len(self.policies) - 1) 89 | return self.policies[policy_idx](img) 90 | 91 | def __repr__(self): 92 | return "AutoAugment CIFAR10 Policy" 93 | 94 | 95 | class AutoAugSVHNPolicy(object): 96 | def __init__(self, fillcolor=(128, 128, 128)): 97 | self.policies = [ 98 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 99 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 100 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 101 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 102 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 103 | 104 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 105 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 106 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 107 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 108 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 109 | 110 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 111 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 112 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 113 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 114 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 115 | 116 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 117 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 118 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 119 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 120 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 121 | 122 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 123 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 124 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 125 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 126 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 127 | ] 128 | 129 | def __call__(self, img): 130 | policy_idx = random.randint(0, len(self.policies) - 1) 131 | return self.policies[policy_idx](img) 132 | 133 | def __repr__(self): 134 | return "AutoAugment SVHN Policy" 135 | 136 | 137 | class SubPolicy(object): 138 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 139 | ranges = { 140 | "shearX": np.linspace(0, 0.3, 10), 141 | "shearY": np.linspace(0, 0.3, 10), 142 | "translateX": np.linspace(0, 150 / 331, 10), 143 | "translateY": np.linspace(0, 150 / 331, 10), 144 | "rotate": np.linspace(0, 30, 10), 145 | "color": np.linspace(0.0, 0.9, 10), 146 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 147 | "solarize": np.linspace(256, 0, 10), 148 | "contrast": np.linspace(0.0, 0.9, 10), 149 | "sharpness": np.linspace(0.0, 0.9, 10), 150 | "brightness": np.linspace(0.0, 0.9, 10), 151 | "autocontrast": [0] * 10, 152 | "equalize": [0] * 10, 153 | "invert": [0] * 10 154 | } 155 | 156 | def rotate_with_fill(img, magnitude): 157 | rot = img.convert("RGBA").rotate(magnitude) 158 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 159 | 160 | func = { 161 | "shearX": lambda img, magnitude: img.transform( 162 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 163 | Image.BICUBIC, fillcolor=fillcolor), 164 | "shearY": lambda img, magnitude: img.transform( 165 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 166 | Image.BICUBIC, fillcolor=fillcolor), 167 | "translateX": lambda img, magnitude: img.transform( 168 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 169 | fillcolor=fillcolor), 170 | "translateY": lambda img, magnitude: img.transform( 171 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 172 | fillcolor=fillcolor), 173 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 174 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 175 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 176 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 177 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 178 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 179 | 1 + magnitude * random.choice([-1, 1])), 180 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 181 | 1 + magnitude * random.choice([-1, 1])), 182 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 183 | 1 + magnitude * random.choice([-1, 1])), 184 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 185 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 186 | "invert": lambda img, magnitude: ImageOps.invert(img) 187 | } 188 | 189 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 190 | # operation1, ranges[operation1][magnitude_idx1], 191 | # operation2, ranges[operation2][magnitude_idx2]) 192 | self.p1 = p1 193 | self.operation1 = func[operation1] 194 | self.magnitude1 = ranges[operation1][magnitude_idx1] 195 | self.p2 = p2 196 | self.operation2 = func[operation2] 197 | self.magnitude2 = ranges[operation2][magnitude_idx2] 198 | 199 | def __call__(self, img): 200 | if random.random() < self.p1: 201 | img = self.operation1(img, self.magnitude1) 202 | if random.random() < self.p2: 203 | img = self.operation2(img, self.magnitude2) 204 | return img -------------------------------------------------------------------------------- /dataloader/cifar100/cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import pickle 7 | 8 | import torchvision.transforms as transforms 9 | 10 | from torchvision.datasets.vision import VisionDataset 11 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 12 | from .autoaugment import CIFAR10Policy, Cutout 13 | 14 | class CIFAR10(VisionDataset): 15 | """`CIFAR10 `_ Dataset. 16 | 17 | Args: 18 | root (string): Root directory of dataset where directory 19 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 20 | train (bool, optional): If True, creates dataset from training set, otherwise 21 | creates from test set. 22 | transform (callable, optional): A function/transform that takes in an PIL image 23 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 24 | target_transform (callable, optional): A function/transform that takes in the 25 | target and transforms it. 26 | download (bool, optional): If true, downloads the dataset from the internet and 27 | puts it in root directory. If dataset is already downloaded, it is not 28 | downloaded again. 29 | 30 | """ 31 | base_folder = 'cifar-10-batches-py' 32 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 33 | filename = "cifar-10-python.tar.gz" 34 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 35 | train_list = [ 36 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 37 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 38 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 39 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 40 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 41 | ] 42 | 43 | test_list = [ 44 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 45 | ] 46 | meta = { 47 | 'filename': 'batches.meta', 48 | 'key': 'label_names', 49 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 50 | } 51 | 52 | def __init__(self, root, train=True, transform=None, target_transform=None, 53 | download=False, index=None, base_sess=None, autoaug=True): 54 | 55 | super(CIFAR10, self).__init__(root, transform=transform, 56 | target_transform=target_transform) 57 | self.root = os.path.expanduser(root) 58 | self.train = train # training set or test set 59 | 60 | if download: 61 | self.download() 62 | 63 | if not self._check_integrity(): 64 | raise RuntimeError('Dataset not found or corrupted.' + 65 | ' You can use download=True to download it') 66 | 67 | # if self.train: 68 | # downloaded_list = self.train_list 69 | # else: 70 | # downloaded_list = self.test_list 71 | 72 | if autoaug is False: 73 | if self.train: 74 | downloaded_list = self.train_list 75 | self.transform = transforms.Compose([ 76 | transforms.RandomCrop(32, padding=4), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 80 | ]) 81 | else: 82 | downloaded_list = self.test_list 83 | self.transform = transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 86 | ]) 87 | else: 88 | if self.train: 89 | downloaded_list = self.train_list 90 | self.transform = transforms.Compose([ 91 | transforms.RandomCrop(32, padding=4), 92 | transforms.RandomHorizontalFlip(), 93 | CIFAR10Policy(), # add AutoAug 94 | transforms.ToTensor(), 95 | Cutout(n_holes=1, length=16), 96 | transforms.Normalize( 97 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 98 | ]) 99 | else: 100 | downloaded_list = self.test_list 101 | self.transform = transforms.Compose([ 102 | transforms.ToTensor(), 103 | transforms.Normalize( 104 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 105 | ]) 106 | 107 | self.data = [] 108 | self.targets = [] 109 | 110 | # now load the picked numpy arrays 111 | for file_name, checksum in downloaded_list: 112 | file_path = os.path.join(self.root, self.base_folder, file_name) 113 | with open(file_path, 'rb') as f: 114 | entry = pickle.load(f, encoding='latin1') 115 | self.data.append(entry['data']) 116 | if 'labels' in entry: 117 | self.targets.extend(entry['labels']) 118 | else: 119 | self.targets.extend(entry['fine_labels']) 120 | 121 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 122 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 123 | 124 | self.targets = np.asarray(self.targets) 125 | 126 | if base_sess: 127 | self.data, self.targets = self.SelectfromDefault(self.data, self.targets, index) 128 | else: # new Class session 129 | if train: 130 | self.data, self.targets = self.NewClassSelector(self.data, self.targets, index) 131 | else: 132 | self.data, self.targets = self.SelectfromDefault(self.data, self.targets, index) 133 | 134 | self._load_meta() 135 | 136 | def SelectfromDefault(self, data, targets, index): 137 | data_tmp = [] 138 | targets_tmp = [] 139 | for i in index: 140 | ind_cl = np.where(i == targets)[0] 141 | if data_tmp == []: 142 | data_tmp = data[ind_cl] 143 | targets_tmp = targets[ind_cl] 144 | else: 145 | data_tmp = np.vstack((data_tmp, data[ind_cl])) 146 | targets_tmp = np.hstack((targets_tmp, targets[ind_cl])) 147 | 148 | return data_tmp, targets_tmp 149 | 150 | def NewClassSelector(self, data, targets, index): 151 | data_tmp = [] 152 | targets_tmp = [] 153 | ind_list = [int(i) for i in index] 154 | ind_np = np.array(ind_list) 155 | index = ind_np.reshape((5,5)) 156 | for i in index: 157 | ind_cl = i 158 | if data_tmp == []: 159 | data_tmp = data[ind_cl] 160 | targets_tmp = targets[ind_cl] 161 | else: 162 | data_tmp = np.vstack((data_tmp, data[ind_cl])) 163 | targets_tmp = np.hstack((targets_tmp, targets[ind_cl])) 164 | 165 | return data_tmp, targets_tmp 166 | 167 | def _load_meta(self): 168 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 169 | if not check_integrity(path, self.meta['md5']): 170 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 171 | ' You can use download=True to download it') 172 | with open(path, 'rb') as infile: 173 | data = pickle.load(infile, encoding='latin1') 174 | self.classes = data[self.meta['key']] 175 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 176 | 177 | def __getitem__(self, index): 178 | """ 179 | Args: 180 | index (int): Index 181 | 182 | Returns: 183 | tuple: (image, target) where target is index of the target class. 184 | """ 185 | img, target = self.data[index], self.targets[index] 186 | 187 | # doing this so that it is consistent with all other datasets 188 | # to return a PIL Image 189 | img = Image.fromarray(img) 190 | 191 | if self.transform is not None: 192 | img = self.transform(img) 193 | 194 | if self.target_transform is not None: 195 | target = self.target_transform(target) 196 | 197 | return img, target 198 | 199 | def __len__(self): 200 | return len(self.data) 201 | 202 | def _check_integrity(self): 203 | root = self.root 204 | for fentry in (self.train_list + self.test_list): 205 | filename, md5 = fentry[0], fentry[1] 206 | fpath = os.path.join(root, self.base_folder, filename) 207 | if not check_integrity(fpath, md5): 208 | return False 209 | return True 210 | 211 | def download(self): 212 | if self._check_integrity(): 213 | print('Files already downloaded and verified') 214 | return 215 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 216 | 217 | def extra_repr(self): 218 | return "Split: {}".format("Train" if self.train is True else "Test") 219 | 220 | 221 | class CIFAR_concate(VisionDataset): 222 | def __init__(self, train,x1,y1,x2,y2): 223 | 224 | self.train=True 225 | if self.train: 226 | self.transform = transforms.Compose([ 227 | transforms.RandomCrop(32, padding=4), 228 | transforms.RandomHorizontalFlip(), 229 | transforms.ToTensor(), 230 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 231 | ]) 232 | else: 233 | self.transform = transforms.Compose([ 234 | transforms.ToTensor(), 235 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 236 | ]) 237 | 238 | 239 | self.data=np.vstack([x1,x2]) 240 | self.targets=np.hstack([y1,y2]) 241 | print(len(self.data),len(self.targets)) 242 | 243 | def __getitem__(self, index): 244 | 245 | img, target = self.data[index], self.targets[index] 246 | 247 | img = Image.fromarray(img) 248 | 249 | if self.transform is not None: 250 | img = self.transform(img) 251 | 252 | return img, target 253 | 254 | def __len__(self): 255 | return len(self.data) 256 | 257 | class CIFAR100(CIFAR10): 258 | """`CIFAR100 `_ Dataset. 259 | 260 | This is a subclass of the `CIFAR10` Dataset. 261 | """ 262 | base_folder = 'cifar-100-python' 263 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 264 | filename = "cifar-100-python.tar.gz" 265 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 266 | train_list = [ 267 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 268 | ] 269 | 270 | test_list = [ 271 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 272 | ] 273 | meta = { 274 | 'filename': 'meta', 275 | 'key': 'fine_label_names', 276 | 'md5': '7973b15100ade9c7d40fb424638fde48', 277 | } 278 | 279 | if __name__ == "__main__": 280 | 281 | 282 | dataroot = '../../data/' 283 | batch_size_base = 128 284 | txt_path = "../../data/index_list/cifar100/session_2.txt" 285 | # class_index = open(txt_path).read().splitlines() 286 | class_index = np.arange(60) 287 | class_index_val = np.arange(60,76) 288 | class_index_test= np.arange(76,100) 289 | 290 | trainset = CIFAR100(root=dataroot, train=True, download=True, transform=None, index=class_index_test, 291 | base_sess=True) 292 | testset = CIFAR100(root=dataroot, train=False, download=False,index=class_index, base_sess=True) 293 | 294 | 295 | import pickle 296 | print(trainset.data.shape) 297 | print(trainset.targets.shape) 298 | cls = np.unique(trainset.targets) 299 | print(cls) 300 | data={'data':trainset.data,'labels':trainset.targets} 301 | with open('CIFAR100_test.pickle', 'wb') as handle: 302 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 303 | 304 | -------------------------------------------------------------------------------- /dataloader/cifar100/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | import torch 5 | 6 | 7 | class Cutout(object): 8 | def __init__(self, n_holes, length): 9 | self.n_holes = n_holes 10 | self.length = length 11 | 12 | def __call__(self, img): 13 | h = img.size(1) 14 | w = img.size(2) 15 | 16 | mask = np.ones((h, w), np.float32) 17 | 18 | for n in range(self.n_holes): 19 | y = np.random.randint(h) 20 | x = np.random.randint(w) 21 | 22 | y1 = np.clip(y - self.length // 2, 0, h) 23 | y2 = np.clip(y + self.length // 2, 0, h) 24 | x1 = np.clip(x - self.length // 2, 0, w) 25 | x2 = np.clip(x + self.length // 2, 0, w) 26 | 27 | mask[y1: y2, x1: x2] = 0. 28 | 29 | mask = torch.from_numpy(mask) 30 | mask = mask.expand_as(img) 31 | img = img * mask 32 | 33 | return img 34 | 35 | class ImageNetPolicy(object): 36 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 37 | 38 | Example: 39 | >>> policy = ImageNetPolicy() 40 | >>> transformed = policy(image) 41 | 42 | Example as a PyTorch Transform: 43 | >>> transform=transforms.Compose([ 44 | >>> transforms.Resize(256), 45 | >>> ImageNetPolicy(), 46 | >>> transforms.ToTensor()]) 47 | """ 48 | def __init__(self, fillcolor=(128, 128, 128)): 49 | self.policies = [ 50 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 51 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 52 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 53 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 54 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 55 | 56 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 57 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 58 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 59 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 60 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 61 | 62 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 63 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 64 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 65 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 66 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 67 | 68 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 69 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 70 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 71 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 72 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 73 | 74 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 75 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 76 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 77 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 78 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 79 | ] 80 | 81 | 82 | def __call__(self, img): 83 | policy_idx = random.randint(0, len(self.policies) - 1) 84 | return self.policies[policy_idx](img) 85 | 86 | def __repr__(self): 87 | return "AutoAugment ImageNet Policy" 88 | 89 | 90 | class CIFAR10Policy(object): 91 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 92 | 93 | Example: 94 | >>> policy = CIFAR10Policy() 95 | >>> transformed = policy(image) 96 | 97 | Example as a PyTorch Transform: 98 | >>> transform=transforms.Compose([ 99 | >>> transforms.Resize(256), 100 | >>> CIFAR10Policy(), 101 | >>> transforms.ToTensor()]) 102 | """ 103 | def __init__(self, fillcolor=(128, 128, 128)): 104 | self.policies = [ 105 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 106 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 107 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 108 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 109 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 110 | 111 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 112 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 113 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 114 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 115 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 116 | 117 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 118 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 119 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 120 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 121 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 122 | 123 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 124 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 125 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 126 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 127 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 128 | 129 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 130 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 131 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 132 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 133 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 134 | ] 135 | 136 | 137 | def __call__(self, img): 138 | policy_idx = random.randint(0, len(self.policies) - 1) 139 | return self.policies[policy_idx](img) 140 | 141 | def __repr__(self): 142 | return "AutoAugment CIFAR10 Policy" 143 | 144 | 145 | class SVHNPolicy(object): 146 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 147 | 148 | Example: 149 | >>> policy = SVHNPolicy() 150 | >>> transformed = policy(image) 151 | 152 | Example as a PyTorch Transform: 153 | >>> transform=transforms.Compose([ 154 | >>> transforms.Resize(256), 155 | >>> SVHNPolicy(), 156 | >>> transforms.ToTensor()]) 157 | """ 158 | def __init__(self, fillcolor=(128, 128, 128)): 159 | self.policies = [ 160 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 161 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 162 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 163 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 164 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 165 | 166 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 167 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 168 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 169 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 170 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 171 | 172 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 173 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 174 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 175 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 176 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 177 | 178 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 179 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 180 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 181 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 182 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 183 | 184 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 185 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 186 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 187 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 188 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 189 | ] 190 | 191 | 192 | def __call__(self, img): 193 | policy_idx = random.randint(0, len(self.policies) - 1) 194 | return self.policies[policy_idx](img) 195 | 196 | def __repr__(self): 197 | return "AutoAugment SVHN Policy" 198 | 199 | 200 | class SubPolicy(object): 201 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 202 | ranges = { 203 | "shearX": np.linspace(0, 0.3, 10), 204 | "shearY": np.linspace(0, 0.3, 10), 205 | "translateX": np.linspace(0, 150 / 331, 10), 206 | "translateY": np.linspace(0, 150 / 331, 10), 207 | "rotate": np.linspace(0, 30, 10), 208 | "color": np.linspace(0.0, 0.9, 10), 209 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 210 | "solarize": np.linspace(256, 0, 10), 211 | "contrast": np.linspace(0.0, 0.9, 10), 212 | "sharpness": np.linspace(0.0, 0.9, 10), 213 | "brightness": np.linspace(0.0, 0.9, 10), 214 | "autocontrast": [0] * 10, 215 | "equalize": [0] * 10, 216 | "invert": [0] * 10 217 | } 218 | 219 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 220 | def rotate_with_fill(img, magnitude): 221 | rot = img.convert("RGBA").rotate(magnitude) 222 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 223 | 224 | func = { 225 | "shearX": lambda img, magnitude: img.transform( 226 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 227 | Image.BICUBIC, fillcolor=fillcolor), 228 | "shearY": lambda img, magnitude: img.transform( 229 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 230 | Image.BICUBIC, fillcolor=fillcolor), 231 | "translateX": lambda img, magnitude: img.transform( 232 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 233 | fillcolor=fillcolor), 234 | "translateY": lambda img, magnitude: img.transform( 235 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 236 | fillcolor=fillcolor), 237 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 238 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 239 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 240 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 241 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 242 | 1 + magnitude * random.choice([-1, 1])), 243 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 244 | 1 + magnitude * random.choice([-1, 1])), 245 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 246 | 1 + magnitude * random.choice([-1, 1])), 247 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 248 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 249 | "invert": lambda img, magnitude: ImageOps.invert(img) 250 | } 251 | 252 | self.p1 = p1 253 | self.operation1 = func[operation1] 254 | self.magnitude1 = ranges[operation1][magnitude_idx1] 255 | self.p2 = p2 256 | self.operation2 = func[operation2] 257 | self.magnitude2 = ranges[operation2][magnitude_idx2] 258 | 259 | 260 | def __call__(self, img): 261 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 262 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 263 | return img -------------------------------------------------------------------------------- /models/resnet18_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import errno 6 | import hashlib 7 | import os 8 | import warnings 9 | import re 10 | import shutil 11 | import sys 12 | import tempfile 13 | from tqdm import tqdm 14 | from urllib.request import urlopen 15 | from urllib.parse import urlparse # noqa: F401 16 | 17 | 18 | def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True): 19 | r"""Loads the Torch serialized object at the given URL. 20 | 21 | If the object is already present in `model_dir`, it's deserialized and 22 | returned. The filename part of the URL should follow the naming convention 23 | ``filename-.ext`` where ```` is the first eight or more 24 | digits of the SHA256 hash of the contents of the file. The hash is used to 25 | ensure unique names and to verify the contents of the file. 26 | 27 | The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where 28 | environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. 29 | ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux 30 | filesytem layout, with a default value ``~/.cache`` if not set. 31 | 32 | Args: 33 | url (string): URL of the object to download 34 | model_dir (string, optional): directory in which to save the object 35 | map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) 36 | progress (bool, optional): whether or not to display a progress bar to stderr 37 | 38 | Example: 39 | >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 40 | 41 | """ 42 | # Issue warning to move data if old env is set 43 | if os.getenv('TORCH_MODEL_ZOO'): 44 | warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 45 | 46 | if model_dir is None: 47 | torch_home = _get_torch_home() 48 | model_dir = os.path.join(torch_home, 'checkpoints') 49 | 50 | try: 51 | os.makedirs(model_dir) 52 | except OSError as e: 53 | if e.errno == errno.EEXIST: 54 | # Directory already exists, ignore. 55 | pass 56 | else: 57 | # Unexpected OSError, re-raise. 58 | raise 59 | 60 | parts = urlparse(url) 61 | filename = os.path.basename(parts.path) 62 | cached_file = os.path.join(model_dir, filename) 63 | if not os.path.exists(cached_file): 64 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 65 | hash_prefix = HASH_REGEX.search(filename).group(1) 66 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress) 67 | return torch.load(cached_file, map_location=map_location) 68 | 69 | 70 | def _download_url_to_file(url, dst, hash_prefix, progress): 71 | file_size = None 72 | u = urlopen(url) 73 | meta = u.info() 74 | if hasattr(meta, 'getheaders'): 75 | content_length = meta.getheaders("Content-Length") 76 | else: 77 | content_length = meta.get_all("Content-Length") 78 | if content_length is not None and len(content_length) > 0: 79 | file_size = int(content_length[0]) 80 | 81 | # We deliberately save it in a temp file and move it after 82 | # download is complete. This prevents a local working checkpoint 83 | # being overriden by a broken download. 84 | dst_dir = os.path.dirname(dst) 85 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) 86 | 87 | try: 88 | if hash_prefix is not None: 89 | sha256 = hashlib.sha256() 90 | with tqdm(total=file_size, disable=not progress, 91 | unit='B', unit_scale=True, unit_divisor=1024) as pbar: 92 | while True: 93 | buffer = u.read(8192) 94 | if len(buffer) == 0: 95 | break 96 | f.write(buffer) 97 | if hash_prefix is not None: 98 | sha256.update(buffer) 99 | pbar.update(len(buffer)) 100 | 101 | f.close() 102 | if hash_prefix is not None: 103 | digest = sha256.hexdigest() 104 | if digest[:len(hash_prefix)] != hash_prefix: 105 | raise RuntimeError('invalid hash value (expected "{}", got "{}")' 106 | .format(hash_prefix, digest)) 107 | shutil.move(f.name, dst) 108 | finally: 109 | f.close() 110 | if os.path.exists(f.name): 111 | os.remove(f.name) 112 | 113 | 114 | ENV_TORCH_HOME = 'TORCH_HOME' 115 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' 116 | DEFAULT_CACHE_DIR = '~/.cache' 117 | HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') 118 | 119 | 120 | def _get_torch_home(): 121 | torch_home = os.path.expanduser( 122 | os.getenv(ENV_TORCH_HOME, 123 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'))) 124 | return torch_home 125 | 126 | 127 | 128 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 129 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 130 | 'wide_resnet50_2', 'wide_resnet101_2'] 131 | 132 | 133 | model_urls = { 134 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 135 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 136 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 137 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 138 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 139 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 140 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 141 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 142 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 143 | } 144 | 145 | 146 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 147 | """3x3 convolution with padding""" 148 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 149 | padding=dilation, groups=groups, bias=False, dilation=dilation) 150 | 151 | 152 | def conv1x1(in_planes, out_planes, stride=1): 153 | """1x1 convolution""" 154 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 155 | 156 | 157 | class BasicBlock(nn.Module): 158 | expansion = 1 159 | 160 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 161 | base_width=64, dilation=1, norm_layer=None): 162 | super(BasicBlock, self).__init__() 163 | if norm_layer is None: 164 | norm_layer = nn.BatchNorm2d 165 | if groups != 1 or base_width != 64: 166 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 167 | if dilation > 1: 168 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 169 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 170 | self.conv1 = conv3x3(inplanes, planes, stride) 171 | self.bn1 = norm_layer(planes) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.conv2 = conv3x3(planes, planes) 174 | self.bn2 = norm_layer(planes) 175 | self.downsample = downsample 176 | self.stride = stride 177 | 178 | def forward(self, x): 179 | identity = x 180 | 181 | out = self.conv1(x) 182 | out = self.bn1(out) 183 | out = self.relu(out) 184 | 185 | out = self.conv2(out) 186 | out = self.bn2(out) 187 | 188 | if self.downsample is not None: 189 | identity = self.downsample(x) 190 | 191 | out += identity 192 | out = self.relu(out) 193 | 194 | return out 195 | 196 | 197 | class Bottleneck(nn.Module): 198 | expansion = 4 199 | 200 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 201 | base_width=64, dilation=1, norm_layer=None): 202 | super(Bottleneck, self).__init__() 203 | if norm_layer is None: 204 | norm_layer = nn.BatchNorm2d 205 | width = int(planes * (base_width / 64.)) * groups 206 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 207 | self.conv1 = conv1x1(inplanes, width) 208 | self.bn1 = norm_layer(width) 209 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 210 | self.bn2 = norm_layer(width) 211 | self.conv3 = conv1x1(width, planes * self.expansion) 212 | self.bn3 = norm_layer(planes * self.expansion) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.downsample = downsample 215 | self.stride = stride 216 | 217 | def forward(self, x): 218 | identity = x 219 | 220 | out = self.conv1(x) 221 | out = self.bn1(out) 222 | out = self.relu(out) 223 | 224 | out = self.conv2(out) 225 | out = self.bn2(out) 226 | out = self.relu(out) 227 | 228 | out = self.conv3(out) 229 | out = self.bn3(out) 230 | 231 | if self.downsample is not None: 232 | identity = self.downsample(x) 233 | 234 | out += identity 235 | out = self.relu(out) 236 | 237 | return out 238 | 239 | 240 | class ResNet(nn.Module): 241 | 242 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 243 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 244 | norm_layer=None): 245 | super(ResNet, self).__init__() 246 | if norm_layer is None: 247 | norm_layer = nn.BatchNorm2d 248 | self._norm_layer = norm_layer 249 | 250 | self.inplanes = 64 251 | self.dilation = 1 252 | if replace_stride_with_dilation is None: 253 | # each element in the tuple indicates if we should replace 254 | # the 2x2 stride with a dilated convolution instead 255 | replace_stride_with_dilation = [False, False, False] 256 | if len(replace_stride_with_dilation) != 3: 257 | raise ValueError("replace_stride_with_dilation should be None " 258 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 259 | self.groups = groups 260 | self.base_width = width_per_group 261 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 262 | bias=False) 263 | self.bn1 = norm_layer(self.inplanes) 264 | self.relu = nn.ReLU(inplace=True) 265 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 266 | self.layer1 = self._make_layer(block, 64, layers[0]) 267 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 268 | dilate=replace_stride_with_dilation[0]) 269 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 270 | dilate=replace_stride_with_dilation[1]) 271 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 272 | dilate=replace_stride_with_dilation[2]) 273 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 274 | # self.fc = nn.Linear(512 * block.expansion, num_classes,bias=False) 275 | 276 | for m in self.modules(): 277 | if isinstance(m, nn.Conv2d): 278 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 279 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 280 | nn.init.constant_(m.weight, 1) 281 | nn.init.constant_(m.bias, 0) 282 | 283 | # Zero-initialize the last BN in each residual branch, 284 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 285 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 286 | if zero_init_residual: 287 | for m in self.modules(): 288 | if isinstance(m, Bottleneck): 289 | nn.init.constant_(m.bn3.weight, 0) 290 | elif isinstance(m, BasicBlock): 291 | nn.init.constant_(m.bn2.weight, 0) 292 | 293 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 294 | norm_layer = self._norm_layer 295 | downsample = None 296 | previous_dilation = self.dilation 297 | if dilate: 298 | self.dilation *= stride 299 | stride = 1 300 | if stride != 1 or self.inplanes != planes * block.expansion: 301 | downsample = nn.Sequential( 302 | conv1x1(self.inplanes, planes * block.expansion, stride), 303 | norm_layer(planes * block.expansion), 304 | ) 305 | 306 | layers = [] 307 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 308 | self.base_width, previous_dilation, norm_layer)) 309 | self.inplanes = planes * block.expansion 310 | for _ in range(1, blocks): 311 | layers.append(block(self.inplanes, planes, groups=self.groups, 312 | base_width=self.base_width, dilation=self.dilation, 313 | norm_layer=norm_layer)) 314 | 315 | return nn.Sequential(*layers) 316 | 317 | def forward(self, x): 318 | x = self.conv1(x) 319 | x = self.bn1(x) 320 | x = self.relu(x) 321 | x = self.maxpool(x) 322 | 323 | x = self.layer1(x) 324 | x = self.layer2(x) 325 | x = self.layer3(x) 326 | x = self.layer4(x) 327 | 328 | # x = self.avgpool(x) 329 | # x = torch.flatten(x, 1) 330 | # x = self.fc(x) 331 | # x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1)) 332 | # x = temperature * x 333 | 334 | return x 335 | 336 | 337 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 338 | model = ResNet(block, layers, **kwargs) 339 | if pretrained: 340 | model_dict = model.state_dict() 341 | state_dict = load_state_dict_from_url(model_urls[arch], 342 | progress=progress) 343 | state_dict = {k: v for k, v in state_dict.items() if k not in ['fc.weight', 'fc.bias']} 344 | model_dict.update(state_dict) 345 | model.load_state_dict(model_dict) 346 | return model 347 | 348 | 349 | def resnet18(pretrained=False, progress=True, **kwargs): 350 | r"""ResNet-18 model from 351 | `"Deep Residual Learning for Image Recognition" `_ 352 | 353 | Args: 354 | pretrained (bool): If True, returns a model pre-trained on ImageNet 355 | progress (bool): If True, displays a progress bar of the download to stderr 356 | """ 357 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 358 | **kwargs) 359 | 360 | 361 | def resnet34(pretrained=False, progress=True, **kwargs): 362 | r"""ResNet-34 model from 363 | `"Deep Residual Learning for Image Recognition" `_ 364 | 365 | Args: 366 | pretrained (bool): If True, returns a model pre-trained on ImageNet 367 | progress (bool): If True, displays a progress bar of the download to stderr 368 | """ 369 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 370 | **kwargs) 371 | 372 | 373 | def resnet50(pretrained=False, progress=True, **kwargs): 374 | r"""ResNet-50 model from 375 | `"Deep Residual Learning for Image Recognition" `_ 376 | 377 | Args: 378 | pretrained (bool): If True, returns a model pre-trained on ImageNet 379 | progress (bool): If True, displays a progress bar of the download to stderr 380 | """ 381 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 382 | **kwargs) 383 | 384 | 385 | def resnet101(pretrained=False, progress=True, **kwargs): 386 | r"""ResNet-101 model from 387 | `"Deep Residual Learning for Image Recognition" `_ 388 | 389 | Args: 390 | pretrained (bool): If True, returns a model pre-trained on ImageNet 391 | progress (bool): If True, displays a progress bar of the download to stderr 392 | """ 393 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 394 | **kwargs) 395 | 396 | 397 | def resnet152(pretrained=False, progress=True, **kwargs): 398 | r"""ResNet-152 model from 399 | `"Deep Residual Learning for Image Recognition" `_ 400 | 401 | Args: 402 | pretrained (bool): If True, returns a model pre-trained on ImageNet 403 | progress (bool): If True, displays a progress bar of the download to stderr 404 | """ 405 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 406 | **kwargs) 407 | 408 | 409 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 410 | r"""ResNeXt-50 32x4d model from 411 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 412 | 413 | Args: 414 | pretrained (bool): If True, returns a model pre-trained on ImageNet 415 | progress (bool): If True, displays a progress bar of the download to stderr 416 | """ 417 | kwargs['groups'] = 32 418 | kwargs['width_per_group'] = 4 419 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 420 | pretrained, progress, **kwargs) 421 | 422 | 423 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 424 | r"""ResNeXt-101 32x8d model from 425 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 426 | 427 | Args: 428 | pretrained (bool): If True, returns a model pre-trained on ImageNet 429 | progress (bool): If True, displays a progress bar of the download to stderr 430 | """ 431 | kwargs['groups'] = 32 432 | kwargs['width_per_group'] = 8 433 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 434 | pretrained, progress, **kwargs) 435 | 436 | 437 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 438 | r"""Wide ResNet-50-2 model from 439 | `"Wide Residual Networks" `_ 440 | 441 | The model is the same as ResNet except for the bottleneck number of channels 442 | which is twice larger in every block. The number of channels in outer 1x1 443 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 444 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 445 | 446 | Args: 447 | pretrained (bool): If True, returns a model pre-trained on ImageNet 448 | progress (bool): If True, displays a progress bar of the download to stderr 449 | """ 450 | kwargs['width_per_group'] = 64 * 2 451 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 452 | pretrained, progress, **kwargs) 453 | 454 | 455 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 456 | r"""Wide ResNet-101-2 model from 457 | `"Wide Residual Networks" `_ 458 | 459 | The model is the same as ResNet except for the bottleneck number of channels 460 | which is twice larger in every block. The number of channels in outer 1x1 461 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 462 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 463 | 464 | Args: 465 | pretrained (bool): If True, returns a model pre-trained on ImageNet 466 | progress (bool): If True, displays a progress bar of the download to stderr 467 | """ 468 | kwargs['width_per_group'] = 64 * 2 469 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 470 | pretrained, progress, **kwargs) 471 | --------------------------------------------------------------------------------