├── LICENSE ├── README.md ├── bpe_simple_vocab_16e6.txt.gz ├── data.py ├── data ├── ini.txt ├── lab_list.txt ├── mapped_edges.txt ├── node_f.npy └── train_text.txt ├── data_graph.py ├── main_test.py ├── main_test_amazon.py ├── main_train.py ├── main_train_amazon.py ├── meta_net ├── init.txt ├── main_cog2p2_amazon.py ├── main_cog2p2_cora.py ├── model_cocoop.py ├── task_amazon.py └── task_cora.py ├── model.py ├── model_g_coop.py ├── multitask.py ├── multitask_amazon.py ├── requirements.txt ├── res └── cora │ └── init.txt ├── simple_tokenizer.py └── zero-shot ├── datahelper.py ├── zero-shot-amazon.py └── zero-shot-cora.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 123456-abc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt Tuning on Graph-augmented Low-resource Text Classification 2 | We provide the implementation of G2P2 and G2P2* model, which is the source code for the TKDE journal 3 | 4 | "Prompt Tuning on Graph-augmented Low-resource Text Classification", and the link is https://ieeexplore.ieee.org/abstract/document/10633805. 5 | 6 | The repository is organised as follows: 7 | - dataset/: the directory of data sets. Currently, it only has the dataset of Cora, if you want the three processed Amazon datasets, you can download and put them under this directory, the link is https://drive.google.com/drive/folders/1IzuYNIYDxr63GteBKeva-8KnAIhvqjMZ?usp=sharing. 8 | - res/: the directory of saved models. 9 | - meta_net/: the directoty of our G2P2* model 10 | - task_cora.py, task_amazon.py: data preprocessing for cora dataset and Amazon datasets 11 | - model_cocoop: model of G2P2* 12 | - main_cog2p2_cora.py, main_cog2p2_amazon.py: tuning and testing entrance for cora, tuning and testing entrance for Amazon datasets 13 | - bpe_simple_vocab_16e6.txt.gz: vocabulary for simple tokenization. 14 | - data.py, data_graph.py: for data loading utilization. 15 | - main_test.py, main_test_amazon.py: testing entrance for cora, testing entrance for Amazon datasets. 16 | - main_train.py, main_train_amazon.py: pre-training entrance for cora, pre-training entrance for Amazon datasets. 17 | - model.py, model_g_coop.py: model for pre-training, model for prompt tuning. 18 | - multitask.py, multitask_amazon.py: task generator for cora, task generator for Amazon datasets. 19 | - requirements.txt: the required packages. 20 | - simple_tokenizer: a simple tokenizer. 21 | 22 | 23 | # For pre-train: 24 | On Cora dataset, 25 | 26 | python main_train.py 27 | 28 | If on Amazon datasets, it should be: 29 | 30 | python main_train_amazon.py 31 | 32 | # For testing: 33 | (1) For G2P2, 34 | On Cora dataset, 35 | 36 | python main_test.py 37 | 38 | If on Amazon datasets, it should be: 39 | 40 | python main_test_amazon.py 41 | 42 | (2) For G2P2*, 43 | 44 | cd meta_net 45 | 46 | On Cora dataset, 47 | 48 | python main_cog2p2_cora.py 49 | 50 | If on Amazon datasets, it should be: 51 | 52 | python main_cog2p2_amazon.py 53 | 54 | 55 | -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenZhihao666/G2P2-conditional/b93bc1b6c36c52c05cf1bb7795ba9cf64451615a/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | 5 | class DataHelper(Dataset): 6 | def __init__(self, edge_index, args, directed=False, transform=None): 7 | # self.num_nodes = len(node_list) 8 | self.transform = transform 9 | 10 | self.degrees = dict() 11 | self.node_set = set() 12 | self.neighs = dict() 13 | self.args = args 14 | 15 | idx, degree = np.unique(edge_index, return_counts=True) 16 | for i in range(idx.shape[0]): 17 | self.degrees[idx[i]] = degree[i].item() 18 | 19 | self.node_dim = idx.shape[0] 20 | print('lenth of dataset', self.node_dim) 21 | 22 | train_edge_index = edge_index 23 | self.final_edge_index = train_edge_index.T 24 | 25 | for i in range(self.final_edge_index.shape[0]): 26 | s_node = self.final_edge_index[i][0].item() 27 | t_node = self.final_edge_index[i][1].item() 28 | 29 | if s_node not in self.neighs: 30 | self.neighs[s_node] = [] 31 | if t_node not in self.neighs: 32 | self.neighs[t_node] = [] 33 | 34 | self.neighs[s_node].append(t_node) 35 | if not directed: 36 | self.neighs[t_node].append(s_node) 37 | 38 | # self.neighs = sorted(self.neighs) 39 | self.idx = idx 40 | 41 | 42 | def __len__(self): 43 | return self.node_dim 44 | 45 | def __getitem__(self, idx): 46 | 47 | s_n = self.idx[idx].item() 48 | t_n = [np.random.choice(self.neighs[s_n], replace=True).item() for _ in range(self.args.neigh_num)] 49 | t_n = np.array(t_n) 50 | 51 | sample = { 52 | 's_n': s_n, # e.g., 5424 53 | 't_n': t_n, # e.g., 5427 54 | # 'neg_n': neg_n 55 | } 56 | 57 | if self.transform: 58 | sample = self.transform(sample) 59 | 60 | return sample 61 | -------------------------------------------------------------------------------- /data/ini.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/lab_list.txt: -------------------------------------------------------------------------------- 1 | artificial intelligence, agents artificial intelligence, data mining artificial intelligence, expert systems artificial intelligence, games and search artificial intelligence, knowledge representation artificial intelligence, machine learning, case-based artificial intelligence, machine learning, genetic algorithms artificial intelligence, machine learning, neural networks artificial intelligence, machine learning, probabilistic methods artificial intelligence, machine learning, reinforcement learning artificial intelligence, machine learning, rule learning artificial intelligence, machine learning, theory artificial intelligence, nlp artificial intelligence, planning artificial intelligence, robotics artificial intelligence, speech artificial intelligence, theorem proving artificial intelligence, vision and pattern recognition data structures algorithms and theory, computational complexity data structures algorithms and theory, computational geometry data structures algorithms and theory, formal languages data structures algorithms and theory, hashing data structures algorithms and theory, logic data structures algorithms and theory, parallel data structures algorithms and theory, quantum computing data structures algorithms and theory, randomized data structures algorithms and theory, sorting databases, concurrency databases, deductive databases, object oriented databases, performance databases, query evaluation databases, relational databases, temporal encryption and compression, compression encryption and compression, encryption encryption and compression, security hardware and architecture, distributed architectures hardware and architecture, high performance computing hardware and architecture, input output and storage hardware and architecture, logic design hardware and architecture, memory structures hardware and architecture, microprogramming hardware and architecture, vlsi human computer interaction, cooperative human computer interaction, graphics and virtual reality human computer interaction, interface design human computer interaction, multimedia human computer interaction, wearable computers information retrieval, digital library information retrieval, extraction information retrieval, filtering information retrieval, retrieval nan networking, internet networking, protocols networking, routing networking, wireless operating systems, distributed operating systems, fault tolerance operating systems, memory management operating systems, realtime programming, compiler design programming, debugging programming, functional programming, garbage collection programming, java programming, logic programming, object oriented programming, semantics programming, software development -------------------------------------------------------------------------------- /data/node_f.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenZhihao666/G2P2-conditional/b93bc1b6c36c52c05cf1bb7795ba9cf64451615a/data/node_f.npy -------------------------------------------------------------------------------- /data_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | 5 | class DataHelper(Dataset): 6 | def __init__(self, edge_index, args, the_nodes, directed=False, transform=None): 7 | # self.num_nodes = len(node_list) 8 | self.transform = transform 9 | self.degrees = dict() 10 | self.node_set = set() 11 | self.neighs = dict() 12 | self.args = args 13 | 14 | idx, degree = np.unique(edge_index, return_counts=True) 15 | # for (a, b) in (idx, degree): 16 | for i in range(idx.shape[0]): 17 | self.degrees[idx[i]] = degree[i].item() 18 | 19 | self.node_dim = idx.shape[0] 20 | 21 | train_edge_index = edge_index # .T#[:int(0.8 * edge_nums)] 22 | 23 | self.final_edge_index = train_edge_index.T 24 | 25 | for i in range(self.final_edge_index.shape[0]): 26 | s_node = self.final_edge_index[i][0].item() 27 | t_node = self.final_edge_index[i][1].item() 28 | 29 | if s_node not in self.neighs: 30 | self.neighs[s_node] = [] 31 | if t_node not in self.neighs: 32 | self.neighs[t_node] = [] 33 | 34 | self.neighs[s_node].append(t_node) 35 | if not directed: 36 | self.neighs[t_node].append(s_node) 37 | 38 | self.idx = idx 39 | self.the_nodes = the_nodes 40 | 41 | def __len__(self): 42 | return len(self.the_nodes) 43 | 44 | def __getitem__(self, idx): 45 | 46 | s_n = self.the_nodes[idx]#.item() 47 | if len(self.neighs[s_n]) > self.args.neigh_num: 48 | t_n = np.random.choice(self.neighs[s_n], self.args.neigh_num, replace=False) 49 | else: 50 | t_n = np.random.choice(self.neighs[s_n], self.args.neigh_num, replace=True) 51 | # t_n = np.array(t_n) 52 | 53 | sample = { 54 | 's_n': s_n, # e.g., 5424 55 | 't_n': t_n, # e.g., 5427 56 | # 'neg_n': neg_n 57 | } 58 | 59 | if self.transform: 60 | sample = self.transform(sample) 61 | 62 | return sample 63 | -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import argparse 4 | import torch 5 | from random import sample 6 | import random 7 | import math 8 | import time 9 | from model import CLIP, tokenize 10 | from torch import nn, optim 11 | from sklearn import preprocessing 12 | from sklearn.metrics import accuracy_score, f1_score 13 | # from multitask_2 import multitask_data_generator 14 | from multitask import multitask_data_generator 15 | from model_g_coop import CoOp 16 | import json 17 | from data_graph import DataHelper 18 | from torch.utils.data import DataLoader 19 | 20 | 21 | 22 | 23 | def setup_seed(seed): 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | 32 | def main(args): 33 | setup_seed(seed) 34 | 35 | clip_model = CLIP(args) 36 | clip_model.load_state_dict(torch.load('./res/{}/node_ttgt_8&12_0.1.pkl'.format(data_name), map_location=device)) 37 | 38 | task_list, train_idx, val_idx, test_idx = multitask_data_generator(lab_list, labeled_ids, labels, args.k_spt, 39 | args.k_val, args.k_qry, args.n_way) 40 | all_acc = [] 41 | f1_list = [] 42 | for j in range(len(task_list)): 43 | 44 | train_idx_ts = torch.from_numpy(np.array(train_idx[j])).to(device) 45 | val_idx_ts = torch.from_numpy(np.array(val_idx[j])).to(device) 46 | test_idx_ts = torch.from_numpy(np.array(test_idx[j])).to(device) 47 | 48 | train_truth = np.array(lab_list)[np.array(train_idx[j])] 49 | val_truth = np.array(lab_list)[np.array(val_idx[j])] 50 | test_truth = np.array(lab_list)[np.array(test_idx[j])] 51 | 52 | task_lables_arr = np.array(labels)[task_list[j]] 53 | task_labels_dict = dict() 54 | for i in range(task_lables_arr.shape[0]): 55 | task_labels_dict[task_lables_arr[i]] = i 56 | 57 | train_truth_ts = [task_labels_dict[train_truth[i]] for i in range(len(train_truth))] 58 | train_truth_ts = torch.from_numpy(np.array(train_truth_ts)).to(device) 59 | 60 | val_truth_ts = [task_labels_dict[val_truth[i]] for i in range(len(val_truth))] 61 | val_truth_ts = torch.from_numpy(np.array(val_truth_ts)).to(device) 62 | 63 | test_truth_ts = [task_labels_dict[test_truth[i]] for i in range(len(test_truth))] 64 | test_truth_ts = torch.from_numpy(np.array(test_truth_ts)).to(device) 65 | 66 | task_lables = task_lables_arr.tolist() 67 | Data = DataHelper(arr_edge_index, args, train_idx[j]) 68 | loader = DataLoader(Data, batch_size=args.batch_size, shuffle=False, num_workers=0) 69 | for i_batch, sample_batched in enumerate(loader): 70 | s_n = sample_batched['s_n'].numpy() 71 | t_n = sample_batched['t_n'].numpy() 72 | s_n = s_n.reshape(args.num_labels, args.k_spt) 73 | t_n = t_n.reshape(args.num_labels, args.k_spt * args.neigh_num) 74 | temp = [] 75 | for i in range(args.num_labels): 76 | temp.append(np.concatenate((s_n[i], t_n[i]))) 77 | g_texts = [] 78 | for i in range(len(temp)): 79 | g_text = [tit_list[a] for a in temp[i]] 80 | g_texts.append(g_text) 81 | 82 | model = CoOp(args, task_lables, clip_model, g_texts, device) 83 | 84 | best_val = 0 85 | patience = 10 86 | counter = 0 87 | 88 | for epoch in range(1, args.ft_epoch + 1): 89 | # print('----epoch:' + str(epoch)) 90 | model.train() 91 | train_logits = model.forward(train_idx_ts, node_f, edge_index, train_truth_ts) 92 | 93 | model.eval() 94 | with torch.no_grad(): 95 | res = model.forward(val_idx_ts, node_f, edge_index, val_truth_ts, training=False) 96 | val_acc = accuracy_score(val_truth_ts.cpu(), res.argmax(dim=1).cpu()) 97 | if val_acc <= best_val: 98 | counter += 1 99 | if counter >= patience: 100 | break 101 | else: 102 | best_val = val_acc 103 | torch.save(model, './res/{}/g_coop.pkl'.format(data_name)) 104 | counter = 0 105 | # print('{}th_task_best_val'.format(j), round(best_val, 4)) 106 | 107 | best_model = torch.load('./res/{}/g_coop.pkl'.format(data_name)) 108 | best_model.eval() 109 | with torch.no_grad(): 110 | res = model.forward(test_idx_ts, node_f, edge_index, test_truth_ts, training=False) 111 | test_acc = accuracy_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu()) 112 | all_acc.append(test_acc) 113 | f1 = f1_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu(), average='macro') 114 | f1_list.append(f1) 115 | 116 | ans = round(np.mean(all_acc).item(), 4) 117 | print('acc', ans) 118 | 119 | ans = round(np.mean(f1_list).item(), 4) 120 | print('macro f1', ans) 121 | 122 | 123 | if __name__ == '__main__': 124 | parser = argparse.ArgumentParser() 125 | 126 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 127 | parser.add_argument('--ft_epoch', type=int, default=50, help='fine-tune epoch') 128 | parser.add_argument('--lr', type=float, default=2e-5) 129 | 130 | parser.add_argument('--batch_size', type=int, default=64) 131 | parser.add_argument('--gnn_input', type=int, default=128) 132 | parser.add_argument('--gnn_hid', type=int, default=128) 133 | parser.add_argument('--gnn_output', type=int, default=128) 134 | 135 | parser.add_argument('--edge_coef', type=float, default=0.1) 136 | parser.add_argument('--neigh_num', type=int, default=3) 137 | 138 | parser.add_argument('--num_labels', type=int, default=5) 139 | parser.add_argument('--k_spt', type=int, default=5) 140 | parser.add_argument('--k_val', type=int, default=5) 141 | parser.add_argument('--k_qry', type=int, default=50) 142 | parser.add_argument('--n_way', type=int, default=5) 143 | 144 | parser.add_argument('--context_length', type=int, default=128) 145 | parser.add_argument('--coop_n_ctx', type=int, default=4) 146 | parser.add_argument('--prompt_lr', type=float, default=0.01) 147 | 148 | parser.add_argument('--position', type=str, default='end') 149 | parser.add_argument('--class_specific', type=bool, default=False) 150 | parser.add_argument('--ctx_init', type=bool, default=True) 151 | 152 | parser.add_argument('--embed_dim', type=int, default=128) 153 | parser.add_argument('--transformer_heads', type=int, default=8) 154 | parser.add_argument('--transformer_layers', type=int, default=12) 155 | parser.add_argument('--transformer_width', type=int, default=512) 156 | parser.add_argument('--vocab_size', type=int, default=49408) 157 | parser.add_argument('--gpu', type=int, default=0) 158 | 159 | args = parser.parse_args() 160 | 161 | data_name = 'cora' 162 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 163 | print('device:', device) 164 | # device = torch.device("cpu") 165 | FType = torch.FloatTensor 166 | LType = torch.LongTensor 167 | 168 | num_nodes = 0 169 | tit_list = [] 170 | lab_list = [] 171 | with open('./data/train_text.txt', 'r') as f: 172 | lines = f.readlines() 173 | for line in lines: 174 | line = line.strip().split('\t') 175 | tit_list.append(line[2]) 176 | lab_list.append(line[3]) 177 | num_nodes += 1 178 | 179 | print('num_nodes', num_nodes) 180 | 181 | labeled_ids = [] 182 | for i in range(len(lab_list)): 183 | if lab_list[i] != 'nan': 184 | labeled_ids.append(i) 185 | 186 | print('{} nodes having lables'.format(len(labeled_ids))) 187 | 188 | raw_edge_index = [[], []] 189 | with open('./data/mapped_edges.txt', 'r') as f: 190 | lines = f.readlines() 191 | for line in lines: 192 | line = line.strip().split() 193 | raw_edge_index[0].append(int(line[0])) 194 | raw_edge_index[1].append(int(line[1])) 195 | 196 | edge_index = [raw_edge_index[0] + raw_edge_index[1], raw_edge_index[1] + raw_edge_index[0]] 197 | arr_edge_index = np.array(edge_index) 198 | edge_index = np.array(edge_index) 199 | edge_index = torch.from_numpy(edge_index).to(device) 200 | 201 | node_f = np.load('./data/node_f.npy') 202 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 203 | node_f = torch.from_numpy(node_f).to(device) 204 | 205 | # label_texts = [] 206 | with open('./data/lab_list.txt', 'r') as f: 207 | line = f.readline().strip().split('\t') 208 | label_texts = line 209 | 210 | labels = [] 211 | for i in label_texts: 212 | if i != 'nan': 213 | labels.append(i) 214 | 215 | start = time.perf_counter() 216 | all_acc_list = [] 217 | all_macf1_list = [] 218 | 219 | seed = 1 220 | print('seed', seed) 221 | main(args) 222 | end = time.perf_counter() 223 | print("time consuming {:.2f}".format(end - start)) 224 | -------------------------------------------------------------------------------- /main_test_amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import torch 4 | from random import sample 5 | import random 6 | import math 7 | import time 8 | from model import CLIP, tokenize 9 | from torch import nn, optim 10 | from sklearn import preprocessing 11 | from sklearn.metrics import accuracy_score, f1_score 12 | from multitask_amazon import multitask_data_generator 13 | from model_g_coop import CoOp 14 | import json 15 | from data_graph import DataHelper 16 | from torch.utils.data import DataLoader 17 | 18 | 19 | 20 | 21 | def setup_seed(seed): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def main(args): 31 | setup_seed(seed) 32 | 33 | clip_model = CLIP(args) 34 | clip_model.load_state_dict(torch.load('./res/{}/node_ttgt_8&12_10.pkl'.format(args.data_name), map_location=device)) 35 | 36 | task_list, train_idx, val_idx, test_idx = multitask_data_generator(lab_list, labeled_ids, labels, args.k_spt, 37 | args.k_val, args.k_qry, args.n_way) 38 | all_acc = [] 39 | f1_list = [] 40 | for j in range(len(task_list)): 41 | 42 | train_idx_ts = torch.from_numpy(np.array(train_idx[j])).to(device) 43 | val_idx_ts = torch.from_numpy(np.array(val_idx[j])).to(device) 44 | test_idx_ts = torch.from_numpy(np.array(test_idx[j])).to(device) 45 | 46 | train_truth = [] 47 | for a in train_idx[j]: 48 | train_truth.append(id_lab_dict[str(a)]) 49 | 50 | val_truth = [] 51 | for a in val_idx[j]: 52 | val_truth.append(id_lab_dict[str(a)]) 53 | 54 | test_truth = [] 55 | for a in test_idx[j]: 56 | test_truth.append(id_lab_dict[str(a)]) 57 | 58 | task_lables_arr = np.array(labels)[task_list[j]] 59 | task_labels_dict = dict() 60 | for i in range(task_lables_arr.shape[0]): 61 | task_labels_dict[task_lables_arr[i]] = i 62 | 63 | train_truth_ts = [task_labels_dict[train_truth[i]] for i in range(len(train_truth))] 64 | train_truth_ts = torch.from_numpy(np.array(train_truth_ts)).to(device) 65 | 66 | val_truth_ts = [task_labels_dict[val_truth[i]] for i in range(len(val_truth))] 67 | val_truth_ts = torch.from_numpy(np.array(val_truth_ts)).to(device) 68 | 69 | test_truth_ts = [task_labels_dict[test_truth[i]] for i in range(len(test_truth))] 70 | test_truth_ts = torch.from_numpy(np.array(test_truth_ts)).to(device) 71 | 72 | task_lables = task_lables_arr.tolist() 73 | Data = DataHelper(arr_edge_index, args, train_idx[j]) 74 | loader = DataLoader(Data, batch_size=args.batch_size, shuffle=False, num_workers=0) 75 | for i_batch, sample_batched in enumerate(loader): 76 | s_n = sample_batched['s_n'].numpy() 77 | t_n = sample_batched['t_n'].numpy() 78 | s_n = s_n.reshape(args.num_labels, args.k_spt) 79 | t_n = t_n.reshape(args.num_labels, args.k_spt * args.neigh_num) 80 | temp = [] 81 | for i in range(args.num_labels): 82 | temp.append(np.concatenate((s_n[i], t_n[i]))) 83 | g_texts = [] 84 | for i in range(len(temp)): 85 | g_text = [new_dict[a] for a in temp[i]] 86 | g_texts.append(g_text) 87 | 88 | model = CoOp(args, task_lables, clip_model, g_texts, device) 89 | 90 | best_val = 0 91 | patience = 10 92 | counter = 0 93 | 94 | for epoch in range(1, args.ft_epoch + 1): 95 | # print('----epoch:' + str(epoch)) 96 | model.train() 97 | train_logits = model.forward(train_idx_ts, node_f, edge_index, train_truth_ts) 98 | 99 | model.eval() 100 | with torch.no_grad(): 101 | res = model.forward(val_idx_ts, node_f, edge_index, val_truth_ts, training=False) 102 | val_acc = accuracy_score(val_truth_ts.cpu(), res.argmax(dim=1).cpu()) 103 | if val_acc <= best_val: 104 | counter += 1 105 | if counter >= patience: 106 | break 107 | else: 108 | best_val = val_acc 109 | torch.save(model, './res/{}/g_coop.pkl'.format(args.data_name)) 110 | counter = 0 111 | # print('{}th_task_best_val'.format(j), round(best_val, 4)) 112 | 113 | best_model = torch.load('./res/{}/g_coop.pkl'.format(args.data_name)) 114 | best_model.eval() 115 | with torch.no_grad(): 116 | res = model.forward(test_idx_ts, node_f, edge_index, test_truth_ts, training=False) 117 | test_acc = accuracy_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu()) 118 | all_acc.append(test_acc) 119 | f1 = f1_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu(), average='macro') 120 | f1_list.append(f1) 121 | 122 | ans = round(np.mean(all_acc).item(), 4) 123 | print('acc', ans) 124 | 125 | ans = round(np.mean(f1_list).item(), 4) 126 | print('macro f1', ans) 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser() 130 | 131 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 132 | parser.add_argument('--ft_epoch', type=int, default=50, help='fine-tune epoch') 133 | parser.add_argument('--lr', type=float, default=2e-5) 134 | 135 | parser.add_argument('--batch_size', type=int, default=1000) 136 | parser.add_argument('--gnn_input', type=int, default=128) 137 | parser.add_argument('--gnn_hid', type=int, default=128) 138 | parser.add_argument('--gnn_output', type=int, default=128) 139 | 140 | parser.add_argument('--edge_coef', type=float, default=0.1) 141 | parser.add_argument('--neigh_num', type=int, default=3) 142 | 143 | parser.add_argument('--num_labels', type=int, default=5) 144 | parser.add_argument('--k_spt', type=int, default=5) 145 | parser.add_argument('--k_val', type=int, default=5) 146 | parser.add_argument('--k_qry', type=int, default=50) 147 | parser.add_argument('--n_way', type=int, default=5) 148 | 149 | parser.add_argument('--context_length', type=int, default=128) 150 | parser.add_argument('--coop_n_ctx', type=int, default=4) 151 | parser.add_argument('--prompt_lr', type=float, default=0.01) 152 | 153 | parser.add_argument('--position', type=str, default='end') 154 | parser.add_argument('--class_specific', type=bool, default=False) 155 | parser.add_argument('--ctx_init', type=bool, default=True) 156 | 157 | parser.add_argument('--embed_dim', type=int, default=128) 158 | parser.add_argument('--transformer_heads', type=int, default=8) 159 | parser.add_argument('--transformer_layers', type=int, default=12) 160 | parser.add_argument('--transformer_width', type=int, default=512) 161 | parser.add_argument('--vocab_size', type=int, default=49408) 162 | parser.add_argument('--data_name', type=str, default="Musical_Instruments") 163 | parser.add_argument('--gpu', type=int, default=0) 164 | 165 | args = parser.parse_args() 166 | 167 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 168 | print('device:', device) 169 | 170 | num_nodes = 0 171 | tit_list = [] 172 | tit_dict = json.load(open('./data/{}_text.json'.format(args.data_name))) 173 | new_dict = {} 174 | 175 | for i in range(len(tit_dict)): 176 | num_nodes += 1 177 | new_dict[i] = tit_dict[str(i)] 178 | 179 | print('num_nodes', num_nodes) 180 | 181 | edge_index = np.load('./data/{}_edge.npy'.format(args.data_name)) 182 | 183 | arr_edge_index = edge_index 184 | 185 | edge_index = torch.from_numpy(edge_index).to(device) 186 | 187 | node_f = np.load('./data/{}_f_m.npy'.format(args.data_name)) 188 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 189 | node_f = torch.from_numpy(node_f).to(device) 190 | 191 | id_lab_dict = json.load(open('./data/{}_id_labels.json'.format(args.data_name))) 192 | id_lab_list = sorted(id_lab_dict.items(), key=lambda d: int(d[0])) 193 | 194 | labeled_ids = [] 195 | lab_list = [] 196 | for i in id_lab_list: 197 | if i[1] != 'nan' or i[1] != '' or i[1] != ' ': 198 | labeled_ids.append(int(i[0])) 199 | lab_list.append(i[1]) 200 | 201 | labels = sorted(list(set(lab_list))) 202 | 203 | start = time.perf_counter() 204 | all_acc_list = [] 205 | all_macf1_list = [] 206 | 207 | seed = 1 208 | print('seed', seed) 209 | main(args) 210 | end = time.perf_counter() 211 | print("time consuming {:.2f}".format(end - start)) 212 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import DataLoader 3 | from sklearn import preprocessing 4 | import numpy as np 5 | import argparse 6 | import torch 7 | from random import sample 8 | import random 9 | import math 10 | import time 11 | from model import CLIP, tokenize 12 | from data import DataHelper 13 | from sklearn import preprocessing 14 | 15 | 16 | def setup_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | 24 | 25 | def main(args): 26 | setup_seed(seed) 27 | 28 | 29 | model = CLIP(args).to(device) 30 | Data = DataHelper(arr_edge_index, args) 31 | model.train() 32 | 33 | for j in range(args.epoch_num): 34 | loader = DataLoader(Data, batch_size=args.batch_size, shuffle=True, num_workers=10) 35 | for i_batch, sample_batched in enumerate(loader): 36 | s_n, t_n = sample_batched['s_n'], sample_batched['t_n'] 37 | s_n_arr, t_n_arr = s_n.numpy(), t_n.numpy().reshape(-1) # .reshape((1, -1)) 38 | s_n_text, t_n_text = np.array(tit_list)[s_n_arr].tolist(), np.array(tit_list)[t_n_arr].tolist() 39 | s_n_text, t_n_text = tokenize(s_n_text, context_length=args.context_length).to(device), tokenize(t_n_text, context_length=args.context_length).to(device) 40 | 41 | s_n, t_n = s_n.type(LType).to(device), t_n.type(LType).to(device) 42 | loss = model.forward(node_f, edge_index, s_n, t_n, s_n_text, t_n_text, device) 43 | # if i_batch >2 : 44 | # break 45 | if j == 0 and i_batch % 100 == 0: 46 | print('{}th loss in the first epoch:{}'.format(i_batch, loss)) 47 | 48 | # break 49 | print('{}th epoch loss:{}'.format(j, loss)) 50 | 51 | torch.save(model.state_dict(), './res/{}/node_ttgt_8&12_0.1.pkl'.format(data_name)) 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | 56 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 57 | parser.add_argument('--epoch_num', type=int, default=2, help='epoch number') 58 | parser.add_argument('--batch_size', type=int, default=64) 59 | parser.add_argument('--lr', type=float, default=2e-5) 60 | parser.add_argument('--edge_coef', type=float, default=10) 61 | parser.add_argument('--neigh_num', type=int, default=3) 62 | 63 | parser.add_argument('--gnn_input', type=int, default=128) 64 | parser.add_argument('--gnn_hid', type=int, default=128) 65 | parser.add_argument('--gnn_output', type=int, default=128) 66 | 67 | parser.add_argument('--context_length', type=int, default=128) 68 | 69 | parser.add_argument('--embed_dim', type=int, default=128) 70 | parser.add_argument('--transformer_heads', type=int, default=8) 71 | parser.add_argument('--transformer_layers', type=int, default=12) 72 | parser.add_argument('--transformer_width', type=int, default=512) 73 | parser.add_argument('--vocab_size', type=int, default=49408) # 49408 74 | parser.add_argument('--gpu', type=int, default=0) 75 | 76 | args = parser.parse_args() 77 | 78 | data_name = 'cora' 79 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 80 | print('device:', device) 81 | 82 | 83 | num_nodes = 0 84 | tit_list = [] 85 | with open('./data/train_text.txt', 'r') as f: 86 | lines = f.readlines() 87 | for line in lines: 88 | line = line.strip().split('\t') 89 | tit_list.append(line[2]) 90 | num_nodes += 1 91 | 92 | print('num_nodes', num_nodes) 93 | 94 | raw_edge_index = [[], []] 95 | with open('./data/mapped_edges.txt', 'r') as f: 96 | lines = f.readlines() 97 | for line in lines: 98 | line = line.strip().split() 99 | raw_edge_index[0].append(int(line[0])) 100 | raw_edge_index[1].append(int(line[1])) 101 | 102 | print('num of edges', len(raw_edge_index[0] + raw_edge_index[1])) 103 | 104 | edge_index = [raw_edge_index[0] + raw_edge_index[1], raw_edge_index[1] + raw_edge_index[0]] 105 | arr_edge_index = np.array(edge_index) 106 | edge_index = np.array(edge_index) 107 | edge_index = torch.from_numpy(edge_index).to(device) 108 | 109 | node_f = np.load('./data/node_f.npy') 110 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 111 | node_f = torch.from_numpy(node_f).to(device) 112 | 113 | start = time.perf_counter() 114 | 115 | seed = 1 116 | main(args) 117 | 118 | end = time.perf_counter() 119 | print("time consuming {:.2f}".format(end - start)) 120 | -------------------------------------------------------------------------------- /main_train_amazon.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path as osp 3 | from torch.utils.data import DataLoader 4 | from sklearn import preprocessing 5 | import numpy as np 6 | import argparse 7 | import torch 8 | from random import sample 9 | import random 10 | import math 11 | import time 12 | from model import CLIP, tokenize 13 | from data import DataHelper 14 | from sklearn import preprocessing 15 | import json 16 | 17 | 18 | 19 | def setup_seed(seed): 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | def main(args): 29 | setup_seed(seed) 30 | 31 | model = CLIP(args).to(device) 32 | Data = DataHelper(arr_edge_index, args) 33 | model.train() 34 | 35 | for j in range(args.epoch_num): 36 | loader = DataLoader(Data, batch_size=args.batch_size, shuffle=True, num_workers=10) 37 | for i_batch, sample_batched in enumerate(loader): 38 | s_n, t_n = sample_batched['s_n'], sample_batched['t_n'] 39 | s_n_arr = s_n.numpy() # .reshape((1, -1)) 40 | t_n_arr = t_n.numpy().reshape(-1) 41 | s_n_text, t_n_text = [new_dict[i] for i in s_n_arr], [new_dict[j] for j in t_n_arr] 42 | s_n_text, t_n_text = tokenize(s_n_text, context_length=args.context_length).to(device), tokenize(t_n_text, context_length=args.context_length).to(device) 43 | 44 | s_n, t_n = s_n.type(LType).to(device), t_n.type(LType).to(device) 45 | loss = model.forward(node_f, edge_index, s_n, t_n, s_n_text, t_n_text, device) 46 | if j == 0 and i_batch % 100 == 0: 47 | print('{}th loss in the first epoch:{}'.format(i_batch, loss)) 48 | # break 49 | print('{}th epoch loss:{}'.format(j + 1, loss)) 50 | torch.save(model.state_dict(), './res/{}/node_ttgt_8&12_10.pkl'.format(args.data_name)) 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 56 | parser.add_argument('--epoch_num', type=int, default=2, help='epoch number') 57 | parser.add_argument('--batch_size', type=int, default=64) 58 | parser.add_argument('--lr', type=float, default=2e-5) 59 | parser.add_argument('--edge_coef', type=float, default=10) 60 | parser.add_argument('--neigh_num', type=int, default=3) 61 | 62 | parser.add_argument('--gnn_input', type=int, default=128) 63 | parser.add_argument('--gnn_hid', type=int, default=128) 64 | parser.add_argument('--gnn_output', type=int, default=128) 65 | 66 | parser.add_argument('--context_length', type=int, default=128) 67 | 68 | parser.add_argument('--embed_dim', type=int, default=128) 69 | parser.add_argument('--transformer_heads', type=int, default=8) 70 | parser.add_argument('--transformer_layers', type=int, default=12) 71 | parser.add_argument('--transformer_width', type=int, default=512) 72 | parser.add_argument('--vocab_size', type=int, default=49408) # 49408 73 | parser.add_argument('--data_name', type=str, default="Musical_Instruments") 74 | parser.add_argument('--gpu', type=int, default=0) 75 | 76 | args = parser.parse_args() 77 | 78 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 79 | print('device:', device) 80 | 81 | num_nodes = 0 82 | tit_list = [] 83 | tit_dict = json.load(open('./data/{}_text.json'.format(args.data_name))) 84 | new_dict = {} 85 | 86 | for i in range(len(tit_dict)): 87 | num_nodes += 1 88 | new_dict[i] = tit_dict[str(i)] 89 | 90 | print('num_nodes', num_nodes) 91 | 92 | edge_index = np.load('./data/{}_edge.npy'.format(args.data_name)) 93 | 94 | arr_edge_index = edge_index 95 | 96 | edge_index = torch.from_numpy(edge_index).to(device) 97 | 98 | node_f = np.load('./data/{}_f_m.npy'.format(args.data_name)) 99 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 100 | node_f = torch.from_numpy(node_f).to(device) 101 | 102 | start = time.perf_counter() 103 | 104 | seed = 1 105 | main(args) 106 | 107 | end = time.perf_counter() 108 | print("time consuming {:.2f}".format(end - start)) 109 | -------------------------------------------------------------------------------- /meta_net/init.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /meta_net/main_cog2p2_amazon.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import os.path as osp 5 | import numpy as np 6 | import argparse 7 | import torch 8 | from random import sample 9 | import random 10 | import math 11 | import time 12 | from model import CLIP, tokenize 13 | from torch import nn, optim 14 | from sklearn import preprocessing 15 | from sklearn.metrics import accuracy_score, f1_score 16 | import torch.nn.functional as F 17 | from task_amazon import multitask_data_generator 18 | # from model_cocoop import CoOp 19 | from model_cocoop import CoOp 20 | # from model_node_coop import CoOp 21 | import json 22 | from data_graph import DataHelper 23 | from torch.utils.data import DataLoader 24 | 25 | def setup_seed(seed): 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | def main(args): 34 | 35 | setup_seed(args.seed) 36 | # gnn = CLIP(args).gnn.to(device) 37 | clip_model = CLIP(args) # .to(device) 38 | # clip_model.load_state_dict(torch.load('../res/amazon/{}/clip.pkl'.format(args.data_name), map_location=device)) 39 | clip_model.load_state_dict( 40 | torch.load('./res/{}/node_ttgt_8&12_10.pkl', map_location=device)) 41 | 42 | # model.load_state_dict(torch.load('../res/cora/h1l1.pkl')) 43 | 44 | task_list, train_idx, val_idx, test_idx = multitask_data_generator(lab_list, labeled_ids, labels, args.k_spt, 45 | args.k_val, args.k_qry, args.n_way) 46 | 47 | all_acc = [] 48 | f1_list = [] 49 | # for j in range(len(task_list)): 50 | 51 | train_idx_ts = torch.from_numpy(np.array(train_idx[0])).to(device) 52 | val_idx_ts = torch.from_numpy(np.array(val_idx[0])).to(device) 53 | test_idx_ts = torch.from_numpy(np.array(test_idx[0])).to(device) 54 | 55 | train_truth = [] 56 | for a in train_idx[0]: 57 | train_truth.append(id_lab_dict[str(a)]) 58 | 59 | val_truth = [] 60 | for a in val_idx[0]: 61 | val_truth.append(id_lab_dict[str(a)]) 62 | 63 | test_truth = [] 64 | for a in test_idx[0]: 65 | test_truth.append(id_lab_dict[str(a)]) 66 | 67 | task_lables_arr = np.array(labels)[task_list[0]] 68 | task_labels_dict = dict() 69 | for i in range(task_lables_arr.shape[0]): 70 | task_labels_dict[task_lables_arr[i]] = i 71 | 72 | train_truth_ts = [task_labels_dict[train_truth[i]] for i in range(len(train_truth))] 73 | train_truth_ts = torch.from_numpy(np.array(train_truth_ts)).to(device) 74 | 75 | val_truth_ts = [task_labels_dict[val_truth[i]] for i in range(len(val_truth))] 76 | val_truth_ts = torch.from_numpy(np.array(val_truth_ts)).to(device) 77 | 78 | test_truth_ts = [task_labels_dict[test_truth[i]] for i in range(len(test_truth))] 79 | test_truth_ts = torch.from_numpy(np.array(test_truth_ts)).to(device) 80 | 81 | task_lables = task_lables_arr.tolist() 82 | # print('task_lables', task_lables) 83 | model = CoOp(args, task_lables, clip_model, device) 84 | # for param in model.state_dict(): 85 | # print(param) 86 | 87 | best_val = 0 88 | patience = 2 89 | counter = 0 90 | num_train_samples = train_idx_ts.size(0) 91 | batch_num = num_train_samples // args.batch_size if num_train_samples % args.batch_size == 0 else num_train_samples // args.batch_size + 1 92 | epoch_train_orders = np.arange(num_train_samples) 93 | 94 | for epoch in range(1, args.ft_epoch + 1): 95 | # print('----epoch:' + str(epoch)) 96 | random.shuffle(epoch_train_orders) 97 | # if epoch % 1 == 0: 98 | # print('epoch_train_orders: ', epoch_train_orders[:10]) 99 | 100 | for i in range(batch_num): 101 | start = i * args.batch_size 102 | end = min((i + 1) * args.batch_size, num_train_samples) 103 | the_idx = epoch_train_orders[start:end] 104 | model.train() 105 | train_logits, train_loss = model.forward(train_idx_ts[the_idx], node_f, edge_index, train_truth_ts[the_idx]) 106 | 107 | # break 108 | 109 | print("---evaluating---") 110 | model.eval() 111 | with torch.no_grad(): 112 | val_loss = 0 113 | eval_batch_num = num_train_samples // args.eval_batch_size if num_train_samples % args.eval_batch_size == 0 else num_train_samples // args.eval_batch_size + 1 114 | for i in range(eval_batch_num): 115 | start = i * args.eval_batch_size 116 | end = min((i + 1) * args.eval_batch_size, val_idx_ts.size(0)) 117 | res, batch_val_loss = model.forward(val_idx_ts[start:end], node_f, edge_index, val_truth_ts[start:end], 118 | training=False) 119 | val_loss += batch_val_loss 120 | # val_acc = accuracy_score(val_truth_ts.cpu(), res.argmax(dim=1).cpu()) 121 | if val_loss >= best_val: 122 | counter += 1 123 | if counter >= patience: 124 | break 125 | else: 126 | best_val = val_loss 127 | # torch.save(model, '../res/amazon/{}/g_coop_node.pkl'.format(data_name)) 128 | best_model = model 129 | counter = 0 130 | # print('{}th_task_best_val'.format(j), round(best_val, 4)) 131 | 132 | # best_model = torch.load('../res/amazon/{}/g_coop_node.pkl'.format(data_name)) 133 | if val_loss >= best_val: 134 | best_model = model 135 | print("num of test examples= ", test_idx_ts.size(0)) 136 | best_model.eval() 137 | with torch.no_grad(): 138 | res_list = [] 139 | test_batch_num = test_idx_ts.size(0) // args.eval_batch_size if test_idx_ts.size(0) % args.eval_batch_size == 0 else test_idx_ts.size(0) // args.eval_batch_size + 1 140 | for i in range(test_batch_num): 141 | start = i * args.eval_batch_size 142 | end = min((i + 1) * args.eval_batch_size, test_idx_ts.size(0)) 143 | batch_res, _ = model.forward(test_idx_ts[start:end], node_f, edge_index, test_truth_ts[start:end], 144 | training=False) 145 | res_list.append(batch_res) 146 | 147 | # print('res_list', res_list) 148 | res = torch.cat(res_list, dim=0) 149 | test_acc = accuracy_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu()) 150 | # print('{}_task_test_acc'.format(j), round(test_acc, 4)) 151 | all_acc.append(test_acc) 152 | f1 = f1_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu(), average='macro') 153 | f1_list.append(f1) 154 | 155 | ans = round(np.mean(all_acc).item(), 4) 156 | print('base acc', ans) 157 | 158 | ans = round(np.mean(f1_list).item(), 4) 159 | print('base macro f1', ans) 160 | 161 | print("\n\n") 162 | print("----------begin testing new class----------") 163 | print("\n\n") 164 | 165 | test_idx_ts = torch.from_numpy(np.array(test_idx[1])).to(device) 166 | test_truth = [] 167 | for a in test_idx[1]: 168 | test_truth.append(id_lab_dict[str(a)]) 169 | 170 | task_lables_arr = np.array(labels)[task_list[1]] 171 | task_labels_dict = dict() 172 | for i in range(task_lables_arr.shape[0]): 173 | task_labels_dict[task_lables_arr[i]] = i 174 | 175 | test_truth_ts = [task_labels_dict[test_truth[i]] for i in range(len(test_truth))] 176 | test_truth_ts = torch.from_numpy(np.array(test_truth_ts)).to(device) 177 | 178 | test_task_lables = task_lables_arr.tolist() 179 | print('test_task_lables', test_task_lables[:10]) 180 | test_model = CoOp(args, test_task_lables, clip_model, device) 181 | # test_model.load_state_dict(best_model.state_dict()) 182 | base_dict = best_model.state_dict() 183 | # base_dict = model.state_dict() 184 | new_dict = test_model.state_dict() 185 | 186 | # for param in new_dict: 187 | # print(param) 188 | with torch.no_grad(): 189 | new_dict["model.prompt_learner.ctx"] = base_dict["model.prompt_learner.ctx"] 190 | new_dict["model.prompt_learner.meta_net.linear1.weight"] = base_dict["model.prompt_learner.meta_net.linear1.weight"] 191 | new_dict["model.prompt_learner.meta_net.linear1.bias"] = base_dict["model.prompt_learner.meta_net.linear1.bias"] 192 | new_dict["model.prompt_learner.meta_net.linear2.weight"] = base_dict["model.prompt_learner.meta_net.linear2.weight"] 193 | new_dict["model.prompt_learner.meta_net.linear2.bias"] = base_dict["model.prompt_learner.meta_net.linear2.bias"] 194 | test_model.load_state_dict(new_dict) 195 | 196 | test_model.eval() 197 | with torch.no_grad(): 198 | res_list = [] 199 | test_batch_num = test_idx_ts.size(0) // args.eval_batch_size if test_idx_ts.size(0) % args.eval_batch_size == 0 else test_idx_ts.size(0) // args.eval_batch_size + 1 200 | for i in range(test_batch_num): 201 | start = i * args.eval_batch_size 202 | end = min((i + 1) * args.eval_batch_size, test_idx_ts.size(0)) 203 | batch_res, _ = test_model.forward(test_idx_ts[start:end], node_f, edge_index, test_truth_ts[start:end], 204 | training=False) 205 | res_list.append(batch_res) 206 | 207 | # print('res_list', res_list) 208 | res = torch.cat(res_list, dim=0) 209 | test_acc = accuracy_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu()) 210 | # print('{}_task_test_acc'.format(j), round(test_acc, 4)) 211 | all_acc.append(test_acc) 212 | f1 = f1_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu(), average='macro') 213 | f1_list.append(f1) 214 | 215 | ans = round(np.mean(all_acc).item(), 4) 216 | print('new acc', ans) 217 | 218 | ans = round(np.mean(f1_list).item(), 4) 219 | print('new macro f1', ans) 220 | 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser() 224 | 225 | parser.add_argument('--ft_epoch', type=int, default=10, help='fine-tune epoch') 226 | # parser.add_argument('--ft_epoch', type=int, default=1, help='fine-tune epoch') 227 | 228 | parser.add_argument('--batch_size', type=int, default=5) 229 | parser.add_argument('--eval_batch_size', type=int, default=512) 230 | parser.add_argument('--lr', type=float, default=0.0001) 231 | parser.add_argument('--ft_lr', type=float, default=0.01) 232 | parser.add_argument('--gnn_input', type=int, default=128) 233 | # parser.add_argument('--gnn_hid', type=int, default=16) 234 | parser.add_argument('--gnn_hid', type=int, default=128) 235 | parser.add_argument('--gnn_output', type=int, default=128) 236 | 237 | parser.add_argument('--edge_coef', type=float, default=0.1) 238 | parser.add_argument('--neigh_num', type=int, default=3) 239 | 240 | parser.add_argument('--num_labels', type=int, default=5) 241 | parser.add_argument('--k_spt', type=int, default=5) 242 | parser.add_argument('--k_val', type=int, default=5) 243 | parser.add_argument('--k_qry', type=int, default=50) 244 | parser.add_argument('--n_way', type=int, default=5) 245 | 246 | # parser.add_argument('--context_length', type=int, default=120) 247 | parser.add_argument('--context_length', type=int, default=128) 248 | parser.add_argument('--coop_n_ctx', type=int, default=4) 249 | parser.add_argument('--prompt_lr', type=float, default=0.005) 250 | # parser.add_argument('--prompt_lr', type=float, default=0.001) 251 | 252 | parser.add_argument('--position', type=str, default='end') 253 | parser.add_argument('--class_specific', type=bool, default=False) 254 | # parser.add_argument('--class_specific', type=bool, default=True) 255 | parser.add_argument('--ctx_init', type=bool, default=True) 256 | 257 | parser.add_argument('--embed_dim', type=int, default=128) 258 | parser.add_argument('--transformer_heads', type=int, default=8) 259 | parser.add_argument('--transformer_layers', type=int, default=12) 260 | parser.add_argument('--transformer_width', type=int, default=512) 261 | parser.add_argument('--vocab_size', type=int, default=49408) 262 | 263 | parser.add_argument('--gpu', type=int, default=0) 264 | parser.add_argument('--seed', type=int, default=1) 265 | 266 | # parser.add_argument('--data_name', type=str, default='Arts_Crafts_and_Sewing') 267 | # parser.add_argument('--data_name', type=str, default='Industrial_and_Scientific') 268 | parser.add_argument('--data_name', type=str, default="Musical_Instruments") 269 | 270 | args = parser.parse_args() 271 | 272 | # print('args.class_specific= ', args.class_specific) 273 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 274 | # device = torch.device("cuda:1") 275 | print('device', device) 276 | # data_name = 'cora' 277 | criterion = nn.BCEWithLogitsLoss() 278 | 279 | # device = torch.device("cpu") 280 | FType = torch.FloatTensor 281 | LType = torch.LongTensor 282 | 283 | num_nodes = 0 284 | tit_dict = json.load(open('./data/amazon/{}_text.json'.format(args.data_name))) 285 | new_dict = {} 286 | for i in range(len(tit_dict)): 287 | new_dict[i] = tit_dict[str(i)] 288 | 289 | id_lab_dict = json.load(open('./data/amazon/{}_id_labels.json'.format(args.data_name))) 290 | id_lab_list = sorted(id_lab_dict.items(), key=lambda d: int(d[0])) 291 | 292 | labeled_ids = [] 293 | lab_list = [] 294 | for i in id_lab_list: 295 | # if lab_list[i] == 'nan': 296 | # print(lab_list[i]) 297 | if i[1] != 'nan' or i[1] != '' or i[1] != ' ': 298 | labeled_ids.append(int(i[0])) 299 | lab_list.append(i[1]) 300 | 301 | print('number of labels', len(lab_list)) 302 | 303 | edge_index = np.load('./data/amazon/{}_edge.npy'.format(args.data_name)) 304 | 305 | arr_edge_index = edge_index 306 | 307 | edge_index = torch.from_numpy(edge_index).to(device) 308 | 309 | # node_f = np.load('../cora/node_f_title.npy').astype(np.float32) 310 | node_f = np.load('./data/amazon/{}_f_m.npy'.format(args.data_name)) 311 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 312 | node_f = torch.from_numpy(node_f).to(device) 313 | 314 | labels = sorted(list(set(lab_list))) 315 | 316 | start = time.perf_counter() 317 | base_acc_list = [] 318 | base_macf1_list = [] 319 | 320 | new_acc_list = [] 321 | new_macf1_list = [] 322 | 323 | main(args) 324 | 325 | end = time.perf_counter() 326 | print("time consuming {:.2f}".format(end - start)) 327 | -------------------------------------------------------------------------------- /meta_net/main_cog2p2_cora.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import os.path as osp 5 | import numpy as np 6 | import argparse 7 | import torch 8 | from random import sample 9 | import random 10 | import math 11 | import time 12 | from model_amazon_node import CLIP, tokenize 13 | from torch import nn, optim 14 | from sklearn import preprocessing 15 | from sklearn.metrics import accuracy_score, f1_score 16 | import torch.nn.functional as F 17 | from task_cora import multitask_data_generator 18 | # from model_cocoop import CoOp 19 | from model_cocoop import CoOp 20 | # from model_node_coop import CoOp 21 | import json 22 | from data_graph import DataHelper 23 | from torch.utils.data import DataLoader 24 | 25 | 26 | def setup_seed(seed): 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | 35 | def main(args): 36 | setup_seed(args.seed) 37 | 38 | # gnn = CLIP(args).gnn.to(device) 39 | clip_model = CLIP(args) # .to(device) 40 | # clip_model.load_state_dict(torch.load('../res/amazon/{}/clip.pkl'.format(dataset_name), map_location=device)) 41 | clip_model.load_state_dict( 42 | torch.load('../res/amazon/{}/node_ttgt_8&12_0.1.pkl'.format(dataset_name), map_location=device)) 43 | 44 | # model.load_state_dict(torch.load('../res/cora/h1l1.pkl')) 45 | 46 | task_list, train_idx, val_idx, test_idx = multitask_data_generator(lab_list, labeled_ids, labels, args.k_spt, 47 | args.k_val, args.k_qry, args.n_way) 48 | 49 | all_acc = [] 50 | f1_list = [] 51 | # for j in range(len(task_list)): 52 | 53 | train_idx_ts = torch.from_numpy(np.array(train_idx[0])).to(device) 54 | val_idx_ts = torch.from_numpy(np.array(val_idx[0])).to(device) 55 | test_idx_ts = torch.from_numpy(np.array(test_idx[0])).to(device) 56 | 57 | train_truth = np.array(lab_list)[np.array(train_idx[0])] 58 | val_truth = np.array(lab_list)[np.array(val_idx[0])] 59 | test_truth = np.array(lab_list)[np.array(test_idx[0])] 60 | 61 | task_lables_arr = np.array(labels)[task_list[0]] 62 | task_labels_dict = dict() 63 | for i in range(task_lables_arr.shape[0]): 64 | task_labels_dict[task_lables_arr[i]] = i 65 | 66 | train_truth_ts = [task_labels_dict[train_truth[i]] for i in range(train_truth.shape[0])] 67 | train_truth_ts = torch.from_numpy(np.array(train_truth_ts)).to(device) 68 | 69 | val_truth_ts = [task_labels_dict[val_truth[i]] for i in range(val_truth.shape[0])] 70 | val_truth_ts = torch.from_numpy(np.array(val_truth_ts)).to(device) 71 | 72 | test_truth_ts = [task_labels_dict[test_truth[i]] for i in range(test_truth.shape[0])] 73 | test_truth_ts = torch.from_numpy(np.array(test_truth_ts)).to(device) 74 | 75 | task_lables = task_lables_arr.tolist() 76 | # print('task_lables', task_lables) 77 | model = CoOp(args, task_lables, clip_model, device) 78 | # for param in model.state_dict(): 79 | # print(param) 80 | 81 | best_val = 0 82 | patience = 2 83 | counter = 0 84 | num_train_samples = train_idx_ts.size(0) 85 | batch_num = num_train_samples // args.batch_size if num_train_samples % args.batch_size == 0 else num_train_samples // args.batch_size + 1 86 | epoch_train_orders = np.arange(num_train_samples) 87 | 88 | for epoch in range(1, args.ft_epoch + 1): 89 | # print('----epoch:' + str(epoch)) 90 | random.shuffle(epoch_train_orders) 91 | if epoch % 1 == 0: 92 | print('epoch_train_orders: ', epoch_train_orders[:10]) 93 | 94 | for i in range(batch_num): 95 | start = i * args.batch_size 96 | end = min((i + 1) * args.batch_size, num_train_samples) 97 | the_idx = epoch_train_orders[start:end] 98 | model.train() 99 | train_logits, train_loss = model.forward(train_idx_ts[the_idx], node_f, edge_index, train_truth_ts[the_idx]) 100 | # break 101 | 102 | model.eval() 103 | with torch.no_grad(): 104 | val_loss = 0 105 | eval_batch_num = num_train_samples // args.eval_batch_size if num_train_samples % args.eval_batch_size == 0 else num_train_samples // args.eval_batch_size + 1 106 | for i in range(eval_batch_num): 107 | start = i * args.eval_batch_size 108 | end = min((i + 1) * args.eval_batch_size, val_idx_ts.size(0)) 109 | res, batch_val_loss = model.forward(val_idx_ts[start:end], node_f, edge_index, val_truth_ts[start:end], 110 | training=False) 111 | val_loss += batch_val_loss 112 | # val_acc = accuracy_score(val_truth_ts.cpu(), res.argmax(dim=1).cpu()) 113 | if val_loss >= best_val: 114 | counter += 1 115 | if counter >= patience: 116 | break 117 | else: 118 | best_val = val_loss 119 | # torch.save(model, '../res/amazon/{}/g_coop_node.pkl'.format(data_name)) 120 | best_model = model 121 | counter = 0 122 | # print('{}th_task_best_val'.format(j), round(best_val, 4)) 123 | 124 | # best_model = torch.load('../res/amazon/{}/g_coop_node.pkl'.format(data_name)) 125 | if val_loss >= best_val: 126 | best_model = model 127 | print("num of test examples= ", test_idx_ts.size(0)) 128 | best_model.eval() 129 | with torch.no_grad(): 130 | res_list = [] 131 | test_batch_num = test_idx_ts.size(0) // args.eval_batch_size if test_idx_ts.size(0) % args.eval_batch_size == 0 else test_idx_ts.size(0) // args.eval_batch_size + 1 132 | for i in range(test_batch_num): 133 | start = i * args.eval_batch_size 134 | end = min((i + 1) * args.eval_batch_size, test_idx_ts.size(0)) 135 | batch_res, _ = model.forward(test_idx_ts[start:end], node_f, edge_index, test_truth_ts[start:end], 136 | training=False) 137 | res_list.append(batch_res) 138 | 139 | # print('res_list', res_list) 140 | res = torch.cat(res_list, dim=0) 141 | test_acc = accuracy_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu()) 142 | # print('{}_task_test_acc'.format(j), round(test_acc, 4)) 143 | all_acc.append(test_acc) 144 | f1 = f1_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu(), average='macro') 145 | f1_list.append(f1) 146 | 147 | ans = round(np.mean(all_acc).item(), 4) 148 | print('base acc', ans) 149 | base_acc_list[Bob].append(ans) 150 | 151 | ans = round(np.mean(f1_list).item(), 4) 152 | print('base macro f1', ans) 153 | base_macf1_list[Bob].append(ans) 154 | 155 | print("\n\n") 156 | print("----------begin testing new class----------") 157 | print("\n\n") 158 | 159 | test_idx_ts = torch.from_numpy(np.array(test_idx[1])).to(device) 160 | test_truth = np.array(lab_list)[np.array(test_idx[1])] 161 | task_lables_arr = np.array(labels)[task_list[1]] 162 | task_labels_dict = dict() 163 | for i in range(task_lables_arr.shape[0]): 164 | task_labels_dict[task_lables_arr[i]] = i 165 | 166 | test_truth_ts = [task_labels_dict[test_truth[i]] for i in range(test_truth.shape[0])] 167 | test_truth_ts = torch.from_numpy(np.array(test_truth_ts)).to(device) 168 | 169 | test_task_lables = task_lables_arr.tolist() 170 | print('test_task_lables', test_task_lables[:10]) 171 | test_model = CoOp(args, test_task_lables, clip_model, device) 172 | # test_model.load_state_dict(best_model.state_dict()) 173 | base_dict = best_model.state_dict() 174 | # base_dict = model.state_dict() 175 | new_dict = test_model.state_dict() 176 | 177 | # for param in new_dict: 178 | # print(param) 179 | with torch.no_grad(): 180 | new_dict["model.prompt_learner.ctx"] = base_dict["model.prompt_learner.ctx"] 181 | new_dict["model.prompt_learner.meta_net.linear1.weight"] = base_dict["model.prompt_learner.meta_net.linear1.weight"] 182 | new_dict["model.prompt_learner.meta_net.linear1.bias"] = base_dict["model.prompt_learner.meta_net.linear1.bias"] 183 | new_dict["model.prompt_learner.meta_net.linear2.weight"] = base_dict["model.prompt_learner.meta_net.linear2.weight"] 184 | new_dict["model.prompt_learner.meta_net.linear2.bias"] = base_dict["model.prompt_learner.meta_net.linear2.bias"] 185 | test_model.load_state_dict(new_dict) 186 | 187 | test_model.eval() 188 | with torch.no_grad(): 189 | res_list = [] 190 | test_batch_num = test_idx_ts.size(0) // args.eval_batch_size if test_idx_ts.size(0) % args.eval_batch_size == 0 else test_idx_ts.size(0) // args.eval_batch_size + 1 191 | for i in range(test_batch_num): 192 | start = i * args.eval_batch_size 193 | end = min((i + 1) * args.eval_batch_size, test_idx_ts.size(0)) 194 | batch_res, _ = model.forward(test_idx_ts[start:end], node_f, edge_index, test_truth_ts[start:end], 195 | training=False) 196 | res_list.append(batch_res) 197 | 198 | # print('res_list', res_list) 199 | res = torch.cat(res_list, dim=0) 200 | test_acc = accuracy_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu()) 201 | # print('{}_task_test_acc'.format(j), round(test_acc, 4)) 202 | all_acc.append(test_acc) 203 | f1 = f1_score(test_truth_ts.cpu(), res.argmax(dim=1).cpu(), average='macro') 204 | f1_list.append(f1) 205 | 206 | ans = round(np.mean(all_acc).item(), 4) 207 | print('new acc', ans) 208 | new_acc_list[Bob].append(ans) 209 | 210 | ans = round(np.mean(f1_list).item(), 4) 211 | print('new macro f1', ans) 212 | new_macf1_list[Bob].append(ans) 213 | 214 | print("\n\n") 215 | print("\n\n") 216 | print("\n\n") 217 | 218 | 219 | if __name__ == '__main__': 220 | parser = argparse.ArgumentParser() 221 | 222 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 223 | parser.add_argument('--hidden', type=str, default=16, help='number of hidden neurons') 224 | parser.add_argument('--ft_epoch', type=int, default=20, help='fine-tune epoch') 225 | # parser.add_argument('--ft_epoch', type=int, default=1, help='fine-tune epoch') 226 | 227 | parser.add_argument('--batch_size', type=int, default=10) 228 | parser.add_argument('--eval_batch_size', type=int, default=1000) 229 | parser.add_argument('--lr', type=float, default=0.0001) 230 | parser.add_argument('--ft_lr', type=float, default=0.01) 231 | parser.add_argument('--gnn_input', type=int, default=128) 232 | # parser.add_argument('--gnn_hid', type=int, default=16) 233 | parser.add_argument('--gnn_hid', type=int, default=128) 234 | parser.add_argument('--gnn_output', type=int, default=128) 235 | 236 | parser.add_argument('--edge_coef', type=float, default=0.1) 237 | parser.add_argument('--neigh_num', type=int, default=3) 238 | 239 | parser.add_argument('--num_labels', type=int, default=5) 240 | parser.add_argument('--k_spt', type=int, default=5) 241 | parser.add_argument('--k_val', type=int, default=5) 242 | parser.add_argument('--k_qry', type=int, default=50) 243 | parser.add_argument('--n_way', type=int, default=5) 244 | 245 | # parser.add_argument('--context_length', type=int, default=77) 246 | parser.add_argument('--context_length', type=int, default=128) 247 | parser.add_argument('--coop_n_ctx', type=int, default=4) 248 | parser.add_argument('--prompt_lr', type=float, default=0.005) 249 | # parser.add_argument('--prompt_lr', type=float, default=0.001) 250 | 251 | parser.add_argument('--position', type=str, default='end') 252 | parser.add_argument('--class_specific', type=bool, default=False) 253 | # parser.add_argument('--class_specific', type=bool, default=True) 254 | parser.add_argument('--ctx_init', type=bool, default=True) 255 | 256 | parser.add_argument('--embed_dim', type=int, default=128) 257 | parser.add_argument('--transformer_heads', type=int, default=8) 258 | parser.add_argument('--transformer_layers', type=int, default=12) 259 | parser.add_argument('--transformer_width', type=int, default=512) 260 | parser.add_argument('--vocab_size', type=int, default=49408) 261 | 262 | parser.add_argument('--gpu', type=int, default=1) 263 | parser.add_argument('--seed', type=int, default=1) 264 | 265 | args = parser.parse_args() 266 | 267 | # print('args.class_specific= ', args.class_specific) 268 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 269 | # device = torch.device("cuda:1") 270 | print('device', device) 271 | dataset_name = 'cora' 272 | criterion = nn.BCEWithLogitsLoss() 273 | 274 | # device = torch.device("cpu") 275 | FType = torch.FloatTensor 276 | LType = torch.LongTensor 277 | 278 | num_nodes = 0 279 | tit_list = [] 280 | lab_list = [] 281 | with open('../cora/train_text.txt', 'r') as f: 282 | lines = f.readlines() 283 | for line in lines: 284 | line = line.strip().split('\t') 285 | # tit_list.append(line[1]) # the title has been processed 286 | tit_list.append(line[2]) 287 | lab_list.append(line[3]) 288 | num_nodes += 1 289 | 290 | print('num_nodes', num_nodes) 291 | 292 | labeled_ids = [] 293 | for i in range(len(lab_list)): 294 | if lab_list[i] != 'nan': 295 | labeled_ids.append(i) 296 | 297 | print('{} nodes having lables'.format(len(labeled_ids))) 298 | 299 | raw_edge_index = [[], []] 300 | with open('../cora/mapped_edges.txt', 'r') as f: 301 | lines = f.readlines() 302 | for line in lines: 303 | line = line.strip().split() 304 | raw_edge_index[0].append(int(line[0])) 305 | raw_edge_index[1].append(int(line[1])) 306 | 307 | edge_index = [raw_edge_index[0] + raw_edge_index[1], raw_edge_index[1] + raw_edge_index[0]] 308 | arr_edge_index = np.array(edge_index) 309 | edge_index = np.array(edge_index) 310 | edge_index = torch.from_numpy(edge_index).to(device) 311 | 312 | node_f = np.load('../cora/node_f.npy') 313 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 314 | node_f = torch.from_numpy(node_f).to(device) 315 | 316 | # label_texts = [] 317 | with open('../cora/lab_list.txt', 'r') as f: 318 | line = f.readline().strip().split('\t') 319 | label_texts = line 320 | 321 | labels = [] 322 | for i in label_texts: 323 | if i != 'nan': 324 | labels.append(i) 325 | 326 | start = time.perf_counter() 327 | base_acc_list = [] 328 | base_macf1_list = [] 329 | 330 | new_acc_list = [] 331 | new_macf1_list = [] 332 | 333 | main(args) 334 | 335 | end = time.perf_counter() 336 | print("time consuming {:.2f}".format(end - start)) 337 | 338 | 339 | -------------------------------------------------------------------------------- /meta_net/model_cocoop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.cuda.amp import GradScaler, autocast 7 | from torch import optim 8 | import model 9 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 10 | from collections import OrderedDict 11 | 12 | _tokenizer = _Tokenizer() 13 | 14 | 15 | # def load_clip(args, device): 16 | # 17 | # model = clip(args).to(device) 18 | # model.load_state_dict(torch.load('./res/cora/node_ce.pkl')) 19 | # 20 | # return model 21 | 22 | 23 | class TextEncoder(nn.Module): 24 | def __init__(self, clip_model): 25 | super().__init__() 26 | self.transformer = clip_model.transformer 27 | self.positional_embedding = clip_model.positional_embedding 28 | self.ln_final = clip_model.ln_final 29 | self.text_projection = clip_model.text_projection 30 | self.dtype = clip_model.dtype 31 | 32 | def forward(self, prompts, tokenized_prompts): 33 | x = prompts + self.positional_embedding.type(self.dtype) 34 | x = x.permute(1, 0, 2) # NLD -> LND 35 | x = self.transformer(x) 36 | x = x.permute(1, 0, 2) # LND -> NLD 37 | x = self.ln_final(x).type(self.dtype) 38 | 39 | # x.shape = [batch_size, n_ctx, transformer.width] 40 | # take features from the eot embedding (eot_token is the highest number in each sequence) 41 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 42 | 43 | return x 44 | 45 | class PromptLearner(nn.Module): 46 | def __init__(self, args, classnames, clip_model): 47 | super().__init__() 48 | self.vars = nn.ParameterList() 49 | n_cls = len(classnames) 50 | n_ctx = args.coop_n_ctx 51 | dtype = clip_model.dtype 52 | ctx_dim = clip_model.ln_final.weight.shape[0] 53 | 54 | # random initialization 55 | if args.class_specific: 56 | # print("Initializing class-specific contexts") 57 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 58 | else: 59 | # print("Initializing a generic context") 60 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 61 | nn.init.normal_(ctx_vectors, std=0.02) 62 | prompt_prefix = " ".join(["X"] * n_ctx) 63 | 64 | # print(f'Initial context: "{prompt_prefix}"') 65 | # print(f"Number of context words (tokens): {n_ctx}") 66 | 67 | # self.ctx = nn.Parameter(ctx_vectors) # to be optimized 68 | # self.vars.append(self.ctx) 69 | self.ctx = nn.Parameter(ctx_vectors) 70 | 71 | self.meta_net = nn.Sequential(OrderedDict([ 72 | ("linear1", nn.Linear(args.gnn_output, args.gnn_output // 16)), 73 | ("relu", nn.ReLU(inplace=True)), 74 | ("linear2", nn.Linear(args.gnn_output // 16, args.transformer_width)) 75 | ])) 76 | 77 | classnames = [name.replace("_", " ") for name in classnames] 78 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 79 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 80 | 81 | tokenized_prompts = torch.cat([model_amazon_node.tokenize(p, context_length=args.context_length) for p in prompts]) 82 | with torch.no_grad(): 83 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 84 | 85 | # These token vectors will be saved when in save_model(), 86 | # but they should be ignored in load_model() as we want to use 87 | # those computed using the current class names 88 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 89 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS 90 | 91 | self.n_cls = n_cls 92 | self.n_ctx = n_ctx 93 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 94 | self.name_lens = name_lens 95 | self.class_token_position = args.position 96 | 97 | def construct_prompts(self, ctx, prefix, suffix, label=None): 98 | # dim0 is either batch_size (during training) or n_cls (during testing) 99 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 100 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 101 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 102 | 103 | if label is not None: 104 | prefix = prefix[label] 105 | suffix = suffix[label] 106 | 107 | prompts = torch.cat( 108 | [ 109 | prefix, # (dim0, 1, dim) 110 | ctx, # (dim0, n_ctx, dim) 111 | suffix, # (dim0, *, dim) 112 | ], 113 | dim=1, 114 | ) 115 | 116 | return prompts 117 | 118 | def forward(self, im_features): 119 | prefix = self.token_prefix 120 | suffix = self.token_suffix 121 | ctx = self.ctx # (n_ctx, ctx_dim) 122 | bias = self.meta_net(im_features) # (batch, ctx_dim) 123 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) 124 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) 125 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) 126 | 127 | # Use instance-conditioned context tokens for all classes 128 | prompts = [] 129 | for ctx_shifted_i in ctx_shifted: 130 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) 131 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim) 132 | prompts.append(pts_i) 133 | prompts = torch.stack(prompts) 134 | 135 | return prompts 136 | 137 | 138 | 139 | class CustomCLIP(nn.Module): 140 | def __init__(self, args, classnames, clip_model): 141 | super().__init__() 142 | self.prompt_learner = PromptLearner(args, classnames, clip_model) 143 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 144 | self.image_encoder = clip_model.gnn 145 | self.text_encoder = TextEncoder(clip_model) 146 | self.logit_scale = clip_model.logit_scale 147 | self.dtype = clip_model.dtype 148 | 149 | def forward(self, s_n, x, adj): 150 | tokenized_prompts = self.tokenized_prompts 151 | logit_scale = self.logit_scale.exp() 152 | 153 | image_features = self.image_encoder(x, adj) 154 | image_features = image_features[s_n] 155 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 156 | 157 | prompts = self.prompt_learner(image_features) 158 | 159 | 160 | logits = [] 161 | for pts_i, imf_i in zip(prompts, image_features): 162 | text_features = self.text_encoder(pts_i, tokenized_prompts) 163 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 164 | l_i = logit_scale * imf_i @ text_features.t() 165 | logits.append(l_i) 166 | logits = torch.stack(logits) 167 | 168 | # text_features = self.text_encoder(prompts, tokenized_prompts) 169 | # image_features = image_features / image_features.norm(dim=-1, keepdim=True) 170 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 171 | # logits = logit_scale * image_features @ text_features.t() 172 | 173 | return logits 174 | 175 | class CoOp(nn.Module): 176 | 177 | """Context Optimization (CoOp). 178 | Learning to Prompt for Vision-Language Models 179 | https://arxiv.org/abs/2109.01134 180 | """ 181 | 182 | def __init__(self, args, classnames, clip_model, device): 183 | super().__init__() 184 | self.args = args 185 | self.classnames = classnames 186 | self.model = CustomCLIP(args, classnames, clip_model) 187 | 188 | # name_to_update = "prompt_learner" 189 | # for name, param in self.model.named_parameters(): 190 | # if name_to_update not in name: 191 | # param.requires_grad_(False) 192 | 193 | # print("Turning off gradients in both the image and the text encoder") 194 | for name, param in self.model.named_parameters(): 195 | if "prompt_learner" not in name: 196 | param.requires_grad_(False) 197 | 198 | # NOTE: only give prompt_learner to the optimizer 199 | # self.optim = build_optimizer(self.model.prompt_learner, args.OPTIM) 200 | self.model.to(device) 201 | 202 | self.optim = optim.Adam(self.model.prompt_learner.parameters(), lr=args.prompt_lr) 203 | 204 | # for name, param in self.model.prompt_learner.named_parameters(): 205 | # print("name:", name) 206 | 207 | def forward(self, s_n, x, adj, label, training=True): 208 | 209 | logits = self.model(s_n, x, adj) 210 | loss = F.cross_entropy(logits, label) 211 | if training: 212 | self.optim.zero_grad() 213 | torch.cuda.empty_cache() 214 | loss.backward() 215 | self.optim.step() 216 | 217 | return logits, round(loss.clone().detach().cpu().item(), 4) 218 | 219 | -------------------------------------------------------------------------------- /meta_net/task_amazon.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def multitask_data_generator(labels, labeled_node_list, select_array, k_spt, k_val, k_qry, n_way): 6 | labels_local = labels # .clone().detach() 7 | # random.shuffle(select_array_index) 8 | 9 | class_idx_list = [] 10 | train_class_list = [] 11 | val_class_list = [] 12 | test_class_list = [] 13 | for i in range(len(select_array)): 14 | class_idx_list.append([]) 15 | train_class_list.append([]) 16 | val_class_list.append([]) 17 | test_class_list.append([]) 18 | 19 | # for j in labeled_node_list: 20 | # for i in range(len(select_array)): 21 | # if (labels_local[j] == select_array[i]): 22 | # class_idx_list[i].append(j) 23 | for j in range(len(labeled_node_list)): 24 | for i in range(len(select_array)): 25 | if (labels_local[j] == select_array[i]): 26 | class_idx_list[i].append(labeled_node_list[j]) 27 | 28 | 29 | usable_labels = [] 30 | for i in range(len(class_idx_list)): 31 | if len(class_idx_list[i]) > 320: 32 | # if len(class_idx_list[i]) >= 100: 33 | usable_labels.append(i) 34 | # elif 15 > len(class_idx_list[i]) > 0: 35 | # new_classes.append(i) 36 | 37 | len_usable_labels = len(usable_labels) 38 | print('len_usable_labels', len_usable_labels) 39 | random.shuffle(usable_labels) 40 | 41 | base_classes = np.random.choice(usable_labels, len_usable_labels//2, replace=False).tolist() 42 | new_classes = list(set(usable_labels)-set(base_classes)) 43 | 44 | for i in range(len(select_array)): 45 | if i not in set(usable_labels): 46 | continue 47 | # train_class_list[i] = random.sample(class_idx_list[i], k_spt) 48 | train_class_list[i] = np.random.choice(class_idx_list[i], k_spt, replace=False).tolist() 49 | val_class_temp = [n1 for n1 in class_idx_list[i] if n1 not in train_class_list[i]] 50 | # print('val_class_temp', val_class_temp) 51 | # test_class_list[i] = random.sample(test_class_list[i], k_qry) 52 | val_class_list[i] = np.random.choice(val_class_temp, k_val, replace=False).tolist() 53 | # test_class_temp = [n1 for n1 in class_idx_list[i] if 54 | # (n1 not in train_class_list[i]) and (n1 not in val_class_list[i])] 55 | test_class_temp = [n1 for n1 in class_idx_list[i] if 56 | (n1 not in train_class_list[i]) and (n1 not in val_class_list[i])] 57 | # test_class_list[i] = [np.random.choice(test_class_temp, replace=False).item() for _ in range(k_qry)] 58 | test_class_list[i] = test_class_temp 59 | 60 | train_idx = [[], []] 61 | test_idx = [[], []] 62 | val_idx = [[], []] 63 | 64 | for j in base_classes: 65 | train_idx[0] += train_class_list[j] 66 | val_idx[0] += val_class_list[j] 67 | test_idx[0] += test_class_list[j] 68 | 69 | for j in new_classes: 70 | train_idx[1] += train_class_list[j] 71 | val_idx[1] += val_class_list[j] 72 | test_idx[1] += test_class_list[j] 73 | 74 | task_list = [base_classes, new_classes] 75 | 76 | return task_list, train_idx, val_idx, test_idx 77 | -------------------------------------------------------------------------------- /meta_net/task_cora.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def multitask_data_generator(labels, labeled_node_list, select_array, k_spt, k_val, k_qry, n_way): 6 | labels_local = labels # .clone().detach() 7 | # random.shuffle(select_array_index) 8 | 9 | class_idx_list = [] 10 | train_class_list = [] 11 | val_class_list = [] 12 | test_class_list = [] 13 | for i in range(len(select_array)): 14 | class_idx_list.append([]) 15 | train_class_list.append([]) 16 | val_class_list.append([]) 17 | test_class_list.append([]) 18 | 19 | for j in labeled_node_list: 20 | for i in range(len(select_array)): 21 | if (labels_local[j] == select_array[i]): 22 | class_idx_list[i].append(j) 23 | 24 | usable_labels = [] 25 | for i in range(len(class_idx_list)): 26 | if len(class_idx_list[i]) > 10: 27 | # if len(class_idx_list[i]) >= 100: 28 | usable_labels.append(i) 29 | # elif 15 > len(class_idx_list[i]) > 0: 30 | # new_classes.append(i) 31 | 32 | len_usable_labels = len(usable_labels) 33 | print('len_usable_labels', len_usable_labels) 34 | random.shuffle(usable_labels) 35 | 36 | base_classes = np.random.choice(usable_labels, len_usable_labels//2, replace=False).tolist() 37 | new_classes = list(set(usable_labels)-set(base_classes)) 38 | 39 | for i in range(len(select_array)): 40 | if i not in set(usable_labels): 41 | continue 42 | # train_class_list[i] = random.sample(class_idx_list[i], k_spt) 43 | train_class_list[i] = np.random.choice(class_idx_list[i], k_spt, replace=False).tolist() 44 | val_class_temp = [n1 for n1 in class_idx_list[i] if n1 not in train_class_list[i]] 45 | # print('val_class_temp', val_class_temp) 46 | # test_class_list[i] = random.sample(test_class_list[i], k_qry) 47 | val_class_list[i] = np.random.choice(val_class_temp, k_val, replace=False).tolist() 48 | # test_class_temp = [n1 for n1 in class_idx_list[i] if 49 | # (n1 not in train_class_list[i]) and (n1 not in val_class_list[i])] 50 | test_class_temp = [n1 for n1 in class_idx_list[i] if 51 | (n1 not in train_class_list[i]) and (n1 not in val_class_list[i])] 52 | # test_class_list[i] = [np.random.choice(test_class_temp, replace=False).item() for _ in range(k_qry)] 53 | test_class_list[i] = test_class_temp 54 | 55 | train_idx = [[], []] 56 | test_idx = [[], []] 57 | val_idx = [[], []] 58 | 59 | for j in base_classes: 60 | train_idx[0] += train_class_list[j] 61 | val_idx[0] += val_class_list[j] 62 | test_idx[0] += test_class_list[j] 63 | 64 | for j in new_classes: 65 | train_idx[1] += train_class_list[j] 66 | val_idx[1] += val_class_list[j] 67 | test_idx[1] += test_class_list[j] 68 | 69 | task_list = [base_classes, new_classes] 70 | 71 | return task_list, train_idx, val_idx, test_idx 72 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from typing import Any, Union, List 9 | from pkg_resources import packaging 10 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | from torch_geometric.nn.conv import MessagePassing 12 | from torch_scatter import scatter_add 13 | from torch_geometric.utils import add_remaining_self_loops 14 | from torch.nn import Parameter 15 | from torch import nn, optim 16 | 17 | _tokenizer = _Tokenizer() 18 | 19 | 20 | class LayerNorm(nn.LayerNorm): 21 | """Subclass torch's LayerNorm to handle fp16.""" 22 | 23 | def forward(self, x: torch.Tensor): 24 | orig_type = x.dtype 25 | ret = super().forward(x.type(torch.float32)) 26 | return ret.type(orig_type) 27 | 28 | 29 | class QuickGELU(nn.Module): 30 | def forward(self, x: torch.Tensor): 31 | return x * torch.sigmoid(1.702 * x) 32 | 33 | 34 | class ResidualAttentionBlock(nn.Module): 35 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 36 | super().__init__() 37 | 38 | self.attn = nn.MultiheadAttention(d_model, n_head) 39 | self.ln_1 = LayerNorm(d_model) 40 | self.mlp = nn.Sequential(OrderedDict([ 41 | ("c_fc", nn.Linear(d_model, d_model * 4)), 42 | ("gelu", QuickGELU()), 43 | ("c_proj", nn.Linear(d_model * 4, d_model)) 44 | ])) 45 | self.ln_2 = LayerNorm(d_model) 46 | self.attn_mask = attn_mask 47 | 48 | def attention(self, x: torch.Tensor): 49 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 50 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 51 | 52 | def forward(self, x: torch.Tensor): 53 | x = x + self.attention(self.ln_1(x)) 54 | x = x + self.mlp(self.ln_2(x)) 55 | return x 56 | 57 | 58 | class Transformer(nn.Module): 59 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 60 | super().__init__() 61 | self.width = width 62 | self.layers = layers 63 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 64 | 65 | def forward(self, x: torch.Tensor): 66 | return self.resblocks(x) 67 | 68 | 69 | class GNN(MessagePassing): 70 | def __init__(self, args, **kwargs): 71 | super(GNN, self).__init__(aggr='add', **kwargs) 72 | self.vars = nn.ParameterList() 73 | 74 | w = nn.Parameter(torch.ones([args.gnn_hid, args.gnn_input])) 75 | torch.nn.init.xavier_uniform_(w) 76 | self.vars.append(w) 77 | self.vars.append(nn.Parameter(torch.zeros(args.gnn_hid))) 78 | 79 | w = nn.Parameter(torch.ones([args.gnn_output, args.gnn_hid])) 80 | torch.nn.init.xavier_uniform_(w) 81 | self.vars.append(w) 82 | self.vars.append(nn.Parameter(torch.zeros(args.gnn_output))) 83 | 84 | @staticmethod 85 | def norm(edge_index, num_nodes, improved=False, dtype=None): 86 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 87 | device=edge_index.device) 88 | 89 | fill_value = 1.0 if not improved else 2.0 90 | edge_index, edge_weight = add_remaining_self_loops( 91 | edge_index, edge_weight, fill_value, num_nodes) 92 | 93 | row, col = edge_index 94 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 95 | deg_inv_sqrt = deg.pow(-0.5) 96 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 97 | 98 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 99 | 100 | def forward(self, x, edge_index, vars=None): 101 | if vars is None: 102 | vars = self.vars 103 | improved = False 104 | 105 | w, b = vars[0], vars[1] 106 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype) 107 | x = self.propagate(edge_index, x=x, norm=norm) 108 | x = F.linear(x, w, b) 109 | x = F.leaky_relu(x) 110 | 111 | w, b = vars[2], vars[3] 112 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype) 113 | x = self.propagate(edge_index, x=x, norm=norm) 114 | x = F.linear(x, w, b) 115 | 116 | return x 117 | 118 | def parameters(self): 119 | return self.vars 120 | 121 | 122 | class CLIP(nn.Module): 123 | def __init__(self, 124 | args 125 | ): 126 | super().__init__() 127 | 128 | self.context_length = args.context_length 129 | self.args = args 130 | self.edge_coef = args.edge_coef 131 | 132 | self.gnn = GNN(args) 133 | self.transformer = Transformer( 134 | width=args.transformer_width, 135 | layers=args.transformer_layers, 136 | heads=args.transformer_heads, 137 | attn_mask=self.build_attention_mask() 138 | ) 139 | 140 | self.vocab_size = args.vocab_size 141 | self.token_embedding = nn.Embedding(args.vocab_size, 142 | args.transformer_width) # the embedding for all possible tokens 143 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, args.transformer_width)) 144 | self.ln_final = LayerNorm(args.transformer_width) 145 | 146 | self.text_projection = nn.Parameter(torch.empty(args.transformer_width, args.embed_dim)) 147 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 148 | 149 | self.dtype = self.gnn.vars[0].dtype 150 | 151 | self.optim = optim.Adam([{'params': self.token_embedding.weight}, 152 | {'params': self.positional_embedding}, 153 | {'params': self.transformer.parameters()}, 154 | {'params': self.text_projection}, 155 | {'params': self.gnn.parameters()} 156 | ], lr=args.lr) 157 | 158 | self.initialize_parameters() 159 | 160 | def initialize_parameters(self): 161 | nn.init.normal_(self.token_embedding.weight, std=0.02) 162 | nn.init.normal_(self.positional_embedding, std=0.01) 163 | 164 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 165 | attn_std = self.transformer.width ** -0.5 166 | fc_std = (2 * self.transformer.width) ** -0.5 167 | for block in self.transformer.resblocks: 168 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 169 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 170 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 171 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 172 | 173 | if self.text_projection is not None: 174 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 175 | 176 | def build_attention_mask(self): 177 | # lazily create causal attention mask, with full attention between the vision tokens 178 | # pytorch uses additive attention mask; fill with -inf 179 | mask = torch.empty(self.context_length, self.context_length) 180 | mask.fill_(float("-inf")) 181 | mask.triu_(1) # zero out the lower diagonal 182 | return mask 183 | 184 | def encode_image(self, idx_train, x, adj): 185 | embs = self.gnn(x, adj) 186 | train_embs = embs[idx_train] 187 | return train_embs 188 | 189 | def encode_text(self, text): 190 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 191 | 192 | x = x + self.positional_embedding.type(self.dtype) 193 | x = x.permute(1, 0, 194 | 2) # NLD -> LND, batch_size * context_length *emb_dim -> context_length * batch_size *emb_dim 195 | x = self.transformer(x) 196 | x = x.permute(1, 0, 197 | 2) # LND -> NLD, context_length * batch_size *emb_dim -> batch_size * context_length *emb_dim 198 | x = self.ln_final(x).type(self.dtype) 199 | # x.shape = [batch_size, n_ctx, transformer.width] 200 | # take features from the eot (end of token) embedding (eot_token is the highest number in each sequence) 201 | # so there is node need to shorten the context length 202 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] # 203 | x = x @ self.text_projection 204 | return x 205 | 206 | def forward(self, x, adj, s_n, t_n, s_n_text, t_n_text, device, training=True): 207 | 208 | s_image_features = self.encode_image(s_n, x, adj) 209 | 210 | s_text_features = self.encode_text(s_n_text) 211 | 212 | t_text_features = self.encode_text(t_n_text) 213 | t_text_features = t_text_features.reshape(s_image_features.shape[0], self.args.neigh_num, self.args.gnn_output) 214 | t_text_features = torch.mean(t_text_features, dim=1, keepdim=False) 215 | # normalized features 216 | s_image_features = s_image_features / s_image_features.norm(dim=-1, keepdim=True) 217 | s_text_features = s_text_features / s_text_features.norm(dim=-1, keepdim=True) 218 | t_text_features = t_text_features / t_text_features.norm(dim=-1, keepdim=True) 219 | # cosine similarity as logits 220 | 221 | labels = torch.arange(s_image_features.shape[0]).to(device) 222 | 223 | logit_scale = self.logit_scale.exp() # the temporature hyperparameter 224 | logits = logit_scale * s_image_features @ s_text_features.t() 225 | loss_i = F.cross_entropy(logits, labels) 226 | loss_t = F.cross_entropy(logits.T, labels) 227 | node_loss = (loss_i + loss_t) / 2 228 | 229 | logits = logit_scale * s_image_features @ t_text_features.t() 230 | loss_i = F.cross_entropy(logits, labels) 231 | loss_t = F.cross_entropy(logits.T, labels) 232 | gt_loss = (loss_i + loss_t)/2 233 | 234 | logits = logit_scale * s_text_features @ t_text_features.t() 235 | loss_i = F.cross_entropy(logits, labels) 236 | loss_t = F.cross_entropy(logits.T, labels) 237 | tt_loss = (loss_i + loss_t)/2 238 | 239 | all_loss = node_loss + self.edge_coef * gt_loss + self.edge_coef * tt_loss 240 | 241 | if training == True: 242 | self.optim.zero_grad() 243 | torch.cuda.empty_cache() 244 | all_loss.backward() 245 | self.optim.step() 246 | 247 | # shape = [global_batch_size, global_batch_size] 248 | return round((all_loss.detach().clone()).cpu().item(), 4) 249 | 250 | 251 | def tokenize(texts: Union[str, List[str]], context_length: int = 128, truncate: bool = True) -> torch.LongTensor: 252 | 253 | """ 254 | Returns the tokenized representation of given input string(s) 255 | 256 | Parameters 257 | ---------- 258 | texts : Union[str, List[str]] 259 | An input string or a list of input strings to tokenize 260 | 261 | context_length : int 262 | The context length to use; all CLIP models use 77 as the context length 263 | 264 | truncate: bool 265 | Whether to truncate the text in case its encoding is longer than the context length 266 | 267 | Returns 268 | ------- 269 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 270 | """ 271 | if isinstance(texts, str): 272 | texts = [texts] 273 | 274 | sot_token = _tokenizer.encoder["<|startoftext|>"] 275 | eot_token = _tokenizer.encoder["<|endoftext|>"] 276 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 277 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 278 | 279 | for i, tokens in enumerate(all_tokens): 280 | if len(tokens) > context_length: 281 | if truncate: 282 | tokens = tokens[:context_length] 283 | tokens[-1] = eot_token 284 | else: 285 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 286 | result[i, :len(tokens)] = torch.tensor(tokens) 287 | 288 | return result 289 | 290 | -------------------------------------------------------------------------------- /model_g_coop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch import optim 7 | import model 8 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 9 | 10 | _tokenizer = _Tokenizer() 11 | 12 | 13 | class TextEncoder(nn.Module): 14 | def __init__(self, clip_model): 15 | super().__init__() 16 | self.transformer = clip_model.transformer 17 | self.positional_embedding = clip_model.positional_embedding 18 | self.ln_final = clip_model.ln_final 19 | self.text_projection = clip_model.text_projection 20 | self.dtype = clip_model.dtype 21 | 22 | def forward(self, prompts, tokenized_prompts): 23 | x = prompts + self.positional_embedding.type(self.dtype) 24 | x = x.permute(1, 0, 2) # NLD -> LND 25 | x = self.transformer(x) 26 | x = x.permute(1, 0, 2) # LND -> NLD 27 | x = self.ln_final(x).type(self.dtype) 28 | 29 | # x.shape = [batch_size, n_ctx, transformer.width] 30 | # take features from the eot embedding (eot_token is the highest number in each sequence) 31 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 32 | 33 | return x 34 | 35 | 36 | class PromptLearner(nn.Module): 37 | def __init__(self, args, classnames, clip_model, g_texts): 38 | super().__init__() 39 | self.vars = nn.ParameterList() 40 | n_cls = len(classnames) 41 | n_ctx = args.coop_n_ctx 42 | dtype = clip_model.dtype 43 | ctx_dim = clip_model.ln_final.weight.shape[0] 44 | 45 | # random initialization 46 | if args.ctx_init: 47 | # use given words to initialize context vectors 48 | if args.class_specific: 49 | ctx_vectors = [] 50 | for ctx_list in g_texts: 51 | prompt = model.tokenize(ctx_list, context_length=args.context_length) 52 | with torch.no_grad(): 53 | embedding = clip_model.token_embedding(prompt).type(dtype) 54 | ctx_vector = embedding[:, 1: 1 + n_ctx, :] 55 | ctx_vector = torch.mean(ctx_vector, dim=0) 56 | ctx_vectors.append(ctx_vector) 57 | ctx_vectors = torch.stack(ctx_vectors) 58 | else: 59 | temp = [] 60 | for ctx_list in g_texts: 61 | temp += ctx_list 62 | prompt = model.tokenize(temp, context_length=args.context_length) 63 | with torch.no_grad(): 64 | embedding = clip_model.token_embedding(prompt).type(dtype) 65 | ctx_vector = embedding[:, 1: 1 + n_ctx, :] 66 | ctx_vectors = torch.mean(ctx_vector, dim=0) 67 | # print('ctx_vectors.shape', ctx_vectors.shape) 68 | else: 69 | if args.class_specific: 70 | # print("Initializing class-specific contexts") 71 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 72 | else: 73 | # print("Initializing a generic context") 74 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 75 | nn.init.normal_(ctx_vectors, std=0.02) 76 | prompt_prefix = " ".join(["X"] * n_ctx) 77 | 78 | # print(f'Initial context: "{prompt_prefix}"') 79 | # print(f"Number of context words (tokens): {n_ctx}") 80 | 81 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 82 | self.vars.append(self.ctx) 83 | 84 | classnames = [name.replace("_", " ") for name in classnames] 85 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 86 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 87 | 88 | tokenized_prompts = torch.cat( 89 | [model.tokenize(p, context_length=args.context_length) for p in prompts]) 90 | with torch.no_grad(): 91 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 92 | 93 | # These token vectors will be saved when in save_model(), 94 | # but they should be ignored in load_model() as we want to use 95 | # those computed using the current class names 96 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 97 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS 98 | 99 | self.n_cls = n_cls 100 | self.n_ctx = n_ctx 101 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 102 | self.name_lens = name_lens 103 | self.class_token_position = args.position 104 | 105 | def forward(self): 106 | ctx = self.ctx 107 | if ctx.dim() == 2: 108 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 109 | 110 | prefix = self.token_prefix 111 | suffix = self.token_suffix 112 | 113 | if self.class_token_position == "end": 114 | prompts = torch.cat( 115 | [ 116 | prefix, # (n_cls, 1, dim) 117 | ctx, # (n_cls, n_ctx, dim) 118 | suffix, # (n_cls, *, dim) 119 | ], 120 | dim=1, 121 | ) 122 | 123 | elif self.class_token_position == "middle": 124 | half_n_ctx = self.n_ctx // 2 125 | prompts = [] 126 | for i in range(self.n_cls): 127 | name_len = self.name_lens[i] 128 | prefix_i = prefix[i: i + 1, :, :] 129 | class_i = suffix[i: i + 1, :name_len, :] 130 | suffix_i = suffix[i: i + 1, name_len:, :] 131 | ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :] 132 | ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :] 133 | prompt = torch.cat( 134 | [ 135 | prefix_i, # (1, 1, dim) 136 | ctx_i_half1, # (1, n_ctx//2, dim) 137 | class_i, # (1, name_len, dim) 138 | ctx_i_half2, # (1, n_ctx//2, dim) 139 | suffix_i, # (1, *, dim) 140 | ], 141 | dim=1, 142 | ) 143 | prompts.append(prompt) 144 | prompts = torch.cat(prompts, dim=0) 145 | 146 | elif self.class_token_position == "front": 147 | prompts = [] 148 | for i in range(self.n_cls): 149 | name_len = self.name_lens[i] 150 | prefix_i = prefix[i: i + 1, :, :] 151 | class_i = suffix[i: i + 1, :name_len, :] 152 | suffix_i = suffix[i: i + 1, name_len:, :] 153 | ctx_i = ctx[i: i + 1, :, :] 154 | prompt = torch.cat( 155 | [ 156 | prefix_i, # (1, 1, dim) 157 | class_i, # (1, name_len, dim) 158 | ctx_i, # (1, n_ctx, dim) 159 | suffix_i, # (1, *, dim) 160 | ], 161 | dim=1, 162 | ) 163 | prompts.append(prompt) 164 | prompts = torch.cat(prompts, dim=0) 165 | 166 | else: 167 | raise ValueError 168 | 169 | return prompts 170 | 171 | def parameters(self): 172 | return self.vars 173 | 174 | 175 | class CustomCLIP(nn.Module): 176 | def __init__(self, args, classnames, clip_model, g_texts): 177 | super().__init__() 178 | self.prompt_learner = PromptLearner(args, classnames, clip_model, g_texts) 179 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 180 | self.image_encoder = clip_model.gnn 181 | self.text_encoder = TextEncoder(clip_model) 182 | self.logit_scale = clip_model.logit_scale 183 | self.dtype = clip_model.dtype 184 | 185 | def forward(self, s_n, x, adj): 186 | image_features = self.image_encoder(x, adj) 187 | image_features = image_features[s_n] 188 | prompts = self.prompt_learner() 189 | tokenized_prompts = self.tokenized_prompts 190 | text_features = self.text_encoder(prompts, tokenized_prompts) 191 | 192 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 193 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 194 | 195 | logit_scale = self.logit_scale.exp() 196 | logits = logit_scale * image_features @ text_features.t() 197 | 198 | return logits 199 | 200 | 201 | class CoOp(nn.Module): 202 | """Context Optimization (CoOp). 203 | Learning to Prompt for Vision-Language Models 204 | https://arxiv.org/abs/2109.01134 205 | """ 206 | 207 | def __init__(self, args, classnames, clip_model, g_texts, device): 208 | super().__init__() 209 | self.args = args 210 | self.classnames = classnames 211 | self.model = CustomCLIP(args, classnames, clip_model, g_texts) 212 | 213 | # print("Turning off gradients in both the image and the text encoder") 214 | for name, param in self.model.named_parameters(): 215 | if "prompt_learner" not in name: 216 | param.requires_grad_(False) 217 | 218 | # NOTE: only give prompt_learner to the optimizer 219 | # self.optim = build_optimizer(self.model.prompt_learner, args.OPTIM) 220 | self.model.to(device) 221 | 222 | self.optim = optim.Adam(self.model.prompt_learner.parameters(), lr=args.prompt_lr) 223 | 224 | def forward(self, s_n, x, adj, label, training=True): 225 | 226 | logits = self.model(s_n, x, adj) 227 | if training: 228 | loss = F.cross_entropy(logits, label) 229 | self.optim.zero_grad() 230 | torch.cuda.empty_cache() 231 | loss.backward() 232 | self.optim.step() 233 | 234 | return logits 235 | -------------------------------------------------------------------------------- /multitask.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | def multitask_data_generator(labels, labeled_node_list, select_array, k_spt, k_val, k_qry, n_way): 5 | labels_local = labels 6 | 7 | class_idx_list = [] 8 | train_class_list = [] 9 | val_class_list = [] 10 | test_class_list = [] 11 | for i in range(len(select_array)): 12 | class_idx_list.append([]) 13 | train_class_list.append([]) 14 | val_class_list.append([]) 15 | test_class_list.append([]) 16 | 17 | for j in labeled_node_list: 18 | for i in range(len(select_array)): 19 | if (labels_local[j] == select_array[i]): 20 | class_idx_list[i].append(j) 21 | 22 | usable_labels = [] 23 | for i in range(len(class_idx_list)): 24 | if len(class_idx_list[i]) >= 30: 25 | usable_labels.append(i) 26 | 27 | random.shuffle(usable_labels) 28 | task_list = [] 29 | for i in range(len(usable_labels) // n_way): 30 | task_idx = usable_labels[i * n_way:(i + 1) * n_way] 31 | task_list.append(task_idx) 32 | 33 | for i in range(len(select_array)): 34 | if i not in set(usable_labels): 35 | continue 36 | train_class_list[i] = np.random.choice(class_idx_list[i], k_spt, replace=False).tolist() 37 | val_class_temp = [n1 for n1 in class_idx_list[i] if n1 not in train_class_list[i]] 38 | val_class_list[i] = np.random.choice(val_class_temp, k_val, replace=False).tolist() 39 | test_class_temp = [n1 for n1 in class_idx_list[i] if 40 | (n1 not in train_class_list[i]) and (n1 not in val_class_list[i])] 41 | test_class_list[i] = test_class_temp 42 | 43 | train_idx = [] 44 | test_idx = [] 45 | val_idx = [] 46 | 47 | for i in range(len(task_list)): 48 | train_idx.append([]) 49 | test_idx.append([]) 50 | val_idx.append([]) 51 | # print(task_list[i]) 52 | for j in task_list[i]: 53 | train_idx[i] += train_class_list[j] 54 | val_idx[i] += val_class_list[j] 55 | test_idx[i] += test_class_list[j] 56 | 57 | return task_list, train_idx, val_idx, test_idx 58 | 59 | 60 | -------------------------------------------------------------------------------- /multitask_amazon.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def multitask_data_generator(labels, labeled_node_list, select_array, k_spt, k_val, k_qry, n_way): 6 | labels_local = labels 7 | 8 | class_idx_list = [] 9 | train_class_list = [] 10 | val_class_list = [] 11 | test_class_list = [] 12 | for i in range(len(select_array)): 13 | class_idx_list.append([]) 14 | train_class_list.append([]) 15 | val_class_list.append([]) 16 | test_class_list.append([]) 17 | 18 | for j in range(len(labeled_node_list)): 19 | for i in range(len(select_array)): 20 | if (labels_local[j] == select_array[i]): 21 | class_idx_list[i].append(labeled_node_list[j]) 22 | # labels_local[j] = 0 23 | 24 | usable_labels = [] 25 | for i in range(len(class_idx_list)): 26 | if len(class_idx_list[i]) >= 30: 27 | # if len(class_idx_list[i]) >= 100: 28 | usable_labels.append(i) 29 | 30 | random.shuffle(usable_labels) 31 | task_list = [] 32 | for i in range(len(usable_labels) // n_way): 33 | task_idx = usable_labels[i * n_way:(i + 1) * n_way] 34 | task_list.append(task_idx) 35 | 36 | for i in range(len(select_array)): 37 | if i not in set(usable_labels): 38 | continue 39 | train_class_list[i] = np.random.choice(class_idx_list[i], k_spt, replace=False).tolist() 40 | val_class_temp = [n1 for n1 in class_idx_list[i] if n1 not in train_class_list[i]] 41 | val_class_list[i] = np.random.choice(val_class_temp, k_val, replace=False).tolist() 42 | test_class_temp = [n1 for n1 in class_idx_list[i] if 43 | (n1 not in train_class_list[i]) and (n1 not in val_class_list[i])] 44 | test_class_list[i] = test_class_temp 45 | 46 | train_idx = [] 47 | test_idx = [] 48 | val_idx = [] 49 | 50 | for i in range(len(task_list)): 51 | train_idx.append([]) 52 | test_idx.append([]) 53 | val_idx.append([]) 54 | # print(task_list[i]) 55 | for j in task_list[i]: 56 | train_idx[i] += train_class_list[j] 57 | val_idx[i] += val_class_list[j] 58 | test_idx[i] += test_class_list[j] 59 | 60 | return task_list, train_idx, val_idx, test_idx 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0+cu113 2 | torch-cluster==1.6.0 3 | torch-geometric==2.0.4 4 | torch-scatter==2.0.9 5 | torch-sparse==0.6.13 6 | torch-spline-conv==1.2.1 7 | torchaudio==0.11.0+cu113 8 | torchvision==0.12.0+cu113 9 | transformers==4.19.2 10 | -------------------------------------------------------------------------------- /res/cora/init.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /zero-shot/datahelper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import sys 5 | import random 6 | import copy 7 | import torch 8 | from sklearn import preprocessing 9 | # np.random.seed(1) 10 | 11 | class DataHelper(Dataset): 12 | def __init__(self, node_list, transform=None): 13 | self.num_nodes = len(node_list) 14 | self.transform = transform 15 | self.node_list = node_list 16 | 17 | def __len__(self): 18 | return self.num_nodes 19 | 20 | def __getitem__(self, idx): 21 | 22 | # node_idx = np.arange(self.num_nodes)[idx] 23 | node_idx = self.node_list[idx] 24 | 25 | sample = { 26 | 'node_idx': node_idx, 27 | } 28 | 29 | if self.transform: 30 | sample = self.transform(sample) 31 | 32 | return sample 33 | -------------------------------------------------------------------------------- /zero-shot/zero-shot-amazon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import torch 4 | from random import sample 5 | import random 6 | import math 7 | import time 8 | from model import CLIP, tokenize 9 | from torch import nn, optim 10 | from sklearn import preprocessing 11 | from sklearn.metrics import accuracy_score, f1_score 12 | from multitask_amazon import multitask_data_generator 13 | from model_g_coop import CoOp 14 | import json 15 | from datahelper import DataHelper 16 | from torch.utils.data import DataLoader 17 | 18 | 19 | 20 | 21 | def setup_seed(seed): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def main(args): 31 | setup_seed(seed) 32 | 33 | model = CLIP(args) 34 | model.load_state_dict(torch.load('./res/{}/node_ttgt_8&12_10.pkl'.format(args.data_name), map_location=device)) 35 | 36 | task_list, train_idx, val_idx, test_idx = multitask_data_generator(lab_list, labeled_ids, labels, args.k_spt, 37 | args.k_val, args.k_qry, args.n_way) 38 | all_acc = [] 39 | f1_list = [] 40 | for j in range(len(task_list)): 41 | 42 | test_gt = [] 43 | for a in test_idx[j]: 44 | test_gt.append(id_lab_dict[str(a)]) 45 | model.eval() 46 | task_lables_arr = np.array(labels)[task_list[j]] 47 | task_lables = task_lables_arr.tolist() 48 | 49 | task_prompt = [] 50 | for a in range(len(task_lables)): 51 | prompt = the_template + ' ' + task_lables[a] 52 | task_prompt.append(prompt) 53 | # print('task_prompt', task_prompt) 54 | test_labels = tokenize(task_prompt, context_length=args.context_length).to(device) 55 | with torch.no_grad(): 56 | syn_class = model.encode_text(test_labels) 57 | 58 | Data = DataHelper(test_idx[j]) 59 | loader = DataLoader(Data, batch_size=args.batch_size, shuffle=False, num_workers=0) 60 | node_feas = [] 61 | for i_batch, sample_batched in enumerate(loader): 62 | idx_train = sample_batched['node_idx'].to(device) 63 | with torch.no_grad(): 64 | node_fea = model.encode_image(idx_train, node_f, edge_index) 65 | node_feas.append(node_fea) 66 | 67 | node_feas = torch.cat(node_feas, dim=0) 68 | 69 | # node_feas = torch.cat(node_feas, dim=0) 70 | syn_class /= syn_class.norm(dim=-1, keepdim=True) 71 | node_feas /= node_feas.norm(dim=-1, keepdim=True) 72 | similarity = (100.0 * node_feas @ syn_class.T).softmax(dim=-1) 73 | pred = similarity.argmax(dim=-1) 74 | pred = pred.cpu().numpy().reshape(-1) 75 | y_pred = task_lables_arr[pred] 76 | acc = accuracy_score(test_gt, y_pred) 77 | all_acc.append(acc) 78 | f1 = f1_score(test_gt, y_pred, average='macro') 79 | f1_list.append(f1) 80 | 81 | ans = round(np.mean(all_acc).item(), 4) 82 | print('zero shot acc', ans) 83 | 84 | ans = round(np.mean(f1_list).item(), 4) 85 | print('macro f1', ans) 86 | 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser() 90 | 91 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 92 | parser.add_argument('--ft_epoch', type=int, default=50, help='fine-tune epoch') 93 | parser.add_argument('--lr', type=float, default=2e-5) 94 | 95 | parser.add_argument('--batch_size', type=int, default=1000) 96 | parser.add_argument('--gnn_input', type=int, default=128) 97 | parser.add_argument('--gnn_hid', type=int, default=128) 98 | parser.add_argument('--gnn_output', type=int, default=128) 99 | 100 | parser.add_argument('--edge_coef', type=float, default=0.1) 101 | parser.add_argument('--neigh_num', type=int, default=3) 102 | 103 | parser.add_argument('--num_labels', type=int, default=5) 104 | parser.add_argument('--k_spt', type=int, default=5) 105 | parser.add_argument('--k_val', type=int, default=5) 106 | parser.add_argument('--k_qry', type=int, default=50) 107 | parser.add_argument('--n_way', type=int, default=5) 108 | 109 | parser.add_argument('--context_length', type=int, default=128) 110 | parser.add_argument('--coop_n_ctx', type=int, default=4) 111 | parser.add_argument('--prompt_lr', type=float, default=0.01) 112 | 113 | parser.add_argument('--position', type=str, default='end') 114 | parser.add_argument('--class_specific', type=bool, default=False) 115 | parser.add_argument('--ctx_init', type=bool, default=True) 116 | 117 | parser.add_argument('--embed_dim', type=int, default=128) 118 | parser.add_argument('--transformer_heads', type=int, default=8) 119 | parser.add_argument('--transformer_layers', type=int, default=12) 120 | parser.add_argument('--transformer_width', type=int, default=512) 121 | parser.add_argument('--vocab_size', type=int, default=49408) 122 | # parser.add_argument('--data_name', type=str, default="Arts_Crafts_and_Sewing") 123 | # parser.add_argument('--data_name', type=str, default="Industrial_and_Scientific") 124 | parser.add_argument('--data_name', type=str, default="Musical_Instruments") 125 | parser.add_argument('--gpu', type=int, default=0) 126 | 127 | args = parser.parse_args() 128 | 129 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 130 | print('device:', device) 131 | 132 | num_nodes = 0 133 | tit_list = [] 134 | tit_dict = json.load(open('./data/{}_text.json'.format(args.data_name))) 135 | new_dict = {} 136 | 137 | for i in range(len(tit_dict)): 138 | num_nodes += 1 139 | new_dict[i] = tit_dict[str(i)] 140 | 141 | print('num_nodes', num_nodes) 142 | 143 | edge_index = np.load('./data/{}_edge.npy'.format(args.data_name)) 144 | 145 | arr_edge_index = edge_index 146 | 147 | edge_index = torch.from_numpy(edge_index).to(device) 148 | 149 | node_f = np.load('./data/{}_f_m.npy'.format(args.data_name)) 150 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 151 | node_f = torch.from_numpy(node_f).to(device) 152 | 153 | id_lab_dict = json.load(open('./data/{}_id_labels.json'.format(args.data_name))) 154 | id_lab_list = sorted(id_lab_dict.items(), key=lambda d: int(d[0])) 155 | 156 | labeled_ids = [] 157 | lab_list = [] 158 | for i in id_lab_list: 159 | if i[1] != 'nan' or i[1] != '' or i[1] != ' ': 160 | labeled_ids.append(int(i[0])) 161 | lab_list.append(i[1]) 162 | 163 | labels = sorted(list(set(lab_list))) 164 | 165 | start = time.perf_counter() 166 | 167 | the_list = ['', 'a ', 'an ', 'of ', 'art ', 'sewing ', 'art of ', 'sewing of ', 'arts crafts of ', 'arts crafts or sewing of ', 'an arts crafts or sewing of '] 168 | the_template = the_list[0] 169 | seed = 1 170 | print('seed', seed) 171 | main(args) 172 | end = time.perf_counter() 173 | print("time consuming {:.2f}".format(end - start)) 174 | 175 | -------------------------------------------------------------------------------- /zero-shot/zero-shot-cora.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import torch 4 | from random import sample 5 | import random 6 | import math 7 | import time 8 | from model import CLIP, tokenize 9 | from torch import nn, optim 10 | from sklearn import preprocessing 11 | from sklearn.metrics import accuracy_score, f1_score 12 | # from multitask_2 import multitask_data_generator 13 | from multitask import multitask_data_generator 14 | from model_g_coop import CoOp 15 | import json 16 | from datahelper import DataHelper 17 | from torch.utils.data import DataLoader 18 | 19 | 20 | 21 | 22 | def setup_seed(seed): 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | def main(args): 32 | setup_seed(seed) 33 | 34 | model = CLIP(args) 35 | model.load_state_dict(torch.load('./res/{}/node_ttgt_8&12_0.1.pkl'.format(data_name), map_location=device)) 36 | 37 | task_list, train_idx, val_idx, test_idx = multitask_data_generator(lab_list, labeled_ids, labels, args.k_spt, 38 | args.k_val, args.k_qry, args.n_way) 39 | all_acc = [] 40 | f1_list = [] 41 | for j in range(len(task_list)): 42 | # test_gt = np.array(lab_list)[np.array(test_idx[i])] 43 | 44 | test_gt = np.array(lab_list)[np.array(test_idx[j])] 45 | model.eval() 46 | task_lables_arr = np.array(labels)[task_list[j]] 47 | task_lables = task_lables_arr.tolist() 48 | 49 | task_prompt = [] 50 | for a in range(len(task_lables)): 51 | prompt = the_template + task_lables[a] 52 | task_prompt.append(prompt) 53 | # print('task_prompt', task_prompt) 54 | test_labels = tokenize(task_prompt, context_length=args.context_length).to(device) 55 | with torch.no_grad(): 56 | syn_class = model.encode_text(test_labels) 57 | 58 | Data = DataHelper(test_idx[j]) 59 | loader = DataLoader(Data, batch_size=args.batch_size, shuffle=False, num_workers=0) 60 | node_feas = [] 61 | for i_batch, sample_batched in enumerate(loader): 62 | idx_train = sample_batched['node_idx'].to(device) 63 | with torch.no_grad(): 64 | node_fea = model.encode_image(idx_train, node_f, edge_index) 65 | node_feas.append(node_fea) 66 | 67 | node_feas = torch.cat(node_feas, dim=0) 68 | 69 | syn_class /= syn_class.norm(dim=-1, keepdim=True) 70 | node_feas /= node_feas.norm(dim=-1, keepdim=True) 71 | similarity = (100.0 * node_feas @ syn_class.T).softmax(dim=-1) 72 | pred = similarity.argmax(dim=-1) 73 | pred = pred.cpu().numpy().reshape(-1) 74 | y_pred = task_lables_arr[pred] 75 | acc = accuracy_score(test_gt, y_pred) 76 | all_acc.append(acc) 77 | f1 = f1_score(test_gt, y_pred, average='macro') 78 | f1_list.append(f1) 79 | 80 | ans = round(np.mean(all_acc).item(), 4) 81 | print('zero shot acc', ans) 82 | 83 | ans = round(np.mean(f1_list).item(), 4) 84 | print('macro f1', ans) 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | 90 | parser.add_argument('--aggregation_times', type=int, default=2, help='Aggregation times') 91 | parser.add_argument('--ft_epoch', type=int, default=50, help='fine-tune epoch') 92 | parser.add_argument('--lr', type=float, default=2e-5) 93 | 94 | parser.add_argument('--batch_size', type=int, default=64) 95 | parser.add_argument('--gnn_input', type=int, default=128) 96 | parser.add_argument('--gnn_hid', type=int, default=128) 97 | parser.add_argument('--gnn_output', type=int, default=128) 98 | 99 | parser.add_argument('--edge_coef', type=float, default=0.1) 100 | parser.add_argument('--neigh_num', type=int, default=3) 101 | 102 | parser.add_argument('--num_labels', type=int, default=5) 103 | parser.add_argument('--k_spt', type=int, default=5) 104 | parser.add_argument('--k_val', type=int, default=5) 105 | parser.add_argument('--k_qry', type=int, default=50) 106 | parser.add_argument('--n_way', type=int, default=5) 107 | 108 | parser.add_argument('--context_length', type=int, default=128) 109 | parser.add_argument('--coop_n_ctx', type=int, default=4) 110 | parser.add_argument('--prompt_lr', type=float, default=0.01) 111 | 112 | parser.add_argument('--position', type=str, default='end') 113 | parser.add_argument('--class_specific', type=bool, default=False) 114 | parser.add_argument('--ctx_init', type=bool, default=True) 115 | 116 | parser.add_argument('--embed_dim', type=int, default=128) 117 | parser.add_argument('--transformer_heads', type=int, default=8) 118 | parser.add_argument('--transformer_layers', type=int, default=12) 119 | parser.add_argument('--transformer_width', type=int, default=512) 120 | parser.add_argument('--vocab_size', type=int, default=49408) 121 | parser.add_argument('--gpu', type=int, default=0) 122 | 123 | args = parser.parse_args() 124 | 125 | data_name = 'cora' 126 | device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") 127 | print('device:', device) 128 | 129 | num_nodes = 0 130 | tit_list = [] 131 | lab_list = [] 132 | with open('./data/train_text.txt', 'r') as f: 133 | lines = f.readlines() 134 | for line in lines: 135 | line = line.strip().split('\t') 136 | tit_list.append(line[2]) 137 | lab_list.append(line[3]) 138 | num_nodes += 1 139 | 140 | print('num_nodes', num_nodes) 141 | 142 | labeled_ids = [] 143 | for i in range(len(lab_list)): 144 | if lab_list[i] != 'nan': 145 | labeled_ids.append(i) 146 | 147 | print('{} nodes having lables'.format(len(labeled_ids))) 148 | 149 | raw_edge_index = [[], []] 150 | with open('./data/mapped_edges.txt', 'r') as f: 151 | lines = f.readlines() 152 | for line in lines: 153 | line = line.strip().split() 154 | raw_edge_index[0].append(int(line[0])) 155 | raw_edge_index[1].append(int(line[1])) 156 | 157 | edge_index = [raw_edge_index[0] + raw_edge_index[1], raw_edge_index[1] + raw_edge_index[0]] 158 | arr_edge_index = np.array(edge_index) 159 | edge_index = np.array(edge_index) 160 | edge_index = torch.from_numpy(edge_index).to(device) 161 | 162 | node_f = np.load('./data/node_f.npy') 163 | node_f = preprocessing.StandardScaler().fit_transform(node_f) 164 | node_f = torch.from_numpy(node_f).to(device) 165 | 166 | # label_texts = [] 167 | with open('./data/lab_list.txt', 'r') as f: 168 | line = f.readline().strip().split('\t') 169 | label_texts = line 170 | 171 | labels = [] 172 | for i in label_texts: 173 | if i != 'nan': 174 | labels.append(i) 175 | 176 | start = time.perf_counter() 177 | 178 | the_list = ['', 'a ', 'an ', 'of ', 'paper of ', 'research of ', 'a paper of ', 'a research of ', 'a model of ', 179 | 'research paper of ', 'a research paper of '] 180 | 181 | 182 | the_template = the_list[0] 183 | seed = 1 184 | print('seed', seed) 185 | main(args) 186 | end = time.perf_counter() 187 | print("time consuming {:.2f}".format(end - start)) 188 | --------------------------------------------------------------------------------