├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── image_folder.py └── imagenet.py ├── evaluate_awa2.py ├── evaluate_imagenet.py ├── materials ├── awa2-split.json ├── glove.py ├── imagenet-split.json ├── imagenet-testsets.json ├── imagenet-xml-wnids.json ├── make_dense_graph.py ├── make_dense_grouped_graph.py ├── make_induced_graph.py └── process_resnet.py ├── models ├── __init__.py ├── gcn.py ├── gcn_dense.py ├── gcn_dense_att.py └── resnet.py ├── train_gcn_basic.py ├── train_gcn_dense.py ├── train_gcn_dense_att.py ├── train_resnet_fit.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yinbo Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense Graph Propagation 2 | 3 | The code for the paper [Rethinking Knowledge Graph Propagation for Zero-Shot Learning](https://arxiv.org/abs/1805.11724). 4 | 5 | ### Citation 6 | ``` 7 | @inproceedings{kampffmeyer2019rethinking, 8 | title={Rethinking knowledge graph propagation for zero-shot learning}, 9 | author={Kampffmeyer, Michael and Chen, Yinbo and Liang, Xiaodan and Wang, Hao and Zhang, Yujia and Xing, Eric P}, 10 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 11 | pages={11487--11496}, 12 | year={2019} 13 | } 14 | ``` 15 | 16 | ## Requirements 17 | 18 | * python 3 19 | * pytorch 0.4.0 20 | * nltk 21 | 22 | ## Instructions 23 | 24 | ### Materials Preparation 25 | 26 | There is a folder `materials/`, which contains some meta data and programs already. 27 | 28 | #### Glove Word Embedding 29 | 1. Download: http://nlp.stanford.edu/data/glove.6B.zip 30 | 2. Unzip it, find and put `glove.6B.300d.txt` to `materials/`. 31 | 32 | #### Graphs 33 | 1. `cd materials/` 34 | 2. Run `python make_induced_graph.py`, get `imagenet-induced-graph.json` 35 | 3. Run `python make_dense_graph.py`, get `imagenet-dense-graph.json` 36 | 3. Run `python make_dense_grouped_graph.py`, get `imagenet-dense-grouped-graph.json` 37 | 38 | #### Pretrained ResNet50 39 | 1. Download: https://download.pytorch.org/models/resnet50-19c8e357.pth 40 | 2. Rename and put it as `materials/resnet50-raw.pth` 41 | 3. `cd materials/`, run `python process_resnet.py`, get `fc-weights.json` and `resnet50-base.pth` 42 | 43 | #### ImageNet and AwA2 44 | 45 | Download ImageNet and AwA2, create the softlinks (command `ln -s`): `materials/datasets/imagenet` and `materials/datasets/awa2`, to the root directory of the dataset. 46 | 47 | An ImageNet root directory should contain image folders, each folder with the wordnet id of the class. 48 | 49 | An AwA2 root directory should contain the folder JPEGImages. 50 | 51 | ### Training 52 | 53 | Make a directory `save/` for saving models. 54 | 55 | In most programs, use `--gpu` to specify the devices to run the code (default: use gpu 0). 56 | 57 | #### Train Graph Networks 58 | * SGCN: Run `python train_gcn_basic.py`, get results in `save/gcn-basic` 59 | * DGP: Run `python train_gcn_dense_att.py`, get results in `save/gcn-dense-att` 60 | 61 | In the results folder: 62 | * `*.pth` is the state dict of Graph Networks model 63 | * `*.pred` is the prediction file, which can be loaded by `torch.load()`. It is a python dict, having two keys: `wnids` - the wordnet ids of the predicted classes, `pred` - the predicted fc weights 64 | 65 | #### Finetune ResNet 66 | Run `python train_resnet_fit.py` with the args: 67 | * `--pred`: the `.pred` file for finetuning 68 | * `--train-dir`: the directory contains 1K imagenet training classes, each class with a folder named by its wordnet id 69 | * `--save-path`: the folder you want to save the result, e.g. `save/resnet-fit-xxx` 70 | 71 | (In the paper's setting, --train-dir is the folder composed of 1K classes from fall2011.tar, with the missing class "teddy bear" from ILSVRC2012.) 72 | 73 | ### Testing 74 | 75 | #### ImageNet 76 | Run `python evaluate_imagenet.py` with the args: 77 | * `--cnn`: path to resnet50 weights, e.g. `materials/resnet50-base.pth` or `save/resnet-fit-xxx/x.pth` 78 | * `--pred`: the `.pred` file for testing 79 | * `--test-set`: load test set in `materials/imagenet-testsets.json`, choices: `[2-hops, 3-hops, all]` 80 | * (optional) `--keep-ratio` for the ratio of testing data, `--consider-trains` to include training classes' classifiers, `--test-train` for testing with train classes images only. 81 | 82 | #### AwA2 83 | Run `python evaluate_awa2.py` with the args: 84 | * `--cnn`: path to resnet50 weights, e.g. `materials/resnet50-base.pth` or `save/resnet-fit-xxx/x.pth` 85 | * `--pred`: the `.pred` file for testing 86 | * (optional) `--consider-trains` to include training classes' classifiers 87 | 88 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class ImageFolder(Dataset): 11 | 12 | def __init__(self, path, classes, stage='train'): 13 | self.data = [] 14 | for i, c in enumerate(classes): 15 | cls_path = osp.join(path, c) 16 | images = os.listdir(cls_path) 17 | for image in images: 18 | self.data.append((osp.join(cls_path, image), i)) 19 | 20 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 21 | std=[0.229, 0.224, 0.225]) 22 | 23 | if stage == 'train': 24 | self.transforms = transforms.Compose([transforms.RandomResizedCrop(224), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | normalize]) 28 | if stage == 'test': 29 | self.transforms = transforms.Compose([transforms.Resize(256), 30 | transforms.CenterCrop(224), 31 | transforms.ToTensor(), 32 | normalize]) 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def __getitem__(self, i): 38 | path, label = self.data[i] 39 | image = Image.open(path).convert('RGB') 40 | image = self.transforms(image) 41 | if image.shape[0] != 3 or image.shape[1] != 224 or image.shape[2] != 224: 42 | print('you should delete this guy:', path) 43 | return image, label 44 | 45 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import random 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | from torchvision import get_image_backend 12 | 13 | 14 | class ImageNet(): 15 | 16 | def __init__(self, path): 17 | self.path = path 18 | self.keep_ratio = 1.0 19 | 20 | def get_subset(self, wnid): 21 | path = osp.join(self.path, wnid) 22 | return ImageNetSubset(path, wnid, keep_ratio=self.keep_ratio) 23 | 24 | def set_keep_ratio(self, r): 25 | self.keep_ratio = r 26 | 27 | 28 | class ImageNetSubset(Dataset): 29 | 30 | def __init__(self, path, wnid, keep_ratio=1.0): 31 | self.wnid = wnid 32 | 33 | def pil_loader(path): 34 | with open(path, 'rb') as f: 35 | try: 36 | img = Image.open(f) 37 | except OSError: 38 | return None 39 | return img.convert('RGB') 40 | 41 | def accimage_loader(path): 42 | import accimage 43 | try: 44 | return accimage.Image(path) 45 | except IOError: 46 | return pil_loader(path) 47 | 48 | def default_loader(path): 49 | if get_image_backend() == 'accimage': 50 | return accimage_loader(path) 51 | else: 52 | return pil_loader(path) 53 | 54 | # get file list 55 | all_files = os.listdir(path) 56 | files = [] 57 | for f in all_files: 58 | if f.endswith('.JPEG'): 59 | files.append(f) 60 | random.shuffle(files) 61 | files = files[:max(1, round(len(files) * keep_ratio))] 62 | 63 | # read images 64 | data = [] 65 | for filename in files: 66 | image = default_loader(osp.join(path, filename)) 67 | if image is None: 68 | continue 69 | # pytorch model-zoo pre-process 70 | preprocess = transforms.Compose([ 71 | transforms.Resize(256), 72 | transforms.CenterCrop(224), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 75 | std=[0.229, 0.224, 0.225]) 76 | ]) 77 | data.append(preprocess(image)) 78 | if data != []: 79 | self.data = torch.stack(data) 80 | else: 81 | self.data = [] 82 | 83 | def __len__(self): 84 | return len(self.data) 85 | 86 | def __getitem__(self, idx): 87 | return self.data[idx], self.wnid 88 | 89 | -------------------------------------------------------------------------------- /evaluate_awa2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os.path as osp 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from models.resnet import make_resnet50_base 9 | from datasets.image_folder import ImageFolder 10 | from utils import set_gpu, pick_vectors 11 | 12 | 13 | def test_on_subset(dataset, cnn, n, pred_vectors, all_label, 14 | consider_trains): 15 | hit = 0 16 | tot = 0 17 | 18 | loader = DataLoader(dataset=dataset, batch_size=32, 19 | shuffle=False, num_workers=2) 20 | 21 | for batch_id, batch in enumerate(loader, 1): 22 | data, label = batch 23 | data = data.cuda() 24 | 25 | feat = cnn(data) # (batch_size, d) 26 | feat = torch.cat([feat, torch.ones(len(feat)).view(-1, 1).cuda()], dim=1) 27 | 28 | fcs = pred_vectors.t() 29 | 30 | table = torch.matmul(feat, fcs) 31 | if not consider_trains: 32 | table[:, :n] = -1e18 33 | 34 | pred = torch.argmax(table, dim=1) 35 | hit += (pred == all_label).sum().item() 36 | tot += len(data) 37 | 38 | return hit, tot 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--cnn') 44 | parser.add_argument('--pred') 45 | 46 | parser.add_argument('--gpu', default='0') 47 | parser.add_argument('--consider-trains', action='store_true') 48 | 49 | parser.add_argument('--output', default=None) 50 | args = parser.parse_args() 51 | 52 | set_gpu(args.gpu) 53 | 54 | awa2_split = json.load(open('materials/awa2-split.json', 'r')) 55 | train_wnids = awa2_split['train'] 56 | test_wnids = awa2_split['test'] 57 | 58 | print('train: {}, test: {}'.format(len(train_wnids), len(test_wnids))) 59 | print('consider train classifiers: {}'.format(args.consider_trains)) 60 | 61 | pred_file = torch.load(args.pred) 62 | pred_wnids = pred_file['wnids'] 63 | pred_vectors = pred_file['pred'] 64 | pred_dic = dict(zip(pred_wnids, pred_vectors)) 65 | pred_vectors = pick_vectors(pred_dic, train_wnids + test_wnids, is_tensor=True).cuda() 66 | pred_vectors = pred_vectors.cuda() 67 | 68 | n = len(train_wnids) 69 | m = len(test_wnids) 70 | 71 | cnn = make_resnet50_base() 72 | cnn.load_state_dict(torch.load(args.cnn)) 73 | cnn = cnn.cuda() 74 | cnn.eval() 75 | 76 | test_names = awa2_split['test_names'] 77 | 78 | ave_acc = 0; ave_acc_n = 0 79 | 80 | results = {} 81 | 82 | awa2_path = 'materials/datasets/awa2' 83 | 84 | for i, name in enumerate(test_names, 1): 85 | dataset = ImageFolder(osp.join(awa2_path, 'JPEGImages'), [name], 'test') 86 | hit, tot = test_on_subset(dataset, cnn, n, pred_vectors, n + i - 1, 87 | args.consider_trains) 88 | acc = hit / tot 89 | ave_acc += acc 90 | ave_acc_n += 1 91 | 92 | print('{} {}: {:.2f}%'.format(i, name.replace('+', ' '), acc * 100)) 93 | 94 | results[name] = acc 95 | 96 | print('summary: {:.2f}%'.format(ave_acc / ave_acc_n * 100)) 97 | 98 | if args.output is not None: 99 | json.dump(results, open(args.output, 'w')) 100 | -------------------------------------------------------------------------------- /evaluate_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os.path as osp 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | from models.resnet import make_resnet50_base 10 | from datasets.imagenet import ImageNet 11 | from utils import set_gpu, pick_vectors 12 | 13 | 14 | def test_on_subset(dataset, cnn, n, pred_vectors, all_label, 15 | consider_trains): 16 | top = [1, 2, 5, 10, 20] 17 | hits = torch.zeros(len(top)).cuda() 18 | tot = 0 19 | 20 | loader = DataLoader(dataset=dataset, batch_size=32, 21 | shuffle=False, num_workers=2) 22 | 23 | for batch_id, batch in enumerate(loader, 1): 24 | data, label = batch 25 | data = data.cuda() 26 | 27 | feat = cnn(data) # (batch_size, d) 28 | feat = torch.cat([feat, torch.ones(len(feat)).view(-1, 1).cuda()], dim=1) 29 | 30 | fcs = pred_vectors.t() 31 | 32 | table = torch.matmul(feat, fcs) 33 | if not consider_trains: 34 | table[:, :n] = -1e18 35 | 36 | gth_score = table[:, all_label].repeat(table.shape[1], 1).t() 37 | rks = (table >= gth_score).sum(dim=1) 38 | 39 | assert (table[:, all_label] == gth_score[:, all_label]).min() == 1 40 | 41 | for i, k in enumerate(top): 42 | hits[i] += (rks <= k).sum().item() 43 | tot += len(data) 44 | 45 | return hits, tot 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--cnn') 51 | parser.add_argument('--pred') 52 | 53 | parser.add_argument('--test-set') 54 | 55 | parser.add_argument('--output', default=None) 56 | 57 | parser.add_argument('--gpu', default='0') 58 | 59 | parser.add_argument('--keep-ratio', type=float, default=0.1) 60 | parser.add_argument('--consider-trains', action='store_true') 61 | parser.add_argument('--test-train', action='store_true') 62 | 63 | args = parser.parse_args() 64 | 65 | set_gpu(args.gpu) 66 | 67 | test_sets = json.load(open('materials/imagenet-testsets.json', 'r')) 68 | train_wnids = test_sets['train'] 69 | test_wnids = test_sets[args.test_set] 70 | 71 | print('test set: {}, {} classes, ratio={}' 72 | .format(args.test_set, len(test_wnids), args.keep_ratio)) 73 | print('consider train classifiers: {}'.format(args.consider_trains)) 74 | 75 | pred_file = torch.load(args.pred) 76 | pred_wnids = pred_file['wnids'] 77 | pred_vectors = pred_file['pred'] 78 | pred_dic = dict(zip(pred_wnids, pred_vectors)) 79 | pred_vectors = pick_vectors(pred_dic, train_wnids + test_wnids, is_tensor=True).cuda() 80 | 81 | pred_vectors = pred_vectors.cuda() 82 | 83 | n = len(train_wnids) 84 | m = len(test_wnids) 85 | 86 | cnn = make_resnet50_base() 87 | cnn.load_state_dict(torch.load(args.cnn)) 88 | cnn = cnn.cuda() 89 | cnn.eval() 90 | 91 | TEST_TRAIN = args.test_train 92 | 93 | imagenet_path = 'materials/datasets/imagenet' 94 | dataset = ImageNet(imagenet_path) 95 | dataset.set_keep_ratio(args.keep_ratio) 96 | 97 | s_hits = torch.FloatTensor([0, 0, 0, 0, 0]).cuda() # top 1 2 5 10 20 98 | s_tot = 0 99 | 100 | results = {} 101 | 102 | if TEST_TRAIN: 103 | for i, wnid in enumerate(train_wnids, 1): 104 | subset = dataset.get_subset(wnid) 105 | hits, tot = test_on_subset(subset, cnn, n, pred_vectors, i - 1, 106 | consider_trains=args.consider_trains) 107 | results[wnid] = (hits / tot).tolist() 108 | 109 | s_hits += hits 110 | s_tot += tot 111 | 112 | print('{}/{}, {}:'.format(i, len(train_wnids), wnid), end=' ') 113 | for i in range(len(hits)): 114 | print('{:.0f}%({:.2f}%)' 115 | .format(hits[i] / tot * 100, s_hits[i] / s_tot * 100), end=' ') 116 | print('x{}({})'.format(tot, s_tot)) 117 | else: 118 | for i, wnid in enumerate(test_wnids, 1): 119 | subset = dataset.get_subset(wnid) 120 | hits, tot = test_on_subset(subset, cnn, n, pred_vectors, n + i - 1, 121 | consider_trains=args.consider_trains) 122 | results[wnid] = (hits / tot).tolist() 123 | 124 | s_hits += hits 125 | s_tot += tot 126 | 127 | print('{}/{}, {}:'.format(i, len(test_wnids), wnid), end=' ') 128 | for i in range(len(hits)): 129 | print('{:.0f}%({:.2f}%)' 130 | .format(hits[i] / tot * 100, s_hits[i] / s_tot * 100), end=' ') 131 | print('x{}({})'.format(tot, s_tot)) 132 | 133 | print('summary:', end=' ') 134 | for s_hit in s_hits: 135 | print('{:.2f}%'.format(s_hit / s_tot * 100), end=' ') 136 | print('total {}'.format(s_tot)) 137 | 138 | if args.output is not None: 139 | json.dump(results, open(args.output, 'w')) 140 | 141 | -------------------------------------------------------------------------------- /materials/awa2-split.json: -------------------------------------------------------------------------------- 1 | {"train": ["n02071294", "n02363005", "n02110341", "n02123394", "n02106662", "n02123597", "n02445715", "n01889520", "n02129604", "n02398521", "n02128385", "n02493793", "n02503517", "n02480855", "n02403003", "n02481823", "n02342885", "n02118333", "n02355227", "n02324045", "n02114100", "n02085620", "n02441942", "n02444819", "n02410702", "n02391049", "n02510455", "n02395406", "n02129165", "n02134084", "n02106030", "n02403454", "n02430045", "n02330245", "n02065726", "n02419796", "n02132580", "n02391994", "n02508021", "n02432983"], "test": ["n02411705", "n02068974", "n02139199", "n02076196", "n02064816", "n02331046", "n02374451", "n02081571", "n02439033", "n02127482"], "train_names": ["killer+whale", "beaver", "dalmatian", "persian+cat", "german+shepherd", "siamese+cat", "skunk", "mole", "tiger", "hippopotamus", "leopard", "spider+monkey", "elephant", "gorilla", "ox", "chimpanzee", "hamster", "fox", "squirrel", "rabbit", "wolf", "chihuahua", "weasel", "otter", "buffalo", "zebra", "giant+panda", "pig", "lion", "polar+bear", "collie", "cow", "deer", "mouse", "humpback+whale", "antelope", "grizzly+bear", "rhinoceros", "raccoon", "moose"], "test_names": ["sheep", "dolphin", "bat", "seal", "blue+whale", "rat", "horse", "walrus", "giraffe", "bobcat"]} -------------------------------------------------------------------------------- /materials/glove.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GloVe(): 5 | 6 | def __init__(self, file_path): 7 | self.dimension = None 8 | self.embedding = dict() 9 | with open(file_path, 'r') as f: 10 | for line in f.readlines(): 11 | strs = line.strip().split() 12 | word = strs[0] 13 | vector = torch.FloatTensor(list(map(float, strs[1:]))) 14 | self.embedding[word] = vector 15 | if self.dimension is None: 16 | self.dimension = len(vector) 17 | 18 | def _fix_word(self, word): 19 | terms = word.replace('_', ' ').split(' ') 20 | ret = self.zeros() 21 | cnt = 0 22 | for term in terms: 23 | v = self.embedding.get(term) 24 | if v is None: 25 | subterms = term.split('-') 26 | subterm_sum = self.zeros() 27 | subterm_cnt = 0 28 | for subterm in subterms: 29 | subv = self.embedding.get(subterm) 30 | if subv is not None: 31 | subterm_sum += subv 32 | subterm_cnt += 1 33 | if subterm_cnt > 0: 34 | v = subterm_sum / subterm_cnt 35 | if v is not None: 36 | ret += v 37 | cnt += 1 38 | return ret / cnt if cnt > 0 else None 39 | 40 | def __getitem__(self, words): 41 | if type(words) is str: 42 | words = [words] 43 | ret = self.zeros() 44 | cnt = 0 45 | for word in words: 46 | v = self.embedding.get(word) 47 | if v is None: 48 | v = self._fix_word(word) 49 | if v is not None: 50 | ret += v 51 | cnt += 1 52 | if cnt > 0: 53 | return ret / cnt 54 | else: 55 | return self.zeros() 56 | 57 | def zeros(self): 58 | return torch.zeros(self.dimension) 59 | 60 | -------------------------------------------------------------------------------- /materials/make_dense_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--input', default='imagenet-induced-graph.json') 6 | parser.add_argument('--output', default='imagenet-dense-graph.json') 7 | args = parser.parse_args() 8 | 9 | js = json.load(open(args.input, 'r')) 10 | wnids = js['wnids'] 11 | vectors = js['vectors'] 12 | edges = js['edges'] 13 | 14 | n = len(wnids) 15 | adjs = {} 16 | for i in range(n): 17 | adjs[i] = [] 18 | for u, v in edges: 19 | adjs[u].append(v) 20 | 21 | new_edges = [] 22 | 23 | for u, wnid in enumerate(wnids): 24 | q = [u] 25 | l = 0 26 | d = {} 27 | d[u] = 0 28 | while l < len(q): 29 | x = q[l] 30 | l += 1 31 | for y in adjs[x]: 32 | if d.get(y) is None: 33 | d[y] = d[x] + 1 34 | q.append(y) 35 | for x, dis in d.items(): 36 | new_edges.append((u, x)) 37 | 38 | json.dump({'wnids': wnids, 'vectors': vectors, 'edges': new_edges}, 39 | open(args.output, 'w')) 40 | 41 | -------------------------------------------------------------------------------- /materials/make_dense_grouped_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--input', default='imagenet-induced-graph.json') 6 | parser.add_argument('--output', default='imagenet-dense-grouped-graph.json') 7 | args = parser.parse_args() 8 | 9 | js = json.load(open(args.input, 'r')) 10 | wnids = js['wnids'] 11 | vectors = js['vectors'] 12 | edges = js['edges'] 13 | 14 | n = len(wnids) 15 | adjs = {} 16 | for i in range(n): 17 | adjs[i] = [] 18 | for u, v in edges: 19 | adjs[u].append(v) 20 | 21 | new_edges = [[] for i in range(99)] 22 | 23 | for u, wnid in enumerate(wnids): 24 | q = [u] 25 | l = 0 26 | d = {} 27 | d[u] = 0 28 | while l < len(q): 29 | x = q[l] 30 | l += 1 31 | for y in adjs[x]: 32 | if d.get(y) is None: 33 | d[y] = d[x] + 1 34 | q.append(y) 35 | for x, dis in d.items(): 36 | new_edges[dis].append((u, x)) 37 | 38 | while new_edges[-1] == []: 39 | new_edges.pop() 40 | 41 | json.dump({'wnids': wnids, 'vectors': vectors, 'edges_set': new_edges}, 42 | open(args.output, 'w')) 43 | 44 | -------------------------------------------------------------------------------- /materials/make_induced_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from nltk.corpus import wordnet as wn 5 | import torch 6 | 7 | from glove import GloVe 8 | 9 | 10 | def getnode(x): 11 | return wn.synset_from_pos_and_offset('n', int(x[1:])) 12 | 13 | 14 | def getwnid(u): 15 | s = str(u.offset()) 16 | return 'n' + (8 - len(s)) * '0' + s 17 | 18 | 19 | def getedges(s): 20 | dic = {x: i for i, x in enumerate(s)} 21 | edges = [] 22 | for i, u in enumerate(s): 23 | for v in u.hypernyms(): 24 | j = dic.get(v) 25 | if j is not None: 26 | edges.append((i, j)) 27 | return edges 28 | 29 | 30 | def induce_parents(s, stop_set): 31 | q = s 32 | vis = set(s) 33 | l = 0 34 | while l < len(q): 35 | u = q[l] 36 | l += 1 37 | if u in stop_set: 38 | continue 39 | for p in u.hypernyms(): 40 | if p not in vis: 41 | vis.add(p) 42 | q.append(p) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--input', default='imagenet-split.json') 48 | parser.add_argument('--output', default='imagenet-induced-graph.json') 49 | args = parser.parse_args() 50 | 51 | print('making graph ...') 52 | 53 | xml_wnids = json.load(open('imagenet-xml-wnids.json', 'r')) 54 | xml_nodes = list(map(getnode, xml_wnids)) 55 | xml_set = set(xml_nodes) 56 | 57 | js = json.load(open(args.input, 'r')) 58 | train_wnids = js['train'] 59 | test_wnids = js['test'] 60 | 61 | key_wnids = train_wnids + test_wnids 62 | 63 | s = list(map(getnode, key_wnids)) 64 | induce_parents(s, xml_set) 65 | 66 | s_set = set(s) 67 | for u in xml_nodes: 68 | if u not in s_set: 69 | s.append(u) 70 | 71 | wnids = list(map(getwnid, s)) 72 | edges = getedges(s) 73 | 74 | print('making glove embedding ...') 75 | 76 | glove = GloVe('glove.6B.300d.txt') 77 | vectors = [] 78 | for wnid in wnids: 79 | vectors.append(glove[getnode(wnid).lemma_names()]) 80 | vectors = torch.stack(vectors) 81 | 82 | print('dumping ...') 83 | 84 | obj = {} 85 | obj['wnids'] = wnids 86 | obj['vectors'] = vectors.tolist() 87 | obj['edges'] = edges 88 | json.dump(obj, open(args.output, 'w')) 89 | 90 | -------------------------------------------------------------------------------- /materials/process_resnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | p = torch.load('resnet50-raw.pth') 6 | w = p['fc.weight'].data 7 | b = p['fc.bias'].data 8 | 9 | p.pop('fc.weight') 10 | p.pop('fc.bias') 11 | torch.save(p, 'resnet50-base.pth') 12 | 13 | v = torch.cat([w, b.unsqueeze(1)], dim=1).tolist() 14 | wnids = json.load(open('imagenet-split.json', 'r'))['train'] 15 | wnids = sorted(wnids) 16 | obj = [] 17 | for i in range(len(wnids)): 18 | obj.append((wnids[i], v[i])) 19 | json.dump(obj, open('fc-weights.json', 'w')) 20 | 21 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.init import xavier_uniform_ 8 | 9 | from utils import normt_spm, spm_to_tensor 10 | 11 | 12 | class GraphConv(nn.Module): 13 | 14 | def __init__(self, in_channels, out_channels, dropout=False, relu=True): 15 | super().__init__() 16 | 17 | if dropout: 18 | self.dropout = nn.Dropout(p=0.5) 19 | else: 20 | self.dropout = None 21 | 22 | self.w = nn.Parameter(torch.empty(in_channels, out_channels)) 23 | self.b = nn.Parameter(torch.zeros(out_channels)) 24 | xavier_uniform_(self.w) 25 | 26 | if relu: 27 | self.relu = nn.LeakyReLU(negative_slope=0.2) 28 | else: 29 | self.relu = None 30 | 31 | def forward(self, inputs, adj): 32 | if self.dropout is not None: 33 | inputs = self.dropout(inputs) 34 | 35 | outputs = torch.mm(adj, torch.mm(inputs, self.w)) + self.b 36 | 37 | if self.relu is not None: 38 | outputs = self.relu(outputs) 39 | return outputs 40 | 41 | 42 | class GCN(nn.Module): 43 | 44 | def __init__(self, n, edges, in_channels, out_channels, hidden_layers): 45 | super().__init__() 46 | 47 | edges = np.array(edges) 48 | adj = sp.coo_matrix((np.ones(len(edges)), (edges[:, 0], edges[:, 1])), 49 | shape=(n, n), dtype='float32') 50 | adj = normt_spm(adj, method='in') 51 | adj = spm_to_tensor(adj) 52 | self.adj = adj.cuda() 53 | 54 | hl = hidden_layers.split(',') 55 | if hl[-1] == 'd': 56 | dropout_last = True 57 | hl = hl[:-1] 58 | else: 59 | dropout_last = False 60 | 61 | i = 0 62 | layers = [] 63 | last_c = in_channels 64 | for c in hl: 65 | if c[0] == 'd': 66 | dropout = True 67 | c = c[1:] 68 | else: 69 | dropout = False 70 | c = int(c) 71 | 72 | i += 1 73 | conv = GraphConv(last_c, c, dropout=dropout) 74 | self.add_module('conv{}'.format(i), conv) 75 | layers.append(conv) 76 | 77 | last_c = c 78 | 79 | conv = GraphConv(last_c, out_channels, relu=False, dropout=dropout_last) 80 | self.add_module('conv-last', conv) 81 | layers.append(conv) 82 | 83 | self.layers = layers 84 | 85 | def forward(self, x): 86 | for conv in self.layers: 87 | x = conv(x, self.adj) 88 | return F.normalize(x) 89 | 90 | -------------------------------------------------------------------------------- /models/gcn_dense.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.init import xavier_uniform_ 8 | 9 | from utils import normt_spm, spm_to_tensor 10 | 11 | 12 | class GraphConv(nn.Module): 13 | 14 | def __init__(self, in_channels, out_channels, dropout=False, relu=True): 15 | super().__init__() 16 | 17 | if dropout: 18 | self.dropout = nn.Dropout(p=0.5) 19 | else: 20 | self.dropout = None 21 | 22 | self.w = nn.Parameter(torch.empty(in_channels, out_channels)) 23 | self.b = nn.Parameter(torch.zeros(out_channels)) 24 | xavier_uniform_(self.w) 25 | 26 | if relu: 27 | self.relu = nn.LeakyReLU(negative_slope=0.2) 28 | else: 29 | self.relu = None 30 | 31 | def forward(self, inputs, adj): 32 | if self.dropout is not None: 33 | inputs = self.dropout(inputs) 34 | 35 | outputs = torch.mm(adj, torch.mm(inputs, self.w)) + self.b 36 | 37 | if self.relu is not None: 38 | outputs = self.relu(outputs) 39 | return outputs 40 | 41 | 42 | class GCN_Dense(nn.Module): 43 | 44 | def __init__(self, n, edges, in_channels, out_channels, hidden_layers): 45 | super().__init__() 46 | 47 | edges = np.array(edges) 48 | adj = sp.coo_matrix((np.ones(len(edges)), (edges[:, 0], edges[:, 1])), 49 | shape=(n, n), dtype='float32') 50 | self.adj = spm_to_tensor(normt_spm(adj, method='in')).cuda() 51 | self.r_adj = spm_to_tensor(normt_spm(adj.transpose(), method='in')).cuda() 52 | 53 | hl = hidden_layers.split(',') 54 | if hl[-1] == 'd': 55 | dropout_last = True 56 | hl = hl[:-1] 57 | else: 58 | dropout_last = False 59 | 60 | i = 0 61 | layers = [] 62 | last_c = in_channels 63 | for c in hl: 64 | if c[0] == 'd': 65 | dropout = True 66 | c = c[1:] 67 | else: 68 | dropout = False 69 | c = int(c) 70 | 71 | i += 1 72 | conv = GraphConv(last_c, c, dropout=dropout) 73 | self.add_module('conv{}'.format(i), conv) 74 | layers.append(conv) 75 | 76 | last_c = c 77 | 78 | conv = GraphConv(last_c, out_channels, relu=False, dropout=dropout_last) 79 | self.add_module('conv-last', conv) 80 | layers.append(conv) 81 | 82 | self.layers = layers 83 | 84 | def forward(self, x): 85 | graph_side = True 86 | for conv in self.layers: 87 | if graph_side: 88 | x = conv(x, self.adj) 89 | else: 90 | x = conv(x, self.r_adj) 91 | graph_side = not graph_side 92 | return F.normalize(x) 93 | 94 | -------------------------------------------------------------------------------- /models/gcn_dense_att.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.init import xavier_uniform_ 8 | 9 | from utils import normt_spm, spm_to_tensor 10 | 11 | 12 | class GraphConv(nn.Module): 13 | 14 | def __init__(self, in_channels, out_channels, dropout=False, relu=True): 15 | super().__init__() 16 | 17 | if dropout: 18 | self.dropout = nn.Dropout(p=0.5) 19 | else: 20 | self.dropout = None 21 | 22 | self.w = nn.Parameter(torch.empty(in_channels, out_channels)) 23 | self.b = nn.Parameter(torch.zeros(out_channels)) 24 | xavier_uniform_(self.w) 25 | 26 | if relu: 27 | self.relu = nn.LeakyReLU(negative_slope=0.2) 28 | else: 29 | self.relu = None 30 | 31 | def forward(self, inputs, adj_set, att): 32 | if self.dropout is not None: 33 | inputs = self.dropout(inputs) 34 | 35 | support = torch.mm(inputs, self.w) + self.b 36 | outputs = None 37 | for i, adj in enumerate(adj_set): 38 | y = torch.mm(adj, support) * att[i] 39 | if outputs is None: 40 | outputs = y 41 | else: 42 | outputs = outputs + y 43 | 44 | if self.relu is not None: 45 | outputs = self.relu(outputs) 46 | return outputs 47 | 48 | 49 | class GCN_Dense_Att(nn.Module): 50 | 51 | def __init__(self, n, edges_set, in_channels, out_channels, hidden_layers): 52 | super().__init__() 53 | 54 | self.n = n 55 | self.d = len(edges_set) 56 | 57 | self.a_adj_set = [] 58 | self.r_adj_set = [] 59 | 60 | for edges in edges_set: 61 | edges = np.array(edges) 62 | adj = sp.coo_matrix((np.ones(len(edges)), (edges[:, 0], edges[:, 1])), 63 | shape=(n, n), dtype='float32') 64 | a_adj = spm_to_tensor(normt_spm(adj, method='in')).cuda() 65 | r_adj = spm_to_tensor(normt_spm(adj.transpose(), method='in')).cuda() 66 | self.a_adj_set.append(a_adj) 67 | self.r_adj_set.append(r_adj) 68 | 69 | hl = hidden_layers.split(',') 70 | if hl[-1] == 'd': 71 | dropout_last = True 72 | hl = hl[:-1] 73 | else: 74 | dropout_last = False 75 | 76 | self.a_att = nn.Parameter(torch.ones(self.d)) 77 | self.r_att = nn.Parameter(torch.ones(self.d)) 78 | 79 | i = 0 80 | layers = [] 81 | last_c = in_channels 82 | for c in hl: 83 | if c[0] == 'd': 84 | dropout = True 85 | c = c[1:] 86 | else: 87 | dropout = False 88 | c = int(c) 89 | 90 | i += 1 91 | conv = GraphConv(last_c, c, dropout=dropout) 92 | self.add_module('conv{}'.format(i), conv) 93 | layers.append(conv) 94 | 95 | last_c = c 96 | 97 | conv = GraphConv(last_c, out_channels, relu=False, dropout=dropout_last) 98 | self.add_module('conv-last', conv) 99 | layers.append(conv) 100 | 101 | self.layers = layers 102 | 103 | def forward(self, x): 104 | graph_side = True 105 | for conv in self.layers: 106 | if graph_side: 107 | adj_set = self.a_adj_set 108 | att = self.a_att 109 | else: 110 | adj_set = self.r_adj_set 111 | att = self.r_att 112 | att = F.softmax(att, dim=0) 113 | x = conv(x, adj_set, att) 114 | graph_side = not graph_side 115 | 116 | return F.normalize(x) 117 | 118 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ['ResNetBase', 'make_resnet_base', 'ResNet'] 8 | 9 | 10 | ''' 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | ''' 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | return out 95 | 96 | 97 | class ResNetBase(nn.Module): 98 | 99 | def __init__(self, block, layers): 100 | self.inplanes = 64 101 | super(ResNetBase, self).__init__() 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | 112 | self.out_channels = 512 * block.expansion 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = x.view(x.shape[0], x.shape[1], -1).mean(dim=2) 151 | return x 152 | 153 | 154 | def make_resnet18_base(**kwargs): 155 | """Constructs a ResNet-18 model. 156 | """ 157 | model = ResNetBase(BasicBlock, [2, 2, 2, 2], **kwargs) 158 | return model 159 | 160 | 161 | def make_resnet34_base(**kwargs): 162 | """Constructs a ResNet-34 model. 163 | """ 164 | model = ResNetBase(BasicBlock, [3, 4, 6, 3], **kwargs) 165 | return model 166 | 167 | 168 | def make_resnet50_base(**kwargs): 169 | """Constructs a ResNet-50 model. 170 | """ 171 | model = ResNetBase(Bottleneck, [3, 4, 6, 3], **kwargs) 172 | return model 173 | 174 | 175 | def make_resnet101_base(**kwargs): 176 | """Constructs a ResNet-101 model. 177 | """ 178 | model = ResNetBase(Bottleneck, [3, 4, 23, 3], **kwargs) 179 | return model 180 | 181 | 182 | def make_resnet152_base(**kwargs): 183 | """Constructs a ResNet-152 model. 184 | """ 185 | model = ResNetBase(Bottleneck, [3, 8, 36, 3], **kwargs) 186 | return model 187 | 188 | 189 | def make_resnet_base(version, pretrained=None): 190 | maker = { 191 | 'resnet18': make_resnet18_base, 192 | 'resnet34': make_resnet34_base, 193 | 'resnet50': make_resnet50_base, 194 | 'resnet101': make_resnet101_base, 195 | 'resnet152': make_resnet152_base 196 | } 197 | resnet = maker[version]() 198 | if pretrained is not None: 199 | sd = torch.load(pretrained) 200 | sd.pop('fc.weight') 201 | sd.pop('fc.bias') 202 | resnet.load_state_dict(sd) 203 | return resnet 204 | 205 | 206 | class ResNet(nn.Module): 207 | 208 | def __init__(self, version, num_classes, pretrained=None): 209 | super().__init__() 210 | self.resnet_base = make_resnet_base(version, pretrained=pretrained) 211 | self.fc = nn.Linear(self.resnet_base.out_channels, num_classes) 212 | 213 | def forward(self, x, need_features=False): 214 | x = self.resnet_base(x) 215 | feat = x 216 | x = self.fc(x) 217 | if need_features: 218 | return x, feat 219 | else: 220 | return x 221 | 222 | -------------------------------------------------------------------------------- /train_gcn_basic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import os.path as osp 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from utils import ensure_path, set_gpu, l2_loss 10 | from models.gcn import GCN 11 | 12 | 13 | def save_checkpoint(name): 14 | torch.save(gcn.state_dict(), osp.join(save_path, name + '.pth')) 15 | torch.save(pred_obj, osp.join(save_path, name + '.pred')) 16 | 17 | 18 | def mask_l2_loss(a, b, mask): 19 | return l2_loss(a[mask], b[mask]) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--max-epoch', type=int, default=3000) 26 | parser.add_argument('--trainval', default='10,0') 27 | parser.add_argument('--lr', type=float, default=0.001) 28 | parser.add_argument('--weight-decay', type=float, default=0.0005) 29 | parser.add_argument('--save-epoch', type=int, default=300) 30 | parser.add_argument('--save-path', default='save/gcn-basic') 31 | 32 | parser.add_argument('--gpu', default='0') 33 | 34 | parser.add_argument('--no-pred', action='store_true') 35 | args = parser.parse_args() 36 | 37 | set_gpu(args.gpu) 38 | 39 | save_path = args.save_path 40 | ensure_path(save_path) 41 | 42 | graph = json.load(open('materials/imagenet-induced-graph.json', 'r')) 43 | wnids = graph['wnids'] 44 | n = len(wnids) 45 | edges = graph['edges'] 46 | 47 | edges = edges + [(v, u) for (u, v) in edges] 48 | edges = edges + [(u, u) for u in range(n)] 49 | 50 | word_vectors = torch.tensor(graph['vectors']).cuda() 51 | word_vectors = F.normalize(word_vectors) 52 | 53 | fcfile = json.load(open('materials/fc-weights.json', 'r')) 54 | train_wnids = [x[0] for x in fcfile] 55 | fc_vectors = [x[1] for x in fcfile] 56 | assert train_wnids == wnids[:len(train_wnids)] 57 | fc_vectors = torch.tensor(fc_vectors).cuda() 58 | fc_vectors = F.normalize(fc_vectors) 59 | 60 | hidden_layers = 'd2048,d' 61 | gcn = GCN(n, edges, word_vectors.shape[1], fc_vectors.shape[1], hidden_layers).cuda() 62 | 63 | print('{} nodes, {} edges'.format(n, len(edges))) 64 | print('word vectors:', word_vectors.shape) 65 | print('fc vectors:', fc_vectors.shape) 66 | print('hidden layers:', hidden_layers) 67 | 68 | optimizer = torch.optim.Adam(gcn.parameters(), lr=args.lr, weight_decay=args.weight_decay) 69 | 70 | v_train, v_val = map(float, args.trainval.split(',')) 71 | n_trainval = len(fc_vectors) 72 | n_train = round(n_trainval * (v_train / (v_train + v_val))) 73 | print('num train: {}, num val: {}'.format(n_train, n_trainval - n_train)) 74 | tlist = list(range(len(fc_vectors))) 75 | random.shuffle(tlist) 76 | 77 | min_loss = 1e18 78 | 79 | trlog = {} 80 | trlog['train_loss'] = [] 81 | trlog['val_loss'] = [] 82 | trlog['min_loss'] = 0 83 | 84 | for epoch in range(1, args.max_epoch + 1): 85 | gcn.train() 86 | output_vectors = gcn(word_vectors) 87 | loss = mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]) 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | gcn.eval() 93 | output_vectors = gcn(word_vectors) 94 | train_loss = mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]).item() 95 | if v_val > 0: 96 | val_loss = mask_l2_loss(output_vectors, fc_vectors, tlist[n_train:]).item() 97 | loss = val_loss 98 | else: 99 | val_loss = 0 100 | loss = train_loss 101 | print('epoch {}, train_loss={:.4f}, val_loss={:.4f}' 102 | .format(epoch, train_loss, val_loss)) 103 | 104 | trlog['train_loss'].append(train_loss) 105 | trlog['val_loss'].append(val_loss) 106 | trlog['min_loss'] = min_loss 107 | torch.save(trlog, osp.join(save_path, 'trlog')) 108 | 109 | if (epoch % args.save_epoch == 0): 110 | if args.no_pred: 111 | pred_obj = None 112 | else: 113 | pred_obj = { 114 | 'wnids': wnids, 115 | 'pred': output_vectors 116 | } 117 | 118 | if epoch % args.save_epoch == 0: 119 | save_checkpoint('epoch-{}'.format(epoch)) 120 | 121 | pred_obj = None 122 | 123 | -------------------------------------------------------------------------------- /train_gcn_dense.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import os.path as osp 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from utils import ensure_path, set_gpu, l2_loss 10 | from models.gcn_dense import GCN_Dense 11 | 12 | 13 | def save_checkpoint(name): 14 | torch.save(gcn.state_dict(), osp.join(save_path, name + '.pth')) 15 | torch.save(pred_obj, osp.join(save_path, name + '.pred')) 16 | 17 | 18 | def mask_l2_loss(a, b, mask): 19 | return l2_loss(a[mask], b[mask]) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--max-epoch', type=int, default=3000) 26 | parser.add_argument('--trainval', default='10,0') 27 | parser.add_argument('--lr', type=float, default=0.001) 28 | parser.add_argument('--weight-decay', type=float, default=0.0005) 29 | parser.add_argument('--save-epoch', type=int, default=300) 30 | parser.add_argument('--save-path', default='save/gcn-dense') 31 | 32 | parser.add_argument('--gpu', default='0') 33 | 34 | parser.add_argument('--no-pred', action='store_true') 35 | args = parser.parse_args() 36 | 37 | set_gpu(args.gpu) 38 | 39 | save_path = args.save_path 40 | ensure_path(save_path) 41 | 42 | graph = json.load(open('materials/imagenet-dense-graph.json', 'r')) 43 | wnids = graph['wnids'] 44 | n = len(wnids) 45 | edges = graph['edges'] 46 | 47 | word_vectors = torch.tensor(graph['vectors']).cuda() 48 | word_vectors = F.normalize(word_vectors) 49 | 50 | fcfile = json.load(open('materials/fc-weights.json', 'r')) 51 | train_wnids = [x[0] for x in fcfile] 52 | fc_vectors = [x[1] for x in fcfile] 53 | assert train_wnids == wnids[:len(train_wnids)] 54 | fc_vectors = torch.tensor(fc_vectors).cuda() 55 | fc_vectors = F.normalize(fc_vectors) 56 | 57 | hidden_layers = 'd2048,d' 58 | gcn = GCN_Dense(n, edges, word_vectors.shape[1], fc_vectors.shape[1], hidden_layers).cuda() 59 | 60 | print('{} nodes, {} edges'.format(n, len(edges))) 61 | print('word vectors:', word_vectors.shape) 62 | print('fc vectors:', fc_vectors.shape) 63 | print('hidden layers:', hidden_layers) 64 | 65 | optimizer = torch.optim.Adam(gcn.parameters(), lr=args.lr, weight_decay=args.weight_decay) 66 | 67 | v_train, v_val = map(float, args.trainval.split(',')) 68 | n_trainval = len(fc_vectors) 69 | n_train = round(n_trainval * (v_train / (v_train + v_val))) 70 | print('num train: {}, num val: {}'.format(n_train, n_trainval - n_train)) 71 | tlist = list(range(len(fc_vectors))) 72 | random.shuffle(tlist) 73 | 74 | min_loss = 1e18 75 | 76 | trlog = {} 77 | trlog['train_loss'] = [] 78 | trlog['val_loss'] = [] 79 | trlog['min_loss'] = 0 80 | 81 | for epoch in range(1, args.max_epoch + 1): 82 | gcn.train() 83 | output_vectors = gcn(word_vectors) 84 | loss = mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]) 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | gcn.eval() 90 | output_vectors = gcn(word_vectors) 91 | train_loss = mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]).item() 92 | if v_val > 0: 93 | val_loss = mask_l2_loss(output_vectors, fc_vectors, tlist[n_train:]).item() 94 | loss = val_loss 95 | else: 96 | val_loss = 0 97 | loss = train_loss 98 | print('epoch {}, train_loss={:.4f}, val_loss={:.4f}' 99 | .format(epoch, train_loss, val_loss)) 100 | 101 | trlog['train_loss'].append(train_loss) 102 | trlog['val_loss'].append(val_loss) 103 | trlog['min_loss'] = min_loss 104 | torch.save(trlog, osp.join(save_path, 'trlog')) 105 | 106 | if (epoch % args.save_epoch == 0): 107 | if args.no_pred: 108 | pred_obj = None 109 | else: 110 | pred_obj = { 111 | 'wnids': wnids, 112 | 'pred': output_vectors 113 | } 114 | 115 | if epoch % args.save_epoch == 0: 116 | save_checkpoint('epoch-{}'.format(epoch)) 117 | 118 | pred_obj = None 119 | 120 | -------------------------------------------------------------------------------- /train_gcn_dense_att.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import os.path as osp 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from utils import ensure_path, set_gpu, l2_loss 10 | from models.gcn_dense_att import GCN_Dense_Att 11 | 12 | 13 | def save_checkpoint(name): 14 | torch.save(gcn.state_dict(), osp.join(save_path, name + '.pth')) 15 | torch.save(pred_obj, osp.join(save_path, name + '.pred')) 16 | 17 | 18 | def mask_l2_loss(a, b, mask): 19 | return l2_loss(a[mask], b[mask]) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--max-epoch', type=int, default=3000) 26 | parser.add_argument('--trainval', default='10,0') 27 | parser.add_argument('--lr', type=float, default=0.001) 28 | parser.add_argument('--weight-decay', type=float, default=0.0005) 29 | parser.add_argument('--save-epoch', type=int, default=300) 30 | parser.add_argument('--save-path', default='save/gcn-dense-att') 31 | 32 | parser.add_argument('--gpu', default='0') 33 | 34 | parser.add_argument('--no-pred', action='store_true') 35 | args = parser.parse_args() 36 | 37 | set_gpu(args.gpu) 38 | 39 | save_path = args.save_path 40 | ensure_path(save_path) 41 | 42 | graph = json.load(open('materials/imagenet-dense-grouped-graph.json', 'r')) 43 | wnids = graph['wnids'] 44 | n = len(wnids) 45 | 46 | edges_set = graph['edges_set'] 47 | print('edges_set', [len(l) for l in edges_set]) 48 | 49 | lim = 4 50 | for i in range(lim + 1, len(edges_set)): 51 | edges_set[lim].extend(edges_set[i]) 52 | edges_set = edges_set[:lim + 1] 53 | print('edges_set', [len(l) for l in edges_set]) 54 | 55 | word_vectors = torch.tensor(graph['vectors']).cuda() 56 | word_vectors = F.normalize(word_vectors) 57 | 58 | fcfile = json.load(open('materials/fc-weights.json', 'r')) 59 | train_wnids = [x[0] for x in fcfile] 60 | fc_vectors = [x[1] for x in fcfile] 61 | assert train_wnids == wnids[:len(train_wnids)] 62 | fc_vectors = torch.tensor(fc_vectors).cuda() 63 | fc_vectors = F.normalize(fc_vectors) 64 | 65 | hidden_layers = 'd2048,d' 66 | gcn = GCN_Dense_Att(n, edges_set, 67 | word_vectors.shape[1], fc_vectors.shape[1], hidden_layers).cuda() 68 | 69 | print('word vectors:', word_vectors.shape) 70 | print('fc vectors:', fc_vectors.shape) 71 | print('hidden layers:', hidden_layers) 72 | 73 | optimizer = torch.optim.Adam(gcn.parameters(), lr=args.lr, weight_decay=args.weight_decay) 74 | 75 | v_train, v_val = map(float, args.trainval.split(',')) 76 | n_trainval = len(fc_vectors) 77 | n_train = round(n_trainval * (v_train / (v_train + v_val))) 78 | print('num train: {}, num val: {}'.format(n_train, n_trainval - n_train)) 79 | tlist = list(range(len(fc_vectors))) 80 | random.shuffle(tlist) 81 | 82 | min_loss = 1e18 83 | 84 | trlog = {} 85 | trlog['train_loss'] = [] 86 | trlog['val_loss'] = [] 87 | trlog['min_loss'] = 0 88 | 89 | for epoch in range(1, args.max_epoch + 1): 90 | gcn.train() 91 | output_vectors = gcn(word_vectors) 92 | loss = mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]) 93 | optimizer.zero_grad() 94 | loss.backward() 95 | optimizer.step() 96 | 97 | gcn.eval() 98 | output_vectors = gcn(word_vectors) 99 | train_loss = mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]).item() 100 | if v_val > 0: 101 | val_loss = mask_l2_loss(output_vectors, fc_vectors, tlist[n_train:]).item() 102 | loss = val_loss 103 | else: 104 | val_loss = 0 105 | loss = train_loss 106 | print('epoch {}, train_loss={:.4f}, val_loss={:.4f}' 107 | .format(epoch, train_loss, val_loss)) 108 | 109 | trlog['train_loss'].append(train_loss) 110 | trlog['val_loss'].append(val_loss) 111 | trlog['min_loss'] = min_loss 112 | torch.save(trlog, osp.join(save_path, 'trlog')) 113 | 114 | if (epoch % args.save_epoch == 0): 115 | if args.no_pred: 116 | pred_obj = None 117 | else: 118 | pred_obj = { 119 | 'wnids': wnids, 120 | 'pred': output_vectors 121 | } 122 | 123 | if epoch % args.save_epoch == 0: 124 | save_checkpoint('epoch-{}'.format(epoch)) 125 | 126 | pred_obj = None 127 | 128 | -------------------------------------------------------------------------------- /train_resnet_fit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as osp 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | 11 | from utils import set_gpu, ensure_path 12 | from models.resnet import ResNet 13 | from datasets.image_folder import ImageFolder 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--pred') 19 | parser.add_argument('--train-dir') 20 | parser.add_argument('--save-path', default='save/resnet-fit') 21 | parser.add_argument('--gpu', default='0') 22 | args = parser.parse_args() 23 | 24 | set_gpu(args.gpu) 25 | save_path = args.save_path 26 | ensure_path(save_path) 27 | 28 | pred = torch.load(args.pred) 29 | train_wnids = sorted(os.listdir(args.train_dir)) 30 | 31 | train_dir = args.train_dir 32 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]) 34 | train_dataset = datasets.ImageFolder(train_dir, transforms.Compose([ 35 | transforms.RandomResizedCrop(224), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | normalize])) 39 | loader = torch.utils.data.DataLoader( 40 | train_dataset, batch_size=64, shuffle=True, 41 | num_workers=4, pin_memory=True, sampler=None) 42 | 43 | assert pred['wnids'][:1000] == train_wnids 44 | 45 | model = ResNet('resnet50', 1000) 46 | sd = model.resnet_base.state_dict() 47 | sd.update(torch.load('materials/resnet50-base.pth')) 48 | model.resnet_base.load_state_dict(sd) 49 | 50 | fcw = pred['pred'][:1000].cpu() 51 | model.fc.weight = nn.Parameter(fcw[:, :-1]) 52 | model.fc.bias = nn.Parameter(fcw[:, -1]) 53 | 54 | model = model.cuda() 55 | model.train() 56 | 57 | optimizer = torch.optim.SGD(model.resnet_base.parameters(), lr=0.0001, momentum=0.9) 58 | loss_fn = nn.CrossEntropyLoss().cuda() 59 | 60 | keep_ratio = 0.9975 61 | trlog = {} 62 | trlog['loss'] = [] 63 | trlog['acc'] = [] 64 | 65 | for epoch in range(1, 9999): 66 | 67 | ave_loss = None 68 | ave_acc = None 69 | 70 | for i, (data, label) in enumerate(loader, 1): 71 | data = data.cuda() 72 | label = label.cuda() 73 | 74 | logits = model(data) 75 | loss = loss_fn(logits, label) 76 | 77 | _, pred = torch.max(logits, dim=1) 78 | acc = torch.eq(pred, label).type(torch.FloatTensor).mean().item() 79 | 80 | if i == 1: 81 | ave_loss = loss.item() 82 | ave_acc = acc 83 | else: 84 | ave_loss = ave_loss * keep_ratio + loss.item() * (1 - keep_ratio) 85 | ave_acc = ave_acc * keep_ratio + acc * (1 - keep_ratio) 86 | 87 | print('epoch {}, {}/{}, loss={:.4f} ({:.4f}), acc={:.4f} ({:.4f})' 88 | .format(epoch, i, len(loader), loss.item(), ave_loss, acc, ave_acc)) 89 | 90 | optimizer.zero_grad() 91 | loss.backward() 92 | optimizer.step() 93 | 94 | trlog['loss'].append(ave_loss) 95 | trlog['acc'].append(ave_acc) 96 | 97 | torch.save(trlog, osp.join(save_path, 'trlog')) 98 | 99 | torch.save(model.resnet_base.state_dict(), 100 | osp.join(save_path, 'epoch-{}.pth'.format(epoch))) 101 | 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | 5 | import numpy as np 6 | import scipy.sparse as sp 7 | import torch 8 | 9 | 10 | def ensure_path(path): 11 | if osp.exists(path): 12 | if input('{} exists, remove? ([y]/n)'.format(path)) != 'n': 13 | shutil.rmtree(path) 14 | os.mkdir(path) 15 | else: 16 | os.mkdir(path) 17 | 18 | 19 | def set_gpu(gpu): 20 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 21 | print('using gpu {}'.format(gpu)) 22 | 23 | 24 | def pick_vectors(dic, wnids, is_tensor=False): 25 | o = next(iter(dic.values())) 26 | dim = len(o) 27 | ret = [] 28 | for wnid in wnids: 29 | v = dic.get(wnid) 30 | if v is None: 31 | if not is_tensor: 32 | v = [0] * dim 33 | else: 34 | v = torch.zeros(dim) 35 | ret.append(v) 36 | if not is_tensor: 37 | return torch.FloatTensor(ret) 38 | else: 39 | return torch.stack(ret) 40 | 41 | 42 | def l2_loss(a, b): 43 | return ((a - b)**2).sum() / (len(a) * 2) 44 | 45 | 46 | def normt_spm(mx, method='in'): 47 | if method == 'in': 48 | mx = mx.transpose() 49 | rowsum = np.array(mx.sum(1)) 50 | r_inv = np.power(rowsum, -1).flatten() 51 | r_inv[np.isinf(r_inv)] = 0. 52 | r_mat_inv = sp.diags(r_inv) 53 | mx = r_mat_inv.dot(mx) 54 | return mx 55 | 56 | if method == 'sym': 57 | rowsum = np.array(mx.sum(1)) 58 | r_inv = np.power(rowsum, -0.5).flatten() 59 | r_inv[np.isinf(r_inv)] = 0. 60 | r_mat_inv = sp.diags(r_inv) 61 | mx = mx.dot(r_mat_inv).transpose().dot(r_mat_inv) 62 | return mx 63 | 64 | 65 | def spm_to_tensor(sparse_mx): 66 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 67 | indices = torch.from_numpy(np.vstack( 68 | (sparse_mx.row, sparse_mx.col))).long() 69 | values = torch.from_numpy(sparse_mx.data) 70 | shape = torch.Size(sparse_mx.shape) 71 | return torch.sparse.FloatTensor(indices, values, shape) 72 | 73 | --------------------------------------------------------------------------------