├── datas └── HICO_meta_res18 ├── models └── 5-way 1-shot ├── HICO_Vector.mat ├── TUHOI_Vector.mat ├── TUHOI_class.txt ├── README.md ├── HICO_class.txt ├── task_generator_test.py ├── task_generator.py ├── SGAP_test.py ├── SGAP_train.py └── SGAP_tr_val_test.py /datas/HICO_meta_res18: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/5-way 1-shot: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /HICO_Vector.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuxiyao/SGAP-Net/HEAD/HICO_Vector.mat -------------------------------------------------------------------------------- /TUHOI_Vector.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liuxiyao/SGAP-Net/HEAD/TUHOI_Vector.mat -------------------------------------------------------------------------------- /TUHOI_class.txt: -------------------------------------------------------------------------------- 1 | popsicle 2 | camel 3 | laptop 4 | move 5 | spatula 6 | swimming_trunks 7 | snowmobile 8 | touch 9 | chair 10 | dumbbell 11 | blow 12 | cup 13 | tv 14 | backpack 15 | show 16 | sit_with 17 | bench 18 | balance_beam 19 | piano 20 | ski 21 | stove 22 | kick 23 | pingpong_ball 24 | hit 25 | ladle 26 | beat 27 | watch 28 | screwdriver 29 | sail_on 30 | snake 31 | vacuum 32 | cook 33 | walk 34 | cool 35 | croquet_ball 36 | power_drill 37 | sofa 38 | punching_bag 39 | baby_bed 40 | race 41 | saxophone 42 | motorcycle 43 | turtle 44 | bicycle 45 | sit_on 46 | carry 47 | cattle 48 | trombone 49 | hammer 50 | harp 51 | lipstick 52 | lead 53 | pet 54 | harmonica 55 | paddle 56 | lay_on 57 | unicycle 58 | bird 59 | stand_with 60 | drink 61 | drill 62 | rub 63 | electric_fan 64 | type_on 65 | use 66 | cart 67 | punch 68 | guitar 69 | spray 70 | watercraft 71 | baseball 72 | wear 73 | sing 74 | throw 75 | change 76 | dry 77 | dishwasher 78 | drum 79 | accordion 80 | iPod 81 | maraca 82 | golf_ball 83 | chop 84 | hair_spray 85 | stand 86 | swing 87 | feed 88 | repair 89 | stretcher 90 | puck 91 | play_on 92 | sit_in 93 | syringe 94 | table 95 | open 96 | speak 97 | horse 98 | basketball 99 | snowplow 100 | frog 101 | plastic_bag 102 | listen 103 | golfcart 104 | oboe 105 | bus 106 | park 107 | train 108 | stethoscope 109 | diaper 110 | rabbit 111 | french_horn 112 | hold 113 | eat 114 | shoot 115 | pull 116 | look 117 | car 118 | ride 119 | crutch 120 | lizard 121 | cello 122 | steer 123 | rugby_ball 124 | hamburger 125 | stick 126 | purse 127 | drive 128 | racket 129 | violin 130 | stand_on 131 | trumpet 132 | cream 133 | microphone 134 | sit 135 | perform 136 | put_on 137 | bowl 138 | elephant 139 | frying_pan 140 | horizontal_bar 141 | banjo 142 | pizza 143 | play 144 | water_bottle 145 | flute 146 | hang 147 | walk_to 148 | hair_dryer 149 | stab 150 | eat_with 151 | soccer_ball 152 | band_aid 153 | burrito 154 | dog 155 | volleyball 156 | lift 157 | clean 158 | axe 159 | skunk 160 | chain_saw 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # SGAP-Net: Semantic-Guided Attentive Prototypes Network for Few-Shot Human-Object Interaction Recognition, AAAI2020. 3 | 4 | ## Few-Shot Human-Object Interaction Recognition with Semantic-Guided Attentive Prototypes Network 5 | 6 | We resubmit it in TIP with 4 extensions as follows: 7 | + an alternative prototypes calculation approach called Hallucinatory Graph Prototypes (HGP), which consists of a hallucinator and an HOI Graph Convolution Network (GCN); 8 | + a new dataset split strategy, and the corresponding experiments; 9 | + cross-domain experiments between different datasets; 10 | + additional introduction to related work and ablation studies. 11 | 12 | ## Dependencies 13 | 14 | This code requires the following: 15 | 16 | python 3.6+* 17 | Pytorch 1.0+ 18 | 19 | ## Dataset 20 | [FS-HOI] 21 | code:283w 22 | 23 | ## Abstract 24 | 25 | Extreme instance imbalance among categories and combinatorial explosion make the recognition of Human-Object Interaction (HOI) a challenging task. Few studies have addressed both challenges directly. Motivated by the success of few-shot learning that learns a robust model from a few instances, we formulate HOI as a few-shot task in a meta-learning framework to alleviate the above challenges. Due to the fact that the intrinsic characteristic of HOI is diverse and interactive, we propose a Semantic-Guided Attentive Prototypes Network (SGAP-Net) to learn a semantic-guided metric space where HOI recognition can be performed by computing distances to attentive prototypes of each class. Specifically, the model generates attentive prototypes guided by the category names of actions and objects, which highlight the commonalities of images from the same class in HOI. In addition, we design a novel decision method to alleviate the biases produced by different patterns of the same action in HOI. Finally, in order to realize the task of few-shot HOI, we reorganize two HOI benchmark datasets, i.e., HICO-FS and TUHOI-FS, to realize the task of few-shot HOI. Extensive experimental results on both datasets have demonstrated the effectiveness of our proposed SGAP-Net approach. 26 | -------------------------------------------------------------------------------- /HICO_class.txt: -------------------------------------------------------------------------------- 1 | release 2 | row 3 | sail 4 | wave 5 | lasso 6 | milk 7 | eat_at 8 | sit_at 9 | dry 10 | scratch 11 | chase 12 | groom 13 | run 14 | train 15 | park 16 | push 17 | race 18 | straddle 19 | turn 20 | greet 21 | stab 22 | tag 23 | teach 24 | herd 25 | shear 26 | board 27 | exit 28 | break 29 | hunt 30 | blow 31 | light 32 | stir 33 | talk_on 34 | text_on 35 | drink_with 36 | pour 37 | hose 38 | hop_on 39 | walk 40 | lift 41 | assemble 42 | fly 43 | launch 44 | stick 45 | wield 46 | read 47 | type_on 48 | control 49 | peel 50 | pick 51 | squeeze 52 | check 53 | pay 54 | buy 55 | slide 56 | smell 57 | move 58 | point 59 | cook 60 | eat 61 | cut_with 62 | flip 63 | grind 64 | block 65 | catch 66 | dribble 67 | hit 68 | kick 69 | serve 70 | sign 71 | spin 72 | throw 73 | pack 74 | pick_up 75 | zip 76 | drag 77 | jump 78 | lie_on 79 | hug 80 | kiss 81 | swing 82 | adjust 83 | cut 84 | pull 85 | tie 86 | wear 87 | operate 88 | clean 89 | flush 90 | stand_on 91 | brush_with 92 | install 93 | stop_at 94 | direct 95 | drive 96 | inspect 97 | load 98 | ride 99 | sit_on 100 | carry 101 | lose 102 | open 103 | repair 104 | set 105 | stand_under 106 | make 107 | paint 108 | fill 109 | sip 110 | toast 111 | lick 112 | wash 113 | feed 114 | hold 115 | pet 116 | watch 117 | airplane 118 | bicycle 119 | bird 120 | boat 121 | bottle 122 | bus 123 | car 124 | cat 125 | chair 126 | couch 127 | cow 128 | dining_table 129 | dog 130 | horse 131 | motorcycle 132 | person 133 | potted_plant 134 | sheep 135 | tv 136 | apple 137 | backpack 138 | banana 139 | baseball_bat 140 | baseball_glove 141 | bear 142 | bed 143 | bench 144 | book 145 | bowl 146 | broccoli 147 | cake 148 | carrot 149 | cell_phone 150 | clock 151 | cup 152 | donut 153 | elephant 154 | fire_hydrant 155 | fork 156 | frisbee 157 | giraffe 158 | hair_drier 159 | handbag 160 | hot_dog 161 | keyboard 162 | kite 163 | knife 164 | laptop 165 | microwave 166 | mouse 167 | orange 168 | oven 169 | parking_meter 170 | pizza 171 | refrigerator 172 | remote 173 | sandwich 174 | scissors 175 | sink 176 | skateboard 177 | skis 178 | snowboard 179 | spoon 180 | sports_ball 181 | stop_sign 182 | suitcase 183 | surfboard 184 | teddy_bear 185 | tennis_racket 186 | toaster 187 | toilet 188 | toothbrush 189 | traffic_light 190 | truck 191 | umbrella 192 | vase 193 | wine_glass 194 | zebra 195 | -------------------------------------------------------------------------------- /task_generator_test.py: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def dataset_folders(dataset): 28 | # train_folder = './datas/' + dataset + '_res18_proto/train' 29 | # test_folder = './datas/' + dataset + '_res18_proto/test' 30 | train_folder = './datas/' + dataset + '_res18_proto55/train' 31 | test_folder = './datas/' + dataset + '_res18_proto55/test' 32 | 33 | metatrain_folders = all_path(train_folder) 34 | metatest_folders = all_path(test_folder) 35 | 36 | # metatrain_folders_noun = [os.path.join(train_folder, label) \ 37 | # for label in os.listdir(train_folder) \ 38 | # if os.path.isdir(os.path.join(train_folder, label)) \ 39 | # ] 40 | # metatest_folders_noun = [os.path.join(test_folder, label) \ 41 | # for label in os.listdir(test_folder) \ 42 | # if os.path.isdir(os.path.join(test_folder, label)) \ 43 | # ] 44 | 45 | # random.seed(1) 46 | random.shuffle(metatrain_folders) 47 | random.shuffle(metatest_folders) 48 | 49 | return metatrain_folders, metatest_folders 50 | 51 | def all_path(dirname): 52 | result = [] 53 | for maindir, subdir, file_name_list in os.walk(dirname): 54 | for filename in file_name_list: 55 | apath = os.path.join(maindir, filename) 56 | result.append(apath) 57 | return result 58 | 59 | class MiniImagenetTask(object): 60 | 61 | def __init__(self, character_folders, num_classes, train_num, test_num): 62 | 63 | self.character_folders = character_folders 64 | self.num_classes = num_classes 65 | self.train_num = train_num 66 | self.test_num = test_num 67 | 68 | class_folders = random.sample(self.character_folders, self.num_classes) 69 | 70 | labels = list(range(len(class_folders))) 71 | self.n_folders =[item.split('/')[4] for item in class_folders] 72 | self.v_folders = [(item.split('/')[5]).split('.')[0] for item in class_folders] 73 | labels = dict(zip(class_folders, labels)) 74 | samples = dict() 75 | 76 | self.train_roots = [] 77 | self.test_roots = [] 78 | for c in class_folders: 79 | 80 | temp = np.load(c) 81 | #temp = [os.path.join(c, x) for x in os.listdir(c)] 82 | samples[c] = random.sample(list(temp), len(temp)) 83 | random.shuffle(samples[c]) 84 | 85 | self.train_roots += samples[c][:train_num] 86 | self.test_roots += samples[c][train_num:train_num+test_num] 87 | 88 | self.train_labels = [labels[x] for x in class_folders for i in range(train_num)] 89 | self.test_labels = [labels[x] for x in class_folders for i in range(test_num) ] 90 | self.label_name = labels 91 | 92 | def get_class(self, sample): 93 | return os.path.join(*sample.split('/')[:-1]) 94 | 95 | class FewShotDataset(Dataset): 96 | 97 | def __init__(self, task, split='train', transform=None, target_transform=None): 98 | self.transform = transform # Torch operations on the input image 99 | self.target_transform = target_transform 100 | self.task = task 101 | self.split = split 102 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 103 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 104 | 105 | def __len__(self): 106 | return len(self.image_roots) 107 | 108 | def __getitem__(self, idx): 109 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 110 | 111 | class MiniImagenet(FewShotDataset): 112 | 113 | def __init__(self, *args, **kwargs): 114 | super(MiniImagenet, self).__init__(*args, **kwargs) 115 | 116 | def __getitem__(self, idx): 117 | image_root = self.image_roots[idx] 118 | image = image_root 119 | if self.transform is not None: 120 | image = self.transform(image) 121 | label = self.labels[idx] 122 | if self.target_transform is not None: 123 | label = self.target_transform(label) 124 | return image, label 125 | 126 | 127 | class ClassBalancedSampler(Sampler): 128 | ''' Samples 'num_inst' examples each from 'num_cl' pools 129 | of examples of size 'num_per_class' ''' 130 | 131 | def __init__(self, num_cl, num_inst,shuffle=False): 132 | 133 | self.num_cl = num_cl 134 | self.num_inst = num_inst 135 | self.shuffle = shuffle 136 | 137 | def __iter__(self): 138 | # return a single list of indices, assuming that items will be grouped by class 139 | if self.shuffle: 140 | batches = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)] for j in range(self.num_cl)] 141 | else: 142 | batches = [[i+j*self.num_inst for i in range(self.num_inst)] for j in range(self.num_cl)] 143 | batches = [[batches[j][i] for j in range(self.num_cl)] for i in range(self.num_inst)] 144 | 145 | if self.shuffle: 146 | random.shuffle(batches) 147 | for sublist in batches: 148 | random.shuffle(sublist) 149 | batches = [item for sublist in batches for item in sublist] 150 | return iter(batches) 151 | 152 | def __len__(self): 153 | return 1 154 | 155 | class ClassBalancedSamplerOld(Sampler): 156 | ''' Samples 'num_inst' examples each from 'num_cl' pools 157 | of examples of size 'num_per_class' ''' 158 | 159 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=False): 160 | self.num_per_class = num_per_class 161 | self.num_cl = num_cl 162 | self.num_inst = num_inst 163 | self.shuffle = shuffle 164 | 165 | def __iter__(self): 166 | # return a single list of indices, assuming that items will be grouped by class 167 | if self.shuffle: 168 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 169 | else: 170 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 171 | batch = [item for sublist in batch for item in sublist] 172 | 173 | if self.shuffle: 174 | random.shuffle(batch) 175 | return iter(batch) 176 | 177 | def __len__(self): 178 | return 1 179 | 180 | 181 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False): 182 | #normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 183 | #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])#ImageNet 184 | #normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 185 | 186 | #dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize]))#normalize,transforms.CenterCrop(224) 187 | dataset = MiniImagenet(task, split=split) 188 | if split == 'train': 189 | sampler = ClassBalancedSamplerOld(num_per_class,task.num_classes, task.train_num,shuffle=shuffle) 190 | 191 | else: 192 | sampler = ClassBalancedSampler(task.num_classes, task.test_num,shuffle=shuffle) 193 | 194 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 195 | return loader 196 | -------------------------------------------------------------------------------- /task_generator.py: -------------------------------------------------------------------------------- 1 | # code is based on https://github.com/katerakelly/pytorch-maml 2 | import torchvision 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | import torch 6 | from torch.utils.data import DataLoader,Dataset 7 | import random 8 | import os 9 | from PIL import Image 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | def imshow(img): 15 | npimg = img.numpy() 16 | plt.axis("off") 17 | plt.imshow(np.transpose(npimg,(1,2,0))) 18 | plt.show() 19 | 20 | class Rotate(object): 21 | def __init__(self, angle): 22 | self.angle = angle 23 | def __call__(self, x, mode="reflect"): 24 | x = x.rotate(self.angle) 25 | return x 26 | 27 | def dataset_folders(dataset): 28 | # train_folder = './datas/' + dataset + '_meta_res18/train' 29 | # test_folder = './datas/' + dataset + '_meta_res18/val' 30 | 31 | train_folder = './datas/' + dataset + '_res18_proto55/train' 32 | test_folder = './datas/' + dataset + '_res18_proto55/val' 33 | 34 | metatrain_folders = all_path(train_folder) 35 | metatest_folders = all_path(test_folder) 36 | 37 | # metatrain_folders_noun = [os.path.join(train_folder, label) \ 38 | # for label in os.listdir(train_folder) \ 39 | # if os.path.isdir(os.path.join(train_folder, label)) \ 40 | # ] 41 | # metatest_folders_noun = [os.path.join(test_folder, label) \ 42 | # for label in os.listdir(test_folder) \ 43 | # if os.path.isdir(os.path.join(test_folder, label)) \ 44 | # ] 45 | 46 | #random.seed(1) 47 | random.shuffle(metatrain_folders) 48 | random.shuffle(metatest_folders) 49 | 50 | return metatrain_folders,metatest_folders 51 | 52 | 53 | def all_path(dirname): 54 | result = [] 55 | for maindir, subdir, file_name_list in os.walk(dirname): 56 | for filename in file_name_list: 57 | apath = os.path.join(maindir, filename) 58 | result.append(apath) 59 | return result 60 | 61 | class MiniImagenetTask(object): 62 | 63 | def __init__(self, character_folders, num_classes, train_num, test_num): 64 | 65 | self.character_folders = character_folders 66 | self.num_classes = num_classes 67 | self.train_num = train_num 68 | self.test_num = test_num 69 | 70 | 71 | class_folders = random.sample(self.character_folders,self.num_classes) 72 | 73 | labels = list(range(len(class_folders))) 74 | self.n_folders =[item.split('/')[4] for item in class_folders] 75 | self.v_folders = [(item.split('/')[5]).split('.')[0] for item in class_folders] 76 | labels = dict(zip(class_folders, labels)) 77 | samples = dict() 78 | 79 | self.train_roots = [] 80 | self.test_roots = [] 81 | for c in class_folders: 82 | 83 | temp = np.load(c) 84 | #temp = [os.path.join(c, x) for x in os.listdir(c)] 85 | samples[c] = random.sample(list(temp), len(temp)) 86 | random.shuffle(samples[c]) 87 | 88 | self.train_roots += samples[c][:train_num] 89 | self.test_roots += samples[c][train_num:train_num+test_num] 90 | 91 | self.train_labels = [labels[x] for x in class_folders for i in range(train_num)] 92 | self.test_labels = [labels[x] for x in class_folders for i in range(test_num) ] 93 | self.label_name = labels 94 | 95 | def get_class(self, sample): 96 | return os.path.join(*sample.split('/')[:-1]) 97 | 98 | class FewShotDataset(Dataset): 99 | 100 | def __init__(self, task, split='train', transform=None, target_transform=None): 101 | self.transform = transform # Torch operations on the input image 102 | self.target_transform = target_transform 103 | self.task = task 104 | self.split = split 105 | self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots 106 | self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels 107 | 108 | def __len__(self): 109 | return len(self.image_roots) 110 | 111 | def __getitem__(self, idx): 112 | raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.") 113 | 114 | class MiniImagenet(FewShotDataset): 115 | 116 | def __init__(self, *args, **kwargs): 117 | super(MiniImagenet, self).__init__(*args, **kwargs) 118 | 119 | def __getitem__(self, idx): 120 | image_root = self.image_roots[idx] 121 | image = image_root 122 | if self.transform is not None: 123 | image = self.transform(image) 124 | label = self.labels[idx] 125 | if self.target_transform is not None: 126 | label = self.target_transform(label) 127 | return image, label 128 | 129 | 130 | class ClassBalancedSampler(Sampler): 131 | ''' Samples 'num_inst' examples each from 'num_cl' pools 132 | of examples of size 'num_per_class' ''' 133 | 134 | def __init__(self, num_cl, num_inst,shuffle=False): 135 | 136 | self.num_cl = num_cl 137 | self.num_inst = num_inst 138 | self.shuffle = shuffle 139 | 140 | def __iter__(self): 141 | # return a single list of indices, assuming that items will be grouped by class 142 | if self.shuffle: 143 | batches = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)] for j in range(self.num_cl)] 144 | else: 145 | batches = [[i+j*self.num_inst for i in range(self.num_inst)] for j in range(self.num_cl)] 146 | batches = [[batches[j][i] for j in range(self.num_cl)] for i in range(self.num_inst)] 147 | 148 | if self.shuffle: 149 | random.shuffle(batches) 150 | for sublist in batches: 151 | random.shuffle(sublist) 152 | batches = [item for sublist in batches for item in sublist] 153 | return iter(batches) 154 | 155 | def __len__(self): 156 | return 1 157 | 158 | class ClassBalancedSamplerOld(Sampler): 159 | ''' Samples 'num_inst' examples each from 'num_cl' pools 160 | of examples of size 'num_per_class' ''' 161 | 162 | def __init__(self, num_per_class, num_cl, num_inst,shuffle=False): 163 | self.num_per_class = num_per_class 164 | self.num_cl = num_cl 165 | self.num_inst = num_inst 166 | self.shuffle = shuffle 167 | 168 | def __iter__(self): 169 | # return a single list of indices, assuming that items will be grouped by class 170 | if self.shuffle: 171 | batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 172 | else: 173 | batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)] 174 | batch = [item for sublist in batch for item in sublist] 175 | 176 | if self.shuffle: 177 | random.shuffle(batch) 178 | return iter(batch) 179 | 180 | def __len__(self): 181 | return 1 182 | 183 | 184 | def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False): 185 | #normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) 186 | #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])#ImageNet 187 | #normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 188 | 189 | #dataset = MiniImagenet(task,split=split,transform=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize]))#normalize,transforms.CenterCrop(224) 190 | dataset = MiniImagenet(task, split=split) 191 | if split == 'train': 192 | sampler = ClassBalancedSamplerOld(num_per_class,task.num_classes, task.train_num,shuffle=shuffle) 193 | 194 | else: 195 | sampler = ClassBalancedSampler(task.num_classes, task.test_num,shuffle=shuffle) 196 | 197 | loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler) 198 | return loader 199 | -------------------------------------------------------------------------------- /SGAP_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------- 2 | # Project: 3 | # Date: 4 | # Author: 5 | # All Rights Reserved 6 | #------------------------------------- 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.optim.lr_scheduler import StepLR 14 | import numpy as np 15 | import task_generator_test as tg 16 | import os 17 | import math 18 | import argparse 19 | import scipy as sp 20 | import scipy.stats 21 | import scipy.io as sio 22 | 23 | parser = argparse.ArgumentParser(description="HOI Recognition") 24 | parser.add_argument("-dataset","--dataset",default = 'HICO') 25 | parser.add_argument("-way","--class_num",type = int, default = 5) 26 | parser.add_argument("-shot","--sample_num_per_class",type = int, default = 5) 27 | parser.add_argument("-query_num","--query_num_per_class",type = int, default =1) 28 | parser.add_argument("-episode","--episode",type = int, default= 10) 29 | parser.add_argument("-test_episode","--test_episode", type = int, default = 600) 30 | parser.add_argument("-learning_rate","--learning_rate", type = float, default = 0.000001) 31 | parser.add_argument("-gpu","--gpu",type=int, default=0) 32 | parser.add_argument("-t","--tau",type=int, default=1.5) 33 | parser.add_argument("-a","--alpha",type=int, default=0.5) 34 | args = parser.parse_args() 35 | 36 | Feature_D = 512 37 | # Hyper Parameters 38 | Dataset = args.dataset 39 | CLASS_NUM = args.class_num 40 | SAMPLE_NUM_PER_CLASS = args.sample_num_per_class 41 | BATCH_NUM_PER_CLASS = args.query_num_per_class 42 | EPISODE = args.episode 43 | TEST_EPISODE = args.test_episode 44 | LEARNING_RATE = args.learning_rate 45 | GPU = args.gpu 46 | Tau = args.tau 47 | Alpha = args.alpha 48 | 49 | 50 | def mean_confidence_interval(data, confidence=0.95): 51 | a = 1.0*np.array(data) 52 | n = len(a) 53 | m, se = np.mean(a), scipy.stats.sem(a) 54 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 55 | return m,h 56 | class TNetwork(nn.Module): 57 | """docstring for RelationNetwork""" 58 | def __init__(self): 59 | super(TNetwork, self).__init__() 60 | self.word_feature1 = nn.Linear(400,Feature_D) 61 | self.word_feature2 = nn.Linear(Feature_D, Feature_D) 62 | self.features=nn.Linear(Feature_D,Feature_D) 63 | self.map=nn.Linear(Feature_D,Feature_D) 64 | self.norm = nn.BatchNorm1d(Feature_D, momentum=1, affine=True) 65 | 66 | def forward(self,x,y): 67 | out1 = F.relu(self.word_feature1(x)) 68 | out1 = F.sigmoid(self.word_feature2(out1)) 69 | out2 = F.relu(self.features(y)) 70 | out = F.relu((out1 + 1) * out2) 71 | out = F.relu(self.map(out)) 72 | return out 73 | 74 | class Generate_word(nn.Module): 75 | """docstring for RelationNetwork""" 76 | def __init__(self): 77 | super(Generate_word, self).__init__() 78 | self.word_feature1 = nn.Linear(Feature_D,400) 79 | #self.word_feature2 = nn.Linear(1000,400) 80 | 81 | def forward(self,x): 82 | out1 = F.relu(self.word_feature1(x)) 83 | #out2 = F.relu(self.word_feature2(y)) 84 | return out1 85 | 86 | def euclidean_dist2(x,y): 87 | n = x.size(0) 88 | m = y.size(0) 89 | d = x.size(1) 90 | if d !=y.size(1): 91 | raise Exception 92 | x = x.unsqueeze(1).expand(n,m,d) 93 | y = y.unsqueeze(0).expand(n,m,d) 94 | return torch.pow(x-y,2).sum(2) 95 | 96 | def euclidean_dist(x,y): 97 | x1 = x.size(0) 98 | x2 = x.size(1) 99 | x3 = x.size(2) 100 | y1 = y.size(0) 101 | y2 = y.size(1) 102 | if x3 !=y2: 103 | raise Exception 104 | #x = x.unsqueeze(1).expand(n,m,d) 105 | y = y.unsqueeze(0).expand(x1,x2,x3) 106 | return torch.pow(x-y,2).sum(2) 107 | 108 | class Task_norm(nn.Module): 109 | """docstring for RelationNetwork""" 110 | def __init__(self): 111 | super(Task_norm, self).__init__() 112 | self.norm = nn.BatchNorm1d(Feature_D, momentum=1, affine=True) 113 | #self.word_feature2 = nn.Linear(1000,400) 114 | 115 | def forward(self,x): 116 | if len(x.size())==3: 117 | x = x.view(-1,Feature_D) 118 | out = self.norm(x) 119 | out = out.view(5,-1,Feature_D) 120 | else: 121 | out = self.norm(x) 122 | 123 | #out2 = F.relu(self.word_feature2(y)) 124 | return out 125 | 126 | def step_seq(a,x): 127 | x[x last_accuracy: 386 | 387 | # save networks 388 | torch.save(generate_word_noun.state_dict(), str("./models/" + Dataset + "_generate_word_noun_" + str(CLASS_NUM) + "way" + str(SAMPLE_NUM_PER_CLASS) + "shot" + "_max.pkl")) 389 | torch.save(generate_word_verb.state_dict(), str("./models/" + Dataset + "_generate_word_verb_" + str(CLASS_NUM) + "way" + str(SAMPLE_NUM_PER_CLASS) + "shot" + "_max.pkl")) 390 | torch.save(tnetwork.state_dict(), str("./models/" + Dataset + "_tnetwork_" + str(CLASS_NUM) + "way" + str(SAMPLE_NUM_PER_CLASS) + "shot" + "_max.pkl")) 391 | torch.save(tvnetwork.state_dict(), str("./models/" + Dataset + "_tvnetwork_" + str(CLASS_NUM) + "way" + str(SAMPLE_NUM_PER_CLASS) + "shot" + "_max.pkl")) 392 | torch.save(task_norm.state_dict(), str("./models/" + Dataset + "_task_norm_" + str(CLASS_NUM) + "way" + str(SAMPLE_NUM_PER_CLASS) + "shot" + "_max.pkl")) 393 | print("save networks for episode:",episode) 394 | best_episode = episode 395 | last_accuracy = test_accuracy 396 | print('best accuracy:',last_accuracy,"best networks from episode:",best_episode) 397 | 398 | 399 | if __name__ == '__main__': 400 | main() 401 | -------------------------------------------------------------------------------- /SGAP_tr_val_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------- 2 | # Project: 3 | # Date: 4 | # Author: 5 | # All Rights Reserved 6 | #------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torch.optim.lr_scheduler import StepLR 13 | import numpy as np 14 | import task_generator as tg 15 | import task_generator_test as tg_test 16 | import os 17 | import math 18 | import argparse 19 | import scipy as sp 20 | import scipy.stats 21 | import torchvision 22 | import scipy.io as sio 23 | 24 | parser = argparse.ArgumentParser(description="HOI Recognition") 25 | parser.add_argument("-dataset","--dataset",default = 'HICO') 26 | parser.add_argument("-way","--class_num",type = int, default = 5) 27 | parser.add_argument("-shot","--sample_num_per_class",type = int, default = 1) 28 | parser.add_argument("-query_num","--query_num_per_class",type = int, default =5) 29 | parser.add_argument("-episode","--episode",type = int, default= 500000) 30 | parser.add_argument("-test_episode","--test_episode", type = int, default = 600) 31 | parser.add_argument("-test_round","--test_round",type = int, default= 10) 32 | parser.add_argument("-learning_rate","--learning_rate", type = float, default = 0.000001) 33 | parser.add_argument("-gpu","--gpu",type=int, default=0) 34 | parser.add_argument("-t","--tau",type=int, default=1.5) 35 | parser.add_argument("-a","--alpha",type=int, default=0.5) 36 | args = parser.parse_args() 37 | 38 | 39 | # Hyper Parameters 40 | Dataset = args.dataset 41 | CLASS_NUM = args.class_num 42 | SAMPLE_NUM_PER_CLASS = args.sample_num_per_class 43 | BATCH_NUM_PER_CLASS = args.query_num_per_class 44 | EPISODE = args.episode 45 | TEST_EPISODE = args.test_episode 46 | TEST_ROUNd = args.test_round 47 | LEARNING_RATE = args.learning_rate 48 | GPU = args.gpu 49 | Tau = args.tau 50 | Alpha = args.alpha 51 | 52 | 53 | 54 | 55 | def mean_confidence_interval(data, confidence=0.95): 56 | a = 1.0*np.array(data) 57 | n = len(a) 58 | m, se = np.mean(a), scipy.stats.sem(a) 59 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 60 | return m,h 61 | 62 | 63 | class TNetwork(nn.Module): 64 | """docstring for RelationNetwork""" 65 | def __init__(self): 66 | super(TNetwork, self).__init__() 67 | self.word_feature1 = nn.Linear(400,1000) 68 | self.word_feature2 = nn.Linear(1000, 1000) 69 | self.features=nn.Linear(1000,1000) 70 | self.map=nn.Linear(1000,1000) 71 | self.norm = nn.BatchNorm1d(1000,momentum=1,affine=True) 72 | 73 | def forward(self,x,y): 74 | out1 = F.relu(self.word_feature1(x)) 75 | out1 = F.sigmoid(self.word_feature2(out1)) 76 | out2 = F.relu(self.features(y)) 77 | out = F.relu((out1 + 1) * out2) 78 | out = F.relu(self.map(out)) 79 | 80 | return out 81 | 82 | class Generate_word(nn.Module): 83 | """docstring for RelationNetwork""" 84 | def __init__(self): 85 | super(Generate_word, self).__init__() 86 | self.word_feature1 = nn.Linear(1000,400) 87 | #self.word_feature2 = nn.Linear(1000,400) 88 | 89 | def forward(self,x): 90 | out1 = F.relu(self.word_feature1(x)) 91 | #out2 = F.relu(self.word_feature2(y)) 92 | return out1 93 | 94 | class Task_norm(nn.Module): 95 | """docstring for RelationNetwork""" 96 | def __init__(self): 97 | super(Task_norm, self).__init__() 98 | self.norm = nn.BatchNorm1d(1000, momentum=1, affine=True) 99 | #self.word_feature2 = nn.Linear(1000,400) 100 | 101 | def forward(self,x): 102 | if len(x.size())==3: 103 | x = x.view(-1,1000) 104 | out = self.norm(x) 105 | out = out.view(5,-1,1000) 106 | else: 107 | out = self.norm(x) 108 | 109 | #out2 = F.relu(self.word_feature2(y)) 110 | return out 111 | 112 | def euclidean_dist2(x,y): 113 | n = x.size(0) 114 | m = y.size(0) 115 | d = x.size(1) 116 | if d !=y.size(1): 117 | raise Exception 118 | x = x.unsqueeze(1).expand(n,m,d) 119 | y = y.unsqueeze(0).expand(n,m,d) 120 | return torch.pow(x-y,2).sum(2) 121 | 122 | def euclidean_dist(x,y): 123 | x1 = x.size(0) 124 | x2 = x.size(1) 125 | x3 = x.size(2) 126 | y1 = y.size(0) 127 | y2 = y.size(1) 128 | if x3 !=y2: 129 | raise Exception 130 | #x = x.unsqueeze(1).expand(n,m,d) 131 | y = y.unsqueeze(0).expand(x1,x2,x3) 132 | return torch.pow(x-y,2).sum(2) 133 | 134 | def step_seq(a,x): 135 | x[x last_accuracy: 387 | 388 | # save networks 389 | torch.save(generate_word_noun.state_dict(),str("./models/"+Dataset+"_generate_word_noun_"+str(CLASS_NUM)+"way"+str(SAMPLE_NUM_PER_CLASS)+"shot" + "_max.pkl")) 390 | torch.save(generate_word_verb.state_dict(),str("./models/"+Dataset+"_generate_word_verb_"+str(CLASS_NUM)+"way"+str(SAMPLE_NUM_PER_CLASS)+"shot" + "_max.pkl")) 391 | torch.save(tnetwork.state_dict(),str("./models/"+Dataset+"_tnetwork_"+str(CLASS_NUM)+"way"+str(SAMPLE_NUM_PER_CLASS)+"shot" + "_max.pkl")) 392 | torch.save(tvnetwork.state_dict(),str("./models/"+Dataset+"_tvnetwork_"+str(CLASS_NUM)+"way"+str(SAMPLE_NUM_PER_CLASS)+"shot" + "_max.pkl")) 393 | torch.save(task_norm.state_dict(), str("./models/"+Dataset+"_task_norm_"+str(CLASS_NUM)+"way"+str(SAMPLE_NUM_PER_CLASS)+"shot" + "_max.pkl")) 394 | 395 | print("save networks for episode:",episode) 396 | best_episode = episode 397 | last_accuracy = test_accuracy 398 | print('best accuracy:',last_accuracy,"best networks from episode:",best_episode) 399 | 400 | 401 | 402 | print("********************************") 403 | print("Testing on test * 10...") 404 | if os.path.exists('./models/'+Dataset+'_tvnetwork_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl'): 405 | tvnetwork.load_state_dict(torch.load('./models/'+Dataset+'_tvnetwork_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl')) 406 | tnetwork.load_state_dict(torch.load('./models/'+Dataset+'_tnetwork_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl')) 407 | print("load noun-verb part success") 408 | if os.path.exists('./models/'+Dataset+'_generate_word_verb_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl'): 409 | generate_word_noun.load_state_dict(torch.load('./models/'+Dataset+'_generate_word_noun_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl')) 410 | generate_word_verb.load_state_dict(torch.load('./models/'+Dataset+'_generate_word_verb_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl')) 411 | task_norm.load_state_dict(torch.load('./models/'+Dataset+'_task_norm_5way'+str(SAMPLE_NUM_PER_CLASS)+'shot_max.pkl')) 412 | print("load generate_wordNet success") 413 | 414 | total_accuracy = 0.0 415 | H_all = 0.0 416 | metatrain_folders, metatest_folders = tg_test.dataset_folders(Dataset) 417 | for episode in range(TEST_ROUNd): 418 | # test 419 | 420 | 421 | accuracies = [] 422 | for i in range(TEST_EPISODE): 423 | total_rewards = 0 424 | task = tg_test.MiniImagenetTask(metatest_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, 1) 425 | sample_dataloader = tg_test.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, 426 | split="train", shuffle=False) 427 | num_per_class = 1 428 | test_dataloader = tg_test.get_mini_imagenet_data_loader(task, num_per_class=num_per_class, split="test", 429 | shuffle=False) 430 | sample_images, sample_labels = sample_dataloader.__iter__().next() 431 | for test_images, test_labels in test_dataloader: 432 | batch_size = test_labels.shape[0] 433 | # calculate features 434 | sample_features = sample_images.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, 1000).cuda(GPU) 435 | test_features = test_images.cuda(GPU) # 20x64 436 | 437 | sample_features = task_norm(sample_features) 438 | #test_features = task_norm(test_features) 439 | 440 | n_vector1 = generate_word_noun(sample_features) 441 | sample_noun = tnetwork(n_vector1, sample_features) 442 | sample_features1 = sample_features - step_seq(Tau,sample_noun) 443 | v_vector1 = generate_word_verb(sample_features1) 444 | sample_verb = tvnetwork(v_vector1, sample_features1) 445 | sample_features2 = sample_features1 + sample_noun + sample_verb 446 | 447 | fake_n_vector = generate_word_noun(test_features) 448 | test_noun = tnetwork(fake_n_vector, test_features) 449 | test_features1 = test_features - step_seq(Tau,test_noun) 450 | fake_v_vector = generate_word_verb(test_features1) 451 | test_verb = tvnetwork(fake_v_vector, test_features1) 452 | test_features2 = test_features1 + test_noun + test_verb 453 | 454 | prototypes = torch.mean(sample_features2, 1) 455 | dists = euclidean_dist2(test_features2.view(-1, 1000), prototypes) 456 | log_p_y = F.log_softmax(-dists, dim=1) 457 | 458 | _, predict_labels = torch.min(log_p_y.data, 1) 459 | 460 | rewards = [1 if predict_labels[j] == test_labels.cuda()[j] else 0 for j in range(batch_size)] 461 | 462 | total_rewards += np.sum(rewards) 463 | 464 | 465 | accuracy = total_rewards/1.0/CLASS_NUM/1 466 | accuracies.append(accuracy) 467 | 468 | test_accuracy,h = mean_confidence_interval(accuracies) 469 | 470 | print("test accuracy:",test_accuracy,"h:",h) 471 | 472 | total_accuracy += test_accuracy 473 | H_all += h 474 | 475 | print("aver_accuracy:", total_accuracy / EPISODE, "aver_h:", H_all / EPISODE) 476 | 477 | 478 | if __name__ == '__main__': 479 | main() 480 | --------------------------------------------------------------------------------