├── LICENSE ├── README.md ├── api.py ├── args.py ├── async_api.py ├── baseline.py ├── data.py ├── dgl_main.py ├── ego_graph.py ├── few_shot_samples.py ├── gcn_lib ├── __init__.py ├── dense │ ├── __init__.py │ ├── torch_edge.py │ ├── torch_nn.py │ └── torch_vertex.py └── sparse │ ├── __init__.py │ ├── torch_edge.py │ ├── torch_message.py │ ├── torch_nn.py │ └── torch_vertex.py ├── generate_pyg_data.py ├── hyper.py ├── imgs ├── README.md ├── llm_as_enhancer.png └── llm_as_predictor.png ├── lmfinetune.py ├── models.py ├── ogbn_products.py ├── ood.py ├── rev ├── __init__.py ├── gcn_revop.py ├── memgcn.py └── rev_layer.py ├── secret.yaml ├── train_utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 1998czk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring the Potential of Large Language Models (LLMs) in Learning on Graphs 2 | 3 | 4 | **UPDATE**: The pt file for Citeseer has some problems. Please use the latest version `citeseer2` instead of the version inside small_data.zip. We use Graph Cleaner [Graph Cleaner](https://github.com/lywww/GraphCleaner/tree/master/case_studies) to fix wrong labels. 5 | 6 | This is the official code repository for our paper [Exploring the Potential of Large Language Models (LLMs) in Learning on Graphs](https://arxiv.org/abs/2307.03393) 7 | 8 | Followup-works: 9 | 10 | [LLM-GNN](https://github.com/CurryTang/LLMGNN) 11 | [TSGFM](https://github.com/CurryTang/TSGFM) 12 | 13 | 14 | ## Introduction 15 | Learning on Graphs has attracted immense attention due to its wide real-world applications. The most popular pipeline for learning on graphs with textual node attributes primarily relies on Graph Neural Networks (GNNs), and utilizes shallow text embedding as initial node representations, which has limitations in general knowledge and profound semantic understanding. In recent years, Large Language Models (LLMs) have been proven to possess extensive common knowledge and powerful semantic comprehension abilities that have revolutionized existing workflows to handle text data. In this paper, we aim to explore the potential of LLMs in graph machine learning, especially the node classification task, and investigate two possible pipelines: LLMs-as-Enhancers and LLMs-as-Predictors. The former leverages LLMs to enhance nodes' text attributes with their massive knowledge and then generate predictions through GNNs. The latter attempts to directly employ LLMs as standalone predictors. We conduct comprehensive and systematical studies on these two pipelines under various settings. From comprehensive empirical results, we make original observations and find new insights that open new possibilities and suggest promising directions to leverage LLMs for learning on graphs. 16 | 17 | We provide the implementation of the following pipelines. 18 | ### LLMs-as-Predictors 19 | Check `ego_graph.py`, and directly use `ChatGPT` to do zero-shot/few-shot predictions. 20 | ![LLMs-as-Predictors](https://github.com/CurryTang/Graph-LLM/blob/master/imgs/llm_as_predictor.png) 21 | 22 | ### LLMs-as-Enhancers 23 | Check `baseline.py`, various kinds of embedding-visible LLMs (like LLaMA, SentenceBERT, or text-ada-embedding-002) can be used to generate embeddings as node features. 24 | ![LLMs-as-Enhancers](https://github.com/CurryTang/Graph-LLM/blob/master/imgs/llm_as_enhancer.png) 25 | 26 | ### (New project) LLMs-as-Annotators 27 | Check out our new project here: [Label-free Node Classification on Graphs with Large Language Models (LLMS) 28 | ](https://github.com/CurryTang/LLMGNN) 29 | 30 | 31 | ## Citation 32 | ``` 33 | @article{Chen2023ExploringTP, 34 | title={Exploring the Potential of Large Language Models (LLMs) in Learning on Graphs}, 35 | author={Zhikai Chen and Haitao Mao and Hang Li and Wei Jin and Haifang Wen and Xiaochi Wei and Shuaiqiang Wang and Dawei Yin and Wenqi Fan and Hui Liu and Jiliang Tang}, 36 | journal={ArXiv}, 37 | year={2023}, 38 | volume={abs/2307.03393} 39 | } 40 | ``` 41 | 42 | ## 0. Environment Setup 43 | 44 | ### Package Installation 45 | Assume your cuda version is 11.8 46 | ``` 47 | conda create --name LLMGNN python=3.10 48 | conda activate LLMGNN 49 | 50 | conda install pytorch==2.0.0 cudatoolkit=11.8 -c pytorch 51 | conda install -c pyg pytorch-sparse 52 | conda install -c pyg pytorch-scatter 53 | conda install -c pyg pytorch-cluster 54 | conda install -c pyg pyg 55 | pip install ogb 56 | conda install -c dglteam/label/cu118 dgl 57 | pip install transformers 58 | pip install --upgrade accelerate 59 | pip install openai 60 | pip install langchain 61 | pip install gensim 62 | pip install google-generativeai 63 | pip install -U sentence-transformers 64 | pip install editdistance 65 | pip install InstructorEmbedding 66 | pip install optuna 67 | pip install tiktoken 68 | pip install pytorch_warmup 69 | ``` 70 | 71 | ### Dataset 72 | We have provided the processed datasets via the following [google drive link](https://drive.google.com/drive/folders/1_laNA6eSQ6M5td2LvsEp3IL9qF6BC1KV?usp=sharing) 73 | 74 | To unzip the files, you need to 75 | 1. unzip the `small_data.zip` into `preprocessed_data/new` 76 | 2. If you want to use ogbn-products, unzip `big_data.zip` info `preprocessed_data/new` 77 | 3. Download and move `*_explanation.pt` and `*_pl.pt` into `preprocessed_data/new`. These files are related to TAPE. 78 | 4. unzip the `ada.zip` into `./` 79 | 5. Move `*_entity.pt` into `./` 80 | 5. Put `ogb_arxiv.csv` into `./preprocessed_data` 81 | 82 | ### Get ft and no-ft LM embeddings 83 | 84 | Refer to the following scripts 85 | ``` bash 86 | for setting in "random" 87 | do 88 | for data in "cora" "pubmed" 89 | do 90 | WANDB_DISABLED=True CUDA_VISIBLE_DEVICES=3 python3 lmfinetune.py --dataset $data --split $setting --batch_size=9 --label_smoothing 0.3 --seed_num 5 91 | WANDB_DISABLED=True CUDA_VISIBLE_DEVICES=3 python3 lmfinetune.py --dataset $data --split $setting --batch_size=9 --label_smoothing 0.3 --seed_num 5 --use_explanation 1 92 | done 93 | done 94 | ``` 95 | 96 | ### Generate pt files for all data formats 97 | Run 98 | ``` python 99 | python3 generate_pyg_data.py 100 | ``` 101 | 102 | 103 | 104 | ## 1. Experiments for **LLM-as-Enhancers** 105 | 106 | For feature-level, **LLM-as-Enhancers**, you may replicate the experiments using files **baseline.py** and **lmfinetune.py** 107 | 108 | For example, you may run param sweep with the following script 109 | ``` bash 110 | for model in "GCN" "GAT" "MLP" 111 | do 112 | for data in "cora" "pubmed" 113 | do 114 | for setting in "random" 115 | do 116 | # Add more formats here 117 | for format in "ft" 118 | do 119 | CUDA_VISIBLE_DEVICES=1 python3 baseline.py --model_name $model --seed_num 5 --sweep_round 40 --mode sweep --dataset $data --split $setting --data_format $format 120 | echo "$model $data $setting $format done" 121 | done 122 | done 123 | done 124 | done 125 | ``` 126 | 127 | Run with a specific group of hyperparameters 128 | ``` bash 129 | python3 baseline.py --data_format sbert --split random --dataset pubmed --lr 0.01 --seed_num 5 130 | ``` 131 | 132 | Feature ensemble, separate each ensemble format with "\;" 133 | ``` bash 134 | CUDA_VISIBLE_DEVICES=1 python3 baseline.py --model_name GCN --num_split 1 --seed_num 5 --sweep_split 1 --sweep_round 5 --mode sweep --dataset pubmed --split random --ensemble_string sbert\;know_sep_sb\;ft\;pl\;know_exp_ft 135 | ``` 136 | 137 | Batch version for ogbn-products 138 | ``` bash 139 | CUDA_VISIBLE_DEVICES=7 python3 baseline.py --model_name SAGE --epochs 10 --num_split 1 --batchify 1 --dataset products --split fixed --data_format ft --normalize 1 --norm BatchNorm --mode main --lr 0.003 --dropout 0.5 --weight_decay 0 --hidden_dimension 256 --num_layers 3 140 | ``` 141 | 142 | To replicate the results for RevGAT (You need to first run once with the default features to generate the dgl data) 143 | ``` bash 144 | python dgl_main.py --data_root_dir ./dgldata \ 145 | --pretrain_path ./preprocessed_data/new/arxiv_fixed_sbert.pt \ 146 | --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --n-layers 2 --dropout 0.75 --n-hidden 256 --save kd --backbone rev --group 2 --mode teacher 147 | 148 | 149 | python dgl_main.py --data_root_dir ./dgldata \ 150 | --pretrain_path ./preprocessed_data/new/arxiv_fixed_sbert.pt \ 151 | --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --n-layers 2 --dropout 0.75 --n-hidden 256 --save kd --backbone rev --group 2 --mode student --alpha 0.95 --temp 0.7 152 | ``` 153 | 154 | To replicate the results for [SAGN](https://github.com/THUDM/SCR/tree/main/ogbn-products) and [GLEM](https://github.com/AndyJZhao/GLEM), you may check their repositories and put the processed pt file into their pipelines. 155 | 156 | 157 | 158 | 159 | 160 | 161 | ## 2. Experiments for **LLM-as-Predictors** 162 | 163 | Just run 164 | ``` bash 165 | python3 ego_graph.py 166 | ``` 167 | 168 | 169 | 170 | ## 3. (UPDATE) Further Experiments on OOD & Prompts 171 | 172 | In two recent studies titled [CAN LLMS EFFECTIVELY LEVERAGE GRAPH STRUCTURAL INFORMATION: WHEN AND WHY](https://arxiv.org/pdf/2309.16595.pdf) and [Explanations as Features: LLM-Based Features for Text-Attributed Graphs](https://arxiv.org/pdf/2305.19523), researchers probed a specific prompt tailored for the Arxiv dataset containing data from post-2023, data which ChatGPT's pre-training corpus doesn't cover. Notably, the results showed no decline in performance compared to the original dataset. This intriguing outcome prompts us to delve deeper into creating efficacious prompts across varied domains. 173 | 174 | Out-of-distribution (OOD) generalization, commonly known as Graph OOD, is a fervent area of discussion. Recent benchmarks, such as [GOOD](https://github.com/divelab/GOOD/tree/GOODv1/GOOD), indicate that GNNs don't fare well during structural and feature shifts. We embarked on an experiment using the Arxiv dataset to assess the potential of LLMs-as-Predictors, leveraging a prompt from [Explanations as Features: LLM-Based Features for Text-Attributed Graphs](https://arxiv.org/pdf/2305.19523), which exhibited superior performance. 175 | 176 | | | All avg | Val | Test | Best baseline (test) | 177 | |------------------ |-------------- |------- |---------- |---------------------- | 178 | | concept degree | 73.91 ± 0.63 | 73.01 | 72.79 | 63.00 | 179 | | covariate degree | 75.75 ± 3.6 | 70.23 | 68.21 | 59.08 | 180 | | concept time | 74.29 ± 0.96 | 72.66 | 71.98 | 67.45 | 181 | | covariate time | 72.69 ± 1.53 | 74.28 | 74.37 | 71.34 | 182 | 183 | * **Concept-shift**: Where P(Y|X) varies, yet its construct remains anchored to covariate-shift by adjusting the ratios in each domain. 184 | * **Covariate-shift**: While P(X) shifts, P(Y|X) remains consistent. 185 | 186 | For the covariate shift, there are configurations of 10/1/1 environments (train/val/test), and for the concept shift, it's 3/1/1 (train/val/test). The term `All avg` represents the mean performance across all environments. 187 | 188 | One discernible merit of using LLMs-as-Predictors is their heightened resilience to OOD shifts. 189 | 190 | -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | from yaml import load, dump 2 | try: 3 | from yaml import CLoader as Loader, CDumper as Dumper 4 | except ImportError: 5 | from yaml import Loader, Dumper 6 | import openai 7 | import math 8 | from tqdm import tqdm 9 | from langchain.chat_models import ChatOpenAI 10 | from langchain.callbacks import get_openai_callback 11 | from langchain.schema import SystemMessage, HumanMessage 12 | import asyncio 13 | import langchain 14 | from langchain.cache import SQLiteCache 15 | import os.path as osp 16 | import json 17 | langchain.llm_cache = SQLiteCache(database_path=".langchain.db") 18 | import os 19 | import logging 20 | from async_api import process_api_requests_from_file 21 | import tiktoken 22 | import pickle 23 | import json 24 | import google.generativeai as palm 25 | from utils import print_str_to_file 26 | import torch 27 | import time 28 | 29 | 30 | 31 | 32 | def persist_cache_to_disk(filename): 33 | def decorator(original_func): 34 | try: 35 | cache = pickle.load(open(filename, 'rb')) 36 | except (IOError, ValueError): 37 | cache = {} 38 | 39 | 40 | def new_func(*args, **kwargs): 41 | str_repr = json.dumps([args, kwargs], sort_keys=True) 42 | if str_repr not in cache: 43 | cache[str_repr] = original_func(*args, **kwargs) 44 | pickle.dump(cache, open(filename, "wb")) 45 | return cache[str_repr] 46 | 47 | return new_func 48 | 49 | return decorator 50 | 51 | def load_yaml_file(filename = 'config.yaml'): 52 | with open(filename, 'r') as stream: 53 | data = load(stream=stream, Loader=Loader) 54 | return data 55 | 56 | 57 | def get_embedding(text, model="text-embedding-ada-002"): 58 | text = text.replace("\n", " ") 59 | res = openai.Embedding.create(input = [text], model=model)['data'][0]['embedding'] 60 | return res 61 | 62 | def openai_ada_api(input_list, model_name = 'text-embedding-ada-002', max_len = 8190, max_batch = 1024): 63 | if len(input_list) < max_batch: 64 | input_list = [x[:max_len] for x in input_list] 65 | res = openai.Embedding.create(input = input_list, model=model_name)['data'] 66 | res = [x['embedding'] for x in res] 67 | return res 68 | else: 69 | input_list = [x[:max_len] for x in input_list] 70 | total_res = [] 71 | total_batch_num = math.ceil(len(input_list) / max_batch) 72 | for i in tqdm(range(total_batch_num)): 73 | sub_input_list = input_list[i * max_batch: (i + 1) * max_batch] 74 | res = openai.Embedding.create(input = sub_input_list, model=model_name)['data'] 75 | res = [x['embedding'] for x in res] 76 | total_res.extend(res) 77 | return total_res 78 | 79 | def openai_text_davinci_003(prompt, api_key): 80 | response = openai.Completion.create( 81 | model='text-davinci-003', 82 | prompt=prompt, 83 | temperature=0, 84 | max_tokens=1500, 85 | top_p=1, 86 | frequency_penalty=0, 87 | presence_penalty=0, 88 | api_key=api_key 89 | ) 90 | return response['choices'][0]['text'] 91 | 92 | 93 | def openai_text_api(input_text, api_key, model_name = "gpt-3.5-turbo", temperature = 0): 94 | response = openai.ChatCompletion.create( 95 | model=model_name, 96 | messages=[{"role": "user", "content": input_text}], 97 | temperature=temperature, 98 | api_key=api_key) 99 | return response 100 | 101 | @persist_cache_to_disk("./ogb/res_chat.pkl") 102 | def openai_text_api_list(input_texts): 103 | out = [] 104 | for x in tqdm(input_texts): 105 | resp = openai_text_api(x) 106 | out.append(resp) 107 | return out 108 | 109 | #df['ada_embedding'] = df.combined.apply(lambda x: get_embedding(x, model='text-embedding-ada-002')) 110 | #df.to_csv('output/embedded_1k_reviews.csv', index=False) 111 | 112 | async def openai_query_with_cost(instructions, generate_prompt): 113 | """ 114 | Given a list of instructions, query the corresponding openai api, and estimate the token 115 | usage and cost 116 | """ 117 | results = [] 118 | total_price = 0 119 | with get_openai_callback() as cb: 120 | llm = ChatOpenAI(model_name="gpt-3.5-turbo") 121 | prompts = generate_prompt(instructions) 122 | tasks = [chat_generate(llm, instruction, message) for instruction, message in tqdm(prompts)] 123 | results = await asyncio.gather(*tasks) 124 | total_price = cb.total_cost 125 | return results, total_price 126 | 127 | 128 | async def chat_generate(agent, instruction, message): 129 | if instruction: 130 | res = await agent([instruction, message]) 131 | else: 132 | res = await agent([message]) 133 | print("Generate") 134 | return res 135 | 136 | def generate_prompt_for_ogb_arxiv(instructions): 137 | generate_prompts = [] 138 | for line in instructions: 139 | title, abstract, category = line['title'], line['abstract'], line['category_name'] 140 | # instruction = SystemMessage(content="") 141 | message = HumanMessage(content=f"Given the title and abstract of a paper from arxiv.\n Title: {title}\nAbstract: {abstract}\n Summarize the key points of this paper which can best represent its category.") 142 | generate_prompts.append((None, message)) 143 | return generate_prompts 144 | 145 | def generate_prompt_for_correct(texts, max_tokens = 768): 146 | generate_prompts = [] 147 | for line in texts: 148 | line = line[:max_tokens] 149 | message = HumanMessage(content=f"Most words of the following text is misspelled, correct them \n{line}") 150 | generate_prompts.append((None, message)) 151 | return generate_prompts 152 | 153 | def generate_request_json_file_correct(texts, max_tokens = 300, filename = 'correct.jsonl'): 154 | filename = osp.join("./ogb/data", filename) 155 | jobs = [{"model": "gpt-3.5-turbo", "messages": [{'role': 'user', 'content': f"It seems a lot of words from the following paragraph lose some alphas in the end, can you help me correct them\n{line[:max_tokens]}"}]} for line in texts] 156 | with open(filename, "w+") as f: 157 | for job in jobs: 158 | json_string = json.dumps(job) 159 | f.write(json_string + "\n") 160 | 161 | 162 | async def call_async_api(request_filepath, save_filepath, request_url, api_key, max_request_per_minute, max_tokens_per_minute, sp, ss): 163 | await process_api_requests_from_file( 164 | requests_filepath=request_filepath, 165 | save_filepath=save_filepath, 166 | request_url=request_url, 167 | api_key=api_key, 168 | max_requests_per_minute=float(max_request_per_minute), 169 | max_tokens_per_minute=float(max_tokens_per_minute), 170 | token_encoding_name='cl100k_base', 171 | max_attempts=int(2), 172 | logging_level=int(logging.INFO), 173 | seconds_to_pause=sp, 174 | seconds_to_sleep=ss 175 | ) 176 | 177 | 178 | def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"): 179 | """Returns the number of tokens used by a list of messages.""" 180 | try: 181 | encoding = tiktoken.encoding_for_model(model) 182 | except KeyError: 183 | print("Warning: model not found. Using cl100k_base encoding.") 184 | encoding = tiktoken.get_encoding("cl100k_base") 185 | if model == "gpt-3.5-turbo": 186 | print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.") 187 | return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") 188 | elif model == "gpt-4": 189 | print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.") 190 | return num_tokens_from_messages(messages, model="gpt-4-0314") 191 | elif model == "gpt-3.5-turbo-0301": 192 | tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 193 | tokens_per_name = -1 # if there's a name, the role is omitted 194 | elif model == "gpt-4-0314": 195 | tokens_per_message = 3 196 | tokens_per_name = 1 197 | else: 198 | raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""") 199 | num_tokens = 0 200 | for message in messages: 201 | num_tokens += tokens_per_message 202 | for key, value in message.items(): 203 | num_tokens += len(encoding.encode(value)) 204 | if key == "name": 205 | num_tokens += tokens_per_name 206 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 207 | return num_tokens 208 | 209 | 210 | def num_tokens_from_string(string: str, model = "text-davinci-003") -> int: 211 | """Returns the number of tokens in a text string.""" 212 | encoding = tiktoken.encoding_for_model(model) 213 | num_tokens = len(encoding.encode(string)) 214 | return num_tokens 215 | 216 | def generate_chat_input_file(input_text, model_name = 'gpt-3.5-turbo'): 217 | jobs = [] 218 | for i, text in enumerate(input_text): 219 | obj = {} 220 | obj['model'] = model_name 221 | obj['messages'] = [ 222 | { 223 | 'role': 'user', 224 | 'content': text 225 | } 226 | ] 227 | jobs.append(obj) 228 | return jobs 229 | 230 | 231 | @persist_cache_to_disk("./async_req_davinci.pkl") 232 | def generate_davinci_003_input_file(input_text, model_name = 'text-davinci-003', max_token = 4096, temperature = 0.7, log_probs = None): 233 | jobs = [] 234 | for text in input_text: 235 | obj = {} 236 | obj['model'] = model_name 237 | obj['messages'] = [ 238 | { 239 | "model": "text-davinci-003", 240 | "prompt": text, 241 | "max_tokens": max_token, 242 | "temperature": temperature, 243 | "stream": False, 244 | "logprobs": log_probs 245 | } 246 | ] 247 | jobs.append(obj) 248 | return jobs 249 | 250 | 251 | def load_result_from_jsonline(json_file_name): 252 | openai_result = [] 253 | with open(json_file_name, 'r') as f: 254 | for line in f: 255 | json_obj = json.loads(line.strip()) 256 | openai_result.append(json_obj[1]['choices'][0]['message']['content']) 257 | return openai_result 258 | 259 | 260 | 261 | async def async_openai_text_api(input_text, api_key, model_name = "gpt-3.5-turbo"): 262 | response = await openai.ChatCompletion.acreate( 263 | model=model_name, 264 | messages=[{"role": "user", "content": input_text}], 265 | temperature=0.7, 266 | api_key=api_key) 267 | return response['choices'][0]['message']['content'] 268 | 269 | 270 | 271 | 272 | def efficient_openai_text_api(input_text, filename, savepath, sp, ss, api_key="change_this_to_your_key", rewrite = True): 273 | if not osp.exists(savepath) or rewrite: 274 | jobs = generate_chat_input_file(input_text) 275 | with open(filename, "w") as f: 276 | for job in jobs: 277 | json_string = json.dumps(job) 278 | f.write(json_string + "\n") 279 | asyncio.run( 280 | call_async_api( 281 | filename, save_filepath=savepath, 282 | request_url="https://api.openai.com/v1/chat/completions", 283 | api_key=api_key, 284 | max_request_per_minute=1000, 285 | max_tokens_per_minute=90000, 286 | sp=sp, 287 | ss=ss 288 | ) 289 | ) 290 | openai_result = [] 291 | with open(savepath, 'r') as f: 292 | for line in f: 293 | json_obj = json.loads(line.strip()) 294 | idx = json_obj[-1] 295 | if isinstance(idx, int): 296 | openai_result.append((json_obj[1]['choices'][0]['message']['content'], idx)) 297 | else: 298 | idx = json_obj[-2] 299 | new_result = openai_text_api(json_obj[0]['messages'][0]['content']) 300 | openai_result.append((new_result['choices'][0]['message']['content'], idx)) 301 | openai_result = sorted(openai_result, key=lambda x:x[-1]) 302 | return openai_result 303 | 304 | 305 | 306 | 307 | def google_text_generate_api(output_path, prompts, api_key = "change_this_to_your_key", model = 'models/text-bison-001', max_out_tokens = 512): 308 | if os.path.exists(osp.join(output_path, 'total.pt')): 309 | return torch.load(osp.join(output_path, 'total.pt')) 310 | palm.configure(api_key=api_key) 311 | results = [] 312 | for i, prompt in enumerate(tqdm(prompts)): 313 | completion = palm.generate_text( 314 | model=model, 315 | prompt=prompt, 316 | temperature=0, 317 | # The maximum length of the response 318 | max_output_tokens=max_out_tokens 319 | ) 320 | time.sleep(2) 321 | # import ipdb; ipdb.set_trace() 322 | results.append(completion.result) 323 | output_file_path = osp.join(output_path, f"{i}.txt") 324 | print_str_to_file(completion.result, output_file_path) 325 | torch.save(results, osp.join(output_path, 'total.pt')) 326 | return results 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def replace_args_with_dict_values(args, dictionary): 5 | for key, value in dictionary.items(): 6 | if hasattr(args, key): 7 | setattr(args, key, value) 8 | return args 9 | 10 | def get_command_line_args(): 11 | parser = argparse.ArgumentParser(description='LLM Graph') 12 | parser.add_argument('--dataset', default='arxiv', type=str) 13 | parser.add_argument('--normalize', default=0, type=int) 14 | parser.add_argument('--epochs', type=int, default=300) 15 | parser.add_argument('--early_stopping', type=int, default=10) 16 | parser.add_argument('--model_name', type=str, default='MLP') 17 | parser.add_argument('--norm', type=str, default=None) 18 | parser.add_argument('--seed_num', type=int, default=5) 19 | parser.add_argument('--return_embeds', type=int, default=1) 20 | parser.add_argument('--lr', type=float, default=0.01) 21 | parser.add_argument('--weight_decay', type=float, default=5e-4) 22 | parser.add_argument('--num_split', type=int, default=1) 23 | parser.add_argument('--sweep_split', type=int, default=1) 24 | parser.add_argument('--output_intermediate', type=int, default=0) 25 | parser.add_argument('--num_layers', type=int, default=2) 26 | parser.add_argument('--hidden_dimension', type=int, default=256) 27 | parser.add_argument('--dropout', type=float, default=0.5) 28 | parser.add_argument('--optim', type=str, default='adam') 29 | parser.add_argument('--warmup', default=10, type=int) 30 | parser.add_argument('--lr_gamma', default=0.998, type=float) 31 | # parser.add_argument('--subgraph_test', default=1, type=int) 32 | # parser.add_argument('--subgraph_node_number', default=10000, type=int) 33 | # parser.add_argument('--subgraph_train_number', default=3000, type=int) 34 | # parser.add_argument('--subgraph_val_number', default=2000, type=int) 35 | parser.add_argument('--data_format', type=str, default='sbert') 36 | parser.add_argument('--early_stop_start', type=int, default=400) 37 | # parser.add_argument('--sample_type', type=str, default='mlp') 38 | # parser.add_argument('--instruction_aware', type=int, default=0) 39 | # parser.add_argument('--need_cat', type=int, default=0) 40 | # parser.add_argument('--style_num', type=int, default=1) 41 | # parser.add_argument('--need_related', type=int, default=0) 42 | parser.add_argument('--alpha', type=float, default=0.9) 43 | parser.add_argument('--low_label_test', type=int, default=0) 44 | parser.add_argument('--few_shot_test', type=int, default=0) 45 | parser.add_argument('--split', type=str, default='fixed') 46 | parser.add_argument("--sweep_round", type=int, default=50) 47 | parser.add_argument('--mode', type=str, default="main") 48 | parser.add_argument('--inductive', type=int, default = 0) 49 | parser.add_argument('--batchify', type=int, default = 0) 50 | parser.add_argument('--num_of_heads', type=int, default = 8) 51 | parser.add_argument('--num_of_out_heads', type=int, default = 1) 52 | parser.add_argument("--save_logits", type=int, default=0) 53 | parser.add_argument("--ensemble", nargs='+', type=str, default=[]) 54 | parser.add_argument("--formats", nargs='+', type=str, default=[]) 55 | parser.add_argument("--ensemble_string", type=str, default="") 56 | parser.add_argument("--llm_pl", type=int, default=0) 57 | # parser.add_argument("--llm_pl_num", type=int, default=1) 58 | args = parser.parse_args() 59 | return args -------------------------------------------------------------------------------- /async_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | API REQUEST PARALLEL PROCESSOR 3 | 4 | Using the OpenAI API to process lots of text quickly takes some care. 5 | If you trickle in a million API requests one by one, they'll take days to complete. 6 | If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors. 7 | To maximize throughput, parallel requests need to be throttled to stay under rate limits. 8 | 9 | This script parallelizes requests to the OpenAI API while throttling to stay under rate limits. 10 | 11 | Features: 12 | - Streams requests from file, to avoid running out of memory for giant jobs 13 | - Makes requests concurrently, to maximize throughput 14 | - Throttles request and token usage, to stay under rate limits 15 | - Retries failed requests up to {max_attempts} times, to avoid missing data 16 | - Logs errors, to diagnose problems with requests 17 | 18 | Example command to call script: 19 | ``` 20 | python examples/api_request_parallel_processor.py \ 21 | --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \ 22 | --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \ 23 | --request_url https://api.openai.com/v1/embeddings \ 24 | --max_requests_per_minute 1500 \ 25 | --max_tokens_per_minute 6250000 \ 26 | --token_encoding_name cl100k_base \ 27 | --max_attempts 5 \ 28 | --logging_level 20 29 | ``` 30 | 31 | Inputs: 32 | - requests_filepath : str 33 | - path to the file containing the requests to be processed 34 | - file should be a jsonl file, where each line is a json object with API parameters 35 | - e.g., {"model": "text-embedding-ada-002", "input": "embed me"} 36 | - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically) 37 | - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl 38 | - the code to generate the example file is appended to the bottom of this script 39 | - save_filepath : str, optional 40 | - path to the file where the results will be saved 41 | - file will be a jsonl file, where each line is an array with the original request plus the API response 42 | - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}] 43 | - if omitted, results will be saved to {requests_filename}_results.jsonl 44 | - request_url : str, optional 45 | - URL of the API endpoint to call 46 | - if omitted, will default to "https://api.openai.com/v1/embeddings" 47 | - api_key : str, optional 48 | - API key to use 49 | - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")} 50 | - max_requests_per_minute : float, optional 51 | - target number of requests to make per minute (will make less if limited by tokens) 52 | - leave headroom by setting this to 50% or 75% of your limit 53 | - if requests are limiting you, try batching multiple embeddings or completions into one request 54 | - if omitted, will default to 1,500 55 | - max_tokens_per_minute : float, optional 56 | - target number of tokens to use per minute (will use less if limited by requests) 57 | - leave headroom by setting this to 50% or 75% of your limit 58 | - if omitted, will default to 125,000 59 | - token_encoding_name : str, optional 60 | - name of the token encoding used, as defined in the `tiktoken` package 61 | - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`) 62 | - max_attempts : int, optional 63 | - number of times to retry a failed request before giving up 64 | - if omitted, will default to 5 65 | - logging_level : int, optional 66 | - level of logging to use; higher numbers will log fewer messages 67 | - 40 = ERROR; will log only when requests fail after all retries 68 | - 30 = WARNING; will log when requests his rate limits or other errors 69 | - 20 = INFO; will log when requests start and the status at finish 70 | - 10 = DEBUG; will log various things as the loop runs to see when they occur 71 | - if omitted, will default to 20 (INFO). 72 | 73 | The script is structured as follows: 74 | - Imports 75 | - Define main() 76 | - Initialize things 77 | - In main loop: 78 | - Get next request if one is not already waiting for capacity 79 | - Update available token & request capacity 80 | - If enough capacity available, call API 81 | - The loop pauses if a rate limit error is hit 82 | - The loop breaks when no tasks remain 83 | - Define dataclasses 84 | - StatusTracker (stores script metadata counters; only one instance is created) 85 | - APIRequest (stores API inputs, outputs, metadata; one method to call API) 86 | - Define functions 87 | - api_endpoint_from_url (extracts API endpoint from request URL) 88 | - append_to_jsonl (writes to results file) 89 | - num_tokens_consumed_from_request (bigger function to infer token usage from request) 90 | - task_id_generator_function (yields 1, 2, 3, ...) 91 | - Run main() 92 | """ 93 | 94 | # imports 95 | import aiohttp # for making API calls concurrently 96 | import argparse # for running script from command line 97 | import asyncio # for running API calls concurrently 98 | import json # for saving results to a jsonl file 99 | import logging # for logging rate limit warnings and other messages 100 | import os # for reading API key 101 | import re # for matching endpoint from request URL 102 | import tiktoken # for counting tokens 103 | import time # for sleeping after rate limit is hit 104 | from dataclasses import dataclass # for storing API inputs, outputs, and metadata 105 | 106 | 107 | async def process_api_requests_from_file( 108 | requests_filepath: str, 109 | save_filepath: str, 110 | request_url: str, 111 | api_key: str, 112 | max_requests_per_minute: float, 113 | max_tokens_per_minute: float, 114 | token_encoding_name: str, 115 | max_attempts: int, 116 | logging_level: int, 117 | seconds_to_pause: int = 60, 118 | seconds_to_sleep: int = 0.1 119 | ): 120 | """Processes API requests in parallel, throttling to stay under rate limits.""" 121 | # constants 122 | seconds_to_pause_after_rate_limit_error = seconds_to_pause 123 | seconds_to_sleep_each_loop = seconds_to_sleep # 1 ms limits max throughput to 1,000 requests per second 124 | 125 | # initialize logging 126 | logging.basicConfig(level=logging_level) 127 | logging.debug(f"Logging initialized at level {logging_level}") 128 | 129 | # infer API endpoint and construct request header 130 | api_endpoint = api_endpoint_from_url(request_url) 131 | request_header = {"Authorization": f"Bearer {api_key}"} 132 | 133 | # initialize trackers 134 | queue_of_requests_to_retry = asyncio.Queue() 135 | task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... 136 | status_tracker = StatusTracker() # single instance to track a collection of variables 137 | next_request = None # variable to hold the next request to call 138 | 139 | # initialize available capacity counts 140 | available_request_capacity = max_requests_per_minute 141 | available_token_capacity = max_tokens_per_minute 142 | last_update_time = time.time() 143 | 144 | # initialize flags 145 | file_not_finished = True # after file is empty, we'll skip reading it 146 | logging.debug(f"Initialization complete.") 147 | 148 | with open(save_filepath, 'w') as f: 149 | pass 150 | 151 | # initialize file reading 152 | with open(requests_filepath) as file: 153 | # `requests` will provide requests one at a time 154 | requests = file.__iter__() 155 | logging.debug(f"File opened. Entering main loop") 156 | 157 | while True: 158 | # get next request (if one is not already waiting for capacity) 159 | if next_request is None: 160 | if not queue_of_requests_to_retry.empty(): 161 | next_request = queue_of_requests_to_retry.get_nowait() 162 | logging.debug(f"Retrying request {next_request.task_id}: {next_request}") 163 | elif file_not_finished: 164 | try: 165 | # get new request 166 | request_json = json.loads(next(requests)) 167 | next_request = APIRequest( 168 | task_id=next(task_id_generator), 169 | request_json=request_json, 170 | token_consumption=num_tokens_consumed_from_request(request_json, api_endpoint, token_encoding_name), 171 | attempts_left=max_attempts, 172 | ) 173 | status_tracker.num_tasks_started += 1 174 | status_tracker.num_tasks_in_progress += 1 175 | logging.debug(f"Reading request {next_request.task_id}: {next_request}") 176 | except StopIteration: 177 | # if file runs out, set flag to stop reading it 178 | logging.debug("Read file exhausted") 179 | file_not_finished = False 180 | 181 | # update available capacity 182 | current_time = time.time() 183 | seconds_since_update = current_time - last_update_time 184 | available_request_capacity = min( 185 | available_request_capacity + max_requests_per_minute * seconds_since_update / 60.0, 186 | max_requests_per_minute, 187 | ) 188 | available_token_capacity = min( 189 | available_token_capacity + max_tokens_per_minute * seconds_since_update / 60.0, 190 | max_tokens_per_minute, 191 | ) 192 | last_update_time = current_time 193 | 194 | # if enough capacity available, call API 195 | if next_request: 196 | next_request_tokens = next_request.token_consumption 197 | if ( 198 | available_request_capacity >= 1 199 | and available_token_capacity >= next_request_tokens 200 | ): 201 | # update counters 202 | available_request_capacity -= 1 203 | available_token_capacity -= next_request_tokens 204 | next_request.attempts_left -= 1 205 | 206 | # call API 207 | asyncio.create_task( 208 | next_request.call_api( 209 | request_url=request_url, 210 | request_header=request_header, 211 | retry_queue=queue_of_requests_to_retry, 212 | save_filepath=save_filepath, 213 | status_tracker=status_tracker, 214 | ) 215 | ) 216 | next_request = None # reset next_request to empty 217 | 218 | # if all tasks are finished, break 219 | if status_tracker.num_tasks_in_progress == 0: 220 | break 221 | 222 | # main loop sleeps briefly so concurrent tasks can run 223 | await asyncio.sleep(seconds_to_sleep_each_loop) 224 | 225 | # if a rate limit error was hit recently, pause to cool down 226 | seconds_since_rate_limit_error = (time.time() - status_tracker.time_of_last_rate_limit_error) 227 | if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: 228 | remaining_seconds_to_pause = (seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error) 229 | await asyncio.sleep(remaining_seconds_to_pause) 230 | # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago 231 | logging.warn(f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}") 232 | 233 | # after finishing, log final status 234 | logging.info(f"""Parallel processing complete. Results saved to {save_filepath}""") 235 | if status_tracker.num_tasks_failed > 0: 236 | logging.warning(f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}.") 237 | if status_tracker.num_rate_limit_errors > 0: 238 | logging.warning(f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") 239 | 240 | 241 | # dataclasses 242 | 243 | 244 | @dataclass 245 | class StatusTracker: 246 | """Stores metadata about the script's progress. Only one instance is created.""" 247 | 248 | num_tasks_started: int = 0 249 | num_tasks_in_progress: int = 0 # script ends when this reaches 0 250 | num_tasks_succeeded: int = 0 251 | num_tasks_failed: int = 0 252 | num_rate_limit_errors: int = 0 253 | num_api_errors: int = 0 # excluding rate limit errors, counted above 254 | num_other_errors: int = 0 255 | time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits 256 | 257 | 258 | @dataclass 259 | class APIRequest: 260 | """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" 261 | 262 | task_id: int 263 | request_json: dict 264 | token_consumption: int 265 | attempts_left: int 266 | result = [] 267 | 268 | async def call_api( 269 | self, 270 | request_url: str, 271 | request_header: dict, 272 | retry_queue: asyncio.Queue, 273 | save_filepath: str, 274 | status_tracker: StatusTracker, 275 | ): 276 | """Calls the OpenAI API and saves results.""" 277 | logging.info(f"Starting request #{self.task_id}") 278 | error = None 279 | try: 280 | async with aiohttp.ClientSession() as session: 281 | async with session.post( 282 | url=request_url, headers=request_header, json=self.request_json 283 | ) as response: 284 | response = await response.json() 285 | if "error" in response: 286 | logging.warning( 287 | f"Request {self.task_id} failed with error {response['error']}" 288 | ) 289 | status_tracker.num_api_errors += 1 290 | error = response 291 | if "Rate limit" in response["error"].get("message", ""): 292 | status_tracker.time_of_last_rate_limit_error = time.time() 293 | status_tracker.num_rate_limit_errors += 1 294 | status_tracker.num_api_errors -= 1 # rate limit errors are counted separately 295 | 296 | except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them 297 | logging.warning(f"Request {self.task_id} failed with Exception {e}") 298 | status_tracker.num_other_errors += 1 299 | error = e 300 | if error: 301 | self.result.append(error) 302 | if self.attempts_left: 303 | retry_queue.put_nowait(self) 304 | else: 305 | logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") 306 | append_to_jsonl([self.request_json, [str(e) for e in self.result], self.task_id, 'error'], save_filepath) 307 | status_tracker.num_tasks_in_progress -= 1 308 | status_tracker.num_tasks_failed += 1 309 | else: 310 | append_to_jsonl([self.request_json, response, self.task_id], save_filepath) 311 | status_tracker.num_tasks_in_progress -= 1 312 | status_tracker.num_tasks_succeeded += 1 313 | logging.debug(f"Request {self.task_id} saved to {save_filepath}") 314 | 315 | 316 | # functions 317 | 318 | 319 | def api_endpoint_from_url(request_url): 320 | """Extract the API endpoint from the request URL.""" 321 | match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) 322 | return match[1] 323 | 324 | 325 | def append_to_jsonl(data, filename: str) -> None: 326 | """Append a json payload to the end of a jsonl file.""" 327 | json_string = json.dumps(data) 328 | with open(filename, "a") as f: 329 | f.write(json_string + "\n") 330 | 331 | 332 | def num_tokens_consumed_from_request( 333 | request_json: dict, 334 | api_endpoint: str, 335 | token_encoding_name: str, 336 | ): 337 | """Count the number of tokens in the request. Only supports completion and embedding requests.""" 338 | encoding = tiktoken.get_encoding(token_encoding_name) 339 | # if completions request, tokens = prompt + n * max_tokens 340 | if api_endpoint.endswith("completions"): 341 | max_tokens = request_json.get("max_tokens", 15) 342 | n = request_json.get("n", 1) 343 | completion_tokens = n * max_tokens 344 | 345 | # chat completions 346 | if api_endpoint.startswith("chat/"): 347 | num_tokens = 0 348 | for message in request_json["messages"]: 349 | num_tokens += 4 # every message follows {role/name}\n{content}\n 350 | for key, value in message.items(): 351 | num_tokens += len(encoding.encode(value)) 352 | if key == "name": # if there's a name, the role is omitted 353 | num_tokens -= 1 # role is always required and always 1 token 354 | num_tokens += 2 # every reply is primed with assistant 355 | return num_tokens + completion_tokens 356 | # normal completions 357 | else: 358 | prompt = request_json["prompt"] 359 | if isinstance(prompt, str): # single prompt 360 | prompt_tokens = len(encoding.encode(prompt)) 361 | num_tokens = prompt_tokens + completion_tokens 362 | return num_tokens 363 | elif isinstance(prompt, list): # multiple prompts 364 | prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) 365 | num_tokens = prompt_tokens + completion_tokens * len(prompt) 366 | return num_tokens 367 | else: 368 | raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') 369 | # if embeddings request, tokens = input tokens 370 | elif api_endpoint == "embeddings": 371 | input = request_json["input"] 372 | if isinstance(input, str): # single input 373 | num_tokens = len(encoding.encode(input)) 374 | return num_tokens 375 | elif isinstance(input, list): # multiple inputs 376 | num_tokens = sum([len(encoding.encode(i)) for i in input]) 377 | return num_tokens 378 | else: 379 | raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') 380 | # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) 381 | else: 382 | raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') 383 | 384 | 385 | def task_id_generator_function(): 386 | """Generate integers 0, 1, 2, and so on.""" 387 | task_id = 0 388 | while True: 389 | yield task_id 390 | task_id += 1 391 | 392 | 393 | # run script 394 | 395 | 396 | if __name__ == "__main__": 397 | # parse command line arguments 398 | parser = argparse.ArgumentParser() 399 | parser.add_argument("--requests_filepath") 400 | parser.add_argument("--save_filepath", default=None) 401 | parser.add_argument("--request_url", default="https://api.openai.com/v1/embeddings") 402 | parser.add_argument("--api_key", default=os.getenv("OPENAI_API_KEY")) 403 | parser.add_argument("--max_requests_per_minute", type=int, default=3_000 * 0.5) 404 | parser.add_argument("--max_tokens_per_minute", type=int, default=90000) 405 | parser.add_argument("--token_encoding_name", default="cl100k_base") 406 | parser.add_argument("--max_attempts", type=int, default=2) 407 | parser.add_argument("--logging_level", default=logging.INFO) 408 | args = parser.parse_args() 409 | 410 | if args.save_filepath is None: 411 | args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl") 412 | 413 | # run script 414 | asyncio.run( 415 | process_api_requests_from_file( 416 | requests_filepath=args.requests_filepath, 417 | save_filepath=args.save_filepath, 418 | request_url=args.request_url, 419 | api_key=args.api_key, 420 | max_requests_per_minute=float(args.max_requests_per_minute), 421 | max_tokens_per_minute=float(args.max_tokens_per_minute), 422 | token_encoding_name=args.token_encoding_name, 423 | max_attempts=int(args.max_attempts), 424 | logging_level=int(args.logging_level), 425 | ) 426 | ) 427 | 428 | 429 | """ 430 | APPENDIX 431 | 432 | The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002. 433 | 434 | It was generated with the following code: 435 | 436 | ```python 437 | import json 438 | 439 | filename = "data/example_requests_to_parallel_process.jsonl" 440 | n_requests = 10_000 441 | jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)] 442 | with open(filename, "w") as f: 443 | for job in jobs: 444 | json_string = json.dumps(job) 445 | f.write(json_string + "\n") 446 | ``` 447 | 448 | As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically). 449 | """ -------------------------------------------------------------------------------- /baseline.py: -------------------------------------------------------------------------------- 1 | from args import * 2 | from data import get_dataset, set_seed_config, set_api_key, pkl_and_write, get_tf_idf_by_texts 3 | import torch 4 | from train_utils import train, test, get_optimizer, confidence_test, topk_test, to_inductive, batch_train, batch_test 5 | from models import get_model 6 | import numpy as np 7 | import ipdb 8 | import optuna 9 | from torch.utils.tensorboard import SummaryWriter 10 | import openai 11 | from copy import deepcopy 12 | import logging 13 | import time 14 | from torch_geometric.utils import index_to_mask 15 | import optuna 16 | import sys 17 | from hyper import hyper_search 18 | import os.path as osp 19 | import torch.nn.functional as F 20 | from torch_geometric.loader import NeighborLoader 21 | from utils import delete_non_tensor_attributes 22 | from ogb.nodeproppred import Evaluator 23 | from collections import defaultdict 24 | 25 | 26 | 27 | def train_pipeline_batch(seeds, args, epoch, data, writer, need_train, mode="main"): 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | test_result_acc = [] 30 | early_stop_accum = 0 31 | val_result_acc = [] 32 | out_res = [] 33 | best_val = 0 34 | evaluator = Evaluator(name='ogbn-products') 35 | if args.inductive: 36 | data = to_inductive(data) 37 | if mode == "main": 38 | split_num = args.num_split 39 | else: 40 | split_num = args.sweep_split 41 | split = 0 42 | data.train_mask = data.train_masks[split] 43 | data.val_mask = data.val_masks[split] 44 | data.test_mask = data.test_masks[split] 45 | data = delete_non_tensor_attributes(data) 46 | assert split_num == 1 47 | for seed in seeds: 48 | set_seed_config(seed) 49 | model = get_model(args).to(device) 50 | optimizer, scheduler = get_optimizer(args, model) 51 | loss_fn = torch.nn.CrossEntropyLoss() 52 | best_val = 0 53 | for split in range(split_num): 54 | if args.normalize: 55 | data.x = F.normalize(data.x, dim = -1) 56 | input_nodes = torch.arange(data.x.shape[0])[data.train_mask] 57 | # import ipdb; ipdb.set_trace() 58 | data = data.to(device, 'x', 'y') 59 | subgraph_loader = NeighborLoader(data, input_nodes=input_nodes, 60 | num_neighbors=[15, 10, 5], 61 | batch_size=1024, shuffle=True, 62 | num_workers=4) 63 | val_loader = NeighborLoader(data, input_nodes=None, batch_size=4096, shuffle=False, 64 | num_neighbors=[-1], num_workers=1, persistent_workers=True) 65 | # import ipdb; ipdb.set_trace() 66 | for epoch in range(1, args.epochs + 1): 67 | train_loss = batch_train(model, subgraph_loader, optimizer, device) 68 | if scheduler: 69 | scheduler.step() 70 | val_acc = batch_test(model, data, evaluator, val_loader, device, data.val_mask) 71 | print(f"Epoch {epoch}: Train loss: {train_loss}, Val acc: {val_acc}") 72 | if val_acc > best_val: 73 | best_val = val_acc 74 | best_model = deepcopy(model) 75 | early_stop_accum = 0 76 | else: 77 | if epoch >= args.early_stop_start: 78 | early_stop_accum += 1 79 | if early_stop_accum > args.early_stopping and epoch >= args.early_stop_start: 80 | break 81 | test_acc = batch_test(model, data, evaluator, val_loader, device, data.test_mask) 82 | val_result_acc.append(val_acc) 83 | test_result_acc.append(test_acc) 84 | return test_result_acc, val_result_acc 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | def train_pipeline(seeds, args, epoch, data, writer, need_train, mode="main"): 94 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 95 | test_result_acc = [] 96 | early_stop_accum = 0 97 | val_result_acc = [] 98 | out_res = [] 99 | 100 | if args.inductive: 101 | data = to_inductive(data) 102 | if mode == "main": 103 | split_num = args.num_split 104 | else: 105 | split_num = args.sweep_split 106 | for i, seed in enumerate(seeds): 107 | best_val = 0 108 | set_seed_config(seed) 109 | model = get_model(args).to(device) 110 | optimizer, scheduler = get_optimizer(args, model) 111 | loss_fn = torch.nn.CrossEntropyLoss() # if hasattr(data, "xs"): 112 | # data.x = data.xs[0] 113 | if args.normalize: 114 | data.x = F.normalize(data.x, dim = -1) 115 | data = data.to(device) 116 | data.train_mask = data.train_masks[i] 117 | data.val_mask = data.val_masks[i] 118 | data.test_mask = data.test_masks[i] 119 | if 'ft' in args.data_format: 120 | data.x = data.xs[i] 121 | data.train_mask = data.train_masks[i] 122 | data.val_mask = data.val_masks[i] 123 | data.test_mask = data.test_masks[i] 124 | if args.split == 'pl_fixed' or args.split == 'pl_random': 125 | data.train_mask = data.train_masks[i] 126 | data.val_mask = data.val_masks[i] 127 | data.test_mask = data.test_masks[i] 128 | data.backup_y = data.y 129 | data.y = data.ys[i] 130 | # import ipdb; ipdb.set_trace() 131 | for i in range(epoch): 132 | # ipdb.set_trace() 133 | train_mask = data.train_mask 134 | val_mask = data.val_mask 135 | if need_train: 136 | train_loss, val_loss, val_acc = train(model, data, optimizer, loss_fn, train_mask, val_mask) 137 | if writer != None: 138 | writer.add_scalar('Loss/train', train_loss, i) 139 | writer.add_scalar('Loss/val', val_loss, i) 140 | writer.add_scalar('Acc/val', val_acc[0], i) 141 | if scheduler: 142 | scheduler.step() 143 | if args.output_intermediate: 144 | print(f"Epoch {i}: Train loss: {train_loss}, Val loss: {val_loss}, Val acc: {val_acc[0]}") 145 | if val_acc[0] > best_val: 146 | best_val = val_acc[0] 147 | best_model = deepcopy(model) 148 | early_stop_accum = 0 149 | else: 150 | if i >= args.early_stop_start: 151 | early_stop_accum += 1 152 | if early_stop_accum > args.early_stopping and i >= args.early_stop_start: 153 | break 154 | else: 155 | best_model = model 156 | if 'pl' in args.split: 157 | data.y = data.backup_y 158 | test_acc, res = test(best_model, data, args.return_embeds, data.test_mask) 159 | test_result_acc.append(test_acc) 160 | val_result_acc.append(best_val) 161 | out_res.append(res) 162 | # del data 163 | # del best_model 164 | return test_result_acc, val_result_acc, out_res 165 | 166 | 167 | 168 | 169 | 170 | def main(args = None, custom_args = None, save_best = False): 171 | seeds = [i for i in range(args.seed_num)] 172 | writer = SummaryWriter() 173 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 174 | if custom_args != None: 175 | args = replace_args_with_dict_values(args, custom_args) 176 | data = get_dataset(args.seed_num, args.dataset, args.split, args.data_format, args.low_label_test) 177 | seeds = [i for i in range(args.seed_num)] 178 | best_model = None 179 | best_val = 0 180 | epoch = args.epochs 181 | vars(args)['input_dim'] = data.x.shape[1] 182 | vars(args)['num_classes'] = data.y.max().item() + 1 183 | if args.model_name == 'LP': 184 | need_train = False 185 | else: 186 | need_train = True 187 | if not args.batchify and args.ensemble_string == "": 188 | data.x = data.x.to(torch.float32) 189 | test_result_acc, val_result_acc, out_res = train_pipeline(seeds, args, epoch, data, writer, need_train) 190 | mean_test_acc = np.mean(test_result_acc) * 100 191 | std_test_acc = np.std(test_result_acc) * 100 192 | print(f"Test Accuracy: {mean_test_acc:.2f} ± {std_test_acc:.2f}") 193 | print("Test acc: {}".format(test_result_acc)) 194 | pkl_and_write(out_res, f'./output/{args.model_name}_{args.dataset}_{args.data_format}.pkl') 195 | elif args.ensemble_string != "": 196 | feats = args.ensemble_string.split(";") 197 | res = [] 198 | sep_test_acc = defaultdict(list) 199 | labels = data.y 200 | test_masks = data.test_masks 201 | for feat in feats: 202 | vars(args)['data_format'] = feat 203 | data = get_dataset(args.seed_num, args.dataset, args.split, args.data_format, args.low_label_test) 204 | vars(args)['input_dim'] = data.x.shape[1] 205 | vars(args)['num_classes'] = data.y.max().item() + 1 206 | # model = get_model(args).to(device) 207 | # optimizer, scheduler = get_optimizer(args, model) 208 | data.x = data.x.to(torch.float32) 209 | test_result_acc, val_result_acc, out_res = train_pipeline(seeds, args, epoch, data, writer, need_train) 210 | res.append(out_res) 211 | sep_test_acc[feat] = test_result_acc 212 | for key, value in sep_test_acc.items(): 213 | mean = np.mean(value) * 100 214 | std = np.std(value) * 100 215 | print(f"{key}: {mean:.2f} ± {std:.2f}") 216 | ensemble_input = [[res[i][j] for i in range(len(feats))] for j in range(len(seeds))] 217 | ensemble_helper(ensemble_input, labels, test_masks) 218 | else: 219 | test_result_acc, val_result_acc = train_pipeline_batch(seeds, args, epoch, data, writer, need_train) 220 | mean_test_acc = np.mean(test_result_acc) * 100.0 221 | std_test_acc = np.std(test_result_acc) * 100.0 222 | print(f"Test Accuracy: {mean_test_acc:.4f} ± {std_test_acc:.4f}") 223 | print("Test acc: {}".format(test_result_acc)) 224 | if save_best: 225 | pkl_and_write(args, osp.join("./bestargs", f"{args.model_name}_{args.dataset}_{args.data_format}.pkl")) 226 | writer.close() 227 | 228 | 229 | 230 | def max_trial_callback(study, trial, max_try): 231 | n_complete = len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE or t.state == optuna.trial.TrialState.RUNNING]) 232 | n_total_complete = len([t for t in study.trials]) 233 | if n_complete >= max_try or n_total_complete >= 2 * max_try: 234 | study.stop() 235 | torch.cuda.empty_cache() 236 | 237 | 238 | def sweep(args = None): 239 | # test_seeds = [i for i in range(args.seed_num)] 240 | sweep_seeds = [0, 1, 2, 3, 4] 241 | ## get default command line args 242 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 243 | optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout)) 244 | study_name = f"{args.dataset}_{args.model_name}_{args.data_format}_{args.split}" 245 | study = optuna.create_study(study_name=study_name, storage=None, direction='maximize', load_if_exists=True) 246 | param_f = hyper_search 247 | sweep_round = args.sweep_round 248 | study.optimize(lambda trial: sweep_run(trial, args, sweep_seeds, param_f, device), catch=(RuntimeError,), n_trials=sweep_round, callbacks=[lambda study, trial: max_trial_callback(study, trial, sweep_round)], show_progress_bar=True, gc_after_trial=True) 249 | main(args=args, custom_args = study.best_trial.params, save_best = True) 250 | print(study.best_trial.params) 251 | 252 | 253 | 254 | def sweep_run(trial, args, sweep_seeds, param_f, device): 255 | params = param_f(trial, args.data_format, args.model_name, args.dataset) 256 | args = replace_args_with_dict_values(args, params) 257 | data = get_dataset(args.seed_num, args.dataset, args.split, args.data_format, args.low_label_test).to(device) 258 | best_model = None 259 | best_val = 0 260 | epoch = args.epochs 261 | vars(args)['input_dim'] = data.x.shape[1] 262 | vars(args)['num_classes'] = data.y.max().item() + 1 263 | # model = get_model(args).to(device) 264 | # optimizer, scheduler = get_optimizer(args, model) 265 | # loss_fn = torch.nn.CrossEntropyLoss() 266 | if args.model_name == 'LP': 267 | need_train = False 268 | else: 269 | need_train = True 270 | if not args.batchify and args.ensemble_string == "": 271 | data.x = data.x.to(torch.float32) 272 | test_result_acc, val_result_acc, out_res = train_pipeline(sweep_seeds, args, epoch, data, None, need_train, mode="sweep") 273 | elif args.ensemble_string != "": 274 | feats = args.ensemble_string.split(";") 275 | res = [] 276 | sep_test_acc = defaultdict(list) 277 | labels = data.y 278 | test_masks = data.test_masks 279 | for feat in feats: 280 | vars(args)['data_format'] = feat 281 | data = get_dataset(args.seed_num, args.dataset, args.split, args.data_format, args.low_label_test) 282 | vars(args)['input_dim'] = data.x.shape[1] 283 | vars(args)['num_classes'] = data.y.max().item() + 1 284 | # model = get_model(args).to(device) 285 | # optimizer, scheduler = get_optimizer(args, model) 286 | data.x = data.x.to(torch.float32) 287 | test_result_acc, val_result_acc, out_res = train_pipeline(sweep_seeds, args, epoch, data, None, need_train, mode="sweep") 288 | res.append(out_res) 289 | sep_test_acc[feat] = test_result_acc 290 | for key, value in sep_test_acc.items(): 291 | print(f"{key}: {np.mean(value):.4f} ± {np.std(value):.4f}") 292 | ensemble_input = [[res[i][j] for i in range(len(feats))] for j in range(len(sweep_seeds))] 293 | mean_test_acc, _ = ensemble_helper(ensemble_input, labels, test_masks) 294 | return mean_test_acc 295 | else: 296 | test_result_acc, val_result_acc = train_pipeline_batch(seeds, args, epoch, data, writer, need_train, mode="sweep") 297 | mean_test_acc = np.mean(test_result_acc) 298 | std_test_acc = np.std(test_result_acc) 299 | print(f"Test Accuracy: {mean_test_acc} ± {std_test_acc}") 300 | # mean_val_acc = np.mean(val_result_acc) 301 | # std_val_acc = np.std(val_result_acc) 302 | # print(f"Val Accuracy: {mean_val_acc} ± {std_val_acc}") 303 | return mean_test_acc 304 | 305 | 306 | 307 | 308 | @torch.no_grad() 309 | def ensemble_helper(logits, labels, test_masks): 310 | seeds_num = len(logits) 311 | accs = [] 312 | for i in range(seeds_num): 313 | test_mask = test_masks[i].cpu() 314 | this_seed_logits = logits[i] 315 | avg_logits = sum(this_seed_logits) / len(this_seed_logits) 316 | pred = torch.argmax(avg_logits, dim=1).cpu() 317 | labels = labels.cpu() 318 | acc = torch.sum(pred[test_mask] == labels[test_mask]).item() / len(labels[test_mask]) 319 | accs.append(acc) 320 | mean_test_acc = np.mean(accs) * 100.0 321 | std_test_acc = np.std(accs) * 100.0 322 | print(f"Ensemble Accuracy: {mean_test_acc:.2f} ± {std_test_acc:.2f}") 323 | return mean_test_acc, std_test_acc 324 | 325 | 326 | 327 | 328 | 329 | if __name__ == '__main__': 330 | current_time = int(time.time()) 331 | logging.basicConfig(filename='./logs/{}.log'.format(current_time), 332 | filemode='a', 333 | format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', 334 | datefmt='%H:%M:%S', 335 | level=logging.INFO) 336 | # set_seed_config(42) 337 | ## get mode: sweep or main 338 | args = get_command_line_args() 339 | set_api_key() 340 | # param_search() 341 | if args.mode == "main": 342 | main(args = args) 343 | else: 344 | sweep(args = args) 345 | 346 | 347 | 348 | -------------------------------------------------------------------------------- /gcn_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryTang/Graph-LLM/344f7c6b7786c7f8293c24ce5b90f141c777aeec/gcn_lib/__init__.py -------------------------------------------------------------------------------- /gcn_lib/dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * 4 | 5 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_cluster import knn_graph 4 | 5 | 6 | class DenseDilated(nn.Module): 7 | """ 8 | Find dilated neighbor from neighbor list 9 | 10 | edge_index: (2, batch_size, num_points, k) 11 | """ 12 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 13 | super(DenseDilated, self).__init__() 14 | self.dilation = dilation 15 | self.stochastic = stochastic 16 | self.epsilon = epsilon 17 | self.k = k 18 | 19 | def forward(self, edge_index): 20 | if self.stochastic: 21 | if torch.rand(1) < self.epsilon and self.training: 22 | num = self.k * self.dilation 23 | randnum = torch.randperm(num)[:self.k] 24 | edge_index = edge_index[:, :, :, randnum] 25 | else: 26 | edge_index = edge_index[:, :, :, ::self.dilation] 27 | else: 28 | edge_index = edge_index[:, :, :, ::self.dilation] 29 | return edge_index 30 | 31 | 32 | def pairwise_distance(x): 33 | """ 34 | Compute pairwise distance of a point cloud. 35 | Args: 36 | x: tensor (batch_size, num_points, num_dims) 37 | Returns: 38 | pairwise distance: (batch_size, num_points, num_points) 39 | """ 40 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 41 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 42 | return x_square + x_inner + x_square.transpose(2, 1) 43 | 44 | 45 | def dense_knn_matrix(x, k=16): 46 | """Get KNN based on the pairwise distance. 47 | Args: 48 | x: (batch_size, num_dims, num_points, 1) 49 | k: int 50 | Returns: 51 | nearest neighbors: (batch_size, num_points ,k) (batch_size, num_points, k) 52 | """ 53 | with torch.no_grad(): 54 | x = x.transpose(2, 1).squeeze(-1) 55 | batch_size, n_points, n_dims = x.shape 56 | _, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k) 57 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) 58 | return torch.stack((nn_idx, center_idx), dim=0) 59 | 60 | 61 | class DenseDilatedKnnGraph(nn.Module): 62 | """ 63 | Find the neighbors' indices based on dilated knn 64 | """ 65 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 66 | super(DenseDilatedKnnGraph, self).__init__() 67 | self.dilation = dilation 68 | self.stochastic = stochastic 69 | self.epsilon = epsilon 70 | self.k = k 71 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 72 | self.knn = dense_knn_matrix 73 | 74 | def forward(self, x): 75 | edge_index = self.knn(x, self.k * self.dilation) 76 | return self._dilated(edge_index) 77 | 78 | 79 | class DilatedKnnGraph(nn.Module): 80 | """ 81 | Find the neighbors' indices based on dilated knn 82 | """ 83 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 84 | super(DilatedKnnGraph, self).__init__() 85 | self.dilation = dilation 86 | self.stochastic = stochastic 87 | self.epsilon = epsilon 88 | self.k = k 89 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 90 | self.knn = knn_graph 91 | 92 | def forward(self, x): 93 | x = x.squeeze(-1) 94 | B, C, N = x.shape 95 | edge_index = [] 96 | for i in range(B): 97 | edgeindex = self.knn(x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation) 98 | edgeindex = edgeindex.view(2, N, self.k * self.dilation) 99 | edge_index.append(edgeindex) 100 | edge_index = torch.stack(edge_index, dim=1) 101 | return self._dilated(edge_index) 102 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d 4 | 5 | 6 | ############################## 7 | # Basic layers 8 | ############################## 9 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 10 | # activation layer 11 | 12 | act = act.lower() 13 | if act == 'relu': 14 | layer = nn.ReLU(inplace) 15 | elif act == 'leakyrelu': 16 | layer = nn.LeakyReLU(neg_slope, inplace) 17 | elif act == 'prelu': 18 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 19 | else: 20 | raise NotImplementedError('activation layer [%s] is not found' % act) 21 | return layer 22 | 23 | 24 | def norm_layer(norm, nc): 25 | # normalization layer 2d 26 | norm = norm.lower() 27 | if norm == 'batch': 28 | layer = nn.BatchNorm2d(nc, affine=True) 29 | elif norm == 'instance': 30 | layer = nn.InstanceNorm2d(nc, affine=False) 31 | else: 32 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 33 | return layer 34 | 35 | 36 | class MLP(Seq): 37 | def __init__(self, channels, act='relu', norm=None, bias=True): 38 | m = [] 39 | for i in range(1, len(channels)): 40 | m.append(Lin(channels[i - 1], channels[i], bias)) 41 | if act is not None and act.lower() != 'none': 42 | m.append(act_layer(act)) 43 | if norm is not None and norm.lower() != 'none': 44 | m.append(norm_layer(norm, channels[-1])) 45 | super(MLP, self).__init__(*m) 46 | 47 | 48 | class BasicConv(Seq): 49 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): 50 | m = [] 51 | for i in range(1, len(channels)): 52 | m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias)) 53 | if act is not None and act.lower() != 'none': 54 | m.append(act_layer(act)) 55 | if norm is not None and norm.lower() != 'none': 56 | m.append(norm_layer(norm, channels[-1])) 57 | if drop > 0: 58 | m.append(nn.Dropout2d(drop)) 59 | 60 | super(BasicConv, self).__init__(*m) 61 | 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight) 68 | if m.bias is not None: 69 | nn.init.zeros_(m.bias) 70 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | 75 | def batched_index_select(x, idx): 76 | r"""fetches neighbors features from a given neighbor idx 77 | 78 | Args: 79 | x (Tensor): input feature Tensor 80 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. 81 | idx (Tensor): edge_idx 82 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. 83 | Returns: 84 | Tensor: output neighbors features 85 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. 86 | """ 87 | batch_size, num_dims, num_vertices = x.shape[:3] 88 | k = idx.shape[-1] 89 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices 90 | idx = idx + idx_base 91 | idx = idx.contiguous().view(-1) 92 | 93 | x = x.transpose(2, 1) 94 | feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :] 95 | feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() 96 | return feature 97 | -------------------------------------------------------------------------------- /gcn_lib/dense/torch_vertex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .torch_nn import BasicConv, batched_index_select 4 | from .torch_edge import DenseDilatedKnnGraph, DilatedKnnGraph 5 | import torch.nn.functional as F 6 | 7 | 8 | class MRConv2d(nn.Module): 9 | """ 10 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type 11 | """ 12 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 13 | super(MRConv2d, self).__init__() 14 | self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) 15 | 16 | def forward(self, x, edge_index): 17 | x_i = batched_index_select(x, edge_index[1]) 18 | x_j = batched_index_select(x, edge_index[0]) 19 | x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) 20 | return self.nn(torch.cat([x, x_j], dim=1)) 21 | 22 | 23 | class EdgeConv2d(nn.Module): 24 | """ 25 | Edge convolution layer (with activation, batch normalization) for dense data type 26 | """ 27 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 28 | super(EdgeConv2d, self).__init__() 29 | self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) 30 | 31 | def forward(self, x, edge_index): 32 | x_i = batched_index_select(x, edge_index[1]) 33 | x_j = batched_index_select(x, edge_index[0]) 34 | max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) 35 | return max_value 36 | 37 | 38 | class GraphConv2d(nn.Module): 39 | """ 40 | Static graph convolution layer 41 | """ 42 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True): 43 | super(GraphConv2d, self).__init__() 44 | if conv == 'edge': 45 | self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) 46 | elif conv == 'mr': 47 | self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) 48 | else: 49 | raise NotImplementedError('conv:{} is not supported'.format(conv)) 50 | 51 | def forward(self, x, edge_index): 52 | return self.gconv(x, edge_index) 53 | 54 | 55 | class DynConv2d(GraphConv2d): 56 | """ 57 | Dynamic graph convolution layer 58 | """ 59 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', 60 | norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 61 | super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) 62 | self.k = kernel_size 63 | self.d = dilation 64 | if knn == 'matrix': 65 | self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 66 | else: 67 | self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 68 | 69 | def forward(self, x): 70 | edge_index = self.dilated_knn_graph(x) 71 | return super(DynConv2d, self).forward(x, edge_index) 72 | 73 | 74 | class PlainDynBlock2d(nn.Module): 75 | """ 76 | Plain Dynamic graph convolution block 77 | """ 78 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 79 | bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 80 | super(PlainDynBlock2d, self).__init__() 81 | self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, 82 | act, norm, bias, stochastic, epsilon, knn) 83 | 84 | def forward(self, x): 85 | return self.body(x) 86 | 87 | 88 | class ResDynBlock2d(nn.Module): 89 | """ 90 | Residual Dynamic graph convolution block 91 | """ 92 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 93 | bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1): 94 | super(ResDynBlock2d, self).__init__() 95 | self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, 96 | act, norm, bias, stochastic, epsilon, knn) 97 | self.res_scale = res_scale 98 | 99 | def forward(self, x): 100 | return self.body(x) + x*self.res_scale 101 | 102 | 103 | class DenseDynBlock2d(nn.Module): 104 | """ 105 | Dense Dynamic graph convolution block 106 | """ 107 | def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', 108 | act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'): 109 | super(DenseDynBlock2d, self).__init__() 110 | self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv, 111 | act, norm, bias, stochastic, epsilon, knn) 112 | 113 | def forward(self, x): 114 | dense = self.body(x) 115 | return torch.cat((x, dense), 1) 116 | -------------------------------------------------------------------------------- /gcn_lib/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * 4 | 5 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_cluster import knn_graph 4 | 5 | 6 | class Dilated(nn.Module): 7 | """ 8 | Find dilated neighbor from neighbor list 9 | """ 10 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 11 | super(Dilated, self).__init__() 12 | self.dilation = dilation 13 | self.stochastic = stochastic 14 | self.epsilon = epsilon 15 | self.k = k 16 | 17 | def forward(self, edge_index, batch=None): 18 | if self.stochastic: 19 | if torch.rand(1) < self.epsilon and self.training: 20 | num = self.k * self.dilation 21 | randnum = torch.randperm(num)[:self.k] 22 | edge_index = edge_index.view(2, -1, num) 23 | edge_index = edge_index[:, :, randnum] 24 | return edge_index.view(2, -1) 25 | else: 26 | edge_index = edge_index[:, ::self.dilation] 27 | else: 28 | edge_index = edge_index[:, ::self.dilation] 29 | return edge_index 30 | 31 | 32 | class DilatedKnnGraph(nn.Module): 33 | """ 34 | Find the neighbors' indices based on dilated knn 35 | """ 36 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'): 37 | super(DilatedKnnGraph, self).__init__() 38 | self.dilation = dilation 39 | self.stochastic = stochastic 40 | self.epsilon = epsilon 41 | self.k = k 42 | self._dilated = Dilated(k, dilation, stochastic, epsilon) 43 | if knn == 'matrix': 44 | self.knn = knn_graph_matrix 45 | else: 46 | self.knn = knn_graph 47 | 48 | def forward(self, x, batch): 49 | edge_index = self.knn(x, self.k * self.dilation, batch) 50 | return self._dilated(edge_index, batch) 51 | 52 | 53 | def pairwise_distance(x): 54 | """ 55 | Compute pairwise distance of a point cloud. 56 | Args: 57 | x: tensor (batch_size, num_points, num_dims) 58 | Returns: 59 | pairwise distance: (batch_size, num_points, num_points) 60 | """ 61 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 62 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 63 | return x_square + x_inner + x_square.transpose(2, 1) 64 | 65 | 66 | def knn_matrix(x, k=16, batch=None): 67 | """Get KNN based on the pairwise distance. 68 | Args: 69 | pairwise distance: (num_points, num_points) 70 | k: int 71 | Returns: 72 | nearest neighbors: (num_points*k ,1) (num_points, k) 73 | """ 74 | with torch.no_grad(): 75 | if batch is None: 76 | batch_size = 1 77 | else: 78 | batch_size = batch[-1] + 1 79 | x = x.view(batch_size, -1, x.shape[-1]) 80 | 81 | neg_adj = -pairwise_distance(x.detach()) 82 | _, nn_idx = torch.topk(neg_adj, k=k) 83 | del neg_adj 84 | 85 | n_points = x.shape[1] 86 | start_idx = torch.arange(0, n_points*batch_size, n_points).long().view(batch_size, 1, 1) 87 | if x.is_cuda: 88 | start_idx = start_idx.cuda() 89 | nn_idx += start_idx 90 | del start_idx 91 | 92 | if x.is_cuda: 93 | torch.cuda.empty_cache() 94 | 95 | nn_idx = nn_idx.view(1, -1) 96 | center_idx = torch.arange(0, n_points*batch_size).repeat(k, 1).transpose(1, 0).contiguous().view(1, -1) 97 | if x.is_cuda: 98 | center_idx = center_idx.cuda() 99 | return nn_idx, center_idx 100 | 101 | 102 | def knn_graph_matrix(x, k=16, batch=None): 103 | """Construct edge feature for each point 104 | Args: 105 | x: (num_points, num_dims) 106 | batch: (num_points, ) 107 | k: int 108 | Returns: 109 | edge_index: (2, num_points*k) 110 | """ 111 | nn_idx, center_idx = knn_matrix(x, k, batch) 112 | return torch.cat((nn_idx, center_idx), dim=0) 113 | 114 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_message.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import MessagePassing 4 | from torch_scatter import scatter, scatter_softmax 5 | from torch_geometric.utils import degree 6 | 7 | 8 | class GenMessagePassing(MessagePassing): 9 | def __init__(self, aggr='softmax', 10 | t=1.0, learn_t=False, 11 | p=1.0, learn_p=False, 12 | y=0.0, learn_y=False): 13 | 14 | if aggr in ['softmax_sg', 'softmax', 'softmax_sum']: 15 | 16 | super(GenMessagePassing, self).__init__(aggr=None) 17 | self.aggr = aggr 18 | 19 | if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'): 20 | self.learn_t = True 21 | self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True) 22 | else: 23 | self.learn_t = False 24 | self.t = t 25 | 26 | if aggr == 'softmax_sum': 27 | self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) 28 | 29 | elif aggr in ['power', 'power_sum']: 30 | 31 | super(GenMessagePassing, self).__init__(aggr=None) 32 | self.aggr = aggr 33 | 34 | if learn_p: 35 | self.p = torch.nn.Parameter(torch.Tensor([p]), requires_grad=True) 36 | else: 37 | self.p = p 38 | 39 | if aggr == 'power_sum': 40 | self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) 41 | else: 42 | super(GenMessagePassing, self).__init__(aggr=aggr) 43 | 44 | def aggregate(self, inputs, index, ptr=None, dim_size=None): 45 | 46 | if self.aggr in ['add', 'mean', 'max', None]: 47 | return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size) 48 | 49 | elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']: 50 | 51 | if self.learn_t: 52 | out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) 53 | else: 54 | with torch.no_grad(): 55 | out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) 56 | 57 | out = scatter(inputs*out, index, dim=self.node_dim, 58 | dim_size=dim_size, reduce='sum') 59 | 60 | if self.aggr == 'softmax_sum': 61 | self.sigmoid_y = torch.sigmoid(self.y) 62 | degrees = degree(index, num_nodes=dim_size).unsqueeze(1) 63 | out = torch.pow(degrees, self.sigmoid_y) * out 64 | 65 | return out 66 | 67 | 68 | elif self.aggr in ['power', 'power_sum']: 69 | min_value, max_value = 1e-7, 1e1 70 | torch.clamp_(inputs, min_value, max_value) 71 | out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, 72 | dim_size=dim_size, reduce='mean') 73 | torch.clamp_(out, min_value, max_value) 74 | out = torch.pow(out, 1/self.p) 75 | # torch.clamp(out, min_value, max_value) 76 | 77 | if self.aggr == 'power_sum': 78 | self.sigmoid_y = torch.sigmoid(self.y) 79 | degrees = degree(index, num_nodes=dim_size).unsqueeze(1) 80 | out = torch.pow(degrees, self.sigmoid_y) * out 81 | 82 | return out 83 | 84 | else: 85 | raise NotImplementedError('To be implemented') 86 | 87 | 88 | class MsgNorm(torch.nn.Module): 89 | def __init__(self, learn_msg_scale=False): 90 | super(MsgNorm, self).__init__() 91 | 92 | self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]), 93 | requires_grad=learn_msg_scale) 94 | 95 | def forward(self, x, msg, p=2): 96 | msg = F.normalize(msg, p=p, dim=1) 97 | x_norm = x.norm(p=p, dim=1, keepdim=True) 98 | msg = msg * x_norm * self.msg_scale 99 | return msg 100 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_nn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import Sequential as Seq, Linear as Lin 3 | # from utils.data_util import get_atom_feature_dims, get_bond_feature_dims 4 | import numpy as np 5 | import h5py 6 | import os 7 | import os.path as osp 8 | import shutil 9 | from glob import glob 10 | import torch 11 | from torch_scatter import scatter 12 | from torch_geometric.data import InMemoryDataset, Data, extract_zip 13 | from tqdm import tqdm 14 | import torch_geometric as tg 15 | 16 | 17 | def intersection(lst1, lst2): 18 | return list(set(lst1) & set(lst2)) 19 | 20 | 21 | def process_indexes(idx_list): 22 | idx_dict = {} 23 | for i, idx in enumerate(idx_list): 24 | idx_dict[idx] = i 25 | 26 | return [idx_dict[i] for i in sorted(idx_dict.keys())] 27 | 28 | 29 | def add_zeros(data): 30 | data.x = torch.zeros(data.num_nodes, dtype=torch.long) 31 | return data 32 | 33 | 34 | def extract_node_feature(data, reduce='add'): 35 | if reduce in ['mean', 'max', 'add']: 36 | data.x = scatter(data.edge_attr, 37 | data.edge_index[0], 38 | dim=0, 39 | dim_size=data.num_nodes, 40 | reduce=reduce) 41 | else: 42 | raise Exception('Unknown Aggregation Type') 43 | return data 44 | 45 | # random partition graph 46 | def random_partition_graph(num_nodes, cluster_number=10): 47 | parts = np.random.randint(cluster_number, size=num_nodes) 48 | return parts 49 | 50 | 51 | def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1): 52 | # convert sparse tensor to scipy csr 53 | adj = adj.to_scipy(layout='csr') 54 | 55 | num_batches = cluster_number // batch_size 56 | 57 | sg_nodes = [[] for _ in range(num_batches)] 58 | sg_edges = [[] for _ in range(num_batches)] 59 | 60 | for cluster in range(num_batches): 61 | sg_nodes[cluster] = np.where(parts == cluster)[0] 62 | sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(adj[sg_nodes[cluster], :][:, sg_nodes[cluster]])[0] 63 | 64 | return sg_nodes, sg_edges 65 | 66 | def random_rotate(points): 67 | theta = np.random.uniform(0, np.pi * 2) 68 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 69 | rotation_matrix = torch.from_numpy(rotation_matrix).float() 70 | points[:, 0:2] = torch.matmul(points[:, [0, 1]].transpose(1, 3), rotation_matrix).transpose(1, 3) 71 | return points 72 | 73 | 74 | def random_translate(points, mean=0, std=0.02): 75 | points += torch.randn(points.shape)*std + mean 76 | return points 77 | 78 | 79 | def random_points_augmentation(points, rotate=False, translate=False, **kwargs): 80 | if rotate: 81 | points = random_rotate(points) 82 | if translate: 83 | points = random_translate(points, **kwargs) 84 | 85 | return points 86 | 87 | 88 | def scale_translate_pointcloud(pointcloud, shift=[-0.2, 0.2], scale=[2. / 3., 3. /2.]): 89 | """ 90 | for scaling and shifting the point cloud 91 | :param pointcloud: 92 | :return: 93 | """ 94 | B, C, N = pointcloud.shape[0:3] 95 | scale = scale[0] + torch.rand([B, C, 1, 1])*(scale[1]-scale[0]) 96 | shift = shift[0] + torch.rand([B, C, 1, 1]) * (shift[1]-shift[0]) 97 | translated_pointcloud = torch.mul(pointcloud, scale) + shift 98 | return translated_pointcloud 99 | 100 | 101 | class PartNet(InMemoryDataset): 102 | r"""The PartNet dataset from 103 | the `"PartNet: A Large-scale Benchmark for Fine-grained and Hierarchical Part-level 3D Object Understanding" 104 | `_ 105 | paper, containing 3D objects annotated with fine-grained, instance-level, and hierarchical 3D part information. 106 | 107 | Args: 108 | root (string): Root directory where the dataset should be saved. 109 | dataset (str, optional): Which dataset to use (ins_seg_h5, or sem_seg_h5). 110 | (default: :obj:`sem_seg_h5`) 111 | obj_category (str, optional): which category to load. 112 | (default: :obj:`Bed`) 113 | level (str, optional): Which level of part semantic segmentation to use. 114 | (default: :obj:`3`) 115 | phase (str, optional): If :obj:`test`, loads the testing dataset, 116 | If :obj:`val`, loads the validation dataset, 117 | otherwise the training dataset. (default: :obj:`train`) 118 | transform (callable, optional): A function/transform that takes in an 119 | :obj:`torch_geometric.data.Data` object and returns a transformed 120 | version. The data object will be transformed before every access. 121 | (default: :obj:`None`) 122 | pre_transform (callable, optional): A function/transform that takes in 123 | an :obj:`torch_geometric.data.Data` object and returns a 124 | transformed version. The data object will be transformed before 125 | being saved to disk. (default: :obj:`None`) 126 | pre_filter (callable, optional): A function that takes in an 127 | :obj:`torch_geometric.data.Data` object and returns a boolean 128 | value, indicating whether the data object should be included in the 129 | final dataset. (default: :obj:`None`) 130 | """ 131 | # the dataset we use for our paper is pre-released version 132 | def __init__(self, 133 | root, 134 | dataset='sem_seg_h5', 135 | obj_category='Bed', 136 | level=3, 137 | phase='train', 138 | transform=None, 139 | pre_transform=None, 140 | pre_filter=None): 141 | self.dataset = dataset 142 | self.level = level 143 | self.obj_category = obj_category 144 | self.object = '-'.join([self.obj_category, str(self.level)]) 145 | self.level_folder = 'level_'+str(self.level) 146 | self.processed_file_folder = osp.join(self.dataset, self.level_folder, self.object) 147 | super(PartNet, self).__init__(root, transform, pre_transform, pre_filter) 148 | if phase == 'test': 149 | path = self.processed_paths[1] 150 | elif phase == 'val': 151 | path = self.processed_paths[2] 152 | else: 153 | path = self.processed_paths[0] 154 | self.data, self.slices = torch.load(path) 155 | 156 | @property 157 | def raw_file_names(self): 158 | return [self.dataset] 159 | 160 | @property 161 | def processed_file_names(self): 162 | return osp.join(self.processed_file_folder, 'train.pt'), osp.join(self.processed_file_folder, 'test.pt'), \ 163 | osp.join(self.processed_file_folder, 'val.pt') 164 | 165 | def download(self): 166 | path = osp.join(self.raw_dir, self.dataset) 167 | if not osp.exists(path): 168 | raise FileExistsError('PartNet can only downloaded via application. ' 169 | 'See details in https://cs.stanford.edu/~kaichun/partnet/') 170 | # path = download_url(self.url, self.root) 171 | extract_zip(path, self.root) 172 | os.unlink(path) 173 | shutil.rmtree(self.raw_dir) 174 | name = self.url.split(os.sep)[-1].split('.')[0] 175 | os.rename(osp.join(self.root, name), self.raw_dir) 176 | 177 | def process(self): 178 | # save to processed_paths 179 | processed_path = osp.join(self.processed_dir, self.processed_file_folder) 180 | if not osp.exists(processed_path): 181 | os.makedirs(osp.join(processed_path)) 182 | torch.save(self.process_set('train'), self.processed_paths[0]) 183 | torch.save(self.process_set('test'), self.processed_paths[1]) 184 | torch.save(self.process_set('val'), self.processed_paths[2]) 185 | 186 | def process_set(self, dataset): 187 | if self.dataset == 'ins_seg_h5': 188 | raw_path = osp.join(self.raw_dir, 'ins_seg_h5_for_sgpn', self.dataset) 189 | categories = glob(osp.join(raw_path, '*')) 190 | categories = sorted([x.split(os.sep)[-1] for x in categories]) 191 | 192 | data_list = [] 193 | for target, category in enumerate(tqdm(categories)): 194 | folder = osp.join(raw_path, category) 195 | paths = glob('{}/{}-*.h5'.format(folder, dataset)) 196 | labels, nors, opacitys, pts, rgbs = [], [], [], [], [] 197 | for path in paths: 198 | f = h5py.File(path) 199 | pts += torch.from_numpy(f['pts'][:]).unbind(0) 200 | labels += torch.from_numpy(f['label'][:]).to(torch.long).unbind(0) 201 | nors += torch.from_numpy(f['nor'][:]).unbind(0) 202 | opacitys += torch.from_numpy(f['opacity'][:]).unbind(0) 203 | rgbs += torch.from_numpy(f['rgb'][:]).to(torch.float32).unbind(0) 204 | 205 | for i, (pt, label, nor, opacity, rgb) in enumerate(zip(pts, labels, nors, opacitys, rgbs)): 206 | data = Data(pos=pt[:, :3], y=label, norm=nor[:, :3], x=torch.cat((opacity.unsqueeze(-1), rgb/255.), 1)) 207 | 208 | if self.pre_filter is not None and not self.pre_filter(data): 209 | continue 210 | if self.pre_transform is not None: 211 | data = self.pre_transform(data) 212 | data_list.append(data) 213 | else: 214 | raw_path = osp.join(self.raw_dir, self.dataset) 215 | categories = glob(osp.join(raw_path, self.object)) 216 | categories = sorted([x.split(os.sep)[-1] for x in categories]) 217 | data_list = [] 218 | # class_name = [] 219 | for target, category in enumerate(tqdm(categories)): 220 | folder = osp.join(raw_path, category) 221 | paths = glob('{}/{}-*.h5'.format(folder, dataset)) 222 | labels, pts = [], [] 223 | # clss = category.split('-')[0] 224 | 225 | for path in paths: 226 | f = h5py.File(path) 227 | pts += torch.from_numpy(f['data'][:].astype(np.float32)).unbind(0) 228 | labels += torch.from_numpy(f['label_seg'][:].astype(np.float32)).to(torch.long).unbind(0) 229 | for i, (pt, label) in enumerate(zip(pts, labels)): 230 | data = Data(pos=pt[:, :3], y=label) 231 | # data = PartData(pos=pt[:, :3], y=label, clss=clss) 232 | if self.pre_filter is not None and not self.pre_filter(data): 233 | continue 234 | if self.pre_transform is not None: 235 | data = self.pre_transform(data) 236 | data_list.append(data) 237 | return self.collate(data_list) 238 | 239 | 240 | class PartData(Data): 241 | def __init__(self, 242 | y=None, 243 | pos=None, 244 | clss=None): 245 | super(PartData).__init__(pos=pos, y=y) 246 | self.clss = clss 247 | 248 | 249 | # allowable multiple choice node and edge features 250 | # code from https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py 251 | allowable_features = { 252 | 'possible_atomic_num_list' : list(range(1, 119)) + ['misc'], 253 | 'possible_chirality_list' : [ 254 | 'CHI_UNSPECIFIED', 255 | 'CHI_TETRAHEDRAL_CW', 256 | 'CHI_TETRAHEDRAL_CCW', 257 | 'CHI_OTHER' 258 | ], 259 | 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], 260 | 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], 261 | 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], 262 | 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], 263 | 'possible_hybridization_list' : [ 264 | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' 265 | ], 266 | 'possible_is_aromatic_list': [False, True], 267 | 'possible_is_in_ring_list': [False, True], 268 | 'possible_bond_type_list' : [ 269 | 'SINGLE', 270 | 'DOUBLE', 271 | 'TRIPLE', 272 | 'AROMATIC', 273 | 'misc' 274 | ], 275 | 'possible_bond_stereo_list': [ 276 | 'STEREONONE', 277 | 'STEREOZ', 278 | 'STEREOE', 279 | 'STEREOCIS', 280 | 'STEREOTRANS', 281 | 'STEREOANY', 282 | ], 283 | 'possible_is_conjugated_list': [False, True], 284 | } 285 | 286 | 287 | def safe_index(l, e): 288 | """ 289 | Return index of element e in list l. If e is not present, return the last index 290 | """ 291 | try: 292 | return l.index(e) 293 | except: 294 | return len(l) - 1 295 | 296 | 297 | def atom_to_feature_vector(atom): 298 | """ 299 | Converts rdkit atom object to feature list of indices 300 | :param mol: rdkit atom object 301 | :return: list 302 | """ 303 | atom_feature = [ 304 | safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), 305 | allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())), 306 | safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), 307 | safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), 308 | safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), 309 | safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), 310 | safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), 311 | allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), 312 | allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()), 313 | ] 314 | return atom_feature 315 | 316 | 317 | def get_atom_feature_dims(): 318 | return list(map(len, [ 319 | allowable_features['possible_atomic_num_list'], 320 | allowable_features['possible_chirality_list'], 321 | allowable_features['possible_degree_list'], 322 | allowable_features['possible_formal_charge_list'], 323 | allowable_features['possible_numH_list'], 324 | allowable_features['possible_number_radical_e_list'], 325 | allowable_features['possible_hybridization_list'], 326 | allowable_features['possible_is_aromatic_list'], 327 | allowable_features['possible_is_in_ring_list'] 328 | ])) 329 | 330 | 331 | def bond_to_feature_vector(bond): 332 | """ 333 | Converts rdkit bond object to feature list of indices 334 | :param mol: rdkit bond object 335 | :return: list 336 | """ 337 | bond_feature = [ 338 | safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())), 339 | allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())), 340 | allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()), 341 | ] 342 | return bond_feature 343 | 344 | 345 | def get_bond_feature_dims(): 346 | return list(map(len, [ 347 | allowable_features['possible_bond_type_list'], 348 | allowable_features['possible_bond_stereo_list'], 349 | allowable_features['possible_is_conjugated_list'] 350 | ])) 351 | 352 | 353 | def atom_feature_vector_to_dict(atom_feature): 354 | [atomic_num_idx, 355 | chirality_idx, 356 | degree_idx, 357 | formal_charge_idx, 358 | num_h_idx, 359 | number_radical_e_idx, 360 | hybridization_idx, 361 | is_aromatic_idx, 362 | is_in_ring_idx] = atom_feature 363 | 364 | feature_dict = { 365 | 'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx], 366 | 'chirality': allowable_features['possible_chirality_list'][chirality_idx], 367 | 'degree': allowable_features['possible_degree_list'][degree_idx], 368 | 'formal_charge': allowable_features['possible_formal_charge_list'][formal_charge_idx], 369 | 'num_h': allowable_features['possible_numH_list'][num_h_idx], 370 | 'num_rad_e': allowable_features['possible_number_radical_e_list'][number_radical_e_idx], 371 | 'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx], 372 | 'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx], 373 | 'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx] 374 | } 375 | 376 | return feature_dict 377 | 378 | 379 | def bond_feature_vector_to_dict(bond_feature): 380 | [bond_type_idx, 381 | bond_stereo_idx, 382 | is_conjugated_idx] = bond_feature 383 | 384 | feature_dict = { 385 | 'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx], 386 | 'bond_stereo': allowable_features['possible_bond_stereo_list'][bond_stereo_idx], 387 | 'is_conjugated': allowable_features['possible_is_conjugated_list'][is_conjugated_idx] 388 | } 389 | 390 | return feature_dict 391 | 392 | 393 | ############################## 394 | # Basic layers 395 | ############################## 396 | def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): 397 | # activation layer 398 | act = act_type.lower() 399 | if act == 'relu': 400 | layer = nn.ReLU(inplace) 401 | elif act == 'leakyrelu': 402 | layer = nn.LeakyReLU(neg_slope, inplace) 403 | elif act == 'prelu': 404 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 405 | else: 406 | raise NotImplementedError('activation layer [%s] is not found' % act) 407 | return layer 408 | 409 | 410 | def norm_layer(norm_type, nc): 411 | # normalization layer 1d 412 | norm = norm_type.lower() 413 | if norm == 'batch': 414 | layer = nn.BatchNorm1d(nc, affine=True) 415 | elif norm == 'layer': 416 | layer = nn.LayerNorm(nc, elementwise_affine=True) 417 | elif norm == 'instance': 418 | layer = nn.InstanceNorm1d(nc, affine=False) 419 | else: 420 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 421 | return layer 422 | 423 | 424 | class MultiSeq(Seq): 425 | def __init__(self, *args): 426 | super(MultiSeq, self).__init__(*args) 427 | 428 | def forward(self, *inputs): 429 | for module in self._modules.values(): 430 | if type(inputs) == tuple: 431 | inputs = module(*inputs) 432 | else: 433 | inputs = module(inputs) 434 | return inputs 435 | 436 | 437 | class MLP(Seq): 438 | def __init__(self, channels, act='relu', 439 | norm=None, bias=True, 440 | drop=0., last_lin=False): 441 | m = [] 442 | 443 | for i in range(1, len(channels)): 444 | 445 | m.append(Lin(channels[i - 1], channels[i], bias)) 446 | 447 | if (i == len(channels) - 1) and last_lin: 448 | pass 449 | else: 450 | if norm is not None and norm.lower() != 'none': 451 | m.append(norm_layer(norm, channels[i])) 452 | if act is not None and act.lower() != 'none': 453 | m.append(act_layer(act)) 454 | if drop > 0: 455 | m.append(nn.Dropout2d(drop)) 456 | 457 | self.m = m 458 | super(MLP, self).__init__(*self.m) 459 | 460 | 461 | class AtomEncoder(nn.Module): 462 | 463 | def __init__(self, emb_dim): 464 | super(AtomEncoder, self).__init__() 465 | 466 | self.atom_embedding_list = nn.ModuleList() 467 | full_atom_feature_dims = get_atom_feature_dims() 468 | 469 | for i, dim in enumerate(full_atom_feature_dims): 470 | emb = nn.Embedding(dim, emb_dim) 471 | nn.init.xavier_uniform_(emb.weight.data) 472 | self.atom_embedding_list.append(emb) 473 | 474 | def forward(self, x): 475 | x_embedding = 0 476 | for i in range(x.shape[1]): 477 | x_embedding += self.atom_embedding_list[i](x[:, i]) 478 | 479 | return x_embedding 480 | 481 | 482 | class BondEncoder(nn.Module): 483 | 484 | def __init__(self, emb_dim): 485 | super(BondEncoder, self).__init__() 486 | 487 | self.bond_embedding_list = nn.ModuleList() 488 | full_bond_feature_dims = get_bond_feature_dims() 489 | 490 | for i, dim in enumerate(full_bond_feature_dims): 491 | emb = nn.Embedding(dim, emb_dim) 492 | nn.init.xavier_uniform_(emb.weight.data) 493 | self.bond_embedding_list.append(emb) 494 | 495 | def forward(self, edge_attr): 496 | bond_embedding = 0 497 | for i in range(edge_attr.shape[1]): 498 | bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) 499 | 500 | return bond_embedding 501 | 502 | 503 | -------------------------------------------------------------------------------- /gcn_lib/sparse/torch_vertex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torch_geometric as tg 5 | from .torch_nn import MLP, act_layer, norm_layer, BondEncoder 6 | from .torch_edge import DilatedKnnGraph 7 | from .torch_message import GenMessagePassing, MsgNorm 8 | from torch_geometric.utils import remove_self_loops, add_self_loops 9 | 10 | 11 | class GENConv(GenMessagePassing): 12 | """ 13 | GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf 14 | SoftMax & PowerMean Aggregation 15 | """ 16 | def __init__(self, in_dim, emb_dim, 17 | aggr='softmax', 18 | t=1.0, learn_t=False, 19 | p=1.0, learn_p=False, 20 | y=0.0, learn_y=False, 21 | msg_norm=False, learn_msg_scale=True, 22 | encode_edge=False, bond_encoder=False, 23 | edge_feat_dim=None, 24 | norm='batch', mlp_layers=2, 25 | eps=1e-7): 26 | 27 | super(GENConv, self).__init__(aggr=aggr, 28 | t=t, learn_t=learn_t, 29 | p=p, learn_p=learn_p, 30 | y=y, learn_y=learn_y) 31 | 32 | channels_list = [in_dim] 33 | 34 | for i in range(mlp_layers-1): 35 | channels_list.append(in_dim*2) 36 | 37 | channels_list.append(emb_dim) 38 | 39 | self.mlp = MLP(channels=channels_list, 40 | norm=norm, 41 | last_lin=True) 42 | 43 | self.msg_encoder = torch.nn.ReLU() 44 | self.eps = eps 45 | 46 | self.msg_norm = msg_norm 47 | self.encode_edge = encode_edge 48 | self.bond_encoder = bond_encoder 49 | 50 | if msg_norm: 51 | self.msg_norm = MsgNorm(learn_msg_scale=learn_msg_scale) 52 | else: 53 | self.msg_norm = None 54 | 55 | if self.encode_edge: 56 | if self.bond_encoder: 57 | self.edge_encoder = BondEncoder(emb_dim=in_dim) 58 | else: 59 | self.edge_encoder = torch.nn.Linear(edge_feat_dim, in_dim) 60 | 61 | def forward(self, x, edge_index, edge_attr=None): 62 | if self.encode_edge and edge_attr is not None: 63 | edge_emb = self.edge_encoder(edge_attr) 64 | else: 65 | edge_emb = edge_attr 66 | 67 | m = self.propagate(edge_index, x=x, edge_attr=edge_emb) 68 | 69 | if self.msg_norm is not None: 70 | m = self.msg_norm(x, m) 71 | 72 | h = x + m 73 | out = self.mlp(h) 74 | 75 | return out 76 | 77 | def message(self, x_j, edge_attr=None): 78 | 79 | if edge_attr is not None: 80 | msg = x_j + edge_attr 81 | else: 82 | msg = x_j 83 | 84 | return self.msg_encoder(msg) + self.eps 85 | 86 | def update(self, aggr_out): 87 | return aggr_out 88 | 89 | 90 | class MRConv(nn.Module): 91 | """ 92 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) 93 | """ 94 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'): 95 | super(MRConv, self).__init__() 96 | self.nn = MLP([in_channels*2, out_channels], act, norm, bias) 97 | self.aggr = aggr 98 | 99 | def forward(self, x, edge_index): 100 | """""" 101 | x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0]) 102 | return self.nn(torch.cat([x, x_j], dim=1)) 103 | 104 | 105 | class EdgConv(tg.nn.EdgeConv): 106 | """ 107 | Edge convolution layer (with activation, batch normalization) 108 | """ 109 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'): 110 | super(EdgConv, self).__init__(MLP([in_channels*2, out_channels], act, norm, bias), aggr) 111 | 112 | def forward(self, x, edge_index): 113 | return super(EdgConv, self).forward(x, edge_index) 114 | 115 | 116 | class GATConv(nn.Module): 117 | """ 118 | Graph Attention Convolution layer (with activation, batch normalization) 119 | """ 120 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, heads=8): 121 | super(GATConv, self).__init__() 122 | self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias) 123 | m =[] 124 | if act: 125 | m.append(act_layer(act)) 126 | if norm: 127 | m.append(norm_layer(norm, out_channels)) 128 | self.unlinear = nn.Sequential(*m) 129 | 130 | def forward(self, x, edge_index): 131 | out = self.unlinear(self.gconv(x, edge_index)) 132 | return out 133 | 134 | 135 | class SAGEConv(tg.nn.SAGEConv): 136 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 137 | Large Graphs" `_ paper 138 | 139 | .. math:: 140 | \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot 141 | \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j) 142 | 143 | \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i} 144 | {\| \mathbf{\hat{x}}_i \|_2}. 145 | 146 | Args: 147 | in_channels (int): Size of each input sample. 148 | out_channels (int): Size of each output sample. 149 | normalize (bool, optional): If set to :obj:`False`, output features 150 | will not be :math:`\ell_2`-normalized. (default: :obj:`True`) 151 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 152 | an additive bias. (default: :obj:`True`) 153 | **kwargs (optional): Additional arguments of 154 | :class:`torch_geometric.nn.conv.MessagePassing`. 155 | """ 156 | 157 | def __init__(self, 158 | in_channels, 159 | out_channels, 160 | nn, 161 | norm=True, 162 | bias=True, 163 | relative=False, 164 | **kwargs): 165 | self.relative = relative 166 | if norm is not None: 167 | super(SAGEConv, self).__init__(in_channels, out_channels, True, bias, **kwargs) 168 | else: 169 | super(SAGEConv, self).__init__(in_channels, out_channels, False, bias, **kwargs) 170 | self.nn = nn 171 | 172 | def forward(self, x, edge_index, size=None): 173 | """""" 174 | if size is None: 175 | edge_index, _ = remove_self_loops(edge_index) 176 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 177 | 178 | x = x.unsqueeze(-1) if x.dim() == 1 else x 179 | return self.propagate(edge_index, size=size, x=x) 180 | 181 | def message(self, x_i, x_j): 182 | if self.relative: 183 | x = torch.matmul(x_j - x_i, self.weight) 184 | else: 185 | x = torch.matmul(x_j, self.weight) 186 | return x 187 | 188 | def update(self, aggr_out, x): 189 | out = self.nn(torch.cat((x, aggr_out), dim=1)) 190 | if self.bias is not None: 191 | out = out + self.bias 192 | if self.normalize: 193 | out = F.normalize(out, p=2, dim=-1) 194 | return out 195 | 196 | 197 | class RSAGEConv(SAGEConv): 198 | """ 199 | Residual SAGE convolution layer (with activation, batch normalization) 200 | """ 201 | 202 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, relative=False): 203 | nn = MLP([out_channels + in_channels, out_channels], act, norm, bias) 204 | super(RSAGEConv, self).__init__(in_channels, out_channels, nn, norm, bias, relative) 205 | 206 | 207 | class SemiGCNConv(nn.Module): 208 | """ 209 | SemiGCN convolution layer (with activation, batch normalization) 210 | """ 211 | 212 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 213 | super(SemiGCNConv, self).__init__() 214 | self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias) 215 | m = [] 216 | if act: 217 | m.append(act_layer(act)) 218 | if norm: 219 | m.append(norm_layer(norm, out_channels)) 220 | self.unlinear = nn.Sequential(*m) 221 | 222 | def forward(self, x, edge_index): 223 | out = self.unlinear(self.gconv(x, edge_index)) 224 | return out 225 | 226 | 227 | class GinConv(tg.nn.GINConv): 228 | """ 229 | GINConv layer (with activation, batch normalization) 230 | """ 231 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='add'): 232 | super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias)) 233 | 234 | def forward(self, x, edge_index): 235 | return super(GinConv, self).forward(x, edge_index) 236 | 237 | 238 | class GraphConv(nn.Module): 239 | """ 240 | Static graph convolution layer 241 | """ 242 | def __init__(self, in_channels, out_channels, conv='edge', 243 | act='relu', norm=None, bias=True, heads=8): 244 | super(GraphConv, self).__init__() 245 | if conv.lower() == 'edge': 246 | self.gconv = EdgConv(in_channels, out_channels, act, norm, bias) 247 | elif conv.lower() == 'mr': 248 | self.gconv = MRConv(in_channels, out_channels, act, norm, bias) 249 | elif conv.lower() == 'gat': 250 | self.gconv = GATConv(in_channels, out_channels//heads, act, norm, bias, heads) 251 | elif conv.lower() == 'gcn': 252 | self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias) 253 | elif conv.lower() == 'gin': 254 | self.gconv = GinConv(in_channels, out_channels, act, norm, bias) 255 | elif conv.lower() == 'sage': 256 | self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False) 257 | elif conv.lower() == 'rsage': 258 | self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True) 259 | else: 260 | raise NotImplementedError('conv {} is not implemented'.format(conv)) 261 | 262 | def forward(self, x, edge_index): 263 | return self.gconv(x, edge_index) 264 | 265 | 266 | class DynConv(GraphConv): 267 | """ 268 | Dynamic graph convolution layer 269 | """ 270 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', 271 | norm=None, bias=True, heads=8, **kwargs): 272 | super(DynConv, self).__init__(in_channels, out_channels, conv, act, norm, bias, heads) 273 | self.k = kernel_size 274 | self.d = dilation 275 | self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs) 276 | 277 | def forward(self, x, batch=None): 278 | edge_index = self.dilated_knn_graph(x, batch) 279 | return super(DynConv, self).forward(x, edge_index) 280 | 281 | 282 | class PlainDynBlock(nn.Module): 283 | """ 284 | Plain Dynamic graph convolution block 285 | """ 286 | def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 287 | bias=True, res_scale=1, **kwargs): 288 | super(PlainDynBlock, self).__init__() 289 | self.body = DynConv(channels, channels, kernel_size, dilation, conv, 290 | act, norm, bias, **kwargs) 291 | self.res_scale = res_scale 292 | 293 | def forward(self, x, batch=None): 294 | return self.body(x, batch), batch 295 | 296 | 297 | class ResDynBlock(nn.Module): 298 | """ 299 | Residual Dynamic graph convolution block 300 | """ 301 | def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 302 | bias=True, res_scale=1, **kwargs): 303 | super(ResDynBlock, self).__init__() 304 | self.body = DynConv(channels, channels, kernel_size, dilation, conv, 305 | act, norm, bias, **kwargs) 306 | self.res_scale = res_scale 307 | 308 | def forward(self, x, batch=None): 309 | return self.body(x, batch) + x*self.res_scale, batch 310 | 311 | 312 | class DenseDynBlock(nn.Module): 313 | """ 314 | Dense Dynamic graph convolution block 315 | """ 316 | def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, **kwargs): 317 | super(DenseDynBlock, self).__init__() 318 | self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv, 319 | act, norm, bias, **kwargs) 320 | 321 | def forward(self, x, batch=None): 322 | dense = self.body(x, batch) 323 | return torch.cat((x, dense), 1), batch 324 | 325 | 326 | class ResGraphBlock(nn.Module): 327 | """ 328 | Residual Static graph convolution block 329 | """ 330 | def __init__(self, channels, conv='edge', act='relu', norm=None, bias=True, heads=8, res_scale=1): 331 | super(ResGraphBlock, self).__init__() 332 | self.body = GraphConv(channels, channels, conv, act, norm, bias, heads) 333 | self.res_scale = res_scale 334 | 335 | def forward(self, x, edge_index): 336 | return self.body(x, edge_index) + x*self.res_scale, edge_index 337 | 338 | 339 | class DenseGraphBlock(nn.Module): 340 | """ 341 | Dense Static graph convolution block 342 | """ 343 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8): 344 | super(DenseGraphBlock, self).__init__() 345 | self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads) 346 | 347 | def forward(self, x, edge_index): 348 | dense = self.body(x, edge_index) 349 | return torch.cat((x, dense), 1), edge_index 350 | 351 | -------------------------------------------------------------------------------- /generate_pyg_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | from data import get_tf_idf_by_texts, get_llama_embedding, get_word2vec, get_sbert_embedding, set_api_key, get_ogbn_dataset, get_e5_large_embedding 4 | from api import openai_ada_api 5 | import h5py 6 | import numpy as np 7 | from torch_geometric.utils import index_to_mask 8 | from torch_geometric.data import GraphSAINTRandomWalkSampler, NeighborSampler 9 | from data import set_seed_config, LabelPerClassSplit, generate_random_mask 10 | from utils import knowledge_augmentation, compute_pca_with_whitening, bert_whitening 11 | 12 | 13 | def main(): 14 | dataset = ['cora', 'pubmed'] 15 | split = ['random', 'fixed'] 16 | ogb_dataset = ['arxiv', 'products'] 17 | embedding = ["tfidf"] 18 | # knowledge = ["cora", "pubmed"] 19 | data_path = "./preprocessed_data" 20 | ## if match default, just skip 21 | default = { 22 | 'cora': 'tfidf', 23 | "citeseer": 'tfidf', 24 | "pubmed": 'tfidf', 25 | "arxiv": 'word2vec', 26 | "products": 'bow' 27 | } 28 | split_seeds = [i for i in range(10)] 29 | rewrite = True 30 | ## load raw text data 31 | ## handle mask issue 32 | data_obj = None 33 | for name in dataset: 34 | for setting in split: 35 | if name in ogb_dataset and setting == 'random': continue 36 | if name == "cora" and setting == 'random': 37 | data_obj = torch.load("./preprocessed_data/new/cora_random_sbert.pt", map_location="cpu") 38 | data_obj.raw_texts = data_obj.raw_text 39 | data_obj.category_names = [data_obj.label_names[i] for i in data_obj.y.tolist()] 40 | elif name == "cora" and setting == 'fixed': 41 | data_obj = torch.load("./preprocessed_data/new/cora_fixed_sbert.pt", map_location="cpu") 42 | data_obj.raw_texts = data_obj.raw_text 43 | data_obj.category_names = [data_obj.label_names[i] for i in data_obj.y.tolist()] 44 | elif name == "citeseer" and setting == 'random': 45 | data_obj = torch.load("./preprocessed_data/new/citeseer_random_sbert.pt", map_location="cpu") 46 | elif name == "citeseer" and setting == 'fixed': 47 | data_obj = torch.load("./preprocessed_data/new/citeseer_fixed_sbert.pt", map_location="cpu") 48 | elif name == "pubmed" and setting == 'random': 49 | data_obj = torch.load("./preprocessed_data/new/pubmed_random_sbert.pt", map_location="cpu") 50 | elif name == "pubmed" and setting == 'fixed': 51 | data_obj = torch.load("./preprocessed_data/new/pubmed_fixed_sbert.pt", map_location="cpu") 52 | elif name == "arxiv": 53 | data_obj = torch.load("./preprocessed_data/new/arxiv_fixed_sbert.pt", map_location="cpu") 54 | elif name == "products": 55 | data_obj = torch.load("./preprocessed_data/new/products_fixed_sbert.pt", map_location="cpu") 56 | # old_products = get_ogbn_dataset("ogbn-products", normalize_features=False) 57 | # splits = old_products.get_idx_split() 58 | # data_obj.train_masks = [index_to_mask(splits['train'], size = data_obj.x.shape[0])] 59 | # data_obj.val_masks = [index_to_mask(splits['valid'], size = data_obj.x.shape[0])] 60 | # data_obj.test_masks = [index_to_mask(splits['test'], size = data_obj.x.shape[0])] 61 | ## set embedding typ 62 | if name == 'cora' or name == 'pubmed': 63 | #d_name = name.split("_")[0] 64 | d_name = name 65 | entity_pt = torch.load(f"{d_name}_entity.pt", map_location="cpu") 66 | data_obj = torch.load(osp.join(data_path, "new", f"{d_name}_fixed_sbert.pt"), map_location="cpu") 67 | data_obj.entity = entity_pt 68 | num_nodes = len(data_obj.raw_texts) 69 | hidden_dim = 768 70 | for typ in embedding: 71 | # if typ != "ft": continue 72 | # if typ != "sbert" or name != "arxiv": continue 73 | # if typ == "know_exp_ft" and typ != "cora" and typ != "pubmed": continue 74 | if osp.exists(osp.join(data_path, "new", f"{name}_{setting}_{typ}.pt")) and not rewrite: 75 | data_obj = torch.load(osp.join(data_path, "new", f"{name}_{setting}_{typ}.pt"), map_location="cpu") 76 | continue 77 | 78 | # if "know" in typ and name != "cora" and name != "pubmed": continue 79 | 80 | # if default[name] != typ: 81 | if typ == 'tfidf': 82 | if name == 'cora': 83 | max_features = 1433 84 | elif name == 'citeseer': 85 | max_features = 3703 86 | elif name == 'pubmed': 87 | max_features = 500 88 | else: 89 | max_features = 1000 90 | data_obj.x, _ = get_tf_idf_by_texts(data_obj.raw_texts, None, None, max_features=max_features, use_tokenizer=False) 91 | elif typ == 'know_tf': 92 | if name == 'cora': 93 | max_features = 1433 94 | elif name == 'citeseer': 95 | max_features = 3703 96 | elif name == 'pubmed': 97 | max_features = 500 98 | texts, knowledge = knowledge_augmentation(data_obj.raw_texts, data_obj.entity, strategy='back') 99 | data_obj.x, _ = get_tf_idf_by_texts(texts, None, None, max_features=max_features, use_tokenizer=False) 100 | # if name in knowledge: 101 | # entity_pt = torch.load(f"{name}_entity.pt", map_location="cpu") 102 | # data_obj.entity = entity_pt 103 | elif typ == 'word2vec': 104 | data_obj.x = get_word2vec(data_obj.raw_texts) 105 | elif typ == 'sbert': 106 | #if "know" not in name: 107 | data_obj.x = get_sbert_embedding(data_obj.raw_texts) 108 | elif typ == 'know_inp_sb': 109 | texts_inp, _ = knowledge_augmentation(data_obj.raw_texts, data_obj.entity, strategy='inplace') 110 | data_obj.x = get_e5_large_embedding(texts_inp, 'cuda', name + 'knowinp', batch_size=16) 111 | elif typ == "know_sep_sb": 112 | _, knowledge = knowledge_augmentation(data_obj.raw_texts, data_obj.entity, strategy='separate') 113 | data_obj.x = get_e5_large_embedding(knowledge, 'cuda', name + 'knowsep', batch_size=16) 114 | elif typ == 'ada': 115 | if name in ['cora', 'citeseer', 'pubmed']: 116 | data_obj.x = torch.tensor(openai_ada_api(data_obj.raw_texts)) 117 | elif name == 'arxiv': 118 | data_obj.x = torch.load("./ogb_node_features.pt", map_location = 'cpu') 119 | elif name == 'products': 120 | with h5py.File('ogbn_products.h5', 'r') as hf: 121 | numpy_array = np.array(hf['products']) 122 | # convert the numpy array to a torch tensor 123 | tensor = torch.from_numpy(numpy_array) 124 | data_obj.x = tensor 125 | elif typ == 'llama': 126 | if name == "pubmed" and setting == "random": 127 | llama_obj = torch.load(osp.join(data_path, "new", "pubmed_fixed_llama.pt"), map_location="cpu") 128 | data_obj.x = llama_obj.x 129 | else: 130 | data_obj.x = get_llama_embedding(data_obj.raw_texts) 131 | elif typ == "ft": 132 | if name == 'pubmed' or name == 'cora': 133 | data_obj.xs = [] 134 | for i in range(5): 135 | emb = np.memmap(f"./lmoutput/{name}_finetune_{setting}_{i}.emb", dtype=np.float16, mode='r', 136 | shape=(num_nodes, hidden_dim)) 137 | x = torch.tensor(emb, dtype=torch.float32) 138 | data_obj.xs.append(x) 139 | data_obj.x = data_obj.xs[0] 140 | else: 141 | # elif 'know' not in name: 142 | emb = np.memmap(f"./lmoutput/{name}_finetune_{setting}_0.emb", dtype=np.float16, mode='r', 143 | shape=(num_nodes, hidden_dim)) 144 | data_obj.x = torch.tensor(emb, dtype=torch.float32) 145 | elif typ == "noft": 146 | if name == 'pubmed' or name == 'cora': 147 | data_obj.xs = [] 148 | for i in range(5): 149 | emb = np.memmap(f"./lmoutput/{name}_no_finetune_{setting}_{i}.emb", dtype=np.float16, mode='r', 150 | shape=(num_nodes, hidden_dim)) 151 | x = torch.tensor(emb, dtype=torch.float32) 152 | data_obj.xs.append(x) 153 | data_obj.x = data_obj.xs[0] 154 | else: 155 | emb = np.memmap(f"./lmoutput/{name}_no_finetune_{setting}.emb", dtype=np.float16, mode='r', 156 | shape=(num_nodes, hidden_dim)) 157 | data_obj.x = torch.tensor(emb, dtype=torch.float32) 158 | elif typ == 'avg': 159 | emb = np.memmap(f"./lmoutput/{name}_no_finetune_{setting}_0.emb", dtype=np.float16, mode='r', shape=(num_nodes, hidden_dim)) 160 | data_obj.x = torch.tensor(emb, dtype=torch.float32) 161 | elif typ == 'avg_white': 162 | emb = np.memmap(f"./lmoutput/{name}_no_finetune_{setting}_0.emb", dtype=np.float16, mode='r', shape=(num_nodes, hidden_dim)) 163 | emb = torch.tensor(emb, dtype=torch.float32) 164 | emb_white = bert_whitening(emb) 165 | data_obj.x = emb_white 166 | elif typ == 'e5': 167 | emb = torch.load(f"./openai_out/{name}_e5_embedding.pt") 168 | data_obj.x = emb 169 | elif typ == 'google': 170 | if name in ['arxiv', 'products']: 171 | continue 172 | emb = torch.load(f"./openai_out/{name}_google_embedding.pt") 173 | emb = emb.reshape(num_nodes, -1) 174 | data_obj.x = emb 175 | elif typ == "know_exp_ft": 176 | xs = [] 177 | for i in range(5): 178 | emb = np.memmap(f"./lmoutput/{name}_finetune_{setting}_{i}_exp.emb", dtype=np.float16, mode='r', 179 | shape=(num_nodes, hidden_dim)) 180 | x = torch.tensor(emb, dtype=torch.float32) 181 | xs.append(x) 182 | data_obj.xs = xs 183 | data_obj.x = xs[0] 184 | elif typ == "know_inp_ft": 185 | xs = [] 186 | for i in range(5): 187 | emb = np.memmap(f"./lmoutput/{name}_inp_finetune_{setting}_{i}.emb", dtype=np.float16, mode='r', 188 | shape=(num_nodes, hidden_dim)) 189 | x = torch.tensor(emb, dtype=torch.float32) 190 | xs.append(x) 191 | data_obj.xs = xs 192 | data_obj.x = xs[0] 193 | elif typ == "know_sep_ft": 194 | xs = [] 195 | for i in range(5): 196 | emb = np.memmap(f"./lmoutput/{name}_sep_finetune_{setting}_{i}.emb", dtype=np.float16, mode='r', 197 | shape=(num_nodes, hidden_dim)) 198 | x = torch.tensor(emb, dtype=torch.float32) 199 | xs.append(x) 200 | data_obj.xs = xs 201 | data_obj.x = xs[0] 202 | elif typ == "know_exp_sb": 203 | exp = torch.load(f"./preprocessed_data/new/{name}_explanation.pt") 204 | data_obj.x = get_sbert_embedding(exp) 205 | elif typ == "pl": 206 | pl = torch.load(f"./preprocessed_data/new/{name}_pred.pt") 207 | data_obj.x = pl 208 | elif "white" in typ: 209 | if typ == "ft_white": 210 | if not osp.exists(f"./preprocessed_data/new/{name}_{setting}_ft.pt"): 211 | print("You must first generate ft object before generating white object") 212 | continue 213 | ft_obj = torch.load(f"./preprocessed_data/new/{name}_{setting}_ft.pt") 214 | elif typ == 'no_ft_white' or typ == "no_ft_whitening": 215 | if not osp.exists(f"./preprocessed_data/new/{name}_{setting}_noft.pt"): 216 | print("You must first generate noft object before generating white object") 217 | continue 218 | ft_obj = torch.load(f"./preprocessed_data/new/{name}_{setting}_noft.pt") 219 | elif typ == 'llama_white': 220 | if not osp.exists(f"./preprocessed_data/new/{name}_{setting}_llama.pt"): 221 | print("You must first generate llama object before generating white object") 222 | continue 223 | ft_obj = torch.load(f"./preprocessed_data/new/{name}_{setting}_llama.pt") 224 | if (name == 'pubmed' or name == 'cora') and typ == 'ft_white': 225 | newxs = [] 226 | for i in range(5): 227 | x = ft_obj.xs[i] 228 | newx = torch.zeros(x.shape[0], 16) 229 | train_mask = ft_obj.train_masks[i] 230 | val_mask = ft_obj.val_masks[i] 231 | test_mask = ft_obj.test_masks[i] 232 | visible_mask = train_mask | val_mask 233 | X_tr_embeds = x[visible_mask] 234 | X_ts_embeds = x[test_mask] 235 | X_tr_pca, X_ts_pca = compute_pca_with_whitening(X_tr_embeds, X_ts_embeds) 236 | newx[visible_mask] = X_tr_pca 237 | newx[test_mask] = X_ts_pca 238 | newxs.append(newx) 239 | data_obj.xs = newxs 240 | else: 241 | x = ft_obj.x 242 | newx = torch.zeros(x.shape[0], 16) 243 | train_mask = ft_obj.train_masks[0] 244 | val_mask = ft_obj.val_masks[0] 245 | test_mask = ft_obj.test_masks[0] 246 | visible_mask = train_mask | val_mask 247 | X_tr_embeds = x[visible_mask] 248 | X_ts_embeds = x[test_mask] 249 | X_tr_pca, X_ts_pca = compute_pca_with_whitening(X_tr_embeds, X_ts_embeds) 250 | newx[visible_mask] = X_tr_pca 251 | newx[test_mask] = X_ts_pca 252 | data_obj.x = newx 253 | torch.save(data_obj, osp.join(data_path, "new", f"{name}_{setting}_{typ}.pt")) 254 | print("Save object {}".format(osp.join(data_path, "new", f"{name}_{setting}_{typ}.pt"))) 255 | continue 256 | 257 | if name in ['cora', 'citeseer', 'pubmed']: 258 | new_train_masks = [] 259 | new_val_masks = [] 260 | new_test_masks = [] 261 | for k in range(num_split := 10): 262 | set_seed_config(split_seeds[k]) 263 | if setting == 'fixed': 264 | ## 20 per class 265 | fixed_split = LabelPerClassSplit(num_labels_per_class=20, num_valid = 500, num_test=1000) 266 | t_mask, val_mask, te_mask = fixed_split(data_obj, data_obj.x.shape[0]) 267 | new_train_masks.append(t_mask) 268 | new_val_masks.append(val_mask) 269 | new_test_masks.append(te_mask) 270 | else: 271 | total_num = data_obj.x.shape[0] 272 | train_num = int(0.6 * total_num) 273 | val_num = int(0.2 * total_num) 274 | t_mask, val_mask, te_mask = generate_random_mask(data_obj.x.shape[0], train_num, val_num) 275 | new_train_masks.append(t_mask) 276 | new_val_masks.append(val_mask) 277 | new_test_masks.append(te_mask) 278 | data_obj.train_masks = new_train_masks 279 | data_obj.val_masks = new_val_masks 280 | data_obj.test_masks = new_test_masks 281 | 282 | 283 | torch.save(data_obj, osp.join(data_path, "new", f"{name}_{setting}_{typ}.pt")) 284 | print("Save object {}".format(osp.join(data_path, "new", f"{name}_{setting}_{typ}.pt"))) 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | if __name__ == '__main__': 293 | set_api_key() 294 | main() 295 | -------------------------------------------------------------------------------- /hyper.py: -------------------------------------------------------------------------------- 1 | def hyper_search(trial, embedding_type, gnn_model, dataset): 2 | """ 3 | an example, set the range here 4 | """ 5 | return { 6 | 'lr': trial.suggest_categorical('lr', [1e-2, 5e-2, 5e-3, 1e-3]), 7 | 'weight_decay': trial.suggest_categorical('weight_decay', [1e-5, 5e-5, 5e-4, 0]), 8 | 'hidden_dimension': trial.suggest_categorical('hidden_dimension', [16, 32, 64, 8, 128, 256]), 9 | 'dropout': trial.suggest_categorical('dropout', [0., .1, .2, .3, .5, .8]), 10 | 'num_layers': trial.suggest_categorical('num_layers', [2,3]), 11 | 'normalize_features': trial.suggest_categorical('normalize_features', [0, 1]) 12 | } -------------------------------------------------------------------------------- /imgs/README.md: -------------------------------------------------------------------------------- 1 | intro png 2 | -------------------------------------------------------------------------------- /imgs/llm_as_enhancer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryTang/Graph-LLM/344f7c6b7786c7f8293c24ce5b90f141c777aeec/imgs/llm_as_enhancer.png -------------------------------------------------------------------------------- /imgs/llm_as_predictor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryTang/Graph-LLM/344f7c6b7786c7f8293c24ce5b90f141c777aeec/imgs/llm_as_predictor.png -------------------------------------------------------------------------------- /lmfinetune.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | from transformers import PreTrainedModel 4 | from transformers.modeling_outputs import TokenClassifierOutput 5 | import os.path as osp 6 | from utils import init_random_state, init_path, eval 7 | from utils import compute_loss 8 | from data import set_seed_config 9 | import numpy as np 10 | import torch.nn.functional as F 11 | from transformers.models.auto import AutoModel, AutoTokenizer, AutoModelForSequenceClassification 12 | from transformers.trainer import Trainer, TrainingArguments, IntervalStrategy 13 | import argparse 14 | from torch_geometric.utils import mask_to_index 15 | import ipdb 16 | from ogb.nodeproppred import Evaluator 17 | import os 18 | from data import get_dataset 19 | from utils import knowledge_augmentation 20 | 21 | ### Adapted from GLEM 22 | 23 | 24 | class BertClassifier(PreTrainedModel): 25 | def __init__(self, model, n_labels, loss_func, dropout=0.0, seed=0, cla_bias=True, feat_shrink=''): 26 | super().__init__(model.config) 27 | self.bert_encoder, self.loss_func = model, loss_func 28 | self.dropout = nn.Dropout(dropout) 29 | self.feat_shrink = feat_shrink 30 | hidden_dim = model.config.hidden_size 31 | if feat_shrink: 32 | self.feat_shrink_layer = nn.Linear(model.config.hidden_size, int(feat_shrink), bias=cla_bias) 33 | hidden_dim = int(feat_shrink) 34 | self.classifier = nn.Linear(hidden_dim, n_labels, bias=cla_bias) 35 | self.loss_func = loss_func 36 | init_random_state(seed) 37 | 38 | def forward(self, input_ids, attention_mask, labels = None, return_dict = None): 39 | outputs = self.bert_encoder(input_ids=input_ids, 40 | attention_mask=attention_mask, 41 | return_dict=return_dict, 42 | output_hidden_states=True) 43 | emb = self.dropout(outputs['hidden_states'][-1]) # outputs[0]=last hidden state 44 | # Use CLS Emb as sentence emb. 45 | cls_token_emb = emb.permute(1, 0, 2)[0] 46 | if self.feat_shrink: 47 | cls_token_emb = self.feat_shrink_layer(cls_token_emb) 48 | logits = self.classifier(cls_token_emb) 49 | 50 | if labels.shape[-1] == 1: 51 | labels = labels.squeeze() 52 | # print(f'{sum(is_gold)} gold, {sum(~is_gold)} pseudo') 53 | # import ipdb; ipdb.set_trace() 54 | loss = self.loss_func(logits, labels) 55 | return TokenClassifierOutput(loss=loss, logits=logits) 56 | 57 | 58 | class BertEmbInfModel(PreTrainedModel): 59 | def __init__(self, model): 60 | super().__init__(model.config) 61 | self.bert_encoder = model 62 | 63 | @th.no_grad() 64 | def forward(self, **input): 65 | # Extract outputs from the model 66 | outputs = self.bert_encoder(**input, output_hidden_states=True) 67 | emb = outputs['hidden_states'][-1] # Last layer 68 | # Use CLS Emb as sentence emb. 69 | node_cls_emb = emb.permute(1, 0, 2)[0] 70 | return TokenClassifierOutput(logits=node_cls_emb) 71 | 72 | 73 | class BertClaInfModel(PreTrainedModel): 74 | def __init__(self, model, emb, pred, loss_func, feat_shrink=''): 75 | super().__init__(model.config) 76 | self.bert_classifier = model 77 | self.feat_shrink = feat_shrink 78 | self.emb = emb 79 | self.pred = pred 80 | self.loss_func = loss_func 81 | 82 | 83 | @th.no_grad() 84 | def forward(self, input_ids, attention_mask, labels = None, return_dict = None, node_id = None): 85 | # Extract outputs from the model 86 | batch_nodes = node_id.cpu().numpy() 87 | outputs = self.bert_classifier.bert_encoder(input_ids=input_ids, 88 | attention_mask=attention_mask, 89 | return_dict=return_dict, 90 | output_hidden_states=True) 91 | emb = outputs['hidden_states'][-1] # outputs[0]=last hidden state 92 | # Use CLS Emb as sentence emb. 93 | cls_token_emb = emb.permute(1, 0, 2)[0] 94 | if self.feat_shrink: 95 | cls_token_emb = self.bert_classifier.feat_shrink_layer(cls_token_emb) 96 | logits = self.bert_classifier.classifier(cls_token_emb) 97 | # Save prediction and embeddings to disk (memmap) 98 | self.emb[batch_nodes] = cls_token_emb.cpu().numpy().astype(np.float16) 99 | self.pred[batch_nodes] = logits.cpu().numpy().astype(np.float16) 100 | # Output empty to fit the Huggingface trainer pipeline 101 | # loss = self.loss_func(logits, labels) 102 | empty = th.zeros((len(node_id), 1)).cuda() 103 | return TokenClassifierOutput(loss=empty, logits=logits) 104 | 105 | 106 | class Config(): 107 | def __init__(self, args) -> None: 108 | self.model_name = args.model 109 | self.dataset_name = args.dataset 110 | self.seed = args.seed 111 | self.seed_num = args.seed_num 112 | 113 | self.feat_shrink = args.feat_shrink 114 | self.weight_decay = args.weight_decay 115 | self.dropout = args.dropout 116 | self.att_dropout = args.att_dropout 117 | self.cla_dropout = args.cla_dropout 118 | 119 | self.batch_size = args.batch_size 120 | self.epochs = args.epochs 121 | self.warmup_epochs = args.warmup_epochs 122 | self.eval_patience = args.eval_patience 123 | self.grad_acc_steps = args.grad_acc_steps 124 | self.lr = args.lr 125 | 126 | self.output_dir = args.output_dir 127 | self.checkpoint_dir = args.checkpoint_dir 128 | self.label_smoothing = args.label_smoothing 129 | self.split_id = args.split_id 130 | self.eq_batch_size = args.eq_batch_size 131 | self.split = args.split 132 | self.local_rank = os.getenv('LOCAL_RANK', -1) 133 | self.n_gpus = args.n_gpus 134 | self.use_explanation = args.use_explanation 135 | if self.model_name == 'deberta-large': 136 | self.hidden_dim = 1024 137 | else: 138 | self.hidden_dim = 768 139 | 140 | 141 | def get_model_name_mapping(model_name): 142 | mapping = { 143 | "deberta-base": "microsoft/deberta-base", 144 | "deberta-large": "microsoft/deberta-large", 145 | "bert": "bert-base-uncased" 146 | } 147 | return mapping[model_name] 148 | 149 | 150 | 151 | class TextDataset(th.utils.data.Dataset): 152 | def __init__(self, encodings, raw_texts, pyg_data, labels=None): 153 | self.encodings = encodings 154 | self.labels = labels 155 | self.raw_texts = raw_texts 156 | self.data_obj = pyg_data 157 | 158 | 159 | def __getitem__(self, idx): 160 | item = { 161 | 'input_ids': self.encodings['input_ids'][idx].flatten(), 162 | 'attention_mask': self.encodings['attention_mask'][idx].flatten(), 163 | } 164 | # ipdb.set_trace() 165 | ## for inference model to save 166 | item['node_id'] = idx 167 | if self.labels != None: 168 | item["labels"] = self.labels[idx].to(th.long) 169 | #item["raw_text"] = self.raw_texts[idx] 170 | return item 171 | 172 | def __len__(self): 173 | return len(self.raw_texts) 174 | 175 | 176 | def compute_metrics(eval_pred): 177 | logits, labels = eval_pred 178 | import evaluate 179 | metric = evaluate.load("accuracy") 180 | logits = th.tensor(logits).to('cuda') 181 | labels = th.tensor(labels).to('cuda') 182 | predictions = th.argmax(logits, dim=-1) 183 | return metric.compute(predictions=predictions, references=labels) 184 | 185 | 186 | class LMTrainer(): 187 | def __init__(self, config, data, metrics, loss_func) -> None: 188 | self.config = config 189 | set_seed_config(self.config.seed) 190 | self.name = get_model_name_mapping(self.config.model_name) 191 | self.total_data = data 192 | train_steps = self.total_data.x.shape[0] // self.config.batch_size + 1 193 | eval_steps = self.config.eval_patience // self.config.batch_size 194 | warmup_step = int(self.config.warmup_epochs * train_steps) 195 | # total_steps = self.config.epochs * len(self.total_data.raw_text) // self.config.batch_size 196 | self.n_labels = self.total_data.y.max().item() + 1 197 | self.training_args = TrainingArguments( 198 | output_dir=self.config.output_dir, 199 | do_train=True, 200 | do_eval=True, 201 | evaluation_strategy=IntervalStrategy.STEPS, 202 | eval_steps=eval_steps, 203 | save_steps=eval_steps, 204 | learning_rate=self.config.lr, 205 | weight_decay=self.config.weight_decay, 206 | load_best_model_at_end=True, 207 | gradient_accumulation_steps=self.config.grad_acc_steps, 208 | per_device_train_batch_size=self.config.batch_size, 209 | per_device_eval_batch_size=self.config.batch_size * 4, 210 | warmup_steps=warmup_step, 211 | num_train_epochs=self.config.epochs, 212 | dataloader_num_workers=1, 213 | fp16=True, 214 | dataloader_drop_last=True, 215 | local_rank=self.config.local_rank, 216 | report_to='none' 217 | ) 218 | self.loss_func = loss_func 219 | pretrained_model = AutoModel.from_pretrained(self.name, cache_dir = "/localscratch/czk") 220 | self.model = BertClassifier(pretrained_model, 221 | n_labels=self.n_labels, 222 | loss_func=self.loss_func, 223 | feat_shrink=self.config.feat_shrink) 224 | # self.model = AutoModelForSequenceClassification.from_pretrained(self.name, num_labels=7, cache_dir="/localscratch/czk") 225 | 226 | self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_fast = False, cache_dir = "/localscratch/czk") 227 | if 'inp' in self.config.dataset_name: 228 | prev = self.config.dataset_name.split('_')[0] 229 | data_obj = th.load(f"./preprocessed_data/new/{prev}_{self.config.split}_know_inp_sb.pt", map_location='cpu') 230 | texts_inp, _ = knowledge_augmentation(data_obj.raw_texts, data_obj.entity, strategy='inplace') 231 | X = self.tokenizer(texts_inp, padding=True, truncation=True, max_length=512, return_tensors='pt') 232 | elif 'sep' in self.config.dataset_name: 233 | prev = self.config.dataset_name.split('_')[0] 234 | data_obj = th.load(f"./preprocessed_data/new/{prev}_{self.config.split}_know_sep_sb.pt", map_location='cpu') 235 | texts_inp, knowledge = knowledge_augmentation(data_obj.raw_texts, data_obj.entity, strategy='separate') 236 | X = self.tokenizer(knowledge, padding=True, truncation=True, max_length=512, return_tensors='pt') 237 | else: 238 | if not self.config.use_explanation: 239 | X = self.tokenizer(self.total_data.raw_texts, padding=True, truncation=True, max_length=512, return_tensors='pt') 240 | else: 241 | explanation = th.load(f"./preprocessed_data/new/{self.config.dataset_name}_explanation.pt") 242 | X = self.tokenizer(explanation, padding=True, truncation=True, max_length=512, return_tensors='pt') 243 | self.num_of_nodes = len(self.total_data.raw_texts) 244 | self.text_dataset = TextDataset(X, self.total_data.raw_texts, self.total_data, self.total_data.y) 245 | 246 | self.train_dataset = th.utils.data.Subset( 247 | self.text_dataset, mask_to_index(self.total_data.train_mask)) 248 | self.val_dataset = th.utils.data.Subset( 249 | self.text_dataset, mask_to_index(self.total_data.val_mask)) 250 | self.test_dataset = th.utils.data.Subset( 251 | self.text_dataset, mask_to_index(self.total_data.test_mask)) 252 | 253 | # ipdb.set_trace() 254 | 255 | self.trainer = Trainer( 256 | self.model, 257 | args = self.training_args, 258 | train_dataset=self.train_dataset, 259 | eval_dataset=self.val_dataset, 260 | compute_metrics=metrics) 261 | 262 | 263 | self.model.config.dropout = self.config.dropout 264 | self.model.config.attention_dropout = self.config.att_dropout 265 | self.best_model = None 266 | self.metrics = metrics 267 | # self.trainer.train() 268 | 269 | 270 | def train(self): 271 | self.trainer.train() 272 | self.best_model = self.trainer.model 273 | th.save(self.trainer.model.state_dict(), init_path(osp.join(self.config.checkpoint_dir, f"{self.config.dataset_name}-{self.config.model_name}.pt"))) 274 | 275 | def save(self, finetune = True): 276 | finetune_str = "finetune" if finetune else "no_finetune" 277 | if not self.config.use_explanation: 278 | emb_path = osp.join(self.config.output_dir, f"{self.config.dataset_name}_{finetune_str}_{self.config.split}_{self.config.seed}.emb") 279 | pred_path = osp.join(self.config.output_dir, f"{self.config.dataset_name}_{finetune_str}_{self.config.split}_{self.config.seed}.pred") 280 | else: 281 | emb_path = osp.join(self.config.output_dir, f"{self.config.dataset_name}_{finetune_str}_{self.config.split}_{self.config.seed}_exp.emb") 282 | pred_path = osp.join(self.config.output_dir, f"{self.config.dataset_name}_{finetune_str}_{self.config.split}_{self.config.seed}_exp.pred") 283 | self.emb = np.memmap(init_path(emb_path), dtype=np.float16, mode='w+', 284 | shape=(self.num_of_nodes, self.config.hidden_dim)) 285 | self.pred = np.memmap(init_path(pred_path), dtype=np.float16, mode='w+', 286 | shape=(self.num_of_nodes, self.n_labels)) 287 | 288 | if finetune: 289 | emb_save_model = BertClaInfModel(self.best_model, self.emb, self.pred, self.loss_func, self.config.feat_shrink) 290 | else: 291 | pretrained_model = AutoModel.from_pretrained(self.name, cache_dir = "/localscratch/czk") 292 | no_ft_model = BertClassifier(pretrained_model, 293 | n_labels=self.n_labels, 294 | loss_func=self.loss_func, 295 | feat_shrink=self.config.feat_shrink) 296 | emb_save_model = BertClaInfModel(no_ft_model, self.emb, self.pred, self.loss_func, self.config.feat_shrink) 297 | 298 | emb_save_model.eval() 299 | save_args = TrainingArguments( 300 | output_dir=self.config.output_dir, 301 | overwrite_output_dir=False, 302 | do_train=False, 303 | do_eval=True, 304 | per_device_eval_batch_size=self.config.batch_size, 305 | dataloader_drop_last=False, 306 | dataloader_num_workers=1, 307 | fp16_full_eval=True, 308 | local_rank=self.config.local_rank, 309 | report_to='none' 310 | ) 311 | 312 | saver = Trainer(model=emb_save_model, args=save_args) 313 | saver.predict(self.text_dataset) 314 | 315 | ## evaluate the output 316 | mapping = { 317 | "cora": "cora", 318 | "pubmed": "pubmed", 319 | "citeseer": "citeseer", 320 | 'products': "ogbn-products", 321 | "arxiv": "ogbn-arxiv" 322 | } 323 | if "inp" in config.dataset_name or "sep" in config.dataset_name: 324 | data_name = config.dataset_name.split("_")[0] 325 | else: 326 | data_name = config.dataset_name 327 | dataset_name = mapping[data_name] 328 | total_pred = th.tensor(self.pred) 329 | res = evaluate(total_pred, self.total_data, dataset_name, 0) 330 | print(res) 331 | 332 | 333 | 334 | def evaluate(total_pred, total_data, dataset_name, split_id = 0): 335 | total_pred = th.argmax(total_pred, dim=-1) 336 | train_mask = total_data.train_mask 337 | val_mask = total_data.val_mask 338 | test_mask = total_data.test_mask 339 | train_input_dict = { 340 | "y_true": total_pred[train_mask].reshape(-1, 1), 341 | "y_pred": total_data.y[train_mask].reshape(-1, 1) 342 | } 343 | val_input_dict = { 344 | "y_true": total_pred[val_mask].reshape(-1, 1), 345 | "y_pred": total_data.y[val_mask].reshape(-1, 1) 346 | } 347 | test_input_dict = { 348 | "y_true": total_pred[test_mask].reshape(-1, 1), 349 | "y_pred": total_data.y[test_mask].reshape(-1, 1) 350 | } 351 | if "ogb" in dataset_name: 352 | evaluator = Evaluator(name = dataset_name) 353 | train_acc = evaluator.eval(train_input_dict)['acc'] 354 | val_acc = evaluator.eval(val_input_dict)['acc'] 355 | test_acc = evaluator.eval(test_input_dict)['acc'] 356 | res = { 357 | "train_acc": train_acc.item() if isinstance(train_acc, th.Tensor) else train_acc, 358 | "val_acc": val_acc.item() if isinstance(val_acc, th.Tensor) else val_acc, 359 | "test_acc": test_acc.item() if isinstance(test_acc, th.Tensor) else test_acc 360 | } 361 | else: 362 | train_acc = eval(train_input_dict) 363 | val_acc = eval(val_input_dict) 364 | test_acc = eval(test_input_dict) 365 | res = { 366 | "train_acc": train_acc.item() if isinstance(train_acc, th.Tensor) else train_acc, 367 | "val_acc": val_acc.item() if isinstance(val_acc, th.Tensor) else val_acc, 368 | "test_acc": test_acc.item() if isinstance(test_acc, th.Tensor) else test_acc 369 | } 370 | return res 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | def parse_args(): 379 | parser = argparse.ArgumentParser(description='LM training') 380 | parser.add_argument('--device', type=int, default=0) 381 | parser.add_argument('--seed_num', type=int, default=5) 382 | parser.add_argument('--seed', type=int, default=0) 383 | parser.add_argument('--dataset', type=str, default="cora") 384 | parser.add_argument('--model', type=str, default="deberta-base") 385 | parser.add_argument('--feat_shrink', type=str, default="") 386 | ## follow GLEM 387 | parser.add_argument('--batch_size', type=int, default=36) 388 | parser.add_argument('--grad_acc_steps', type=int, default=1) 389 | parser.add_argument('--lr', type=float, default=2e-5) 390 | parser.add_argument('--epochs', type=int, default=4) 391 | parser.add_argument('--warmup_epochs', type=float, default=0.6) 392 | parser.add_argument('--eval_patience', type=int, default=50000) 393 | parser.add_argument('--weight_decay', type=float, default=0.00) 394 | parser.add_argument('--dropout', type=float, default=0.3) 395 | parser.add_argument('--att_dropout', type=float, default=0.1) 396 | parser.add_argument('--cla_dropout', type=float, default=0.4) 397 | parser.add_argument("--split", type=str, default="fixed") 398 | parser.add_argument("--output_dir", type=str, default="./lmoutput") 399 | parser.add_argument('--checkpoint_dir', type=str, default="./lmcheckpoint") 400 | parser.add_argument("--label_smoothing", type=float, default=0) 401 | parser.add_argument("--split_id", type=int, default = 0) 402 | parser.add_argument("--eq_batch_size", type=int, default = 36) 403 | parser.add_argument("--n_gpus", type=int, default = 1) 404 | parser.add_argument("--local-rank", type=int, default=0) 405 | parser.add_argument("--use_explanation", type=int, default=0) 406 | parser.add_argument("--use_knowledge", type=int, default=0) 407 | # parser.add_argument("--") 408 | args = parser.parse_args() 409 | return args 410 | 411 | 412 | if __name__ == '__main__': 413 | command_line_args = parse_args() 414 | num_of_seeds = [i for i in range(command_line_args.seed_num)] 415 | config = Config(command_line_args) 416 | if "inp" in config.dataset_name or "sep" in config.dataset_name: 417 | data_name = config.dataset_name.split("_")[0] 418 | else: 419 | data_name = config.dataset_name 420 | data_obj = get_dataset(config.seed_num, data_name, config.split, "sbert", 0) 421 | for i in num_of_seeds: 422 | # import ipdb; ipdb.set_trace() 423 | n_labels = data_obj.y.max().item() + 1 424 | data_obj.train_mask = data_obj.train_masks[i] 425 | data_obj.val_mask = data_obj.val_masks[i] 426 | data_obj.test_mask = data_obj.test_masks[i] 427 | # import ipdb; ipdb.set_trace() 428 | config.seed = i 429 | loss_func = th.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing, reduction='mean') 430 | # model = BertClassifier(config.model_name, n_labels, loss_func) 431 | trainer = LMTrainer(config, data_obj, compute_metrics, loss_func) 432 | trainer.train() 433 | trainer.save(finetune = True) 434 | th.cuda.empty_cache() 435 | # trainer.save(finetune = False) 436 | 437 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn.models import MLP 2 | from torch_geometric.nn.conv import GCNConv, GINConv, SAGEConv 3 | import torch 4 | import torch.nn.functional as F 5 | from sentence_transformers import SentenceTransformer 6 | import torch.nn as nn 7 | from torch_geometric.nn import LabelPropagation 8 | from torch_geometric.nn.models import GAT 9 | import dgl.nn.pytorch as dglnn 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch_geometric.nn import GATConv as PYGGATConv 14 | import rev.memgcn as memgcn 15 | from rev.rev_layer import SharedDropout 16 | import copy 17 | from torch_sparse import SparseTensor, matmul 18 | from torch_geometric.utils import degree 19 | import numpy as np 20 | import math 21 | import tqdm 22 | from dgl import function as fn 23 | from dgl._ffi.base import DGLError 24 | from dgl.nn.pytorch.utils import Identity 25 | from dgl.ops import edge_softmax 26 | from dgl.utils import expand_as_pair 27 | 28 | 29 | 30 | BIG_CONSTANT = 1e8 31 | 32 | 33 | def get_model(args): 34 | if args.model_name == 'MLP': 35 | return UniversalMLP(args.num_layers, args.input_dim, args.hidden_dimension, args.num_classes, args.dropout, args.norm, args.return_embeds) 36 | elif args.model_name == 'GCN': 37 | return GCN(args.num_layers, args.input_dim, args.hidden_dimension, args.num_classes, args.dropout, args.norm) 38 | elif args.model_name == 'MLP2': 39 | return DeepMLP(args.input_dim, args.num_classes) 40 | elif args.model_name == 'LP': 41 | return LP(args.num_layers, args.alpha) 42 | elif args.model_name == 'SAGE': 43 | return SAGE(args.input_dim, args.hidden_dimension, args.num_classes, args.num_layers, args.dropout) 44 | elif args.model_name == 'GAT': 45 | return GAT2(args.input_dim, args.hidden_dimension, args.num_layers, args.num_classes, args.dropout, args.dropout, args.num_of_heads, args.num_of_out_heads, args.norm) 46 | 47 | 48 | class GAT2(torch.nn.Module): 49 | def __init__(self, num_feat, hidden_dimension, num_layers, num_class, dropout, attn_drop, num_of_heads = 1, num_of_out_heads = 1, norm = None): 50 | super().__init__() 51 | self.layers = [] 52 | self.bns = [] 53 | if num_layers == 1: 54 | self.conv1 = PYGGATConv(num_feat, hidden_dimension, num_of_heads, concat = False, dropout=attn_drop) 55 | else: 56 | self.conv1 = PYGGATConv(num_feat, hidden_dimension, num_of_heads, concat = True, dropout=attn_drop) 57 | self.bns.append(torch.nn.BatchNorm1d(hidden_dimension * num_of_heads)) 58 | self.layers.append(self.conv1) 59 | for _ in range(num_layers - 2): 60 | self.layers.append( 61 | PYGGATConv(hidden_dimension * num_of_heads, hidden_dimension, num_of_heads, concat = True, dropout = dropout) 62 | ) 63 | self.bns.append(torch.nn.BatchNorm1d(hidden_dimension * num_of_heads)) 64 | 65 | # On the Pubmed dataset, use `heads` output heads in `conv2`. 66 | if num_layers > 1: 67 | self.layers.append(PYGGATConv(hidden_dimension * num_of_heads, num_class, heads=num_of_out_heads, 68 | concat=False, dropout=attn_drop).cuda()) 69 | self.layers = torch.nn.ModuleList(self.layers) 70 | self.bns = torch.nn.ModuleList(self.bns) 71 | self.norm = norm 72 | self.num_layers = num_layers 73 | self.with_bn = True if self.norm == 'BatchNorm' else False 74 | self.dropout = dropout 75 | 76 | def forward(self, data): 77 | x, edge_index = data.x, data.edge_index 78 | for i in range(self.num_layers): 79 | x = F.dropout(x, self.dropout, training=self.training) 80 | x = self.layers[i](x, edge_index) 81 | if i != self.num_layers - 1: 82 | if self.with_bn: 83 | x = self.bns[i](x) 84 | x = F.elu(x) 85 | return x 86 | 87 | 88 | 89 | class GATWrapper(torch.nn.Module): 90 | def __init__(self, in_size, hidden_size, num_layers, out_size, dropout): 91 | super().__init__() 92 | self.gat = GAT(in_size, hidden_size, num_layers, out_size, dropout) 93 | 94 | def forward(self, data): 95 | x, edge_index= data.x, data.edge_index 96 | return self.gat(x, edge_index) 97 | 98 | 99 | 100 | class UniversalMLP(torch.nn.Module): 101 | def __init__(self, num_layers, input_dim, hidden_dimension, num_classes, dropout, norm=None, return_embeds = False) -> None: 102 | super().__init__() 103 | hidden_dimensions = [hidden_dimension] * (num_layers - 1) 104 | self.hidden_dimensions = [input_dim] + hidden_dimensions + [num_classes] 105 | self.mlp = MLP(channel_list=self.hidden_dimensions, dropout=dropout, norm=norm) 106 | self.return_embeds = False 107 | 108 | def forward(self, data): 109 | x = data.x 110 | return self.mlp(x) 111 | 112 | def inference(self, x_all, subgraph_loader, device): 113 | xs = [] 114 | for batch in tqdm.tqdm(subgraph_loader): 115 | edge_index, n_id, size = batch.edge_index, batch.n_id, batch.batch_size 116 | edge_index = edge_index.to(device) 117 | # import ipdb; ipdb.set_trace() 118 | x = x_all[n_id][:batch.batch_size].to(device) 119 | x = self.mlp(x) 120 | xs.append(x.cpu()) 121 | x_all = torch.cat(xs, dim=0) 122 | return x_all 123 | 124 | 125 | class DeepMLP(torch.nn.Module): 126 | def __init__(self, in_size, out_size) -> None: 127 | super().__init__() 128 | self.mlp = nn.Sequential(nn.Linear(in_size, 1024), 129 | nn.SELU(), 130 | nn.Dropout(0.5), 131 | nn.LayerNorm(1024), 132 | nn.Linear(1024, 512), 133 | nn.SELU(), 134 | nn.Dropout(0.5), 135 | nn.LayerNorm(512), 136 | nn.Linear(512, out_size), 137 | ) 138 | 139 | def forward(self, data): 140 | x = data.x 141 | return self.mlp(x) 142 | 143 | 144 | 145 | 146 | class GCN(torch.nn.Module): 147 | def __init__(self, num_layers, input_dim, hidden_dimension, num_classes, dropout, norm=None) -> None: 148 | super().__init__() 149 | self.convs = torch.nn.ModuleList() 150 | self.norms = torch.nn.ModuleList() 151 | self.num_layers = num_layers 152 | self.dropout = dropout 153 | if num_layers == 1: 154 | self.convs.append(GCNConv(input_dim, num_classes, cached=False, 155 | normalize=True)) 156 | else: 157 | self.convs.append(GCNConv(input_dim, hidden_dimension, cached=False, 158 | normalize=True)) 159 | if norm: 160 | self.norms.append(torch.nn.BatchNorm1d(hidden_dimension)) 161 | else: 162 | self.norms.append(torch.nn.Identity()) 163 | 164 | for _ in range(num_layers - 2): 165 | self.convs.append(GCNConv(hidden_dimension, hidden_dimension, cached=False, 166 | normalize=True)) 167 | if norm: 168 | self.norms.append(torch.nn.BatchNorm1d(hidden_dimension)) 169 | else: 170 | self.norms.append(torch.nn.Identity()) 171 | 172 | self.convs.append(GCNConv(hidden_dimension, num_classes, cached=False, normalize=True)) 173 | 174 | def forward(self, data): 175 | x, edge_index, edge_weight= data.x, data.edge_index, data.edge_weight 176 | for i in range(self.num_layers): 177 | x = F.dropout(x, p=self.dropout, training=self.training) 178 | x = self.convs[i](x, edge_index) 179 | if i != self.num_layers - 1: 180 | x = self.norms[i](x) 181 | x = F.relu(x) 182 | return x 183 | 184 | 185 | 186 | class LP(torch.nn.Module): 187 | def __init__(self, num_layers, alpha) -> None: 188 | super().__init__() 189 | self.lp = LabelPropagation(num_layers, alpha) 190 | 191 | def forward(self, data): 192 | y= data.y 193 | train_mask = data.train_mask 194 | return self.lp(y, data.adj_t, train_mask) 195 | 196 | 197 | def sbert(device): 198 | model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder='/localscratch/czk/huggingface', device=device).to(device) 199 | return model 200 | 201 | 202 | def mpnet(device): 203 | model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', cache_folder='/localscratch/czk/huggingface', device=device).to(device) 204 | return model 205 | 206 | 207 | 208 | class SAGE(torch.nn.Module): 209 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 210 | dropout): 211 | super(SAGE, self).__init__() 212 | 213 | self.convs = torch.nn.ModuleList() 214 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 215 | for _ in range(num_layers - 2): 216 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 217 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 218 | 219 | self.dropout = dropout 220 | 221 | def reset_parameters(self): 222 | for conv in self.convs: 223 | conv.reset_parameters() 224 | 225 | def forward(self, data): 226 | x, edge_index = data.x, data.edge_index 227 | edge_weight = None 228 | for conv in self.convs[:-1]: 229 | x = conv(x, edge_index, edge_weight) 230 | x = F.relu(x) 231 | x = F.dropout(x, p=self.dropout, training=self.training) 232 | x = self.convs[-1](x, edge_index, edge_weight) 233 | return x 234 | 235 | def inference(self, x_all, subgraph_loader, device): 236 | for i, conv in enumerate(self.convs): 237 | xs = [] 238 | for batch in subgraph_loader: 239 | edge_index, n_id, size = batch.edge_index, batch.n_id, batch.batch_size 240 | edge_index = edge_index.to(device) 241 | x = x_all[n_id].to(device) 242 | x_target = x[:size] 243 | x = conv((x, x_target), edge_index) 244 | if i != len(self.convs) - 1: 245 | x = F.relu(x) 246 | xs.append(x.cpu()) 247 | x_all = torch.cat(xs, dim=0) 248 | return x_all 249 | 250 | 251 | 252 | 253 | class ElementWiseLinear(nn.Module): 254 | def __init__(self, size, weight=True, bias=True, inplace=False): 255 | super().__init__() 256 | if weight: 257 | self.weight = nn.Parameter(torch.Tensor(size)) 258 | else: 259 | self.weight = None 260 | if bias: 261 | self.bias = nn.Parameter(torch.Tensor(size)) 262 | else: 263 | self.bias = None 264 | self.inplace = inplace 265 | 266 | self.reset_parameters() 267 | 268 | def reset_parameters(self): 269 | if self.weight is not None: 270 | nn.init.ones_(self.weight) 271 | if self.bias is not None: 272 | nn.init.zeros_(self.bias) 273 | 274 | def forward(self, x): 275 | if self.inplace: 276 | if self.weight is not None: 277 | x.mul_(self.weight) 278 | if self.bias is not None: 279 | x.add_(self.bias) 280 | else: 281 | if self.weight is not None: 282 | x = x * self.weight 283 | if self.bias is not None: 284 | x = x + self.bias 285 | return x 286 | 287 | 288 | class GATConv(nn.Module): 289 | def __init__( 290 | self, 291 | in_feats, 292 | out_feats, 293 | num_heads=1, 294 | feat_drop=0.0, 295 | attn_drop=0.0, 296 | edge_drop=0.0, 297 | negative_slope=0.2, 298 | use_attn_dst=True, 299 | residual=False, 300 | activation=None, 301 | allow_zero_in_degree=False, 302 | use_symmetric_norm=False, 303 | ): 304 | super(GATConv, self).__init__() 305 | self._num_heads = num_heads 306 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 307 | self._out_feats = out_feats 308 | self._allow_zero_in_degree = allow_zero_in_degree 309 | self._use_symmetric_norm = use_symmetric_norm 310 | if isinstance(in_feats, tuple): 311 | self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) 312 | self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False) 313 | else: 314 | self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) 315 | self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 316 | if use_attn_dst: 317 | self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 318 | else: 319 | self.register_buffer("attn_r", None) 320 | self.feat_drop = nn.Dropout(feat_drop) 321 | assert feat_drop == 0.0 # not implemented 322 | self.attn_drop = nn.Dropout(attn_drop) 323 | assert attn_drop == 0.0 # not implemented 324 | self.edge_drop = edge_drop 325 | self.leaky_relu = nn.LeakyReLU(negative_slope) 326 | if residual: 327 | self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False) 328 | else: 329 | self.register_buffer("res_fc", None) 330 | self.reset_parameters() 331 | self._activation = activation 332 | 333 | def reset_parameters(self): 334 | gain = nn.init.calculate_gain("relu") 335 | if hasattr(self, "fc"): 336 | nn.init.xavier_normal_(self.fc.weight, gain=gain) 337 | else: 338 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 339 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 340 | nn.init.xavier_normal_(self.attn_l, gain=gain) 341 | if isinstance(self.attn_r, nn.Parameter): 342 | nn.init.xavier_normal_(self.attn_r, gain=gain) 343 | if isinstance(self.res_fc, nn.Linear): 344 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 345 | 346 | def set_allow_zero_in_degree(self, set_value): 347 | self._allow_zero_in_degree = set_value 348 | 349 | def forward(self, graph, feat, perm=None): 350 | with graph.local_scope(): 351 | if not self._allow_zero_in_degree: 352 | if (graph.in_degrees() == 0).any(): 353 | assert False 354 | 355 | if isinstance(feat, tuple): 356 | h_src = self.feat_drop(feat[0]) 357 | h_dst = self.feat_drop(feat[1]) 358 | if not hasattr(self, "fc_src"): 359 | self.fc_src, self.fc_dst = self.fc, self.fc 360 | feat_src, feat_dst = h_src, h_dst 361 | feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) 362 | feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) 363 | else: 364 | h_src = self.feat_drop(feat) 365 | feat_src = h_src 366 | feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats) 367 | if graph.is_block: 368 | h_dst = h_src[: graph.number_of_dst_nodes()] 369 | feat_dst = feat_src[: graph.number_of_dst_nodes()] 370 | else: 371 | h_dst = h_src 372 | feat_dst = feat_src 373 | 374 | if self._use_symmetric_norm: 375 | degs = graph.out_degrees().float().clamp(min=1) 376 | norm = torch.pow(degs, -0.5) 377 | shp = norm.shape + (1,) * (feat_src.dim() - 1) 378 | norm = torch.reshape(norm, shp) 379 | feat_src = feat_src * norm 380 | 381 | # NOTE: GAT paper uses "first concatenation then linear projection" 382 | # to compute attention scores, while ours is "first projection then 383 | # addition", the two approaches are mathematically equivalent: 384 | # We decompose the weight vector a mentioned in the paper into 385 | # [a_l || a_r], then 386 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j 387 | # Our implementation is much efficient because we do not need to 388 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, 389 | # addition could be optimized with DGL's built-in function u_add_v, 390 | # which further speeds up computation and saves memory footprint. 391 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) 392 | graph.srcdata.update({"ft": feat_src, "el": el}) 393 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. 394 | if self.attn_r is not None: 395 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) 396 | graph.dstdata.update({"er": er}) 397 | graph.apply_edges(fn.u_add_v("el", "er", "e")) 398 | else: 399 | graph.apply_edges(fn.copy_u("el", "e")) 400 | e = self.leaky_relu(graph.edata.pop("e")) 401 | 402 | if self.training and self.edge_drop > 0: 403 | if perm is None: 404 | perm = torch.randperm(graph.number_of_edges(), device=e.device) 405 | bound = int(graph.number_of_edges() * self.edge_drop) 406 | eids = perm[bound:] 407 | graph.edata["a"] = torch.zeros_like(e) 408 | graph.edata["a"][eids] = self.attn_drop(edge_softmax(graph, e[eids], eids=eids)) 409 | else: 410 | graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) 411 | 412 | # message passing 413 | graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft")) 414 | rst = graph.dstdata["ft"] 415 | 416 | if self._use_symmetric_norm: 417 | degs = graph.in_degrees().float().clamp(min=1) 418 | norm = torch.pow(degs, 0.5) 419 | shp = norm.shape + (1,) * (feat_dst.dim() - 1) 420 | norm = torch.reshape(norm, shp) 421 | rst = rst * norm 422 | 423 | # residual 424 | if self.res_fc is not None: 425 | resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) 426 | rst = rst + resval 427 | 428 | # activation 429 | if self._activation is not None: 430 | rst = self._activation(rst) 431 | return rst 432 | 433 | 434 | class RevGATBlock(nn.Module): 435 | def __init__( 436 | self, 437 | node_feats, 438 | edge_feats, 439 | edge_emb, 440 | out_feats, 441 | n_heads=1, 442 | attn_drop=0.0, 443 | edge_drop=0.0, 444 | negative_slope=0.2, 445 | residual=True, 446 | activation=None, 447 | use_attn_dst=True, 448 | allow_zero_in_degree=True, 449 | use_symmetric_norm=False, 450 | ): 451 | super(RevGATBlock, self).__init__() 452 | 453 | self.norm = nn.BatchNorm1d(n_heads * out_feats) 454 | self.conv = GATConv( 455 | node_feats, 456 | out_feats, 457 | num_heads=n_heads, 458 | attn_drop=attn_drop, 459 | edge_drop=edge_drop, 460 | negative_slope=negative_slope, 461 | residual=residual, 462 | activation=activation, 463 | use_attn_dst=use_attn_dst, 464 | allow_zero_in_degree=allow_zero_in_degree, 465 | use_symmetric_norm=use_symmetric_norm, 466 | ) 467 | self.dropout = SharedDropout() 468 | if edge_emb > 0: 469 | self.edge_encoder = nn.Linear(edge_feats, edge_emb) 470 | else: 471 | self.edge_encoder = None 472 | 473 | def forward(self, x, graph, dropout_mask=None, perm=None, efeat=None): 474 | if perm is not None: 475 | perm = perm.squeeze() 476 | out = self.norm(x) 477 | out = F.relu(out, inplace=True) 478 | if isinstance(self.dropout, SharedDropout): 479 | self.dropout.set_mask(dropout_mask) 480 | out = self.dropout(out) 481 | 482 | if self.edge_encoder is not None: 483 | if efeat is None: 484 | efeat = graph.edata["feat"] 485 | efeat_emb = self.edge_encoder(efeat) 486 | efeat_emb = F.relu(efeat_emb, inplace=True) 487 | else: 488 | efeat_emb = None 489 | 490 | out = self.conv(graph, out, perm).flatten(1, -1) 491 | return out 492 | 493 | 494 | class RevGAT(nn.Module): 495 | def __init__( 496 | self, 497 | in_feats, 498 | n_classes, 499 | n_hidden, 500 | n_layers, 501 | n_heads, 502 | activation, 503 | dropout=0.0, 504 | input_drop=0.0, 505 | attn_drop=0.0, 506 | edge_drop=0.0, 507 | use_attn_dst=True, 508 | use_symmetric_norm=False, 509 | group=2, 510 | ): 511 | super().__init__() 512 | self.in_feats = in_feats 513 | self.n_hidden = n_hidden 514 | self.n_classes = n_classes 515 | self.n_layers = n_layers 516 | self.num_heads = n_heads 517 | self.group = group 518 | 519 | self.convs = nn.ModuleList() 520 | self.norm = nn.BatchNorm1d(n_heads * n_hidden) 521 | 522 | for i in range(n_layers): 523 | in_hidden = n_heads * n_hidden if i > 0 else in_feats 524 | out_hidden = n_hidden if i < n_layers - 1 else n_classes 525 | num_heads = n_heads if i < n_layers - 1 else 1 526 | out_channels = n_heads 527 | 528 | if i == 0 or i == n_layers -1: 529 | self.convs.append( 530 | GATConv( 531 | in_hidden, 532 | out_hidden, 533 | num_heads=num_heads, 534 | attn_drop=attn_drop, 535 | edge_drop=edge_drop, 536 | use_attn_dst=use_attn_dst, 537 | use_symmetric_norm=use_symmetric_norm, 538 | residual=True, 539 | ) 540 | ) 541 | else: 542 | Fms = nn.ModuleList() 543 | fm = RevGATBlock( 544 | in_hidden // group, 545 | 0, 546 | 0, 547 | out_hidden // group, 548 | n_heads=num_heads, 549 | attn_drop=attn_drop, 550 | edge_drop=edge_drop, 551 | use_attn_dst=use_attn_dst, 552 | use_symmetric_norm=use_symmetric_norm, 553 | residual=True, 554 | ) 555 | for i in range(self.group): 556 | if i == 0: 557 | Fms.append(fm) 558 | else: 559 | Fms.append(copy.deepcopy(fm)) 560 | 561 | invertible_module = memgcn.GroupAdditiveCoupling(Fms, 562 | group=self.group) 563 | 564 | conv = memgcn.InvertibleModuleWrapper(fn=invertible_module, 565 | keep_input=False) 566 | 567 | self.convs.append(conv) 568 | 569 | self.bias_last = ElementWiseLinear(n_classes, weight=False, bias=True, inplace=True) 570 | 571 | self.input_drop = nn.Dropout(input_drop) 572 | self.dropout = dropout 573 | self.dp_last = nn.Dropout(dropout) 574 | self.activation = activation 575 | 576 | def forward(self, graph, feat): 577 | h = feat 578 | h = self.input_drop(h) 579 | 580 | self.perms = [] 581 | for i in range(self.n_layers): 582 | perm = torch.randperm(graph.number_of_edges(), 583 | device=graph.device) 584 | self.perms.append(perm) 585 | 586 | h = self.convs[0](graph, h, self.perms[0]).flatten(1, -1) 587 | 588 | m = torch.zeros_like(h).bernoulli_(1 - self.dropout) 589 | mask = m.requires_grad_(False) / (1 - self.dropout) 590 | 591 | for i in range(1, self.n_layers-1): 592 | graph.requires_grad = False 593 | perm = torch.stack([self.perms[i]]*self.group, dim=1) 594 | h = self.convs[i](h, graph, mask, perm) 595 | 596 | h = self.norm(h) 597 | h = self.activation(h, inplace=True) 598 | h = self.dp_last(h) 599 | h = self.convs[-1](graph, h, self.perms[-1]) 600 | 601 | h = h.mean(1) 602 | h = self.bias_last(h) 603 | 604 | return h 605 | 606 | -------------------------------------------------------------------------------- /ogbn_products.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import pandas as pd 3 | from ogb.nodeproppred import PygNodePropPredDataset 4 | import os.path as osp 5 | import torch_geometric.transforms as T 6 | import openai 7 | from tqdm import tqdm 8 | import torch 9 | import time 10 | import pyarrow as pa 11 | import os 12 | import re 13 | from tqdm import tqdm 14 | 15 | def set_api_key(): 16 | openai.api_key = "XXX" 17 | 18 | 19 | 20 | def get_transform(normalize_features, transform): 21 | # import ipdb; ipdb.set_trace() 22 | if transform is not None and normalize_features: 23 | transform = T.Compose([T.NormalizeFeatures(), transform]) 24 | elif normalize_features: 25 | transform = T.NormalizeFeatures() 26 | elif transform is not None: 27 | transform = transform 28 | return transform 29 | 30 | 31 | def get_ogbn_dataset(name, normalize_features=True, transform=None): 32 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', name) 33 | dataset = PygNodePropPredDataset(name, path) 34 | dataset.transform = get_transform(normalize_features, transform) 35 | return dataset 36 | 37 | 38 | def ogb_data(normalize_features = False, transform = None): 39 | dataset = get_ogbn_dataset("ogbn-products", normalize_features, transform=transform) 40 | data = dataset[0] 41 | return data 42 | 43 | 44 | def compress_embeddings(embedding_list, name = 'ogb_product_features.pt'): 45 | arr = pa.array(embedding_list) 46 | torch.save(arr, name) 47 | 48 | 49 | 50 | def get_raw_dataset(raw_train = "raw_data/Amazon-3M.raw/trn.json", raw_test = "raw_data/Amazon-3M.raw/tst.json", 51 | label2cat = "raw_data/ogbn_products/mapping/labelidx2productcategory.csv", 52 | idx2asin = "raw_data/ogbn_products/mapping/nodeidx2asin.csv" 53 | ): 54 | train_part = load_dataset("json", data_files=raw_train) 55 | test_part = load_dataset("json", data_files=raw_test) 56 | train_df = train_part['train'].to_pandas() 57 | test_df = test_part['train'].to_pandas() 58 | combine_df = pd.concat([train_df, test_df], ignore_index=True) 59 | label2cat_df = pd.read_csv(label2cat) 60 | idx2asin = pd.read_csv(idx2asin) 61 | idx_mapping = {row[0]: row[1] for row in idx2asin.values} 62 | content_mapping = {row[0]: (row[1], row[2]) for row in combine_df.values} 63 | return idx_mapping, content_mapping 64 | 65 | 66 | def openai_ada_api(input_list, model_name = 'text-embedding-ada-002', max_len = 8190): 67 | input_list = [x[:max_len] for x in input_list] 68 | res = openai.Embedding.create(input = input_list, model=model_name)['data'] 69 | res = [x['embedding'] for x in res] 70 | return res 71 | 72 | def save_large_features(large_list, chunk_num = 10): 73 | chunk_size = len(large_list) // chunk_num 74 | for i in tqdm(range(chunk_num)): 75 | if osp.exists(f'ogbn_product_features_{i}.pt'): 76 | continue 77 | part = large_list[i * chunk_size: (i + 1) * chunk_size] 78 | torch.save(part, f'ogbn_product_features_{i}.pt') 79 | 80 | 81 | def load_backup(path = "ogb/backup/backup.pt"): 82 | initial = torch.load(path) 83 | ogb_path = "ogb/backup" 84 | scatter_filenames = [osp.join(ogb_path, x) for x in os.listdir(ogb_path) if 'backup_' in x and 'compress' not in x] 85 | sort_filenames = sorted(scatter_filenames, key=lambda x:int(re.findall(r'\d+', x)[-1])) 86 | for filename in tqdm(sort_filenames): 87 | size = int(re.findall(r'\d+', filename)[-1]) 88 | intermediate_file = torch.load(filename) 89 | initial.extend(intermediate_file) 90 | assert len(initial) <= size 91 | return initial 92 | 93 | 94 | 95 | 96 | def generate_embeddings(cache_size = 1024): 97 | if not osp.exists('prompt.pt'): 98 | idx_mapping, content_mapping = get_raw_dataset() 99 | idx_mapping_list = idx_mapping.items() 100 | idx_mapping_list = sorted(idx_mapping_list, key=lambda x:x[0]) 101 | prompt_list = [] 102 | for key, value in idx_mapping_list: 103 | content = content_mapping[value] 104 | title, abstract = content 105 | title = title.strip() 106 | abstract = abstract.strip() 107 | prompt = f"{title}: {abstract}" 108 | prompt_list.append(prompt) 109 | torch.save(prompt_list, 'prompt.pt') 110 | else: 111 | prompt_list = torch.load('prompt.pt') 112 | 113 | print("prompt loaded") 114 | cache_num = 0 115 | backup_num = 1 116 | result = [] 117 | total_num = 0 118 | ogb_products = ogb_data() 119 | if osp.exists('backup.pt'): 120 | result = load_backup() 121 | cache_num = len(result) 122 | total_num = len(result) 123 | print("backup loaded") 124 | while cache_num < len(prompt_list): 125 | prompt_input = prompt_list[cache_num :cache_num + cache_size] 126 | if osp.exists(osp.join('ogb', f'backup_{total_num}.pt')): 127 | res = torch.load(osp.join('ogb', f'backup_{total_num}.pt')) 128 | else: 129 | res = openai_ada_api(prompt_input) 130 | cache_num += cache_size 131 | total_num += min(cache_size, len(prompt_list) - cache_size) 132 | print(total_num) 133 | result.extend(res) 134 | torch.save(res, osp.join('ogb/backup', f'backup_{total_num}.pt')) 135 | compress_embeddings(res, osp.join(f'compress_backup_{total_num}.pt')) 136 | print(f"Current number done: {total_num}") 137 | save_large_features(result, chunk_num=10) 138 | compress_embeddings(result) 139 | assert total_num == ogb_products.x.shape[0] 140 | 141 | 142 | 143 | 144 | 145 | def generate_ogb_products_pd_df(): 146 | idx_mapping, content_mapping = get_raw_dataset() 147 | idx_mapping_list = idx_mapping.items() 148 | idx_mapping_list = sorted(idx_mapping_list, key=lambda x:x[0]) 149 | titles = [] 150 | contents = [] 151 | for _, value in tqdm(idx_mapping_list): 152 | content = content_mapping[value] 153 | title, abstract = content 154 | title = title.strip() 155 | abstract = abstract.strip() 156 | titles.append(title) 157 | contents.append(abstract) 158 | df = pd.DataFrame({'title': titles, 'content': contents}) 159 | df.to_csv('ogb_products.csv', index=False) 160 | 161 | 162 | 163 | 164 | 165 | if __name__ == '__main__': 166 | generate_ogb_products_pd_df() 167 | -------------------------------------------------------------------------------- /ood.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import ipdb 4 | import os.path as osp 5 | 6 | ## change this to your path 7 | GOOD_ARXIV = "/egr/research-dselab/chenzh85/toy_experiments/GOODArxiv" 8 | data_path = "/egr/research-dselab/chenzh85/toy_experiments/ogb/preprocessed_data/new/" 9 | 10 | def ood_main(): 11 | arxiv_ood_split_degree_concept = torch.load(osp.join(GOOD_ARXIV, "degree/processed/concept.pt")) 12 | arxiv_ood_split_degree_concept = arxiv_ood_split_degree_concept[0] 13 | arxiv_ood_split_degree_covariate = torch.load(osp.join(GOOD_ARXIV, "degree/processed/covariate.pt")) 14 | arxiv_ood_split_degree_covariate = arxiv_ood_split_degree_covariate[0] 15 | 16 | arxiv_ood_split_time_concept = torch.load(osp.join(GOOD_ARXIV, "time/processed/concept.pt")) 17 | arxiv_ood_split_time_concept = arxiv_ood_split_time_concept[0] 18 | arxiv_ood_split_time_covariate = torch.load(osp.join(GOOD_ARXIV, "time/processed/covariate.pt")) 19 | arxiv_ood_split_time_covariate = arxiv_ood_split_time_covariate[0] 20 | 21 | llm_y = torch.load(osp.join(data_path, "arxiv_fixed_pl.pt")) 22 | pseudo_labels = llm_y.x[:, 0][:] 23 | pseudo_labels -= 1 24 | 25 | gt = llm_y.y 26 | 27 | ood_type = ['concept_degree', 'covariate_degree', 'concept_time', 'covariate_time'] 28 | ## evaluate the accuracy according to the environment id 29 | for i, data in enumerate([arxiv_ood_split_degree_concept, arxiv_ood_split_degree_covariate, arxiv_ood_split_time_concept, arxiv_ood_split_time_covariate]): 30 | name = ood_type[i] 31 | env_id_max = data.env_id.max() 32 | avg = [] 33 | for k in range(env_id_max + 1): 34 | mask = (data.env_id == k) 35 | pseudo_label = pseudo_labels[mask] 36 | gt_label = gt[mask] 37 | acc = (pseudo_label == gt_label).float().mean() 38 | print(f"{name} {k} {acc}") 39 | avg.append(acc) 40 | 41 | mask = data.val_mask 42 | pseudo_label = pseudo_labels[mask] 43 | gt_label = gt[mask] 44 | acc = (pseudo_label == gt_label).float().mean() 45 | print(f"{name} val {acc}") 46 | avg.append(acc) 47 | mask = data.test_mask 48 | pseudo_label = pseudo_labels[mask] 49 | gt_label = gt[mask] 50 | acc = (pseudo_label == gt_label).float().mean() 51 | print(f"{name} test {acc}") 52 | avg_acc = np.mean(avg) 53 | std_acc = np.std(avg) 54 | print(f"{name} avg {avg_acc} std {std_acc}") 55 | 56 | 57 | 58 | 59 | 60 | if __name__ == '__main__': 61 | ood_main() 62 | -------------------------------------------------------------------------------- /rev/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CurryTang/Graph-LLM/344f7c6b7786c7f8293c24ce5b90f141c777aeec/rev/__init__.py -------------------------------------------------------------------------------- /rev/gcn_revop.py: -------------------------------------------------------------------------------- 1 | """This module is implemented by Guohao Li based on MemCNN @ Copyright (c) 2018 Sil C. van de Leemput under MIT license.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | use_context_mans = True 8 | 9 | try: 10 | pytorch_version_one_and_above = int(torch.__version__[0]) > 0 11 | except TypeError: 12 | pytorch_version_one_and_above = True 13 | 14 | 15 | class InvertibleCheckpointFunction(torch.autograd.Function): 16 | @staticmethod 17 | def forward(ctx, fn, fn_inverse, keep_input, num_bwd_passes, preserve_rng_state, num_inputs, *inputs_and_weights): 18 | # store in context 19 | ctx.fn = fn 20 | ctx.fn_inverse = fn_inverse 21 | ctx.keep_input = keep_input 22 | ctx.weights = inputs_and_weights[num_inputs:] 23 | ctx.num_bwd_passes = num_bwd_passes 24 | ctx.preserve_rng_state = preserve_rng_state 25 | ctx.num_inputs = num_inputs 26 | inputs = inputs_and_weights[:num_inputs] 27 | 28 | if preserve_rng_state: 29 | ctx.fwd_cpu_state = torch.get_rng_state() 30 | # Don't eagerly initialize the cuda context by accident. 31 | # (If the user intends that the context is initialized later, within their 32 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 33 | # we have no way to anticipate this will happen before we run the function.) 34 | ctx.had_cuda_in_fwd = False 35 | if torch.cuda._initialized: 36 | ctx.had_cuda_in_fwd = True 37 | ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*inputs) 38 | 39 | ctx.input_requires_grad = [element.requires_grad for element in inputs] 40 | 41 | with torch.no_grad(): 42 | # Makes a detached copy which shares the storage 43 | x = [] 44 | for element in inputs: 45 | if isinstance(element, torch.Tensor): 46 | x.append(element.detach()) 47 | else: 48 | x.append(element) 49 | outputs = ctx.fn(*x) 50 | 51 | if not isinstance(outputs, tuple): 52 | outputs = (outputs,) 53 | 54 | # Detaches y in-place (inbetween computations can now be discarded) 55 | detached_outputs = tuple([element.detach_() for element in outputs]) 56 | 57 | # clear memory from inputs 58 | # only clear memory of node features 59 | if not ctx.keep_input: 60 | if not pytorch_version_one_and_above: 61 | # PyTorch 0.4 way to clear storage for node features 62 | inputs[0].data.set_() 63 | else: 64 | # PyTorch 1.0+ way to clear storage for node features 65 | inputs[0].storage().resize_(0) 66 | 67 | # store these tensor nodes for backward pass 68 | ctx.inputs = [inputs] * num_bwd_passes 69 | ctx.outputs = [detached_outputs] * num_bwd_passes 70 | 71 | return detached_outputs 72 | 73 | @staticmethod 74 | def backward(ctx, *grad_outputs): # pragma: no cover 75 | if not torch.autograd._is_checkpoint_valid(): 76 | raise RuntimeError("InvertibleCheckpointFunction is not compatible with .grad(), please use .backward() if possible") 77 | # retrieve input and output tensor nodes 78 | if len(ctx.outputs) == 0: 79 | raise RuntimeError("Trying to perform backward on the InvertibleCheckpointFunction for more than " 80 | "{} times! Try raising `num_bwd_passes` by one.".format(ctx.num_bwd_passes)) 81 | inputs = ctx.inputs.pop() 82 | outputs = ctx.outputs.pop() 83 | 84 | # recompute input if necessary 85 | if not ctx.keep_input: 86 | # Stash the surrounding rng state, and mimic the state that was 87 | # present at this time during forward. Restore the surrounding state 88 | # when we're done. 89 | rng_devices = [] 90 | if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: 91 | rng_devices = ctx.fwd_gpu_devices 92 | with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): 93 | if ctx.preserve_rng_state: 94 | torch.set_rng_state(ctx.fwd_cpu_state) 95 | if ctx.had_cuda_in_fwd: 96 | set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) 97 | # recompute input 98 | with torch.no_grad(): 99 | # edge_index and edge_emb 100 | inputs_inverted = ctx.fn_inverse(*(outputs+inputs[1:])) 101 | # clear memory from outputs 102 | if not pytorch_version_one_and_above: 103 | # PyTorch 0.4 way to clear storage 104 | for element in outputs: 105 | element.data.set_() 106 | else: 107 | # PyTorch 1.0+ way to clear storage 108 | for element in outputs: 109 | element.storage().resize_(0) 110 | 111 | if not isinstance(inputs_inverted, tuple): 112 | inputs_inverted = (inputs_inverted,) 113 | if pytorch_version_one_and_above: 114 | for element_original, element_inverted in zip(inputs, inputs_inverted): 115 | element_original.storage().resize_(int(np.prod(element_original.size()))) 116 | element_original.set_(element_inverted) 117 | else: 118 | for element_original, element_inverted in zip(inputs, inputs_inverted): 119 | element_original.set_(element_inverted) 120 | 121 | # compute gradients 122 | with torch.set_grad_enabled(True): 123 | detached_inputs = [] 124 | for element in inputs: 125 | if isinstance(element, torch.Tensor): 126 | detached_inputs.append(element.detach()) 127 | else: 128 | detached_inputs.append(element) 129 | detached_inputs = tuple(detached_inputs) 130 | for det_input, requires_grad in zip(detached_inputs, ctx.input_requires_grad): 131 | det_input.requires_grad = requires_grad 132 | temp_output = ctx.fn(*detached_inputs) 133 | if not isinstance(temp_output, tuple): 134 | temp_output = (temp_output,) 135 | 136 | filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, 137 | detached_inputs)) 138 | gradients = torch.autograd.grad(outputs=temp_output, 139 | inputs=filtered_detached_inputs + ctx.weights, 140 | grad_outputs=grad_outputs) 141 | 142 | # Setting the gradients manually on the inputs and outputs (mimic backwards) 143 | filtered_inputs = list(filter(lambda x: x.requires_grad, 144 | inputs)) 145 | 146 | input_gradients = [] 147 | i = 0 148 | for rg in ctx.input_requires_grad: 149 | if rg: 150 | input_gradients.append(gradients[i]) 151 | i += 1 152 | else: 153 | input_gradients.append(None) 154 | 155 | gradients = tuple(input_gradients) + gradients[-len(ctx.weights):] 156 | 157 | return (None, None, None, None, None, None) + gradients 158 | 159 | 160 | class InvertibleModuleWrapper(nn.Module): 161 | def __init__(self, fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1, 162 | disable=False, preserve_rng_state=False): 163 | """ 164 | The InvertibleModuleWrapper which enables memory savings during training by exploiting 165 | the invertible properties of the wrapped module. 166 | 167 | Parameters 168 | ---------- 169 | fn : :obj:`torch.nn.Module` 170 | A torch.nn.Module which has a forward and an inverse function implemented with 171 | :math:`x == m.inverse(m.forward(x))` 172 | 173 | keep_input : :obj:`bool`, optional 174 | Set to retain the input information on forward, by default it can be discarded since it will be 175 | reconstructed upon the backward pass. 176 | 177 | keep_input_inverse : :obj:`bool`, optional 178 | Set to retain the input information on inverse, by default it can be discarded since it will be 179 | reconstructed upon the backward pass. 180 | 181 | num_bwd_passes :obj:`int`, optional 182 | Number of backward passes to retain a link with the output. After the last backward pass the output 183 | is discarded and memory is freed. 184 | Warning: if this value is raised higher than the number of required passes memory will not be freed 185 | correctly anymore and the training process can quickly run out of memory. 186 | Hence, The typical use case is to keep this at 1, until it raises an error for raising this value. 187 | 188 | disable : :obj:`bool`, optional 189 | This will disable using the InvertibleCheckpointFunction altogether. 190 | Essentially this renders the function as `y = fn(x)` without any of the memory savings. 191 | Setting this to true will also ignore the keep_input and keep_input_inverse properties. 192 | 193 | preserve_rng_state : :obj:`bool`, optional 194 | Setting this will ensure that the same RNG state is used during reconstruction of the inputs. 195 | I.e. if keep_input = False on forward or keep_input_inverse = False on inverse. By default 196 | this is False since most invertible modules should have a valid inverse and hence are 197 | deterministic. 198 | 199 | Attributes 200 | ---------- 201 | keep_input : :obj:`bool`, optional 202 | Set to retain the input information on forward, by default it can be discarded since it will be 203 | reconstructed upon the backward pass. 204 | 205 | keep_input_inverse : :obj:`bool`, optional 206 | Set to retain the input information on inverse, by default it can be discarded since it will be 207 | reconstructed upon the backward pass. 208 | 209 | """ 210 | super(InvertibleModuleWrapper, self).__init__() 211 | self.disable = disable 212 | self.keep_input = keep_input 213 | self.keep_input_inverse = keep_input_inverse 214 | self.num_bwd_passes = num_bwd_passes 215 | self.preserve_rng_state = preserve_rng_state 216 | self._fn = fn 217 | 218 | def forward(self, *xin): 219 | """Forward operation :math:`R(x) = y` 220 | 221 | Parameters 222 | ---------- 223 | *xin : :obj:`torch.Tensor` tuple 224 | Input torch tensor(s). 225 | 226 | Returns 227 | ------- 228 | :obj:`torch.Tensor` tuple 229 | Output torch tensor(s) *y. 230 | 231 | """ 232 | if not self.disable: 233 | y = InvertibleCheckpointFunction.apply( 234 | self._fn.forward, 235 | self._fn.inverse, 236 | self.keep_input, 237 | self.num_bwd_passes, 238 | self.preserve_rng_state, 239 | len(xin), 240 | *(xin + tuple([p for p in self._fn.parameters() if p.requires_grad]))) 241 | else: 242 | y = self._fn(*xin) 243 | 244 | # If the layer only has one input, we unpack the tuple again 245 | if isinstance(y, tuple) and len(y) == 1: 246 | return y[0] 247 | return y 248 | 249 | def inverse(self, *yin): 250 | """Inverse operation :math:`R^{-1}(y) = x` 251 | 252 | Parameters 253 | ---------- 254 | *yin : :obj:`torch.Tensor` tuple 255 | Input torch tensor(s). 256 | 257 | Returns 258 | ------- 259 | :obj:`torch.Tensor` tuple 260 | Output torch tensor(s) *x. 261 | 262 | """ 263 | if not self.disable: 264 | x = InvertibleCheckpointFunction.apply( 265 | self._fn.inverse, 266 | self._fn.forward, 267 | self.keep_input_inverse, 268 | self.num_bwd_passes, 269 | self.preserve_rng_state, 270 | len(yin), 271 | *(yin + tuple([p for p in self._fn.parameters() if p.requires_grad]))) 272 | else: 273 | x = self._fn.inverse(*yin) 274 | 275 | # If the layer only has one input, we unpack the tuple again 276 | if isinstance(x, tuple) and len(x) == 1: 277 | return x[0] 278 | return x 279 | 280 | # To consider: maybe get_device_states and set_device_states should reside in 281 | # torch/random.py? 282 | # 283 | # get_device_states and set_device_states cannot be imported from 284 | # torch.utils.checkpoint, since it was not 285 | # present in older versions, so we include a copy here. 286 | def get_device_states(*args): 287 | # This will not error out if "arg" is a CPU tensor or a non-tensor type 288 | # because 289 | # the conditionals short-circuit. 290 | fwd_gpu_devices = list(set(arg.get_device() for arg in args 291 | if isinstance(arg, torch.Tensor) and arg.is_cuda)) 292 | 293 | fwd_gpu_states = [] 294 | for device in fwd_gpu_devices: 295 | with torch.cuda.device(device): 296 | fwd_gpu_states.append(torch.cuda.get_rng_state()) 297 | 298 | return fwd_gpu_devices, fwd_gpu_states 299 | 300 | 301 | def set_device_states(devices, states): 302 | for device, state in zip(devices, states): 303 | with torch.cuda.device(device): 304 | torch.cuda.set_rng_state(state) -------------------------------------------------------------------------------- /rev/memgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | try: 5 | from .gcn_revop import InvertibleModuleWrapper 6 | except: 7 | from gcn_revop import InvertibleModuleWrapper 8 | 9 | class GroupAdditiveCoupling(torch.nn.Module): 10 | def __init__(self, Fms, split_dim=-1, group=2): 11 | super(GroupAdditiveCoupling, self).__init__() 12 | 13 | self.Fms = Fms 14 | self.split_dim = split_dim 15 | self.group = group 16 | 17 | def forward(self, x, edge_index, *args): 18 | xs = torch.chunk(x, self.group, dim=self.split_dim) 19 | chunked_args = list(map(lambda arg: torch.chunk(arg, self.group, dim=self.split_dim), args)) 20 | args_chunks = list(zip(*chunked_args)) 21 | y_in = sum(xs[1:]) 22 | 23 | ys = [] 24 | for i in range(self.group): 25 | Fmd = self.Fms[i].forward(y_in, edge_index, *args_chunks[i]) 26 | y = xs[i] + Fmd 27 | y_in = y 28 | ys.append(y) 29 | 30 | out = torch.cat(ys, dim=self.split_dim) 31 | 32 | return out 33 | 34 | def inverse(self, y, edge_index, *args): 35 | ys = torch.chunk(y, self.group, dim=self.split_dim) 36 | chunked_args = list(map(lambda arg: torch.chunk(arg, self.group, dim=self.split_dim), args)) 37 | args_chunks = list(zip(*chunked_args)) 38 | 39 | xs = [] 40 | for i in range(self.group-1, -1, -1): 41 | if i != 0: 42 | y_in = ys[i-1] 43 | else: 44 | y_in = sum(xs) 45 | 46 | Fmd = self.Fms[i].forward(y_in, edge_index, *args_chunks[i]) 47 | x = ys[i] - Fmd 48 | xs.append(x) 49 | 50 | x = torch.cat(xs[::-1], dim=self.split_dim) 51 | 52 | return x -------------------------------------------------------------------------------- /rev/rev_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv 5 | from gcn_lib.sparse.torch_vertex import GENConv 6 | from gcn_lib.sparse.torch_nn import norm_layer 7 | 8 | 9 | class SharedDropout(nn.Module): 10 | def __init__(self): 11 | super(SharedDropout, self).__init__() 12 | self.mask = None 13 | 14 | def set_mask(self, mask): 15 | self.mask = mask 16 | 17 | def forward(self, x): 18 | if self.training: 19 | assert self.mask is not None 20 | out = x * self.mask 21 | return out 22 | else: 23 | return x 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | def __init__(self, norm, in_channels): 28 | super(BasicBlock, self).__init__() 29 | self.norm = norm_layer(norm, in_channels) 30 | self.dropout = SharedDropout() 31 | 32 | def forward(self, x, edge_index, dropout_mask=None, edge_emb=None): 33 | # dropout_mask = kwargs.get('dropout_mask', None) 34 | # edge_emb = kwargs.get('edge_emb', None) 35 | out = self.norm(x) 36 | out = F.relu(out) 37 | 38 | if isinstance(self.dropout, SharedDropout): 39 | if dropout_mask is not None: 40 | self.dropout.set_mask(dropout_mask) 41 | out = self.dropout(out) 42 | 43 | if edge_emb is not None: 44 | out = self.gcn(out, edge_index, edge_emb) 45 | else: 46 | out = self.gcn(out, edge_index) 47 | 48 | return out 49 | 50 | 51 | class GENBlock(BasicBlock): 52 | def __init__(self, in_channels, out_channels, 53 | aggr='max', 54 | t=1.0, learn_t=False, 55 | p=1.0, learn_p=False, 56 | y=0.0, learn_y=False, 57 | msg_norm=False, 58 | learn_msg_scale=False, 59 | encode_edge=False, 60 | edge_feat_dim=0, 61 | norm='layer', mlp_layers=1): 62 | super(GENBlock, self).__init__(norm, in_channels) 63 | 64 | self.gcn = GENConv(in_channels, out_channels, 65 | aggr=aggr, 66 | t=t, learn_t=learn_t, 67 | p=p, learn_p=learn_p, 68 | y=y, learn_y=learn_y, 69 | msg_norm=msg_norm, 70 | learn_msg_scale=learn_msg_scale, 71 | encode_edge=encode_edge, 72 | edge_feat_dim=edge_feat_dim, 73 | norm=norm, 74 | mlp_layers=mlp_layers) 75 | 76 | 77 | class GCNBlock(BasicBlock): 78 | def __init__(self, in_channels, out_channels, 79 | norm='layer'): 80 | super(GCNBlock, self).__init__(norm, in_channels) 81 | 82 | self.gcn = GCNConv(in_channels, out_channels) 83 | 84 | 85 | class SAGEBlock(BasicBlock): 86 | def __init__(self, in_channels, out_channels, 87 | norm='layer', 88 | dropout=0.0): 89 | super(SAGEBlock, self).__init__(norm, in_channels) 90 | 91 | self.gcn = SAGEConv(in_channels, out_channels) 92 | 93 | 94 | class GATBlock(torch.nn.Module): 95 | def __init__(self, in_channels, out_channels, 96 | heads=1, 97 | norm='layer', 98 | att_dropout=0.0, 99 | dropout=0.0): 100 | super(GATBlock, self).__init__(norm, in_channels) 101 | 102 | self.gcn = GATConv(in_channels, out_channels, 103 | heads=heads, 104 | concat=False, 105 | dropout=att_dropout, 106 | add_self_loops=False) -------------------------------------------------------------------------------- /secret.yaml: -------------------------------------------------------------------------------- 1 | openai: 2 | secret: "sk-xxxxxxx" 3 | google: 4 | secret: "xxxxxxx" 5 | word2vec: 6 | path: "path_to_your_word2vec_file" 7 | llama: 8 | path: "./ggml-xxxx.bin" -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.nodeproppred import Evaluator 3 | from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler 4 | import warnings 5 | import pytorch_warmup as warmup 6 | from torch_geometric.utils import index_to_mask 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import os.path as osp 10 | from data import read_and_unpkl 11 | from utils import norm_entropy 12 | import numpy as np 13 | import editdistance 14 | import ast 15 | 16 | 17 | 18 | class WarmupExpLR(_LRScheduler): 19 | def __init__(self, optimizer, warmup_epochs, total_epochs, gamma=0.1, last_epoch=-1): 20 | self.warmup_epochs = warmup_epochs 21 | self.total_epochs = total_epochs 22 | self.gamma = gamma 23 | super(WarmupExpLR, self).__init__(optimizer, last_epoch) 24 | 25 | def get_lr(self): 26 | if self.last_epoch < self.warmup_epochs: 27 | return [group['lr'] for group in self.optimizer.param_groups] 28 | else: 29 | return [group['lr'] * self.gamma 30 | for group in self.optimizer.param_groups] 31 | 32 | def _get_closed_form_lr(self): 33 | return [base_lr * self.gamma ** self.last_epoch 34 | for base_lr in self.base_lrs] 35 | 36 | 37 | 38 | 39 | 40 | def get_optimizer(args, model): 41 | if args.model_name == 'LP': 42 | return None, None 43 | if args.optim == 'adam': 44 | optimizer = torch.optim.Adam(model.parameters(), lr = args.lr, weight_decay=args.weight_decay) 45 | scheduler = None 46 | elif args.optim == 'radam': 47 | optimizer = torch.optim.RAdam(model.parameters(), lr = args.lr, weight_decay=args.weight_decay) 48 | scheduler = WarmupExpLR(optimizer, args.warmup, total_epochs=args.epochs, gamma=args.lr_gamma) 49 | return optimizer, scheduler 50 | 51 | 52 | def train(model, data, optimizer, loss_fn, train_mask, val_mask): 53 | model.train() 54 | optimizer.zero_grad() 55 | preds = model(data) 56 | if len(data.y.shape) != 1: 57 | y = data.y.squeeze(1) 58 | else: 59 | y = data.y 60 | train_loss = loss_fn(preds[train_mask], y[train_mask]) 61 | train_loss.backward() 62 | optimizer.step() 63 | val_loss = loss_fn(preds[val_mask], y[val_mask]) 64 | val_acc = test(model, data, False, val_mask) 65 | return train_loss, val_loss, val_acc 66 | 67 | 68 | def batch_train(model, loader, optimizer, device): 69 | model.train() 70 | total_loss = 0 71 | for batch in loader: 72 | batch_size, n_id, edge_index = batch.batch_size, batch.n_id, batch.edge_index 73 | # data = data.to(device) 74 | optimizer.zero_grad() 75 | batch.edge_index = batch.edge_index.to(device) 76 | out = model(batch)[:batch_size] 77 | y = batch.y[:batch_size].squeeze() 78 | loss = F.cross_entropy(out, y) 79 | loss.backward() 80 | optimizer.step() 81 | total_loss += loss.item() 82 | return total_loss / len(loader) 83 | 84 | 85 | def to_inductive(data, msk_index = 0): 86 | data = data.clone() 87 | mask = data.train_masks[msk_index] 88 | data.x = data.x[mask] 89 | data.y = data.y[mask] 90 | data.train_mask = mask[mask] 91 | data.test_masks = None 92 | data.edge_index, _ = subgraph(mask, data.edge_index, None, 93 | relabel_nodes=True, num_nodes=data.num_nodes) 94 | data.num_nodes = mask.sum().item() 95 | return data 96 | 97 | 98 | 99 | @torch.no_grad() 100 | def batch_test(model, data, evaluator, subgraph_loader, device, mask): 101 | model.eval() 102 | 103 | out = model.inference(data.x, subgraph_loader, device) 104 | 105 | y_pred = out.argmax(dim=-1, keepdim=True) 106 | 107 | # import ipdb; ipdb.set_trace() 108 | if len(data.y.shape) == 1: 109 | y_true = data.y.unsqueeze(dim=1) # for non ogb datas 110 | else: 111 | y_true = data.y 112 | 113 | test_acc = evaluator.eval({ 114 | 'y_true': y_true[mask], 115 | 'y_pred': y_pred[mask] 116 | })['acc'] 117 | 118 | return test_acc 119 | 120 | 121 | 122 | @torch.no_grad() 123 | def topk_test(model, data, mask, topk = 3, need_batch = False, subgraph_loader = None): 124 | model.eval() 125 | # model.model.initialized = False 126 | if not need_batch: 127 | out = model(data) 128 | y_pred = out.argmax(dim=-1, keepdim=True) 129 | else: 130 | out = model.inference(data.x, subgraph_loader, device) 131 | y_true = data.y 132 | y_pred = out.argmax(dim=-1, keepdim=True) 133 | r_y_pred = y_pred.reshape(-1) 134 | confidence = out.gather(1, r_y_pred.unsqueeze(1)).reshape(-1) 135 | data.confidence = confidence 136 | sorted_conf_idx = torch.argsort(data.confidence) 137 | full_length = data.x.shape[0] 138 | com_res = data.y.view(-1, 1).expand_as(out.topk(3,1).values).eq(out.topk(3,1).indices).sum(-1).to(torch.bool) 139 | low_confidence_sorted_conf_mask = index_to_mask(sorted_conf_idx[:full_length // 3], size=full_length) 140 | med_confidence_sorted_conf_mask = index_to_mask(sorted_conf_idx[full_length // 3 : full_length * 2 // 3], size=full_length) 141 | high_confidence_sorted_conf_mask = index_to_mask(sorted_conf_idx[full_length * 2 // 3:], size=full_length) 142 | 143 | y_1 = y_pred.reshape(-1) 144 | true_mask = (y_1 == data.y) 145 | false_mask = ~true_mask 146 | 147 | evaluator = Evaluator(name='ogbn-arxiv') 148 | top3_low_acc = torch.sum(com_res[mask & low_confidence_sorted_conf_mask]) / com_res[mask & low_confidence_sorted_conf_mask].shape[0] 149 | top3_med_acc = torch.sum(com_res[mask & med_confidence_sorted_conf_mask]) / com_res[mask & med_confidence_sorted_conf_mask].shape[0] 150 | top3_high_acc = torch.sum(com_res[mask & high_confidence_sorted_conf_mask]) / com_res[mask & high_confidence_sorted_conf_mask].shape[0] 151 | # true_acc = torch.sum(com_res[mask & true_mask]) / com_res[mask & true_mask].shape[0] 152 | 153 | res = data.y.view(-1).eq(r_y_pred) 154 | top1_low_acc = torch.sum(res[mask & low_confidence_sorted_conf_mask]) / res[mask & low_confidence_sorted_conf_mask].shape[0] 155 | top1_med_acc = torch.sum(res[mask & med_confidence_sorted_conf_mask]) / res[mask & med_confidence_sorted_conf_mask].shape[0] 156 | top1_high_acc = torch.sum(res[mask & high_confidence_sorted_conf_mask]) / res[mask & high_confidence_sorted_conf_mask].shape[0] 157 | # top1_low_acc = torch.sum() 158 | top3_false_acc = torch.sum(com_res[mask & false_mask]) / com_res[mask & false_mask].shape[0] 159 | total_acc = torch.sum(com_res[mask]) / com_res[mask].shape[0] 160 | print("Top3 Accuracy on low confidence nodes: {}\n".format(top3_low_acc.item())) 161 | print("Top3 Accuracy on medium confidence nodes: {}\n".format(top3_med_acc.item())) 162 | print("Top3 Accuracy on high confidence nodes: {}\n".format(top3_high_acc.item())) 163 | print("Top1 Accuracy on low confidence nodes: {}\n".format(top1_low_acc.item())) 164 | print("Top1 Accuracy on medium confidence nodes: {}\n".format(top1_med_acc.item())) 165 | print("Top1 Accuracy on high confidence nodes: {}\n".format(top1_high_acc.item())) 166 | print("Top3 Accuracy on gnn false nodes: {}\n".format(top3_false_acc.item())) 167 | return top3_low_acc.item(), top3_med_acc.item(), top3_high_acc.item(), total_acc.item() 168 | 169 | 170 | 171 | @torch.no_grad() 172 | def confidence_test(model, data, mask): 173 | model.eval() 174 | # model.model.initialized = False 175 | out = model(data) 176 | y_pred = out.argmax(dim=-1, keepdim=True) 177 | r_y_pred = y_pred.reshape(-1) 178 | confidence = out.gather(1, r_y_pred.unsqueeze(1)).reshape(-1) 179 | data.confidence = confidence 180 | sorted_conf_idx = torch.argsort(data.confidence) 181 | full_length = data.x.shape[0] 182 | low_confidence_sorted_conf_mask = index_to_mask(sorted_conf_idx[:full_length // 3], size=full_length) 183 | med_confidence_sorted_conf_mask = index_to_mask(sorted_conf_idx[full_length // 3 : full_length * 2 // 3], size=full_length) 184 | high_confidence_sorted_conf_mask = index_to_mask(sorted_conf_idx[full_length * 2 // 3:], size=full_length) 185 | # ground_truth = data.y.cpu() 186 | # true_mask = data.y.cpu() == y_pred.cpu() 187 | # false_mask = data.y.cpu() != y_pred.cpu() 188 | 189 | if len(data.y.shape) == 1: 190 | y = data.y.unsqueeze(dim=1) # for non ogb datas 191 | else: 192 | y = data.y 193 | 194 | y_1 = y_pred.reshape(-1) 195 | true_mask = (y_1 == data.y) 196 | false_mask = ~true_mask 197 | 198 | evaluator = Evaluator(name='ogbn-arxiv') 199 | low_acc = evaluator.eval({ 200 | 'y_true': y[mask | low_confidence_sorted_conf_mask], 201 | 'y_pred': y_pred[mask | low_confidence_sorted_conf_mask], 202 | })['acc'] 203 | 204 | med_acc = evaluator.eval({ 205 | 'y_true': y[mask | med_confidence_sorted_conf_mask], 206 | 'y_pred': y_pred[mask | med_confidence_sorted_conf_mask], 207 | })['acc'] 208 | 209 | high_acc = evaluator.eval({ 210 | 'y_true': y[mask | high_confidence_sorted_conf_mask], 211 | 'y_pred': y_pred[mask | high_confidence_sorted_conf_mask], 212 | })['acc'] 213 | 214 | 215 | true_acc = evaluator.eval({ 216 | 'y_true': y[mask | true_mask], 217 | 'y_pred': y_pred[mask | true_mask], 218 | })['acc'] 219 | 220 | 221 | false_acc = evaluator.eval({ 222 | 'y_true': y[mask | false_mask], 223 | 'y_pred': y_pred[mask | false_mask], 224 | })['acc'] 225 | 226 | print(true_acc, false_acc) 227 | 228 | return low_acc, med_acc, high_acc 229 | 230 | @torch.no_grad() 231 | def test(model, data, return_embeds, mask): 232 | model.eval() 233 | # model.model.initialized = False 234 | out = model(data) 235 | y_pred = out.argmax(dim=-1, keepdim=True) 236 | 237 | if len(data.y.shape) == 1: 238 | y = data.y.unsqueeze(dim=1) # for non ogb datas 239 | else: 240 | y = data.y 241 | 242 | evaluator = Evaluator(name='ogbn-arxiv') 243 | acc = evaluator.eval({ 244 | 'y_true': y[mask], 245 | 'y_pred': y_pred[mask], 246 | })['acc'] 247 | 248 | 249 | if not return_embeds: 250 | return acc, None 251 | else: 252 | return acc, out 253 | 254 | 255 | def loss_kd(all_out, teacher_all_out, outputs, labels, teacher_outputs, 256 | alpha, temperature): 257 | """ 258 | loss function for Knowledge Distillation (KD) 259 | """ 260 | 261 | T = temperature 262 | 263 | loss_CE = F.cross_entropy(outputs, labels) 264 | D_KL = nn.KLDivLoss()(F.log_softmax(all_out / T, dim=1), 265 | F.softmax(teacher_all_out / T, dim=1)) * (T * T) 266 | KD_loss = (1. - alpha) * loss_CE + alpha * D_KL 267 | 268 | return KD_loss 269 | 270 | def loss_kd_only(all_out, teacher_all_out, temperature): 271 | T = temperature 272 | 273 | D_KL = nn.KLDivLoss()(F.log_softmax(all_out / T, dim=1), 274 | F.softmax(teacher_all_out / T, dim=1)) * (T * T) 275 | 276 | return D_KL 277 | 278 | 279 | 280 | def seed_everything(seed: int): 281 | import random, os 282 | import numpy as np 283 | import torch 284 | 285 | random.seed(seed) 286 | os.environ['PYTHONHASHSEED'] = str(seed) 287 | np.random.seed(seed) 288 | torch.manual_seed(seed) 289 | torch.cuda.manual_seed(seed) 290 | torch.backends.cudnn.deterministic = True 291 | torch.backends.cudnn.benchmark = True 292 | 293 | 294 | 295 | 296 | # def glem(num_nodes, hidden_num, n_labels, dataset = 'cora', gnn_model_name = 'GCN', feature_name = 'sbert', lm_output_path = './lmoutput', gnn_output_path = './output', setting = 'fixed'): 297 | # ## pretrain lm and get the embeddings 298 | # lm_output_embedding_path = osp.join(lm_output_path, f"{dataset}_finetune_{setting}.emb") 299 | # lm_output_pred_path = osp.join(lm_output_path, f"{dataset}_finetune_{setting}.pred") 300 | # lm_emb_np = np.memmap(lm_output_embedding_path, dtype=np.float16, mode='r', 301 | # shape=(num_nodes, hidden_dim)) 302 | # lm_pred = np.memmap(lm_output_pred_path, dtype=np.float16, mode='r', 303 | # shape=(num_nodes, n_labels)) 304 | # lm_emb = torch.tensor(emb, dtype=torch.float32) 305 | # lm_pred = torch.tensor(pred, dtype=torch.float32) 306 | # ## pretrain gnn and get the embeddings 307 | # gnn_pred_path = osp.join(gnn_output_path, f"{gnn_model_name}_{dataset}_{feature_name}.pkl") 308 | # gnn_pred = read_and_unpkl(gnn_pred_path) 309 | 310 | 311 | def tensor_intersection(tensor1, tensor2): 312 | set1 = set(tensor1.numpy().flatten()) 313 | set2 = set(tensor2.numpy().flatten()) 314 | 315 | intersection = set1 & set2 316 | 317 | return torch.tensor(list(intersection)) 318 | 319 | @torch.no_grad() 320 | def llm_pseudo_label(data, logits, budget = 100, train_val_ratio = 3, strategy = 1): 321 | """ 322 | train_val_ratio: new train : new val 323 | Strategy 1: totally random 324 | Strategy 2: each class random 325 | Strategy 3: confidence based 326 | Strategy 4: class confidence based 327 | Strategy 5: use prompt to test llm's confidence 328 | """ 329 | ## data is the low labeling rate data 330 | node_idx = torch.arange(data.x.shape[0]) 331 | test_mask = data.test_masks[0].cpu() 332 | data = data.cpu() 333 | test_idx = node_idx[test_mask] 334 | if strategy == 1: 335 | selected_test_idx = torch.randperm(test_idx.shape[0])[:budget] 336 | elif strategy == 2: 337 | num_of_class = data.y.max().item() + 1 338 | per_class = budget // num_of_class 339 | selected_test_idx = [] 340 | count = [0 for _ in range(num_of_class)] 341 | rand_node_idx = torch.randperm(test_idx.shape[0]) 342 | for i in rand_node_idx: 343 | if i not in test_idx: continue 344 | lbl = data.y[i].item() 345 | if count[lbl] < per_class: 346 | selected_test_idx.append(i.item()) 347 | count[lbl] += 1 348 | if min(count) == per_class: break 349 | selected_test_idx = torch.LongTensor(selected_test_idx) 350 | elif strategy == 3: 351 | norm_entro = norm_entropy(logits) 352 | test_idx_set = set(test_idx.tolist()) 353 | sorted_idx = torch.argsort(norm_entro).tolist() 354 | intersection = [i for i in sorted_idx if i in test_idx_set] 355 | selected_test_idx = torch.LongTensor(intersection[:budget]) 356 | elif strategy == 4: 357 | num_of_class = data.y.max().item() + 1 358 | per_class = budget // num_of_class 359 | count = [0 for _ in range(num_of_class)] 360 | norm_entro = norm_entropy(logits) 361 | test_idx_set = set(test_idx.tolist()) 362 | sorted_idx = torch.argsort(norm_entro).tolist() 363 | for i in sorted_idx: 364 | if i not in test_idx_set: continue 365 | lbl = data.y[i].item() 366 | if count[lbl] < per_class: 367 | selected_test_idx.append(i.item()) 368 | count[lbl] += 1 369 | if min(count) == per_class: break 370 | selected_test_idx = torch.LongTensor(selected_test_idx) 371 | return selected_test_idx 372 | 373 | 374 | 375 | 376 | def top1_label_getter(pred_texts, label_names): 377 | preds = [] 378 | label_names = [l.lower() for l in label_names] 379 | for i, t in enumerate(pred_texts): 380 | match = False 381 | clean_t = t.replace('.', ' ') 382 | clean_t = clean_t.lower() 383 | try: 384 | start = clean_t.find('[') 385 | end = clean_t.find(']', start) + 1 # +1 to include the closing bracket 386 | list_str = clean_t[start:end] 387 | result = ast.literal_eval(list_str) 388 | res = result[0] 389 | if res in label_names: 390 | this = label_names.index(res) 391 | preds.append(this) 392 | match = True 393 | else: 394 | edits = np.array([editdistance.eval(res, l) for l in label_names]) 395 | this = np.argmin(edits) 396 | preds.append(this) 397 | match = True 398 | except Exception: 399 | for i, l in enumerate(label_names): 400 | if l.lower() in clean_t: 401 | preds.append(i) 402 | match = True 403 | break 404 | if not match: 405 | edits = np.array([editdistance.eval(clean_t, l) for l in label_names]) 406 | this = np.argmin(edits) 407 | preds.append(this) 408 | 409 | preds = torch.LongTensor(preds) 410 | return preds 411 | 412 | 413 | 414 | 415 | def annotator(pred_texts, label_names): 416 | label_names = [l.lower() for l in label_names] 417 | anno = [] 418 | conf = [] 419 | for i, t in enumerate(pred_texts): 420 | match = False 421 | # clean_t = t.replace('.', ' ') 422 | clean_t = t.lower() 423 | try: 424 | start = clean_t.find('{') 425 | end = clean_t.find('}', start) + 1 # +1 to include the closing bracket 426 | list_str = clean_t[start:end] 427 | # import ipdb; ipdb.set_trace() 428 | result = ast.literal_eval(list_str) 429 | # import ipdb; ipdb.set_trace() 430 | label = ast.literal_eval(result['category']) 431 | confidence = result['confidence level'] 432 | l = label_names.index(label[0]) 433 | anno.append(l) 434 | conf.append(confidence) 435 | # import ipdb; ipdb.set_trace() 436 | except Exception: 437 | anno.append(-1) 438 | conf.append(0) 439 | 440 | anno = torch.LongTensor(anno) 441 | return anno, conf 442 | --------------------------------------------------------------------------------