├── .gitignore ├── CUB ├── dynamic+PN │ ├── Conv4.sh │ ├── ResNet18.sh │ ├── train_stage_1.py │ └── train_stage_2.py ├── dynamic+PN_gt │ ├── Conv4.sh │ ├── ResNet18.sh │ ├── train_stage_1.py │ └── train_stage_2.py ├── dynamic │ ├── Conv4.sh │ ├── ResNet18.sh │ ├── train_stage_1.py │ └── train_stage_2.py ├── proto+BP │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+FSL │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+MT │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+PN │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+PN_gt │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+PN_less_annot │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+bbN │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto+uPN │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── proto │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── transfer+PN │ ├── Conv4.sh │ ├── ResNet18.sh │ ├── finetune_cub.py │ ├── finetune_na.py │ └── train_base.py ├── transfer+PN_gt │ ├── Conv4.sh │ ├── ResNet18.sh │ ├── finetune_cub.py │ └── train_base.py └── transfer │ ├── Conv4.sh │ ├── ResNet18.sh │ ├── finetune_cub.py │ ├── finetune_na.py │ └── train_base.py ├── FGVC ├── proto+PN │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py └── proto │ ├── Conv4.sh │ ├── ResNet18.sh │ └── train.py ├── LICENSE ├── README.md ├── dataset ├── download.sh ├── exclude_na_id_list.pth ├── init.sh ├── init_cub.py ├── init_fgvc.py ├── init_na.py └── init_oid.py └── utils ├── dataloader.py ├── dynamic_eval.py ├── dynamic_train.py ├── models.py ├── networks.py ├── proto_eval.py ├── proto_train.py ├── sampler.py ├── transfer_eval.py ├── transfer_train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all the model and log files 2 | model_*.pth 3 | events.* 4 | *.log 5 | 6 | # ignore python cache files 7 | *.pyc 8 | 9 | # OS generated files 10 | .DS_Store 11 | -------------------------------------------------------------------------------- /CUB/dynamic+PN/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_stage_1.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 100 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --num_part 15 \ 10 | --alpha 100 \ 11 | --batch_size 64 \ 12 | --gpu 0 13 | 14 | python train_stage_2.py \ 15 | --opt adam \ 16 | --lr 1e-3 \ 17 | --gamma 1e-1 \ 18 | --epoch 200 \ 19 | --stage 1 \ 20 | --weight_decay 0 \ 21 | --num_part 15 \ 22 | --load_path model_Conv4-stage_1.pth \ 23 | --gpu 0 -------------------------------------------------------------------------------- /CUB/dynamic+PN/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_stage_1.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 25 \ 7 | --stage 3 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --alpha 200 \ 11 | --resnet \ 12 | --batch_size 64 \ 13 | --gpu 0 14 | 15 | python train_stage_2.py \ 16 | --opt adam \ 17 | --lr 1e-3 \ 18 | --gamma 1e-1 \ 19 | --epoch 200 \ 20 | --stage 1 \ 21 | --weight_decay 0 \ 22 | --num_part 15 \ 23 | --resnet \ 24 | --load_path model_ResNet18-stage_1.pth \ 25 | --gpu 0 -------------------------------------------------------------------------------- /CUB/dynamic+PN/train_stage_1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import dynamic_train,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='stage_1', 15 | train_annot='part') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.support, 18 | batch_size=args.batch_size, 19 | annot=config.train_annot, 20 | annot_path=pm.annot_path) 21 | 22 | num_class = len(train_loader.dataset.classes) 23 | 24 | model = networks.Dynamic_PN(num_class=num_class, 25 | num_part=args.num_part, 26 | resnet=args.resnet) 27 | 28 | model.cuda() 29 | 30 | train_func = partial(dynamic_train.train_PN_stage_1, 31 | train_loader=train_loader, 32 | alpha=args.alpha) 33 | 34 | tm = util.Train_Manager(args,pm,config, 35 | train_func=train_func) 36 | 37 | tm.train(model) -------------------------------------------------------------------------------- /CUB/dynamic+PN/train_stage_2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import dynamic_train,dynamic_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='stage_2', 15 | shots=[20]) 16 | 17 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 18 | way=config.way, 19 | shots=config.shots) 20 | 21 | num_class = len(train_loader.dataset.classes) 22 | 23 | model = networks.Dynamic_PN(num_class=num_class, 24 | num_part=args.num_part, 25 | way=config.way, 26 | shots=config.shots, 27 | resnet=args.resnet) 28 | model.cuda() 29 | 30 | model.load_state_dict(torch.load(args.load_path)) 31 | 32 | train_func = partial(dynamic_train.train_stage_2,train_loader=train_loader) 33 | eval_func = dynamic_eval.default_eval 34 | 35 | tm = util.TM_dynamic_PN_stage_2(args,pm,config, 36 | train_func=train_func, 37 | eval_func=eval_func) 38 | 39 | tm.train(model) 40 | 41 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 42 | dynamic_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/dynamic+PN_gt/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_stage_1.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 50 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --num_part 15 \ 10 | --batch_size 64 \ 11 | --gpu 0 12 | 13 | python train_stage_2.py \ 14 | --opt adam \ 15 | --lr 1e-3 \ 16 | --gamma 1e-1 \ 17 | --epoch 200 \ 18 | --stage 1 \ 19 | --weight_decay 0 \ 20 | --num_part 15 \ 21 | --load_path model_Conv4-stage_1.pth \ 22 | --gpu 0 -------------------------------------------------------------------------------- /CUB/dynamic+PN_gt/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_stage_1.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 25 \ 7 | --stage 3 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --num_part 15 \ 11 | --batch_size 64 \ 12 | --gpu 0 13 | 14 | python train_stage_2.py \ 15 | --opt adam \ 16 | --lr 1e-3 \ 17 | --gamma 1e-1 \ 18 | --epoch 200 \ 19 | --stage 1 \ 20 | --weight_decay 0 \ 21 | --resnet \ 22 | --num_part 15 \ 23 | --load_path model_ResNet18-stage_1.pth \ 24 | --gpu 0 -------------------------------------------------------------------------------- /CUB/dynamic+PN_gt/train_stage_1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import dynamic_train,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='stage_1', 15 | train_annot='part') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.support, 18 | batch_size=args.batch_size, 19 | annot=config.train_annot, 20 | annot_path=pm.annot_path) 21 | 22 | num_class = len(train_loader.dataset.classes) 23 | 24 | model = networks.Dynamic_PN_gt(num_class=num_class, 25 | num_part=args.num_part, 26 | resnet=args.resnet) 27 | 28 | model.cuda() 29 | 30 | train_func = partial(dynamic_train.train_stage_1,train_loader=train_loader) 31 | 32 | tm = util.Train_Manager(args,pm,config, 33 | train_func=train_func) 34 | 35 | tm.train(model) -------------------------------------------------------------------------------- /CUB/dynamic+PN_gt/train_stage_2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import dynamic_train,dynamic_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='stage_2', 15 | shots=[20], 16 | train_annot='part', 17 | eval_annot='part') 18 | 19 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 20 | way=config.way, 21 | shots=config.shots, 22 | annot=config.train_annot, 23 | annot_path=pm.annot_path) 24 | 25 | num_class = len(train_loader.dataset.classes) 26 | 27 | model = networks.Dynamic_PN_gt(num_class=num_class, 28 | num_part=args.num_part, 29 | way=config.way, 30 | shots=config.shots, 31 | resnet=args.resnet) 32 | model.cuda() 33 | 34 | model.load_state_dict(torch.load(args.load_path)) 35 | 36 | train_func = partial(dynamic_train.train_stage_2,train_loader=train_loader) 37 | eval_func = dynamic_eval.default_eval 38 | 39 | tm = util.TM_dynamic_stage_2(args,pm,config, 40 | train_func=train_func, 41 | eval_func=eval_func) 42 | 43 | tm.train(model) 44 | 45 | dynamic_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/dynamic/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_stage_1.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 200 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --batch_size 64 \ 10 | --gpu 0 11 | 12 | python train_stage_2.py \ 13 | --opt adam \ 14 | --lr 1e-3 \ 15 | --gamma 1e-1 \ 16 | --epoch 200 \ 17 | --stage 1 \ 18 | --weight_decay 0 \ 19 | --load_path model_Conv4-stage_1.pth \ 20 | --gpu 0 -------------------------------------------------------------------------------- /CUB/dynamic/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_stage_1.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 100 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --batch_size 64 \ 11 | --gpu 0 12 | 13 | python train_stage_2.py \ 14 | --opt adam \ 15 | --lr 1e-3 \ 16 | --gamma 1e-1 \ 17 | --epoch 200 \ 18 | --stage 1 \ 19 | --weight_decay 0 \ 20 | --resnet \ 21 | --load_path model_ResNet18-stage_1.pth \ 22 | --gpu 0 -------------------------------------------------------------------------------- /CUB/dynamic/train_stage_1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import dynamic_train,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='stage_1') 15 | 16 | train_loader = dataloader.normal_train_dataloader(data_path=pm.support, 17 | batch_size=args.batch_size) 18 | 19 | num_class = len(train_loader.dataset.classes) 20 | 21 | model = networks.Dynamic(num_class=num_class,resnet=args.resnet) 22 | 23 | model.cuda() 24 | 25 | train_func = partial(dynamic_train.train_stage_1,train_loader=train_loader) 26 | 27 | tm = util.Train_Manager(args,pm,config, 28 | train_func=train_func) 29 | 30 | tm.train(model) -------------------------------------------------------------------------------- /CUB/dynamic/train_stage_2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import dynamic_train,dynamic_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='stage_2', 15 | shots=[20]) 16 | 17 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 18 | way=config.way, 19 | shots=config.shots) 20 | 21 | num_class = len(train_loader.dataset.classes) 22 | 23 | model = networks.Dynamic(num_class=num_class, 24 | way=config.way, 25 | shots=config.shots, 26 | resnet=args.resnet) 27 | model.cuda() 28 | 29 | model.load_state_dict(torch.load(args.load_path)) 30 | 31 | train_func = partial(dynamic_train.train_stage_2,train_loader=train_loader) 32 | eval_func = dynamic_eval.default_eval 33 | 34 | tm = util.TM_dynamic_stage_2(args,pm,config, 35 | train_func=train_func, 36 | eval_func=eval_func) 37 | 38 | tm.train(model) 39 | 40 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 41 | dynamic_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+BP/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt adam \ 4 | --lr 1e-3 \ 5 | --gamma 1e-1 \ 6 | --epoch 800 \ 7 | --stage 1 \ 8 | --weight_decay 0 \ 9 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+BP/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt adam \ 4 | --lr 1e-3 \ 5 | --gamma 1e-1 \ 6 | --epoch 600 \ 7 | --stage 1 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+BP/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name) 13 | 14 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 15 | shots=config.shots, 16 | way=config.way) 17 | 18 | model = networks.Proto_BP(way=config.way, 19 | shots=config.shots, 20 | resnet=args.resnet) 21 | 22 | model.cuda() 23 | 24 | train_func = partial(proto_train.default_train,train_loader=train_loader) 25 | eval_func = proto_eval.default_eval 26 | 27 | tm = util.Train_Manager(args,pm,config, 28 | train_func=train_func, 29 | eval_func=eval_func) 30 | 31 | tm.train(model) 32 | 33 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 34 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+FSL/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-2 \ 5 | --gamma 1e-1 \ 6 | --epoch 400 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+FSL/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+FSL/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name, 13 | shots=[5,5,10], 14 | train_annot='bbx', 15 | eval_annot='bbx') 16 | 17 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 18 | shots=config.shots, 19 | way=config.way, 20 | annot=config.train_annot, 21 | annot_path=pm.annot_path) 22 | 23 | model = networks.Proto_FSL(way=config.way, 24 | shots=config.shots, 25 | resnet=args.resnet) 26 | 27 | model.cuda() 28 | 29 | train_func = partial(proto_train.default_train,train_loader=train_loader) 30 | eval_func = proto_eval.default_eval 31 | 32 | tm = util.Train_Manager(args,pm,config, 33 | train_func=train_func, 34 | eval_func=eval_func) 35 | 36 | tm.train(model) 37 | 38 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 39 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+MT/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 600 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --alpha 100 \ 11 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+MT/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 5e-3 \ 9 | --num_part 15 \ 10 | --alpha 200 \ 11 | --resnet \ 12 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+MT/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name, 13 | train_annot='part') 14 | 15 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 16 | shots=config.shots, 17 | way=config.way, 18 | annot=config.train_annot, 19 | annot_path=pm.annot_path) 20 | 21 | model = networks.Proto_MT(num_part=args.num_part, 22 | way=config.way, 23 | shots=config.shots, 24 | resnet=args.resnet) 25 | 26 | model.cuda() 27 | 28 | train_func = partial(proto_train.PN_train,train_loader=train_loader,alpha=args.alpha) 29 | eval_func = proto_eval.default_eval 30 | 31 | tm = util.Train_Manager(args,pm,config, 32 | train_func=train_func, 33 | eval_func=eval_func) 34 | 35 | tm.train(model) 36 | 37 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 38 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+PN/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 600 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --alpha 100 \ 11 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+PN/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 5e-3 \ 9 | --num_part 15 \ 10 | --resnet \ 11 | --alpha 200 \ 12 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+PN/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name, 13 | train_annot='part') 14 | 15 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 16 | shots=config.shots, 17 | way=config.way, 18 | annot=config.train_annot, 19 | annot_path=pm.annot_path) 20 | 21 | model = networks.Proto_PN(num_part=args.num_part, 22 | way=config.way, 23 | shots=config.shots, 24 | resnet=args.resnet) 25 | 26 | model.cuda() 27 | 28 | train_func = partial(proto_train.PN_train,train_loader=train_loader,alpha=args.alpha) 29 | eval_func = proto_eval.default_eval 30 | 31 | tm = util.Train_Manager(args,pm,config, 32 | train_func=train_func, 33 | eval_func=eval_func) 34 | 35 | tm.train(model) 36 | 37 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 38 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+PN_gt/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 400 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --num_part 15 \ 10 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+PN_gt/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 5e-3 \ 9 | --num_part 15 \ 10 | --resnet \ 11 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+PN_gt/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name, 13 | train_annot='part', 14 | eval_annot='part') 15 | 16 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 17 | shots=config.shots, 18 | way=config.way, 19 | annot=config.train_annot, 20 | annot_path=pm.annot_path) 21 | 22 | model = networks.Proto_PN_gt(num_part=args.num_part, 23 | way=config.way, 24 | shots=config.shots, 25 | resnet=args.resnet) 26 | 27 | model.cuda() 28 | 29 | train_func = partial(proto_train.default_train,train_loader=train_loader) 30 | eval_func = proto_eval.default_eval 31 | 32 | tm = util.Train_Manager(args,pm,config, 33 | train_func=train_func, 34 | eval_func=eval_func) 35 | 36 | tm.train(model) 37 | proto_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/proto+PN_less_annot/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 600 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --alpha 100 \ 11 | --percent 20 \ 12 | --batch_size 7 \ 13 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+PN_less_annot/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 5e-3 \ 9 | --num_part 15 \ 10 | --alpha 200 \ 11 | --percent 20 \ 12 | --batch_size 7 \ 13 | --resnet \ 14 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+PN_less_annot/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name, 13 | train_annot='part', 14 | suffix="percent_%d-bz_%d"%(args.percent,args.batch_size)) 15 | 16 | train_loader = dataloader.proto_train_less_annot_dataloader(data_path=pm.support, 17 | shots=config.shots, 18 | way=config.way, 19 | annot_path=pm.annot_path, 20 | percent=args.percent, 21 | batch_size=args.batch_size) 22 | 23 | model = networks.Proto_PN_less_annot(num_part=args.num_part, 24 | way=config.way, 25 | shots=config.shots, 26 | resnet=args.resnet) 27 | 28 | model.cuda() 29 | 30 | train_func = partial(proto_train.PN_train_less_annot, 31 | train_loader=train_loader, 32 | alpha=args.alpha, 33 | batch_size=args.batch_size) 34 | 35 | eval_func = proto_eval.default_eval 36 | 37 | tm = util.Train_Manager(args,pm,config, 38 | train_func=train_func, 39 | eval_func=eval_func) 40 | 41 | tm.train(model) 42 | 43 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 44 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+bbN/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-2 \ 5 | --gamma 1e-1 \ 6 | --epoch 400 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --num_part 2 \ 10 | --alpha 10 \ 11 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+bbN/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt adam \ 4 | --lr 1e-1 \ 5 | --gamma 5e-1 \ 6 | --epoch 160 \ 7 | --stage 5 \ 8 | --weight_decay 0 \ 9 | --num_part 2 \ 10 | --alpha 10 \ 11 | --resnet \ 12 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+bbN/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name, 13 | train_annot='bbx') 14 | 15 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 16 | shots=config.shots, 17 | way=config.way, 18 | annot=config.train_annot, 19 | annot_path=pm.annot_path) 20 | 21 | model = networks.Proto_bbN(num_part=args.num_part, 22 | way=config.way, 23 | shots=config.shots, 24 | resnet=args.resnet) 25 | 26 | model.cuda() 27 | 28 | train_func = partial(proto_train.bbN_train,train_loader=train_loader,alpha=args.alpha) 29 | eval_func = proto_eval.default_eval 30 | 31 | tm = util.Train_Manager(args,pm,config, 32 | train_func=train_func, 33 | eval_func=eval_func) 34 | 35 | tm.train(model) 36 | 37 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 38 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto+uPN/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 600 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+uPN/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 200 \ 7 | --stage 2 \ 8 | --weight_decay 5e-3 \ 9 | --num_part 15 \ 10 | --resnet \ 11 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto+uPN/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name) 13 | 14 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 15 | shots=config.shots, 16 | way=config.way) 17 | 18 | model = networks.Proto_uPN(num_part=args.num_part, 19 | way=config.way, 20 | shots=config.shots, 21 | resnet=args.resnet) 22 | 23 | model.cuda() 24 | 25 | train_func = partial(proto_train.default_train,train_loader=train_loader) 26 | eval_func = proto_eval.default_eval 27 | 28 | tm = util.Train_Manager(args,pm,config, 29 | train_func=train_func, 30 | eval_func=eval_func) 31 | 32 | tm.train(model) 33 | 34 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 35 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/proto/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 400 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --gpu 0 -------------------------------------------------------------------------------- /CUB/proto/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args,name) 13 | 14 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 15 | shots=config.shots, 16 | way=config.way) 17 | 18 | model = networks.Proto(way=config.way, 19 | shots=config.shots, 20 | resnet=args.resnet) 21 | 22 | model.cuda() 23 | 24 | train_func = partial(proto_train.default_train,train_loader=train_loader) 25 | eval_func = proto_eval.default_eval 26 | 27 | tm = util.Train_Manager(args,pm,config, 28 | train_func=train_func, 29 | eval_func=eval_func) 30 | 31 | tm.train(model) 32 | 33 | pm_na = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 34 | proto_eval.eval_test(model,pm,config,pm_na=pm_na) -------------------------------------------------------------------------------- /CUB/transfer+PN/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_base.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 200 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --num_part 15 \ 10 | --alpha 100 \ 11 | --batch_size 64 \ 12 | --gpu 0 13 | 14 | python finetune_cub.py \ 15 | --opt adam \ 16 | --lr 1e-3 \ 17 | --gamma 1e-1 \ 18 | --epoch 40 \ 19 | --stage 1 \ 20 | --weight_decay 0 \ 21 | --num_part 15 \ 22 | --batch_size 16 \ 23 | --load_path model_Conv4-base.pth \ 24 | --gpu 0 25 | 26 | python finetune_na.py \ 27 | --opt adam \ 28 | --lr 1e-3 \ 29 | --gamma 1e-1 \ 30 | --epoch 20 \ 31 | --stage 1 \ 32 | --weight_decay 0 \ 33 | --num_part 15 \ 34 | --batch_size 16 \ 35 | --load_path model_Conv4-base.pth \ 36 | --gpu 0 -------------------------------------------------------------------------------- /CUB/transfer+PN/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_base.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 100 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --alpha 100 \ 11 | --batch_size 64 \ 12 | --resnet \ 13 | --gpu 0 14 | 15 | python finetune_cub.py \ 16 | --opt adam \ 17 | --lr 1e-3 \ 18 | --gamma 1e-1 \ 19 | --epoch 40 \ 20 | --stage 1 \ 21 | --weight_decay 0 \ 22 | --num_part 15 \ 23 | --batch_size 16 \ 24 | --resnet \ 25 | --load_path model_ResNet18-base.pth \ 26 | --gpu 0 27 | 28 | python finetune_na.py \ 29 | --opt adam \ 30 | --lr 1e-3 \ 31 | --gamma 1e-1 \ 32 | --epoch 20 \ 33 | --stage 1 \ 34 | --weight_decay 0 \ 35 | --num_part 15 \ 36 | --batch_size 16 \ 37 | --resnet \ 38 | --load_path model_ResNet18-base.pth \ 39 | --gpu 0 -------------------------------------------------------------------------------- /CUB/transfer+PN/finetune_cub.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | import torch.nn as nn 6 | sys.path.append('../../') 7 | from utils import transfer_train,transfer_eval,networks,dataloader,util 8 | 9 | args,name = util.train_parser() 10 | 11 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 12 | 13 | config = util.Config(args=args, 14 | name=name, 15 | suffix='cub') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.test_refer, 18 | batch_size=args.batch_size) 19 | num_class = len(train_loader.dataset.classes) 20 | 21 | model = networks.Transfer_PN(num_part=args.num_part, 22 | resnet=args.resnet) 23 | model.cuda() 24 | model.load_state_dict(torch.load(args.load_path)) 25 | model.linear_classifier = nn.Linear(model.dim,num_class).cuda() 26 | 27 | train_func = partial(transfer_train.default_train, 28 | train_loader=train_loader) 29 | 30 | tm = util.TM_transfer_PN_finetune(args,pm,config, 31 | train_func=train_func) 32 | 33 | tm.train(model) 34 | 35 | transfer_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/transfer+PN/finetune_na.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from functools import partial 6 | sys.path.append('../../') 7 | from utils import transfer_train,transfer_eval,networks,dataloader,util 8 | 9 | args,name = util.train_parser() 10 | 11 | pm = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 12 | 13 | config = util.Config(args=args, 14 | name=name, 15 | suffix='na') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.test_refer, 18 | batch_size=args.batch_size) 19 | num_class = len(train_loader.dataset.classes) 20 | 21 | model = networks.Transfer_PN(num_part=args.num_part, 22 | resnet=args.resnet) 23 | model.cuda() 24 | model.load_state_dict(torch.load(args.load_path)) 25 | model.linear_classifier = nn.Linear(model.dim,num_class).cuda() 26 | 27 | train_func = partial(transfer_train.default_train,train_loader=train_loader) 28 | 29 | tm = util.TM_transfer_PN_finetune(args,pm,config, 30 | train_func=train_func) 31 | 32 | tm.train(model) 33 | 34 | transfer_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/transfer+PN/train_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import transfer_train,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='base', 15 | train_annot='part') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.support, 18 | batch_size=args.batch_size, 19 | annot=config.train_annot, 20 | annot_path=pm.annot_path) 21 | 22 | model = networks.Transfer_PN(num_part=args.num_part, 23 | resnet=args.resnet) 24 | 25 | model.cuda() 26 | 27 | train_func = partial(transfer_train.PN_train, 28 | train_loader=train_loader, 29 | alpha=args.alpha) 30 | 31 | tm = util.Train_Manager(args,pm,config, 32 | train_func=train_func) 33 | 34 | tm.train(model) -------------------------------------------------------------------------------- /CUB/transfer+PN_gt/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_base.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 200 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --num_part 15 \ 10 | --batch_size 64 \ 11 | --gpu 0 12 | 13 | python finetune_cub.py \ 14 | --opt adam \ 15 | --lr 1e-3 \ 16 | --gamma 1e-1 \ 17 | --epoch 40 \ 18 | --stage 1 \ 19 | --weight_decay 0 \ 20 | --num_part 15 \ 21 | --batch_size 16 \ 22 | --load_path model_Conv4-base.pth \ 23 | --gpu 0 -------------------------------------------------------------------------------- /CUB/transfer+PN_gt/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_base.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 100 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 15 \ 10 | --batch_size 64 \ 11 | --resnet \ 12 | --gpu 0 13 | 14 | python finetune_cub.py \ 15 | --opt adam \ 16 | --lr 1e-3 \ 17 | --gamma 1e-1 \ 18 | --epoch 40 \ 19 | --stage 1 \ 20 | --weight_decay 0 \ 21 | --num_part 15 \ 22 | --batch_size 16 \ 23 | --resnet \ 24 | --load_path model_ResNet18-base.pth \ 25 | --gpu 0 -------------------------------------------------------------------------------- /CUB/transfer+PN_gt/finetune_cub.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | import torch.nn as nn 6 | sys.path.append('../../') 7 | from utils import transfer_train,transfer_eval,networks,dataloader,util 8 | 9 | args,name = util.train_parser() 10 | 11 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 12 | 13 | config = util.Config(args=args, 14 | name=name, 15 | suffix='cub', 16 | train_annot='part', 17 | eval_annot='part') 18 | 19 | train_loader = dataloader.normal_train_dataloader(data_path=pm.test_refer, 20 | batch_size=args.batch_size, 21 | annot=config.train_annot, 22 | annot_path=pm.annot_path) 23 | num_class = len(train_loader.dataset.classes) 24 | 25 | model = networks.Transfer_PN_gt(num_part=args.num_part, 26 | resnet=args.resnet) 27 | model.cuda() 28 | model.load_state_dict(torch.load(args.load_path)) 29 | model.linear_classifier = nn.Linear(model.dim,num_class).cuda() 30 | 31 | train_func = partial(transfer_train.default_train, 32 | train_loader=train_loader) 33 | 34 | tm = util.TM_transfer_finetune(args,pm,config, 35 | train_func=train_func) 36 | 37 | tm.train(model) 38 | 39 | transfer_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/transfer+PN_gt/train_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import transfer_train,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='base', 15 | train_annot='part') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.support, 18 | batch_size=args.batch_size, 19 | annot=config.train_annot, 20 | annot_path=pm.annot_path) 21 | 22 | model = networks.Transfer_PN_gt(num_part=args.num_part, 23 | resnet=args.resnet) 24 | 25 | model.cuda() 26 | 27 | train_func = partial(transfer_train.default_train,train_loader=train_loader) 28 | 29 | tm = util.Train_Manager(args,pm,config, 30 | train_func=train_func) 31 | 32 | tm.train(model) -------------------------------------------------------------------------------- /CUB/transfer/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_base.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 200 \ 7 | --stage 2 \ 8 | --weight_decay 5e-4 \ 9 | --batch_size 64 \ 10 | --gpu 0 11 | 12 | python finetune_cub.py \ 13 | --opt adam \ 14 | --lr 1e-3 \ 15 | --gamma 1e-1 \ 16 | --epoch 40 \ 17 | --stage 1 \ 18 | --weight_decay 0 \ 19 | --batch_size 16 \ 20 | --load_path model_Conv4-base.pth \ 21 | --gpu 0 22 | 23 | python finetune_na.py \ 24 | --opt adam \ 25 | --lr 1e-3 \ 26 | --gamma 1e-1 \ 27 | --epoch 20 \ 28 | --stage 1 \ 29 | --weight_decay 0 \ 30 | --batch_size 16 \ 31 | --load_path model_Conv4-base.pth \ 32 | --gpu 0 -------------------------------------------------------------------------------- /CUB/transfer/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_base.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 100 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --batch_size 64 \ 11 | --gpu 0 12 | 13 | python finetune_cub.py \ 14 | --opt adam \ 15 | --lr 1e-3 \ 16 | --gamma 1e-1 \ 17 | --epoch 40 \ 18 | --stage 1 \ 19 | --weight_decay 0 \ 20 | --resnet \ 21 | --batch_size 16 \ 22 | --load_path model_ResNet18-base.pth \ 23 | --gpu 0 24 | 25 | python finetune_na.py \ 26 | --opt adam \ 27 | --lr 1e-3 \ 28 | --gamma 1e-1 \ 29 | --epoch 20 \ 30 | --stage 1 \ 31 | --weight_decay 0 \ 32 | --resnet \ 33 | --batch_size 16 \ 34 | --load_path model_ResNet18-base.pth \ 35 | --gpu 0 -------------------------------------------------------------------------------- /CUB/transfer/finetune_cub.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | import torch.nn as nn 6 | sys.path.append('../../') 7 | from utils import transfer_train,transfer_eval,networks,dataloader,util 8 | 9 | args,name = util.train_parser() 10 | 11 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 12 | 13 | config = util.Config(args=args, 14 | name=name, 15 | suffix='cub') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.test_refer, 18 | batch_size=args.batch_size) 19 | num_class = len(train_loader.dataset.classes) 20 | 21 | model = networks.Transfer(resnet=args.resnet) 22 | model.cuda() 23 | model.load_state_dict(torch.load(args.load_path)) 24 | model.linear_classifier = nn.Linear(model.dim,num_class).cuda() 25 | 26 | train_func = partial(transfer_train.default_train,train_loader=train_loader) 27 | 28 | tm = util.TM_transfer_finetune(args,pm,config, 29 | train_func=train_func) 30 | 31 | tm.train(model) 32 | 33 | transfer_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/transfer/finetune_na.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from functools import partial 6 | sys.path.append('../../') 7 | from utils import transfer_train,transfer_eval,networks,dataloader,util 8 | 9 | args,name = util.train_parser() 10 | 11 | pm = util.Path_Manager_NA('../../dataset/na_fewshot',args=args) 12 | 13 | config = util.Config(args=args, 14 | name=name, 15 | suffix='na') 16 | 17 | train_loader = dataloader.normal_train_dataloader(data_path=pm.test_refer, 18 | batch_size=args.batch_size) 19 | num_class = len(train_loader.dataset.classes) 20 | 21 | model = networks.Transfer(resnet=args.resnet) 22 | model.cuda() 23 | model.load_state_dict(torch.load(args.load_path)) 24 | model.linear_classifier = nn.Linear(model.dim,num_class).cuda() 25 | 26 | train_func = partial(transfer_train.default_train,train_loader=train_loader) 27 | 28 | tm = util.TM_transfer_finetune(args,pm,config, 29 | train_func=train_func) 30 | 31 | tm.train(model) 32 | 33 | transfer_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /CUB/transfer/train_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import transfer_train,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/cub_fewshot',args=args) 11 | 12 | config = util.Config(args=args, 13 | name=name, 14 | suffix='base') 15 | 16 | train_loader = dataloader.normal_train_dataloader(data_path=pm.support, 17 | batch_size=args.batch_size) 18 | 19 | model = networks.Transfer(resnet=args.resnet) 20 | 21 | model.cuda() 22 | 23 | train_func = partial(transfer_train.default_train,train_loader=train_loader) 24 | 25 | tm = util.Train_Manager(args,pm,config, 26 | train_func=train_func) 27 | 28 | tm.train(model) -------------------------------------------------------------------------------- /FGVC/proto+PN/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 500 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --num_part 5 \ 10 | --alpha 50 \ 11 | --val_epoch 40 \ 12 | --batch_size 400 \ 13 | --gpu 0 -------------------------------------------------------------------------------- /FGVC/proto+PN/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 5e-3 \ 9 | --num_part 5 \ 10 | --alpha 50 \ 11 | --resnet \ 12 | --val_epoch 40 \ 13 | --batch_size 400 \ 14 | --gpu 0 -------------------------------------------------------------------------------- /FGVC/proto+PN/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/fgvc_fewshot',args=args) 11 | 12 | config = util.Config(args,name) 13 | 14 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 15 | shots=config.shots, 16 | way=config.way) 17 | 18 | if args.resnet: 19 | res=224 20 | else: 21 | res=84 22 | oid_path = '../../dataset/oid_fewshot/res_%d'%(res) 23 | oid_loader = dataloader.oid_dataloader(oid_path,args.batch_size) 24 | 25 | model = networks.Proto_PN_less_annot(num_part=args.num_part, 26 | way=config.way, 27 | shots=config.shots, 28 | resnet=args.resnet) 29 | 30 | model.cuda() 31 | 32 | train_func = partial(proto_train.fgvc_PN_train, 33 | train_loader=train_loader, 34 | oid_loader=oid_loader, 35 | alpha=args.alpha) 36 | eval_func = proto_eval.default_eval 37 | 38 | tm = util.Train_Manager(args,pm,config, 39 | train_func=train_func, 40 | eval_func=eval_func) 41 | 42 | tm.train(model) 43 | 44 | proto_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /FGVC/proto/Conv4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 500 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --val_epoch 40 \ 10 | --gpu 0 -------------------------------------------------------------------------------- /FGVC/proto/ResNet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --opt sgd \ 4 | --lr 1e-1 \ 5 | --gamma 1e-1 \ 6 | --epoch 300 \ 7 | --stage 2 \ 8 | --weight_decay 1e-3 \ 9 | --resnet \ 10 | --val_epoch 40 \ 11 | --gpu 0 -------------------------------------------------------------------------------- /FGVC/proto/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | sys.path.append('../../') 6 | from utils import proto_train,proto_eval,networks,dataloader,util 7 | 8 | args,name = util.train_parser() 9 | 10 | pm = util.Path_Manager('../../dataset/fgvc_fewshot',args=args) 11 | 12 | config = util.Config(args,name) 13 | 14 | train_loader = dataloader.meta_train_dataloader(data_path=pm.support, 15 | shots=config.shots, 16 | way=config.way) 17 | 18 | model = networks.Proto(way=config.way, 19 | shots=config.shots, 20 | resnet=args.resnet) 21 | 22 | model.cuda() 23 | 24 | train_func = partial(proto_train.default_train,train_loader=train_loader) 25 | eval_func = proto_eval.default_eval 26 | 27 | tm = util.Train_Manager(args,pm,config, 28 | train_func=train_func, 29 | eval_func=eval_func) 30 | 31 | tm.train(model) 32 | 33 | proto_eval.eval_test(model,pm,config) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Luming Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisiting Pose-Normalization for Fine-Grained Few-Shot Recognition 2 | 3 | This repo contains the reference source code for our CVPR 2020 paper [Revisiting Pose-Normalization for Fine-Grained Few-Shot Recognition](https://arxiv.org/abs/2004.00705). 4 | 5 | ## Environment 6 | 7 | Python 3.7 8 | 9 | Pytorch 1.1.0 with CUDA 9.0 10 | 11 | tensorboardX 12 | 13 | ## Set up dataset 14 | 15 | In our experiments, we use four datasets: [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [NABirds](https://dl.allaboutbirds.org/nabirds), [FGVC-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) and [OID-Aircraft](http://www.robots.ox.ac.uk/~vgg/data/oid/). 16 | 17 | You have two options to download this data: 18 | 19 | - Manually download them using the hyper-links provided above, and then extract them into the `dataset` folder. 20 | - If you don't want to download and extract them one by one, you can also go into the `dataset` folder and execute `download.sh`. Before doing that, though, please go to the official [NABirds](https://dl.allaboutbirds.org/nabirds) website and register using your name and email address, and also accept their terms of use. 21 | 22 | After download is finished, navigate to the `dataset` folder and execute `init.sh` to generate the dataset we use for training and evaluation. More details about dataset split can be found in the paper. 23 | 24 | ## Train and test 25 | 26 | For experiments on CUB and FGVC, each model has its own individual folder in the `CUB` or `FGVC` directory respectively. 27 | 28 | For traininng, you just need to go to the model folder you wish to run, and execute `Con4.sh` or `ResNet18.sh` for the 4-layer ConvNet or ResNet18 backbone. The hyper-parameters have already been set to the values given in the supplementary materials. We set the default gpu device to 0. If you want to specify others, just change the `--gpu` argument in the `.sh` script. 29 | 30 | The training and validation accuracies are displayed in both the std output and the generated `*.log` file. The training history, including losses, train/validation accuracy and heatmap visualization, can also be displayed via tensorboard. The tensorboard summary is located in the `log_*` folder. During the training process, the model snapshot with the best validation performance will be saved in `model_*.pth`. 31 | 32 | After training is complete, the script will automatically evaluate the final model on the corresponding test set, and output the test accuracy numbers in both the std output and `*.log` file. 33 | 34 | ## Train with less part annotation 35 | 36 | For the ablation study on training on the CUB dataset with less part annotation, navigate to `proto+PN_less_annot` in the `CUB` directory. The `.sh` scripts have been set to training with 20% part annotation and batch size 7. If you want to train with other percentages of annotation, change both the `--percent` and `--batch_size` arguments in the scripts as described in the supplementary matrials. 37 | 38 | 39 | ## Citation 40 | If you find our code or paper useful, please consider citing our work using the following bibtex: 41 | ``` 42 | @inproceedings{tang2020revisiting, 43 | title={Revisiting Pose-Normalization for Fine-Grained Few-Shot Recognition}, 44 | author={Tang, Luming and Wertheimer, Davis and Hariharan, Bharath}, 45 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 46 | pages={14352--14361}, 47 | year={2020} 48 | } 49 | ``` 50 | 51 | 52 | ## Updates 53 | 05/27/2021: As pointed out in this [issue](https://github.com/Tsingularity/PoseNorm_Fewshot/issues/2), the website for OID-Aircraft seems to be down for now. As an alternative for downloading the dataset, I upload one copy to this google drive [link](https://drive.google.com/file/d/10vKcoS6-JFEpioD_FStJZbWknFCfPOnU/view?usp=sharing). More details about the dataset could be found in its original [paper](https://www.robots.ox.ac.uk/~karen/pdf/vedaldi14understanding.pdf). If you wanna download the dataset from google drive via command line directly, please refer to this [line](https://github.com/Tsingularity/PoseNorm_Fewshot/blob/master/dataset/download.sh#L3) of code in `download.sh`, or you can refer to our recent [FRN](https://github.com/Tsingularity/FRN) repository for more details. 54 | -------------------------------------------------------------------------------- /dataset/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "downloading CUB..." 3 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" -O CUB_200_2011.tgz && rm -rf /tmp/cookies.txt 4 | tar -xzf CUB_200_2011.tgz 5 | rm CUB_200_2011.tgz 6 | rm attributes.txt 7 | echo "CUB downloaded" 8 | 9 | echo "downloading NABird..." 10 | wget https://www.dropbox.com/s/nf78cbxq6bxpcfc/nabirds.tar.gz 11 | tar -xzf nabirds.tar.gz 12 | rm nabirds.tar.gz 13 | echo "NABird downloaded" 14 | 15 | echo "downloading FGVC-Aircraft..." 16 | wget http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz 17 | tar -xzf fgvc-aircraft-2013b.tar.gz 18 | rm fgvc-aircraft-2013b.tar.gz 19 | echo "FGVC-Aircraft downloaded" 20 | 21 | echo "downloading OID..." 22 | wget http://www.robots.ox.ac.uk/~vgg/data/oid/archives/oid-aircraft-beta-1.tar.gz 23 | tar -xzf oid-aircraft-beta-1.tar.gz 24 | rm oid-aircraft-beta-1.tar.gz 25 | echo "OID downloaded" 26 | -------------------------------------------------------------------------------- /dataset/exclude_na_id_list.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tsingularity/PoseNorm_Fewshot/139ee43e8b4f4343c860562d151f125db8a7fb49/dataset/exclude_na_id_list.pth -------------------------------------------------------------------------------- /dataset/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "initializing CUB..." 3 | python init_cub.py --origin_path ./CUB_200_2011 4 | echo "CUB finished" 5 | 6 | echo "initializing NABird..." 7 | python init_na.py --origin_path ./nabirds 8 | echo "NABird finished" 9 | 10 | echo "initializing FGVC-Aircraft..." 11 | python init_fgvc.py --origin_path ./fgvc-aircraft-2013b 12 | echo "FGVC-Aircraft finished" 13 | 14 | echo "initializing OID..." 15 | python init_oid.py \ 16 | --oid_origin_path ./oid-aircraft-beta-1 \ 17 | --fgvc_origin_path ./fgvc-aircraft-2013b 18 | echo "OID finished" -------------------------------------------------------------------------------- /dataset/init_cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | import sys 6 | sys.path.append('..') 7 | from utils import util 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--origin_path",help="directory of the original CUB dataset you download and extract",type=str) 11 | args = parser.parse_args() 12 | 13 | origin_path = args.origin_path 14 | target_path = os.path.abspath('./cub_fewshot') 15 | resolution = [84,224] 16 | 17 | util.mkdir(target_path) 18 | 19 | id2path = {} 20 | 21 | with open(os.path.join(origin_path,'images.txt')) as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | index, path = line.strip().split() 25 | index = int(index) 26 | id2path[index] = path 27 | 28 | cat2name = {} 29 | 30 | with open(os.path.join(origin_path,'classes.txt')) as f: 31 | lines = f.readlines() 32 | for line in lines: 33 | cat, name = line.strip().split() 34 | cat = int(cat) 35 | cat2name[cat] = name 36 | 37 | cat2img = {} 38 | 39 | with open(os.path.join(origin_path,'image_class_labels.txt')) as f: 40 | lines = f.readlines() 41 | for line in lines: 42 | image_id, class_id = line.strip().split() 43 | image_id = int(image_id) 44 | class_id = int(class_id) 45 | 46 | if class_id not in cat2img: 47 | cat2img[class_id]=[] 48 | cat2img[class_id].append(image_id) 49 | 50 | support = [] 51 | val_ref = [] 52 | val_query = [] 53 | test_ref = [] 54 | test_query = [] 55 | 56 | support_cat = [] 57 | val_cat = [] 58 | test_cat = [] 59 | 60 | for i in range(1,201): 61 | img_list = cat2img[i] 62 | img_num = len(img_list) 63 | name = cat2name[i] 64 | 65 | if i%2 == 0: 66 | support_cat.append(name) 67 | support.extend(img_list) 68 | elif i%4 == 1: 69 | val_cat.append(name) 70 | val_ref.extend(img_list[:img_num//5]) 71 | val_query.extend(img_list[img_num//5:]) 72 | elif i%4 ==3: 73 | test_cat.append(name) 74 | test_ref.extend(img_list[:img_num//5]) 75 | test_query.extend(img_list[img_num//5:]) 76 | 77 | id2bbx={} 78 | 79 | with open(os.path.join(origin_path,'bounding_boxes.txt')) as f: 80 | lines = f.readlines() 81 | for line in lines: 82 | index,x,y,width,height = line.strip().split() 83 | index = int(index) 84 | x = float(x) 85 | y = float(y) 86 | width = float(width) 87 | height = float(height) 88 | id2bbx[index] = [x,y,width,height] 89 | 90 | id2part={} 91 | 92 | with open(os.path.join(origin_path,'parts','part_locs.txt')) as f: 93 | lines = f.readlines() 94 | for line in lines: 95 | index,part_id,x,y,visible = line.strip().split() 96 | index = int(index) 97 | x = float(x) 98 | y = float(y) 99 | visible = int(visible) 100 | if index not in id2part: 101 | id2part[index]=[] 102 | id2part[index].append([x,y,visible]) 103 | 104 | split = ['support','val/refer','val/query','test/refer','test/query'] 105 | split_cat = [support_cat,val_cat,test_cat] 106 | split_img = [support,val_ref,val_query,test_ref,test_query] 107 | 108 | for res in resolution: 109 | 110 | path2annot = {} 111 | 112 | res_dir = os.path.join(target_path,'res_'+str(res)) 113 | util.mkdir(res_dir) 114 | 115 | for folder_name in ['support','val','test','val/refer','val/query','test/refer','test/query']: 116 | util.mkdir(os.path.join(res_dir,folder_name)) 117 | 118 | for i in range(3): 119 | if i: 120 | for j in [2*i-1,2*i]: 121 | temp_path = os.path.join(res_dir,split[j]) 122 | for cat_name in split_cat[i]: 123 | util.mkdir(os.path.join(temp_path,cat_name)) 124 | else: 125 | temp_path = os.path.join(res_dir,split[i]) 126 | for cat_name in split_cat[i]: 127 | util.mkdir(os.path.join(temp_path,cat_name)) 128 | 129 | for i in range(5): 130 | temp_path = os.path.join(res_dir,split[i]) 131 | for index in split_img[i]: 132 | img_path = id2path[index] 133 | origin_img = os.path.join(origin_path,'images',img_path) 134 | target_img = os.path.join(temp_path,img_path[:-3]+'bmp') 135 | 136 | p = Image.open(origin_img) 137 | w,h = p.size 138 | p = p.resize((res,res),Image.BILINEAR) 139 | p.save(target_img) 140 | 141 | x,y,width,height = id2bbx[index] 142 | x_min = x/w 143 | x_max = (x+width)/w 144 | y_min = y/h 145 | y_max = (y+height)/h 146 | 147 | parts = id2part[index] 148 | new_parts = [] 149 | for part in parts: 150 | x = part[0]/w 151 | y = part[1]/h 152 | new_parts.append([x,y,part[2]]) 153 | 154 | path2annot[target_img] = {} 155 | path2annot[target_img]['bbx'] = [x_min,x_max,y_min,y_max] 156 | path2annot[target_img]['part'] = new_parts 157 | 158 | torch.save(path2annot,os.path.join(res_dir,'path2annot.pth')) 159 | 160 | 161 | for res in resolution: 162 | 163 | res_dir = os.path.join(target_path,'res_%d'%(res)) 164 | 165 | origin_dict = torch.load(os.path.join(res_dir,'path2annot.pth')) 166 | 167 | tar_dir = os.path.join(res_dir,'eval_k_shot') 168 | util.mkdir(tar_dir) 169 | 170 | cat_name = os.listdir(os.path.join(res_dir,'test/refer')) 171 | for cat in cat_name: 172 | util.mkdir(os.path.join(tar_dir,cat)) 173 | 174 | tar_dict = {} 175 | 176 | for filename in origin_dict: 177 | 178 | if 'test/refer' in filename: 179 | tar_filename = filename.replace('test/refer','eval_k_shot') 180 | elif 'test/query' in filename: 181 | tar_filename = filename.replace('test/query','eval_k_shot') 182 | else: 183 | continue 184 | 185 | os.symlink(filename,tar_filename) 186 | 187 | tar_dict[tar_filename] = origin_dict[filename] 188 | 189 | torch.save(tar_dict,os.path.join(res_dir,'path2annot_eval_k_shot.pth')) -------------------------------------------------------------------------------- /dataset/init_fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | import argparse 5 | from PIL import Image 6 | import numpy as np 7 | import sys 8 | sys.path.append('..') 9 | from utils import util 10 | 11 | np.random.seed(42) 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--origin_path",help="directory of the original FGVC dataset you download and extract",type=str) 15 | args = parser.parse_args() 16 | 17 | home_dir = os.path.join(args.origin_path,'data') 18 | target_dir = os.path.abspath('./fgvc_fewshot') 19 | resolution = [84,224] 20 | 21 | util.mkdir(target_dir) 22 | 23 | cat_id2name={} 24 | cat_name2id={} 25 | with open(os.path.join(home_dir,'variants.txt')) as f: 26 | content = f.readlines() 27 | for i in range(len(content)): 28 | name=content[i].strip() 29 | cat_id2name[i]=name 30 | cat_name2id[name]=i 31 | 32 | img2cat={} 33 | cat2img={} 34 | for i in ['images_variant_trainval.txt','images_variant_test.txt']: 35 | with open(os.path.join(home_dir,i)) as f: 36 | content = f.readlines() 37 | for line in content: 38 | line = line.strip() 39 | img = line[:7] 40 | cat = line[8:] 41 | cat_id = cat_name2id[cat] 42 | img2cat[img] = cat_id 43 | if cat_id not in cat2img: 44 | cat2img[cat_id]=[] 45 | cat2img[cat_id].append(img) 46 | 47 | img2bbx={} 48 | with open(os.path.join(home_dir,'images_box.txt')) as f: 49 | content = f.readlines() 50 | for line in content: 51 | line = line.strip() 52 | img,xmin,ymin,xmax,ymax = line.split() 53 | xmin = float(xmin) 54 | ymin = float(ymin) 55 | xmax = float(xmax) 56 | ymax = float(ymax) 57 | with Image.open(os.path.join(home_dir,'images',img+'.jpg')) as temp: 58 | width,height = temp.size 59 | height = height-20 60 | img2bbx[img] = [xmin/width,xmax/width,ymin/height,ymax/height] 61 | 62 | support_cat=[] 63 | val_cat=[] 64 | test_cat=[] 65 | for i in range(100): 66 | if i%2==0: 67 | support_cat.append(i) 68 | elif i%4==1: 69 | val_cat.append(i) 70 | elif i%4==3: 71 | test_cat.append(i) 72 | 73 | for res in resolution: 74 | res_dir = os.path.join(target_dir,'res_'+str(res)) 75 | util.mkdir(res_dir) 76 | for i in ['support','val','test','val/refer','val/query','test/refer','test/query']: 77 | util.mkdir(os.path.join(res_dir,i)) 78 | dir_name = ['support','val/refer','val/query','test/refer','test/query'] 79 | cat_list = [support_cat,val_cat,test_cat] 80 | for i in range(5): 81 | index = math.ceil(i/2) 82 | tar_cat_list = cat_list[index] 83 | for j in tar_cat_list: 84 | util.mkdir(os.path.join(res_dir,dir_name[i],str(j))) 85 | 86 | def resize_img(path,res): 87 | img = Image.open(path) 88 | width = img.size[0] 89 | height = img.size[1]-20 90 | img = img.crop((0,0,width,height)).resize((res,res),Image.BILINEAR) 91 | return img 92 | 93 | 94 | for res in resolution: 95 | 96 | path2annot = {} 97 | res_dir = os.path.join(target_dir,'res_'+str(res)) 98 | for i in support_cat: 99 | for img in cat2img[i]: 100 | tar_img = resize_img(os.path.join(home_dir,'images',img+'.jpg'),res) 101 | target_path = os.path.join(res_dir,'support',str(i),img+'.bmp') 102 | tar_img.save(target_path) 103 | path2annot[target_path]={'bbx':img2bbx[img]} 104 | 105 | cat_list = [val_cat,test_cat] 106 | dir_name_1 = ['val','test'] 107 | dir_name_2 = ['refer','query'] 108 | 109 | for i in range(2): 110 | for j in cat_list[i]: 111 | img_list = cat2img[j] 112 | img_num = len(img_list) 113 | np.random.shuffle(img_list) 114 | 115 | refer_list = img_list[:img_num//5] 116 | query_list = img_list[img_num//5:] 117 | temp_list = [refer_list,query_list] 118 | for k in range(2): 119 | for img in temp_list[k]: 120 | tar_img = resize_img(os.path.join(home_dir,'images',img+'.jpg'),res) 121 | target_path = os.path.join(res_dir,dir_name_1[i],dir_name_2[k],str(j),img+'.bmp') 122 | tar_img.save(target_path) 123 | path2annot[target_path]={'bbx':img2bbx[img]} 124 | 125 | torch.save(path2annot,os.path.join(res_dir,'path2annot.pth')) 126 | 127 | ###eval k shot 128 | 129 | origin_dict = path2annot 130 | 131 | tar_dir = os.path.join(res_dir,'eval_k_shot') 132 | util.mkdir(tar_dir) 133 | 134 | cat_name = os.listdir(os.path.join(res_dir,'test/refer')) 135 | for cat in cat_name: 136 | util.mkdir(os.path.join(tar_dir,cat)) 137 | 138 | tar_dict = {} 139 | 140 | for filename in origin_dict: 141 | 142 | if 'test/refer' in filename: 143 | tar_filename = filename.replace('test/refer','eval_k_shot') 144 | elif 'test/query' in filename: 145 | tar_filename = filename.replace('test/query','eval_k_shot') 146 | else: 147 | continue 148 | 149 | os.symlink(filename,tar_filename) 150 | 151 | tar_dict[tar_filename] = origin_dict[filename] 152 | 153 | torch.save(tar_dict,os.path.join(res_dir,'path2annot_eval_k_shot.pth')) -------------------------------------------------------------------------------- /dataset/init_na.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import os 4 | import numpy as np 5 | import sys 6 | import argparse 7 | sys.path.append('..') 8 | from utils import util 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--origin_path",help="directory of the original nabirds dataset you download and extract",type=str) 12 | args = parser.parse_args() 13 | 14 | origin_path = args.origin_path 15 | target_path = os.path.abspath('./na_fewshot') 16 | exclude_na_id_list = torch.load('exclude_na_id_list.pth') 17 | resolution = [84,224] 18 | 19 | np.random.seed(42) 20 | util.mkdir(target_path) 21 | 22 | name2size={} 23 | 24 | with open(os.path.join(origin_path,'sizes.txt')) as f: 25 | while True: 26 | content = f.readline().strip() 27 | if content == '': 28 | break 29 | content = content.split() 30 | name = content[0].replace('-','') 31 | width = int(content[1]) 32 | height = int(content[2]) 33 | name2size[name] = [width,height] 34 | 35 | name2bbx={} 36 | 37 | with open(os.path.join(origin_path,'bounding_boxes.txt')) as f: 38 | while True: 39 | content = f.readline().strip() 40 | if content == '': 41 | break 42 | content = content.split() 43 | name = content[0].replace('-','') 44 | x = int(content[1]) 45 | y = int(content[2]) 46 | width = int(content[3]) 47 | height = int(content[4]) 48 | 49 | [w,h] = name2size[name] 50 | x_min = x/w 51 | x_max = (x+width)/w 52 | y_min = y/h 53 | y_max = (y+height)/h 54 | name2bbx[name] = [x_min,x_max,y_min,y_max] 55 | 56 | name2part={} 57 | 58 | with open(os.path.join(origin_path,'parts/part_locs.txt')) as f: 59 | while True: 60 | content = f.readline().strip() 61 | if content == '': 62 | break 63 | content = content.split() 64 | name = content[0].replace('-','') 65 | x = int(content[2]) 66 | y = int(content[3]) 67 | visible = int(content[4]) 68 | 69 | if name not in name2part: 70 | name2part[name] = [] 71 | 72 | [w,h] = name2size[name] 73 | x = x/w 74 | y = y/h 75 | name2part[name].append([x,y,visible]) 76 | 77 | name2annotation = {} 78 | for i in name2bbx: 79 | name2annotation[i] = {} 80 | name2annotation[i]['bbx'] = name2bbx[i] 81 | name2annotation[i]['part'] = name2part[i] 82 | 83 | for res in resolution: 84 | 85 | res_dir = os.path.join(target_path,'res_'+str(res)) 86 | util.mkdir(res_dir) 87 | util.mkdir(os.path.join(res_dir,'refer')) 88 | util.mkdir(os.path.join(res_dir,'query')) 89 | 90 | path2annot = {} 91 | 92 | for i in os.listdir(os.path.join(origin_path,'images')): 93 | 94 | if int(i) in exclude_na_id_list: 95 | continue 96 | util.mkdir(os.path.join(res_dir,'refer',i)) 97 | util.mkdir(os.path.join(res_dir,'query',i)) 98 | 99 | image_list = os.listdir(os.path.join(origin_path,'images',i)) 100 | np.random.shuffle(image_list) 101 | 102 | num = len(image_list) 103 | refer_num = int(num/5) 104 | refer_list = image_list[:refer_num] 105 | query_list = image_list[refer_num:] 106 | 107 | img_list = [refer_list,query_list] 108 | folder_name = ['refer','query'] 109 | 110 | for index in range(2): 111 | for j in img_list[index]: 112 | p = Image.open(os.path.join(origin_path,'images',i,j)) 113 | p = p.convert('RGB') 114 | 115 | p = p.resize((res,res),Image.BILINEAR) 116 | target_img = os.path.join(res_dir,folder_name[index],i,j[:-3]+'bmp') 117 | p.save(target_img) 118 | path2annot[target_img] = name2annotation[j[:-4]] 119 | 120 | torch.save(path2annot,os.path.join(res_dir,'path2annot.pth')) -------------------------------------------------------------------------------- /dataset/init_oid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image,ImageDraw 4 | import scipy.io as scio 5 | import argparse 6 | import sys 7 | sys.path.append('..') 8 | from utils import util 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--oid_origin_path",help="directory of the original OID dataset you download and extract",type=str) 12 | parser.add_argument("--fgvc_origin_path",help="directory of the original FGVC dataset you download and extract",type=str) 13 | args = parser.parse_args() 14 | 15 | origin_path = os.path.join(args.oid_origin_path,'data/images/aeroplane') 16 | target_path = './oid_fewshot' 17 | resolution = [84,224] 18 | feature_map = [10,14] 19 | 20 | util.mkdir(target_path) 21 | 22 | fg_list = os.listdir(os.path.join(args.fgvc_origin_path,'data/images')) 23 | mat_file = os.path.join(args.oid_origin_path,'data/annotations/anno.mat') 24 | 25 | data = scio.loadmat(mat_file,struct_as_record=False, squeeze_me=True) 26 | anno = data['anno'] 27 | 28 | path2id={} 29 | for i in range(len(anno.aeroplane.id)): 30 | path=anno.image.name[anno.aeroplane.parentId[i]-1] 31 | if path not in path2id: 32 | path2id[path]=[] 33 | path2id[path].append(anno.aeroplane.id[i]) 34 | 35 | id2part={} 36 | for i in range(len(anno.aeroplane.id)): 37 | id2part[anno.aeroplane.id[i]]={} 38 | id2part[anno.aeroplane.id[i]]['aero']=[anno.aeroplane.polygon[i]] 39 | 40 | for i in range(len(anno.wing.id)): 41 | par_id = anno.wing.parentId[i] 42 | if 'wing' not in id2part[par_id]: 43 | id2part[par_id]['wing']=[] 44 | id2part[par_id]['wing'].append(anno.wing.polygon[i]) 45 | 46 | for i in range(len(anno.wheel.id)): 47 | par_id = anno.wheel.parentId[i] 48 | if 'wheel' not in id2part[par_id]: 49 | id2part[par_id]['wheel']=[] 50 | id2part[par_id]['wheel'].append(anno.wheel.polygon[i]) 51 | 52 | for i in range(len(anno.verticalStabilizer.id)): 53 | par_id = anno.verticalStabilizer.parentId[i] 54 | if 'vertical' not in id2part[par_id]: 55 | id2part[par_id]['vertical']=[] 56 | id2part[par_id]['vertical'].append(anno.verticalStabilizer.polygon[i]) 57 | 58 | for i in range(len(anno.nose.id)): 59 | par_id = anno.nose.parentId[i] 60 | if 'nose' not in id2part[par_id]: 61 | id2part[par_id]['nose']=[] 62 | id2part[par_id]['nose'].append(anno.wheel.polygon[i]) 63 | 64 | valid_path = [] 65 | for i in list(path2id.keys()): 66 | if i not in fg_list: 67 | if len(path2id[i])==1: 68 | valid_path.append(i) 69 | 70 | 71 | for i in range(2): 72 | 73 | res = resolution[i] 74 | fm_size = feature_map[i] 75 | par_dir = os.path.join(target_path,'res_%d'%(res)) 76 | util.mkdir(par_dir) 77 | 78 | for name in ['origin','aero','wing','wheel','vertical','nose']: 79 | util.mkdir(os.path.join(par_dir,name)) 80 | 81 | for j in valid_path: 82 | origin_img = Image.open(os.path.join(origin_path,j)) 83 | width = origin_img.size[0] 84 | height = origin_img.size[1]-20 85 | tar_img = origin_img.crop((0,0,width,height)).resize((res,res),Image.BILINEAR) 86 | tar_img.save(os.path.join(par_dir,'origin',j[:-3]+'bmp')) 87 | 88 | _id = path2id[j][0] 89 | 90 | scalar = np.array([[fm_size/width],[fm_size/height]]) 91 | for part in ['aero','wing','wheel','vertical','nose']: 92 | temp_img=Image.new('L',(fm_size,fm_size),0) 93 | if part in id2part[_id]: 94 | for poly in id2part[_id][part]: 95 | my_poly = poly*scalar 96 | my_poly = my_poly.T.flatten().tolist() 97 | ImageDraw.Draw(temp_img).polygon(my_poly,outline=255,fill=255) 98 | temp_img.save(os.path.join(par_dir,part,j[:-3]+'bmp')) 99 | 100 | 101 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | import numpy as np 7 | from copy import deepcopy 8 | from PIL import Image 9 | from . import sampler 10 | 11 | 12 | mean=[0.485,0.456,0.406] 13 | std=[0.229,0.224,0.225] 14 | 15 | transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean=mean,std=std)]) 18 | 19 | 20 | def get_bird_dataset(data_path,annot,annot_path,flip): 21 | 22 | annot_dict = None 23 | if annot is not None: 24 | annot_dict = torch.load(annot_path) 25 | 26 | dataset = datasets.ImageFolder( 27 | data_path, 28 | loader = lambda x: bird_loader(path=x,flip=flip, 29 | annot_dict=annot_dict,annot=annot)) 30 | 31 | return dataset 32 | 33 | 34 | 35 | def meta_train_dataloader(data_path,shots,way,annot=None,annot_path=None): 36 | 37 | dataset = get_bird_dataset(data_path,annot=annot,annot_path=annot_path,flip=True) 38 | 39 | loader = torch.utils.data.DataLoader( 40 | dataset, 41 | batch_sampler = sampler.meta_batchsampler(data_source=dataset,way=way,shots=shots), 42 | num_workers = 3, 43 | pin_memory = False) 44 | 45 | return loader 46 | 47 | 48 | 49 | def eval_dataloader(data_path,annot=None,annot_path=None): 50 | 51 | dataset = get_bird_dataset(data_path,annot=annot,annot_path=annot_path,flip=False) 52 | 53 | loader = torch.utils.data.DataLoader( 54 | dataset, 55 | batch_sampler = sampler.ordered_sampler(data_source=dataset), 56 | num_workers = 3, 57 | pin_memory = False) 58 | 59 | return loader 60 | 61 | 62 | 63 | def eval_k_shot_dataloader(data_path,way,shot,annot=None,annot_path=None): 64 | 65 | dataset = get_bird_dataset(data_path,annot=annot,annot_path=annot_path,flip=True) 66 | 67 | loader = torch.utils.data.DataLoader( 68 | dataset, 69 | batch_sampler = sampler.random_sampler(data_source=dataset,way=way,shot=shot), 70 | num_workers = 3, 71 | pin_memory = False) 72 | 73 | return loader 74 | 75 | 76 | def normal_train_dataloader(data_path,batch_size,annot=None,annot_path=None): 77 | 78 | dataset = get_bird_dataset(data_path,annot=annot,annot_path=annot_path,flip=True) 79 | 80 | loader = torch.utils.data.DataLoader( 81 | dataset, 82 | batch_size = batch_size, 83 | shuffle = True, 84 | num_workers = 3, 85 | pin_memory = False, 86 | drop_last=True) 87 | 88 | return loader 89 | 90 | 91 | def oid_dataloader(data_path,batch_size,flip=True): 92 | 93 | oid_dataset = OidDataset(data_path,flip) 94 | 95 | dataloader = torch.utils.data.DataLoader(oid_dataset, 96 | batch_size=batch_size,shuffle=True,num_workers=5) 97 | 98 | return dataloader 99 | 100 | 101 | 102 | def proto_train_less_annot_dataloader(data_path,shots,way, 103 | percent,annot_path,batch_size): 104 | 105 | dataset = get_bird_dataset(data_path,annot='part',annot_path=annot_path,flip=True) 106 | 107 | loader = torch.utils.data.DataLoader( 108 | dataset, 109 | batch_sampler = sampler.proto_less_annot_batchsampler(data_source=dataset, 110 | way=way,shots=shots,percent=percent,batch_size=batch_size), 111 | num_workers = 3, 112 | pin_memory = False) 113 | 114 | return loader 115 | 116 | 117 | 118 | def bird_loader(path,flip=False,annot=None,annot_dict=None): 119 | 120 | p = Image.open(path) 121 | 122 | flip = flip and np.random.choice([True,False]) 123 | 124 | if flip: 125 | p = p.transpose(Image.FLIP_LEFT_RIGHT) 126 | 127 | p = p.convert('RGB') 128 | 129 | p = transform(p) 130 | 131 | if annot is None: 132 | return p 133 | 134 | else: 135 | 136 | p_size = p.size(-1) 137 | if p_size == 224: 138 | fm_size = 14 139 | elif p_size == 84: 140 | fm_size = 10 141 | 142 | if annot=='bbx': 143 | 144 | mask = np.zeros((fm_size,fm_size)) 145 | 146 | box = annot_dict[path]['bbx'] 147 | 148 | for i in range(4): 149 | if box[i]>1: 150 | box[i]=1 151 | 152 | x_min = fm_size*box[0] 153 | x_max = fm_size*box[1] 154 | y_min = fm_size*box[2] 155 | y_max = fm_size*box[3] 156 | 157 | x_min_int = int(x_min) 158 | x_max_int = int(x_max-0.0000001)+1 159 | y_min_int = int(y_min) 160 | y_max_int = int(y_max-0.0000001)+1 161 | 162 | if flip: 163 | 164 | mask[y_min_int:y_max_int,fm_size-x_max_int:fm_size-x_min_int] = 1 165 | 166 | # fade out 167 | mask[:, fm_size-x_min_int-1] *= 1-(x_min-x_min_int) 168 | mask[:, fm_size-x_max_int] *= 1-(x_max_int-x_max) 169 | 170 | else: 171 | 172 | mask[y_min_int:y_max_int,x_min_int:x_max_int] = 1 173 | 174 | # fade out 175 | mask[:,x_min_int] *= 1-(x_min-x_min_int) 176 | mask[:,x_max_int-1] *= 1-(x_max_int-x_max) 177 | 178 | mask[y_min_int,:] *= 1-(y_min-y_min_int) 179 | 180 | mask[y_max_int-1,:] *= 1-(y_max_int-y_max) 181 | 182 | mask = torch.FloatTensor(mask).unsqueeze(0) 183 | 184 | return [p,mask] 185 | 186 | elif annot=='part': 187 | 188 | num_part = 15 189 | 190 | mask = np.zeros((num_part,fm_size,fm_size)) 191 | 192 | part_loc = np.array(annot_dict[path]['part']) 193 | 194 | if flip: 195 | 196 | part_loc[[6,10]] = part_loc[[10,6]] 197 | part_loc[[7,11]] = part_loc[[11,7]] 198 | part_loc[[8,12]] = part_loc[[12,8]] 199 | 200 | for i in range(15): 201 | 202 | if part_loc[i][2]==0: 203 | continue 204 | 205 | if part_loc[i][0]>=1: 206 | part_loc[i][0]=0.99999999 207 | if part_loc[i][1]>=1: 208 | part_loc[i][1]=0.99999999 209 | 210 | x_int = int(fm_size*part_loc[i][0]) 211 | y_int = int(fm_size*part_loc[i][1]) 212 | 213 | if flip: 214 | mask[i][y_int][fm_size-1-x_int] = 1 215 | 216 | else: 217 | mask[i][y_int][x_int] = 1 218 | 219 | mask = torch.FloatTensor(mask) 220 | 221 | return [p,mask] 222 | 223 | 224 | 225 | class OidDataset(torch.utils.data.Dataset): 226 | 227 | def __init__(self,root_dir,flip=True): 228 | 229 | img_list = os.listdir(os.path.join(root_dir,'origin')) 230 | length = len(img_list) 231 | 232 | self.length = length 233 | self.img_list = img_list 234 | self.flip = flip 235 | self.root_dir = root_dir 236 | 237 | def __len__(self): 238 | return self.length 239 | 240 | def __getitem__(self,idx): 241 | 242 | img_list = self.img_list 243 | root_dir = self.root_dir 244 | 245 | flip = np.random.choice([True,False]) and self.flip 246 | origin_img = Image.open(os.path.join(root_dir,'origin',img_list[idx])) 247 | if flip: 248 | origin_img = origin_img.transpose(Image.FLIP_LEFT_RIGHT) 249 | origin_img = origin_img.convert('RGB') 250 | img_tensor = transform(origin_img) 251 | 252 | part_arr = [] 253 | for part in ['aero','wing','wheel','vertical','nose']: 254 | part_img = Image.open(os.path.join(root_dir,part,img_list[idx])) 255 | if flip: 256 | part_img = part_img.transpose(Image.FLIP_LEFT_RIGHT) 257 | part_arr.append(np.array(part_img)) 258 | part_arr = np.stack(part_arr,axis=0)/255 259 | 260 | part_tensor = torch.FloatTensor(part_arr) 261 | 262 | return [img_tensor,part_tensor] -------------------------------------------------------------------------------- /utils/dynamic_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from . import util,dataloader 6 | 7 | def default_eval(refer_loader,query_loader,model,class_acc=False): 8 | 9 | class_weight = get_class_weight(refer_loader,model) 10 | 11 | acc = get_prediction(query_loader,model,class_weight,class_acc) 12 | 13 | return acc 14 | 15 | 16 | 17 | def get_class_weight(loader,model): 18 | 19 | dim = model.dim 20 | way = len(loader.dataset.classes) 21 | class_weight = torch.zeros(way,dim).cuda() 22 | 23 | for i, (inp,target) in enumerate(loader): 24 | 25 | current_class_id = target[0] 26 | 27 | with torch.no_grad(): 28 | 29 | if isinstance(inp,list): 30 | (image_inp,mask) = inp 31 | image_inp = image_inp.cuda() 32 | mask = mask.cuda() 33 | feature_vector = model.get_feature_vector(image_inp,mask) 34 | 35 | elif isinstance(inp,torch.Tensor): 36 | inp = inp.cuda() 37 | feature_vector = model.get_feature_vector(inp) 38 | 39 | class_weight[current_class_id] = model.get_single_class_weight(feature_vector) 40 | 41 | return class_weight 42 | 43 | 44 | 45 | 46 | def get_prediction(loader,model,class_weight,class_acc): 47 | 48 | data_source = loader.dataset 49 | 50 | way = len(data_source.classes) 51 | 52 | correct_count = torch.zeros(way).cuda() 53 | 54 | counts = torch.zeros(way).cuda() 55 | 56 | for class_id in data_source.targets: 57 | counts[class_id] += 1 58 | 59 | for i, (inp,target) in enumerate(loader): 60 | 61 | current_class_id = target[0] 62 | target = target.cuda() 63 | batch_size = target.size(0) 64 | 65 | if isinstance(inp,list): 66 | (image_inp,mask) = inp 67 | 68 | image_inp = image_inp.cuda() 69 | mask = mask.cuda() 70 | feature_vector = model.get_feature_vector(image_inp,mask) 71 | 72 | elif isinstance(inp,torch.Tensor): 73 | inp = inp.cuda() 74 | feature_vector = model.get_feature_vector(inp) 75 | 76 | prediction = model.get_prediction(feature_vector,class_weight) 77 | 78 | _, top1_pred = prediction.topk(1) 79 | 80 | correct_count[current_class_id] = torch.sum(torch.eq(top1_pred,target.view(batch_size,1))) 81 | 82 | acc = (torch.sum(correct_count)/torch.sum(counts)).item()*100 83 | 84 | if not class_acc: 85 | return acc 86 | else: 87 | class_acc = torch.mean(correct_count/counts).item()*100 88 | return acc,class_acc 89 | 90 | 91 | 92 | 93 | def k_shot_eval(eval_loader,model,way,shot): 94 | 95 | test_shot = 16 96 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 97 | 98 | acc_list = [] 99 | 100 | for i, (inp,_) in enumerate(eval_loader): 101 | 102 | if isinstance(inp,list): 103 | (image_inp,mask) = inp 104 | image_inp = image_inp.cuda() 105 | mask = mask.cuda() 106 | feature_vector = model.get_feature_vector(image_inp,mask) 107 | 108 | elif isinstance(inp,torch.Tensor): 109 | inp = inp.cuda() 110 | feature_vector = model.get_feature_vector(inp) 111 | 112 | max_index = model.eval_k_shot(feature_vector,way,shot) 113 | 114 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 115 | acc_list.append(acc) 116 | 117 | mean,interval = util.eval(acc_list) 118 | 119 | return mean,interval 120 | 121 | 122 | 123 | def eval_test(model,pm,config,pm_na=None): 124 | 125 | logger = config.logger 126 | annot = config.eval_annot 127 | 128 | logger.info('------------------------') 129 | logger.info('evaluating:') 130 | 131 | with torch.no_grad(): 132 | 133 | model.load_state_dict(torch.load(config.save_path)) 134 | model.eval() 135 | 136 | refer_loader = dataloader.eval_dataloader(pm.test_refer, 137 | annot=annot,annot_path=pm.annot_path) 138 | query_loader = dataloader.eval_dataloader(pm.test_query, 139 | annot=annot,annot_path=pm.annot_path) 140 | 141 | test_acc = default_eval(refer_loader,query_loader,model=model) 142 | logger.info(('the final test acc is %.3f') % (test_acc)) 143 | 144 | way = len(refer_loader.dataset.classes) 145 | for shot in [1,5]: 146 | eval_loader = dataloader.eval_k_shot_dataloader(pm.k_shot, 147 | way=way,shot=shot,annot=annot,annot_path=pm.k_shot_annot_path) 148 | mean,interval = k_shot_eval(eval_loader,model,way,shot) 149 | logger.info('%d-way-%d-shot acc: %.2f\t%.2f'%(way,shot,mean,interval)) 150 | 151 | if pm_na is not None: 152 | 153 | logger.info('------------------------') 154 | logger.info('evaluating on NA:') 155 | 156 | refer_loader = dataloader.eval_dataloader(pm_na.test_refer, 157 | annot=annot,annot_path=pm_na.annot_path) 158 | query_loader = dataloader.eval_dataloader(pm_na.test_query, 159 | annot=annot,annot_path=pm_na.annot_path) 160 | 161 | mean_acc,class_acc = default_eval(refer_loader,query_loader, 162 | model=model,class_acc=True) 163 | 164 | logger.info(('mean_acc is %.3f') % (mean_acc)) 165 | logger.info(('class_acc is %.3f') % (class_acc)) 166 | 167 | -------------------------------------------------------------------------------- /utils/dynamic_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from torchvision.utils import make_grid 7 | from torch.nn import NLLLoss,BCEWithLogitsLoss 8 | from . import util 9 | 10 | def train_stage_1(train_loader,model, 11 | optimizer,writer,iter_counter): 12 | 13 | lr = optimizer.param_groups[0]['lr'] 14 | writer.add_scalar('lr',lr,iter_counter) 15 | criterion = NLLLoss().cuda() 16 | 17 | avg_loss = 0 18 | avg_acc = 0 19 | 20 | for i, (inp,target) in enumerate(train_loader): 21 | 22 | iter_counter += 1 23 | batch_size = target.size(0) 24 | target = target.cuda() 25 | 26 | if isinstance(inp,list): 27 | (image_inp,mask) = inp 28 | image_inp = image_inp.cuda() 29 | mask = mask.cuda() 30 | log_prediction = model.forward_stage_1(image_inp,mask) 31 | 32 | elif isinstance(inp,torch.Tensor): 33 | inp = inp.cuda() 34 | log_prediction = model.forward_stage_1(inp) 35 | 36 | loss = criterion(log_prediction,target) 37 | 38 | optimizer.zero_grad() 39 | loss.backward() 40 | optimizer.step() 41 | 42 | _,max_index = torch.max(log_prediction,1) 43 | acc = 100*(torch.sum(torch.eq(max_index,target)).float()/batch_size).item() 44 | 45 | avg_acc += acc 46 | avg_loss += loss.item() 47 | 48 | avg_loss = avg_loss/(i+1) 49 | avg_acc = avg_acc/(i+1) 50 | 51 | writer.add_scalar('dynamic_loss',avg_loss,iter_counter) 52 | writer.add_scalar('train_acc',avg_acc,iter_counter) 53 | 54 | return iter_counter,avg_acc 55 | 56 | 57 | def train_PN_stage_1(train_loader,model, 58 | optimizer,writer,iter_counter,alpha): 59 | 60 | lr = optimizer.param_groups[0]['lr'] 61 | writer.add_scalar('lr',lr,iter_counter) 62 | criterion = NLLLoss().cuda() 63 | criterion_part = BCEWithLogitsLoss().cuda() 64 | 65 | avg_dynamic_loss = 0 66 | avg_heatmap_loss = 0 67 | avg_total_loss = 0 68 | avg_acc = 0 69 | 70 | for i, ((inp,mask),target) in enumerate(train_loader): 71 | 72 | iter_counter += 1 73 | batch_size = target.size(0) 74 | 75 | inp = inp.cuda() 76 | mask = mask.cuda() 77 | 78 | target = target.cuda() 79 | 80 | if iter_counter%1000==0: 81 | model.eval() 82 | util.visualize(model,writer,iter_counter,inp[:9],mask[:9]) 83 | model.train() 84 | 85 | log_prediction,heatmap_logits = model.forward_stage_1(inp,mask) 86 | 87 | loss_heatmap = criterion_part(heatmap_logits,mask) 88 | loss_dynamic = criterion(log_prediction,target) 89 | loss = alpha*loss_heatmap+loss_dynamic 90 | 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | _,max_index = torch.max(log_prediction,1) 96 | acc = 100*(torch.sum(torch.eq(max_index,target)).float()/batch_size).item() 97 | 98 | avg_acc += acc 99 | avg_total_loss += loss.item() 100 | avg_dynamic_loss += loss_dynamic.item() 101 | avg_heatmap_loss += loss_heatmap.item() 102 | 103 | avg_total_loss = avg_total_loss/(i+1) 104 | avg_dynamic_loss = avg_dynamic_loss/(i+1) 105 | avg_heatmap_loss = avg_heatmap_loss/(i+1) 106 | avg_acc = avg_acc/(i+1) 107 | 108 | writer.add_scalar('total_loss',avg_total_loss,iter_counter) 109 | writer.add_scalar('dynamic_loss',avg_dynamic_loss,iter_counter) 110 | writer.add_scalar('heatmap_loss',avg_heatmap_loss,iter_counter) 111 | 112 | writer.add_scalar('train_acc',avg_acc,iter_counter) 113 | 114 | return iter_counter,avg_acc 115 | 116 | 117 | def train_stage_2(train_loader,model, 118 | optimizer,writer,iter_counter): 119 | 120 | lr = optimizer.param_groups[0]['lr'] 121 | writer.add_scalar('lr',lr,iter_counter) 122 | criterion = NLLLoss().cuda() 123 | 124 | num_fake_novel_class = model.num_fake_novel_class 125 | shots = model.shots[0] 126 | num_class = model.num_class 127 | way = model.way 128 | 129 | avg_loss = 0 130 | avg_acc = 0 131 | 132 | for i, (inp,target) in enumerate(train_loader): 133 | 134 | iter_counter += 1 135 | 136 | fake_novel_class_id = target.view(way,shots)[:num_fake_novel_class,0] 137 | fake_novel_class_id_list = fake_novel_class_id.tolist() 138 | fake_novel_class_id = fake_novel_class_id.cuda() 139 | 140 | fake_base_class_id_list = [] 141 | 142 | for j in range(num_class): 143 | if j not in fake_novel_class_id_list: 144 | fake_base_class_id_list.append(j) 145 | fake_base_class_id = torch.tensor(fake_base_class_id_list).long().cuda() 146 | 147 | target_for_loss = target.view(way,shots)[:,5:].cuda() 148 | target_for_loss = target_for_loss.view(-1) 149 | 150 | if isinstance(inp,list): 151 | (image_inp,mask) = inp 152 | image_inp = image_inp.cuda() 153 | mask = mask.cuda() 154 | feature_vector = model.get_feature_vector(image_inp,mask) 155 | 156 | elif isinstance(inp,torch.Tensor): 157 | inp = inp.cuda() 158 | feature_vector = model.get_feature_vector(inp) 159 | 160 | log_prediction = model.forward_stage_2(feature_vector,fake_novel_class_id,fake_base_class_id) 161 | loss = criterion(log_prediction,target_for_loss) 162 | 163 | optimizer.zero_grad() 164 | loss.backward() 165 | optimizer.step() 166 | 167 | loss_value = loss.item() 168 | _,max_index = torch.max(log_prediction,1) 169 | acc = 100*(torch.sum(torch.eq(max_index,target_for_loss)).float()/target_for_loss.size(0)).item() 170 | 171 | avg_acc += acc 172 | avg_loss += loss_value 173 | 174 | avg_loss = avg_loss/(i+1) 175 | avg_acc = avg_acc/(i+1) 176 | 177 | writer.add_scalar('dynamic_loss',avg_loss,iter_counter) 178 | writer.add_scalar('train_acc',avg_acc,iter_counter) 179 | 180 | return iter_counter,avg_acc -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as torch_models 5 | import numpy as np 6 | 7 | class ConvBlock(nn.Module): 8 | 9 | def __init__(self,input_channel,output_channel): 10 | super().__init__() 11 | 12 | self.layers = nn.Sequential( 13 | nn.Conv2d(input_channel,output_channel,kernel_size=3,padding=1), 14 | nn.BatchNorm2d(output_channel)) 15 | 16 | def forward(self,inp): 17 | return self.layers(inp) 18 | 19 | 20 | class BackBone(nn.Module): 21 | 22 | def __init__(self,num_channel=64): 23 | super().__init__() 24 | 25 | self.layers = nn.Sequential( 26 | ConvBlock(3,num_channel), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(2), 29 | ConvBlock(num_channel,num_channel), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(2), 32 | ConvBlock(num_channel,num_channel), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(2), 35 | ConvBlock(num_channel,num_channel)) 36 | 37 | def forward(self,inp): 38 | 39 | return self.layers(inp) 40 | 41 | 42 | class BackBone_ResNet(nn.Module): 43 | 44 | def __init__(self,num_channel=32): 45 | super().__init__() 46 | 47 | resnet18 = torch_models.resnet18() 48 | conv1 = resnet18.conv1 49 | bn1 = resnet18.bn1 50 | relu = resnet18.relu 51 | maxpool = resnet18.maxpool 52 | layer1 = resnet18.layer1 53 | layer2 = resnet18.layer2 54 | layer3 = resnet18.layer3 55 | layer4 = resnet18.layer4 56 | 57 | layer4[0].conv1 = nn.Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False) 58 | layer4[0].downsample[0] = nn.Conv2d(256,512,kernel_size=(1,1),stride=(1,1),bias=False) 59 | 60 | layer5 = nn.Sequential( 61 | nn.Conv2d(512,num_channel,kernel_size=1,stride=1,padding=0), 62 | nn.BatchNorm2d(num_channel)) 63 | 64 | self.layers = nn.Sequential(conv1,bn1,relu,maxpool,layer1,layer2,layer3,layer4,layer5) 65 | 66 | del resnet18 67 | 68 | def forward(self,inp): 69 | 70 | return self.layers(inp) 71 | 72 | 73 | 74 | def proto_eval_k_shot(feature_vector,way,shot,dim): 75 | 76 | support = feature_vector[:way*shot].view(way,shot,dim) 77 | centroid = torch.mean(support,1).unsqueeze(0) 78 | query = feature_vector[way*shot:].unsqueeze(1) 79 | 80 | neg_l2_distance = torch.sum((centroid-query)**2,-1).neg().view(way*16,way) 81 | _,max_index = torch.max(neg_l2_distance,1) 82 | 83 | return max_index 84 | 85 | 86 | def proto_forward_log_pred(feature_vector,train_shot,test_shot,dim,way): 87 | 88 | support = feature_vector[:way*train_shot].view(way,train_shot,dim) 89 | centroid = torch.mean(support,1).unsqueeze(0) 90 | query = feature_vector[way*train_shot:].unsqueeze(1) 91 | 92 | neg_l2_distance = torch.sum((centroid-query)**2,-1).neg().view(way*test_shot,way) 93 | log_prediction = F.log_softmax(neg_l2_distance,dim=1) 94 | 95 | return log_prediction 96 | 97 | 98 | 99 | class Proto_Model(nn.Module): 100 | 101 | def __init__(self,way=None,shots=None,resnet=False): 102 | 103 | super().__init__() 104 | if resnet: 105 | num_channel = 32 106 | self.feature_extractor = BackBone_ResNet(num_channel) 107 | else: 108 | num_channel = 64 109 | self.feature_extractor = BackBone(num_channel) 110 | self.shots = shots 111 | self.way = way 112 | self.num_channel = num_channel 113 | self.dim = num_channel 114 | 115 | 116 | def get_feature_vector(self,inp): 117 | pass 118 | 119 | 120 | def eval_k_shot(self,inp,way,shot): 121 | 122 | feature_vector = self.get_feature_vector(inp) 123 | max_index = proto_eval_k_shot(feature_vector, 124 | way = way, 125 | shot = shot, 126 | dim = self.dim) 127 | 128 | return max_index 129 | 130 | 131 | def forward(self,inp): 132 | 133 | feature_vector = self.get_feature_vector(inp) 134 | log_prediction = proto_forward_log_pred(feature_vector, 135 | train_shot = self.shots[0], 136 | test_shot = self.shots[1], 137 | dim = self.dim, 138 | way = self.way) 139 | 140 | return log_prediction 141 | 142 | 143 | 144 | class PN_Model(nn.Module): 145 | 146 | def __init__(self,num_part,resnet=False): 147 | super().__init__() 148 | 149 | if resnet: 150 | 151 | num_channel = 32 152 | 153 | resnet18 = torch_models.resnet18() 154 | conv1 = resnet18.conv1 155 | bn1 = resnet18.bn1 156 | relu = resnet18.relu 157 | maxpool = resnet18.maxpool 158 | layer1 = resnet18.layer1 159 | layer2 = resnet18.layer2 160 | layer3 = resnet18.layer3 161 | layer4 = resnet18.layer4 162 | 163 | layer4[0].conv1 = nn.Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False) 164 | layer4[0].downsample[0] = nn.Conv2d(256,512,kernel_size=(1,1),stride=(1,1),bias=False) 165 | 166 | layer5 = nn.Sequential( 167 | nn.Conv2d(512,num_channel,kernel_size=1,stride=1,padding=0), 168 | nn.BatchNorm2d(num_channel)) 169 | 170 | self.shared_layers = nn.Sequential(conv1, 171 | bn1,relu,maxpool,layer1,layer2,layer3) 172 | 173 | self.class_branch = nn.Sequential(layer4,layer5) 174 | 175 | self.part_branch = nn.Sequential( 176 | nn.Conv2d(256,64,kernel_size=3,stride=1,padding=1), 177 | nn.BatchNorm2d(64), 178 | nn.ReLU(inplace=True), 179 | nn.Conv2d(64,num_part,kernel_size=3,stride=1,padding=1)) 180 | 181 | del resnet18 182 | 183 | else: 184 | 185 | num_channel = 64 186 | 187 | self.shared_layers = nn.Sequential( 188 | ConvBlock(3,num_channel), 189 | nn.ReLU(inplace=True), 190 | nn.MaxPool2d(2), 191 | ConvBlock(num_channel,num_channel), 192 | nn.ReLU(inplace=True), 193 | nn.MaxPool2d(2)) 194 | 195 | self.part_branch = nn.Sequential( 196 | nn.Conv2d(num_channel,30,kernel_size=3,stride=2,padding=0), 197 | nn.BatchNorm2d(30), 198 | nn.ReLU(inplace=True), 199 | nn.Conv2d(30,num_part,kernel_size=3,stride=1,padding=1)) 200 | 201 | self.class_branch = nn.Sequential( 202 | ConvBlock(num_channel,num_channel), 203 | nn.ReLU(inplace=True), 204 | nn.MaxPool2d(2), 205 | ConvBlock(num_channel,num_channel)) 206 | 207 | self.num_channel = num_channel 208 | self.num_part = num_part 209 | self.dim = num_channel*num_part 210 | 211 | 212 | def get_heatmap(self,inp): 213 | 214 | logits = self.part_branch(self.shared_layers(inp)) 215 | heat_map = nn.Sigmoid()(logits) 216 | return heat_map 217 | 218 | def eval_k_shot(self,inp,way,shot): 219 | 220 | feature_vector = self.get_feature_vector(inp) 221 | max_index = proto_eval_k_shot(feature_vector, 222 | way = way, 223 | shot = shot, 224 | dim = self.dim) 225 | 226 | return max_index 227 | 228 | 229 | 230 | class Dynamic_Model(nn.Module): 231 | 232 | def __init__(self,dim,num_class,way=None,shots=None,num_fake_novel_class=16): 233 | 234 | super().__init__() 235 | 236 | weight_base = torch.FloatTensor(num_class,dim).normal_(0.0, np.sqrt(2.0/(dim))) 237 | self.weight_base = nn.Parameter(weight_base,requires_grad=True) 238 | 239 | scale_cls = 10.0 240 | self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(scale_cls),requires_grad=True) 241 | self.scale_cls_att = nn.Parameter(torch.FloatTensor(1).fill_(scale_cls),requires_grad=True) 242 | 243 | self.phi_avg = nn.Parameter(torch.FloatTensor(dim).fill_(1),requires_grad=True) 244 | self.phi_att = nn.Parameter(torch.FloatTensor(dim).fill_(1),requires_grad=True) 245 | 246 | self.phi_q = nn.Linear(dim,dim) 247 | self.phi_q.weight.data.copy_(torch.eye(dim,dim)+torch.randn(dim,dim)*0.001) 248 | self.phi_q.bias.data.zero_() 249 | 250 | weight_keys = torch.FloatTensor(num_class,dim).normal_(0.0, np.sqrt(2.0/(dim))) 251 | self.weight_keys = nn.Parameter(weight_keys, requires_grad=True) 252 | 253 | self.dim = dim 254 | self.num_class = num_class 255 | self.num_fake_novel_class = num_fake_novel_class 256 | self.way = way 257 | self.shots = shots 258 | 259 | def get_feature_vector(self,inp): 260 | pass 261 | 262 | def weight_generator(self,fake_base_weight,fake_novel_feature_vector,fake_base_class_id): 263 | 264 | dim = self.dim 265 | num_fake_novel_class = self.num_fake_novel_class 266 | 267 | avg_feature_vector = torch.mean(fake_novel_feature_vector,dim=1) # 5class,channel 268 | avg_weight = self.phi_avg.unsqueeze(0)*avg_feature_vector # 5class,channel 269 | 270 | fake_base_weight = F.normalize(fake_base_weight,p=2,dim=1,eps=1e-12) # 155,channel 271 | 272 | query = self.phi_q(fake_novel_feature_vector.contiguous().view(num_fake_novel_class*5,dim)) # 25, channel 273 | query = F.normalize(query,p=2,dim=1,eps=1e-12) # 25,channel 274 | 275 | weight_keys = self.weight_keys[fake_base_class_id] # the keys of the base categoreis 276 | weight_keys = F.normalize(weight_keys,p=2,dim=1,eps=1e-12) # 155,channel 277 | 278 | logits = self.scale_cls_att*torch.matmul(query,weight_keys.transpose(0,1)) # 25,155 279 | att_score = F.softmax(logits,dim=1) # 25,155 280 | 281 | att_scored_fake_base_weight = torch.matmul(att_score,fake_base_weight) # 25,channel 282 | att_weight = self.phi_att*torch.mean(att_scored_fake_base_weight.view(num_fake_novel_class,5,dim),dim=1) # 5,channel 283 | 284 | fake_novel_weight = avg_weight+att_weight 285 | 286 | return fake_novel_weight 287 | 288 | def forward_stage_2(self,feature_vector,fake_novel_class_id,fake_base_class_id): 289 | 290 | dim = self.dim 291 | way = self.way 292 | shots = self.shots[0] 293 | weight_base = self.weight_base 294 | num_fake_novel_class = self.num_fake_novel_class 295 | 296 | feature_vector = feature_vector.view(way,shots,dim) 297 | feature_vector = F.normalize(feature_vector,p=2,dim=2,eps=1e-12) 298 | 299 | fake_novel_feature_vector = feature_vector[:num_fake_novel_class,:5,:] # 5 class,5 shot,channel 300 | 301 | fake_base_weight = weight_base[fake_base_class_id] # 155,channel 302 | fake_novel_weight = self.weight_generator(fake_base_weight,fake_novel_feature_vector,fake_base_class_id) # 5,channel 303 | 304 | weight_base_clone = weight_base.clone() 305 | weight_base_clone[fake_novel_class_id] = fake_novel_weight #160,channel 306 | 307 | feature_vector_test = feature_vector[:,5:,:].contiguous().view(way*15,dim)# 15,15,channel 308 | 309 | norm_weight = F.normalize(weight_base_clone,p=2,dim=1,eps=1e-12) 310 | 311 | logits = self.scale_cls*torch.matmul(feature_vector_test,norm_weight.transpose(0,1)) 312 | 313 | log_prediction = F.log_softmax(logits,dim=1) 314 | 315 | return log_prediction 316 | 317 | def get_prediction(self,feature_vector,class_weight): 318 | 319 | batch_size = feature_vector.size(0) 320 | dim = self.dim 321 | 322 | feature_vector = feature_vector.view(batch_size,dim) 323 | feature_vector = F.normalize(feature_vector,p=2,dim=1,eps=1e-12) 324 | 325 | norm_weight = F.normalize(class_weight,p=2,dim=1,eps=1e-12) 326 | logits = self.scale_cls*torch.matmul(feature_vector,norm_weight.transpose(0,1)) 327 | 328 | log_prediction = F.log_softmax(logits,dim=1) 329 | 330 | return log_prediction 331 | 332 | def get_single_class_weight(self,feature_vector): 333 | 334 | batch_size = feature_vector.size(0) 335 | dim = self.dim 336 | 337 | feature_vector = feature_vector.view(batch_size,dim) 338 | feature_vector = F.normalize(feature_vector,p=2,dim=1,eps=1e-12) # batch,channel 339 | 340 | avg_feature_vector = torch.mean(feature_vector,dim=0) # channel 341 | avg_weight = self.phi_avg*avg_feature_vector # channel 342 | 343 | norm_base_weight = F.normalize(self.weight_base,p=2,dim=1,eps=1e-12) # 160,channel 344 | 345 | query = self.phi_q(feature_vector) # batch, channel 346 | query = F.normalize(query,p=2,dim=1,eps=1e-12) # batch,channel 347 | 348 | weight_keys = self.weight_keys # the keys of the base categoreis 349 | weight_keys = F.normalize(weight_keys,p=2,dim=1,eps=1e-12) # 160,channel 350 | 351 | logits = self.scale_cls_att*torch.matmul(query,weight_keys.transpose(0,1)) # batch,160 352 | att_score = F.softmax(logits,dim=1) # batch,160 353 | 354 | att_scored_base_weight = torch.matmul(att_score,norm_base_weight) # batch,channel 355 | att_weight = self.phi_att*torch.mean(att_scored_base_weight,dim=0) # channel 356 | 357 | novel_weight = avg_weight+att_weight 358 | 359 | return novel_weight 360 | 361 | def eval_k_shot(self,feature_vector,way,shot): 362 | 363 | dim = self.dim 364 | support = feature_vector[:way*shot].view(way,shot,dim) 365 | class_weight = torch.zeros(way,dim).cuda() 366 | 367 | for i in range(way): 368 | class_weight[i] = self.get_single_class_weight(support[i]) 369 | 370 | log_prediction = self.get_prediction(feature_vector[way*shot:],class_weight) 371 | _,max_index = torch.max(log_prediction,1) 372 | 373 | return max_index 374 | -------------------------------------------------------------------------------- /utils/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as torch_models 5 | import numpy as np 6 | from . import models 7 | 8 | EPS=0.00001 9 | 10 | 11 | def feature_map2vec(feature_map,mask): 12 | 13 | batch_size = feature_map.size(0) 14 | num_channel = feature_map.size(1) 15 | num_part = mask.size(1) 16 | 17 | feature_map = feature_map.unsqueeze(2) 18 | sum_of_weight = mask.view(batch_size,num_part,-1).sum(-1)+EPS 19 | mask = mask.unsqueeze(1) 20 | 21 | vec = (feature_map*mask).view(batch_size,num_channel,num_part,-1).sum(-1) 22 | vec = vec/sum_of_weight.unsqueeze(1) 23 | vec = vec.view(batch_size,num_channel*num_part) 24 | 25 | return vec 26 | 27 | 28 | 29 | class Proto(models.Proto_Model): 30 | 31 | def get_feature_vector(self,inp): 32 | 33 | batch_size = inp.size(0) 34 | feature_map = self.feature_extractor(inp) 35 | feature_vector = F.avg_pool2d(input=feature_map,kernel_size=feature_map.size(-1)) 36 | feature_vector = feature_vector.view(batch_size,-1) 37 | 38 | return feature_vector 39 | 40 | 41 | class Proto_BP(models.Proto_Model): 42 | 43 | def __init__(self,way=None,shots=None,resnet=False): 44 | 45 | super().__init__(way=way,shots=shots,resnet=resnet) 46 | self.dim = self.num_channel**2 47 | 48 | def get_feature_vector(self,inp): 49 | 50 | feature_map = self.feature_extractor(inp) 51 | feature_vector = self.covariance_pool(feature_map) 52 | return feature_vector 53 | 54 | def covariance_pool(self,inp): 55 | 56 | batch_size = inp.size(0) 57 | channel = inp.size(1) 58 | height = inp.size(2) 59 | width = inp.size(3) 60 | 61 | out = inp.view(batch_size,channel,height*width) 62 | out = torch.bmm(out,torch.transpose(out,1,2))/(height*width) 63 | out = out.view(batch_size,channel**2) 64 | out = out.sign().float()*(out.abs()+EPS).sqrt() 65 | 66 | return out 67 | 68 | 69 | class Proto_FSL(models.Proto_Model): 70 | 71 | def __init__(self,way=None,shots=None,resnet=False): 72 | 73 | super().__init__(way=way,shots=shots,resnet=resnet) 74 | self.dim = self.num_channel*2 75 | 76 | 77 | def get_fb_vector(self,inp,mask): 78 | 79 | batch_size = inp.size(0) 80 | num_channel = self.num_channel 81 | 82 | feature_map = self.feature_extractor(inp).unsqueeze(2) 83 | mask = torch.cat([mask,1-mask],dim=1).view(batch_size,1,2,mask.size(2),mask.size(3)) 84 | sum_of_weight = mask.view(*mask.size()[:-2],-1).sum(-1)+EPS 85 | fb_vector = (feature_map*mask).view(batch_size,num_channel,2,-1).sum(-1)/sum_of_weight 86 | 87 | return fb_vector 88 | 89 | 90 | def get_feature_vector(self,inp,fb_vector): 91 | 92 | feature_map = self.feature_extractor(inp) 93 | mask = self.feature_map2mask(feature_map,fb_vector) 94 | feature_vector = feature_map2vec(feature_map,mask) 95 | 96 | return feature_vector 97 | 98 | def feature_map2mask(self,feature_map,fb_vector): 99 | 100 | num_channel = self.num_channel 101 | 102 | feature_map = feature_map.unsqueeze(2) 103 | fb_vector = fb_vector.view(1,num_channel,2,1,1) 104 | 105 | mask = torch.sum((feature_map-fb_vector)**2,1).neg() 106 | mask = F.softmax(mask,dim=1) 107 | 108 | return mask 109 | 110 | 111 | def eval_k_shot(self,inp,mask,way,shot): 112 | 113 | fb_vector = self.get_fb_vector(inp[:way*shot],mask[:way*shot]).mean(0) 114 | feature_vector = self.get_feature_vector(inp,fb_vector) 115 | 116 | max_index = models.proto_eval_k_shot(feature_vector, 117 | way = way, 118 | shot = shot, 119 | dim = self.dim) 120 | 121 | return max_index 122 | 123 | 124 | def forward(self,inp,mask): 125 | 126 | feature_map = self.feature_extractor(inp) 127 | 128 | refer_shot = self.shots[0] 129 | train_shot = self.shots[1] 130 | test_shot = self.shots[2] 131 | 132 | way = self.way 133 | num_channel = self.num_channel 134 | refer_batch_size = way*refer_shot 135 | 136 | refer_mask = mask[:refer_batch_size] 137 | refer_feature_map = feature_map[:refer_batch_size] 138 | refer_mask = torch.cat([refer_mask,1-refer_mask],dim=1).unsqueeze(1) 139 | sum_of_weight = refer_mask.view(*refer_mask.size()[:-2],-1).sum(-1)+EPS 140 | 141 | fb_vector = (refer_feature_map.unsqueeze(2)*refer_mask).view(refer_batch_size,num_channel,2,-1).sum(-1)/sum_of_weight 142 | fb_vector = fb_vector.mean(0) 143 | 144 | feature_map = feature_map[refer_batch_size:] 145 | fb_mask = self.feature_map2mask(feature_map,fb_vector) 146 | feature_vector = feature_map2vec(feature_map,fb_mask) 147 | 148 | log_prediction = models.proto_forward_log_pred(feature_vector, 149 | train_shot = train_shot, 150 | test_shot = test_shot, 151 | dim = self.dim, 152 | way = way) 153 | 154 | return log_prediction 155 | 156 | 157 | class Proto_bbN(models.PN_Model): 158 | 159 | def __init__(self,num_part=2,way=None,shots=None,resnet=False): 160 | 161 | super().__init__(num_part=num_part,resnet=resnet) 162 | self.shots = shots 163 | self.way = way 164 | 165 | def get_feature_vector(self,inp): 166 | 167 | temp = self.shared_layers(inp) 168 | feature_map = self.class_branch(temp) 169 | mask = nn.Softmax(dim=1)(self.part_branch(temp)) 170 | 171 | vec = feature_map2vec(feature_map,mask) 172 | 173 | return vec 174 | 175 | def get_heatmap(self,inp): 176 | 177 | logits = self.part_branch(self.shared_layers(inp)) 178 | heat_map = nn.Softmax(dim=1)(logits) 179 | return heat_map 180 | 181 | 182 | def forward(self,inp,mask): 183 | 184 | temp = self.shared_layers(inp) 185 | feature_map = self.class_branch(temp) 186 | heatmap = nn.Softmax(dim=1)(self.part_branch(temp)) 187 | 188 | feature_vector = feature_map2vec(feature_map,mask) 189 | 190 | log_prediction = models.proto_forward_log_pred(feature_vector, 191 | train_shot = self.shots[0], 192 | test_shot = self.shots[1], 193 | dim = self.dim, 194 | way = self.way) 195 | 196 | return log_prediction,heatmap 197 | 198 | 199 | 200 | class Proto_MT(models.PN_Model): 201 | 202 | def __init__(self,num_part=15,way=None,shots=None,resnet=False): 203 | 204 | super().__init__(num_part=num_part,resnet=resnet) 205 | self.shots = shots 206 | self.way = way 207 | self.dim = self.num_channel 208 | 209 | def get_feature_vector(self,inp): 210 | 211 | batch_size = inp.size(0) 212 | temp = self.shared_layers(inp) 213 | feature_map = self.class_branch(temp) 214 | 215 | feature_vector = F.avg_pool2d(input=feature_map,kernel_size=feature_map.size(-1)) 216 | feature_vector = feature_vector.view(batch_size,-1) 217 | 218 | return feature_vector 219 | 220 | 221 | def forward(self,inp,mask): 222 | 223 | temp = self.shared_layers(inp) 224 | feature_map = self.class_branch(temp) 225 | heatmap_logits = self.part_branch(temp) 226 | 227 | feature_vector = F.avg_pool2d(input=feature_map,kernel_size=feature_map.size(-1)) 228 | feature_vector = feature_vector.view(feature_vector.size(0),-1) 229 | 230 | log_prediction = models.proto_forward_log_pred(feature_vector, 231 | train_shot = self.shots[0], 232 | test_shot = self.shots[1], 233 | dim = self.dim, 234 | way = self.way) 235 | 236 | return log_prediction,heatmap_logits 237 | 238 | 239 | 240 | class Proto_uPN(models.Proto_Model): 241 | 242 | def __init__(self,num_part=15,way=None,shots=None,resnet=False): 243 | 244 | super().__init__(way=way,shots=shots,resnet=resnet) 245 | num_channel = self.num_channel 246 | self.dim = num_channel*num_part 247 | self.num_part = num_part 248 | self.part_vector = nn.Parameter(torch.randn(1,num_channel,num_part,1,1)) 249 | 250 | def get_feature_vector(self,inp): 251 | 252 | num_channel = self.num_channel 253 | num_part = self.num_part 254 | dim = self.dim 255 | 256 | batch_size = inp.size(0) 257 | 258 | feature_map = self.feature_extractor(inp) 259 | fm_size = feature_map.size(-1) 260 | feature_map = feature_map.unsqueeze(2) 261 | 262 | mask = torch.sum((feature_map-self.part_vector)**2,1).neg().view(batch_size,num_part,-1) 263 | mask = F.softmax(mask,dim=2).view(batch_size,1,num_part,fm_size,fm_size) 264 | 265 | feature_vector = (feature_map*mask).view(batch_size,num_channel,num_part,-1).sum(-1) 266 | feature_vector = feature_vector.view(batch_size,dim) 267 | 268 | return feature_vector 269 | 270 | def get_heatmap(self,inp): 271 | 272 | num_part = self.num_part 273 | 274 | batch_size = inp.size(0) 275 | 276 | feature_map = self.feature_extractor(inp) 277 | fm_size = feature_map.size(-1) 278 | feature_map = feature_map.unsqueeze(2) 279 | 280 | mask = torch.sum((feature_map-self.part_vector)**2,1).neg().view(batch_size,num_part,-1) 281 | mask = F.softmax(mask,dim=2).view(batch_size,num_part,fm_size,fm_size) 282 | 283 | return mask 284 | 285 | 286 | 287 | class Proto_PN(models.PN_Model): 288 | 289 | def __init__(self,num_part=15,way=None,shots=None,resnet=False): 290 | 291 | super().__init__(num_part=num_part,resnet=resnet) 292 | self.shots = shots 293 | self.way = way 294 | 295 | def get_feature_vector(self,inp): 296 | 297 | temp = self.shared_layers(inp) 298 | feature_map = self.class_branch(temp) 299 | mask = nn.Sigmoid()(self.part_branch(temp)) 300 | 301 | vec = feature_map2vec(feature_map,mask) 302 | 303 | return vec 304 | 305 | def forward(self,inp,mask): 306 | 307 | temp = self.shared_layers(inp) 308 | feature_map = self.class_branch(temp) 309 | heatmap_logits = self.part_branch(temp) 310 | 311 | feature_vector = feature_map2vec(feature_map,mask) 312 | 313 | log_prediction = models.proto_forward_log_pred(feature_vector, 314 | train_shot = self.shots[0], 315 | test_shot = self.shots[1], 316 | dim = self.dim, 317 | way = self.way) 318 | 319 | return log_prediction,heatmap_logits 320 | 321 | 322 | 323 | class Proto_PN_less_annot(Proto_PN): 324 | 325 | def forward_class(self,inp): 326 | 327 | feature_vector = self.get_feature_vector(inp) 328 | 329 | log_prediction = models.proto_forward_log_pred(feature_vector, 330 | train_shot = self.shots[0], 331 | test_shot = self.shots[1], 332 | dim = self.dim, 333 | way = self.way) 334 | 335 | return log_prediction 336 | 337 | def forward_part(self,inp): 338 | 339 | temp = self.shared_layers(inp) 340 | heatmap_logits = self.part_branch(temp) 341 | 342 | return heatmap_logits 343 | 344 | 345 | 346 | 347 | class Proto_PN_gt(models.Proto_Model): 348 | 349 | def __init__(self,num_part=15,way=None,shots=None,resnet=False): 350 | 351 | super().__init__(way=way,shots=shots,resnet=resnet) 352 | self.dim = self.num_channel*num_part 353 | self.num_part = num_part 354 | 355 | def get_feature_vector(self,inp,mask): 356 | 357 | batch_size = inp.size(0) 358 | num_channel = self.num_channel 359 | num_part = self.num_part 360 | dim = self.dim 361 | 362 | feature_map = self.feature_extractor(inp).unsqueeze(2) 363 | mask = mask.unsqueeze(1) 364 | 365 | feature_vector = (feature_map*mask).view(batch_size,num_channel,num_part,-1).sum(-1) 366 | feature_vector = feature_vector.view(batch_size,dim) 367 | 368 | return feature_vector 369 | 370 | def eval_k_shot(self,inp,mask,way,shot): 371 | 372 | feature_vector = self.get_feature_vector(inp,mask) 373 | 374 | max_index = models.proto_eval_k_shot(feature_vector, 375 | way = way, 376 | shot = shot, 377 | dim = self.dim) 378 | 379 | return max_index 380 | 381 | def forward(self,inp,mask): 382 | 383 | feature_vector = self.get_feature_vector(inp,mask) 384 | 385 | log_prediction = models.proto_forward_log_pred(feature_vector, 386 | train_shot = self.shots[0], 387 | test_shot = self.shots[1], 388 | dim = self.dim, 389 | way = self.way) 390 | 391 | return log_prediction 392 | 393 | 394 | 395 | class Transfer(nn.Module): 396 | 397 | def __init__(self,num_class=100,resnet=False): 398 | 399 | super().__init__() 400 | if resnet: 401 | num_channel = 32 402 | self.feature_extractor = models.BackBone_ResNet(num_channel) 403 | else: 404 | num_channel = 64 405 | self.feature_extractor = models.BackBone(num_channel) 406 | 407 | self.linear_classifier = nn.Linear(num_channel,num_class) 408 | self.num_channel = num_channel 409 | self.num_class = num_class 410 | self.dim = num_channel 411 | 412 | def get_feature_vector(self,inp): 413 | 414 | batch_size = inp.size(0) 415 | feature_map = self.feature_extractor(inp) 416 | feature_vector = F.avg_pool2d(input=feature_map,kernel_size=feature_map.size(-1)) 417 | feature_vector = feature_vector.view(batch_size,-1) 418 | 419 | return feature_vector 420 | 421 | def forward(self,inp): 422 | 423 | feature_vector = self.get_feature_vector(inp) 424 | logits = self.linear_classifier(feature_vector) 425 | log_prediction = F.log_softmax(logits,dim=1) 426 | 427 | return log_prediction 428 | 429 | 430 | class Transfer_PN(models.PN_Model): 431 | 432 | def __init__(self,num_class=100,num_part=15,resnet=False): 433 | 434 | super().__init__(num_part=num_part,resnet=resnet) 435 | self.linear_classifier = nn.Linear(self.dim,num_class) 436 | self.num_class = num_class 437 | 438 | def forward(self,inp,mask=None): 439 | 440 | batch_size = inp.size(0) 441 | num_channel = self.num_channel 442 | num_part = self.num_part 443 | 444 | temp = self.shared_layers(inp) 445 | feature_map = self.class_branch(temp) 446 | heatmap_logits = self.part_branch(temp) 447 | 448 | is_training = True 449 | 450 | if mask is None: 451 | mask = nn.Sigmoid()(heatmap_logits) 452 | is_training = False 453 | 454 | feature_vector = feature_map2vec(feature_map,mask) 455 | 456 | logits = self.linear_classifier(feature_vector) 457 | log_prediction = F.log_softmax(logits,dim=1) 458 | 459 | if is_training: 460 | return log_prediction,heatmap_logits 461 | else: 462 | return log_prediction 463 | 464 | 465 | class Transfer_PN_gt(nn.Module): 466 | 467 | def __init__(self,num_class=100,num_part=15,resnet=False): 468 | 469 | super().__init__() 470 | if resnet: 471 | num_channel = 32 472 | self.feature_extractor = models.BackBone_ResNet(num_channel) 473 | else: 474 | num_channel = 64 475 | self.feature_extractor = models.BackBone(num_channel) 476 | 477 | self.num_channel = num_channel 478 | self.num_part = num_part 479 | self.num_class = num_class 480 | self.dim = num_channel*num_part 481 | self.linear_classifier = nn.Linear(self.dim,num_class) 482 | 483 | def get_feature_vector(self,inp,mask): 484 | 485 | batch_size = inp.size(0) 486 | num_channel = self.num_channel 487 | num_part = self.num_part 488 | dim = self.dim 489 | 490 | feature_map = self.feature_extractor(inp).unsqueeze(2) 491 | mask = mask.unsqueeze(1) 492 | 493 | feature_vector = (feature_map*mask).view(batch_size,num_channel,num_part,-1).sum(-1) 494 | feature_vector = feature_vector.view(batch_size,dim) 495 | 496 | return feature_vector 497 | 498 | def forward(self,inp,mask): 499 | 500 | feature_vector = self.get_feature_vector(inp,mask) 501 | logits = self.linear_classifier(feature_vector) 502 | log_prediction = F.log_softmax(logits,dim=1) 503 | 504 | return log_prediction 505 | 506 | 507 | 508 | class Dynamic(models.Dynamic_Model): 509 | 510 | def __init__(self,num_class=100,resnet=False,way=None,shots=None): 511 | 512 | if resnet: 513 | num_channel = 32 514 | else: 515 | num_channel = 64 516 | 517 | super().__init__(num_class=num_class,dim=num_channel,way=way,shots=shots) 518 | 519 | if resnet: 520 | self.feature_extractor = models.BackBone_ResNet(num_channel) 521 | else: 522 | self.feature_extractor = models.BackBone(num_channel) 523 | 524 | 525 | def get_feature_vector(self,inp): 526 | 527 | batch_size = inp.size(0) 528 | feature_map = self.feature_extractor(inp) 529 | feature_vector = F.avg_pool2d(input=feature_map,kernel_size=feature_map.size(-1)) 530 | feature_vector = feature_vector.view(batch_size,-1) 531 | 532 | return feature_vector 533 | 534 | def forward_stage_1(self,inp): 535 | 536 | feature_vector = self.get_feature_vector(inp) 537 | log_prediction = self.get_prediction(feature_vector,self.weight_base) 538 | 539 | return log_prediction 540 | 541 | 542 | class Dynamic_PN(models.Dynamic_Model): 543 | 544 | def __init__(self,num_class=100,num_part=15,resnet=False,way=None,shots=None): 545 | 546 | if resnet: 547 | num_channel = 32 548 | else: 549 | num_channel = 64 550 | 551 | super().__init__(num_class=num_class,dim=num_channel*num_part,way=way,shots=shots) 552 | 553 | self.PN_Model = models.PN_Model(num_part=num_part,resnet=resnet) 554 | self.num_part = num_part 555 | 556 | def get_feature_vector(self,inp): 557 | 558 | temp = self.PN_Model.shared_layers(inp) 559 | feature_map = self.PN_Model.class_branch(temp) 560 | mask = nn.Sigmoid()(self.PN_Model.part_branch(temp)) 561 | 562 | vec = feature_map2vec(feature_map,mask) 563 | 564 | return vec 565 | 566 | def get_heatmap(self,inp): 567 | 568 | return self.PN_Model.get_heatmap(inp) 569 | 570 | def forward_stage_1(self,inp,mask): 571 | 572 | dim = self.dim 573 | batch_size = inp.size(0) 574 | 575 | temp = self.PN_Model.shared_layers(inp) 576 | feature_map = self.PN_Model.class_branch(temp) 577 | heatmap_logits = self.PN_Model.part_branch(temp) 578 | 579 | feature_vector = feature_map2vec(feature_map,mask).view(batch_size,dim) 580 | log_prediction = self.get_prediction(feature_vector,self.weight_base) 581 | 582 | return log_prediction,heatmap_logits 583 | 584 | 585 | class Dynamic_PN_gt(models.Dynamic_Model): 586 | 587 | def __init__(self,num_class=100,num_part=15,resnet=False,way=None,shots=None): 588 | 589 | if resnet: 590 | num_channel = 32 591 | else: 592 | num_channel = 64 593 | 594 | super().__init__(num_class=num_class,dim=num_channel*num_part,way=way,shots=shots) 595 | 596 | self.num_part = num_part 597 | self.num_channel = num_channel 598 | 599 | if resnet: 600 | self.feature_extractor = models.BackBone_ResNet(num_channel) 601 | else: 602 | self.feature_extractor = models.BackBone(num_channel) 603 | 604 | def get_feature_vector(self,inp,mask): 605 | 606 | batch_size = inp.size(0) 607 | num_channel = self.num_channel 608 | num_part = self.num_part 609 | dim = self.dim 610 | 611 | feature_map = self.feature_extractor(inp).unsqueeze(2) 612 | mask = mask.unsqueeze(1) 613 | 614 | feature_vector = (feature_map*mask).view(batch_size,num_channel,num_part,-1).sum(-1) 615 | feature_vector = feature_vector.view(batch_size,dim) 616 | 617 | return feature_vector 618 | 619 | def forward_stage_1(self,inp,mask): 620 | 621 | feature_vector = self.get_feature_vector(inp,mask) 622 | log_prediction = self.get_prediction(feature_vector,self.weight_base) 623 | 624 | return log_prediction -------------------------------------------------------------------------------- /utils/proto_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from . import util,dataloader 6 | 7 | def default_eval(refer_loader,query_loader,model,class_acc=False): 8 | 9 | fb_vector = None 10 | if hasattr(model,'get_fb_vector'): 11 | fb_vector = get_fb_vector(refer_loader,model) 12 | 13 | centroid = get_class_centroid(refer_loader,model,fb_vector) 14 | 15 | acc = get_prediction(query_loader,model,centroid,fb_vector,class_acc) 16 | 17 | return acc 18 | 19 | 20 | def get_class_centroid(loader,model,fb_vector=None): 21 | 22 | way = len(loader.dataset.classes) 23 | dim = model.dim 24 | 25 | centroid = torch.zeros(way,dim).cuda() 26 | 27 | for i, (inp,target) in enumerate(loader): 28 | 29 | current_class_id = target[0] 30 | 31 | if fb_vector is not None: 32 | (image,_) = inp 33 | image = image.cuda() 34 | vectors = model.get_feature_vector(image,fb_vector) 35 | 36 | elif isinstance(inp,list): 37 | (image,mask) = inp 38 | image = image.cuda() 39 | mask = mask.cuda() 40 | vectors = model.get_feature_vector(image,mask) 41 | 42 | elif isinstance(inp,torch.Tensor): 43 | inp = inp.cuda() 44 | vectors = model.get_feature_vector(inp) 45 | 46 | centroid[current_class_id] = vectors.mean(0).view(dim) 47 | 48 | return centroid 49 | 50 | 51 | def get_prediction(loader,model,centroid,fb_vector=None,class_acc=False): 52 | 53 | data_source = loader.dataset 54 | centroid = centroid.unsqueeze(0) 55 | 56 | way = len(data_source.classes) 57 | 58 | correct_count = torch.zeros(way).cuda() 59 | 60 | counts = torch.zeros(way).cuda() 61 | 62 | for class_id in data_source.targets: 63 | counts[class_id] += 1 64 | 65 | for i, (inp,target) in enumerate(loader): 66 | 67 | current_class_id = target[0] 68 | batch_size = target.size(0) 69 | target = target.cuda() 70 | 71 | if fb_vector is not None: 72 | (image,mask) = inp 73 | image = image.cuda() 74 | out = model.get_feature_vector(image,fb_vector) 75 | 76 | elif isinstance(inp,list): 77 | (image,mask) = inp 78 | image = image.cuda() 79 | mask = mask.cuda() 80 | out = model.get_feature_vector(image,mask) 81 | 82 | elif isinstance(inp,torch.Tensor): 83 | inp = inp.cuda() 84 | out = model.get_feature_vector(inp) 85 | 86 | out = out.unsqueeze(1) 87 | neg_l2_distance = torch.sum((centroid-out)**2,2).neg().view(batch_size,way) 88 | 89 | _, top1_pred = neg_l2_distance.topk(1) 90 | 91 | correct_count[current_class_id] = torch.sum(torch.eq(top1_pred,target.view(batch_size,1))) 92 | 93 | acc = (torch.sum(correct_count)/torch.sum(counts)).item()*100 94 | 95 | if not class_acc: 96 | return acc 97 | else: 98 | class_acc = torch.mean(correct_count/counts).item()*100 99 | return acc,class_acc 100 | 101 | 102 | 103 | def get_fb_vector(loader,model): 104 | 105 | num_channel = model.num_channel 106 | sum_fb_vector = torch.zeros(num_channel,2).cuda() 107 | total_num = 0 108 | 109 | for i,((inp,mask),class_id) in enumerate(loader): 110 | 111 | total_num += inp.size(0) 112 | inp=inp.cuda() 113 | mask=mask.cuda() 114 | 115 | fb_vector = model.get_fb_vector(inp,mask) 116 | sum_fb_vector += fb_vector.sum(0) 117 | 118 | fb_vector = sum_fb_vector/total_num 119 | 120 | return fb_vector 121 | 122 | 123 | def k_shot_eval(eval_loader,model,way,shot): 124 | 125 | test_shot = 16 126 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 127 | 128 | acc_list = [] 129 | 130 | for i, (inp,_) in enumerate(eval_loader): 131 | 132 | if isinstance(inp,list): 133 | (image_inp,mask) = inp 134 | image_inp = image_inp.cuda() 135 | mask = mask.cuda() 136 | max_index = model.eval_k_shot(image_inp,mask,way,shot) 137 | 138 | elif isinstance(inp,torch.Tensor): 139 | inp = inp.cuda() 140 | max_index = model.eval_k_shot(inp,way,shot) 141 | 142 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 143 | acc_list.append(acc) 144 | 145 | mean,interval = util.eval(acc_list) 146 | 147 | return mean,interval 148 | 149 | 150 | def eval_test(model,pm,config,pm_na=None): 151 | 152 | logger = config.logger 153 | annot = config.eval_annot 154 | 155 | logger.info('------------------------') 156 | logger.info('evaluating:') 157 | 158 | with torch.no_grad(): 159 | 160 | model.load_state_dict(torch.load(config.save_path)) 161 | model.eval() 162 | 163 | refer_loader = dataloader.eval_dataloader(pm.test_refer, 164 | annot=annot,annot_path=pm.annot_path) 165 | query_loader = dataloader.eval_dataloader(pm.test_query, 166 | annot=annot,annot_path=pm.annot_path) 167 | 168 | test_acc = default_eval(refer_loader,query_loader,model=model) 169 | logger.info(('the final test acc is %.3f') % (test_acc)) 170 | 171 | way = len(refer_loader.dataset.classes) 172 | for shot in [1,5]: 173 | eval_loader = dataloader.eval_k_shot_dataloader(pm.k_shot, 174 | way=way,shot=shot,annot=annot,annot_path=pm.k_shot_annot_path) 175 | mean,interval = k_shot_eval(eval_loader,model,way,shot) 176 | logger.info('%d-way-%d-shot acc: %.2f\t%.2f'%(way,shot,mean,interval)) 177 | 178 | if pm_na is not None: 179 | 180 | logger.info('------------------------') 181 | logger.info('evaluating on NA:') 182 | 183 | refer_loader = dataloader.eval_dataloader(pm_na.test_refer, 184 | annot=annot,annot_path=pm_na.annot_path) 185 | query_loader = dataloader.eval_dataloader(pm_na.test_query, 186 | annot=annot,annot_path=pm_na.annot_path) 187 | 188 | mean_acc,class_acc = default_eval(refer_loader,query_loader, 189 | model=model,class_acc=True) 190 | 191 | logger.info(('mean_acc is %.3f') % (mean_acc)) 192 | logger.info(('class_acc is %.3f') % (class_acc)) 193 | 194 | -------------------------------------------------------------------------------- /utils/proto_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from tensorboardX import SummaryWriter 6 | from torchvision.utils import make_grid 7 | from torch.nn import NLLLoss,BCEWithLogitsLoss,BCELoss 8 | from . import util 9 | 10 | def default_train(train_loader,model, 11 | optimizer,writer,iter_counter): 12 | 13 | way = model.way 14 | test_shot = model.shots[-1] 15 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 16 | criterion = NLLLoss().cuda() 17 | 18 | lr = optimizer.param_groups[0]['lr'] 19 | 20 | writer.add_scalar('lr',lr,iter_counter) 21 | 22 | avg_loss = 0 23 | avg_acc = 0 24 | 25 | for i, (inp,_) in enumerate(train_loader): 26 | 27 | iter_counter += 1 28 | 29 | if isinstance(inp,list): 30 | (image_inp,mask) = inp 31 | image_inp = image_inp.cuda() 32 | mask = mask.cuda() 33 | log_prediction = model(image_inp,mask) 34 | 35 | elif isinstance(inp,torch.Tensor): 36 | inp = inp.cuda() 37 | log_prediction = model(inp) 38 | 39 | loss = criterion(log_prediction,target) 40 | 41 | optimizer.zero_grad() 42 | loss.backward() 43 | optimizer.step() 44 | 45 | loss_value = loss.item() 46 | _,max_index = torch.max(log_prediction,1) 47 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 48 | 49 | avg_acc += acc 50 | avg_loss += loss_value 51 | 52 | avg_acc = avg_acc/(i+1) 53 | avg_loss = avg_loss/(i+1) 54 | 55 | writer.add_scalar('proto_loss',avg_loss,iter_counter) 56 | writer.add_scalar('train_acc',avg_acc,iter_counter) 57 | 58 | return iter_counter,avg_acc 59 | 60 | 61 | def PN_train(train_loader,model, 62 | optimizer,writer,iter_counter,alpha): 63 | 64 | test_shot = model.shots[-1] 65 | way = model.way 66 | 67 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 68 | criterion = NLLLoss().cuda() 69 | criterion_part = BCEWithLogitsLoss().cuda() 70 | 71 | lr = optimizer.param_groups[0]['lr'] 72 | 73 | writer.add_scalar('lr',lr,iter_counter) 74 | 75 | avg_proto_loss = 0 76 | avg_heatmap_loss = 0 77 | avg_total_loss = 0 78 | avg_acc = 0 79 | 80 | for i, ((inp,mask),_) in enumerate(train_loader): 81 | 82 | iter_counter += 1 83 | inp = inp.cuda() 84 | mask = mask.cuda() 85 | 86 | if iter_counter%1000==0: 87 | model.eval() 88 | util.visualize(model,writer,iter_counter,inp[:9],mask[:9]) 89 | model.train() 90 | 91 | log_prediction,heatmap_logits = model(inp,mask) 92 | 93 | loss_heatmap = criterion_part(heatmap_logits,mask) 94 | loss_proto = criterion(log_prediction,target) 95 | loss = alpha*loss_heatmap+loss_proto 96 | 97 | optimizer.zero_grad() 98 | loss.backward() 99 | optimizer.step() 100 | 101 | _,max_index = torch.max(log_prediction,1) 102 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 103 | 104 | avg_acc += acc 105 | avg_total_loss += loss.item() 106 | avg_proto_loss += loss_proto.item() 107 | avg_heatmap_loss += loss_heatmap.item() 108 | 109 | avg_total_loss = avg_total_loss/(i+1) 110 | avg_proto_loss = avg_proto_loss/(i+1) 111 | avg_heatmap_loss = avg_heatmap_loss/(i+1) 112 | avg_acc = avg_acc/(i+1) 113 | 114 | writer.add_scalar('total_loss',avg_total_loss,iter_counter) 115 | writer.add_scalar('proto_loss',avg_proto_loss,iter_counter) 116 | writer.add_scalar('heatmap_loss',avg_heatmap_loss,iter_counter) 117 | 118 | writer.add_scalar('train_acc',avg_acc,iter_counter) 119 | 120 | return iter_counter,avg_acc 121 | 122 | 123 | 124 | def PN_train_less_annot(train_loader,model, 125 | optimizer,writer,iter_counter,alpha,batch_size): 126 | 127 | test_shot = model.shots[-1] 128 | way = model.way 129 | 130 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 131 | criterion = NLLLoss().cuda() 132 | criterion_part = BCEWithLogitsLoss().cuda() 133 | 134 | lr = optimizer.param_groups[0]['lr'] 135 | 136 | writer.add_scalar('lr',lr,iter_counter) 137 | 138 | avg_proto_loss = 0 139 | avg_heatmap_loss = 0 140 | avg_total_loss = 0 141 | avg_acc = 0 142 | 143 | for i, ((inp,mask),_) in enumerate(train_loader): 144 | 145 | iter_counter += 1 146 | mask = mask[:way*batch_size] 147 | 148 | optimizer.zero_grad() 149 | 150 | log_prediction = model.forward_class(inp[way*batch_size:].cuda()) 151 | loss_proto = criterion(log_prediction,target) 152 | loss_proto.backward() 153 | 154 | heatmap_logits = model.forward_part(inp[:batch_size*way].cuda()) 155 | loss_heatmap = alpha*criterion_part(heatmap_logits,mask.cuda()) 156 | loss_heatmap.backward() 157 | 158 | optimizer.step() 159 | 160 | _,max_index = torch.max(log_prediction,1) 161 | loss = loss_proto+loss_heatmap 162 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 163 | 164 | avg_acc += acc 165 | avg_total_loss += loss.item() 166 | avg_proto_loss += loss_proto.item() 167 | avg_heatmap_loss += (loss_heatmap/alpha).item() 168 | 169 | if iter_counter%1000==0: 170 | model.eval() 171 | util.visualize(model,writer,iter_counter,inp[:9].cuda(),mask[:9]) 172 | model.train() 173 | 174 | avg_total_loss = avg_total_loss/(i+1) 175 | avg_proto_loss = avg_proto_loss/(i+1) 176 | avg_heatmap_loss = avg_heatmap_loss/(i+1) 177 | avg_acc = avg_acc/(i+1) 178 | 179 | writer.add_scalar('total_loss',avg_total_loss,iter_counter) 180 | writer.add_scalar('proto_loss',avg_proto_loss,iter_counter) 181 | writer.add_scalar('heatmap_loss',avg_heatmap_loss,iter_counter) 182 | 183 | writer.add_scalar('train_acc',avg_acc,iter_counter) 184 | 185 | return iter_counter,avg_acc 186 | 187 | 188 | 189 | def bbN_train(train_loader,model, 190 | optimizer,writer,iter_counter,alpha): 191 | 192 | test_shot = model.shots[-1] 193 | way = model.way 194 | 195 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 196 | criterion = NLLLoss().cuda() 197 | criterion_local = BCELoss().cuda() 198 | 199 | lr = optimizer.param_groups[0]['lr'] 200 | 201 | writer.add_scalar('lr',lr,iter_counter) 202 | 203 | avg_proto_loss = 0 204 | avg_heatmap_loss = 0 205 | avg_total_loss = 0 206 | avg_acc = 0 207 | 208 | for i, ((inp,mask),_) in enumerate(train_loader): 209 | 210 | iter_counter += 1 211 | inp = inp.cuda() 212 | mask = mask.cuda() 213 | mask = torch.cat((mask,1.0-mask),1) 214 | 215 | if iter_counter%1000==0: 216 | model.eval() 217 | util.visualize(model,writer,iter_counter,inp[:9],mask[:9]) 218 | model.train() 219 | 220 | log_prediction,heatmap = model(inp,mask) 221 | 222 | loss_heatmap = criterion_local(heatmap,mask) 223 | loss_proto = criterion(log_prediction,target) 224 | loss = alpha*loss_heatmap+loss_proto 225 | 226 | optimizer.zero_grad() 227 | loss.backward() 228 | optimizer.step() 229 | 230 | _,max_index = torch.max(log_prediction,1) 231 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 232 | 233 | avg_acc += acc 234 | avg_total_loss += loss.item() 235 | avg_proto_loss += loss_proto.item() 236 | avg_heatmap_loss += loss_heatmap.item() 237 | 238 | avg_total_loss = avg_total_loss/(i+1) 239 | avg_proto_loss = avg_proto_loss/(i+1) 240 | avg_heatmap_loss = avg_heatmap_loss/(i+1) 241 | avg_acc = avg_acc/(i+1) 242 | 243 | writer.add_scalar('total_loss',avg_total_loss,iter_counter) 244 | writer.add_scalar('proto_loss',avg_proto_loss,iter_counter) 245 | writer.add_scalar('heatmap_loss',avg_heatmap_loss,iter_counter) 246 | 247 | writer.add_scalar('train_acc',avg_acc,iter_counter) 248 | 249 | return iter_counter,avg_acc 250 | 251 | 252 | def fgvc_PN_train(train_loader,oid_loader,model, 253 | optimizer,writer,iter_counter,alpha): 254 | 255 | way = model.way 256 | shots = model.shots 257 | 258 | test_shot = shots[-1] 259 | target = torch.LongTensor([i//test_shot for i in range(test_shot*way)]).cuda() 260 | criterion = NLLLoss().cuda() 261 | criterion_part = BCEWithLogitsLoss().cuda() 262 | 263 | lr = optimizer.param_groups[0]['lr'] 264 | 265 | writer.add_scalar('lr',lr,iter_counter) 266 | 267 | avg_proto_loss = 0 268 | avg_heatmap_loss = 0 269 | avg_total_loss = 0 270 | avg_acc = 0 271 | 272 | for i, ((inp,_),(oid_img,mask)) in enumerate(zip(train_loader,oid_loader)): 273 | 274 | iter_counter += 1 275 | 276 | optimizer.zero_grad() 277 | 278 | log_prediction = model.forward_class(inp.cuda()) 279 | loss_proto = criterion(log_prediction,target) 280 | loss_proto.backward() 281 | 282 | heatmap_logits = model.forward_part(oid_img.cuda()) 283 | loss_heatmap = alpha*criterion_part(heatmap_logits,mask.cuda()) 284 | loss_heatmap.backward() 285 | 286 | optimizer.step() 287 | 288 | _,max_index = torch.max(log_prediction,1) 289 | loss = loss_proto+loss_heatmap 290 | acc = 100*torch.sum(torch.eq(max_index,target)).item()/test_shot/way 291 | 292 | avg_acc += acc 293 | avg_total_loss += loss.item() 294 | avg_proto_loss += loss_proto.item() 295 | avg_heatmap_loss += (loss_heatmap/alpha).item() 296 | 297 | if iter_counter%1000==0: 298 | model.eval() 299 | util.visualize(model,writer,iter_counter,oid_img[:9].cuda(),mask[:9]) 300 | model.train() 301 | 302 | avg_total_loss = avg_total_loss/(i+1) 303 | avg_proto_loss = avg_proto_loss/(i+1) 304 | avg_heatmap_loss = avg_heatmap_loss/(i+1) 305 | avg_acc = avg_acc/(i+1) 306 | 307 | writer.add_scalar('total_loss',avg_total_loss,iter_counter) 308 | writer.add_scalar('proto_loss',avg_proto_loss,iter_counter) 309 | writer.add_scalar('heatmap_loss',avg_heatmap_loss,iter_counter) 310 | 311 | writer.add_scalar('train_acc',avg_acc,iter_counter) 312 | 313 | return iter_counter,avg_acc 314 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | import numpy as np 5 | from copy import deepcopy 6 | from torch.utils.data import Sampler 7 | 8 | 9 | class meta_batchsampler(Sampler): 10 | 11 | def __init__(self,data_source,way,shots): 12 | 13 | self.way = way 14 | self.shots = shots 15 | 16 | class2id = {} 17 | 18 | for i,(image_path,class_id) in enumerate(data_source.imgs): 19 | if class_id not in class2id: 20 | class2id[class_id]=[] 21 | class2id[class_id].append(i) 22 | 23 | self.class2id = class2id 24 | 25 | 26 | def __iter__(self): 27 | 28 | temp_class2id = deepcopy(self.class2id) 29 | for class_id in temp_class2id: 30 | np.random.shuffle(temp_class2id[class_id]) 31 | 32 | while len(temp_class2id) >= self.way: 33 | 34 | id_list = [] 35 | 36 | list_class_id = list(temp_class2id.keys()) 37 | 38 | pcount = np.array([len(temp_class2id[class_id]) for class_id in list_class_id]) 39 | 40 | batch_class_id = np.random.choice(list_class_id,size=self.way,replace=False,p=pcount/sum(pcount)) 41 | 42 | for shot in self.shots: 43 | for class_id in batch_class_id: 44 | for _ in range(shot): 45 | id_list.append(temp_class2id[class_id].pop()) 46 | 47 | for class_id in batch_class_id: 48 | if len(temp_class2id[class_id])= self.way: 167 | 168 | idlist = [] 169 | 170 | pcount = np.array([len(trackdict2[k]) for k in list(trackdict2.keys())]) 171 | cats = np.random.choice(list(trackdict2.keys()),size=self.way,replace=False,p=pcount/sum(pcount)) 172 | 173 | for cat in cats: 174 | idlist.extend(np.random.choice(trackdict1[cat],size=self.batch_size,replace=False)) 175 | 176 | for shot in self.shots: 177 | for cat in cats: 178 | for _ in range(shot): 179 | idlist.append(trackdict2[cat].pop()) 180 | 181 | for cat in cats: 182 | if len(trackdict2[cat]) best_val_acc: 247 | best_val_acc = val_acc 248 | best_epoch = e+1 249 | torch.save(model.state_dict(),save_path) 250 | logger.info('BEST!') 251 | 252 | self.set_train_mode(model) 253 | 254 | scheduler.step() 255 | 256 | logger.info('training finished!') 257 | 258 | if validation: 259 | logger.info('------------------------') 260 | logger.info(('the best epoch is %d/%d') % (best_epoch,total_epoch)) 261 | logger.info(('the best val acc is %.3f') % (best_val_acc)) 262 | 263 | else: 264 | torch.save(model.state_dict(),save_path) 265 | 266 | def set_train_mode(self,model): 267 | model.train() 268 | 269 | 270 | 271 | 272 | class TM_dynamic_stage_2(Train_Manager): 273 | 274 | def set_train_mode(self,model): 275 | 276 | model.train() 277 | 278 | model.feature_extractor.eval() 279 | for param in model.feature_extractor.parameters(): 280 | param.requires_grad = False 281 | 282 | 283 | 284 | class TM_dynamic_PN_stage_2(Train_Manager): 285 | 286 | def set_train_mode(self,model): 287 | 288 | model.train() 289 | 290 | model.PN_Model.eval() 291 | for param in model.PN_Model.parameters(): 292 | param.requires_grad = False 293 | 294 | 295 | 296 | class TM_transfer_finetune(Train_Manager): 297 | 298 | def set_train_mode(self,model): 299 | 300 | model.feature_extractor.eval() 301 | 302 | for param in model.feature_extractor.parameters(): 303 | param.requires_grad = False 304 | 305 | model.linear_classifier.train() 306 | 307 | 308 | 309 | class TM_transfer_PN_finetune(Train_Manager): 310 | 311 | def set_train_mode(self,model): 312 | 313 | model.shared_layers.eval() 314 | for param in model.shared_layers.parameters(): 315 | param.requires_grad = False 316 | 317 | model.class_branch.eval() 318 | for param in model.class_branch.parameters(): 319 | param.requires_grad = False 320 | 321 | model.part_branch.eval() 322 | for param in model.part_branch.parameters(): 323 | param.requires_grad = False 324 | 325 | model.linear_classifier.train() 326 | 327 | 328 | --------------------------------------------------------------------------------