├── imgs
└── results.png
├── data
└── index_list
│ ├── cifar100
│ ├── session_2.txt
│ ├── session_3.txt
│ ├── session_4.txt
│ ├── session_5.txt
│ ├── session_6.txt
│ ├── session_7.txt
│ ├── session_8.txt
│ └── session_9.txt
│ ├── README.md
│ ├── mini_imagenet
│ ├── session_2.txt
│ ├── session_3.txt
│ ├── session_4.txt
│ ├── session_5.txt
│ ├── session_6.txt
│ ├── session_7.txt
│ ├── session_8.txt
│ └── session_9.txt
│ └── cub200
│ ├── session_6.txt
│ ├── session_11.txt
│ ├── session_2.txt
│ ├── session_5.txt
│ ├── session_3.txt
│ ├── session_9.txt
│ ├── session_4.txt
│ ├── session_8.txt
│ ├── session_7.txt
│ └── session_10.txt
├── scripts
├── mini_imagenet.sh
├── cifar.sh
└── cub.sh
├── base.py
├── postprocess_path.py
├── README.md
├── models
├── resnet20_cifar.py
├── teen
│ ├── Network.py
│ ├── helper.py
│ └── fscil_trainer.py
└── resnet18_encoder.py
├── train.py
├── dataloader
├── sampler.py
├── data_utils.py
├── miniimagenet
│ ├── miniimagenet.py
│ └── autoaugment.py
├── cub200
│ ├── cub200.py
│ └── autoaugment.py
└── cifar100
│ ├── cifar.py
│ └── autoaugment.py
└── utils.py
/imgs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangkiw/TEEN/HEAD/imgs/results.png
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_2.txt:
--------------------------------------------------------------------------------
1 | 29774
2 | 33344
3 | 4815
4 | 6772
5 | 48317
6 | 29918
7 | 33262
8 | 5138
9 | 7342
10 | 47874
11 | 28864
12 | 32471
13 | 4316
14 | 6436
15 | 47498
16 | 29802
17 | 33159
18 | 3730
19 | 5093
20 | 47740
21 | 30548
22 | 34549
23 | 2845
24 | 4996
25 | 47866
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_3.txt:
--------------------------------------------------------------------------------
1 | 28855
2 | 32834
3 | 4603
4 | 6914
5 | 48126
6 | 29932
7 | 33300
8 | 3860
9 | 5424
10 | 47055
11 | 29434
12 | 32604
13 | 4609
14 | 6380
15 | 47844
16 | 30456
17 | 34217
18 | 4361
19 | 6550
20 | 46896
21 | 29664
22 | 32857
23 | 4923
24 | 7502
25 | 47270
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_4.txt:
--------------------------------------------------------------------------------
1 | 31267
2 | 34427
3 | 4799
4 | 6611
5 | 47404
6 | 28509
7 | 31687
8 | 3477
9 | 5563
10 | 48003
11 | 29545
12 | 33412
13 | 5114
14 | 6808
15 | 47692
16 | 29209
17 | 33265
18 | 4131
19 | 6401
20 | 48102
21 | 31290
22 | 34432
23 | 6060
24 | 8451
25 | 48279
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_5.txt:
--------------------------------------------------------------------------------
1 | 32337
2 | 35646
3 | 6022
4 | 9048
5 | 48584
6 | 30768
7 | 34394
8 | 5091
9 | 6510
10 | 48023
11 | 30310
12 | 33230
13 | 5098
14 | 6671
15 | 48349
16 | 29690
17 | 33490
18 | 4260
19 | 5916
20 | 47371
21 | 31173
22 | 34943
23 | 4517
24 | 6494
25 | 47689
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_6.txt:
--------------------------------------------------------------------------------
1 | 30281
2 | 33894
3 | 3768
4 | 6113
5 | 48095
6 | 28913
7 | 32821
8 | 6172
9 | 8276
10 | 48004
11 | 31249
12 | 34088
13 | 5257
14 | 6961
15 | 47534
16 | 30404
17 | 34101
18 | 4985
19 | 6899
20 | 48115
21 | 31823
22 | 35148
23 | 3922
24 | 6548
25 | 48127
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_7.txt:
--------------------------------------------------------------------------------
1 | 30815
2 | 34450
3 | 3481
4 | 5089
5 | 47913
6 | 31683
7 | 34591
8 | 5251
9 | 7608
10 | 47984
11 | 29837
12 | 33823
13 | 4615
14 | 6448
15 | 47752
16 | 31222
17 | 34079
18 | 5686
19 | 7919
20 | 48675
21 | 28567
22 | 32964
23 | 5009
24 | 6201
25 | 47039
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_8.txt:
--------------------------------------------------------------------------------
1 | 29355
2 | 33909
3 | 3982
4 | 5389
5 | 47166
6 | 31058
7 | 35180
8 | 5177
9 | 6890
10 | 48032
11 | 31176
12 | 35098
13 | 5235
14 | 7861
15 | 47830
16 | 30874
17 | 34639
18 | 5266
19 | 7489
20 | 47323
21 | 29960
22 | 34050
23 | 4988
24 | 7434
25 | 48208
26 |
--------------------------------------------------------------------------------
/data/index_list/cifar100/session_9.txt:
--------------------------------------------------------------------------------
1 | 30463
2 | 34580
3 | 5230
4 | 6813
5 | 48605
6 | 31702
7 | 35249
8 | 5854
9 | 7765
10 | 48444
11 | 30380
12 | 34028
13 | 5211
14 | 7433
15 | 47988
16 | 31348
17 | 34021
18 | 4929
19 | 7033
20 | 47904
21 | 30627
22 | 33728
23 | 4895
24 | 6299
25 | 47507
26 |
--------------------------------------------------------------------------------
/scripts/mini_imagenet.sh:
--------------------------------------------------------------------------------
1 | python train.py teen \
2 | -project teen \
3 | -dataset mini_imagenet \
4 | -dataroot MINI_DATA_DIR \
5 | -base_mode 'ft_cos' \
6 | -new_mode 'avg_cos' \
7 | -gamma 0.1 \
8 | -lr_base 0.1 \
9 | -decay 0.0005 \
10 | -epochs_base 1000 \
11 | -schedule Cosine \
12 | -tmax 1000 \
13 | -gpu '2' \
14 | -temperature 32 \
15 | -batch_size_base 128
16 |
--------------------------------------------------------------------------------
/scripts/cifar.sh:
--------------------------------------------------------------------------------
1 | python train.py teen \
2 | -project teen \
3 | -dataset cifar100 \
4 | -dataroot CIFAR_DATA_DIR \
5 | -base_mode 'ft_cos' \
6 | -new_mode 'avg_cos' \
7 | -lr_base 0.1 \
8 | -decay 0.0005 \
9 | -epochs_base 600 \
10 | -batch_size_base 256 \
11 | -schedule Cosine \
12 | -tmax 600 \
13 | -gpu '2' \
14 | -temperature 16 \
15 | -softmax_t 16 \
16 | -shift_weight 0.1
17 |
--------------------------------------------------------------------------------
/scripts/cub.sh:
--------------------------------------------------------------------------------
1 | python train.py teen \
2 | -project teen \
3 | -dataset cub200 \
4 | -dataroot CUB_DATA_DIR \
5 | -base_mode 'ft_cos' \
6 | -new_mode 'avg_cos' \
7 | -gamma 0.25 \
8 | -lr_base 0.004 \
9 | -lr_new 0.1 \
10 | -decay 0.0005 \
11 | -epochs_base 400 \
12 | -schedule Milestone \
13 | -milestones 50 100 150 200 250 300 \
14 | -gpu '3' \
15 | -temperature 32 \
16 | -batch_size_base 128 \
17 | -softmax_t 16 \
18 | -shift_weight 0.5
19 |
--------------------------------------------------------------------------------
/data/index_list/README.md:
--------------------------------------------------------------------------------
1 | ### How to use the index files for the experiments ?
2 |
3 | The index files are named like "session_x.txt", where x indicates the session number. Each index file stores the indexes of the images that are selected for the session.
4 | "session_1.txt" stores all the base class training images. Each "session_t.txt" (t>1) stores the 25 (5 classes and 5 shots per class) few-shot new class training images.
5 | You may adopt the following steps to perform the experiments.
6 |
7 | First, at session 1, train a base model using the images in session_1.txt;
8 |
9 | Then, at session t (t>1), finetune the model trained at the previous session (t-1), only using the images in session_t.txt.
10 |
11 | For evaluating the model at session t, first joint all the encountered test sets as a single test set. Then test the current model using all the test images and compute the recognition accuracy.
12 |
--------------------------------------------------------------------------------
/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from dataloader.data_utils import *
3 | from utils import (
4 | Averager, Timer
5 | )
6 |
7 | class Trainer(object, metaclass=abc.ABCMeta):
8 | def __init__(self, args):
9 | self.args = args
10 | self.args = set_up_datasets(self.args)
11 | self.dt, self.ft = Averager(), Averager()
12 | self.bt, self.ot = Averager(), Averager()
13 | self.timer = Timer()
14 | self.init_log()
15 |
16 | @abc.abstractmethod
17 | def train(self):
18 | pass
19 |
20 | def init_log(self):
21 | # train statistics
22 | self.trlog = {}
23 | self.trlog['train_loss'] = []
24 | self.trlog['val_loss'] = []
25 | self.trlog['test_loss'] = []
26 | self.trlog['train_acc'] = []
27 | self.trlog['val_acc'] = []
28 | self.trlog['test_acc'] = []
29 | self.trlog['max_acc_epoch'] = 0
30 | self.trlog['max_acc'] = [0.0] * self.args.sessions
31 |
32 | self.trlog['seen_acc'] = []
33 | self.trlog['unseen_acc'] = []
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_2.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n03544143/n0354414300000811.jpg
2 | MINI-ImageNet/train/n03544143/n0354414300000906.jpg
3 | MINI-ImageNet/train/n03544143/n0354414300000122.jpg
4 | MINI-ImageNet/train/n03544143/n0354414300000185.jpg
5 | MINI-ImageNet/train/n03544143/n0354414300001258.jpg
6 | MINI-ImageNet/train/n03584254/n0358425400000818.jpg
7 | MINI-ImageNet/train/n03584254/n0358425400000891.jpg
8 | MINI-ImageNet/train/n03584254/n0358425400000145.jpg
9 | MINI-ImageNet/train/n03584254/n0358425400000186.jpg
10 | MINI-ImageNet/train/n03584254/n0358425400001254.jpg
11 | MINI-ImageNet/train/n03676483/n0367648300000772.jpg
12 | MINI-ImageNet/train/n03676483/n0367648300000864.jpg
13 | MINI-ImageNet/train/n03676483/n0367648300000108.jpg
14 | MINI-ImageNet/train/n03676483/n0367648300000164.jpg
15 | MINI-ImageNet/train/n03676483/n0367648300001236.jpg
16 | MINI-ImageNet/train/n03770439/n0377043900000738.jpg
17 | MINI-ImageNet/train/n03770439/n0377043900000835.jpg
18 | MINI-ImageNet/train/n03770439/n0377043900000111.jpg
19 | MINI-ImageNet/train/n03770439/n0377043900000154.jpg
20 | MINI-ImageNet/train/n03770439/n0377043900001238.jpg
21 | MINI-ImageNet/train/n03773504/n0377350400000781.jpg
22 | MINI-ImageNet/train/n03773504/n0377350400000880.jpg
23 | MINI-ImageNet/train/n03773504/n0377350400000132.jpg
24 | MINI-ImageNet/train/n03773504/n0377350400000177.jpg
25 | MINI-ImageNet/train/n03773504/n0377350400001237.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_3.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n03775546/n0377554600000808.jpg
2 | MINI-ImageNet/train/n03775546/n0377554600000891.jpg
3 | MINI-ImageNet/train/n03775546/n0377554600000123.jpg
4 | MINI-ImageNet/train/n03775546/n0377554600000179.jpg
5 | MINI-ImageNet/train/n03775546/n0377554600001241.jpg
6 | MINI-ImageNet/train/n03838899/n0383889900000800.jpg
7 | MINI-ImageNet/train/n03838899/n0383889900000881.jpg
8 | MINI-ImageNet/train/n03838899/n0383889900000124.jpg
9 | MINI-ImageNet/train/n03838899/n0383889900000165.jpg
10 | MINI-ImageNet/train/n03838899/n0383889900001242.jpg
11 | MINI-ImageNet/train/n03854065/n0385406500000796.jpg
12 | MINI-ImageNet/train/n03854065/n0385406500000899.jpg
13 | MINI-ImageNet/train/n03854065/n0385406500000103.jpg
14 | MINI-ImageNet/train/n03854065/n0385406500000146.jpg
15 | MINI-ImageNet/train/n03854065/n0385406500001249.jpg
16 | MINI-ImageNet/train/n03888605/n0388860500000783.jpg
17 | MINI-ImageNet/train/n03888605/n0388860500000875.jpg
18 | MINI-ImageNet/train/n03888605/n0388860500000124.jpg
19 | MINI-ImageNet/train/n03888605/n0388860500000184.jpg
20 | MINI-ImageNet/train/n03888605/n0388860500001245.jpg
21 | MINI-ImageNet/train/n03908618/n0390861800000778.jpg
22 | MINI-ImageNet/train/n03908618/n0390861800000855.jpg
23 | MINI-ImageNet/train/n03908618/n0390861800000127.jpg
24 | MINI-ImageNet/train/n03908618/n0390861800000175.jpg
25 | MINI-ImageNet/train/n03908618/n0390861800001243.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_4.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n03924679/n0392467900000785.jpg
2 | MINI-ImageNet/train/n03924679/n0392467900000871.jpg
3 | MINI-ImageNet/train/n03924679/n0392467900000115.jpg
4 | MINI-ImageNet/train/n03924679/n0392467900000164.jpg
5 | MINI-ImageNet/train/n03924679/n0392467900001240.jpg
6 | MINI-ImageNet/train/n03980874/n0398087400000794.jpg
7 | MINI-ImageNet/train/n03980874/n0398087400000911.jpg
8 | MINI-ImageNet/train/n03980874/n0398087400000124.jpg
9 | MINI-ImageNet/train/n03980874/n0398087400000184.jpg
10 | MINI-ImageNet/train/n03980874/n0398087400001229.jpg
11 | MINI-ImageNet/train/n03998194/n0399819400000790.jpg
12 | MINI-ImageNet/train/n03998194/n0399819400000876.jpg
13 | MINI-ImageNet/train/n03998194/n0399819400000116.jpg
14 | MINI-ImageNet/train/n03998194/n0399819400000164.jpg
15 | MINI-ImageNet/train/n03998194/n0399819400001248.jpg
16 | MINI-ImageNet/train/n04067472/n0406747200000802.jpg
17 | MINI-ImageNet/train/n04067472/n0406747200000904.jpg
18 | MINI-ImageNet/train/n04067472/n0406747200000130.jpg
19 | MINI-ImageNet/train/n04067472/n0406747200000175.jpg
20 | MINI-ImageNet/train/n04067472/n0406747200001257.jpg
21 | MINI-ImageNet/train/n04146614/n0414661400000809.jpg
22 | MINI-ImageNet/train/n04146614/n0414661400000921.jpg
23 | MINI-ImageNet/train/n04146614/n0414661400000132.jpg
24 | MINI-ImageNet/train/n04146614/n0414661400000204.jpg
25 | MINI-ImageNet/train/n04146614/n0414661400001243.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_5.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n04149813/n0414981300000789.jpg
2 | MINI-ImageNet/train/n04149813/n0414981300000879.jpg
3 | MINI-ImageNet/train/n04149813/n0414981300000152.jpg
4 | MINI-ImageNet/train/n04149813/n0414981300000196.jpg
5 | MINI-ImageNet/train/n04149813/n0414981300001229.jpg
6 | MINI-ImageNet/train/n04243546/n0424354600000785.jpg
7 | MINI-ImageNet/train/n04243546/n0424354600000868.jpg
8 | MINI-ImageNet/train/n04243546/n0424354600000109.jpg
9 | MINI-ImageNet/train/n04243546/n0424354600000159.jpg
10 | MINI-ImageNet/train/n04243546/n0424354600001243.jpg
11 | MINI-ImageNet/train/n04251144/n0425114400000797.jpg
12 | MINI-ImageNet/train/n04251144/n0425114400000891.jpg
13 | MINI-ImageNet/train/n04251144/n0425114400000123.jpg
14 | MINI-ImageNet/train/n04251144/n0425114400000170.jpg
15 | MINI-ImageNet/train/n04251144/n0425114400001244.jpg
16 | MINI-ImageNet/train/n04258138/n0425813800000807.jpg
17 | MINI-ImageNet/train/n04258138/n0425813800000900.jpg
18 | MINI-ImageNet/train/n04258138/n0425813800000135.jpg
19 | MINI-ImageNet/train/n04258138/n0425813800000193.jpg
20 | MINI-ImageNet/train/n04258138/n0425813800001252.jpg
21 | MINI-ImageNet/train/n04275548/n0427554800000755.jpg
22 | MINI-ImageNet/train/n04275548/n0427554800000854.jpg
23 | MINI-ImageNet/train/n04275548/n0427554800000127.jpg
24 | MINI-ImageNet/train/n04275548/n0427554800000168.jpg
25 | MINI-ImageNet/train/n04275548/n0427554800001238.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_6.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n04296562/n0429656200000772.jpg
2 | MINI-ImageNet/train/n04296562/n0429656200000862.jpg
3 | MINI-ImageNet/train/n04296562/n0429656200000119.jpg
4 | MINI-ImageNet/train/n04296562/n0429656200000158.jpg
5 | MINI-ImageNet/train/n04296562/n0429656200001223.jpg
6 | MINI-ImageNet/train/n04389033/n0438903300000802.jpg
7 | MINI-ImageNet/train/n04389033/n0438903300000912.jpg
8 | MINI-ImageNet/train/n04389033/n0438903300000157.jpg
9 | MINI-ImageNet/train/n04389033/n0438903300000202.jpg
10 | MINI-ImageNet/train/n04389033/n0438903300001261.jpg
11 | MINI-ImageNet/train/n04418357/n0441835700000746.jpg
12 | MINI-ImageNet/train/n04418357/n0441835700000848.jpg
13 | MINI-ImageNet/train/n04418357/n0441835700000111.jpg
14 | MINI-ImageNet/train/n04418357/n0441835700000163.jpg
15 | MINI-ImageNet/train/n04418357/n0441835700001226.jpg
16 | MINI-ImageNet/train/n04435653/n0443565300000828.jpg
17 | MINI-ImageNet/train/n04435653/n0443565300000932.jpg
18 | MINI-ImageNet/train/n04435653/n0443565300000139.jpg
19 | MINI-ImageNet/train/n04435653/n0443565300000203.jpg
20 | MINI-ImageNet/train/n04435653/n0443565300001245.jpg
21 | MINI-ImageNet/train/n04443257/n0444325700000764.jpg
22 | MINI-ImageNet/train/n04443257/n0444325700000852.jpg
23 | MINI-ImageNet/train/n04443257/n0444325700000125.jpg
24 | MINI-ImageNet/train/n04443257/n0444325700000183.jpg
25 | MINI-ImageNet/train/n04443257/n0444325700001216.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_7.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n04509417/n0450941700000801.jpg
2 | MINI-ImageNet/train/n04509417/n0450941700000882.jpg
3 | MINI-ImageNet/train/n04509417/n0450941700000149.jpg
4 | MINI-ImageNet/train/n04509417/n0450941700000201.jpg
5 | MINI-ImageNet/train/n04509417/n0450941700001242.jpg
6 | MINI-ImageNet/train/n04515003/n0451500300000791.jpg
7 | MINI-ImageNet/train/n04515003/n0451500300000893.jpg
8 | MINI-ImageNet/train/n04515003/n0451500300000112.jpg
9 | MINI-ImageNet/train/n04515003/n0451500300000161.jpg
10 | MINI-ImageNet/train/n04515003/n0451500300001259.jpg
11 | MINI-ImageNet/train/n04522168/n0452216800000790.jpg
12 | MINI-ImageNet/train/n04522168/n0452216800000894.jpg
13 | MINI-ImageNet/train/n04522168/n0452216800000124.jpg
14 | MINI-ImageNet/train/n04522168/n0452216800000180.jpg
15 | MINI-ImageNet/train/n04522168/n0452216800001258.jpg
16 | MINI-ImageNet/train/n04596742/n0459674200000809.jpg
17 | MINI-ImageNet/train/n04596742/n0459674200000897.jpg
18 | MINI-ImageNet/train/n04596742/n0459674200000132.jpg
19 | MINI-ImageNet/train/n04596742/n0459674200000189.jpg
20 | MINI-ImageNet/train/n04596742/n0459674200001241.jpg
21 | MINI-ImageNet/train/n04604644/n0460464400000828.jpg
22 | MINI-ImageNet/train/n04604644/n0460464400000904.jpg
23 | MINI-ImageNet/train/n04604644/n0460464400000124.jpg
24 | MINI-ImageNet/train/n04604644/n0460464400000175.jpg
25 | MINI-ImageNet/train/n04604644/n0460464400001256.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_8.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n04612504/n0461250400000737.jpg
2 | MINI-ImageNet/train/n04612504/n0461250400000810.jpg
3 | MINI-ImageNet/train/n04612504/n0461250400000149.jpg
4 | MINI-ImageNet/train/n04612504/n0461250400000194.jpg
5 | MINI-ImageNet/train/n04612504/n0461250400001160.jpg
6 | MINI-ImageNet/train/n06794110/n0679411000000773.jpg
7 | MINI-ImageNet/train/n06794110/n0679411000000882.jpg
8 | MINI-ImageNet/train/n06794110/n0679411000000124.jpg
9 | MINI-ImageNet/train/n06794110/n0679411000000199.jpg
10 | MINI-ImageNet/train/n06794110/n0679411000001256.jpg
11 | MINI-ImageNet/train/n07584110/n0758411000000764.jpg
12 | MINI-ImageNet/train/n07584110/n0758411000000855.jpg
13 | MINI-ImageNet/train/n07584110/n0758411000000133.jpg
14 | MINI-ImageNet/train/n07584110/n0758411000000180.jpg
15 | MINI-ImageNet/train/n07584110/n0758411000001154.jpg
16 | MINI-ImageNet/train/n07613480/n0761348000000770.jpg
17 | MINI-ImageNet/train/n07613480/n0761348000000868.jpg
18 | MINI-ImageNet/train/n07613480/n0761348000000140.jpg
19 | MINI-ImageNet/train/n07613480/n0761348000000183.jpg
20 | MINI-ImageNet/train/n07613480/n0761348000001254.jpg
21 | MINI-ImageNet/train/n07697537/n0769753700000774.jpg
22 | MINI-ImageNet/train/n07697537/n0769753700000862.jpg
23 | MINI-ImageNet/train/n07697537/n0769753700000142.jpg
24 | MINI-ImageNet/train/n07697537/n0769753700000181.jpg
25 | MINI-ImageNet/train/n07697537/n0769753700001231.jpg
26 |
--------------------------------------------------------------------------------
/data/index_list/mini_imagenet/session_9.txt:
--------------------------------------------------------------------------------
1 | MINI-ImageNet/train/n07747607/n0774760700000787.jpg
2 | MINI-ImageNet/train/n07747607/n0774760700000894.jpg
3 | MINI-ImageNet/train/n07747607/n0774760700000140.jpg
4 | MINI-ImageNet/train/n07747607/n0774760700000190.jpg
5 | MINI-ImageNet/train/n07747607/n0774760700001253.jpg
6 | MINI-ImageNet/train/n09246464/n0924646400000794.jpg
7 | MINI-ImageNet/train/n09246464/n0924646400000885.jpg
8 | MINI-ImageNet/train/n09246464/n0924646400000107.jpg
9 | MINI-ImageNet/train/n09246464/n0924646400000145.jpg
10 | MINI-ImageNet/train/n09246464/n0924646400001250.jpg
11 | MINI-ImageNet/train/n09256479/n0925647900000823.jpg
12 | MINI-ImageNet/train/n09256479/n0925647900000899.jpg
13 | MINI-ImageNet/train/n09256479/n0925647900000154.jpg
14 | MINI-ImageNet/train/n09256479/n0925647900000208.jpg
15 | MINI-ImageNet/train/n09256479/n0925647900001246.jpg
16 | MINI-ImageNet/train/n13054560/n1305456000000771.jpg
17 | MINI-ImageNet/train/n13054560/n1305456000000856.jpg
18 | MINI-ImageNet/train/n13054560/n1305456000000102.jpg
19 | MINI-ImageNet/train/n13054560/n1305456000000159.jpg
20 | MINI-ImageNet/train/n13054560/n1305456000001230.jpg
21 | MINI-ImageNet/train/n13133613/n1313361300000758.jpg
22 | MINI-ImageNet/train/n13133613/n1313361300000871.jpg
23 | MINI-ImageNet/train/n13133613/n1313361300000106.jpg
24 | MINI-ImageNet/train/n13133613/n1313361300000160.jpg
25 | MINI-ImageNet/train/n13133613/n1313361300001231.jpg
26 |
--------------------------------------------------------------------------------
/postprocess_path.py:
--------------------------------------------------------------------------------
1 | from utils import ensure_path
2 | import os
3 | import datetime
4 | def sub_set_save_path(args):
5 | if args.project == 'teen':
6 | args.save_path = args.save_path +\
7 | f"-tw_{args.softmax_t}-{args.shift_weight}-{args.soft_mode}"
8 | else:
9 | raise NotImplementedError
10 | return args
11 |
12 | def set_save_path(args):
13 | # base info
14 | time_str = datetime.datetime.now().strftime('%m%d-%H-%M-%S-%f')[:-3]
15 | args.time_str = time_str
16 | mode = args.base_mode + '-' + args.new_mode
17 | if not args.not_data_init:
18 | mode = mode + '-' + 'data_init'
19 | args.save_path = '%s/' % args.dataset
20 | args.save_path = args.save_path + '%s/' % args.project
21 | args.save_path = args.save_path + '%s-start_%d/' % (mode, args.start_session)
22 |
23 | # optimizer & scheduler
24 | if args.schedule == 'Milestone':
25 | mile_stone = str(args.milestones).replace(" ", "").replace(',', '_')[1:-1]
26 | args.save_path = args.save_path +\
27 | f'{args.time_str}-Epo_{args.epochs_base}-Bs_{args.batch_size_base}'\
28 | f'-{args.optim}-Lr_{args.lr_base}-decay{args.decay}-Mom_{args.momentum}'\
29 | f'-MS_{mile_stone}-Gam_{args.gamma}'
30 |
31 | elif args.schedule == 'Step':
32 | args.save_path = args.save_path +\
33 | f'{args.time_str}-Epo_{args.epochs_base}-Bs_{args.batch_size_base}'\
34 | f'-{args.optim}-Lr_{args.lr_base}-decay{args.decay}-Mom_{args.momentum}'\
35 | f'-Step_{args.step}-Gam_{args.gamma}'
36 |
37 | elif args.schedule == 'Cosine':
38 | args.save_path = args.save_path +\
39 | f'{args.time_str}-Epo_{args.epochs_base}-Bs_{args.batch_size_base}'\
40 | f'-{args.optim}-Lr_{args.lr_base}-decay{args.decay}-Mom_{args.momentum}'\
41 | f'-Max_{args.tmax}'
42 | else:
43 | raise NotImplementedError
44 |
45 | # feature normalize
46 | if args.feat_norm:
47 | args.save_path = args.save_path + '-NormT'
48 | else:
49 | args.save_path = args.save_path + '-NormF'
50 |
51 | # train mode
52 | if 'cos' in mode:
53 | args.save_path = args.save_path + '-T_%.2f' % (args.temperature)
54 | if 'ft' in args.new_mode:
55 | args.save_path = args.save_path + '-ftLR_%.3f-ftEpoch_%d' % (
56 | args.lr_new, args.epochs_new)
57 |
58 | # specific parameters
59 | args = sub_set_save_path(args)
60 |
61 | if args.debug:
62 | args.save_path = os.path.join('debug', args.save_path)
63 |
64 | args.save_path = os.path.join('./checkpoint', args.save_path)
65 | ensure_path(args.save_path)
66 | return args
67 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_6.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0055_141524.jpg
2 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0124_142121.jpg
3 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0133_141069.jpg
4 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0111_143101.jpg
5 | CUB_200_2011/images/141.Artic_Tern/Artic_Tern_0107_141181.jpg
6 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0079_143998.jpg
7 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0082_144372.jpg
8 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0029_144140.jpg
9 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0066_144541.jpg
10 | CUB_200_2011/images/142.Black_Tern/Black_Tern_0046_144229.jpg
11 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0009_145057.jpg
12 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0116_145607.jpg
13 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0123_145774.jpg
14 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0006_145594.jpg
15 | CUB_200_2011/images/143.Caspian_Tern/Caspian_Tern_0013_145553.jpg
16 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0071_148796.jpg
17 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0077_149196.jpg
18 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0030_147825.jpg
19 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0095_149960.jpg
20 | CUB_200_2011/images/144.Common_Tern/Common_Tern_0083_148096.jpg
21 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0009_150954.jpg
22 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0045_150752.jpg
23 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0046_150905.jpg
24 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0103_150493.jpg
25 | CUB_200_2011/images/145.Elegant_Tern/Elegant_Tern_0004_150948.jpg
26 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0027_151456.jpg
27 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0077_152255.jpg
28 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0125_151399.jpg
29 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0045_151227.jpg
30 | CUB_200_2011/images/146.Forsters_Tern/Forsters_Tern_0119_152709.jpg
31 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0092_153361.jpg
32 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0020_153458.jpg
33 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0060_153190.jpg
34 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0119_153950.jpg
35 | CUB_200_2011/images/147.Least_Tern/Least_Tern_0037_153637.jpg
36 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0018_154825.jpg
37 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0070_154844.jpg
38 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0064_154771.jpg
39 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0058_797399.jpg
40 | CUB_200_2011/images/148.Green_tailed_Towhee/Green_Tailed_Towhee_0060_154820.jpg
41 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0013_155329.jpg
42 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0079_155394.jpg
43 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0019_155216.jpg
44 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0051_155344.jpg
45 | CUB_200_2011/images/149.Brown_Thrasher/Brown_Thrasher_0081_155256.jpg
46 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0033_155511.jpg
47 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0069_155544.jpg
48 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0096_155449.jpg
49 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0104_155529.jpg
50 | CUB_200_2011/images/150.Sage_Thrasher/Sage_Thrasher_0070_155732.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_11.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0020_183255.jpg
2 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0005_183414.jpg
3 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0068_183662.jpg
4 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0013_182721.jpg
5 | CUB_200_2011/images/191.Red_headed_Woodpecker/Red_Headed_Woodpecker_0095_183688.jpg
6 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0040_184061.jpg
7 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0031_184120.jpg
8 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0090_183964.jpg
9 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0005_184098.jpg
10 | CUB_200_2011/images/192.Downy_Woodpecker/Downy_Woodpecker_0136_184534.jpg
11 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0083_185190.jpg
12 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0084_184715.jpg
13 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0015_184981.jpg
14 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0110_185216.jpg
15 | CUB_200_2011/images/193.Bewick_Wren/Bewick_Wren_0081_185080.jpg
16 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0089_186023.jpg
17 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0097_186015.jpg
18 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0025_185696.jpg
19 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0066_186028.jpg
20 | CUB_200_2011/images/194.Cactus_Wren/Cactus_Wren_0033_186014.jpg
21 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0113_186675.jpg
22 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0099_186237.jpg
23 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0014_186525.jpg
24 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0020_186702.jpg
25 | CUB_200_2011/images/195.Carolina_Wren/Carolina_Wren_0128_186581.jpg
26 | CUB_200_2011/images/196.House_Wren/House_Wren_0108_187102.jpg
27 | CUB_200_2011/images/196.House_Wren/House_Wren_0107_187230.jpg
28 | CUB_200_2011/images/196.House_Wren/House_Wren_0035_187708.jpg
29 | CUB_200_2011/images/196.House_Wren/House_Wren_0094_187226.jpg
30 | CUB_200_2011/images/196.House_Wren/House_Wren_0122_187331.jpg
31 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0056_188241.jpg
32 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0141_188796.jpg
33 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0006_188126.jpg
34 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0044_188270.jpg
35 | CUB_200_2011/images/197.Marsh_Wren/Marsh_Wren_0039_188201.jpg
36 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0122_189042.jpg
37 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0063_189121.jpg
38 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0069_188969.jpg
39 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0111_189443.jpg
40 | CUB_200_2011/images/198.Rock_Wren/Rock_Wren_0027_189331.jpg
41 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0066_189637.jpg
42 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0030_190311.jpg
43 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0075_189578.jpg
44 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0065_189675.jpg
45 | CUB_200_2011/images/199.Winter_Wren/Winter_Wren_0037_190123.jpg
46 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0004_190606.jpg
47 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0054_190398.jpg
48 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0010_190572.jpg
49 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0126_190407.jpg
50 | CUB_200_2011/images/200.Common_Yellowthroat/Common_Yellowthroat_0032_190592.jpg
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration (TEEN)
2 |
3 | The code repository for "Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration" [[paper]](https://arxiv.org/abs/2312.05229) (NeurIPS'23) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:
4 |
5 | @inproceedings{
6 | wang2023teen,
7 | title={Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration},
8 | author={Wang, Qi-Wei and Zhou, Da-Wei and Zhang, Yi-Kai and Zhan, De-Chuan, and Ye, Han-Jia},
9 | booktitle={NeurIPS},
10 | year={2023}
11 | }
12 |
13 | ## Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration
14 |
15 | Real-world scenarios are usually accompanied by continuously appearing classes with scarce labeled samples, which require the machine learning model to incrementally learn new classes and maintain the knowledge of base classes. In this Few-Shot Class-Incremental Learning (FSCIL) scenario, existing methods either introduce extra learnable components or rely on a frozen feature extractor to mitigate catastrophic forgetting and overfitting problems. However, we find a tendency for existing methods to misclassify the samples of new classes into base classes, which leads to the poor performance of new classes. In other words, the strong discriminability of base classes distracts the classification of new classes. To figure out this intriguing phenomenon, we observe that although the feature extractor is only trained on base classes, it can surprisingly represent the semantic
16 | similarity between the base and unseen new classes. Building upon these analyses, we propose a simple yet effective Training-frEE calibratioN (TEEN) strategy to enhance the discriminability of new classes by fusing the new prototypes (i.e., mean features of a class) with weighted base prototypes. In addition to standard benchmarks in FSCIL, TEEN demonstrates remarkable performance and consis- tent improvements over baseline methods in the few-shot learning scenario.
17 |
18 |
19 | ## Results
20 |
21 |
22 | Please refer to our [paper](https://arxiv.org/abs/2312.05229) for detailed values.
23 |
24 | ## Prerequisites
25 |
26 | The following packages are required to run the scripts:
27 |
28 | - [PyTorch-1.4 and torchvision](https://pytorch.org)
29 |
30 | - tqdm
31 |
32 | ## Dataset
33 | We provide the source code on three benchmark datasets, i.e., CIFAR100, CUB200 and miniImageNet. Please follow the guidelines in [CEC](https://github.com/icoz69/CEC-CVPR2021) to prepare them.
34 |
35 |
36 | ## Code Structures and details
37 | There are four parts in the code.
38 | - `models`: It contains the backbone network and training protocols for the experiment.
39 | - `data`: Images and splits for the data sets.
40 | - `dataloader`: Dataloader of different datasets.
41 |
42 | ## Training scripts
43 |
44 | Please see `scripts` folder.
45 |
46 |
47 | ## Acknowledgment
48 | We thank the following repos providing helpful components/functions in our work.
49 |
50 | - [Awesome Few-Shot Class-Incremental Learning](https://github.com/zhoudw-zdw/Awesome-Few-Shot-Class-Incremental-Learning)
51 | - [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/G-U-N/PyCIL)
52 | - [CEC](https://github.com/icoz69/CEC-CVPR2021)
53 | - [FACT](https://github.com/zhoudw-zdw/CVPR22-Fact)
54 |
55 |
56 |
57 | ## Contact
58 | If there are any questions, please feel free to contact with the author: Qi-Wei Wang (wangqiwei@lamda.nju.edu.cn). Enjoy the code.
--------------------------------------------------------------------------------
/data/index_list/cub200/session_2.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0081_96148.jpg
2 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0075_96422.jpg
3 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0026_95832.jpg
4 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0022_95897.jpg
5 | CUB_200_2011/images/101.White_Pelican/White_Pelican_0044_96028.jpg
6 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0072_98035.jpg
7 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0004_98257.jpg
8 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0060_795045.jpg
9 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0039_795063.jpg
10 | CUB_200_2011/images/102.Western_Wood_Pewee/Western_Wood_Pewee_0040_795051.jpg
11 | CUB_200_2011/images/103.Sayornis/Sayornis_0099_98593.jpg
12 | CUB_200_2011/images/103.Sayornis/Sayornis_0133_99129.jpg
13 | CUB_200_2011/images/103.Sayornis/Sayornis_0098_98419.jpg
14 | CUB_200_2011/images/103.Sayornis/Sayornis_0011_98610.jpg
15 | CUB_200_2011/images/103.Sayornis/Sayornis_0114_98976.jpg
16 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0037_99954.jpg
17 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0067_100237.jpg
18 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0019_99810.jpg
19 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0058_100218.jpg
20 | CUB_200_2011/images/104.American_Pipit/American_Pipit_0113_99939.jpg
21 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0038_100443.jpg
22 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0018_796403.jpg
23 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0013_796439.jpg
24 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0026_100456.jpg
25 | CUB_200_2011/images/105.Whip_poor_Will/Whip_Poor_Will_0004_100479.jpg
26 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0004_100733.jpg
27 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0028_100765.jpg
28 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0062_100693.jpg
29 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0042_100760.jpg
30 | CUB_200_2011/images/106.Horned_Puffin/Horned_Puffin_0030_100725.jpg
31 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0009_102112.jpg
32 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0068_101216.jpg
33 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0099_102534.jpg
34 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0001_101213.jpg
35 | CUB_200_2011/images/107.Common_Raven/Common_Raven_0095_101831.jpg
36 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0063_797361.jpg
37 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0050_797374.jpg
38 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0010_797350.jpg
39 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0002_797370.jpg
40 | CUB_200_2011/images/108.White_necked_Raven/White_Necked_Raven_0026_797357.jpg
41 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0036_103231.jpg
42 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0071_103266.jpg
43 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0085_103155.jpg
44 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0056_103241.jpg
45 | CUB_200_2011/images/109.American_Redstart/American_Redstart_0049_103176.jpg
46 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0106_104216.jpg
47 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0086_104755.jpg
48 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0124_104141.jpg
49 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0117_104227.jpg
50 | CUB_200_2011/images/110.Geococcyx/Geococcyx_0036_104173.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_5.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0079_125579.jpg
2 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0080_125606.jpg
3 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0084_125532.jpg
4 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0094_125602.jpg
5 | CUB_200_2011/images/131.Vesper_Sparrow/Vesper_Sparrow_0019_125558.jpg
6 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0068_126156.jpg
7 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0100_126267.jpg
8 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0072_127080.jpg
9 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0033_127728.jpg
10 | CUB_200_2011/images/132.White_crowned_Sparrow/White_Crowned_Sparrow_0095_127118.jpg
11 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0125_128832.jpg
12 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0056_128906.jpg
13 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0085_129180.jpg
14 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0042_128899.jpg
15 | CUB_200_2011/images/133.White_throated_Sparrow/White_Throated_Sparrow_0021_128804.jpg
16 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0096_129388.jpg
17 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0046_129434.jpg
18 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0043_129358.jpg
19 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0019_129407.jpg
20 | CUB_200_2011/images/134.Cape_Glossy_Starling/Cape_Glossy_Starling_0067_129380.jpg
21 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0003_129623.jpg
22 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0045_129483.jpg
23 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0020_129747.jpg
24 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0067_129959.jpg
25 | CUB_200_2011/images/135.Bank_Swallow/Bank_Swallow_0053_129501.jpg
26 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0018_130709.jpg
27 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0048_132793.jpg
28 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0070_130127.jpg
29 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0066_130214.jpg
30 | CUB_200_2011/images/136.Barn_Swallow/Barn_Swallow_0049_130181.jpg
31 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0018_132974.jpg
32 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0023_134314.jpg
33 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0066_133206.jpg
34 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0050_134054.jpg
35 | CUB_200_2011/images/137.Cliff_Swallow/Cliff_Swallow_0075_134516.jpg
36 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0087_137354.jpg
37 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0043_136878.jpg
38 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0111_135253.jpg
39 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0108_135068.jpg
40 | CUB_200_2011/images/138.Tree_Swallow/Tree_Swallow_0064_136322.jpg
41 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0107_138577.jpg
42 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0077_137626.jpg
43 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0040_137885.jpg
44 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0033_137603.jpg
45 | CUB_200_2011/images/139.Scarlet_Tanager/Scarlet_Tanager_0132_138001.jpg
46 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0032_140425.jpg
47 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0046_139802.jpg
48 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0111_139605.jpg
49 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0116_139923.jpg
50 | CUB_200_2011/images/140.Summer_Tanager/Summer_Tanager_0095_139882.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_3.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0127_105742.jpg
2 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0018_26407.jpg
3 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0019_106132.jpg
4 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0011_104921.jpg
5 | CUB_200_2011/images/111.Loggerhead_Shrike/Loggerhead_Shrike_0033_105686.jpg
6 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0092_797048.jpg
7 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0042_797056.jpg
8 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0049_797025.jpg
9 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0083_797051.jpg
10 | CUB_200_2011/images/112.Great_Grey_Shrike/Great_Grey_Shrike_0063_797042.jpg
11 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0021_794576.jpg
12 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0018_794584.jpg
13 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0025_794564.jpg
14 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0041_794582.jpg
15 | CUB_200_2011/images/113.Baird_Sparrow/Baird_Sparrow_0036_794572.jpg
16 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0019_107192.jpg
17 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0088_107220.jpg
18 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0097_106935.jpg
19 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0055_107213.jpg
20 | CUB_200_2011/images/114.Black_throated_Sparrow/Black_Throated_Sparrow_0010_107375.jpg
21 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0068_107422.jpg
22 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0036_107451.jpg
23 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0041_796711.jpg
24 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0014_107435.jpg
25 | CUB_200_2011/images/115.Brewer_Sparrow/Brewer_Sparrow_0076_107393.jpg
26 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0064_108204.jpg
27 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0038_109234.jpg
28 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0098_108644.jpg
29 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0110_108974.jpg
30 | CUB_200_2011/images/116.Chipping_Sparrow/Chipping_Sparrow_0023_108684.jpg
31 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0104_110699.jpg
32 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0098_110735.jpg
33 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0003_110672.jpg
34 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0029_110720.jpg
35 | CUB_200_2011/images/117.Clay_colored_Sparrow/Clay_Colored_Sparrow_0087_110946.jpg
36 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0092_111413.jpg
37 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0111_112968.jpg
38 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0080_111099.jpg
39 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0130_110985.jpg
40 | CUB_200_2011/images/118.House_Sparrow/House_Sparrow_0053_111388.jpg
41 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0069_113827.jpg
42 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0130_113846.jpg
43 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0091_113486.jpg
44 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0043_113607.jpg
45 | CUB_200_2011/images/119.Field_Sparrow/Field_Sparrow_0108_114154.jpg
46 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0104_114908.jpg
47 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0086_115484.jpg
48 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0055_114809.jpg
49 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0012_115324.jpg
50 | CUB_200_2011/images/120.Fox_Sparrow/Fox_Sparrow_0035_114866.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_9.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0023_166764.jpg
2 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0050_166820.jpg
3 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0043_166708.jpg
4 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0098_166794.jpg
5 | CUB_200_2011/images/171.Myrtle_Warbler/Myrtle_Warbler_0015_166713.jpg
6 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0108_167259.jpg
7 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0098_167293.jpg
8 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0104_167096.jpg
9 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0110_167268.jpg
10 | CUB_200_2011/images/172.Nashville_Warbler/Nashville_Warbler_0081_167234.jpg
11 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0062_168119.jpg
12 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0050_168166.jpg
13 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0055_168600.jpg
14 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0118_167640.jpg
15 | CUB_200_2011/images/173.Orange_crowned_Warbler/Orange_Crowned_Warbler_0067_167588.jpg
16 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0083_170281.jpg
17 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0012_170857.jpg
18 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0015_169626.jpg
19 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0126_170311.jpg
20 | CUB_200_2011/images/174.Palm_Warbler/Palm_Warbler_0136_170276.jpg
21 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0017_171678.jpg
22 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0127_171742.jpg
23 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0060_171635.jpg
24 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0056_172064.jpg
25 | CUB_200_2011/images/175.Pine_Warbler/Pine_Warbler_0102_171147.jpg
26 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0073_172771.jpg
27 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0120_173097.jpg
28 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0063_172682.jpg
29 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0053_173290.jpg
30 | CUB_200_2011/images/176.Prairie_Warbler/Prairie_Warbler_0080_172724.jpg
31 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0062_174412.jpg
32 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0037_173418.jpg
33 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0076_174118.jpg
34 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0070_174650.jpg
35 | CUB_200_2011/images/177.Prothonotary_Warbler/Prothonotary_Warbler_0110_173857.jpg
36 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0017_174685.jpg
37 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0039_794859.jpg
38 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0051_794900.jpg
39 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0037_174691.jpg
40 | CUB_200_2011/images/178.Swainson_Warbler/Swainson_Warbler_0018_174715.jpg
41 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0051_175015.jpg
42 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0019_174786.jpg
43 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0023_174977.jpg
44 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0033_174772.jpg
45 | CUB_200_2011/images/179.Tennessee_Warbler/Tennessee_Warbler_0004_174997.jpg
46 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0107_175320.jpg
47 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0065_175924.jpg
48 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0129_175256.jpg
49 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0126_175368.jpg
50 | CUB_200_2011/images/180.Wilson_Warbler/Wilson_Warbler_0054_175285.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_4.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0014_116129.jpg
2 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0114_116160.jpg
3 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0068_115799.jpg
4 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0110_115644.jpg
5 | CUB_200_2011/images/121.Grasshopper_Sparrow/Grasshopper_Sparrow_0042_115638.jpg
6 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0006_116364.jpg
7 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0018_116402.jpg
8 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0026_116620.jpg
9 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0020_116379.jpg
10 | CUB_200_2011/images/122.Harris_Sparrow/Harris_Sparrow_0011_116597.jpg
11 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0023_796582.jpg
12 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0052_796599.jpg
13 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0054_116850.jpg
14 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0064_796573.jpg
15 | CUB_200_2011/images/123.Henslow_Sparrow/Henslow_Sparrow_0070_796571.jpg
16 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0040_117088.jpg
17 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0072_795230.jpg
18 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0068_795180.jpg
19 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0081_795215.jpg
20 | CUB_200_2011/images/124.Le_Conte_Sparrow/Le_Conte_Sparrow_0032_795186.jpg
21 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0084_117492.jpg
22 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0009_117535.jpg
23 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0014_117883.jpg
24 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0042_117507.jpg
25 | CUB_200_2011/images/125.Lincoln_Sparrow/Lincoln_Sparrow_0072_117951.jpg
26 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0056_117974.jpg
27 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0002_796908.jpg
28 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0051_796902.jpg
29 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0014_796906.jpg
30 | CUB_200_2011/images/126.Nelson_Sharp_tailed_Sparrow/Nelson_Sharp_Tailed_Sparrow_0077_796913.jpg
31 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0049_119596.jpg
32 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0118_118603.jpg
33 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0068_119972.jpg
34 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0052_118583.jpg
35 | CUB_200_2011/images/127.Savannah_Sparrow/Savannah_Sparrow_0054_120057.jpg
36 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0001_120720.jpg
37 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0048_120758.jpg
38 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0042_796528.jpg
39 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0049_120735.jpg
40 | CUB_200_2011/images/128.Seaside_Sparrow/Seaside_Sparrow_0035_796533.jpg
41 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0046_121903.jpg
42 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0055_121158.jpg
43 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0107_120990.jpg
44 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0091_121651.jpg
45 | CUB_200_2011/images/129.Song_Sparrow/Song_Sparrow_0087_121062.jpg
46 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0094_124974.jpg
47 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0123_125324.jpg
48 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0041_123497.jpg
49 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0086_123751.jpg
50 | CUB_200_2011/images/130.Tree_Sparrow/Tree_Sparrow_0119_124114.jpg
51 |
--------------------------------------------------------------------------------
/models/resnet20_cifar.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=1, bias=False)
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, inplanes, planes, stride=1, downsample=None, last=False):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = conv3x3(inplanes, planes, stride)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.conv2 = conv3x3(planes, planes)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.downsample = downsample
21 | self.stride = stride
22 | self.last = last
23 |
24 | def forward(self, x, train=True):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 |
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 |
37 | out += residual
38 |
39 | out = self.relu(out)
40 |
41 | return out
42 |
43 | class ResNet(nn.Module):
44 |
45 | def __init__(self, block, layers, num_classes=10):
46 | self.inplanes = 16
47 | super(ResNet, self).__init__()
48 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
49 | bias=False)
50 | self.bn1 = nn.BatchNorm2d(16)
51 | self.relu = nn.ReLU(inplace=True)
52 | self.layer1 = self._make_layer(block, 16, layers[0])
53 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
54 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, last_phase=True)
55 | # self.avgpool = nn.AvgPool2d(8, stride=1)
56 |
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d):
59 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
60 | elif isinstance(m, nn.BatchNorm2d):
61 | nn.init.constant_(m.weight, 1)
62 | nn.init.constant_(m.bias, 0)
63 |
64 | def _make_layer(self, block, planes, blocks, stride=1, last_phase=False):
65 | downsample = None
66 | if stride != 1 or self.inplanes != planes * block.expansion:
67 | downsample = nn.Sequential(
68 | nn.Conv2d(self.inplanes, planes * block.expansion,
69 | kernel_size=1, stride=stride, bias=False),
70 | nn.BatchNorm2d(planes * block.expansion),
71 | )
72 |
73 | layers = []
74 | layers.append(block(self.inplanes, planes, stride, downsample))
75 | self.inplanes = planes * block.expansion
76 | if last_phase:
77 | for i in range(1, blocks-1):
78 | layers.append(block(self.inplanes, planes))
79 | layers.append(block(self.inplanes, planes, last=True))
80 | else:
81 | for i in range(1, blocks):
82 | layers.append(block(self.inplanes, planes))
83 |
84 | return nn.Sequential(*layers)
85 |
86 | def forward(self, x):
87 | x = self.conv1(x)
88 | x = self.bn1(x)
89 | x = self.relu(x)
90 |
91 | x = self.layer1(x)
92 | x = self.layer2(x)
93 | x = self.layer3(x)
94 |
95 | # x = self.avgpool(x)
96 | # x = x.view(x.size(0), -1)
97 | # x = self.fc(x)
98 |
99 | return x
100 |
101 | def resnet20(**kwargs):
102 | n = 3
103 | model = ResNet(BasicBlock, [n, n, n], **kwargs)
104 | return model
105 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_8.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0071_161900.jpg
2 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0035_161741.jpg
3 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0054_161862.jpg
4 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0023_161774.jpg
5 | CUB_200_2011/images/161.Blue_winged_Warbler/Blue_Winged_Warbler_0040_161883.jpg
6 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0113_162403.jpg
7 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0064_162417.jpg
8 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0091_162378.jpg
9 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0016_162411.jpg
10 | CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0080_162392.jpg
11 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0012_162701.jpg
12 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0103_162972.jpg
13 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0022_162912.jpg
14 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0005_163197.jpg
15 | CUB_200_2011/images/163.Cape_May_Warbler/Cape_May_Warbler_0032_162659.jpg
16 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0039_163420.jpg
17 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0020_163353.jpg
18 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0014_797226.jpg
19 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0072_163200.jpg
20 | CUB_200_2011/images/164.Cerulean_Warbler/Cerulean_Warbler_0080_163399.jpg
21 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0128_163696.jpg
22 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0097_163750.jpg
23 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0094_164152.jpg
24 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0105_163996.jpg
25 | CUB_200_2011/images/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0101_164324.jpg
26 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0079_794820.jpg
27 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0046_794828.jpg
28 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0061_164516.jpg
29 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0068_794825.jpg
30 | CUB_200_2011/images/166.Golden_winged_Warbler/Golden_Winged_Warbler_0011_794812.jpg
31 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0040_165173.jpg
32 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0001_164704.jpg
33 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0021_165057.jpg
34 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0058_164674.jpg
35 | CUB_200_2011/images/167.Hooded_Warbler/Hooded_Warbler_0053_164631.jpg
36 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0008_165369.jpg
37 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0035_795878.jpg
38 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0050_165278.jpg
39 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0071_165342.jpg
40 | CUB_200_2011/images/168.Kentucky_Warbler/Kentucky_Warbler_0072_165305.jpg
41 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0041_165709.jpg
42 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0092_165807.jpg
43 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0029_165567.jpg
44 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0030_165782.jpg
45 | CUB_200_2011/images/169.Magnolia_Warbler/Magnolia_Warbler_0053_165682.jpg
46 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0069_166559.jpg
47 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0035_166586.jpg
48 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0002_166520.jpg
49 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0079_166564.jpg
50 | CUB_200_2011/images/170.Mourning_Warbler/Mourning_Warbler_0015_166535.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_7.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0012_797473.jpg
2 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0007_797481.jpg
3 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0020_797461.jpg
4 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0053_797478.jpg
5 | CUB_200_2011/images/151.Black_capped_Vireo/Black_Capped_Vireo_0003_797467.jpg
6 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0097_156272.jpg
7 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0019_156311.jpg
8 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0121_156233.jpg
9 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0011_156276.jpg
10 | CUB_200_2011/images/152.Blue_headed_Vireo/Blue_Headed_Vireo_0119_156259.jpg
11 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0078_794776.jpg
12 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0039_794794.jpg
13 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0068_794763.jpg
14 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0012_794785.jpg
15 | CUB_200_2011/images/153.Philadelphia_Vireo/Philadelphia_Vireo_0013_794772.jpg
16 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0101_156988.jpg
17 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0006_157025.jpg
18 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0041_156954.jpg
19 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0115_157004.jpg
20 | CUB_200_2011/images/154.Red_eyed_Vireo/Red_Eyed_Vireo_0056_156968.jpg
21 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0075_158480.jpg
22 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0061_158494.jpg
23 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0004_158376.jpg
24 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0030_158488.jpg
25 | CUB_200_2011/images/155.Warbling_Vireo/Warbling_Vireo_0077_158427.jpg
26 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0042_159012.jpg
27 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0033_159079.jpg
28 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0126_159341.jpg
29 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0071_159072.jpg
30 | CUB_200_2011/images/156.White_eyed_Vireo/White_Eyed_Vireo_0016_158978.jpg
31 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0066_795007.jpg
32 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0032_159632.jpg
33 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0017_794988.jpg
34 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0025_795009.jpg
35 | CUB_200_2011/images/157.Yellow_throated_Vireo/Yellow_Throated_Vireo_0058_794994.jpg
36 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0073_797138.jpg
37 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0081_159963.jpg
38 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0071_797108.jpg
39 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0105_797143.jpg
40 | CUB_200_2011/images/158.Bay_breasted_Warbler/Bay_Breasted_Warbler_0052_797125.jpg
41 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0057_160037.jpg
42 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0035_160102.jpg
43 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0119_160898.jpg
44 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0102_160073.jpg
45 | CUB_200_2011/images/159.Black_and_white_Warbler/Black_And_White_Warbler_0022_160512.jpg
46 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0050_161154.jpg
47 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0130_161682.jpg
48 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0133_161539.jpg
49 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0054_161158.jpg
50 | CUB_200_2011/images/160.Black_throated_Blue_Warbler/Black_Throated_Blue_Warbler_0024_161619.jpg
51 |
--------------------------------------------------------------------------------
/data/index_list/cub200/session_10.txt:
--------------------------------------------------------------------------------
1 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0063_795553.jpg
2 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0011_795566.jpg
3 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0092_795524.jpg
4 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0006_176037.jpg
5 | CUB_200_2011/images/181.Worm_eating_Warbler/Worm_Eating_Warbler_0018_795546.jpg
6 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0083_176292.jpg
7 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0096_176586.jpg
8 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0119_176485.jpg
9 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0102_176821.jpg
10 | CUB_200_2011/images/182.Yellow_Warbler/Yellow_Warbler_0049_176526.jpg
11 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0043_177070.jpg
12 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0080_177080.jpg
13 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0022_177003.jpg
14 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0050_177331.jpg
15 | CUB_200_2011/images/183.Northern_Waterthrush/Northern_Waterthrush_0014_177305.jpg
16 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0087_795261.jpg
17 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0001_795271.jpg
18 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0034_795242.jpg
19 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0020_795265.jpg
20 | CUB_200_2011/images/184.Louisiana_Waterthrush/Louisiana_Waterthrush_0077_795247.jpg
21 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0046_177864.jpg
22 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0042_177887.jpg
23 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0024_177661.jpg
24 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0031_796633.jpg
25 | CUB_200_2011/images/185.Bohemian_Waxwing/Bohemian_Waxwing_0048_177821.jpg
26 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0094_178049.jpg
27 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0016_178629.jpg
28 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0125_178921.jpg
29 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0004_179215.jpg
30 | CUB_200_2011/images/186.Cedar_Waxwing/Cedar_Waxwing_0065_179017.jpg
31 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0018_179831.jpg
32 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0007_179932.jpg
33 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0024_179876.jpg
34 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0009_179919.jpg
35 | CUB_200_2011/images/187.American_Three_toed_Woodpecker/American_Three_Toed_Woodpecker_0012_179905.jpg
36 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0056_180094.jpg
37 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0034_180419.jpg
38 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0110_180521.jpg
39 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0079_180388.jpg
40 | CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0088_180054.jpg
41 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0112_180827.jpg
42 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0017_181131.jpg
43 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0125_180780.jpg
44 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0086_181891.jpg
45 | CUB_200_2011/images/189.Red_bellied_Woodpecker/Red_Bellied_Woodpecker_0020_182335.jpg
46 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0023_794701.jpg
47 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0033_794721.jpg
48 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0027_794713.jpg
49 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0029_794724.jpg
50 | CUB_200_2011/images/190.Red_cockaded_Woodpecker/Red_Cockaded_Woodpecker_0039_794736.jpg
51 |
--------------------------------------------------------------------------------
/models/teen/Network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.resnet18_encoder import *
5 | from models.resnet20_cifar import *
6 |
7 |
8 | class MYNET(nn.Module):
9 |
10 | def __init__(self, args, mode=None):
11 | super().__init__()
12 |
13 | self.mode = mode
14 | self.args = args
15 | if self.args.dataset in ['cifar100','manyshotcifar']:
16 | self.encoder = resnet20()
17 | self.num_features = 64
18 | if self.args.dataset in ['mini_imagenet','manyshotmini','imagenet100','imagenet1000', 'mini_imagenet_withpath']:
19 | self.encoder = resnet18(False, args) # pretrained=False
20 | self.num_features = 512
21 | if self.args.dataset in ['cub200','manyshotcub']:
22 | self.encoder = resnet18(True, args) # pretrained=True follow TOPIC, models for cub is imagenet pre-trained. https://github.com/xyutao/fscil/issues/11#issuecomment-687548790
23 | self.num_features = 512
24 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
25 |
26 | self.fc = nn.Linear(self.num_features, self.args.num_classes, bias=False)
27 |
28 | def forward_metric(self, x):
29 | x = self.encode(x)
30 | if 'cos' in self.mode:
31 | x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1))
32 | x = self.args.temperature * x
33 |
34 | elif 'dot' in self.mode:
35 | x = self.fc(x)
36 | x = self.args.temperature * x
37 | return x
38 |
39 | def encode(self, x):
40 | x = self.encoder(x)
41 | x = F.adaptive_avg_pool2d(x, 1)
42 | x = x.squeeze(-1).squeeze(-1)
43 | return x
44 |
45 | def forward(self, input):
46 | if self.mode != 'encoder':
47 | input = self.forward_metric(input)
48 | return input
49 | elif self.mode == 'encoder':
50 | input = self.encode(input)
51 | return input
52 | else:
53 | raise ValueError('Unknown mode')
54 |
55 | def update_fc(self,dataloader,class_list,session):
56 | for batch in dataloader:
57 | data, label = [_.cuda() for _ in batch]
58 | data=self.encode(data).detach()
59 |
60 | if self.args.not_data_init:
61 | new_fc = nn.Parameter(
62 | torch.rand(len(class_list), self.num_features, device="cuda"),
63 | requires_grad=True)
64 | nn.init.kaiming_uniform_(new_fc, a=math.sqrt(5))
65 | else:
66 | new_fc = self.update_fc_avg(data, label, class_list)
67 |
68 | def update_fc_avg(self,data,label,class_list):
69 | new_fc=[]
70 | for class_index in class_list:
71 | data_index=(label==class_index).nonzero().squeeze(-1)
72 | embedding=data[data_index]
73 | proto=embedding.mean(0)
74 | new_fc.append(proto)
75 | self.fc.weight.data[class_index]=proto
76 | new_fc=torch.stack(new_fc,dim=0)
77 | return new_fc
78 |
79 | def get_logits(self,x,fc):
80 | if 'dot' in self.args.new_mode:
81 | return F.linear(x,fc)
82 | elif 'cos' in self.args.new_mode:
83 | return self.args.temperature * F.linear(F.normalize(x, p=2, dim=-1), F.normalize(fc, p=2, dim=-1))
84 |
85 | def soft_calibration(self, args, session):
86 | base_protos = self.fc.weight.data[:args.base_class].detach().cpu().data
87 | base_protos = F.normalize(base_protos, p=2, dim=-1)
88 |
89 | cur_protos = self.fc.weight.data[args.base_class + (session-1) * args.way : args.base_class + session * args.way].detach().cpu().data
90 | cur_protos = F.normalize(cur_protos, p=2, dim=-1)
91 |
92 | weights = torch.mm(cur_protos, base_protos.T) * args.softmax_t
93 | norm_weights = torch.softmax(weights, dim=1)
94 | delta_protos = torch.matmul(norm_weights, base_protos)
95 |
96 | delta_protos = F.normalize(delta_protos, p=2, dim=-1)
97 |
98 | updated_protos = (1-args.shift_weight) * cur_protos + args.shift_weight * delta_protos
99 |
100 | self.fc.weight.data[args.base_class + (session-1) * args.way : args.base_class + session * args.way] = updated_protos
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib
3 | import logging
4 | import sys
5 |
6 | from postprocess_path import set_save_path
7 | from utils import Logger, pprint, set_gpu, set_logging, set_seed
8 |
9 |
10 | def get_command_line_parser():
11 | parser = argparse.ArgumentParser()
12 | # about dataset and network
13 | parser.add_argument('-project', type=str, default='base', choices=['teen'])
14 | parser.add_argument('-dataset', type=str, default='cifar100',
15 | choices=['mini_imagenet', 'cub200', 'cifar100'])
16 | parser.add_argument('-dataroot', type=str, default='')
17 | parser.add_argument('-temperature', type=float, default=16)
18 | parser.add_argument('-feat_norm', action='store_true', help='If True, normalize the feature.')
19 |
20 | # about pre-training
21 | parser.add_argument('-epochs_base', type=int, default=100)
22 | parser.add_argument('-epochs_new', type=int, default=100)
23 | parser.add_argument('-lr_base', type=float, default=0.1)
24 | parser.add_argument('-lr_new', type=float, default=0.1)
25 |
26 | ## optimizer & scheduler
27 | parser.add_argument('-optim', type=str, default='sgd', choices=['sgd', 'adam'])
28 | parser.add_argument('-schedule', type=str, default='Step', choices=['Step', 'Milestone','Cosine'])
29 | parser.add_argument('-milestones', nargs='+', type=int, default=[60, 70])
30 | parser.add_argument('-step', type=int, default=20)
31 | parser.add_argument('-decay', type=float, default=0.0005)
32 | parser.add_argument('-momentum', type=float, default=0.9)
33 | parser.add_argument('-gamma', type=float, default=0.1)
34 | parser.add_argument('-tmax', type=int, default=600) #consine scheduler
35 |
36 | parser.add_argument('-not_data_init', action='store_true', help='using average data embedding to init or not')
37 | parser.add_argument('-batch_size_base', type=int, default=128)
38 | parser.add_argument('-batch_size_new', type=int, default=0, help='set 0 will use all the availiable training image for new')
39 | parser.add_argument('-test_batch_size', type=int, default=100)
40 | parser.add_argument('-base_mode', type=str, default='ft_cos',
41 | choices=['ft_dot', 'ft_cos']) # ft_dot means using linear classifier, ft_cos means using cosine classifier
42 | parser.add_argument('-new_mode', type=str, default='avg_cos',
43 | choices=['ft_dot', 'ft_cos', 'avg_cos']) # ft_dot means using linear classifier, ft_cos means using cosine classifier, avg_cos means using average data embedding and cosine classifier
44 |
45 | parser.add_argument('-start_session', type=int, default=0)
46 | parser.add_argument('-model_dir', type=str, default=None, help='loading model parameter from a specific dir')
47 | parser.add_argument('-only_do_incre', action='store_true', help='Load model and incremental learning...')
48 |
49 | # about training
50 | parser.add_argument('-gpu', default='0,1,2,3')
51 | parser.add_argument('-num_workers', type=int, default=8)
52 | parser.add_argument('-seed', type=int, default=1)
53 | parser.add_argument('-debug', action='store_true')
54 |
55 | return parser
56 |
57 | def add_commond_line_parser(params):
58 | project = params[1]
59 | # base parser
60 | parser = get_command_line_parser()
61 |
62 | if project == 'base':
63 | args = parser.parse_args(params[2:])
64 | return args
65 |
66 | elif project == 'teen':
67 | parser.add_argument('-softmax_t', type=float, default=16)
68 | parser.add_argument('-shift_weight', type=float, default=0.5, help='weights of delta prototypes')
69 | parser.add_argument('-soft_mode', type=str, default='soft_proto', choices=['soft_proto', 'soft_embed', 'hard_proto'])
70 | args = parser.parse_args(params[2:])
71 | return args
72 | else:
73 | raise NotImplementedError
74 |
75 | if __name__ == '__main__':
76 | args = add_commond_line_parser(sys.argv)
77 |
78 | set_seed(args.seed)
79 | pprint(vars(args))
80 | args.num_gpu = set_gpu(args)
81 |
82 | set_save_path(args)
83 |
84 | logger = Logger(args, args.save_path)
85 | set_logging('INFO', args.save_path)
86 | logging.info(f"save_path: {args.save_path}")
87 | trainer = importlib.import_module('models.%s.fscil_trainer' % (args.project)).FSCILTrainer(args)
88 | trainer.train()
89 |
--------------------------------------------------------------------------------
/models/teen/helper.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from tqdm import tqdm
3 | import torch.nn.functional as F
4 | import logging
5 |
6 |
7 | def base_train(model, trainloader, optimizer, scheduler, epoch, args):
8 | tl = Averager()
9 | ta = Averager()
10 | model = model.train()
11 | # standard classification for pretrain
12 | tqdm_gen = tqdm(trainloader)
13 | for i, batch in enumerate(tqdm_gen, 1):
14 | data, train_label = [_.cuda() for _ in batch]
15 |
16 | logits = model(data)
17 | logits = logits[:, :args.base_class]
18 | loss = F.cross_entropy(logits, train_label)
19 | acc = count_acc(logits, train_label)
20 |
21 | total_loss = loss
22 |
23 | lrc = scheduler.get_last_lr()[0]
24 | tqdm_gen.set_description(
25 | 'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
26 | tl.add(total_loss.item())
27 | ta.add(acc)
28 |
29 | optimizer.zero_grad()
30 | loss.backward()
31 | optimizer.step()
32 | tl = tl.item()
33 | ta = ta.item()
34 | return tl, ta
35 |
36 | def replace_base_fc(trainset, transform, model, args):
37 | # replace fc.weight with the embedding average of train data
38 | model = model.eval()
39 |
40 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128,
41 | num_workers=8, pin_memory=True, shuffle=False)
42 | trainloader.dataset.transform = transform
43 | embedding_list = []
44 | label_list = []
45 | with torch.no_grad():
46 | for i, batch in enumerate(trainloader):
47 | data, label = [_.cuda() for _ in batch]
48 | model.module.mode = 'encoder'
49 | embedding = model(data)
50 |
51 | embedding_list.append(embedding.cpu())
52 | label_list.append(label.cpu())
53 | embedding_list = torch.cat(embedding_list, dim=0)
54 | label_list = torch.cat(label_list, dim=0)
55 |
56 | proto_list = []
57 |
58 | for class_index in range(args.base_class):
59 | data_index = (label_list == class_index).nonzero()
60 | embedding_this = embedding_list[data_index.squeeze(-1)]
61 | embedding_this = embedding_this.mean(0)
62 | proto_list.append(embedding_this)
63 |
64 | proto_list = torch.stack(proto_list, dim=0)
65 |
66 | model.module.fc.weight.data[:args.base_class] = proto_list
67 |
68 | return model
69 |
70 | def test(model, testloader, epoch, args, session, result_list=None):
71 | test_class = args.base_class + session * args.way
72 | model = model.eval()
73 | vl = Averager()
74 | va = Averager()
75 | va5= Averager()
76 | lgt=torch.tensor([])
77 | lbs=torch.tensor([])
78 | with torch.no_grad():
79 | for i, batch in enumerate(testloader, 1):
80 | data, test_label = [_.cuda() for _ in batch]
81 | logits = model(data)
82 | logits = logits[:, :test_class]
83 | loss = F.cross_entropy(logits, test_label)
84 | acc = count_acc(logits, test_label)
85 | top5acc = count_acc_topk(logits, test_label)
86 |
87 | vl.add(loss.item())
88 | va.add(acc)
89 | va5.add(top5acc)
90 |
91 | lgt=torch.cat([lgt,logits.cpu()])
92 | lbs=torch.cat([lbs,test_label.cpu()])
93 | vl = vl.item()
94 | va = va.item()
95 | va5= va5.item()
96 |
97 | logging.info('epo {}, test, loss={:.4f} acc={:.4f}, acc@5={:.4f}'.format(epoch, vl, va, va5))
98 |
99 | lgt=lgt.view(-1, test_class)
100 | lbs=lbs.view(-1)
101 |
102 | # if session > 0:
103 | # _preds = torch.argmax(lgt, dim=1)
104 | # torch.save(_preds, f"pred_labels/{args.project}_{args.dataset}_{session}_preds.pt")
105 | # torch.save(lbs, f"pred_labels/{args.project}_{args.dataset}_{session}_labels.pt")
106 | # torch.save(model.module.fc.weight.data.cpu()[:test_class], f"pred_labels/{args.project}_{args.dataset}_{session}_weights.pt")
107 |
108 | if session > 0:
109 | save_model_dir = os.path.join(args.save_path, 'session' + str(session) + 'confusion_matrix')
110 | cm = confmatrix(lgt,lbs,save_model_dir)
111 | perclassacc = cm.diagonal()
112 | seenac = np.mean(perclassacc[:args.base_class])
113 | unseenac = np.mean(perclassacc[args.base_class:])
114 |
115 | result_list.append(f"Seen Acc:{seenac} Unseen Acc:{unseenac}")
116 | return vl, (seenac, unseenac, va)
117 | else:
118 | return vl, va
119 |
--------------------------------------------------------------------------------
/dataloader/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import copy
4 |
5 |
6 | class CategoriesSampler():
7 |
8 | def __init__(self, label, n_batch, n_cls, n_per, ):
9 | self.n_batch = n_batch # the number of iterations in the dataloader
10 | self.n_cls = n_cls
11 | self.n_per = n_per
12 |
13 | label = np.array(label) # all data label
14 | self.m_ind = [] # the data index of each class
15 | for i in range(max(label) + 1):
16 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class
17 | ind = torch.from_numpy(ind)
18 | self.m_ind.append(ind)
19 |
20 | def __len__(self):
21 | return self.n_batch
22 |
23 | def __iter__(self):
24 | for i_batch in range(self.n_batch):
25 | batch = []
26 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] # sample n_cls classes from total classes.
27 | for c in classes:
28 | l = self.m_ind[c] # all data indexs of this class
29 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class
30 | batch.append(l[pos])
31 | batch = torch.stack(batch).t().reshape(-1)
32 | # .t() transpose,
33 | # due to it, the label is in the sequence of abcdabcdabcd form after reshape,
34 | # instead of aaaabbbbccccdddd
35 | yield batch
36 | # finally sample n_batch* n_cls(way)* n_per(shot) instances. per bacth.
37 |
38 |
39 |
40 | class BasePreserverCategoriesSampler():
41 | def __init__(self, label, n_batch, n_cls, n_per, ):
42 | self.n_batch = n_batch # the number of iterations in the dataloader
43 | self.n_cls = n_cls
44 | self.n_per = n_per
45 |
46 | label = np.array(label) # all data label
47 | self.m_ind = [] # the data index of each class
48 | for i in range(max(label) + 1):
49 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class
50 | ind = torch.from_numpy(ind)
51 | self.m_ind.append(ind)
52 |
53 | def __len__(self):
54 | return self.n_batch
55 |
56 | def __iter__(self):
57 |
58 | for i_batch in range(self.n_batch):
59 | batch = []
60 | #classes = torch.randperm(len(self.m_ind))[:self.n_cls] # sample n_cls classes from total classes.
61 | classes=torch.arange(len(self.m_ind))
62 | for c in classes:
63 | l = self.m_ind[c] # all data indexs of this class
64 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class
65 | batch.append(l[pos])
66 | batch = torch.stack(batch).t().reshape(-1)
67 | # .t() transpose,
68 | # due to it, the label is in the sequence of abcdabcdabcd form after reshape,
69 | # instead of aaaabbbbccccdddd
70 | yield batch
71 | # finally sample n_batch* n_cls(way)* n_per(shot) instances. per bacth.
72 |
73 | class NewCategoriesSampler():
74 |
75 | def __init__(self, label, n_batch, n_cls, n_per,):
76 | self.n_batch = n_batch # the number of iterations in the dataloader
77 | self.n_cls = n_cls
78 | self.n_per = n_per
79 |
80 | label = np.array(label) # all data label
81 | self.m_ind = [] # the data index of each class
82 | for i in range(max(label) + 1):
83 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class
84 | ind = torch.from_numpy(ind)
85 | self.m_ind.append(ind)
86 |
87 | self.classlist=np.arange(np.min(label),np.max(label)+1)
88 | #print(self.classlist)
89 |
90 | def __len__(self):
91 | return self.n_batch
92 |
93 | def __iter__(self):
94 | for i_batch in range(self.n_batch):
95 | batch = []
96 | for c in self.classlist:
97 | l = self.m_ind[c] # all data indexs of this class
98 | pos = torch.randperm(len(l))[:self.n_per] # sample n_per data index of this class
99 | batch.append(l[pos])
100 | batch = torch.stack(batch).t().reshape(-1)
101 | yield batch
102 |
103 |
104 | if __name__ == '__main__':
105 | q=np.arange(5,10)
106 | print(q)
107 | y=torch.tensor([5,6,7,8,9,5,6,7,8,9,5,6,7,8,9,5,5,5,55,])
108 | label = np.array(y) # all data label
109 | m_ind = [] # the data index of each class
110 | for i in range(max(label) + 1):
111 | ind = np.argwhere(label == i).reshape(-1) # all data index of this class
112 | ind = torch.from_numpy(ind)
113 | m_ind.append(ind)
114 | print(m_ind, len(m_ind))
115 |
--------------------------------------------------------------------------------
/dataloader/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from dataloader.sampler import CategoriesSampler
4 | from torch.utils.data import Dataset
5 | def set_up_datasets(args):
6 | if args.dataset == 'cifar100':
7 | import dataloader.cifar100.cifar as Dataset
8 | args.base_class = 60
9 | args.num_classes=100
10 | args.way = 5
11 | args.shot = 5
12 | args.sessions = 9
13 |
14 | if args.dataset == 'cub200':
15 | import dataloader.cub200.cub200 as Dataset
16 | args.base_class = 100
17 | args.num_classes = 200
18 | args.way = 10
19 | args.shot = 5
20 | args.sessions = 11
21 |
22 | if args.dataset == 'mini_imagenet':
23 | import dataloader.miniimagenet.miniimagenet as Dataset
24 | args.base_class = 60
25 | args.num_classes=100
26 | args.way = 5
27 | args.shot = 5
28 | args.sessions = 9
29 |
30 |
31 | args.Dataset=Dataset
32 | return args
33 |
34 | def get_dataloader(args, session):
35 | if session == 0:
36 | trainset, trainloader, testloader = get_base_dataloader(args)
37 | else:
38 | trainset, trainloader, testloader = get_new_dataloader(args, session)
39 | return trainset, trainloader, testloader
40 |
41 | def get_base_dataloader(args):
42 | txt_path = "data/index_list/" + args.dataset + "/session_" + str(0 + 1) + '.txt'
43 | class_index = np.arange(args.base_class)
44 | if args.dataset == 'cifar100':
45 |
46 | trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=True,
47 | index=class_index, base_sess=True)
48 | testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
49 | index=class_index, base_sess=True)
50 |
51 | if args.dataset == 'cub200':
52 | trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
53 | index=class_index, base_sess=True)
54 | testset = args.Dataset.CUB200(root=args.dataroot, train=False, index=class_index)
55 |
56 | if args.dataset == 'mini_imagenet':
57 | trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
58 | index=class_index, base_sess=True)
59 | testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False, index=class_index)
60 |
61 | if args.dataset == 'imagenet100' or args.dataset == 'imagenet1000':
62 | trainset = args.Dataset.ImageNet(root=args.dataroot, train=True,
63 | index=class_index, base_sess=True)
64 | testset = args.Dataset.ImageNet(root=args.dataroot, train=False, index=class_index)
65 |
66 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_base, shuffle=True,
67 | num_workers=8, pin_memory=True)
68 | testloader = torch.utils.data.DataLoader(
69 | dataset=testset, batch_size=args.test_batch_size, shuffle=False, num_workers=8, pin_memory=True)
70 |
71 | return trainset, trainloader, testloader
72 |
73 |
74 | def get_new_dataloader(args, session):
75 | txt_path = "data/index_list/" + args.dataset + "/session_" + str(session + 1) + '.txt'
76 |
77 | if args.dataset == 'cifar100':
78 | class_index = open(txt_path).read().splitlines()
79 | trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=False,
80 | index=class_index, base_sess=False)
81 | if args.dataset == 'cub200':
82 | trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
83 | index_path=txt_path)
84 | if args.dataset == 'mini_imagenet':
85 | trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
86 | index_path=txt_path)
87 |
88 | if args.batch_size_new == 0:
89 | batch_size_new = trainset.__len__()
90 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_new, shuffle=False,
91 | num_workers=args.num_workers, pin_memory=True)
92 | else:
93 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_new, shuffle=True,
94 | num_workers=args.num_workers, pin_memory=True)
95 |
96 | # test on all encountered classes
97 | class_new = get_session_classes(args, session)
98 |
99 | if args.dataset == 'cifar100':
100 | testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
101 | index=class_new, base_sess=False)
102 | if args.dataset == 'cub200':
103 | testset = args.Dataset.CUB200(root=args.dataroot, train=False,
104 | index=class_new)
105 | if args.dataset == 'mini_imagenet':
106 | testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False,
107 | index=class_new)
108 |
109 | testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=args.test_batch_size, shuffle=False,
110 | num_workers=args.num_workers, pin_memory=True)
111 |
112 | return trainset, trainloader, testloader
113 |
114 | def get_session_classes(args,session):
115 | class_list=np.arange(args.base_class + session * args.way)
116 | return class_list
117 |
118 |
--------------------------------------------------------------------------------
/dataloader/miniimagenet/miniimagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | from .autoaugment import AutoAugImageNetPolicy
10 |
11 | class MiniImageNet(Dataset):
12 |
13 | def __init__(self, root='', train=True,
14 | transform=None,
15 | index_path=None, index=None, base_sess=None, autoaug=True):
16 | if train:
17 | setname = 'train'
18 | else:
19 | setname = 'test'
20 | self.root = os.path.expanduser(root)
21 | self.transform = transform
22 | self.train = train # training set or test set
23 | self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images')
24 | self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split')
25 |
26 | csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
27 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
28 |
29 | self.data = []
30 | self.targets = []
31 | self.data2label = {}
32 | lb = -1
33 |
34 | self.wnids = []
35 |
36 | for l in lines:
37 | name, wnid = l.split(',')
38 | path = osp.join(self.IMAGE_PATH, name)
39 | if wnid not in self.wnids:
40 | self.wnids.append(wnid)
41 | lb += 1
42 | self.data.append(path)
43 | self.targets.append(lb)
44 | self.data2label[path] = lb
45 |
46 | if autoaug is False:
47 | #do not use autoaug.
48 | if train:
49 | image_size = 84
50 | self.transform = transforms.Compose([
51 | transforms.RandomResizedCrop(image_size),
52 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
53 | transforms.RandomHorizontalFlip(),
54 |
55 | transforms.ToTensor(),
56 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
57 | std=[0.229, 0.224, 0.225])])
58 | if base_sess:
59 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
60 | else:
61 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
62 | else:
63 | image_size = 84
64 | self.transform = transforms.Compose([
65 | transforms.Resize([92, 92]),
66 | transforms.CenterCrop(image_size),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
69 | std=[0.229, 0.224, 0.225])])
70 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
71 | else:
72 | #use autoaug.
73 | if train:
74 | image_size = 84
75 | self.transform = transforms.Compose([
76 | transforms.RandomResizedCrop(image_size),
77 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
78 | transforms.RandomHorizontalFlip(),
79 | #add autoaug
80 | AutoAugImageNetPolicy(),
81 | transforms.ToTensor(),
82 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
83 | std=[0.229, 0.224, 0.225])])
84 | if base_sess:
85 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
86 | else:
87 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
88 | else:
89 | image_size = 84
90 | self.transform = transforms.Compose([
91 | transforms.Resize([92, 92]),
92 | transforms.CenterCrop(image_size),
93 | transforms.ToTensor(),
94 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
95 | std=[0.229, 0.224, 0.225])])
96 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
97 |
98 | def SelectfromTxt(self, data2label, index_path):
99 | #select from txt file, and make cooresponding mampping.
100 | index=[]
101 | lines = [x.strip() for x in open(index_path, 'r').readlines()]
102 | for line in lines:
103 | index.append(line.split('/')[3])
104 | data_tmp = []
105 | targets_tmp = []
106 | for i in index:
107 | img_path = os.path.join(self.IMAGE_PATH, i)
108 | data_tmp.append(img_path)
109 | targets_tmp.append(data2label[img_path])
110 |
111 | return data_tmp, targets_tmp
112 |
113 | def SelectfromClasses(self, data, targets, index):
114 | #select from csv file, choose all instances from this class.
115 | data_tmp = []
116 | targets_tmp = []
117 | for i in index:
118 | ind_cl = np.where(i == targets)[0]
119 | for j in ind_cl:
120 | data_tmp.append(data[j])
121 | targets_tmp.append(targets[j])
122 |
123 | return data_tmp, targets_tmp
124 |
125 | def __len__(self):
126 | return len(self.data)
127 |
128 | def __getitem__(self, i):
129 |
130 | path, targets = self.data[i], self.targets[i]
131 | image = self.transform(Image.open(path).convert('RGB'))
132 | return image, targets
133 |
134 |
135 | class MiniImageNet_concate(Dataset):
136 | def __init__(self, train,x1,y1,x2,y2):
137 |
138 | if train:
139 | image_size = 84
140 | self.transform = transforms.Compose([
141 | transforms.RandomResizedCrop(image_size),
142 | # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
143 | transforms.RandomHorizontalFlip(),
144 |
145 | transforms.ToTensor(),
146 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
147 | std=[0.229, 0.224, 0.225])])
148 |
149 | else:
150 | image_size = 84
151 | self.transform = transforms.Compose([
152 | transforms.Resize([92, 92]),
153 | transforms.CenterCrop(image_size),
154 | transforms.ToTensor(),
155 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
156 | std=[0.229, 0.224, 0.225])])
157 |
158 |
159 | self.data=x1+x2
160 | self.targets=y1+y2
161 | print(len(self.data),len(self.targets))
162 |
163 | def __len__(self):
164 | return len(self.data)
165 |
166 | def __getitem__(self, i):
167 |
168 | path, targets = self.data[i], self.targets[i]
169 | image = self.transform(Image.open(path).convert('RGB'))
170 | return image, targets
171 |
172 | if __name__ == '__main__':
173 | txt_path = "../../data/index_list/mini_imagenet/session_1.txt"
174 | # class_index = open(txt_path).read().splitlines()
175 | base_class = 100
176 | class_index = np.arange(base_class)
177 | dataroot = '/data/wangqw/datasets/FSCIL'
178 | batch_size_base = 400
179 | trainset = MiniImageNet(root=dataroot, train=False, transform=None, index=False)
180 | cls = np.unique(trainset.targets)
181 | print(cls)
182 | # trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_base, shuffle=True, num_workers=8,
183 | # pin_memory=True)
184 | # print(trainloader.dataset.data.shape)
185 |
--------------------------------------------------------------------------------
/models/teen/fscil_trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from copy import deepcopy
3 |
4 | import torch.nn as nn
5 | from base import Trainer
6 | from dataloader.data_utils import get_dataloader
7 | from utils import *
8 |
9 | from .helper import *
10 | from .Network import MYNET
11 |
12 |
13 | class FSCILTrainer(Trainer):
14 | def __init__(self, args):
15 | super().__init__(args)
16 | self.args = args
17 | self.set_up_model()
18 |
19 | def set_up_model(self):
20 | self.model = MYNET(self.args, mode=self.args.base_mode)
21 | self.model = nn.DataParallel(self.model, list(range(self.args.num_gpu)))
22 | self.model = self.model.cuda()
23 |
24 | if self.args.model_dir is not None:
25 | logging.info('Loading init parameters from: %s' % self.args.model_dir)
26 | self.best_model_dict = torch.load(self.args.model_dir,
27 | map_location={'cuda:3':'cuda:0'})['params']
28 | else:
29 | logging.info('random init params')
30 | if self.args.start_session > 0:
31 | logging.info('WARING: Random init weights for new sessions!')
32 | self.best_model_dict = deepcopy(self.model.state_dict())
33 |
34 | def train(self,):
35 | args = self.args
36 | t_start_time = time.time()
37 | # init train statistics
38 | result_list = [args]
39 | for session in range(args.start_session, args.sessions):
40 | train_set, trainloader, testloader = get_dataloader(args, session)
41 | self.model.load_state_dict(self.best_model_dict)
42 | if session == 0: # load base class train img label
43 | if not args.only_do_incre:
44 | logging.info(f'new classes for this session:{np.unique(train_set.targets)}')
45 | optimizer, scheduler = get_optimizer(args, self.model)
46 | for epoch in range(args.epochs_base):
47 | start_time = time.time()
48 |
49 | tl, ta = base_train(self.model, trainloader, optimizer, scheduler, epoch, args)
50 | tsl, tsa = test(self.model, testloader, epoch, args, session, result_list=result_list)
51 |
52 | # save better model
53 | if (tsa * 100) >= self.trlog['max_acc'][session]:
54 | self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
55 | self.trlog['max_acc_epoch'] = epoch
56 | save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
57 | torch.save(dict(params=self.model.state_dict()), save_model_dir)
58 | torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
59 | self.best_model_dict = deepcopy(self.model.state_dict())
60 | logging.info('********A better model is found!!**********')
61 | logging.info('Saving model to :%s' % save_model_dir)
62 | logging.info('best epoch {}, best test acc={:.3f}'.format(
63 | self.trlog['max_acc_epoch'], self.trlog['max_acc'][session]))
64 |
65 | self.trlog['train_loss'].append(tl)
66 | self.trlog['train_acc'].append(ta)
67 | self.trlog['test_loss'].append(tsl)
68 | self.trlog['test_acc'].append(tsa)
69 | lrc = scheduler.get_last_lr()[0]
70 |
71 | logging.info(
72 | 'epoch:%03d,lr:%.4f,training_loss:%.5f,training_acc:%.5f,test_loss:%.5f,test_acc:%.5f' % (
73 | epoch, lrc, tl, ta, tsl, tsa))
74 | print('This epoch takes %d seconds' % (time.time() - start_time),
75 | '\n still need around %.2f mins to finish this session' % (
76 | (time.time() - start_time) * (args.epochs_base - epoch) / 60))
77 | scheduler.step()
78 |
79 | # Finish base train
80 | logging.info('>>> Finish Base Train <<<')
81 | result_list.append('Session {}, Test Best Epoch {},\nbest test Acc {:.4f}\n'.format(
82 | session, self.trlog['max_acc_epoch'], self.trlog['max_acc'][session]))
83 | else:
84 | logging.info('>>> Load Model &&& Finish base train...')
85 | assert args.model_dir is not None
86 |
87 | if not args.not_data_init:
88 | self.model.load_state_dict(self.best_model_dict)
89 | self.model = replace_base_fc(train_set, testloader.dataset.transform, self.model, args)
90 | best_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
91 | logging.info('Replace the fc with average embedding, and save it to :%s' % best_model_dir)
92 | self.best_model_dict = deepcopy(self.model.state_dict())
93 | torch.save(dict(params=self.model.state_dict()), best_model_dir)
94 |
95 | self.model.module.mode = 'avg_cos'
96 | tsl, tsa = test(self.model, testloader, 0, args, session, result_list=result_list)
97 | if (tsa * 100) >= self.trlog['max_acc'][session]:
98 | self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
99 | logging.info('The new best test acc of base session={:.3f}'.format(
100 | self.trlog['max_acc'][session]))
101 |
102 | # incremental learning sessions
103 | else:
104 | logging.info("training session: [%d]" % session)
105 | self.model.module.mode = self.args.new_mode
106 | self.model.eval()
107 | trainloader.dataset.transform = testloader.dataset.transform
108 |
109 | if args.soft_mode == 'soft_proto':
110 | self.model.module.update_fc(trainloader, np.unique(train_set.targets), session)
111 | self.model.module.soft_calibration(args, session)
112 | else:
113 | raise NotImplementedError
114 |
115 | tsl, (seenac, unseenac, avgac) = test(self.model, testloader, 0, args, session, result_list=result_list)
116 |
117 | # update results and save model
118 | self.trlog['seen_acc'].append(float('%.3f' % (seenac * 100)))
119 | self.trlog['unseen_acc'].append(float('%.3f' % (unseenac * 100)))
120 | self.trlog['max_acc'][session] = float('%.3f' % (avgac * 100))
121 | self.best_model_dict = deepcopy(self.model.state_dict())
122 |
123 | logging.info(f"Session {session} ==> Seen Acc:{self.trlog['seen_acc'][-1]} "
124 | f"Unseen Acc:{self.trlog['unseen_acc'][-1]} Avg Acc:{self.trlog['max_acc'][session]}")
125 | result_list.append('Session {}, test Acc {:.3f}\n'.format(session, self.trlog['max_acc'][session]))
126 |
127 | # Finish all incremental sessions, save results.
128 | result_list, hmeans = postprocess_results(result_list, self.trlog)
129 | save_list_to_txt(os.path.join(args.save_path, 'results.txt'), result_list)
130 | if not self.args.debug:
131 | save_result(args, self.trlog, hmeans)
132 |
133 | t_end_time = time.time()
134 | total_time = (t_end_time - t_start_time) / 60
135 | logging.info(f"Base Session Best epoch:{self.trlog['max_acc_epoch']}")
136 | logging.info('Total time used %.2f mins' % total_time)
137 | logging.info(self.args.time_str)
138 |
--------------------------------------------------------------------------------
/dataloader/cub200/cub200.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | from .autoaugment import AutoAugImageNetPolicy
10 |
11 | class CUB200(Dataset):
12 |
13 | def __init__(self, root='', train=True,
14 | index_path=None, index=None, base_sess=None, autoaug=False):
15 | self.root = os.path.expanduser(root)
16 | self.train = train # training set or test set
17 | self._pre_operate(self.root)
18 |
19 | if autoaug is False:
20 | #do not use autoaug
21 | if train:
22 | self.transform = transforms.Compose([
23 | transforms.Resize(256),
24 | # transforms.CenterCrop(224),
25 | transforms.RandomResizedCrop(224),
26 | transforms.RandomHorizontalFlip(),
27 | #add autoaug
28 | #AutoAugImageNetPolicy(),
29 | transforms.ToTensor(),
30 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
31 | ])
32 | # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
33 | if base_sess:
34 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
35 | else:
36 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
37 | else:
38 | self.transform = transforms.Compose([
39 | transforms.Resize(256),
40 | transforms.CenterCrop(224),
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
43 | ])
44 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
45 | else:
46 | #use autoaug
47 | if train:
48 | self.transform = transforms.Compose([
49 | transforms.Resize(256),
50 | # transforms.CenterCrop(224),
51 | transforms.RandomResizedCrop(224),
52 | transforms.RandomHorizontalFlip(),
53 | #add autoaug
54 | AutoAugImageNetPolicy(),
55 | transforms.ToTensor(),
56 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
57 | ])
58 | # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
59 | if base_sess:
60 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
61 | else:
62 | self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
63 | else:
64 | self.transform = transforms.Compose([
65 | transforms.Resize(256),
66 | transforms.CenterCrop(224),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
69 | ])
70 | self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
71 | def text_read(self, file):
72 | with open(file, 'r') as f:
73 | lines = f.readlines()
74 | for i, line in enumerate(lines):
75 | lines[i] = line.strip('\n')
76 | return lines
77 |
78 | def list2dict(self, list):
79 | dict = {}
80 | for l in list:
81 | s = l.split(' ')
82 | id = int(s[0])
83 | cls = s[1]
84 | if id not in dict.keys():
85 | dict[id] = cls
86 | else:
87 | raise EOFError('The same ID can only appear once')
88 | return dict
89 |
90 | def _pre_operate(self, root):
91 | image_file = os.path.join(root, 'CUB_200_2011/images.txt')
92 | split_file = os.path.join(root, 'CUB_200_2011/train_test_split.txt')
93 | class_file = os.path.join(root, 'CUB_200_2011/image_class_labels.txt')
94 | id2image = self.list2dict(self.text_read(image_file))
95 | id2train = self.list2dict(self.text_read(split_file)) # 1: train images; 0: test iamges
96 | id2class = self.list2dict(self.text_read(class_file))
97 | train_idx = []
98 | test_idx = []
99 | for k in sorted(id2train.keys()):
100 | if id2train[k] == '1':
101 | train_idx.append(k)
102 | else:
103 | test_idx.append(k)
104 |
105 | self.data = []
106 | self.targets = []
107 | self.data2label = {}
108 | if self.train:
109 | for k in train_idx:
110 | image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k])
111 | self.data.append(image_path)
112 | self.targets.append(int(id2class[k]) - 1)
113 | self.data2label[image_path] = (int(id2class[k]) - 1)
114 |
115 | else:
116 | for k in test_idx:
117 | image_path = os.path.join(root, 'CUB_200_2011/images', id2image[k])
118 | self.data.append(image_path)
119 | self.targets.append(int(id2class[k]) - 1)
120 | self.data2label[image_path] = (int(id2class[k]) - 1)
121 |
122 | def SelectfromTxt(self, data2label, index_path):
123 | index = open(index_path).read().splitlines()
124 | data_tmp = []
125 | targets_tmp = []
126 | for i in index:
127 | img_path = os.path.join(self.root, i)
128 | data_tmp.append(img_path)
129 | targets_tmp.append(data2label[img_path])
130 |
131 | return data_tmp, targets_tmp
132 |
133 | def SelectfromClasses(self, data, targets, index):
134 | data_tmp = []
135 | targets_tmp = []
136 | for i in index:
137 | ind_cl = np.where(i == targets)[0]
138 | for j in ind_cl:
139 | data_tmp.append(data[j])
140 | targets_tmp.append(targets[j])
141 |
142 | return data_tmp, targets_tmp
143 |
144 | def __len__(self):
145 | return len(self.data)
146 |
147 | def __getitem__(self, i):
148 | path, targets = self.data[i], self.targets[i]
149 | image = self.transform(Image.open(path).convert('RGB'))
150 | return image, targets
151 |
152 |
153 | class CUB200_concate(Dataset):
154 | def __init__(self, train,x1,y1,x2,y2):
155 |
156 | if train:
157 | self.transform = transforms.Compose([
158 | transforms.Resize(256),
159 | # transforms.CenterCrop(224),
160 | transforms.RandomResizedCrop(224),
161 | transforms.RandomHorizontalFlip(),
162 | transforms.ToTensor(),
163 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
164 | ])
165 | else:
166 | self.transform = transforms.Compose([
167 | transforms.Resize(256),
168 | transforms.CenterCrop(224),
169 | transforms.ToTensor(),
170 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
171 | ])
172 | self.data=x1+x2
173 | self.targets=y1+y2
174 | print(len(self.data),len(self.targets))
175 |
176 |
177 |
178 | def __len__(self):
179 | return len(self.data)
180 |
181 | def __getitem__(self, i):
182 | path, targets = self.data[i], self.targets[i]
183 | image = self.transform(Image.open(path).convert('RGB'))
184 | return image, targets
185 |
186 | if __name__ == '__main__':
187 | txt_path = "../../data/index_list/cub200/session_1.txt"
188 | # class_index = open(txt_path).read().splitlines()
189 | base_class = 100
190 | class_index = np.arange(base_class)
191 | dataroot = '~/dataloader/data'
192 | batch_size_base = 400
193 | trainset = CUB200(root=dataroot, train=False, index=class_index,
194 | base_sess=True)
195 | cls = np.unique(trainset.targets)
196 | trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_base, shuffle=True, num_workers=8,
197 | pin_memory=True)
198 |
199 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pprint as pprint
4 | import random
5 | import shutil
6 | import time
7 | from collections import OrderedDict, defaultdict
8 | from pathlib import Path
9 |
10 | import matplotlib
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | from sklearn.metrics import confusion_matrix
15 | import logging
16 | from logging.config import dictConfig
17 | from dataloader.data_utils import *
18 |
19 | _utils_pp = pprint.PrettyPrinter()
20 |
21 | def set_logging(level, work_dir):
22 | LOGGING = {
23 | "version": 1,
24 | "disable_existing_loggers": False,
25 | "formatters": {
26 | "simple": {
27 | "format": f"%(message)s"
28 | },
29 | },
30 | "handlers": {
31 | "console": {
32 | "level": f"{level}",
33 | "class": "logging.StreamHandler",
34 | 'formatter': 'simple',
35 | },
36 | 'file': {
37 | 'level': f"{level}",
38 | 'formatter': 'simple',
39 | 'class': 'logging.FileHandler',
40 | 'filename': f'{work_dir if work_dir is not None else "."}/train.log',
41 | 'mode': 'a',
42 | },
43 | },
44 | "loggers": {
45 | "": {
46 | "level": f"{level}",
47 | "handlers": ["console", "file"] if work_dir is not None else ["console"],
48 | },
49 | },
50 | }
51 | dictConfig(LOGGING)
52 | logging.info(f"Log level set to: {level}")
53 |
54 | def pprint(x):
55 | _utils_pp.pprint(x)
56 | class ConfigEncoder(json.JSONEncoder):
57 | def default(self, o):
58 | if isinstance(o, type):
59 | return {'$class': o.__module__ + "." + o.__name__}
60 | elif isinstance(o, Enum):
61 | return {
62 | '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name
63 | }
64 | elif callable(o):
65 | return {
66 | '$function': o.__module__ + "." + o.__name__
67 | }
68 | return json.JSONEncoder.default(self, o)
69 | class Logger(object):
70 | def __init__(self, args, log_dir, **kwargs):
71 | self.logger_path = os.path.join(log_dir, 'scalars.json')
72 | # self.tb_logger = SummaryWriter(
73 | # logdir=osp.join(log_dir, 'tflogger'),
74 | # **kwargs,
75 | # )
76 | self.log_config(vars(args))
77 |
78 | self.scalars = defaultdict(OrderedDict)
79 |
80 | # def add_scalar(self, key, value, counter):
81 | def add_scalar(self, key, value, counter):
82 | assert self.scalars[key].get(counter, None) is None, 'counter should be distinct'
83 | self.scalars[key][counter] = value
84 | # self.tb_logger.add_scalar(key, value, counter)
85 |
86 | def log_config(self, variant_data):
87 | config_filepath = os.path.join(os.path.dirname(self.logger_path), 'configs.json')
88 | with open(config_filepath, "w") as fd:
89 | json.dump(variant_data, fd, indent=2, sort_keys=True, cls=ConfigEncoder)
90 |
91 | def dump(self):
92 | with open(self.logger_path, 'w') as fd:
93 | json.dump(self.scalars, fd, indent=2)
94 |
95 | def set_seed(seed):
96 | if seed == 0:
97 | logging.info(' random seed')
98 | torch.backends.cudnn.benchmark = True
99 | else:
100 | logging.info('manual seed:', seed)
101 | random.seed(seed)
102 | np.random.seed(seed)
103 | torch.manual_seed(seed)
104 | torch.cuda.manual_seed_all(seed)
105 | torch.backends.cudnn.deterministic = True
106 | torch.backends.cudnn.benchmark = False
107 |
108 | def set_gpu(args):
109 | gpu_list = [int(x) for x in args.gpu.split(',')]
110 | logging.info('use gpu:', gpu_list)
111 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
112 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
113 | return gpu_list.__len__()
114 |
115 | def ensure_path(path):
116 | if os.path.exists(path):
117 | pass
118 | else:
119 | logging.info('create folder:', path)
120 | os.makedirs(path)
121 |
122 | class Averager():
123 |
124 | def __init__(self):
125 | self.n = 0
126 | self.v = 0
127 |
128 | def add(self, x):
129 | self.v = (self.v * self.n + x) / (self.n + 1)
130 | self.n += 1
131 |
132 | def item(self):
133 | return self.v
134 |
135 |
136 | class Timer():
137 |
138 | def __init__(self):
139 | self.o = time.time()
140 |
141 | def measure(self, p=1):
142 | x = (time.time() - self.o) / p
143 | x = int(x)
144 | if x >= 3600:
145 | return '{:.1f}h'.format(x / 3600)
146 | if x >= 60:
147 | return '{}m'.format(round(x / 60))
148 | return '{}s'.format(x)
149 |
150 |
151 | def count_acc(logits, label):
152 | pred = torch.argmax(logits, dim=1)
153 | if torch.cuda.is_available():
154 | return (pred == label).type(torch.cuda.FloatTensor).mean().item()
155 | else:
156 | return (pred == label).type(torch.FloatTensor).mean().item()
157 |
158 | def count_acc_topk(x,y,k=5):
159 | _,maxk = torch.topk(x,k,dim=-1)
160 | total = y.size(0)
161 | test_labels = y.view(-1,1)
162 | #top1=(test_labels == maxk[:,0:1]).sum().item()
163 | topk=(test_labels == maxk).sum().item()
164 | return float(topk/total)
165 |
166 | def count_acc_taskIL(logits, label,args):
167 | basenum=args.base_class
168 | incrementnum=(args.num_classes-args.base_class)/args.way
169 | for i in range(len(label)):
170 | currentlabel=label[i]
171 | if currentlabel>> {params_info}-Avg_acc:{trlog['max_acc']} \n Seen_acc:{trlog['seen_acc']} \n Unseen_acc:{trlog['unseen_acc']} \n HMean_acc:{hmeans} \n")
231 |
232 | def harm_mean(seen, unseen):
233 | # compute from session1
234 | assert len(seen) == len(unseen)
235 | harm_means = []
236 | for _seen, _unseen in zip(seen, unseen):
237 | _hmean = (2 * _seen * _unseen) / (_seen + _unseen + 1e-12)
238 | _hmean = float('%.3f' % (_hmean))
239 | harm_means.append(_hmean)
240 | return harm_means
241 |
242 | def get_optimizer(args, model, **kwargs):
243 | # prepare optimizer
244 | if args.project in ['teen']:
245 | if args.optim == 'sgd':
246 | optimizer = torch.optim.SGD(model.parameters(), args.lr_base,
247 | momentum=args.momentum, nesterov=True,
248 | weight_decay=args.decay)
249 | elif args.optim == 'adam':
250 | optimizer = torch.optim.Adam(model.parameters(),
251 | lr=args.lr_base, weight_decay=args.decay)
252 |
253 |
254 | # prepare scheduler
255 | if args.schedule == 'Step':
256 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step,
257 | gamma=args.gamma)
258 | elif args.schedule == 'Milestone':
259 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones,
260 | gamma=args.gamma)
261 | elif args.schedule == 'Cosine':
262 | assert args.tmax >= 0 , "args.tmax should be greater than 0"
263 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.tmax)
264 | return optimizer, scheduler
265 |
266 |
--------------------------------------------------------------------------------
/dataloader/cub200/autoaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
3 | """
4 |
5 | from PIL import Image, ImageEnhance, ImageOps
6 | import numpy as np
7 | import random
8 |
9 | __all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy']
10 |
11 |
12 | class AutoAugImageNetPolicy(object):
13 | def __init__(self, fillcolor=(128, 128, 128)):
14 | self.policies = [
15 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
16 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
17 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
18 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
19 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
20 |
21 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
22 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
23 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
24 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
25 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
26 |
27 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
28 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
29 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
30 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
31 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
32 |
33 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
34 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
35 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
36 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
37 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
38 |
39 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
40 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
41 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
42 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor)
43 | ]
44 |
45 | def __call__(self, img):
46 | policy_idx = random.randint(0, len(self.policies) - 1)
47 | return self.policies[policy_idx](img)
48 |
49 | def __repr__(self):
50 | return "AutoAugment ImageNet Policy"
51 |
52 |
53 | class AutoAugCIFAR10Policy(object):
54 | def __init__(self, fillcolor=(128, 128, 128)):
55 | self.policies = [
56 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
57 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
58 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
59 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
60 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
61 |
62 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
63 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
64 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
65 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
66 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
67 |
68 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
69 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
70 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
71 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
72 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
73 |
74 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
75 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
76 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
77 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
78 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
79 |
80 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
81 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
82 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
83 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
84 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
85 | ]
86 |
87 | def __call__(self, img):
88 | policy_idx = random.randint(0, len(self.policies) - 1)
89 | return self.policies[policy_idx](img)
90 |
91 | def __repr__(self):
92 | return "AutoAugment CIFAR10 Policy"
93 |
94 |
95 | class AutoAugSVHNPolicy(object):
96 | def __init__(self, fillcolor=(128, 128, 128)):
97 | self.policies = [
98 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
99 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
100 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
101 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
102 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
103 |
104 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
105 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
106 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
107 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
108 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
109 |
110 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
111 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
112 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
113 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
114 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
115 |
116 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
117 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
118 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
119 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
120 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
121 |
122 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
123 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
124 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
125 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
126 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
127 | ]
128 |
129 | def __call__(self, img):
130 | policy_idx = random.randint(0, len(self.policies) - 1)
131 | return self.policies[policy_idx](img)
132 |
133 | def __repr__(self):
134 | return "AutoAugment SVHN Policy"
135 |
136 |
137 | class SubPolicy(object):
138 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
139 | ranges = {
140 | "shearX": np.linspace(0, 0.3, 10),
141 | "shearY": np.linspace(0, 0.3, 10),
142 | "translateX": np.linspace(0, 150 / 331, 10),
143 | "translateY": np.linspace(0, 150 / 331, 10),
144 | "rotate": np.linspace(0, 30, 10),
145 | "color": np.linspace(0.0, 0.9, 10),
146 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
147 | "solarize": np.linspace(256, 0, 10),
148 | "contrast": np.linspace(0.0, 0.9, 10),
149 | "sharpness": np.linspace(0.0, 0.9, 10),
150 | "brightness": np.linspace(0.0, 0.9, 10),
151 | "autocontrast": [0] * 10,
152 | "equalize": [0] * 10,
153 | "invert": [0] * 10
154 | }
155 |
156 | def rotate_with_fill(img, magnitude):
157 | rot = img.convert("RGBA").rotate(magnitude)
158 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
159 |
160 | func = {
161 | "shearX": lambda img, magnitude: img.transform(
162 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
163 | Image.BICUBIC, fillcolor=fillcolor),
164 | "shearY": lambda img, magnitude: img.transform(
165 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
166 | Image.BICUBIC, fillcolor=fillcolor),
167 | "translateX": lambda img, magnitude: img.transform(
168 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
169 | fillcolor=fillcolor),
170 | "translateY": lambda img, magnitude: img.transform(
171 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
172 | fillcolor=fillcolor),
173 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
174 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
175 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
176 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
177 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
178 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
179 | 1 + magnitude * random.choice([-1, 1])),
180 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
181 | 1 + magnitude * random.choice([-1, 1])),
182 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
183 | 1 + magnitude * random.choice([-1, 1])),
184 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
185 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
186 | "invert": lambda img, magnitude: ImageOps.invert(img)
187 | }
188 |
189 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
190 | # operation1, ranges[operation1][magnitude_idx1],
191 | # operation2, ranges[operation2][magnitude_idx2])
192 | self.p1 = p1
193 | self.operation1 = func[operation1]
194 | self.magnitude1 = ranges[operation1][magnitude_idx1]
195 | self.p2 = p2
196 | self.operation2 = func[operation2]
197 | self.magnitude2 = ranges[operation2][magnitude_idx2]
198 |
199 | def __call__(self, img):
200 | if random.random() < self.p1:
201 | img = self.operation1(img, self.magnitude1)
202 | if random.random() < self.p2:
203 | img = self.operation2(img, self.magnitude2)
204 | return img
--------------------------------------------------------------------------------
/dataloader/miniimagenet/autoaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
3 | """
4 |
5 | from PIL import Image, ImageEnhance, ImageOps
6 | import numpy as np
7 | import random
8 |
9 | __all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy']
10 |
11 |
12 | class AutoAugImageNetPolicy(object):
13 | def __init__(self, fillcolor=(128, 128, 128)):
14 | self.policies = [
15 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
16 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
17 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
18 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
19 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
20 |
21 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
22 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
23 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
24 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
25 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
26 |
27 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
28 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
29 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
30 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
31 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
32 |
33 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
34 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
35 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
36 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
37 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
38 |
39 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
40 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
41 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
42 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor)
43 | ]
44 |
45 | def __call__(self, img):
46 | policy_idx = random.randint(0, len(self.policies) - 1)
47 | return self.policies[policy_idx](img)
48 |
49 | def __repr__(self):
50 | return "AutoAugment ImageNet Policy"
51 |
52 |
53 | class AutoAugCIFAR10Policy(object):
54 | def __init__(self, fillcolor=(128, 128, 128)):
55 | self.policies = [
56 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
57 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
58 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
59 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
60 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
61 |
62 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
63 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
64 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
65 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
66 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
67 |
68 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
69 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
70 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
71 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
72 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
73 |
74 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
75 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
76 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
77 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
78 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
79 |
80 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
81 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
82 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
83 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
84 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
85 | ]
86 |
87 | def __call__(self, img):
88 | policy_idx = random.randint(0, len(self.policies) - 1)
89 | return self.policies[policy_idx](img)
90 |
91 | def __repr__(self):
92 | return "AutoAugment CIFAR10 Policy"
93 |
94 |
95 | class AutoAugSVHNPolicy(object):
96 | def __init__(self, fillcolor=(128, 128, 128)):
97 | self.policies = [
98 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
99 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
100 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
101 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
102 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
103 |
104 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
105 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
106 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
107 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
108 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
109 |
110 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
111 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
112 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
113 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
114 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
115 |
116 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
117 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
118 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
119 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
120 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
121 |
122 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
123 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
124 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
125 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
126 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
127 | ]
128 |
129 | def __call__(self, img):
130 | policy_idx = random.randint(0, len(self.policies) - 1)
131 | return self.policies[policy_idx](img)
132 |
133 | def __repr__(self):
134 | return "AutoAugment SVHN Policy"
135 |
136 |
137 | class SubPolicy(object):
138 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
139 | ranges = {
140 | "shearX": np.linspace(0, 0.3, 10),
141 | "shearY": np.linspace(0, 0.3, 10),
142 | "translateX": np.linspace(0, 150 / 331, 10),
143 | "translateY": np.linspace(0, 150 / 331, 10),
144 | "rotate": np.linspace(0, 30, 10),
145 | "color": np.linspace(0.0, 0.9, 10),
146 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
147 | "solarize": np.linspace(256, 0, 10),
148 | "contrast": np.linspace(0.0, 0.9, 10),
149 | "sharpness": np.linspace(0.0, 0.9, 10),
150 | "brightness": np.linspace(0.0, 0.9, 10),
151 | "autocontrast": [0] * 10,
152 | "equalize": [0] * 10,
153 | "invert": [0] * 10
154 | }
155 |
156 | def rotate_with_fill(img, magnitude):
157 | rot = img.convert("RGBA").rotate(magnitude)
158 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
159 |
160 | func = {
161 | "shearX": lambda img, magnitude: img.transform(
162 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
163 | Image.BICUBIC, fillcolor=fillcolor),
164 | "shearY": lambda img, magnitude: img.transform(
165 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
166 | Image.BICUBIC, fillcolor=fillcolor),
167 | "translateX": lambda img, magnitude: img.transform(
168 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
169 | fillcolor=fillcolor),
170 | "translateY": lambda img, magnitude: img.transform(
171 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
172 | fillcolor=fillcolor),
173 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
174 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
175 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
176 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
177 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
178 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
179 | 1 + magnitude * random.choice([-1, 1])),
180 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
181 | 1 + magnitude * random.choice([-1, 1])),
182 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
183 | 1 + magnitude * random.choice([-1, 1])),
184 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
185 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
186 | "invert": lambda img, magnitude: ImageOps.invert(img)
187 | }
188 |
189 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
190 | # operation1, ranges[operation1][magnitude_idx1],
191 | # operation2, ranges[operation2][magnitude_idx2])
192 | self.p1 = p1
193 | self.operation1 = func[operation1]
194 | self.magnitude1 = ranges[operation1][magnitude_idx1]
195 | self.p2 = p2
196 | self.operation2 = func[operation2]
197 | self.magnitude2 = ranges[operation2][magnitude_idx2]
198 |
199 | def __call__(self, img):
200 | if random.random() < self.p1:
201 | img = self.operation1(img, self.magnitude1)
202 | if random.random() < self.p2:
203 | img = self.operation2(img, self.magnitude2)
204 | return img
--------------------------------------------------------------------------------
/dataloader/cifar100/cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 | import pickle
7 |
8 | import torchvision.transforms as transforms
9 |
10 | from torchvision.datasets.vision import VisionDataset
11 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive
12 | from .autoaugment import CIFAR10Policy, Cutout
13 |
14 | class CIFAR10(VisionDataset):
15 | """`CIFAR10 `_ Dataset.
16 |
17 | Args:
18 | root (string): Root directory of dataset where directory
19 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
20 | train (bool, optional): If True, creates dataset from training set, otherwise
21 | creates from test set.
22 | transform (callable, optional): A function/transform that takes in an PIL image
23 | and returns a transformed version. E.g, ``transforms.RandomCrop``
24 | target_transform (callable, optional): A function/transform that takes in the
25 | target and transforms it.
26 | download (bool, optional): If true, downloads the dataset from the internet and
27 | puts it in root directory. If dataset is already downloaded, it is not
28 | downloaded again.
29 |
30 | """
31 | base_folder = 'cifar-10-batches-py'
32 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
33 | filename = "cifar-10-python.tar.gz"
34 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
35 | train_list = [
36 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
37 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
38 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
39 | ['data_batch_4', '634d18415352ddfa80567beed471001a'],
40 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
41 | ]
42 |
43 | test_list = [
44 | ['test_batch', '40351d587109b95175f43aff81a1287e'],
45 | ]
46 | meta = {
47 | 'filename': 'batches.meta',
48 | 'key': 'label_names',
49 | 'md5': '5ff9c542aee3614f3951f8cda6e48888',
50 | }
51 |
52 | def __init__(self, root, train=True, transform=None, target_transform=None,
53 | download=False, index=None, base_sess=None, autoaug=True):
54 |
55 | super(CIFAR10, self).__init__(root, transform=transform,
56 | target_transform=target_transform)
57 | self.root = os.path.expanduser(root)
58 | self.train = train # training set or test set
59 |
60 | if download:
61 | self.download()
62 |
63 | if not self._check_integrity():
64 | raise RuntimeError('Dataset not found or corrupted.' +
65 | ' You can use download=True to download it')
66 |
67 | # if self.train:
68 | # downloaded_list = self.train_list
69 | # else:
70 | # downloaded_list = self.test_list
71 |
72 | if autoaug is False:
73 | if self.train:
74 | downloaded_list = self.train_list
75 | self.transform = transforms.Compose([
76 | transforms.RandomCrop(32, padding=4),
77 | transforms.RandomHorizontalFlip(),
78 | transforms.ToTensor(),
79 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
80 | ])
81 | else:
82 | downloaded_list = self.test_list
83 | self.transform = transforms.Compose([
84 | transforms.ToTensor(),
85 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
86 | ])
87 | else:
88 | if self.train:
89 | downloaded_list = self.train_list
90 | self.transform = transforms.Compose([
91 | transforms.RandomCrop(32, padding=4),
92 | transforms.RandomHorizontalFlip(),
93 | CIFAR10Policy(), # add AutoAug
94 | transforms.ToTensor(),
95 | Cutout(n_holes=1, length=16),
96 | transforms.Normalize(
97 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
98 | ])
99 | else:
100 | downloaded_list = self.test_list
101 | self.transform = transforms.Compose([
102 | transforms.ToTensor(),
103 | transforms.Normalize(
104 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
105 | ])
106 |
107 | self.data = []
108 | self.targets = []
109 |
110 | # now load the picked numpy arrays
111 | for file_name, checksum in downloaded_list:
112 | file_path = os.path.join(self.root, self.base_folder, file_name)
113 | with open(file_path, 'rb') as f:
114 | entry = pickle.load(f, encoding='latin1')
115 | self.data.append(entry['data'])
116 | if 'labels' in entry:
117 | self.targets.extend(entry['labels'])
118 | else:
119 | self.targets.extend(entry['fine_labels'])
120 |
121 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
122 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
123 |
124 | self.targets = np.asarray(self.targets)
125 |
126 | if base_sess:
127 | self.data, self.targets = self.SelectfromDefault(self.data, self.targets, index)
128 | else: # new Class session
129 | if train:
130 | self.data, self.targets = self.NewClassSelector(self.data, self.targets, index)
131 | else:
132 | self.data, self.targets = self.SelectfromDefault(self.data, self.targets, index)
133 |
134 | self._load_meta()
135 |
136 | def SelectfromDefault(self, data, targets, index):
137 | data_tmp = []
138 | targets_tmp = []
139 | for i in index:
140 | ind_cl = np.where(i == targets)[0]
141 | if data_tmp == []:
142 | data_tmp = data[ind_cl]
143 | targets_tmp = targets[ind_cl]
144 | else:
145 | data_tmp = np.vstack((data_tmp, data[ind_cl]))
146 | targets_tmp = np.hstack((targets_tmp, targets[ind_cl]))
147 |
148 | return data_tmp, targets_tmp
149 |
150 | def NewClassSelector(self, data, targets, index):
151 | data_tmp = []
152 | targets_tmp = []
153 | ind_list = [int(i) for i in index]
154 | ind_np = np.array(ind_list)
155 | index = ind_np.reshape((5,5))
156 | for i in index:
157 | ind_cl = i
158 | if data_tmp == []:
159 | data_tmp = data[ind_cl]
160 | targets_tmp = targets[ind_cl]
161 | else:
162 | data_tmp = np.vstack((data_tmp, data[ind_cl]))
163 | targets_tmp = np.hstack((targets_tmp, targets[ind_cl]))
164 |
165 | return data_tmp, targets_tmp
166 |
167 | def _load_meta(self):
168 | path = os.path.join(self.root, self.base_folder, self.meta['filename'])
169 | if not check_integrity(path, self.meta['md5']):
170 | raise RuntimeError('Dataset metadata file not found or corrupted.' +
171 | ' You can use download=True to download it')
172 | with open(path, 'rb') as infile:
173 | data = pickle.load(infile, encoding='latin1')
174 | self.classes = data[self.meta['key']]
175 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
176 |
177 | def __getitem__(self, index):
178 | """
179 | Args:
180 | index (int): Index
181 |
182 | Returns:
183 | tuple: (image, target) where target is index of the target class.
184 | """
185 | img, target = self.data[index], self.targets[index]
186 |
187 | # doing this so that it is consistent with all other datasets
188 | # to return a PIL Image
189 | img = Image.fromarray(img)
190 |
191 | if self.transform is not None:
192 | img = self.transform(img)
193 |
194 | if self.target_transform is not None:
195 | target = self.target_transform(target)
196 |
197 | return img, target
198 |
199 | def __len__(self):
200 | return len(self.data)
201 |
202 | def _check_integrity(self):
203 | root = self.root
204 | for fentry in (self.train_list + self.test_list):
205 | filename, md5 = fentry[0], fentry[1]
206 | fpath = os.path.join(root, self.base_folder, filename)
207 | if not check_integrity(fpath, md5):
208 | return False
209 | return True
210 |
211 | def download(self):
212 | if self._check_integrity():
213 | print('Files already downloaded and verified')
214 | return
215 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
216 |
217 | def extra_repr(self):
218 | return "Split: {}".format("Train" if self.train is True else "Test")
219 |
220 |
221 | class CIFAR_concate(VisionDataset):
222 | def __init__(self, train,x1,y1,x2,y2):
223 |
224 | self.train=True
225 | if self.train:
226 | self.transform = transforms.Compose([
227 | transforms.RandomCrop(32, padding=4),
228 | transforms.RandomHorizontalFlip(),
229 | transforms.ToTensor(),
230 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
231 | ])
232 | else:
233 | self.transform = transforms.Compose([
234 | transforms.ToTensor(),
235 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
236 | ])
237 |
238 |
239 | self.data=np.vstack([x1,x2])
240 | self.targets=np.hstack([y1,y2])
241 | print(len(self.data),len(self.targets))
242 |
243 | def __getitem__(self, index):
244 |
245 | img, target = self.data[index], self.targets[index]
246 |
247 | img = Image.fromarray(img)
248 |
249 | if self.transform is not None:
250 | img = self.transform(img)
251 |
252 | return img, target
253 |
254 | def __len__(self):
255 | return len(self.data)
256 |
257 | class CIFAR100(CIFAR10):
258 | """`CIFAR100 `_ Dataset.
259 |
260 | This is a subclass of the `CIFAR10` Dataset.
261 | """
262 | base_folder = 'cifar-100-python'
263 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
264 | filename = "cifar-100-python.tar.gz"
265 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
266 | train_list = [
267 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
268 | ]
269 |
270 | test_list = [
271 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
272 | ]
273 | meta = {
274 | 'filename': 'meta',
275 | 'key': 'fine_label_names',
276 | 'md5': '7973b15100ade9c7d40fb424638fde48',
277 | }
278 |
279 | if __name__ == "__main__":
280 |
281 |
282 | dataroot = '../../data/'
283 | batch_size_base = 128
284 | txt_path = "../../data/index_list/cifar100/session_2.txt"
285 | # class_index = open(txt_path).read().splitlines()
286 | class_index = np.arange(60)
287 | class_index_val = np.arange(60,76)
288 | class_index_test= np.arange(76,100)
289 |
290 | trainset = CIFAR100(root=dataroot, train=True, download=True, transform=None, index=class_index_test,
291 | base_sess=True)
292 | testset = CIFAR100(root=dataroot, train=False, download=False,index=class_index, base_sess=True)
293 |
294 |
295 | import pickle
296 | print(trainset.data.shape)
297 | print(trainset.targets.shape)
298 | cls = np.unique(trainset.targets)
299 | print(cls)
300 | data={'data':trainset.data,'labels':trainset.targets}
301 | with open('CIFAR100_test.pickle', 'wb') as handle:
302 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
303 |
304 |
--------------------------------------------------------------------------------
/dataloader/cifar100/autoaugment.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageEnhance, ImageOps
2 | import numpy as np
3 | import random
4 | import torch
5 |
6 |
7 | class Cutout(object):
8 | def __init__(self, n_holes, length):
9 | self.n_holes = n_holes
10 | self.length = length
11 |
12 | def __call__(self, img):
13 | h = img.size(1)
14 | w = img.size(2)
15 |
16 | mask = np.ones((h, w), np.float32)
17 |
18 | for n in range(self.n_holes):
19 | y = np.random.randint(h)
20 | x = np.random.randint(w)
21 |
22 | y1 = np.clip(y - self.length // 2, 0, h)
23 | y2 = np.clip(y + self.length // 2, 0, h)
24 | x1 = np.clip(x - self.length // 2, 0, w)
25 | x2 = np.clip(x + self.length // 2, 0, w)
26 |
27 | mask[y1: y2, x1: x2] = 0.
28 |
29 | mask = torch.from_numpy(mask)
30 | mask = mask.expand_as(img)
31 | img = img * mask
32 |
33 | return img
34 |
35 | class ImageNetPolicy(object):
36 | """ Randomly choose one of the best 24 Sub-policies on ImageNet.
37 |
38 | Example:
39 | >>> policy = ImageNetPolicy()
40 | >>> transformed = policy(image)
41 |
42 | Example as a PyTorch Transform:
43 | >>> transform=transforms.Compose([
44 | >>> transforms.Resize(256),
45 | >>> ImageNetPolicy(),
46 | >>> transforms.ToTensor()])
47 | """
48 | def __init__(self, fillcolor=(128, 128, 128)):
49 | self.policies = [
50 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
51 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
52 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
53 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
54 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
55 |
56 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
57 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
58 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
59 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
60 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
61 |
62 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
63 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
64 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
65 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
66 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
67 |
68 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
69 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
70 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
71 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
72 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
73 |
74 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
75 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
76 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
77 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
78 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
79 | ]
80 |
81 |
82 | def __call__(self, img):
83 | policy_idx = random.randint(0, len(self.policies) - 1)
84 | return self.policies[policy_idx](img)
85 |
86 | def __repr__(self):
87 | return "AutoAugment ImageNet Policy"
88 |
89 |
90 | class CIFAR10Policy(object):
91 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
92 |
93 | Example:
94 | >>> policy = CIFAR10Policy()
95 | >>> transformed = policy(image)
96 |
97 | Example as a PyTorch Transform:
98 | >>> transform=transforms.Compose([
99 | >>> transforms.Resize(256),
100 | >>> CIFAR10Policy(),
101 | >>> transforms.ToTensor()])
102 | """
103 | def __init__(self, fillcolor=(128, 128, 128)):
104 | self.policies = [
105 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
106 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
107 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
108 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
109 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
110 |
111 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
112 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
113 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
114 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
115 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
116 |
117 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
118 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
119 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
120 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
121 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
122 |
123 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
124 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
125 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
126 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
127 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
128 |
129 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
130 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
131 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
132 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
133 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
134 | ]
135 |
136 |
137 | def __call__(self, img):
138 | policy_idx = random.randint(0, len(self.policies) - 1)
139 | return self.policies[policy_idx](img)
140 |
141 | def __repr__(self):
142 | return "AutoAugment CIFAR10 Policy"
143 |
144 |
145 | class SVHNPolicy(object):
146 | """ Randomly choose one of the best 25 Sub-policies on SVHN.
147 |
148 | Example:
149 | >>> policy = SVHNPolicy()
150 | >>> transformed = policy(image)
151 |
152 | Example as a PyTorch Transform:
153 | >>> transform=transforms.Compose([
154 | >>> transforms.Resize(256),
155 | >>> SVHNPolicy(),
156 | >>> transforms.ToTensor()])
157 | """
158 | def __init__(self, fillcolor=(128, 128, 128)):
159 | self.policies = [
160 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
161 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
162 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
163 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
164 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
165 |
166 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
167 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
168 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
169 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
170 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
171 |
172 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
173 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
174 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
175 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
176 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
177 |
178 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
179 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
180 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
181 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
182 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
183 |
184 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
185 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
186 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
187 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
188 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
189 | ]
190 |
191 |
192 | def __call__(self, img):
193 | policy_idx = random.randint(0, len(self.policies) - 1)
194 | return self.policies[policy_idx](img)
195 |
196 | def __repr__(self):
197 | return "AutoAugment SVHN Policy"
198 |
199 |
200 | class SubPolicy(object):
201 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
202 | ranges = {
203 | "shearX": np.linspace(0, 0.3, 10),
204 | "shearY": np.linspace(0, 0.3, 10),
205 | "translateX": np.linspace(0, 150 / 331, 10),
206 | "translateY": np.linspace(0, 150 / 331, 10),
207 | "rotate": np.linspace(0, 30, 10),
208 | "color": np.linspace(0.0, 0.9, 10),
209 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
210 | "solarize": np.linspace(256, 0, 10),
211 | "contrast": np.linspace(0.0, 0.9, 10),
212 | "sharpness": np.linspace(0.0, 0.9, 10),
213 | "brightness": np.linspace(0.0, 0.9, 10),
214 | "autocontrast": [0] * 10,
215 | "equalize": [0] * 10,
216 | "invert": [0] * 10
217 | }
218 |
219 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
220 | def rotate_with_fill(img, magnitude):
221 | rot = img.convert("RGBA").rotate(magnitude)
222 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
223 |
224 | func = {
225 | "shearX": lambda img, magnitude: img.transform(
226 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
227 | Image.BICUBIC, fillcolor=fillcolor),
228 | "shearY": lambda img, magnitude: img.transform(
229 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
230 | Image.BICUBIC, fillcolor=fillcolor),
231 | "translateX": lambda img, magnitude: img.transform(
232 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
233 | fillcolor=fillcolor),
234 | "translateY": lambda img, magnitude: img.transform(
235 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
236 | fillcolor=fillcolor),
237 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
238 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
239 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
240 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
241 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
242 | 1 + magnitude * random.choice([-1, 1])),
243 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
244 | 1 + magnitude * random.choice([-1, 1])),
245 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
246 | 1 + magnitude * random.choice([-1, 1])),
247 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
248 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
249 | "invert": lambda img, magnitude: ImageOps.invert(img)
250 | }
251 |
252 | self.p1 = p1
253 | self.operation1 = func[operation1]
254 | self.magnitude1 = ranges[operation1][magnitude_idx1]
255 | self.p2 = p2
256 | self.operation2 = func[operation2]
257 | self.magnitude2 = ranges[operation2][magnitude_idx2]
258 |
259 |
260 | def __call__(self, img):
261 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
262 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
263 | return img
--------------------------------------------------------------------------------
/models/resnet18_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import errno
6 | import hashlib
7 | import os
8 | import warnings
9 | import re
10 | import shutil
11 | import sys
12 | import tempfile
13 | from tqdm import tqdm
14 | from urllib.request import urlopen
15 | from urllib.parse import urlparse # noqa: F401
16 |
17 |
18 | def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
19 | r"""Loads the Torch serialized object at the given URL.
20 |
21 | If the object is already present in `model_dir`, it's deserialized and
22 | returned. The filename part of the URL should follow the naming convention
23 | ``filename-.ext`` where ```` is the first eight or more
24 | digits of the SHA256 hash of the contents of the file. The hash is used to
25 | ensure unique names and to verify the contents of the file.
26 |
27 | The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
28 | environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
29 | ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
30 | filesytem layout, with a default value ``~/.cache`` if not set.
31 |
32 | Args:
33 | url (string): URL of the object to download
34 | model_dir (string, optional): directory in which to save the object
35 | map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
36 | progress (bool, optional): whether or not to display a progress bar to stderr
37 |
38 | Example:
39 | >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
40 |
41 | """
42 | # Issue warning to move data if old env is set
43 | if os.getenv('TORCH_MODEL_ZOO'):
44 | warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
45 |
46 | if model_dir is None:
47 | torch_home = _get_torch_home()
48 | model_dir = os.path.join(torch_home, 'checkpoints')
49 |
50 | try:
51 | os.makedirs(model_dir)
52 | except OSError as e:
53 | if e.errno == errno.EEXIST:
54 | # Directory already exists, ignore.
55 | pass
56 | else:
57 | # Unexpected OSError, re-raise.
58 | raise
59 |
60 | parts = urlparse(url)
61 | filename = os.path.basename(parts.path)
62 | cached_file = os.path.join(model_dir, filename)
63 | if not os.path.exists(cached_file):
64 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
65 | hash_prefix = HASH_REGEX.search(filename).group(1)
66 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
67 | return torch.load(cached_file, map_location=map_location)
68 |
69 |
70 | def _download_url_to_file(url, dst, hash_prefix, progress):
71 | file_size = None
72 | u = urlopen(url)
73 | meta = u.info()
74 | if hasattr(meta, 'getheaders'):
75 | content_length = meta.getheaders("Content-Length")
76 | else:
77 | content_length = meta.get_all("Content-Length")
78 | if content_length is not None and len(content_length) > 0:
79 | file_size = int(content_length[0])
80 |
81 | # We deliberately save it in a temp file and move it after
82 | # download is complete. This prevents a local working checkpoint
83 | # being overriden by a broken download.
84 | dst_dir = os.path.dirname(dst)
85 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
86 |
87 | try:
88 | if hash_prefix is not None:
89 | sha256 = hashlib.sha256()
90 | with tqdm(total=file_size, disable=not progress,
91 | unit='B', unit_scale=True, unit_divisor=1024) as pbar:
92 | while True:
93 | buffer = u.read(8192)
94 | if len(buffer) == 0:
95 | break
96 | f.write(buffer)
97 | if hash_prefix is not None:
98 | sha256.update(buffer)
99 | pbar.update(len(buffer))
100 |
101 | f.close()
102 | if hash_prefix is not None:
103 | digest = sha256.hexdigest()
104 | if digest[:len(hash_prefix)] != hash_prefix:
105 | raise RuntimeError('invalid hash value (expected "{}", got "{}")'
106 | .format(hash_prefix, digest))
107 | shutil.move(f.name, dst)
108 | finally:
109 | f.close()
110 | if os.path.exists(f.name):
111 | os.remove(f.name)
112 |
113 |
114 | ENV_TORCH_HOME = 'TORCH_HOME'
115 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
116 | DEFAULT_CACHE_DIR = '~/.cache'
117 | HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
118 |
119 |
120 | def _get_torch_home():
121 | torch_home = os.path.expanduser(
122 | os.getenv(ENV_TORCH_HOME,
123 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
124 | return torch_home
125 |
126 |
127 |
128 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
129 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
130 | 'wide_resnet50_2', 'wide_resnet101_2']
131 |
132 |
133 | model_urls = {
134 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
135 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
136 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
137 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
138 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
139 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
140 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
141 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
142 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
143 | }
144 |
145 |
146 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
147 | """3x3 convolution with padding"""
148 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
149 | padding=dilation, groups=groups, bias=False, dilation=dilation)
150 |
151 |
152 | def conv1x1(in_planes, out_planes, stride=1):
153 | """1x1 convolution"""
154 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
155 |
156 |
157 | class BasicBlock(nn.Module):
158 | expansion = 1
159 |
160 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
161 | base_width=64, dilation=1, norm_layer=None):
162 | super(BasicBlock, self).__init__()
163 | if norm_layer is None:
164 | norm_layer = nn.BatchNorm2d
165 | if groups != 1 or base_width != 64:
166 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
167 | if dilation > 1:
168 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
169 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
170 | self.conv1 = conv3x3(inplanes, planes, stride)
171 | self.bn1 = norm_layer(planes)
172 | self.relu = nn.ReLU(inplace=True)
173 | self.conv2 = conv3x3(planes, planes)
174 | self.bn2 = norm_layer(planes)
175 | self.downsample = downsample
176 | self.stride = stride
177 |
178 | def forward(self, x):
179 | identity = x
180 |
181 | out = self.conv1(x)
182 | out = self.bn1(out)
183 | out = self.relu(out)
184 |
185 | out = self.conv2(out)
186 | out = self.bn2(out)
187 |
188 | if self.downsample is not None:
189 | identity = self.downsample(x)
190 |
191 | out += identity
192 | out = self.relu(out)
193 |
194 | return out
195 |
196 |
197 | class Bottleneck(nn.Module):
198 | expansion = 4
199 |
200 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
201 | base_width=64, dilation=1, norm_layer=None):
202 | super(Bottleneck, self).__init__()
203 | if norm_layer is None:
204 | norm_layer = nn.BatchNorm2d
205 | width = int(planes * (base_width / 64.)) * groups
206 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
207 | self.conv1 = conv1x1(inplanes, width)
208 | self.bn1 = norm_layer(width)
209 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
210 | self.bn2 = norm_layer(width)
211 | self.conv3 = conv1x1(width, planes * self.expansion)
212 | self.bn3 = norm_layer(planes * self.expansion)
213 | self.relu = nn.ReLU(inplace=True)
214 | self.downsample = downsample
215 | self.stride = stride
216 |
217 | def forward(self, x):
218 | identity = x
219 |
220 | out = self.conv1(x)
221 | out = self.bn1(out)
222 | out = self.relu(out)
223 |
224 | out = self.conv2(out)
225 | out = self.bn2(out)
226 | out = self.relu(out)
227 |
228 | out = self.conv3(out)
229 | out = self.bn3(out)
230 |
231 | if self.downsample is not None:
232 | identity = self.downsample(x)
233 |
234 | out += identity
235 | out = self.relu(out)
236 |
237 | return out
238 |
239 |
240 | class ResNet(nn.Module):
241 |
242 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
243 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
244 | norm_layer=None):
245 | super(ResNet, self).__init__()
246 | if norm_layer is None:
247 | norm_layer = nn.BatchNorm2d
248 | self._norm_layer = norm_layer
249 |
250 | self.inplanes = 64
251 | self.dilation = 1
252 | if replace_stride_with_dilation is None:
253 | # each element in the tuple indicates if we should replace
254 | # the 2x2 stride with a dilated convolution instead
255 | replace_stride_with_dilation = [False, False, False]
256 | if len(replace_stride_with_dilation) != 3:
257 | raise ValueError("replace_stride_with_dilation should be None "
258 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
259 | self.groups = groups
260 | self.base_width = width_per_group
261 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
262 | bias=False)
263 | self.bn1 = norm_layer(self.inplanes)
264 | self.relu = nn.ReLU(inplace=True)
265 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
266 | self.layer1 = self._make_layer(block, 64, layers[0])
267 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
268 | dilate=replace_stride_with_dilation[0])
269 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
270 | dilate=replace_stride_with_dilation[1])
271 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
272 | dilate=replace_stride_with_dilation[2])
273 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
274 | # self.fc = nn.Linear(512 * block.expansion, num_classes,bias=False)
275 |
276 | for m in self.modules():
277 | if isinstance(m, nn.Conv2d):
278 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
279 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
280 | nn.init.constant_(m.weight, 1)
281 | nn.init.constant_(m.bias, 0)
282 |
283 | # Zero-initialize the last BN in each residual branch,
284 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
285 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
286 | if zero_init_residual:
287 | for m in self.modules():
288 | if isinstance(m, Bottleneck):
289 | nn.init.constant_(m.bn3.weight, 0)
290 | elif isinstance(m, BasicBlock):
291 | nn.init.constant_(m.bn2.weight, 0)
292 |
293 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
294 | norm_layer = self._norm_layer
295 | downsample = None
296 | previous_dilation = self.dilation
297 | if dilate:
298 | self.dilation *= stride
299 | stride = 1
300 | if stride != 1 or self.inplanes != planes * block.expansion:
301 | downsample = nn.Sequential(
302 | conv1x1(self.inplanes, planes * block.expansion, stride),
303 | norm_layer(planes * block.expansion),
304 | )
305 |
306 | layers = []
307 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
308 | self.base_width, previous_dilation, norm_layer))
309 | self.inplanes = planes * block.expansion
310 | for _ in range(1, blocks):
311 | layers.append(block(self.inplanes, planes, groups=self.groups,
312 | base_width=self.base_width, dilation=self.dilation,
313 | norm_layer=norm_layer))
314 |
315 | return nn.Sequential(*layers)
316 |
317 | def forward(self, x):
318 | x = self.conv1(x)
319 | x = self.bn1(x)
320 | x = self.relu(x)
321 | x = self.maxpool(x)
322 |
323 | x = self.layer1(x)
324 | x = self.layer2(x)
325 | x = self.layer3(x)
326 | x = self.layer4(x)
327 |
328 | # x = self.avgpool(x)
329 | # x = torch.flatten(x, 1)
330 | # x = self.fc(x)
331 | # x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1))
332 | # x = temperature * x
333 |
334 | return x
335 |
336 |
337 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
338 | model = ResNet(block, layers, **kwargs)
339 | if pretrained:
340 | model_dict = model.state_dict()
341 | state_dict = load_state_dict_from_url(model_urls[arch],
342 | progress=progress)
343 | state_dict = {k: v for k, v in state_dict.items() if k not in ['fc.weight', 'fc.bias']}
344 | model_dict.update(state_dict)
345 | model.load_state_dict(model_dict)
346 | return model
347 |
348 |
349 | def resnet18(pretrained=False, progress=True, **kwargs):
350 | r"""ResNet-18 model from
351 | `"Deep Residual Learning for Image Recognition" `_
352 |
353 | Args:
354 | pretrained (bool): If True, returns a model pre-trained on ImageNet
355 | progress (bool): If True, displays a progress bar of the download to stderr
356 | """
357 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
358 | **kwargs)
359 |
360 |
361 | def resnet34(pretrained=False, progress=True, **kwargs):
362 | r"""ResNet-34 model from
363 | `"Deep Residual Learning for Image Recognition" `_
364 |
365 | Args:
366 | pretrained (bool): If True, returns a model pre-trained on ImageNet
367 | progress (bool): If True, displays a progress bar of the download to stderr
368 | """
369 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
370 | **kwargs)
371 |
372 |
373 | def resnet50(pretrained=False, progress=True, **kwargs):
374 | r"""ResNet-50 model from
375 | `"Deep Residual Learning for Image Recognition" `_
376 |
377 | Args:
378 | pretrained (bool): If True, returns a model pre-trained on ImageNet
379 | progress (bool): If True, displays a progress bar of the download to stderr
380 | """
381 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
382 | **kwargs)
383 |
384 |
385 | def resnet101(pretrained=False, progress=True, **kwargs):
386 | r"""ResNet-101 model from
387 | `"Deep Residual Learning for Image Recognition" `_
388 |
389 | Args:
390 | pretrained (bool): If True, returns a model pre-trained on ImageNet
391 | progress (bool): If True, displays a progress bar of the download to stderr
392 | """
393 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
394 | **kwargs)
395 |
396 |
397 | def resnet152(pretrained=False, progress=True, **kwargs):
398 | r"""ResNet-152 model from
399 | `"Deep Residual Learning for Image Recognition" `_
400 |
401 | Args:
402 | pretrained (bool): If True, returns a model pre-trained on ImageNet
403 | progress (bool): If True, displays a progress bar of the download to stderr
404 | """
405 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
406 | **kwargs)
407 |
408 |
409 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
410 | r"""ResNeXt-50 32x4d model from
411 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
412 |
413 | Args:
414 | pretrained (bool): If True, returns a model pre-trained on ImageNet
415 | progress (bool): If True, displays a progress bar of the download to stderr
416 | """
417 | kwargs['groups'] = 32
418 | kwargs['width_per_group'] = 4
419 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
420 | pretrained, progress, **kwargs)
421 |
422 |
423 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
424 | r"""ResNeXt-101 32x8d model from
425 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
426 |
427 | Args:
428 | pretrained (bool): If True, returns a model pre-trained on ImageNet
429 | progress (bool): If True, displays a progress bar of the download to stderr
430 | """
431 | kwargs['groups'] = 32
432 | kwargs['width_per_group'] = 8
433 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
434 | pretrained, progress, **kwargs)
435 |
436 |
437 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
438 | r"""Wide ResNet-50-2 model from
439 | `"Wide Residual Networks" `_
440 |
441 | The model is the same as ResNet except for the bottleneck number of channels
442 | which is twice larger in every block. The number of channels in outer 1x1
443 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
444 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
445 |
446 | Args:
447 | pretrained (bool): If True, returns a model pre-trained on ImageNet
448 | progress (bool): If True, displays a progress bar of the download to stderr
449 | """
450 | kwargs['width_per_group'] = 64 * 2
451 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
452 | pretrained, progress, **kwargs)
453 |
454 |
455 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
456 | r"""Wide ResNet-101-2 model from
457 | `"Wide Residual Networks" `_
458 |
459 | The model is the same as ResNet except for the bottleneck number of channels
460 | which is twice larger in every block. The number of channels in outer 1x1
461 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
462 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
463 |
464 | Args:
465 | pretrained (bool): If True, returns a model pre-trained on ImageNet
466 | progress (bool): If True, displays a progress bar of the download to stderr
467 | """
468 | kwargs['width_per_group'] = 64 * 2
469 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
470 | pretrained, progress, **kwargs)
471 |
--------------------------------------------------------------------------------