├── ecrim ├── train.sh ├── matrix_transformer.py ├── graph_encoder.py ├── pyg_graph.py ├── sbert_wk.py ├── topological_sort.py ├── buffer.py ├── data_helper.py ├── sentence_reordering.py ├── trainer.py └── main_simp.py ├── README.md ├── data ├── retrieve_data.py ├── q_pos.py ├── retrieval.py ├── rawdata │ └── relations.json ├── redis_doc.py └── load_data_doc.py ├── LICENSE └── .gitignore /ecrim/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --train --dev --test --per_gpu_train_batch_size 1 --per_gpu_eval_batch_size 1 --learning_rate 3e-5 --epochs 10 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ECRIM 2 | 3 | For data preparation and processing steps, please refer to https://github.com/thunlp/CodRED 4 | 5 | execute *train.sh* to run the code. 6 | 7 | If you want to train a usable model in a limited time, please run simplified version *main_simp.py*. 8 | 9 | We also provide a simple method to enhance the model, if interested, please run *main_enhance.py*. 10 | 11 | [The collating work was done in a hurry, and there may be some errors and omissions. If you have any questions, please contact us by email.] -------------------------------------------------------------------------------- /data/retrieve_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | 4 | 5 | def main(): 6 | overlap = json.load(open('result-count.json')) 7 | dev_dataset = json.load(open('rawdata/dev_dataset.json')) 8 | dev_ep2r = defaultdict(list) 9 | for ep, _, _, r in dev_dataset: 10 | dev_ep2r[ep].append(r) 11 | dev_data = list() 12 | for ep, r in dev_ep2r.items(): 13 | if 'n/a' in r and len(r) > 1: 14 | r.remove('n/a') 15 | for i, dps in enumerate(overlap[ep]): 16 | dev_data.append([ep, dps[0], dps[1], r[i % len(r)]]) 17 | json.dump(dev_data, open('open_dev_data.json', 'w')) 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MakiseKuurisu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/q_pos.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import redis 4 | from tqdm import tqdm 5 | 6 | 7 | def main(): 8 | q2t = dict() 9 | t2q = dict() 10 | redisd = redis.Redis(host='localhost', port=6379, decode_responses=True) 11 | titles = set() 12 | popular_docs = json.load(open('popular_docs.json')) 13 | for _, title in popular_docs: 14 | titles.add(title) 15 | all_docs = json.load(open('all_docs.json')) 16 | for _, title in popular_docs: 17 | if title not in titles: 18 | titles.add(title) 19 | for title in tqdm(titles): 20 | doc = json.loads(redisd.get(f'codred-doc-{title}')) 21 | for entity in doc['entities']: 22 | if 'Q' in entity: 23 | name = 'Q' + str(entity['Q']) 24 | if name not in q2t: 25 | q2t[name] = dict() 26 | q2t[name][title] = len(entity) 27 | if title not in t2q: 28 | t2q[title] = dict() 29 | t2q[title][name] = len(entity) 30 | json.dump(q2t, open('q2t.json', 'w')) 31 | json.dump(t2q, open('t2q.json', 'w')) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /data/retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import random 4 | import sys 5 | from itertools import product 6 | 7 | import redis 8 | from tqdm import tqdm 9 | 10 | 11 | def count_candidates(q2t, t2q, h, t): 12 | d1 = q2t[h] 13 | d2 = q2t[t] 14 | candidates = list() 15 | for _1, _2 in product(d1, d2): 16 | e1s = t2q[_1] 17 | e2s = t2q[_2] 18 | if h in e1s and t in e2s: 19 | candidates.append([_1, _2, e1s[h] * e2s[t]]) 20 | candidates.sort(key=lambda x: x[2], reverse=True) 21 | return candidates 22 | 23 | 24 | def place_data(dataset): 25 | epr2d = dict() 26 | key2docs = dict() 27 | for key, doc1, doc2, label in dataset: 28 | if key not in epr2d: 29 | epr2d[key] = set() 30 | key2docs[key] = set() 31 | epr2d[key].add(label) 32 | if label != 'n/a': 33 | key2docs[key].add((doc1, doc2)) 34 | bags = list() 35 | for key, labels in epr2d.items(): 36 | rs = list(labels) 37 | if 'n/a' in rs and len(rs) > 1: 38 | rs.remove('n/a') 39 | bags.append([key, rs, key2docs[key]]) 40 | return bags 41 | 42 | 43 | def main(): 44 | dev_dataset = json.load(open('rawdata/dev_dataset.json')) 45 | q2t = json.load(open('q2t.json')) 46 | t2q = json.load(open('t2q.json')) 47 | dev_bags = place_data(dev_dataset) 48 | ret = dict() 49 | dev_ranks = list() 50 | for key, rs, docs in tqdm(dev_bags): 51 | ground_doc_pairs = set([tuple(c) for c in docs]) 52 | h, t = key.split('#') 53 | docpairs = count_candidates(q2t, t2q, h, t) 54 | if len(ground_doc_pairs) > 0: 55 | rank = list() 56 | for i, c in enumerate(docpairs): 57 | if (c[0], c[1]) in ground_doc_pairs: 58 | rank.append(i + 1) 59 | while len(rank) < len(ground_doc_pairs): 60 | rank.append(1000000) 61 | dev_ranks.append(rank) 62 | docpairs = [[d[0], d[1]] for d in docpairs[0:16]] 63 | ret[key] = docpairs 64 | json.dump(ret, open(f'result-count.json', 'w')) 65 | json.dump(dev_ranks, open(f'dev-rank-count.json', 'w')) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /data/rawdata/relations.json: -------------------------------------------------------------------------------- 1 | ["P6216", "P1389", "P1923", "P123", "P1056", "P414", "P3373", "P2500", "P1050", "P206", "P5826", "P205", "P2541", "P4552", "P161", "P2852", "P398", "P1444", "P2860", "P366", "P122", "P126", "P102", "P1027", "P840", "P7047", "P1346", "P175", "P17", "P832", "P520", "P156", "P361", "P69", "P1066", "P2499", "P2389", "P880", "P1387", "P3342", "P57", "P1532", "P463", "P1366", "P2853", "P404", "P360", "P408", "P2743", "P5869", "P944", "P1454", "P1344", "P178", "P4000", "P155", "P4044", "P1990", "P915", "P527", "P1414", "P551", "P1142", "P144", "P3018", "P769", "P1336", "P924", "P136", "P937", "P1435", "P658", "P725", "P286", "P241", "P2546", "P108", "P1434", "P3320", "P511", "P460", "P3033", "P1716", "P1072", "P19", "P707", "P1433", "P400", "P1411", "P25", "P793", "P375", "P521", "P5658", "P22", "P974", "P39", "P193", "P1408", "P739", "P3966", "P1876", "P437", "P546", "P50", "P20", "P169", "P287", "P40", "P3137", "P425", "P170", "P1535", "P36", "P113", "P4743", "P282", "P516", "P112", "P115", "P264", "P3095", "P4791", "P751", "P1891", "P559", "P610", "P26", "P674", "P2094", "P1192", "P411", "P137", "P509", "P512", "P355", "P2522", "P3091", "P4387", "P749", "P2564", "P129", "P664", "P6942", "P611", "P176", "P750", "P1582", "P197", "P740", "P127", "P38", "P1445", "P180", "P1399", "P131", "P2935", "P703", "P2348", "P2321", "P150", "P58", "P5995", "P4614", "P1327", "P149", "P6885", "P462", "P1622", "P1308", "P488", "P2341", "P461", "P1995", "P119", "P118", "P138", "P2868", "P2175", "P495", "P859", "P2670", "P1303", "P1427", "P451", "P3494", "P607", "P199", "P767", "P485", "P598", "P2079", "P279", "P1038", "P852", "P1158", "P449", "P1830", "P5025", "P921", "P3448", "P1343", "P523", "P747", "P276", "P1080", "P1431", "P54", "P553", "P85", "P2925", "P179", "P86", "P403", "P59", "P2176", "P1071", "P135", "P177", "P2408", "P1079", "P483", "P162", "P800", "P629", "P1001", "P4446", "P2579", "P504", "P479", "P737", "P706", "P2416", "P7153", "P121", "P1416", "P2962", "P190", "P101", "P4132", "P1073", "P790", "P291", "P452", "P6275", "P2789", "P608", "P184", "P30", "P5096", "P2360", "P489", "P1877", "P159", "P676", "P1441", "P931", "P114", "P397", "P61", "P306", "P2283", "P6379", "P1365", "P140", "P53", "P272", "P941", "P1037", "P710", "P3075", "P16", "P1557", "P171", "P421", "P945"] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data/redis_doc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from multiprocessing import Pool 4 | 5 | import redis 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer 8 | 9 | 10 | def process(line): 11 | line = line.strip() 12 | if len(line) == 0: 13 | return None 14 | article = json.loads(line) 15 | tokens = list() 16 | mapping = dict() 17 | doc_id = int(article['id']) 18 | for para_id, para in enumerate(article['tokens']): 19 | for sent_id, sentence in enumerate(para): 20 | for word_id, word in enumerate(sentence): 21 | subwords = tokenizer.tokenize(word) 22 | mapping[(para_id, sent_id, word_id)] = list(range(len(tokens), len(tokens) + len(subwords))) 23 | tokens.extend(subwords) 24 | qs = list() 25 | for entity in article['vertexSet']: 26 | assert len(entity) > 0 27 | spans = list() 28 | for mention in entity: 29 | subwords = list() 30 | for position in range(mention['pos'][2], mention['pos'][3]): 31 | k = (mention['pos'][0], mention['pos'][1], position) 32 | if k in mapping: 33 | subwords.extend(mapping[k]) 34 | if len(subwords) > 0: 35 | span = [min(subwords), max(subwords) + 1] 36 | spans.append(span) 37 | if len(spans) > 0: 38 | k = dict() 39 | for key in entity[0]: 40 | if key != 'pos': 41 | k[key] = entity[0][key] 42 | k['spans'] = spans 43 | qs.append(k) 44 | obj = dict() 45 | obj['tokens'] = tokens 46 | obj['entities'] = qs 47 | obj['id'] = article['id'] 48 | obj['title'] = article['title'] 49 | redisd.set(f'codred-doc-{obj["title"]}', json.dumps(obj)) 50 | return doc_id, article['title'] 51 | 52 | 53 | def initializer(base_model): 54 | global redisd 55 | global tokenizer 56 | redisd = redis.Redis(host='localhost', port=6379, decode_responses=True) 57 | tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True) 58 | 59 | 60 | def main(base_model): 61 | redisd = redis.Redis(host='localhost', port=6379, decode_responses=True) 62 | all_ids = list() 63 | with open('rawdata/wiki_ent_link.jsonl') as f: 64 | with Pool(48, initializer=initializer, initargs=(base_model,)) as p: 65 | for doc_id, title in tqdm(p.imap_unordered(process, f)): 66 | all_ids.append([doc_id, title]) 67 | json.dump(all_ids, open('all_docs.json', 'w')) 68 | popular_ids = list() 69 | with open('rawdata/popular_page_ent_link.jsonl') as f: 70 | with Pool(48, initializer=initializer, initargs=(base_model,)) as p: 71 | for doc_id, title in tqdm(p.imap_unordered(process, f)): 72 | popular_ids.append([doc_id, title]) 73 | json.dump(popular_ids, open('popular_docs.json', 'w')) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--base_model', type=str, default='bert-base-cased') 79 | args = parser.parse_args() 80 | main(args.base_model) 81 | -------------------------------------------------------------------------------- /data/load_data_doc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | from functools import partial 5 | from multiprocessing import Pool 6 | 7 | import redis 8 | from tqdm import tqdm 9 | from transformers import AutoConfig, AutoTokenizer 10 | 11 | #def process(line, redisd, relations, dev_docs): 12 | def process(line): 13 | #tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True) 14 | article = json.loads(line) 15 | tokens = list() 16 | mapping = dict() 17 | doc_id = int(article['id']) 18 | for para_id, para in enumerate(article['tokens']): 19 | for sent_id, sentence in enumerate(para): 20 | for word_id, word in enumerate(sentence): 21 | subwords = tokenizer.tokenize(word) 22 | mapping[(para_id, sent_id, word_id)] = list(range(len(tokens), len(tokens) + len(subwords))) 23 | tokens.extend(subwords) 24 | qs = list() 25 | for entity in article['vertexSet']: 26 | spans = list() 27 | for mention in entity: 28 | if 'Q' in mention: 29 | subwords = list() 30 | for position in range(mention['pos'][2], mention['pos'][3]): 31 | subwords.extend(mapping[(mention['pos'][0], mention['pos'][1], position)]) 32 | span = [min(subwords), max(subwords) + 1] 33 | spans.append(span) 34 | if len(spans) == len(entity): 35 | qs.append({ 36 | 'Q': entity[0]['Q'], 37 | 'spans': spans 38 | }) 39 | else: 40 | qs.append(None) 41 | instances = list() 42 | kset = set() 43 | for edge in article['edgeSet']: 44 | h = edge['h'] 45 | t = edge['t'] 46 | kset.add((h, t)) 47 | if qs[h] is None or qs[t] is None: 48 | continue 49 | for r in edge['rs']: 50 | if 'P' + str(r) in relations: 51 | span_h = qs[h]['spans'][0] 52 | span_t = qs[t]['spans'][0] 53 | instances.append([doc_id, span_h[0], span_h[1], span_t[0], span_t[1], 'P' + str(r)]) 54 | no_relations = list() 55 | for i in range(len(qs)): 56 | if qs[i] is None: 57 | continue 58 | for j in range(len(qs)): 59 | if qs[j] is None: 60 | continue 61 | if i != j and (i, j) not in kset: 62 | no_relations.append((i, j)) 63 | if len(no_relations) > len(instances): 64 | no_relations = random.choices(no_relations, k=len(instances)) 65 | for i, j in no_relations: 66 | instances.append([doc_id, qs[i]['spans'][0][0], qs[i]['spans'][0][1], qs[j]['spans'][0][0], qs[j]['spans'][0][1], 'n/a']) 67 | redisd.set(f'dsre-doc-{doc_id}', json.dumps(tokens)) 68 | return instances, article['title'] in dev_docs 69 | 70 | 71 | def initializer(base_model, _relations, t_docs): 72 | global redisd 73 | global tokenizer 74 | global relations 75 | global dev_docs 76 | redisd = redis.Redis(host='localhost', port=6379, decode_responses=True) 77 | tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True) 78 | relations = set(_relations) 79 | dev_docs = t_docs 80 | 81 | 82 | def main(base_model): 83 | redisd = redis.Redis(host='localhost', port=6379, decode_responses=True) 84 | dev_dataset = json.load(open('rawdata/dev_dataset.json')) 85 | dev_docs = set(map(lambda x: x[1], dev_dataset)) | set(map(lambda x: x[2], dev_dataset)) 86 | relations = json.load(open('rawdata/relations.json')) 87 | lines = list() 88 | with open('rawdata/distant_documents.jsonl') as f: 89 | for line in tqdm(f): 90 | lines.append(line.strip()) 91 | train_examples = list() 92 | dev_examples = list() 93 | 94 | 95 | with Pool(16, initializer=initializer, initargs=(base_model, relations, dev_docs)) as p: 96 | for instances, is_dev in tqdm(p.imap(process, lines)): 97 | if is_dev: 98 | dev_examples.extend(instances) 99 | else: 100 | train_examples.extend(instances) 101 | """ 102 | for line in tqdm(lines): 103 | instances, is_dev = process(line, redisd, relations, dev_docs) 104 | if is_dev: 105 | dev_examples.extend(instances) 106 | else: 107 | train_examples.extend(instances) 108 | """ 109 | 110 | json.dump(train_examples, open('dsre_train_examples.json', 'w')) 111 | json.dump(dev_examples, open('dsre_dev_examples.json', 'w')) 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--base_model', type=str, default='bert-base-cased') 117 | args = parser.parse_args() 118 | __spec__ = "ModuleSpec(name='builtins', loader=)" 119 | main(args.base_model) 120 | -------------------------------------------------------------------------------- /ecrim/matrix_transformer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch as tc 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | import pdb 8 | 9 | 10 | class Attention(nn.Module): 11 | 12 | def __init__(self , h , d_model): 13 | super().__init__() 14 | 15 | assert d_model % h == 0 16 | self.reduced_dim = 128 17 | self.d_model = d_model 18 | self.h = h 19 | self.dk = d_model // h 20 | self.nemax = 20 21 | 22 | self.WQ = nn.Linear(self.dk , self.dk) 23 | self.WK = nn.Linear(self.dk , self.dk) 24 | self.WV = nn.Linear(self.dk , self.dk) 25 | 26 | self.relative_pos_emb = nn.Parameter( tc.zeros(2 * self.nemax, 2 * self.nemax, d_model) ) 27 | #self.reset_params() 28 | 29 | def reset_params(self): 30 | nn.init.xavier_normal_(self.WQ.weight.data) 31 | nn.init.xavier_normal_(self.WK.weight.data) 32 | nn.init.xavier_normal_(self.WV.weight.data) 33 | 34 | nn.init.constant_(self.WQ.bias.data , 0) 35 | nn.init.constant_(self.WK.bias.data , 0) 36 | nn.init.constant_(self.WV.bias.data , 0) 37 | 38 | def forward(self , R , R_mas): 39 | ''' 40 | R: (bs , ne , ne , d) 41 | R_mas: (bs , ne , ne , 1) 42 | ''' 43 | 44 | h , dk = self.h , self.dk 45 | bs , ne , ne , d = R.size() 46 | assert d == self.d_model or d == self.reduced_dim 47 | 48 | R = R.view(bs,ne,ne,h,dk).permute(0,3,1,2,4).contiguous() #(bs , h , ne , ne , dk) 49 | R_mas = R_mas.view(bs,1,ne,ne,1) 50 | 51 | Q , K , V = self.WQ(R) , self.WK(R) , self.WV(R) 52 | 53 | Q = Q.view(bs,h,ne*ne,dk) 54 | K = K.view(bs,h,ne*ne,dk) 55 | V = V.view(bs,h,ne*ne,dk) 56 | mas = R_mas.view(bs,1,ne*ne,1) 57 | att_mas = mas.view(bs,1,ne*ne,1) * mas.view(bs,1,1,ne*ne) # (bs,1,ne*ne,ne*ne) 58 | 59 | alpha = tc.matmul(Q , K.transpose(-1,-2)) 60 | alpha = alpha - (1-att_mas)*100000 61 | alpha = tc.softmax(alpha / (dk ** 0.5) , dim = -1) 62 | 63 | R_Z = tc.matmul(alpha , V).view(bs,h,ne,ne,dk) 64 | 65 | R_Z = (R_Z * R_mas).permute(0,2,3,1,4).contiguous().view(bs,ne,ne,h*dk) 66 | 67 | return R_Z 68 | 69 | class FFN(nn.Module): 70 | def __init__(self , d_model , hidden_size = 1024): 71 | super().__init__() 72 | 73 | self.ln1 = nn.Linear(d_model , hidden_size) 74 | self.ln2 = nn.Linear(hidden_size , d_model) 75 | 76 | #self.reset_params() 77 | 78 | def reset_params(self): 79 | nn.init.xavier_normal_(self.ln1.weight.data) 80 | nn.init.xavier_normal_(self.ln2.weight.data) 81 | 82 | nn.init.constant_(self.ln1.bias.data , 0) 83 | nn.init.constant_(self.ln2.bias.data , 0) 84 | 85 | def forward(self , x , x_mas): 86 | x = F.relu(self.ln1(x)) 87 | x = self.ln2(x) 88 | 89 | return x * x_mas 90 | 91 | class Encoder_Layer(nn.Module): 92 | def __init__(self , h , d_model , hidden_size , dropout = 0.0): 93 | super().__init__() 94 | 95 | assert d_model % h == 0 96 | 97 | self.d_model = d_model 98 | self.hidden_size = hidden_size 99 | 100 | self.att = Attention(h , d_model) 101 | self.lnorm_1 = nn.LayerNorm(d_model) 102 | self.drop_1 = nn.Dropout(dropout) 103 | 104 | self.ffn = FFN(d_model , hidden_size) 105 | self.lnorm_2 = nn.LayerNorm(d_model) 106 | self.drop_2 = nn.Dropout(dropout) 107 | 108 | 109 | def forward(self , R , R_mas): 110 | ''' 111 | R: (bs , ne , ne , d) 112 | R_mas: (bs , ne , ne , 1) 113 | ''' 114 | 115 | #-----attention----- 116 | 117 | R_Z = self.att(R , R_mas) 118 | R = self.lnorm_1(self.drop_1(R_Z) + R) 119 | 120 | 121 | #-----FFN----- 122 | R_Z = self.ffn(R , R_mas) 123 | R = self.lnorm_2(self.drop_2(R_Z) + R) 124 | 125 | return R 126 | 127 | class Encoder(nn.Module): 128 | def __init__(self , h = 4 , d_model = 768 , hidden_size = 1024 , num_layers = 6 , dropout = 0.0, device = 0): 129 | super().__init__() 130 | 131 | self.nemax = 1000 132 | self.d_model = d_model 133 | self.hidden_size = hidden_size 134 | self.num_layers = num_layers 135 | 136 | self.layers = nn.ModuleList([ 137 | Encoder_Layer(h , d_model , hidden_size , dropout = dropout) 138 | for _ in range(num_layers) 139 | ]) 140 | 141 | self.row_emb = nn.Parameter( tc.zeros(self.nemax , device = device) ) 142 | self.col_emb = nn.Parameter( tc.zeros(self.nemax , device = device) ) 143 | self.reset_params() 144 | 145 | def reset_params(self): 146 | nn.init.normal_(self.row_emb.data , 0 , 1e-4) 147 | nn.init.normal_(self.col_emb.data , 0 , 1e-4) 148 | 149 | def forward(self , R , R_mas): 150 | ''' 151 | R: (bs , ne , ne , d) 152 | sent_enc: (bs , n , d) 153 | 154 | ''' 155 | #pdb.set_trace() 156 | bs , ne , ne , d = R.size() 157 | assert d == self.d_model 158 | 159 | R_mas = R_mas.view(bs,ne,ne,1).float() 160 | R = R + self.row_emb[:ne].view(1,ne,1,1) + self.col_emb[:ne].view(1,1,ne,1) 161 | for layer in self.layers: 162 | R = layer(R , R_mas) 163 | 164 | return R 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /ecrim/graph_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch as tc 3 | from torch import nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | import pdb 8 | 9 | 10 | class Attention(nn.Module): 11 | def __init__(self , h , d_model): 12 | super().__init__() 13 | 14 | assert d_model % h == 0 15 | 16 | self.d_model = d_model 17 | self.h = h 18 | self.dk = d_model // h 19 | 20 | self.WQ = nn.Linear(self.dk , self.dk) 21 | self.WK = nn.Linear(self.dk , self.dk) 22 | self.WV = nn.Linear(self.dk , self.dk) 23 | 24 | #self.reset_params() 25 | 26 | def reset_params(self): 27 | nn.init.xavier_normal_(self.WQ.weight.data) 28 | nn.init.xavier_normal_(self.WK.weight.data) 29 | nn.init.xavier_normal_(self.WV.weight.data) 30 | 31 | nn.init.constant_(self.WQ.bias.data , 0) 32 | nn.init.constant_(self.WK.bias.data , 0) 33 | nn.init.constant_(self.WV.bias.data , 0) 34 | 35 | def forward(self , R , E , R_mas , E_mas): 36 | ''' 37 | R: (bs , ne , ne , d) 38 | E: (bs , ne , d) 39 | 40 | R_mas: (bs , ne , ne , 1) 41 | E_mas: (bs , ne , 1) 42 | ''' 43 | 44 | h , dk = self.h , self.dk 45 | bs , ne , d = E.size() 46 | assert d == self.d_model 47 | 48 | R = R.view(bs,ne,ne,h,dk).permute(0,3,1,2,4).contiguous() #(bs , h , ne , ne , dk) 49 | E = E.view(bs,ne,h,dk).permute(0,2,1,3).contiguous() #(bs , h , ne , dk) 50 | R_mas = R_mas.view(bs,1,ne,ne,1) 51 | E_mas = E_mas.view(bs,1,ne,1) 52 | 53 | R_Q , R_K , R_V = self.WQ(R) , self.WK(R) , self.WV(R) 54 | E_Q , E_K , E_V = self.WQ(E) , self.WK(E) , self.WV(E) 55 | 56 | #from R to E 57 | alpha = (E_Q.view(bs,h,ne,1,dk) * R_K).sum(-1) # (bs , h , ne , ne) 58 | #alpha_mask = (E_mas.view(bs,ne,1) * R_mas.view(bs,ne,ne)).bool() 59 | alpha_mask = (E_mas.view(bs,1,1,ne).expand(bs,h,ne,ne)).bool() 60 | alpha = alpha.masked_fill(~alpha_mask , float("-inf")) 61 | alpha = tc.softmax(alpha, dim = -1) 62 | E_Z = (alpha.view(bs,h,ne,ne,1) * R_V).sum(dim = 2) 63 | 64 | #from E to R 65 | beta_0 = (R_Q * E_K.view(bs,h,ne,1,dk)).sum(-1 , keepdim = True) 66 | beta_1 = (R_Q * E_K.view(bs,h,1,ne,dk)).sum(-1 , keepdim = True) 67 | 68 | #beta_0 = beta_0.masked_fill(~beta_mask , float("-inf")) 69 | #beta_1 = beta_1.masked_fill(~beta_mask , float("-inf")) 70 | 71 | betas = tc.cat([beta_0 , beta_1] , dim = -1) 72 | betas = tc.softmax(betas, dim = -1) 73 | beta_0 , beta_1 = betas[:,:,:,:,0] , betas[:,:,:,:,1] 74 | 75 | R_Z = E_V.view(bs,h,ne,1,dk) * beta_0.view(bs,h,ne,ne,1) + E_V.view(bs,h,1,ne,dk) * beta_1.view(bs,h,ne,ne,1) 76 | #R_Z = E_V.view(bs,h,ne,1,dk) * 0.5 + E_V.view(bs,h,1,ne,dk) * 0.5 77 | 78 | R_Z = R_Z.masked_fill(~R_mas.expand(R_Z.size()).bool() , 0) 79 | E_Z = E_Z.masked_fill(~E_mas.expand(E_Z.size()).bool() , 0) 80 | 81 | R_Z = R_Z.view(bs,h,ne,ne,dk).permute(0,2,3,1,4).contiguous().view(bs,ne,ne,h*dk) 82 | E_Z = E_Z.view(bs,h,ne,dk).permute(0,2,1,3).contiguous().view(bs,ne,h*dk) 83 | 84 | return R_Z , E_Z 85 | 86 | class FFN(nn.Module): 87 | def __init__(self , d_model , hidden_size = 1024): 88 | super().__init__() 89 | 90 | self.ln1 = nn.Linear(d_model , hidden_size) 91 | self.ln2 = nn.Linear(hidden_size , d_model) 92 | 93 | #self.reset_params() 94 | 95 | def reset_params(self): 96 | nn.init.xavier_normal_(self.ln1.weight.data) 97 | nn.init.xavier_normal_(self.ln2.weight.data) 98 | 99 | nn.init.constant_(self.ln1.bias.data , 0) 100 | nn.init.constant_(self.ln2.bias.data , 0) 101 | 102 | def forward(self , x , x_mas): 103 | x = F.relu(self.ln1(x)) 104 | x = self.ln2(x) 105 | 106 | return x * x_mas 107 | 108 | class Encoder_Layer(nn.Module): 109 | def __init__(self , h , d_model , hidden_size , dropout = 0.0): 110 | super().__init__() 111 | 112 | assert d_model % h == 0 113 | 114 | self.d_model = d_model 115 | self.hidden_size = hidden_size 116 | 117 | self.att = Attention(h , d_model) 118 | self.lnorm_r_1 = nn.LayerNorm(d_model) 119 | self.lnorm_e_1 = nn.LayerNorm(d_model) 120 | self.drop_1 = nn.Dropout(dropout) 121 | 122 | self.ffn = FFN(d_model , hidden_size) 123 | self.lnorm_r_2 = nn.LayerNorm(d_model) 124 | self.lnorm_e_2 = nn.LayerNorm(d_model) 125 | self.drop_2 = nn.Dropout(dropout) 126 | 127 | 128 | def forward(self , R , E , R_mas , E_mas , sent_enc = None , sent_mas = None): 129 | ''' 130 | R: (bs , ne , ne , d) 131 | E: (bs , ne , d) 132 | sent_enc: (bs , n , d) 133 | 134 | R_mas: (bs , ne , ne , 1) 135 | E_mas: (bs , ne , 1) 136 | sent_mas: (bs , ne , 1) 137 | ''' 138 | 139 | bs , ne , d = E.size() 140 | 141 | #-----attention----- 142 | 143 | R_Z , E_Z = self.att(R , E , R_mas , E_mas) 144 | R = self.lnorm_r_1(self.drop_1(R_Z) + R) 145 | E = self.lnorm_e_1(self.drop_1(E_Z) + E) 146 | 147 | 148 | #-----FFN----- 149 | R_Z , E_Z = self.ffn(R , R_mas) , self.ffn(E , E_mas) 150 | R = self.lnorm_r_2(self.drop_2(R_Z) + R) 151 | E = self.lnorm_e_2(self.drop_2(E_Z) + E) 152 | 153 | 154 | #-----extern-attention----- 155 | #alp = tc.matmul(R , sent_enc.transpose(-1,-2)) 156 | 157 | return R , E 158 | 159 | class Encoder(nn.Module): 160 | def __init__(self , h = 8 , d_model = 768 , hidden_size = 2048 , num_layers = 6 , dropout = 0.0): 161 | super().__init__() 162 | 163 | self.d_model = d_model 164 | self.hidden_size = hidden_size 165 | self.num_layers = num_layers 166 | 167 | self.layers = nn.ModuleList([ 168 | Encoder_Layer(h , d_model , hidden_size , dropout = dropout) 169 | for _ in range(num_layers) 170 | ]) 171 | 172 | def forward(self , R , E , R_mas , E_mas , sent_enc = None , sent_mas = None): 173 | ''' 174 | R: (bs , ne , ne , d) 175 | E: (bs , ne , d) 176 | sent_enc: (bs , n , d) 177 | 178 | ''' 179 | 180 | bs , ne , d = E.size() 181 | assert d == self.d_model 182 | 183 | R_mas = R_mas.view(bs,ne,ne,1).float() 184 | E_mas = E_mas.view(bs,ne,1).float() 185 | R , E = R*R_mas , E*E_mas 186 | if sent_mas is not None: 187 | _ , n , _ = sent_enc.size() 188 | sent_mas = sent_mas.view(bs,n,1).float() 189 | sent_enc = sent_enc * sent_mas 190 | 191 | 192 | for layer in self.layers: 193 | R , E = layer(R , E , R_mas , E_mas , sent_enc , sent_mas) 194 | 195 | return R , E 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /ecrim/pyg_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import Linear 6 | from torch_geometric.data import Data 7 | from torch_geometric.nn import GCNConv 8 | from torch_geometric.nn import global_mean_pool 9 | import pdb 10 | 11 | class Attention(nn.Module): 12 | def __init__(self , h , d_model): 13 | super().__init__() 14 | 15 | assert d_model % h == 0 16 | 17 | self.d_model = d_model 18 | self.h = h 19 | self.dk = d_model // h 20 | 21 | self.WQ = nn.Linear(self.dk , self.dk) 22 | self.WK = nn.Linear(self.dk , self.dk) 23 | self.WV = nn.Linear(self.dk , self.dk) 24 | 25 | self.reset_params() 26 | 27 | def reset_params(self): 28 | nn.init.xavier_normal_(self.WQ.weight.data) 29 | nn.init.xavier_normal_(self.WK.weight.data) 30 | nn.init.xavier_normal_(self.WV.weight.data) 31 | 32 | nn.init.constant_(self.WQ.bias.data , 0) 33 | nn.init.constant_(self.WK.bias.data , 0) 34 | nn.init.constant_(self.WV.bias.data , 0) 35 | 36 | def forward(self , S , R , S_mas , R_mas): 37 | ''' 38 | S: (dps, seq_len, d) 39 | R: (dps, 1, d) 40 | 41 | S_mas: (dps, seq_len, 1) 42 | R_mas: (dps, 1, 1) 43 | ''' 44 | 45 | h , dk = self.h , self.dk 46 | dps, seq_len, d = S.size() 47 | assert d == self.d_model or d == self.reduced_dim 48 | 49 | #pdb.set_trace() 50 | S = S.view(dps,seq_len,h,dk).permute(0,2,1,3).contiguous() #(dps, h, seq_len, dk) 51 | S_mas = S_mas.view(dps, 1, seq_len, 1) 52 | R = R.view(dps,1,h,dk).permute(0,2,1,3).contiguous() #(dps , h , 1 , dk) 53 | R_mas = R_mas.view(dps,1,1,1) 54 | 55 | S_Q , S_K , S_V = self.WQ(S) , self.WK(S) , self.WV(S) 56 | R_Q , R_K , R_V = self.WQ(R) , self.WK(R) , self.WV(R) 57 | 58 | 59 | #from E to R 60 | beta = (S_Q * R_K.view(dps,h,1,dk)).sum(-1 , keepdim = True) 61 | 62 | S_Z = R_V.view(dps,h,1,dk) * beta.view(dps,h,seq_len,1) 63 | 64 | 65 | S_Z = S_Z.masked_fill(~S_mas.expand(S_Z.size()).bool() , 0) 66 | 67 | S_Z = S_Z.view(dps,h,seq_len,dk).permute(0,2,1,3).contiguous().view(dps,seq_len,h*dk) 68 | 69 | return S_Z 70 | 71 | class GCN(torch.nn.Module): 72 | def __init__(self, hidden_channels, dim_node_features, num_classes): 73 | super(GCN, self).__init__() 74 | torch.manual_seed(12345) 75 | self.conv1 = GCNConv(dim_node_features, hidden_channels) 76 | self.conv2 = GCNConv(hidden_channels, hidden_channels) 77 | self.conv3 = GCNConv(hidden_channels, hidden_channels) 78 | self.clf = Linear(hidden_channels, num_classes) 79 | 80 | def forward(self, x, edge_index, batch): 81 | x = self.conv1(x, edge_index) 82 | x = x.relu() 83 | x = self.conv2(x, edge_index) 84 | x = x.relu() 85 | x = self.conv3(x, edge_index) 86 | 87 | return x 88 | def create_edges_sigle(r_node_lists, device): 89 | edge_list = [] 90 | p_node_list = {} 91 | p_num = len(r_node_lists) 92 | r_node_start = p_num 93 | for p_idx, r_nodes in enumerate(r_node_lists): 94 | p_node_list[p_idx] = [i for i in range(r_node_start, r_node_start+len(r_nodes))] 95 | r_node_start += len(r_nodes) 96 | # P-R and R-P 97 | for p, r_list in p_node_list.items(): 98 | source_node = [p] * len(r_list) 99 | destination_node = r_list 100 | pair = [[s, d] for s,d in zip(source_node, destination_node)] 101 | pair_reverse = [[d, s] for s,d in zip(source_node, destination_node)] 102 | edge_list.extend(pair+pair_reverse) 103 | # P-P 104 | p_p_edges = [[i, j] for i in range(p_num) for j in range(p_num) if i!=j] 105 | #print(p_p_edges) 106 | edge_list.extend(p_p_edges) 107 | edges = torch.tensor(np.array(edge_list)).t().to(device) 108 | #print(edges.shape) 109 | #print(edges) 110 | return edges 111 | 112 | 113 | def load_features_single(r_embs, p_embs, dk): 114 | node_features = torch.zeros(len(r_embs)+len(p_embs), dk).to(p_embs.device) 115 | p_num = p_embs.size()[0] 116 | node_features[:p_num] = p_embs 117 | node_features[p_num:] = r_embs 118 | node_features.to(p_embs.device) 119 | return node_features 120 | 121 | 122 | def create_graph_single(r_list, r_embs, p_embs): 123 | device = p_embs.device 124 | edge_index = create_edges_sigle(r_list, device) 125 | node_features = load_features_single(r_embs, p_embs, dk=p_embs.size()[-1]) 126 | data = Data(x=node_features, edge_index=edge_index) 127 | return data 128 | 129 | 130 | def create_edges(r_node_lists, device): 131 | def get_p_b_edges(n, b_start): 132 | edges = [] 133 | for i in range(n): 134 | up = (i+1)*n -1 135 | bottom = i*n 136 | for j in range(bottom, up+1): 137 | if j!= (i*n +i): 138 | edges.append([i, j+b_start]) 139 | edges.append([j+b_start, i]) 140 | else: 141 | continue 142 | return edges 143 | 144 | edge_list = [] 145 | p_node_list = {} 146 | p_num = len(r_node_lists) 147 | r_node_start = p_num 148 | for p_idx, r_nodes in enumerate(r_node_lists): 149 | p_node_list[p_idx] = [i for i in range(r_node_start, r_node_start+len(r_nodes))] 150 | r_node_start += len(r_nodes) 151 | # P-R and R-P 152 | for p, r_list in p_node_list.items(): 153 | source_node = [p] * len(r_list) 154 | destination_node = r_list 155 | pair = [[s, d] for s,d in zip(source_node, destination_node)] 156 | pair_reverse = [[d, s] for s,d in zip(source_node, destination_node)] 157 | edge_list.extend(pair+pair_reverse) 158 | # P-P 159 | # p_p_edges = [[i, j] for i in range(p_num) for j in range(p_num) if i!=j] 160 | # edge_list.extend(p_p_edges) 161 | # print(p_p_edges) 162 | 163 | # P-B 164 | b_node_start = r_node_start 165 | #b_node_num = p_num*p_num 166 | p_b_edges = get_p_b_edges(p_num, b_node_start) 167 | edge_list.extend(p_b_edges) 168 | 169 | edges = torch.tensor(np.array(edge_list)).t().to(device) 170 | 171 | return edges 172 | 173 | 174 | def load_features(r_embs, p_embs, b_embs, dk): 175 | #P-idx 0 ~ |P|-1 176 | #R-idx |P| ~ |P|+|R|-1 177 | #B-idx |P|+|R| ~ |P|+|R|+|B|-1 178 | #pdb.set_trace() 179 | node_features = torch.zeros(len(r_embs)+len(p_embs)+len(b_embs), dk).to(p_embs.device) 180 | p_num = p_embs.size()[0] 181 | r_num = r_embs.size()[0] 182 | b_num = b_embs.size()[0] 183 | node_features[:p_num] = p_embs 184 | node_features[p_num:p_num+r_num] = r_embs 185 | node_features[p_num+r_num:] 186 | node_features.to(p_embs.device) 187 | return node_features 188 | 189 | 190 | def create_graph(r_list, r_embs, p_embs, b_embs): 191 | device = p_embs.device 192 | edge_index = create_edges(r_list, device) 193 | node_features = load_features(r_embs, p_embs, b_embs, dk=p_embs.size()[-1]) 194 | data = Data(x=node_features, edge_index=edge_index) 195 | return data 196 | 197 | if __name__ == "__main__": 198 | r_nodes = [[1,2,3,4]] 199 | r_embs = torch.zeros(4, 128) 200 | p_embs = torch.zeros(1, 128) 201 | b_embs = torch.zeros(3, 128) 202 | g = create_graph(r_nodes, r_embs, p_embs, b_embs) 203 | print(g) 204 | model = GCN(hidden_channels=64, dim_node_features=128, num_classes=10) 205 | print(model) -------------------------------------------------------------------------------- /ecrim/sbert_wk.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, unicode_literals 2 | 3 | import sys 4 | import io 5 | import numpy as np 6 | import logging 7 | import argparse 8 | import torch 9 | import random 10 | 11 | from transformers import * 12 | import utils as utils 13 | 14 | 15 | # ----------------------------------------------- 16 | def set_seed(args): 17 | random.seed(args.seed) 18 | np.random.seed(args.seed) 19 | torch.manual_seed(args.seed) 20 | 21 | class sbert(): 22 | def __init__(self, device): 23 | # ----------------------------------------------- 24 | # Set device 25 | self.device = device 26 | self.model_type = "bert-base-uncased" 27 | self.config = AutoConfig.from_pretrained(self.model_type) 28 | self.config.output_hidden_states = True 29 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_type) 30 | self.model = AutoModelWithLMHead.from_pretrained( 31 | self.model_type, config=self.config) 32 | self.model.to(self.device) 33 | self.max_seq_length = 128 34 | self.embed_method = 'ave_last_hidden' 35 | 36 | def pair_sims(self, sentence1, sentence2): 37 | sentences = [sentence1, sentence2] 38 | sentences_index = [self.tokenizer.encode(s, add_special_tokens=True) for s in sentences] 39 | features_input_ids = [] 40 | features_mask = [] 41 | for sent_ids in sentences_index: 42 | # Truncate if too long 43 | if len(sent_ids) > self.max_seq_length: 44 | sent_ids = sent_ids[: self.max_seq_length] 45 | sent_mask = [1] * len(sent_ids) 46 | # Padding 47 | padding_length = self.max_seq_length - len(sent_ids) 48 | sent_ids += [0] * padding_length 49 | sent_mask += [0] * padding_length 50 | # Length Check 51 | assert len(sent_ids) == self.max_seq_length 52 | assert len(sent_mask) == self.max_seq_length 53 | 54 | features_input_ids.append(sent_ids) 55 | features_mask.append(sent_mask) 56 | 57 | features_mask = np.array(features_mask) 58 | 59 | batch_input_ids = torch.tensor(features_input_ids, dtype=torch.long) 60 | batch_input_mask = torch.tensor(features_mask, dtype=torch.long) 61 | batch = [batch_input_ids.to(self.device), batch_input_mask.to(self.device)] 62 | 63 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 64 | self.model.zero_grad() 65 | 66 | with torch.no_grad(): 67 | features = self.model(**inputs)[1] 68 | 69 | # Reshape features from list of (batch_size, seq_len, hidden_dim) for each hidden state to list 70 | # of (num_hidden_states, seq_len, hidden_dim) for each element in the batch. 71 | all_layer_embedding = torch.stack(features).permute(1, 0, 2, 3).cpu().numpy() 72 | 73 | embed_method = utils.generate_embedding(self.embed_method, features_mask) 74 | embedding = embed_method.embed(self.embed_method, all_layer_embedding) 75 | 76 | similarity = ( 77 | embedding[0].dot(embedding[1]) 78 | / np.linalg.norm(embedding[0]) 79 | / np.linalg.norm(embedding[1]) 80 | ) 81 | #print("The similarity between these two sentences are (from 0-1):", similarity) 82 | return similarity 83 | 84 | 85 | 86 | 87 | # ----------------------------------------------- 88 | 89 | if __name__ == "__main__": 90 | # ----------------------------------------------- 91 | # Settings 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument( 94 | "--batch_size", default=64, type=int, help="batch size for extracting features." 95 | ) 96 | parser.add_argument( 97 | "--max_seq_length", 98 | default=128, 99 | type=int, 100 | help="The maximum total input sequence length after tokenization. Sequences longer " 101 | "than this will be truncated, sequences shorter will be padded.", 102 | ) 103 | parser.add_argument( 104 | "--seed", type=int, default=42, help="random seed for initialization" 105 | ) 106 | parser.add_argument( 107 | "--model_type", 108 | type=str, 109 | default="bert-base-uncased", 110 | help="Pre-trained language models. (default: 'bert-base-uncased')", 111 | ) 112 | parser.add_argument( 113 | "--embed_method", 114 | type=str, 115 | default="ave_last_hidden", 116 | help="Choice of method to obtain embeddings (default: 'ave_last_hidden')", 117 | ) 118 | parser.add_argument( 119 | "--context_window_size", 120 | type=int, 121 | default=2, 122 | help="Topological Embedding Context Window Size (default: 2)", 123 | ) 124 | parser.add_argument( 125 | "--layer_start", 126 | type=int, 127 | default=4, 128 | help="Starting layer for fusion (default: 4)", 129 | ) 130 | parser.add_argument( 131 | "--tasks", type=str, default="all", help="choice of tasks to evaluate on" 132 | ) 133 | args = parser.parse_args() 134 | 135 | # ----------------------------------------------- 136 | # Set device 137 | torch.cuda.set_device(-1) 138 | device = torch.device("cuda", 0) 139 | args.device = device 140 | 141 | # ----------------------------------------------- 142 | # Set seed 143 | set_seed(args) 144 | # Set up logger 145 | # logging.basicConfig(format="%(asctime)s : %(message)s", level=logging.DEBUG) 146 | 147 | # ----------------------------------------------- 148 | # Set Model 149 | params = vars(args) 150 | 151 | config = AutoConfig.from_pretrained(params["model_type"], cache_dir="./cache") 152 | config.output_hidden_states = True 153 | tokenizer = AutoTokenizer.from_pretrained(params["model_type"], cache_dir="./cache") 154 | model = AutoModelWithLMHead.from_pretrained( 155 | params["model_type"], config=config, cache_dir="./cache" 156 | ) 157 | model.to(params["device"]) 158 | 159 | # ----------------------------------------------- 160 | 161 | sentence1 = input("\nEnter the first sentence: ") 162 | sentence2 = input("Enter the second sentence: ") 163 | 164 | sentences = [sentence1, sentence2] 165 | 166 | print("The two sentences we have are:", sentences) 167 | 168 | # ----------------------------------------------- 169 | sentences_index = [tokenizer.encode(s, add_special_tokens=True) for s in sentences] 170 | features_input_ids = [] 171 | features_mask = [] 172 | for sent_ids in sentences_index: 173 | # Truncate if too long 174 | if len(sent_ids) > params["max_seq_length"]: 175 | sent_ids = sent_ids[: params["max_seq_length"]] 176 | sent_mask = [1] * len(sent_ids) 177 | # Padding 178 | padding_length = params["max_seq_length"] - len(sent_ids) 179 | sent_ids += [0] * padding_length 180 | sent_mask += [0] * padding_length 181 | # Length Check 182 | assert len(sent_ids) == params["max_seq_length"] 183 | assert len(sent_mask) == params["max_seq_length"] 184 | 185 | features_input_ids.append(sent_ids) 186 | features_mask.append(sent_mask) 187 | 188 | features_mask = np.array(features_mask) 189 | 190 | batch_input_ids = torch.tensor(features_input_ids, dtype=torch.long) 191 | batch_input_mask = torch.tensor(features_mask, dtype=torch.long) 192 | batch = [batch_input_ids.to(device), batch_input_mask.to(device)] 193 | 194 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 195 | model.zero_grad() 196 | 197 | with torch.no_grad(): 198 | features = model(**inputs)[1] 199 | 200 | # Reshape features from list of (batch_size, seq_len, hidden_dim) for each hidden state to list 201 | # of (num_hidden_states, seq_len, hidden_dim) for each element in the batch. 202 | all_layer_embedding = torch.stack(features).permute(1, 0, 2, 3).cpu().numpy() 203 | 204 | embed_method = utils.generate_embedding(params["embed_method"], features_mask) 205 | embedding = embed_method.embed(params, all_layer_embedding) 206 | 207 | similarity = ( 208 | embedding[0].dot(embedding[1]) 209 | / np.linalg.norm(embedding[0]) 210 | / np.linalg.norm(embedding[1]) 211 | ) 212 | print("The similarity between these two sentences are (from 0-1):", similarity) 213 | -------------------------------------------------------------------------------- /ecrim/topological_sort.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import csv 3 | import ast 4 | import argparse 5 | import pdb 6 | 7 | 8 | class Graph: 9 | ''' 10 | The code for this class is based on geeksforgeeks.com 11 | ''' 12 | def __init__(self,vertices): 13 | self.graph = defaultdict(list) 14 | self.V = vertices 15 | 16 | def addEdge(self, u, v, w): 17 | self.graph[u].append([v, w]) 18 | 19 | def topologicalSortUtil(self, v, visited, stack): 20 | 21 | visited[v] = True 22 | 23 | for i in self.graph[v]: 24 | if visited[i[0]] == False: 25 | self.topologicalSortUtil(i[0], visited, stack) 26 | 27 | stack.insert(0,v) 28 | 29 | def topologicalSort(self): 30 | visited = [False]*self.V 31 | stack =[] 32 | for i in range(self.V): 33 | if visited[i] == False: 34 | self.topologicalSortUtil(i, visited, stack) 35 | 36 | return stack 37 | 38 | def isCyclicUtil(self, v, visited, recStack): 39 | 40 | visited[v] = True 41 | recStack[v] = True 42 | for neighbour in self.graph[v]: 43 | if visited[neighbour[0]] == False: 44 | if self.isCyclicUtil( 45 | neighbour[0], visited, recStack) == True: 46 | return True 47 | elif recStack[neighbour[0]] == True: 48 | self.graph[v].remove(neighbour) 49 | return True 50 | 51 | 52 | recStack[v] = False 53 | return False 54 | 55 | def isCyclic(self): 56 | visited = [False] * self.V 57 | recStack = [False] * self.V 58 | for node in range(self.V): 59 | if visited[node] == False: 60 | if self.isCyclicUtil(node, visited, recStack) == True: 61 | return True 62 | return False 63 | 64 | class Stats(object): 65 | 66 | def __init__(self): 67 | self.n_samp = 0 68 | self.n_sent = 0 69 | self.n_pair = 0 70 | self.corr_samp = 0 71 | self.corr_sent = 0 72 | self.corr_pair = 0 73 | self.lcs_seq = 0 74 | self.tau = 0 75 | self.dist_window = [1, 2, 3] 76 | self.min_dist = [0]*len(self.dist_window) 77 | 78 | def pairwise_metric(self, g): 79 | ''' 80 | This calculates the percentage of skip-bigrams for which the 81 | relative order is predicted correctly. Rouge-S metric. 82 | ''' 83 | common = 0 84 | for vert in range(g.V): 85 | to_nodes = g.graph[vert] 86 | to_nodes = [node[0] for node in to_nodes] 87 | gold_nodes = list(range(vert+1, g.V)) 88 | common += len(set(gold_nodes).intersection(set(to_nodes))) 89 | 90 | return common 91 | 92 | def kendall_tau(self, porder, gorder): 93 | ''' 94 | It calculates the number of inversions required by the predicted 95 | order to reach the correct order. 96 | ''' 97 | pred_pairs, gold_pairs = [], [] 98 | for i in range(len(porder)): 99 | for j in range(i+1, len(porder)): 100 | pred_pairs.append((porder[i], porder[j])) 101 | gold_pairs.append((gorder[i], gorder[j])) 102 | common = len(set(pred_pairs).intersection(set(gold_pairs))) 103 | uncommon = len(gold_pairs) - common 104 | tau = 1 - (2*(uncommon/len(gold_pairs))) 105 | 106 | return tau 107 | 108 | def min_dist_metric(self, porder, gorder): 109 | ''' 110 | It calculates the displacement of sentences within a given window. 111 | ''' 112 | count = [0]*len(self.dist_window) 113 | for i in range(len(porder)): 114 | pidx = i 115 | pval = porder[i] 116 | gidx = gorder.index(pval) 117 | for w, window in enumerate(self.dist_window): 118 | if abs(pidx-gidx) <= window: 119 | count[w] += 1 120 | return count 121 | 122 | def lcs(self, X , Y): 123 | m = len(X) 124 | n = len(Y) 125 | 126 | L = [[None]*(n+1) for i in range(m+1)] 127 | 128 | for i in range(m+1): 129 | for j in range(n+1): 130 | if i == 0 or j == 0 : 131 | L[i][j] = 0 132 | elif X[i-1] == Y[j-1]: 133 | L[i][j] = L[i-1][j-1]+1 134 | else: 135 | L[i][j] = max(L[i-1][j] , L[i][j-1]) 136 | 137 | return L[m][n] 138 | 139 | def sample_match(self, order, gold_order): 140 | ''' 141 | It calculates the percentage of samples for which the entire 142 | sequence was correctly predicted. (PMR) 143 | ''' 144 | return order == gold_order 145 | 146 | def sentence_match(self, order, gold_order): 147 | ''' 148 | It measures the percentage of sentences for which their absolute 149 | position was correctly predicted. (Acc) 150 | ''' 151 | return sum([1 for x in range(len(order)) if order[x] == gold_order[x]]) 152 | 153 | def update_stats(self, nvert, npairs, order, gold_order, g): 154 | self.n_samp += 1 155 | self.n_sent += nvert 156 | self.n_pair += npairs 157 | 158 | if self.sample_match(order, gold_order): 159 | self.corr_samp += 1 160 | self.corr_sent += self.sentence_match(order, gold_order) 161 | self.corr_pair += self.pairwise_metric(g) 162 | self.lcs_seq += self.lcs(order, gold_order) 163 | self.tau += self.kendall_tau(order, gold_order) 164 | window_counts = self.min_dist_metric(order, gold_order) 165 | for w, wc in enumerate(window_counts): 166 | self.min_dist[w] += wc 167 | 168 | def print_stats(self): 169 | print("Perfect Match: " + str(self.corr_samp*100/self.n_samp)) 170 | print("Sentence Accuracy: " + str(self.corr_sent*100/self.n_sent)) 171 | print("Rouge-S: " + str(self.corr_pair*100/self.n_pair)) 172 | print("LCS: " + str(self.lcs_seq*100/self.n_sent)) 173 | print("Kendall Tau Ratio: " + str(self.tau/self.n_samp)) 174 | for w, window in enumerate(self.dist_window): 175 | print("Min Dist Metric for window " + str(window) + ": " + \ 176 | str(self.min_dist[w]*100/self.n_sent)) 177 | 178 | def convert_to_graph(data): 179 | 180 | stats = Stats() 181 | i = 0 182 | no_docs, no_sents = 0, 0 183 | 184 | while i < len(data): 185 | ids = data[i][0] 186 | 187 | # get no vertices 188 | docid, nvert, npairs = ids.split('-') 189 | docid, nvert, npairs = int(docid), int(nvert), int(npairs) 190 | 191 | # create graph obj 192 | g = Graph(nvert) 193 | 194 | #read pred label 195 | for j in range(i, i+npairs): 196 | pred = int(data[j][8]) 197 | log0, log1 = float(data[j][6]), float(data[j][7]) 198 | pos_s1, pos_s2 = int(data[j][4]), int(data[j][5]) 199 | 200 | if pred == 0: 201 | g.addEdge(pos_s2, pos_s1, log0) 202 | elif pred == 1: 203 | g.addEdge(pos_s1, pos_s2, log1) 204 | 205 | i += npairs 206 | 207 | while g.isCyclic(): 208 | g.isCyclic() 209 | 210 | order = g.topologicalSort() 211 | no_sents += nvert 212 | no_docs += 1 213 | gold_order = list(range(nvert)) 214 | stats.update_stats(nvert, npairs, order, gold_order, g) 215 | 216 | if len(order) != len(gold_order): 217 | print("yes") 218 | 219 | return stats 220 | 221 | def readf(filename): 222 | data = [] 223 | with open(filename, "r") as inp: 224 | spam = csv.reader(inp, delimiter='\t') 225 | for row in spam: 226 | data.append(row) 227 | return data 228 | 229 | def main(): 230 | parser = argparse.ArgumentParser() 231 | ## Required parameters 232 | parser.add_argument("--file_path", default=None, type=str, 233 | required=True, help="The input data dir.") 234 | args = parser.parse_args() 235 | 236 | data = readf(args.file_path) 237 | stats = convert_to_graph(data) 238 | stats.print_stats() 239 | 240 | if __name__ == "__main__": 241 | main() -------------------------------------------------------------------------------- /ecrim/buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import copy 3 | from transformers import AutoTokenizer 4 | from utils import CAPACITY, BLOCK_SIZE, DEFAULT_MODEL_NAME 5 | import random 6 | from bisect import bisect_left 7 | from itertools import chain 8 | import pdb 9 | class Block: 10 | """Similar to CogLTX(https://proceedings.neurips.cc/paper/2020/file/96671501524948bc3937b4b30d0e57b9-Paper.pdf). 11 | """ 12 | tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME) 13 | def __init__(self, ids, pos, blk_type=1, **kwargs): 14 | self.ids = ids 15 | self.pos = pos 16 | self.blk_type = blk_type 17 | self.relevance = 0 18 | self.estimation = 0 19 | self.entail_score = 0 20 | self.docid = 0 21 | self.h_flag = 0 22 | self.t_flag = 0 23 | self.__dict__.update(kwargs) 24 | def __lt__(self, rhs): 25 | return self.blk_type < rhs.blk_type or (self.blk_type == rhs.blk_type and self.pos < rhs.pos) 26 | def __ne__(self, rhs): 27 | return self.pos != rhs.pos or self.blk_type != rhs.blk_type 28 | def __len__(self): 29 | return len(self.ids) 30 | def __str__(self): 31 | return Block.tokenizer.convert_tokens_to_string(Block.tokenizer.convert_ids_to_tokens(self.ids)) 32 | 33 | class Buffer: 34 | @staticmethod 35 | def split_document_into_blocks(d, tokenizer, cnt=0, hard=True, properties=None, docid=0): 36 | ret = Buffer() 37 | updiv = lambda a,b: (a - 1) // b + 1 38 | if hard: 39 | for sid, tsen in enumerate(d): 40 | psen = properties[sid] if properties is not None else [] 41 | num = updiv(len(tsen), BLOCK_SIZE) # cls 42 | bsize = updiv(len(tsen), num) 43 | for i in range(num): 44 | st, en = i * bsize, min((i + 1) * bsize, len(tsen)) 45 | cnt += 1 46 | tmp = tsen[st: en] + [tokenizer.sep_token] 47 | # inject properties into blks 48 | tmp_kwargs = {} 49 | for p in psen: 50 | if len(p) == 2: 51 | tmp_kwargs[p[0]] = p[1] 52 | elif len(p) == 3: 53 | if st <= p[1] < en: 54 | tmp_kwargs[p[0]] = (p[1] - st, p[2]) 55 | else: 56 | raise ValueError('Invalid property {}'.format(p)) 57 | ret.insert(Block(tokenizer.convert_tokens_to_ids(tmp), cnt, **tmp_kwargs)) 58 | else: 59 | # d is only a list of tokens, not split. 60 | # properties are also a list of tuples. 61 | end_tokens = {'\n':0, '.':1, '?':1, '!':1, ',':2} 62 | for k, v in list(end_tokens.items()): 63 | end_tokens['Ġ' + k] = v 64 | sen_cost, break_cost = 4, 8 65 | poses = [(i, end_tokens[tok]) for i, tok in enumerate(d) if tok in end_tokens] 66 | poses.insert(0, (-1, 0)) 67 | if poses[-1][0] < len(d) - 1: 68 | poses.append((len(d) - 1, 0)) 69 | x = 0 70 | while x < len(poses) - 1: 71 | if poses[x + 1][0] - poses[x][0] > BLOCK_SIZE: 72 | poses.insert(x + 1, (poses[x][0] + BLOCK_SIZE, break_cost)) 73 | x += 1 74 | 75 | best = [(0, 0)] 76 | for i, (p, cost) in enumerate(poses): 77 | if i == 0: 78 | continue 79 | best.append((-1, 100000)) 80 | for j in range(i-1, -1, -1): 81 | if p - poses[j][0] > BLOCK_SIZE: 82 | break 83 | value = best[j][1] + cost + sen_cost 84 | if value < best[i][1]: 85 | best[i] = (j, value) 86 | assert best[i][0] >= 0 87 | intervals, x = [], len(poses) - 1 88 | while x > 0: 89 | l = poses[best[x][0]][0 ] 90 | intervals.append((l + 1, poses[x][0] + 1)) 91 | x = best[x][0] 92 | if properties is None: 93 | properties = [] 94 | for st, en in reversed(intervals): 95 | # copy from hard version 96 | cnt += 1 97 | tmp = d[st: en] + [tokenizer.sep_token] 98 | # inject properties into blks 99 | tmp_kwargs = {} 100 | for p in properties: 101 | if len(p) == 2: 102 | tmp_kwargs[p[0]] = p[1] 103 | elif len(p) == 3: 104 | if st <= p[1] < en: 105 | tmp_kwargs[p[0]] = (p[1] - st, p[2]) 106 | else: 107 | raise ValueError('Invalid property {}'.format(p)) 108 | ret.insert(Block(tokenizer.convert_tokens_to_ids(tmp), cnt, **tmp_kwargs)) 109 | for blk in ret.blocks: 110 | blk.docid = docid 111 | return ret, cnt 112 | 113 | def __init__(self): 114 | self.blocks = [] 115 | 116 | def __add__(self, buf): 117 | ret = Buffer() 118 | ret.blocks = self.blocks + buf.blocks 119 | return ret 120 | 121 | def __len__(self): 122 | return len(self.blocks) 123 | 124 | def __getitem__(self, key): 125 | return self.blocks[key] 126 | 127 | def __str__(self): 128 | return ''.join([str(b)+'\n' for b in self.blocks]) 129 | 130 | def clone(self): 131 | ret = Buffer() 132 | ret.blocks = self.blocks.copy() 133 | return ret 134 | 135 | def calc_size(self): 136 | return sum([len(b) for b in self.blocks]) 137 | 138 | def block_ends(self): 139 | t, ret = 0, [] 140 | for b in self.blocks: 141 | t += len(b) 142 | ret.append(t) 143 | return ret 144 | 145 | def insert(self, b, reverse=True): 146 | if not reverse: 147 | for index in range(len(self.blocks) + 1): 148 | if index >= len(self.blocks) or b < self.blocks[index]: 149 | self.blocks.insert(index, b) 150 | break 151 | else: 152 | for index in range(len(self.blocks), -1, -1): 153 | if index == 0 or self.blocks[index - 1] < b: 154 | self.blocks.insert(index, b) 155 | break 156 | 157 | def merge(self, buf): 158 | ret = Buffer() 159 | t1, t2 = 0, 0 160 | while t1 < len(self.blocks) or t2 < len(buf): 161 | if t1 < len(self.blocks) and (t2 >= len(buf) or self.blocks[t1] < buf.blocks[t2]): 162 | ret.blocks.append(self.blocks[t1]) 163 | t1 += 1 164 | else: 165 | ret.blocks.append(buf.blocks[t2]) 166 | t2 += 1 167 | return ret 168 | 169 | def filtered(self, fltr: 'function blk, index->bool', need_residue=False): 170 | ret, ret2 = Buffer(), Buffer() 171 | for i, blk in enumerate(self.blocks): 172 | if fltr(blk, i): 173 | ret.blocks.append(blk) 174 | else: 175 | ret2.blocks.append(blk) 176 | if need_residue: 177 | return ret, ret2 178 | else: 179 | return ret 180 | 181 | def random_sample(self, size): 182 | assert size <= len(self.blocks) 183 | index = sorted(random.sample(range(len(self.blocks)), size)) 184 | ret = Buffer() 185 | ret.blocks = [self.blocks[i] for i in index] 186 | return ret 187 | 188 | def sort_(self): 189 | self.blocks.sort() 190 | return self 191 | 192 | def fill(self, buf): 193 | ret, tmp_buf, tmp_size = [], self.clone(), self.calc_size() 194 | for blk in buf: 195 | if tmp_size + len(blk) > CAPACITY: 196 | ret.append(tmp_buf) 197 | tmp_buf, tmp_size = self.clone(), self.calc_size() 198 | tmp_buf.blocks.append(blk) 199 | tmp_size += len(blk) 200 | ret.append(tmp_buf) 201 | return ret 202 | 203 | def export(self, device=None, length=None, out=None): 204 | if out is None: 205 | if length is None: 206 | total_length = self.calc_size() 207 | if total_length > CAPACITY: 208 | raise ValueError('export inputs larger than capacity') 209 | else: 210 | total_length = length * len(self.blocks) 211 | ids, att_masks, type_ids = torch.zeros(3, total_length, dtype=torch.long, device=device) 212 | else: # must be zeros and big enough 213 | ids, att_masks, type_ids = out 214 | att_masks.zero_() 215 | t = 0 216 | for b in self.blocks: 217 | try: 218 | ids[t:t + len(b)] = torch.tensor(b.ids, dtype=torch.long, device=device) 219 | except: 220 | #pdb.set_trace() 221 | ids[t-1 :t-1 + len(b)] = torch.tensor(b.ids, dtype=torch.long, device=device) 222 | #pdb.set_trace() 223 | # if b.blk_type == 1: 224 | # type_ids[t:w] = 1 # sentence B 225 | att_masks[t:t + len(b)] = 1 # attention_mask 226 | t += len(b) if length is None else length 227 | return ids, att_masks, type_ids 228 | 229 | def export_01_turn(self, device=None, length=None, out=None): 230 | if out is None: 231 | if length is None: 232 | total_length = self.calc_size() 233 | if total_length > CAPACITY: 234 | raise ValueError('export inputs larger than capacity') 235 | else: 236 | total_length = length * len(self.blocks) 237 | ids, att_masks, type_ids = torch.zeros(3, total_length, dtype=torch.long, device=device) 238 | else: # must be zeros and big enough 239 | ids, att_masks, type_ids = out 240 | att_masks.zero_() 241 | t = 0 242 | for b in self.blocks: 243 | try: 244 | ids[t:t + len(b)] = torch.tensor(b.ids, dtype=torch.long, device=device) # id 245 | except: 246 | #pdb.set_trace() 247 | print("capacity:", 512 - t, "blk_length:", len(b)) 248 | try: 249 | ids[t-1 :t-1 + len(b)] = torch.tensor(b.ids, dtype=torch.long, device=device) 250 | except: 251 | 252 | print(ids, len(ids)) 253 | print(b.ids, len(b.ids)) 254 | try: 255 | ids[t : -1] = torch.tensor(b.ids, dtype=torch.long, device=device)[:512 - t - 1] 256 | ids[-1] = torch.tensor([102], dtype=torch.long, device=device) 257 | except Exception as e: 258 | print(e) 259 | pdb.set_trace() 260 | #pdb.set_trace() 261 | # if b.blk_type == 1: 262 | # type_ids[t:w] = 1 # sentence B 263 | att_masks[t:t + len(b)] = 1 # attention_mask 264 | t += len(b) if length is None else length 265 | sentences = [] 266 | sentences_with_sep = [] 267 | ptr = 0 268 | ids_list = ids.tolist() 269 | for i in range(len(ids_list)): 270 | if ids_list[i] == 102: 271 | sentences.append(ids_list[ptr:i]) 272 | sentences_with_sep.append(ids_list[ptr:i+1]) 273 | ptr = i+1 274 | sentences[-1].append(102) 275 | s_ptr = 0 276 | for s_idx in range(len(sentences_with_sep)): 277 | type_ids[s_ptr:s_ptr + len(sentences_with_sep[s_idx])] = torch.tensor([s_idx%2]* len(sentences_with_sep[s_idx]), dtype=torch.long, device=device) 278 | s_ptr += len(sentences_with_sep[s_idx]) 279 | return ids, att_masks, type_ids 280 | 281 | def export_01_doc(self, device=None, length=None, out=None): 282 | if out is None: 283 | if length is None: 284 | total_length = self.calc_size() 285 | if total_length > CAPACITY: 286 | raise ValueError('export inputs larger than capacity') 287 | else: 288 | total_length = length * len(self.blocks) 289 | ids, att_masks, type_ids = torch.zeros(3, total_length, dtype=torch.long, device=device) 290 | else: # must be zeros and big enough 291 | ids, att_masks, type_ids = out 292 | att_masks.zero_() 293 | t = 0 294 | 295 | #pdb.set_trace() 296 | doc0_ids = [] 297 | doc1_ids = [] 298 | for b in self.blocks: 299 | if b.docid == 0: 300 | doc0_ids.extend(b.ids) 301 | elif b.docid == 1: 302 | doc1_ids.extend(b.ids) 303 | #pdb.set_trace() 304 | ids[:len(doc0_ids)] = torch.tensor(doc0_ids, dtype=torch.long, device=device) 305 | try: 306 | ids[len(doc0_ids):len(doc0_ids)+len(doc1_ids)] = torch.tensor(doc1_ids, dtype=torch.long, device=device) 307 | except: 308 | pdb.set_trace() 309 | print(doc1_ids) 310 | doc1_ids.reverse() 311 | doc1_ids.remove(102) 312 | doc1_ids.remove(102) 313 | doc1_ids.reverse() 314 | doc1_ids.append(102) 315 | ids[len(doc0_ids):len(doc0_ids)+len(doc1_ids)] = torch.tensor(doc1_ids, dtype=torch.long, device=device) 316 | type_ids[:len(doc0_ids)] = torch.tensor([0]*len(doc0_ids), dtype=torch.long, device=device) 317 | type_ids[len(doc0_ids):len(doc0_ids)+len(doc1_ids)] = torch.tensor([1]*len(doc1_ids), dtype=torch.long, device=device) 318 | return ids, att_masks, type_ids 319 | 320 | def export_as_batch(self, device, length=BLOCK_SIZE+1, add_cls=False): 321 | ids, att_masks, type_ids = self.export(device, length, add_cls=add_cls) 322 | return ids.view(-1, length), att_masks.view(-1, length), type_ids.view(-1, length) 323 | 324 | def export_relevance(self, device, length=None, dtype=torch.long, out=None): 325 | if out is None: 326 | total_length = self.calc_size() if length is None else length * len(self.blocks) 327 | relevance = torch.zeros(total_length, dtype=dtype, device=device) 328 | else: 329 | relevance = out 330 | t = 0 331 | for b in self.blocks: 332 | w = t + (len(b) if length is None else length) 333 | if b.relevance >= 1: 334 | relevance[t: w] = 1 335 | t = w 336 | return relevance 337 | 338 | def buffer_collate(batch): 339 | return batch 340 | -------------------------------------------------------------------------------- /ecrim/data_helper.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import os 4 | import re 5 | import logging 6 | import random 7 | import torch 8 | from torch.utils.data import Dataset 9 | from tqdm import tqdm 10 | import pdb 11 | from buffer import Buffer 12 | from utils import CAPACITY, BLOCK_SIZE, ForkedPdb 13 | import numpy as np 14 | import torch 15 | from tqdm import tqdm 16 | from transformers import ( 17 | AutoModelForSequenceClassification, 18 | AutoTokenizer, 19 | AutoConfig, 20 | T5ForConditionalGeneration, 21 | ) 22 | from typing import Dict, List 23 | from a2t.base import Classifier, np_softmax 24 | from collections import defaultdict 25 | from dataclasses import dataclass 26 | 27 | class SimpleListDataset(Dataset): 28 | def __init__(self, source): 29 | if isinstance(source, str): 30 | with open(source, 'rb') as fin: 31 | logging.info('Loading dataset...') 32 | self.dataset = pickle.load(fin) 33 | elif isinstance(source, list): 34 | self.dataset = source 35 | if not isinstance(self.dataset, list): 36 | raise ValueError('The source of SimpleListDataset is not a list.') 37 | def __getitem__(self, index): 38 | return self.dataset[index] 39 | def __len__(self): 40 | return len(self.dataset) 41 | 42 | class BlkPosInterface: 43 | def __init__(self, dataset): 44 | assert isinstance(dataset, SimpleListDataset) 45 | pdb.set_trace() 46 | self.d = {} 47 | self.dataset = dataset 48 | for bufs in dataset: 49 | for buf in bufs: 50 | for blk in buf: 51 | assert blk.pos not in self.d 52 | self.d[blk.pos] = blk 53 | def set_property(self, pos, key, value=None): 54 | blk = self.d[pos] 55 | if value is not None: 56 | setattr(blk, key, value) 57 | elif hasattr(blk, key): 58 | delattr(blk, key) 59 | def apply_changes_from_file(self, filename): 60 | with open(filename, 'r') as fin: 61 | for line in fin: 62 | tmp = [ 63 | int(s) if s.isdigit() or s[0] == '-' and s[1:].isdigit() else s 64 | for s in line.split() 65 | ] 66 | self.set_property(*tmp) 67 | def apply_changes_from_dir(self, tmp_dir): 68 | for shortname in os.listdir(tmp_dir): 69 | filename = os.path.join(tmp_dir, shortname) 70 | if shortname.startswith('changes_'): 71 | self.apply_changes_from_file(filename) 72 | os.replace(filename, os.path.join(tmp_dir, 'backup_' + shortname)) 73 | 74 | def collect_estimations_from_dir(self, tmp_dir): 75 | ret = [] 76 | for shortname in os.listdir(tmp_dir): 77 | filename = os.path.join(tmp_dir, shortname) 78 | if shortname.startswith('estimations_'): 79 | with open(filename, 'r') as fin: 80 | for line in fin: 81 | l = line.split() 82 | pos, estimation = int(l[0]), float(l[1]) 83 | self.d[pos].estimation = estimation 84 | os.replace(filename, os.path.join(tmp_dir, 'backup_' + shortname)) 85 | 86 | def build_random_buffer(self, num_samples): 87 | ForkedPdb().set_trace() 88 | n0, n1 = [int(s) for s in num_samples.split(',')][:2] 89 | ret = [] 90 | max_blk_num = CAPACITY // (BLOCK_SIZE + 1) 91 | logging.info('building buffers for introspection...') 92 | for qbuf, dbuf in tqdm(self.dataset): 93 | # 1. continous 94 | lb = max_blk_num - len(qbuf) 95 | st = random.randint(0, max(0, len(dbuf) - lb * n0)) 96 | for i in range(n0): 97 | buf = Buffer() 98 | buf.blocks = qbuf.blocks + dbuf.blocks[st + i * lb:st + (i+1) * lb] 99 | ret.append(buf) 100 | # 2. pos + neg 101 | pbuf, nbuf = dbuf.filtered(lambda blk, idx: blk.relevance >= 1, need_residue=True) 102 | for i in range(n1): 103 | selected_pblks = random.sample(pbuf.blocks, min(lb, len(pbuf))) 104 | selected_nblks = random.sample(nbuf.blocks, min(lb - len(selected_pblks), len(nbuf))) 105 | buf = Buffer() 106 | buf.blocks = qbuf.blocks + selected_pblks + selected_nblks 107 | ret.append(buf.sort_()) 108 | return SimpleListDataset(ret) 109 | 110 | def build_promising_buffer(self, num_samples): 111 | n2, n3 = [int(x) for x in num_samples.split(',')][2:] 112 | ret = [] 113 | max_blk_num = CAPACITY // (BLOCK_SIZE + 1) 114 | logging.info('building buffers for reasoning...') 115 | for qbuf, dbuf in tqdm(self.dataset): 116 | #1. retrieve top n2*(max-len(pos)) estimations into buf 2. cut 117 | pbuf, nbuf = dbuf.filtered(lambda blk, idx: blk.relevance >= 1, need_residue=True) 118 | if len(pbuf) >= max_blk_num - len(qbuf): 119 | pbuf = pbuf.random_sample(max_blk_num - len(qbuf) - 1) 120 | lb = max_blk_num - len(qbuf) - len(pbuf) 121 | estimations = torch.tensor([blk.estimation for blk in nbuf], dtype=torch.long) 122 | keeped_indices = estimations.argsort(descending=True)[:n2 * lb] 123 | selected_nblks = [blk for i, blk in enumerate(nbuf) if i in keeped_indices] 124 | while 0 < len(selected_nblks) < n2 * lb: 125 | selected_nblks = selected_nblks * (n2 * lb // len(selected_nblks) + 1) 126 | for i in range(n2): 127 | buf = Buffer() 128 | buf.blocks = qbuf.blocks + pbuf.blocks + selected_nblks[i * lb: (i+1) * lb] 129 | ret.append(buf.sort_()) 130 | for i in range(n3): 131 | buf = Buffer() 132 | buf.blocks = qbuf.blocks + pbuf.blocks + random.sample(nbuf.blocks, min(len(nbuf), lb)) 133 | ret.append(buf.sort_()) 134 | return SimpleListDataset(ret) 135 | 136 | def find_lastest_checkpoint(checkpoints_dir, epoch=False): 137 | lastest = (-1, '') 138 | if os.path.exists(checkpoints_dir): 139 | for shortname in os.listdir(checkpoints_dir): 140 | m = re.match(r'_ckpt_epoch_(\d+).+', shortname) 141 | if m is not None and int(m.group(1)) > lastest[0]: 142 | lastest = (int(m.group(1)), shortname) 143 | return os.path.join(checkpoints_dir, lastest[-1]) if not epoch else lastest[0] 144 | 145 | 146 | @dataclass 147 | class REInputFeatures: 148 | subj: str 149 | obj: str 150 | context: str 151 | pair_type: str = None 152 | label: str = None 153 | 154 | 155 | class _NLIRelationClassifier(Classifier): 156 | def __init__( 157 | self, 158 | labels: List[str], 159 | *args, 160 | pretrained_model: str = "roberta-large-mnli", 161 | use_cuda=True, 162 | half=False, 163 | verbose=True, 164 | negative_threshold=0.95, 165 | negative_idx=0, 166 | max_activations=np.inf, 167 | valid_conditions=None, 168 | **kwargs, 169 | ): 170 | super().__init__( 171 | labels, 172 | pretrained_model=pretrained_model, 173 | use_cuda=use_cuda, 174 | verbose=verbose, 175 | half=half, 176 | ) 177 | # self.ent_pos = entailment_position 178 | # self.cont_pos = -1 if self.ent_pos == 0 else 0 179 | self.negative_threshold = negative_threshold 180 | self.negative_idx = negative_idx 181 | self.max_activations = max_activations 182 | self.n_rel = len(labels) 183 | # for label in labels: 184 | # assert '{subj}' in label and '{obj}' in label 185 | 186 | if valid_conditions: 187 | self.valid_conditions = {} 188 | rel2id = {r: i for i, r in enumerate(labels)} 189 | self.n_rel = len(rel2id) 190 | for relation, conditions in valid_conditions.items(): 191 | if relation not in rel2id: 192 | continue 193 | for condition in conditions: 194 | if condition not in self.valid_conditions: 195 | self.valid_conditions[condition] = np.zeros(self.n_rel) 196 | self.valid_conditions[condition][rel2id["no_relation"]] = 1.0 197 | self.valid_conditions[condition][rel2id[relation]] = 1.0 198 | 199 | else: 200 | self.valid_conditions = None 201 | 202 | def idx2label(idx): 203 | return self.labels[idx] 204 | 205 | self.idx2label = np.vectorize(idx2label) 206 | 207 | def _initialize(self, pretrained_model): 208 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model) 209 | self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model) 210 | self.config = AutoConfig.from_pretrained(pretrained_model) 211 | self.ent_pos = self.config.label2id.get("ENTAILMENT", self.config.label2id.get("entailment", None)) 212 | if self.ent_pos is None: 213 | raise ValueError("The model config must contain ENTAILMENT label in the label2id dict.") 214 | else: 215 | self.ent_pos = int(self.ent_pos) 216 | 217 | def _run_batch(self, batch, multiclass=False): 218 | with torch.no_grad(): 219 | input_ids = self.tokenizer.batch_encode_plus(batch, padding=True, truncation=True) 220 | input_ids = torch.tensor(input_ids["input_ids"]).to(self.device) 221 | output = self.model(input_ids)[0].detach().cpu().numpy() 222 | if multiclass: 223 | output = np.exp(output) / np.exp(output).sum( 224 | -1, keepdims=True 225 | ) # np.exp(output[..., [self.cont_pos, self.ent_pos]]).sum(-1, keepdims=True) 226 | output = output[..., self.ent_pos].reshape(input_ids.shape[0] // len(self.labels), -1) 227 | 228 | return output 229 | 230 | def __call__( 231 | self, 232 | features: List[REInputFeatures], 233 | batch_size: int = 1, 234 | multiclass=False, 235 | ): 236 | if not isinstance(features, list): 237 | features = [features] 238 | 239 | batch, outputs = [], [] 240 | for i, feature in tqdm(enumerate(features), total=len(features)): 241 | sentences = [ 242 | f"{feature.context} {self.tokenizer.sep_token} {label_template.format(subj=feature.subj, obj=feature.obj)}." 243 | for label_template in self.labels 244 | ] 245 | batch.extend(sentences) 246 | 247 | if (i + 1) % batch_size == 0: 248 | output = self._run_batch(batch, multiclass=multiclass) 249 | outputs.append(output) 250 | batch = [] 251 | 252 | if len(batch) > 0: 253 | output = self._run_batch(batch, multiclass=multiclass) 254 | outputs.append(output) 255 | 256 | outputs = np.vstack(outputs) 257 | 258 | return outputs 259 | 260 | def _apply_negative_threshold(self, probs): 261 | activations = (probs >= self.negative_threshold).sum(-1).astype(np.int) 262 | idx = np.logical_or( 263 | activations == 0, activations >= self.max_activations 264 | ) # If there are no activations then is a negative example, if there are too many, then is a noisy example 265 | probs[idx, self.negative_idx] = 1.00 266 | return probs 267 | 268 | def _apply_valid_conditions(self, probs, features: List[REInputFeatures]): 269 | mask_matrix = np.stack( 270 | [self.valid_conditions.get(feature.pair_type, np.zeros(self.n_rel)) for feature in features], 271 | axis=0, 272 | ) 273 | probs = probs * mask_matrix 274 | 275 | return probs 276 | 277 | def predict( 278 | self, 279 | contexts: List[str], 280 | batch_size: int = 1, 281 | return_labels: bool = True, 282 | return_confidences: bool = False, 283 | topk: int = 1, 284 | ): 285 | output = self(contexts, batch_size) 286 | topics = np.argsort(output, -1)[:, ::-1][:, :topk] 287 | if return_labels: 288 | topics = self.idx2label(topics) 289 | if return_confidences: 290 | topics = np.stack((topics, np.sort(output, -1)[:, ::-1][:, :topk]), -1).tolist() 291 | topics = [ 292 | [(int(label), float(conf)) if not return_labels else (label, float(conf)) for label, conf in row] 293 | for row in topics 294 | ] 295 | else: 296 | topics = topics.tolist() 297 | if topk == 1: 298 | topics = [row[0] for row in topics] 299 | 300 | return topics 301 | 302 | 303 | class NLIRelationClassifierWithMappingHead(_NLIRelationClassifier): 304 | def __init__( 305 | self, 306 | labels: List[str], 307 | template_mapping: Dict[str, str], 308 | pretrained_model: str = "roberta-large-mnli", 309 | valid_conditions: Dict[str, list] = None, 310 | *args, 311 | **kwargs, 312 | ): 313 | 314 | self.template_mapping_reverse = defaultdict(list) 315 | for key, value in template_mapping.items(): 316 | for v in value: 317 | self.template_mapping_reverse[v].append(key) 318 | self.new_topics = list(self.template_mapping_reverse.keys()) 319 | 320 | self.target_labels = labels 321 | self.new_labels2id = {t: i for i, t in enumerate(self.new_topics)} 322 | self.mapping = defaultdict(list) 323 | for key, value in template_mapping.items(): 324 | self.mapping[key].extend([self.new_labels2id[v] for v in value]) 325 | 326 | super().__init__( 327 | self.new_topics, 328 | *args, 329 | pretrained_model=pretrained_model, 330 | valid_conditions=None, 331 | **kwargs, 332 | ) 333 | 334 | if valid_conditions: 335 | self.valid_conditions = {} 336 | rel2id = {r: i for i, r in enumerate(labels)} 337 | self.n_rel = len(rel2id) 338 | for relation, conditions in valid_conditions.items(): 339 | if relation not in rel2id: 340 | continue 341 | for condition in conditions: 342 | if condition not in self.valid_conditions: 343 | self.valid_conditions[condition] = np.zeros(self.n_rel) 344 | self.valid_conditions[condition][rel2id["no_relation"]] = 1.0 345 | self.valid_conditions[condition][rel2id[relation]] = 1.0 346 | 347 | else: 348 | self.valid_conditions = None 349 | 350 | def idx2label(idx): 351 | return self.target_labels[idx] 352 | 353 | self.idx2label = np.vectorize(idx2label) 354 | 355 | def __call__(self, features: List[REInputFeatures], batch_size=1, multiclass=True): 356 | outputs = super().__call__(features, batch_size, multiclass) 357 | outputs = np.hstack( 358 | [ 359 | np.max(outputs[:, self.mapping[label]], axis=-1, keepdims=True) 360 | if label in self.mapping 361 | else np.zeros((outputs.shape[0], 1)) 362 | for label in self.target_labels 363 | ] 364 | ) 365 | outputs = np_softmax(outputs) if not multiclass else outputs 366 | 367 | if self.valid_conditions: 368 | outputs = self._apply_valid_conditions(outputs, features) 369 | 370 | outputs = self._apply_negative_threshold(outputs) 371 | 372 | return outputs -------------------------------------------------------------------------------- /ecrim/sentence_reordering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers import AutoTokenizer, BertModel 4 | import pdb 5 | from tqdm import tqdm 6 | from itertools import groupby 7 | from topological_sort import Graph 8 | import random 9 | class SentReOrdering(): 10 | def __init__(self, doc1_sentences, doc2_sentences, encoder, device, tokenizer, h, t, sbert_wk): 11 | self.encoder = encoder 12 | self.doc1_sentences = doc1_sentences 13 | self.doc2_sentences = doc2_sentences 14 | self.device = device 15 | self.max_len = 512 16 | self.tokenizer = tokenizer 17 | self.h = h 18 | self.t = t 19 | self.sentences = self.doc1_sentences + self.doc2_sentences 20 | self.sbert = sbert_wk 21 | 22 | def pair_encoding(self): 23 | pairs = [] 24 | for i,sent_1 in tqdm(enumerate(self.sentences)): 25 | for j,sent_2 in tqdm(enumerate(self.sentences)): 26 | if sent_1 != sent_2: 27 | pair_len = len(sent_1) + len(sent_2) + 4 28 | input_ids = torch.zeros(1, self.max_len, dtype=torch.long) 29 | token_type_ids = torch.zeros(1, self.max_len, dtype=torch.long) 30 | attention_mask = torch.zeros(1, self.max_len, dtype=torch.long) 31 | token_type_ids[0][:pair_len] = torch.tensor([[0] * (len(sent_1)+2) + [1] * (len(sent_2)+2)]).unsqueeze(0) 32 | input_ids[0][:pair_len] = torch.tensor(self.tokenizer.convert_tokens_to_ids(['[CLS]'] + sent_1 + ['[SEP]'] + ['CLS'] + sent_2 + ['[SEP]'])).unsqueeze(0) 33 | attention_mask[0][:pair_len] = torch.ones([1] * pair_len) 34 | pair_encoding = self.encoder(input_ids, attention_mask, token_type_ids)[0] 35 | sent_1_embedding = pair_encoding[0, :len(sent_1)+2][0] 36 | sent_2_embedding = pair_encoding[0, len(sent_1)+2:len(sent_1)+len(sent_2)+4][0] 37 | similarity = F.cosine_similarity(sent_1_embedding, sent_2_embedding, dim=0) 38 | F.cosine_similarity(sent_1_embedding.unsqueeze(0), sent_2_embedding.unsqueeze(0), dim=0) 39 | pairs.append((i, j, similarity.item())) 40 | else: 41 | continue 42 | return pairs 43 | 44 | def half_pair_encoding(self): 45 | pairs_1_with_2 = [] 46 | pairs_2_with_1 = [] 47 | for i,sent_1 in enumerate(self.doc1_sentences): 48 | for j,sent_2 in enumerate(self.doc2_sentences): 49 | pair_len = len(sent_1) + len(sent_2) + 4 50 | input_ids = torch.zeros(1, self.max_len, dtype=torch.long) 51 | token_type_ids = torch.zeros(1, self.max_len, dtype=torch.long) 52 | attention_mask = torch.zeros(1, self.max_len, dtype=torch.long) 53 | token_type_ids[0][:pair_len] = torch.tensor([[0] * (len(sent_1)+2) + [1] * (len(sent_2)+2)]).unsqueeze(0) 54 | input_ids[0][:pair_len] = torch.tensor(self.tokenizer.convert_tokens_to_ids(['[CLS]'] + sent_1 + ['[SEP]'] + ['CLS'] + sent_2 + ['[SEP]'])).unsqueeze(0) 55 | attention_mask[0][:pair_len] = torch.ones([1] * pair_len) 56 | pair_encoding = self.encoder(input_ids, attention_mask, token_type_ids)[0] 57 | sent_1_embedding = pair_encoding[0, :len(sent_1)+2][0] 58 | sent_2_embedding = pair_encoding[0, len(sent_1)+2:len(sent_1)+len(sent_2)+4][0] 59 | similarity = F.cosine_similarity(sent_1_embedding, sent_2_embedding, dim=0) 60 | F.cosine_similarity(sent_1_embedding.unsqueeze(0), sent_2_embedding.unsqueeze(0), dim=0) 61 | pairs_1_with_2.append((i, j, similarity.item())) 62 | for i,sent_1 in enumerate(self.doc2_sentences): 63 | for j,sent_2 in enumerate(self.doc1_sentences): 64 | pair_len = len(sent_1) + len(sent_2) + 4 65 | input_ids = torch.zeros(1, self.max_len, dtype=torch.long) 66 | token_type_ids = torch.zeros(1, self.max_len, dtype=torch.long) 67 | attention_mask = torch.zeros(1, self.max_len, dtype=torch.long) 68 | token_type_ids[0][:pair_len] = torch.tensor([[0] * (len(sent_1)+2) + [1] * (len(sent_2)+2)]).unsqueeze(0) 69 | input_ids[0][:pair_len] = torch.tensor(self.tokenizer.convert_tokens_to_ids(['[CLS]'] + sent_1 + ['[SEP]'] + ['CLS'] + sent_2 + ['[SEP]'])).unsqueeze(0) 70 | attention_mask[0][:pair_len] = torch.ones([1] * pair_len) 71 | pair_encoding = self.encoder(input_ids, attention_mask, token_type_ids)[0] 72 | sent_1_embedding = pair_encoding[0, :len(sent_1)+2][0] 73 | sent_2_embedding = pair_encoding[0, len(sent_1)+2:len(sent_1)+len(sent_2)+4][0] 74 | similarity = F.cosine_similarity(sent_1_embedding, sent_2_embedding, dim=0) 75 | F.cosine_similarity(sent_1_embedding.unsqueeze(0), sent_2_embedding.unsqueeze(0), dim=0) 76 | pairs_2_with_1.append((i, j, similarity.item())) 77 | 78 | return pairs_1_with_2, pairs_2_with_1 79 | 80 | def sentence_ordering(self): 81 | pairs = self.pair_encoding() 82 | # start 83 | Selected = [] 84 | pair_start = [p for p in pairs if p[0]==0].sort(reverse=True)[0] 85 | Selected.append(pair_start[1]) 86 | pair_next = [p for p in pairs if p[0]==Selected[-1]].sort(reverse=True)[0] 87 | score_max = 0 88 | pdb.set_trace() 89 | while(len(Selected)<=8): 90 | pair_next = [p for p in pairs if p[0]==Selected[-1]].sort(reverse=True)[0] 91 | score_max = 0 92 | for p_n in pair_next: 93 | if p_n[2] > score_max: 94 | score_max = p_n[2] 95 | candidate = p_n[1] 96 | else: 97 | continue 98 | Selected.append(candidate) 99 | return Selected 100 | 101 | 102 | def half_ordering(self): 103 | Insert_2_to_1 = [] 104 | Insert_1_to_2 = [] 105 | pairs_1_with_2, pairs_2_with_1 = self.half_pair_encoding() 106 | doc1_num = len(self.doc1_sentences) 107 | doc2_num = len(self.doc2_sentences) 108 | Selected = [] 109 | for s_2_idx in range(doc2_num): 110 | s_2_map = list((filter(lambda pair: pair[1] == s_2_idx, pairs_1_with_2))) 111 | head_idx = sorted(s_2_map, key=lambda sims: sims[2], reverse=True)[0][0] 112 | Insert_2_to_1.append((s_2_idx, '->', head_idx)) 113 | for s_1_idx in range(doc1_num): 114 | s_1_map = list((filter(lambda pair: pair[1] == s_1_idx, pairs_2_with_1))) 115 | head_idx = sorted(s_1_map, key=lambda sims: sims[2], reverse=True)[0][0] 116 | Insert_1_to_2.append((s_1_idx, '->', head_idx)) 117 | to_be_removed_2_to_1 = [] 118 | to_be_removed_1_to_2 = [] 119 | for i_2_to_1 in Insert_2_to_1: 120 | for i_1_to_2 in Insert_1_to_2: 121 | if i_2_to_1[0]==i_1_to_2[2] and i_2_to_1[2]==i_1_to_2[0]: # symmetric 122 | Selected.append(self.doc1_sentences[i_2_to_1[2]]) 123 | Selected.append(self.doc2_sentences[i_1_to_2[2]]) 124 | to_be_removed_2_to_1.append(i_2_to_1) 125 | to_be_removed_1_to_2.append(i_1_to_2) 126 | 127 | for tb_r_2_1 in to_be_removed_2_to_1: 128 | Insert_2_to_1.remove(tb_r_2_1) 129 | for tb_r_1_2 in to_be_removed_1_to_2: 130 | Insert_1_to_2.remove(tb_r_1_2) 131 | pdb.set_trace() 132 | max_score = 0 133 | chain = [] 134 | for rest_pair in Insert_1_to_2: 135 | s_1_score = 0 136 | s_1_idx = rest_pair[0] 137 | s_1_map = list((filter(lambda pair: pair[0] == s_1_idx, pairs_1_with_2))) 138 | for mp in s_1_map: 139 | s_1_score += mp[2] 140 | if s_1_score > max_score: 141 | max_score = s_1_score 142 | chain_start = s_1_idx 143 | chain.append(chain_start) 144 | return Selected 145 | 146 | 147 | def half_sbert_encoding(self): 148 | pairs_1_with_2 = [] 149 | pairs_2_with_1 = [] 150 | for i, sent_prior in enumerate(self.doc1_sentences): 151 | for j, sent_later in enumerate(self.doc2_sentences): 152 | similarity = self.sbert.pair_sims(" ".join(sent_prior), " ".join(sent_later)) 153 | pairs_1_with_2.append((i, j, similarity.item())) 154 | for i, sent_prior in enumerate(self.doc2_sentences): 155 | for j, sent_later in enumerate(self.doc1_sentences): 156 | similarity = self.sbert.pair_sims(" ".join(sent_prior), " ".join(sent_later)) 157 | pairs_2_with_1.append((i, j, similarity.item())) 158 | return pairs_1_with_2, pairs_2_with_1 159 | 160 | def sbert_encoding(self): 161 | pairs = [] 162 | sentences = self.doc1_sentences + self.doc2_sentences 163 | for i, sent_prior in enumerate(sentences): 164 | for j, sent_later in enumerate(sentences): 165 | if i!=j: 166 | similarity = self.sbert.pair_sims(" ".join(sent_prior), " ".join(sent_later)) 167 | pairs.append((i, j, similarity.item())) 168 | pairs.append((j, i, similarity.item())) 169 | return pairs 170 | 171 | def peer_encoding(self, h_idx): 172 | pairs = [] 173 | sentences = self.doc1_sentences + self.doc2_sentences 174 | sent_prior = sentences[h_idx] 175 | for j, sent_later in tqdm(enumerate(sentences)): 176 | if j!=h_idx: 177 | similarity = self.sbert.pair_sims(" ".join(sent_prior), " ".join(sent_later)) 178 | pairs.append((h_idx, j, similarity.item())) 179 | else: 180 | pairs.append((h_idx, j, 0)) 181 | return pairs 182 | 183 | 184 | def generate_edges(self): 185 | Edge_2_to_1 = [] 186 | Edge_1_to_2 = [] 187 | pairs_1_with_2, pairs_2_with_1 = self.half_sbert_encoding() 188 | doc1_num = len(self.doc1_sentences) 189 | doc2_num = len(self.doc2_sentences) 190 | for s_2_idx in range(doc2_num): 191 | s_2_map = list((filter(lambda pair: pair[1] == s_2_idx, pairs_1_with_2))) 192 | head_idx = sorted(s_2_map, key=lambda sims: sims[2], reverse=True)[0][0] 193 | Edge_2_to_1.append((s_2_idx, '->', head_idx)) 194 | for s_1_idx in range(doc1_num): 195 | s_1_map = list((filter(lambda pair: pair[1] == s_1_idx, pairs_2_with_1))) 196 | head_idx = sorted(s_1_map, key=lambda sims: sims[2], reverse=True)[0][0] 197 | Edge_1_to_2.append((s_1_idx, '->', head_idx)) 198 | return Edge_2_to_1, Edge_1_to_2 199 | 200 | def topo_sort(self): 201 | doc1_num = len(self.doc1_sentences) 202 | doc2_num = len(self.doc2_sentences) 203 | Edge_2_to_1, Edge_1_to_2 = self.generate_edges() 204 | Edge_1_to_1 = [(i, '->', i+1) for i in range(doc1_num-1)] 205 | Edge_2_to_2 = [(i, '->', i+1) for i in range(doc2_num-1)] 206 | nvert = doc1_num + doc2_num 207 | g = Graph(nvert) 208 | for edge in Edge_2_to_1: 209 | pos_start = edge[2] 210 | pos_end = edge[0] + doc1_num 211 | g.addEdge(pos_start, pos_end, 1) 212 | for edge in Edge_1_to_2: 213 | pos_start = edge[2] + doc1_num 214 | pos_end = edge[0] 215 | g.addEdge(pos_start, pos_end, 1) 216 | for edge in Edge_1_to_1: 217 | pos_s2 = edge[0] 218 | pos_s1 = edge[2] 219 | g.addEdge(pos_s2, pos_s1, 1) 220 | for edge in Edge_2_to_2: 221 | pos_s1 = edge[0] + doc1_num 222 | pos_s2 = edge[2] + doc1_num 223 | g.addEdge(pos_s1, pos_s2, 1) 224 | while g.isCyclic(): 225 | g.isCyclic() 226 | order = g.topologicalSort() 227 | return order 228 | 229 | def all_sort(self, starts, ends): 230 | s_e_pairs = [] 231 | for start in starts: 232 | for end in ends: 233 | s_e_pairs.append((start, end)) 234 | pairs = self.sbert_encoding() 235 | chains = [] 236 | 237 | for s_e_pair in s_e_pairs: 238 | start = s_e_pair[0] 239 | end = s_e_pair[1] 240 | chain = [] 241 | chain.append(start) 242 | peers = list((filter(lambda pair: pair[0] == start and pair[1] not in chain, pairs))) 243 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 244 | while next_blk != end: 245 | chain.append(next_blk) 246 | peers = list((filter(lambda pair: pair[0] == next_blk and pair[1] not in chain, pairs))) 247 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 248 | chain.append(next_blk) 249 | chains.append(chain) 250 | return chains 251 | 252 | def dynamic_sort(self, starts, ends): 253 | s_e_pairs = [] 254 | for start in starts: 255 | for end in ends: 256 | s_e_pairs.append((start, end)) 257 | pairs = self.sbert_encoding() 258 | chains = [] 259 | 260 | for s_e_pair in s_e_pairs: 261 | start = s_e_pair[0] 262 | end = s_e_pair[1] 263 | chain = [] 264 | chain.append(start) 265 | pairs = self.peer_encoding(start) 266 | #pdb.set_trace() 267 | peers = list((filter(lambda pair: pair[0] == start and pair[1] not in chain, pairs))) 268 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 269 | while next_blk != end: 270 | chain.append(next_blk) 271 | pairs = self.peer_encoding(next_blk) 272 | peers = list((filter(lambda pair: pair[0] == next_blk and pair[1] not in chain, pairs))) 273 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 274 | chain.append(next_blk) 275 | if chain not in chains: 276 | chains.append(chain) 277 | else: 278 | continue 279 | return chains 280 | 281 | def semantic_based_sort(self, starts, ends): 282 | start = random.choice(starts) 283 | chain = [] 284 | chain.append(start) 285 | pairs = self.peer_encoding(start) 286 | peers = list((filter(lambda pair: pair[0] == start and pair[1] not in chain, pairs))) 287 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 288 | while next_blk not in ends: 289 | chain.append(next_blk) 290 | pairs = self.peer_encoding(next_blk) 291 | peers = list((filter(lambda pair: pair[0] == next_blk and pair[1] not in chain, pairs))) 292 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 293 | chain.append(next_blk) 294 | rest = list(set([i for i in range(len(self.sentences))]) - set(chain)) 295 | random.shuffle(rest) 296 | if len(chain) < 8: 297 | chain.extend(rest[:8-len(chain)]) 298 | else: 299 | pass 300 | return [chain] 301 | 302 | 303 | def unsemantic_based_sort(self, starts, ends): 304 | start = random.choice(starts) 305 | end = random.choice(ends) 306 | 307 | chain = [] 308 | chain.append(start) 309 | 310 | sentences = self.doc1_sentences + self.doc2_sentences 311 | s_ids = list(set([idx for idx,_ in enumerate(sentences)]).difference(set([start,end]))) 312 | random.shuffle(s_ids) 313 | others = min(6, len(s_ids)-2) 314 | chain.extend(s_ids[:others+1]) 315 | chain.append(end) 316 | return [chain] 317 | 318 | def bidirection_sort(self, starts, ends): 319 | s_e_pairs = [] 320 | for start in starts: 321 | for end in ends: 322 | s_e_pairs.append((start, end)) 323 | chains = [] 324 | 325 | for s_e_pair in s_e_pairs: 326 | start = s_e_pair[0] 327 | end = s_e_pair[1] 328 | chain_head = [] 329 | chain_tail = [] 330 | chain_head.append(start) 331 | chain_tail.append(end) 332 | 333 | pairs_head = self.peer_encoding(start) 334 | pairs_tail = self.peer_encoding(end) 335 | 336 | peers_head = list((filter(lambda pair: pair[0] == start and pair[1] not in chain_head, pairs_head))) 337 | next_blk_head = sorted(peers_head, key=lambda sims: sims[2], reverse=True)[0][1] 338 | peers_tail = list((filter(lambda pair: pair[0] == end and pair[1] not in chain_tail, pairs_tail))) 339 | next_blk_tail = sorted(peers_tail, key=lambda sims: sims[2], reverse=True)[0][1] 340 | while next_blk_head != end and len(chain_head)<4: 341 | chain_head.append(next_blk_head) 342 | pairs_head = self.peer_encoding(next_blk_head) 343 | peers_head = list((filter(lambda pair: pair[0] == next_blk_head and pair[1] not in chain_head, pairs_head))) 344 | next_blk_head = sorted(peers_head, key=lambda sims: sims[2], reverse=True)[0][1] 345 | while next_blk_tail != start and len(chain_tail)<4: 346 | chain_tail.append(next_blk_tail) 347 | pairs_tail = self.peer_encoding(next_blk_tail) 348 | peers_tail = list((filter(lambda pair: pair[0] == next_blk_tail and pair[1] not in chain_tail, pairs_tail))) 349 | next_blk_tail = sorted(peers_tail, key=lambda sims: sims[2], reverse=True)[0][1] 350 | if next_blk_head==end or next_blk_tail==start: 351 | if next_blk_head==end: 352 | chain_head.append(next_blk_head) 353 | chain = chain_head 354 | if next_blk_tail==start: 355 | chain_tail.append(next_blk_tail) 356 | chain_tail.reverse() 357 | chain = chain_tail 358 | else: 359 | chain_tail.reverse() 360 | chain = merge_chain(chain_head=chain_head, chain_tail=chain_tail) 361 | chains.append(chain) 362 | return chains 363 | 364 | def threeSent(self, starts, ends, co_occur): 365 | def consecutive_path(starts, ends): 366 | chain = [] 367 | for start in starts: 368 | chain.append(start) 369 | pairs = self.peer_encoding(start) 370 | peers = list((filter(lambda pair: pair[0] == start and pair[1] not in chain, pairs))) 371 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 372 | while next_blk not in ends and len(chain)<=2: 373 | chain.append(next_blk) 374 | pairs = self.peer_encoding(next_blk) 375 | peers = list((filter(lambda pair: pair[0] == next_blk and pair[1] not in chain, pairs))) 376 | next_blk = sorted(peers, key=lambda sims: sims[2], reverse=True)[0][1] 377 | chain.append(next_blk) 378 | if len(set(chain).intersection(set(ends))) > 0 : 379 | break 380 | else: 381 | chain = [] 382 | continue 383 | return chain 384 | def multihop_path(starts, ends, co_occur): 385 | ori_co_occur = [i for i in co_occur] 386 | path = [] 387 | edges_tuple = [] 388 | start_pos = {} 389 | end_pos = {} 390 | start_edges = list((filter(lambda co: co[0]==1 and co[2] in starts, co_occur))) 391 | end_edges = list((filter(lambda co: co[0]==2 and co[2] in ends, co_occur))) 392 | for s_ed in start_edges: 393 | start_pos[s_ed[2]]=s_ed[1] 394 | edges_tuple.append((1, s_ed[1])) 395 | for e_ed in end_edges: 396 | end_pos[e_ed[2]]=e_ed[1] 397 | edges_tuple.append((2, e_ed[1])) 398 | co_occur = list(set(co_occur).difference(set(start_edges)).difference(set(end_edges))) 399 | if len(start_edges)>0 and len(end_edges)>0: 400 | next_set = [] 401 | pre_set = [] 402 | next_pos = {} 403 | pre_pos = {} 404 | for s_ed in start_edges: 405 | next_set.append(s_ed[1]) 406 | next_pos[s_ed[2]]=s_ed[1] 407 | edges_tuple.append((s_ed[0], s_ed[1])) 408 | next_set = set(next_set) 409 | for e_ed in end_edges: 410 | pre_set.append(e_ed[1]) 411 | pre_pos[e_ed[2]]=e_ed[1] 412 | edges_tuple.append((e_ed[0], e_ed[1])) 413 | pre_set = set(pre_set) 414 | while(len(next_set.intersection(pre_set))==0 and len(co_occur)>0): 415 | start_edges = list((filter(lambda co: co[0] in list(next_set), co_occur))) 416 | end_edges = list((filter(lambda co: co[0] in list(pre_set), co_occur))) 417 | co_occur = list(set(co_occur).difference(set(start_edges)).difference(set(end_edges))) 418 | next_set = list(next_set) 419 | pre_set = list(pre_set) 420 | for s_ed in start_edges: 421 | next_set.append(s_ed[1]) 422 | next_pos[s_ed[2]]=s_ed[1] 423 | edges_tuple.append((s_ed[0], s_ed[1])) 424 | next_set = set(next_set) 425 | for e_ed in end_edges: 426 | pre_set.append(e_ed[1]) 427 | pre_pos[e_ed[2]]=e_ed[1] 428 | edges_tuple.append((e_ed[0], e_ed[1])) 429 | pre_set = set(pre_set) 430 | entity_chain = merge_chain(list(next_set)+[1], [2]+list(pre_set)) 431 | print(entity_chain) 432 | edges_tuple = list(set(edges_tuple)) 433 | path_edges = list((filter(lambda co: co[0] in entity_chain and co[1] in entity_chain, edges_tuple))) 434 | path_triplet = list((filter(lambda e: (e[0], e[1]) in path_edges, ori_co_occur))) 435 | path = [o[2] for o in path_triplet] 436 | return path 437 | 438 | def default_path(starts, ends, co_occur, max_pos): 439 | chain = starts + ends 440 | if len(chain) >= 8: 441 | chain = [starts[0]] + [ends[0]] + random.choices(list(set(starts+ends).difference(set([starts[0]] + [ends[0]]))), k=6) 442 | else: 443 | while(len(chain)<=7): 444 | offset = [-3,-2,-1,1,2,3] 445 | ex_blk = random.choices(chain,k=1)[0] + random.choices(offset, k=1)[0] 446 | if 0<=ex_blk<=max_pos-1 and ex_blk not in chain: 447 | chain.append(ex_blk) 448 | else: 449 | continue 450 | return chain 451 | 452 | c_path = consecutive_path(starts, ends) 453 | s_e, e_e = multihop_path(starts, ends, co_occur) 454 | d_path = default_path(starts, ends, co_occur, len(self.sentences)) 455 | if len(c_path)>0: 456 | path = c_path 457 | elif len(d_path)>0: 458 | path = d_path 459 | else: 460 | path = [] 461 | print(path) 462 | return [path] 463 | 464 | def merge_chain(chain_head, chain_tail): 465 | overlap_blk = list(set(chain_head).intersection(set(chain_tail))) 466 | if len(overlap_blk) >= 1: 467 | #pdb.set_trace() 468 | print(overlap_blk) 469 | for ov in overlap_blk: 470 | chain_tail.remove(ov) 471 | merged = chain_head + chain_tail 472 | return merged 473 | -------------------------------------------------------------------------------- /ecrim/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | import shutil 7 | import sys 8 | import pdb 9 | import apex 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.utils.tensorboard import SummaryWriter 15 | from tqdm import tqdm 16 | from transformers import AdamW, get_linear_schedule_with_warmup 17 | from data_helper import BlkPosInterface, SimpleListDataset 18 | import pickle 19 | class ContextError(Exception): 20 | def __init__(self): 21 | pass 22 | 23 | 24 | class Once: 25 | def __init__(self, rank): 26 | self.rank = rank 27 | 28 | def __enter__(self): 29 | if self.rank > 0: 30 | sys.settrace(lambda *args, **keys: None) 31 | frame = sys._getframe(1) 32 | frame.f_trace = self.trace 33 | return True 34 | 35 | def trace(self, frame, event, arg): 36 | raise ContextError 37 | 38 | def __exit__(self, type, value, traceback): 39 | if type == ContextError: 40 | return True 41 | else: 42 | return False 43 | 44 | 45 | class OnceBarrier: 46 | def __init__(self, rank): 47 | self.rank = rank 48 | 49 | def __enter__(self): 50 | if self.rank > 0: 51 | sys.settrace(lambda *args, **keys: None) 52 | frame = sys._getframe(1) 53 | frame.f_trace = self.trace 54 | return True 55 | 56 | def trace(self, frame, event, arg): 57 | raise ContextError 58 | 59 | def __exit__(self, type, value, traceback): 60 | if self.rank >= 0: 61 | torch.distributed.barrier() 62 | if type == ContextError: 63 | return True 64 | else: 65 | return False 66 | 67 | 68 | class Cache: 69 | def __init__(self, rank): 70 | self.rank = rank 71 | 72 | def __enter__(self): 73 | if self.rank not in [-1, 0]: 74 | torch.distributed.barrier() 75 | return True 76 | 77 | def __exit__(self, type, value, traceback): 78 | if self.rank == 0: 79 | torch.distributed.barrier() 80 | return False 81 | 82 | 83 | def set_seed(seed, n_gpu): 84 | random.seed(seed) 85 | np.random.seed(seed) 86 | torch.manual_seed(seed) 87 | if n_gpu > 0: 88 | torch.cuda.manual_seed_all(seed) 89 | 90 | 91 | class Prefetcher: 92 | def __init__(self, dataloader, stream): 93 | self.dataloader = dataloader 94 | self.stream = torch.cuda.Stream() 95 | 96 | def __iter__(self): 97 | self.iter = iter(self.dataloader) 98 | self.preload() 99 | return self 100 | 101 | def preload(self): 102 | try: 103 | self.next = next(self.iter) 104 | except StopIteration: 105 | self.next = None 106 | return 107 | with torch.cuda.stream(self.stream): 108 | next_list = list() 109 | for v in self.next: 110 | if type(v) == torch.Tensor: 111 | next_list.append(v.cuda(non_blocking=True)) 112 | else: 113 | next_list.append(v) 114 | self.next = tuple(next_list) 115 | 116 | def __next__(self): 117 | torch.cuda.current_stream().wait_stream(self.stream) 118 | if self.next is not None: 119 | result = self.next 120 | self.preload() 121 | return result 122 | else: 123 | raise StopIteration 124 | 125 | def __len__(self): 126 | return len(self.dataloader) 127 | 128 | 129 | class TrainerCallback: 130 | def __init__(self): 131 | pass 132 | 133 | def on_argument(self, parser): 134 | pass 135 | 136 | def load_model(self): 137 | pass 138 | 139 | def load_data(self): 140 | pass 141 | 142 | def collate_fn(self): 143 | return None, None, None 144 | 145 | def on_train_epoch_start(self, epoch): 146 | pass 147 | 148 | def on_train_step(self, step, train_step, inputs, extra, loss, outputs): 149 | pass 150 | 151 | def on_train_epoch_end(self, epoch): 152 | pass 153 | 154 | def on_dev_epoch_start(self, epoch): 155 | pass 156 | 157 | def on_dev_step(self, step, inputs, extra, outputs): 158 | pass 159 | 160 | def on_dev_epoch_end(self, epoch): 161 | pass 162 | 163 | def on_test_epoch_start(self, epoch): 164 | pass 165 | 166 | def on_test_step(self, step, inputs, extra, outputs): 167 | pass 168 | 169 | def on_test_epoch_end(self, epoch): 170 | pass 171 | 172 | def process_train_data(self, data): 173 | pass 174 | 175 | def process_dev_data(self, data): 176 | pass 177 | 178 | def process_test_data(self, data): 179 | pass 180 | 181 | def on_save(self, path): 182 | pass 183 | 184 | def on_load(self, path): 185 | pass 186 | 187 | 188 | class Trainer: 189 | def __init__(self, callback: TrainerCallback): 190 | self.callback = callback 191 | self.callback.trainer = self 192 | logging.basicConfig(level=logging.INFO) 193 | 194 | def parse_args(self): 195 | self.parser = argparse.ArgumentParser() 196 | self.parser.add_argument('--train', action='store_true') 197 | self.parser.add_argument('--dev', action='store_true') 198 | self.parser.add_argument('--test', action='store_true') 199 | self.parser.add_argument('--debug', action='store_true') 200 | self.parser.add_argument("--per_gpu_train_batch_size", default=1, type=int) 201 | self.parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int) 202 | self.parser.add_argument("--learning_rate", default=5e-5, type=float) 203 | self.parser.add_argument("--gradient_accumulation_steps", type=int, default=4) 204 | self.parser.add_argument("--weight_decay", default=0.0, type=float) 205 | self.parser.add_argument("--adam_epsilon", default=1e-8, type=float) 206 | self.parser.add_argument("--max_grad_norm", default=1.0, type=float) 207 | self.parser.add_argument("--epochs", default=10, type=int) 208 | self.parser.add_argument("--warmup_ratio", default=0.1, type=float) 209 | self.parser.add_argument("--logging_steps", type=int, default=500) 210 | self.parser.add_argument("--save_steps", type=int, default=10000) 211 | self.parser.add_argument("--seed", type=int, default=42) 212 | self.parser.add_argument("--num_workers", type=int, default=0) 213 | self.parser.add_argument("--local_rank", type=int, default=-1) 214 | self.parser.add_argument("--fp16", action="store_true") 215 | self.parser.add_argument("--fp16_opt_level", type=str, default="O1") 216 | self.parser.add_argument("--no_cuda", action="store_true") 217 | self.parser.add_argument("--load_checkpoint", default=None, type=str) 218 | self.parser.add_argument("--ignore_progress", action='store_true') 219 | self.parser.add_argument("--dataset_ratio", type=float, default=1.0) 220 | self.parser.add_argument("--no_save", action="store_true") 221 | self.parser.add_argument("--intro_save", default="../data/", type=str) 222 | #self.parser.add_argument("--model_name", default="bert", type=str) 223 | self.callback.on_argument(self.parser) 224 | self.args = self.parser.parse_args() 225 | keys = list(self.args.__dict__.keys()) 226 | for key in keys: 227 | value = getattr(self.args, key) 228 | if type(value) == str and os.path.exists(value): 229 | setattr(self.args, key, os.path.abspath(value)) 230 | if not self.args.train: 231 | self.args.epochs = 1 232 | self.train = self.args.train 233 | self.dev = self.args.dev 234 | self.test = self.args.test 235 | self.debug = self.args.debug 236 | self.per_gpu_train_batch_size = self.args.per_gpu_train_batch_size 237 | self.per_gpu_eval_batch_size = self.args.per_gpu_eval_batch_size 238 | self.learning_rate = self.args.learning_rate 239 | self.gradient_accumulation_steps = self.args.gradient_accumulation_steps 240 | self.weight_decay = self.args.weight_decay 241 | self.adam_epsilon = self.args.adam_epsilon 242 | self.max_grad_norm = self.args.max_grad_norm 243 | self.epochs = self.args.epochs 244 | self.warmup_ratio = self.args.warmup_ratio 245 | self.logging_steps = self.args.logging_steps 246 | self.save_steps = self.args.save_steps 247 | self.seed = self.args.seed 248 | self.num_workers = self.args.num_workers 249 | self.local_rank = self.args.local_rank 250 | self.fp16 = self.args.fp16 251 | self.fp16_opt_level = self.args.fp16_opt_level 252 | self.no_cuda = self.args.no_cuda 253 | self.load_checkpoint = self.args.load_checkpoint 254 | self.ignore_progress = self.args.ignore_progress 255 | self.dataset_ratio = self.args.dataset_ratio 256 | self.no_save = self.args.no_save 257 | self.callback.args = self.args 258 | self.model_name = self.args.model_name 259 | self.intro_save = self.args.intro_save 260 | 261 | def set_env(self): 262 | if self.debug: 263 | sys.excepthook = IPython.core.ultratb.FormattedTB(mode='Verbose', color_scheme='Linux', call_pdb=1) 264 | if self.local_rank == -1 or self.no_cuda: 265 | device = torch.device("cuda" if torch.cuda.is_available() and not self.no_cuda else "cpu") 266 | self.n_gpu = 0 if self.no_cuda else torch.cuda.device_count() 267 | else: 268 | torch.cuda.set_device(self.local_rank) 269 | device = torch.device("cuda", self.local_rank) 270 | torch.distributed.init_process_group(backend="nccl") 271 | self.n_gpu = 1 272 | set_seed(self.seed, self.n_gpu) 273 | self.device = device 274 | with self.once_barrier(): 275 | if not os.path.exists('r'): 276 | os.mkdir('r') 277 | runs = os.listdir('r') 278 | i = max([int(c) for c in runs], default=-1) + 1 279 | os.mkdir(os.path.join('r', str(i))) 280 | src_names = [source for source in os.listdir() if source.endswith('.py')] 281 | for src_name in src_names: 282 | shutil.copy(src_name, os.path.join('r', str(i), src_name)) 283 | os.mkdir(os.path.join('r', str(i), 'output')) 284 | os.mkdir(os.path.join('r', str(i), 'tmp')) 285 | runs = os.listdir('r') 286 | i = max([int(c) for c in runs]) 287 | os.chdir(os.path.join('r', str(i))) 288 | with self.once_barrier(): 289 | json.dump(sys.argv, open('output/args.json', 'w')) 290 | logging.info("Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, 16-bits training: {}".format(self.local_rank, device, self.n_gpu, bool(self.local_rank != -1), self.fp16)) 291 | #self.train_batch_size = self.per_gpu_train_batch_size 292 | #self.eval_batch_size = self.per_gpu_eval_batch_size 293 | self.train_batch_size = self.per_gpu_train_batch_size * max(1, self.n_gpu) 294 | self.eval_batch_size = self.per_gpu_eval_batch_size * max(1, self.n_gpu) 295 | 296 | if self.fp16: 297 | apex.amp.register_half_function(torch, "einsum") 298 | self.stream = torch.cuda.Stream() 299 | 300 | def set_model(self): 301 | self.model= self.callback.load_model() 302 | self.model.to(self.device) 303 | no_decay = ["bias", "LayerNorm.weight"] 304 | optimizer_grouped_parameters = [ 305 | {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],"weight_decay": self.weight_decay}, 306 | {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 307 | ] 308 | self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon) 309 | if self.fp16: 310 | self.model, self.optimizer = apex.amp.initialize(self.model, self.optimizer, opt_level=self.fp16_opt_level) 311 | if self.n_gpu > 1: 312 | self.model = torch.nn.DataParallel(self.model) 313 | if self.local_rank != -1: 314 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=True) 315 | 316 | 317 | def once(self): 318 | return Once(self.local_rank) 319 | 320 | def once_barrier(self): 321 | return OnceBarrier(self.local_rank) 322 | 323 | def cache(self): 324 | return Cache(self.local_rank) 325 | 326 | def load_data(self): 327 | self.train_step = 1 328 | self.epochs_trained = 0 329 | self.steps_trained_in_current_epoch = 0 330 | self.intro_train_step = 1 331 | train_dataset, dev_dataset, test_dataset = self.callback.load_data() 332 | #train_dataset, dev_dataset = self.callback.load_data() 333 | train_fn, dev_fn, test_fn = self.callback.collate_fn() 334 | if train_dataset: 335 | if self.dataset_ratio < 1: 336 | train_dataset = torch.utils.data.Subset(train_dataset, list(range(int(len(train_dataset) * self.dataset_ratio)))) 337 | self.train_dataset = train_dataset 338 | self.train_sampler = RandomSampler(self.train_dataset) if self.local_rank == -1 else DistributedSampler(self.train_dataset) 339 | self.train_dataloader = Prefetcher(DataLoader(self.train_dataset, sampler=self.train_sampler, batch_size=self.train_batch_size, collate_fn=train_fn, num_workers=self.num_workers), self.stream) 340 | self.t_total = len(self.train_dataloader) // self.gradient_accumulation_steps * self.epochs 341 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=int(self.t_total * self.warmup_ratio), num_training_steps=self.t_total) 342 | if dev_dataset: 343 | if self.dataset_ratio < 1: 344 | dev_dataset = torch.utils.data.Subset(dev_dataset, list(range(int(len(dev_dataset) * self.dataset_ratio)))) 345 | self.dev_dataset = dev_dataset 346 | self.dev_sampler = SequentialSampler(self.dev_dataset) if self.local_rank == -1 else DistributedSampler(self.dev_dataset) 347 | self.dev_dataloader = Prefetcher(DataLoader(self.dev_dataset, sampler=self.dev_sampler, batch_size=self.eval_batch_size, collate_fn=dev_fn, num_workers=self.num_workers), self.stream) 348 | if test_dataset: 349 | if self.dataset_ratio < 1: 350 | test_dataset = torch.utils.data.Subset(test_dataset, list(range(int(len(test_dataset) * self.dataset_ratio)))) 351 | self.test_dataset = test_dataset 352 | self.test_sampler = SequentialSampler(self.test_dataset) if self.local_rank == -1 else DistributedSampler(self.test_dataset) 353 | self.test_dataloader = Prefetcher(DataLoader(self.test_dataset, sampler=self.test_sampler, batch_size=self.eval_batch_size, collate_fn=test_fn, num_workers=self.num_workers), self.stream) 354 | 355 | 356 | def restore_checkpoint(self, path, ignore_progress=False): 357 | if self.no_save: 358 | return 359 | model_to_load = self.model.module if hasattr(self.model, "module") else self.model 360 | model_to_load.load_state_dict(torch.load(os.path.join(path, 'pytorch_model.bin'), map_location=self.device)) 361 | self.optimizer.load_state_dict(torch.load(os.path.join(path, "optimizer.pt"), map_location=self.device)) 362 | self.scheduler.load_state_dict(torch.load(os.path.join(path, "scheduler.pt"), map_location=self.device)) 363 | self.callback.on_load(path) 364 | if not ignore_progress: 365 | self.train_step = int(path.split("-")[-1]) 366 | self.epochs_trained = self.train_step // (len(self.train_dataloader) // self.gradient_accumulation_steps) 367 | self.steps_trained_in_current_epoch = self.train_step % (len(self.train_dataloader) // self.gradient_accumulation_steps) 368 | logging.info(" Continuing training from checkpoint, will skip to saved train_step") 369 | logging.info(" Continuing training from epoch %d", self.epochs_trained) 370 | logging.info(" Continuing training from train step %d", self.train_step) 371 | logging.info(" Will skip the first %d steps in the first epoch", self.steps_trained_in_current_epoch) 372 | 373 | def save_checkpoint(self): 374 | if self.no_save: 375 | return 376 | output_dir = os.path.join('output', "checkpoint-{}".format(self.train_step)) 377 | if not os.path.exists(output_dir): 378 | os.mkdir(output_dir) 379 | model_to_save = self.model.module if hasattr(self.model, "module") else self.model 380 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) 381 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 382 | torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 383 | torch.save(self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 384 | self.callback.on_save(output_dir) 385 | 386 | def run(self): 387 | self.parse_args() 388 | self.set_env() 389 | with self.once(): 390 | self.writer = SummaryWriter() 391 | self.set_model() 392 | self.load_data() 393 | if self.load_checkpoint is not None: 394 | self.restore_checkpoint(self.load_checkpoint, self.ignore_progress) 395 | best_performance = 0 396 | best_step = -1 397 | for epoch in range(self.epochs): 398 | if epoch < self.epochs_trained: 399 | continue 400 | with self.once(): 401 | logging.info('epoch %d', epoch) 402 | if self.train: 403 | tr_loss, logging_loss = 0.0, 0.0 404 | self.model.zero_grad() 405 | self.model.train() 406 | self.callback.on_train_epoch_start(epoch) 407 | if self.local_rank >= 0: 408 | self.train_sampler.set_epoch(epoch) 409 | print("==========Training==========") 410 | for step, batch in enumerate(tqdm(self.train_dataloader, disable=self.local_rank > 0)): 411 | if step < self.steps_trained_in_current_epoch: 412 | continue 413 | extra, selected_inputs, selected_rets = self.callback.process_train_data(batch) 414 | outputs = self.model(**selected_inputs) 415 | #print(prof.key_averages().table(sort_by="cuda_time_total")) 416 | #prof.export_chrome_trace('./codred_profile.json') 417 | loss = outputs[0] 418 | if step%500 ==0: 419 | print(loss) 420 | if self.n_gpu > 1: 421 | loss = loss.mean() 422 | if self.gradient_accumulation_steps > 1: 423 | loss = loss / self.gradient_accumulation_steps 424 | if self.local_rank < 0 or (step + 1) % self.gradient_accumulation_steps == 0: 425 | if self.fp16: 426 | with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss: 427 | scaled_loss.backward() 428 | else: 429 | loss.backward() 430 | else: 431 | with self.model.no_sync(): 432 | if self.fp16: 433 | with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss: 434 | scaled_loss.backward() 435 | else: 436 | loss.backward() 437 | tr_loss += loss.item() 438 | if (step + 1) % self.gradient_accumulation_steps == 0: 439 | if self.fp16: 440 | torch.nn.utils.clip_grad_norm_(apex.amp.master_params(self.optimizer), self.max_grad_norm) 441 | else: 442 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 443 | self.optimizer.step() 444 | self.scheduler.step() 445 | self.model.zero_grad() 446 | self.train_step += 1 447 | with self.once(): 448 | if self.train_step % self.logging_steps == 0: 449 | self.writer.add_scalar("lr", self.scheduler.get_lr()[0], self.train_step) 450 | self.writer.add_scalar("loss", (tr_loss - logging_loss) / self.logging_steps, self.train_step) 451 | logging_loss = tr_loss 452 | if self.train_step % self.save_steps == 0: 453 | self.save_checkpoint() 454 | #torch.cuda.empty_cache() 455 | self.callback.on_train_step(step, self.train_step, selected_inputs, extra, loss.item(), outputs) 456 | with self.once(): 457 | self.save_checkpoint() 458 | self.callback.on_train_epoch_end(epoch) 459 | if self.dev: 460 | with torch.no_grad(): 461 | self.model.eval() 462 | self.callback.on_dev_epoch_start(epoch) 463 | for step, batch in enumerate(tqdm(self.dev_dataloader, disable=self.local_rank > 0)): 464 | extra, selected_inputs, selected_rets = self.callback.process_dev_data(batch) 465 | #with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=True, profile_memory=True, with_stack=True, with_modules=True) as prof: 466 | outputs = self.model(**selected_inputs) 467 | #print(prof.key_averages().table(sort_by="cuda_time_total")) 468 | #prof.export_chrome_trace('./codred_profile.json') 469 | self.callback.on_dev_step(step, selected_inputs, extra, outputs) 470 | performance = self.callback.on_dev_epoch_end(epoch) 471 | if performance > best_performance: 472 | best_performance = performance 473 | best_step = self.train_step 474 | if self.dev: 475 | with torch.no_grad(): 476 | self.model.eval() 477 | self.callback.on_dev_epoch_start(epoch) 478 | for step, batch in enumerate(tqdm(self.dev_dataloader, disable=self.local_rank > 0)): 479 | extra, selected_inputs, selected_rets = self.callback.process_dev_data(batch) 480 | #with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=True, profile_memory=True, with_stack=True, with_modules=True) as prof: 481 | outputs = self.model(**selected_inputs) 482 | #print(prof.key_averages().table(sort_by="cuda_time_total")) 483 | #prof.export_chrome_trace('./codred_profile.json') 484 | self.callback.on_dev_step(step, selected_inputs, extra, outputs) 485 | performance = self.callback.on_dev_epoch_end(epoch) 486 | if performance > best_performance: 487 | best_performance = performance 488 | best_step = self.train_step 489 | if self.test: 490 | with torch.no_grad(): 491 | if best_step > 0 and self.train: 492 | self.restore_checkpoint(os.path.join('output', "checkpoint-{}".format(best_step))) 493 | self.model.eval() 494 | self.callback.on_test_epoch_start(epoch) 495 | for step, batch in enumerate(tqdm(self.test_dataloader, disable=self.local_rank > 0)): 496 | extra, selected_inputs, selected_rets = self.callback.process_test_data(batch) 497 | outputs = self.model(**selected_inputs) 498 | self.callback.on_test_step(step, selected_inputs, extra, outputs) 499 | self.callback.on_test_epoch_end(epoch) 500 | with self.once(): 501 | self.writer.close() 502 | json.dump(True, open('output/f.json', 'w')) 503 | 504 | def distributed_broadcast(self, l): 505 | assert type(l) == list or type(l) == dict 506 | if self.local_rank < 0: 507 | return l 508 | else: 509 | torch.distributed.barrier() 510 | process_number = torch.distributed.get_world_size() 511 | json.dump(l, open(f'tmp/{self.local_rank}.json', 'w')) 512 | torch.distributed.barrier() 513 | objs = list() 514 | for i in range(process_number): 515 | objs.append(json.load(open(f'tmp/{i}.json'))) 516 | if type(objs[0]) == list: 517 | ret = list() 518 | for i in range(process_number): 519 | ret.extend(objs[i]) 520 | else: 521 | ret = dict() 522 | for i in range(process_number): 523 | for k, v in objs.items(): 524 | assert k not in ret 525 | ret[k] = v 526 | torch.distributed.barrier() 527 | return ret 528 | 529 | def distributed_merge(self, l): 530 | assert type(l) == list or type(l) == dict 531 | if self.local_rank < 0: 532 | return l 533 | else: 534 | torch.distributed.barrier() 535 | process_number = torch.distributed.get_world_size() 536 | json.dump(l, open(f'tmp/{self.local_rank}.json', 'w')) 537 | torch.distributed.barrier() 538 | if self.local_rank == 0: 539 | objs = list() 540 | for i in range(process_number): 541 | objs.append(json.load(open(f'tmp/{i}.json'))) 542 | if type(objs[0]) == list: 543 | ret = list() 544 | for i in range(process_number): 545 | ret.extend(objs[i]) 546 | else: 547 | ret = dict() 548 | for i in range(process_number): 549 | for k, v in objs.items(): 550 | assert k not in ret 551 | ret[k] = v 552 | else: 553 | ret = None 554 | torch.distributed.barrier() 555 | return ret 556 | 557 | def distributed_get(self, v): 558 | if self.local_rank < 0: 559 | return v 560 | else: 561 | torch.distributed.barrier() 562 | if self.local_rank == 0: 563 | json.dump(v, open('tmp/v.json', 'w')) 564 | torch.distributed.barrier() 565 | v = json.load(open('tmp/v.json')) 566 | torch.distributed.barrier() 567 | return v 568 | 569 | def _write_estimation(self, buf, relevance_blk, f): 570 | for i, blk in enumerate(buf): 571 | f.write(f'{blk.pos} {relevance_blk[i].item()}\n') 572 | 573 | def _score_blocks(self, qbuf, relevance_token): 574 | ends = qbuf.block_ends() 575 | relevance_blk = torch.ones(len(ends), device='cpu') 576 | for i in range(len(ends)): 577 | if qbuf[i].blk_type > 0: # query 578 | relevance_blk[i] = (relevance_token[ends[i-1]:ends[i]]).mean() 579 | return relevance_blk 580 | 581 | def _collect_estimations_from_dir(self, est_dir): 582 | ret = {} 583 | for shortname in os.listdir(est_dir): 584 | filename = os.path.join(est_dir, shortname) 585 | if shortname.startswith('estimations_'): 586 | with open(filename, 'r') as fin: 587 | for line in fin: 588 | l = line.split() 589 | pos, estimation = int(l[0]), float(l[1]) 590 | ret[pos].estimation = estimation 591 | os.replace(filename, os.path.join(est_dir, 'backup_' + shortname)) 592 | return ret -------------------------------------------------------------------------------- /ecrim/main_simp.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures.thread import _threads_queues 2 | import json 3 | import random 4 | from functools import partial 5 | import pdb 6 | from turtle import pd 7 | import numpy as np 8 | import redis 9 | import sklearn 10 | import torch 11 | from eveliver import (Logger, load_model, tensor_to_obj) 12 | from trainer import Trainer, TrainerCallback 13 | from transformers import AutoTokenizer, BertModel 14 | from matrix_transformer import Encoder as MatTransformer 15 | from graph_encoder import Encoder as GraphEncoder 16 | from torch import nn 17 | import torch.nn.functional as F 18 | import os 19 | from tqdm import tqdm 20 | from buffer import Buffer 21 | from utils import CAPACITY, BLOCK_SIZE, DEFAULT_MODEL_NAME, contrastive_pair, check_htb_debug, complete_h_t_debug 22 | from utils import complete_h_t, check_htb, check_htb_debug 23 | from utils import CLS_TOKEN_ID, SEP_TOKEN_ID, H_START_MARKER_ID, H_END_MARKER_ID, T_END_MARKER_ID, T_START_MARKER_ID 24 | import math 25 | from torch.nn import CrossEntropyLoss 26 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 27 | from itertools import groupby 28 | from pyg_graph import create_edges, create_graph, GCN, Attention, create_graph_single 29 | from utils import DotProductSimilarity 30 | from sentence_reordering import SentReOrdering 31 | from sbert_wk import sbert 32 | from itertools import product, combinations 33 | def eval_performance(facts, pred_result): 34 | sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True) 35 | prec = [] 36 | rec = [] 37 | correct = 0 38 | total = len(facts) 39 | #pdb.set_trace() 40 | for i, item in enumerate(sorted_pred_result): 41 | if (item['entpair'][0], item['entpair'][1], item['relation']) in facts: 42 | correct += 1 43 | prec.append(float(correct) / float(i + 1)) 44 | rec.append(float(correct) / float(total)) 45 | auc = sklearn.metrics.auc(x=rec, y=prec) 46 | np_prec = np.array(prec) 47 | np_rec = np.array(rec) 48 | f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max() 49 | mean_prec = np_prec.mean() 50 | return {'prec': np_prec.tolist(), 'rec': np_rec.tolist(), 'mean_prec': mean_prec, 'f1': f1, 'auc': auc} 51 | 52 | def expand(start, end, total_len, max_size): 53 | e_size = max_size - (end - start) 54 | _1 = start - (e_size // 2) 55 | _2 = end + (e_size - e_size // 2) 56 | if _2 - _1 <= total_len: 57 | if _1 < 0: 58 | _2 -= -1 59 | _1 = 0 60 | elif _2 > total_len: 61 | _1 -= (_2 - total_len) 62 | _2 = total_len 63 | else: 64 | _1 = 0 65 | _2 = total_len 66 | return _1, _2 67 | 68 | 69 | def place_train_data(dataset): 70 | ep2d = dict() 71 | for key, doc1, doc2, label in dataset: 72 | if key not in ep2d: 73 | ep2d[key] = dict() 74 | if label not in ep2d[key]: 75 | ep2d[key][label] = list() 76 | ep2d[key][label].append([doc1, doc2, label]) 77 | bags = list() 78 | for key, l2docs in ep2d.items(): 79 | if len(l2docs) == 1 and 'n/a' in l2docs: 80 | bags.append([key, 'n/a', l2docs['n/a'], 'o']) 81 | else: 82 | labels = list(l2docs.keys()) 83 | for label in labels: 84 | if label != 'n/a': 85 | ds = l2docs[label] 86 | if 'n/a' in l2docs: 87 | ds.extend(l2docs['n/a']) 88 | bags.append([key, label, ds, 'o']) 89 | bags.sort(key=lambda x: x[0] + '#' + x[1]) 90 | return bags 91 | 92 | 93 | def place_dev_data(dataset, single_path): 94 | ep2d = dict() 95 | for key, doc1, doc2, label in dataset: 96 | if key not in ep2d: 97 | ep2d[key] = dict() 98 | if label not in ep2d[key]: 99 | ep2d[key][label] = list() 100 | ep2d[key][label].append([doc1, doc2, label]) 101 | bags = list() 102 | for key, l2docs in ep2d.items(): 103 | if len(l2docs) == 1 and 'n/a' in l2docs: 104 | bags.append([key, ['n/a'], l2docs['n/a'], 'o']) 105 | else: 106 | labels = list(l2docs.keys()) 107 | ds = list() 108 | for label in labels: 109 | if single_path and label != 'n/a': 110 | ds.append(random.choice(l2docs[label])) 111 | else: 112 | ds.extend(l2docs[label]) 113 | if 'n/a' in labels: 114 | labels.remove('n/a') 115 | bags.append([key, labels, ds, 'o']) 116 | bags.sort(key=lambda x: x[0] + '#' + '#'.join(x[1])) 117 | return bags 118 | 119 | def place_test_data(dataset, single_path): 120 | ep2d = dict() 121 | for data in dataset: 122 | key = data['h_id'] + '#' + data['t_id'] 123 | doc1 = data['doc'][0] 124 | doc2 = data['doc'][1] 125 | label = 'n/a' 126 | if key not in ep2d: 127 | ep2d[key] = dict() 128 | if label not in ep2d[key]: 129 | ep2d[key][label] = list() 130 | ep2d[key][label].append([doc1, doc2, label]) 131 | bags = list() 132 | for key, l2docs in ep2d.items(): 133 | if len(l2docs) == 1 and 'n/a' in l2docs: 134 | bags.append([key, ['n/a'], l2docs['n/a'], 'o']) 135 | else: 136 | labels = list(l2docs.keys()) 137 | ds = list() 138 | for label in labels: 139 | if single_path and label != 'n/a': 140 | ds.append(random.choice(l2docs[label])) 141 | else: 142 | ds.extend(l2docs[label]) 143 | if 'n/a' in labels: 144 | labels.remove('n/a') 145 | bags.append([key, labels, ds, 'o']) 146 | bags.sort(key=lambda x: x[0] + '#' + '#'.join(x[1])) 147 | return bags 148 | 149 | 150 | def gen_c(tokenizer, passage, span, max_len, bound_tokens, d_start, d_end, no_additional_marker, mask_entity): 151 | ret = list() 152 | ret.append(bound_tokens[0]) 153 | for i in range(span[0], span[1]): 154 | if mask_entity: 155 | ret.append('[MASK]') 156 | else: 157 | ret.append(passage[i]) 158 | ret.append(bound_tokens[1]) 159 | prev = list() 160 | prev_ptr = span[0] - 1 161 | while len(prev) < max_len: 162 | if prev_ptr < 0: 163 | break 164 | if not no_additional_marker and prev_ptr in d_end: 165 | prev.append(f'[unused{(d_end[prev_ptr] + 2) * 2 + 2}]') 166 | prev.append(passage[prev_ptr]) 167 | if not no_additional_marker and prev_ptr in d_start: 168 | prev.append(f'[unused{(d_start[prev_ptr] + 2) * 2 + 1}]') 169 | prev_ptr -= 1 170 | nex = list() 171 | nex_ptr = span[1] 172 | while len(nex) < max_len: 173 | if nex_ptr >= len(passage): 174 | break 175 | if not no_additional_marker and nex_ptr in d_start: 176 | nex.append(f'[unused{(d_start[nex_ptr] + 2) * 2 + 1}]') 177 | nex.append(passage[nex_ptr]) 178 | if not no_additional_marker and nex_ptr in d_end: 179 | nex.append(f'[unused{(d_end[nex_ptr] + 2) * 2 + 2}]') 180 | nex_ptr += 1 181 | prev.reverse() 182 | ret = prev + ret + nex 183 | return ret 184 | 185 | def process(tokenizer, h, t, doc0, doc1): 186 | 187 | ht_markers = ["[unused" + str(i) + "]" for i in range(1, 5)] 188 | b_markers = ["[unused" + str(i) + "]" for i in range(5, 101)] 189 | max_blk_num = CAPACITY // (BLOCK_SIZE + 1) 190 | cnt, batches = 0, [] 191 | d = [] 192 | 193 | def fix_entity(doc, ht_markers, b_markers): 194 | markers = ht_markers + b_markers 195 | markers_pos = [] 196 | if list(set(doc).intersection(set(markers))): 197 | for marker in markers: 198 | try: 199 | pos = doc.index(marker) 200 | markers_pos.append((pos, marker)) 201 | except ValueError as e: 202 | continue 203 | 204 | idx = 0 205 | while idx <= len(markers_pos)-1: 206 | try: 207 | assert (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) % 2 == 1) and (int(markers_pos[idx][1].replace("[unused", "").replace("]", "")) - int(markers_pos[idx+1][1].replace("[unused", "").replace("]", "")) == -1) 208 | entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]] 209 | while "." in entity_name: 210 | assert doc[markers_pos[idx][0] + entity_name.index(".") + 1] == "." 211 | doc[markers_pos[idx][0] + entity_name.index(".") + 1] = "|" 212 | entity_name = doc[markers_pos[idx][0]+1: markers_pos[idx + 1][0]] 213 | idx += 2 214 | except: 215 | #pdb.set_trace() 216 | idx += 1 217 | continue 218 | return doc 219 | 220 | d0 = fix_entity(doc0, ht_markers, b_markers) 221 | d1 = fix_entity(doc1, ht_markers, b_markers) 222 | 223 | for di in [d0, d1]: 224 | d.extend(di) 225 | d0_buf, cnt = Buffer.split_document_into_blocks(d0, tokenizer, cnt=cnt, hard=False, docid=0) 226 | d1_buf, cnt = Buffer.split_document_into_blocks(d1, tokenizer, cnt=cnt, hard=False, docid=1) 227 | dbuf = Buffer() 228 | dbuf.blocks = d0_buf.blocks + d1_buf.blocks 229 | for blk in dbuf: 230 | if list(set(tokenizer.convert_tokens_to_ids(ht_markers)).intersection(set(blk.ids))): 231 | blk.relevance = 2 232 | elif list(set(tokenizer.convert_tokens_to_ids(b_markers)).intersection(set(blk.ids))): 233 | blk.relevance = 1 234 | else: 235 | continue 236 | ret = [] 237 | 238 | n0 = 1 239 | pbuf_ht, nbuf_ht = dbuf.filtered(lambda blk, idx: blk.relevance >= 2, need_residue=True) 240 | pbuf_b, nbuf_b = nbuf_ht.filtered(lambda blk, idx: blk.relevance >= 1, need_residue=True) 241 | 242 | for i in range(n0): 243 | _selected_htblks = random.sample(pbuf_ht.blocks, min(max_blk_num, len(pbuf_ht))) 244 | _selected_pblks = random.sample(pbuf_b.blocks, min(max_blk_num - len(_selected_htblks), len(pbuf_b))) 245 | _selected_nblks = random.sample(nbuf_b.blocks, min(max_blk_num - len(_selected_pblks) - len(_selected_htblks), len(nbuf_b))) 246 | buf = Buffer() 247 | buf.blocks = _selected_htblks + _selected_pblks + _selected_nblks 248 | ret.append(buf.sort_()) 249 | ret[0][0].ids.insert(0, tokenizer.convert_tokens_to_ids(tokenizer.cls_token)) 250 | return ret[0] 251 | 252 | 253 | def process_example_simp(h, t, doc1, doc2, tokenizer, max_len, redisd, no_additional_marker, mask_entity): 254 | max_len = 99999 255 | bert_max_len = 512 256 | doc1 = json.loads(redisd.get('codred-doc-' + doc1)) 257 | doc2 = json.loads(redisd.get('codred-doc-' + doc2)) 258 | v_h = None 259 | for entity in doc1['entities']: 260 | if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None: 261 | v_h = entity 262 | assert v_h is not None 263 | v_t = None 264 | for entity in doc2['entities']: 265 | if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None: 266 | v_t = entity 267 | assert v_t is not None 268 | d1_v = dict() 269 | for entity in doc1['entities']: 270 | if 'Q' in entity: 271 | d1_v[entity['Q']] = entity 272 | d2_v = dict() 273 | for entity in doc2['entities']: 274 | if 'Q' in entity: 275 | d2_v[entity['Q']] = entity 276 | ov = set(d1_v.keys()) & set(d2_v.keys()) 277 | if len(ov) > 40: 278 | ov = set(random.choices(list(ov), k=40)) 279 | ov = list(ov) 280 | ma = dict() 281 | for e in ov: 282 | ma[e] = len(ma) 283 | d1_start = dict() 284 | d1_end = dict() 285 | for entity in doc1['entities']: 286 | if 'Q' in entity and entity['Q'] in ma: 287 | for span in entity['spans']: 288 | d1_start[span[0]] = ma[entity['Q']] 289 | d1_end[span[1] - 1] = ma[entity['Q']] 290 | d2_start = dict() 291 | d2_end = dict() 292 | for entity in doc2['entities']: 293 | if 'Q' in entity and entity['Q'] in ma: 294 | for span in entity['spans']: 295 | d2_start[span[0]] = ma[entity['Q']] 296 | d2_end[span[1] - 1] = ma[entity['Q']] 297 | k1 = gen_c(tokenizer, doc1['tokens'], v_h['spans'][0], max_len, ['[unused1]', '[unused2]'], d1_start, d1_end, no_additional_marker, mask_entity) 298 | k2 = gen_c(tokenizer, doc2['tokens'], v_t['spans'][0], max_len, ['[unused3]', '[unused4]'], d2_start, d2_end, no_additional_marker, mask_entity) 299 | 300 | #pdb.set_trace() 301 | selected_rets = process(tokenizer, v_h['name'], v_t['name'], k1, k2) 302 | 303 | return selected_rets 304 | 305 | 306 | def collate_fn(batch, args, relation2id, tokenizer, redisd, encoder, sbert_wk): 307 | #assert len(batch) == 1 308 | if batch[0][-1] == 'o': 309 | batch = batch[0] 310 | h, t = batch[0].split('#') 311 | r = relation2id[batch[1]] 312 | dps = batch[2] 313 | if len(dps) > 8: 314 | dps = random.choices(dps, k=8) 315 | dplabel = list() 316 | selected_rets = list() 317 | for doc1, doc2, l in dps: 318 | 319 | selected_ret = process_example_simp(h, t, doc1, doc2, tokenizer, args.seq_len, redisd, args.no_additional_marker, args.mask_entity) 320 | 321 | for s_blk in selected_ret: 322 | while(tokenizer.convert_tokens_to_ids("|") in s_blk.ids): 323 | s_blk.ids[s_blk.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".") 324 | dplabel.append(relation2id[l]) 325 | selected_rets.append(selected_ret) 326 | dplabel_t = torch.tensor(dplabel, dtype=torch.int64) 327 | rs_t = torch.tensor([r], dtype=torch.int64) 328 | 329 | 330 | 331 | selected_inputs = torch.zeros(4, len(dps), CAPACITY, dtype=torch.int64) 332 | for dp, buf in enumerate(selected_rets): 333 | buf.export_01_turn(out=(selected_inputs[0, dp], selected_inputs[1, dp], selected_inputs[2, dp])) 334 | 335 | selected_ids = selected_inputs[0] 336 | selected_att_mask = selected_inputs[1] 337 | selected_token_type = selected_inputs[2] 338 | selected_labels = selected_inputs[3] 339 | 340 | else: 341 | examples = batch[0] 342 | h_len = tokenizer.max_len_sentences_pair // 2 - 2 343 | t_len = tokenizer.max_len_sentences_pair - tokenizer.max_len_sentences_pair // 2 - 2 344 | _input_ids = list() 345 | _token_type_ids = list() 346 | _attention_mask = list() 347 | _rs = list() 348 | selected_rets = list() 349 | for idx, example in enumerate(examples): 350 | doc = json.loads(redisd.get(f'dsre-doc-{example[0]}')) 351 | _, h_start, h_end, t_start, t_end, r = example 352 | if r in relation2id: 353 | r = relation2id[r] 354 | else: 355 | r = 'n/a' 356 | h_1, h_2 = expand(h_start, h_end, len(doc), h_len) 357 | t_1, t_2 = expand(t_start, t_end, len(doc), t_len) 358 | h_tokens = doc[h_1:h_start] + ['[unused1]'] + doc[h_start:h_end] + ['[unused2]'] + doc[h_end:h_2] 359 | t_tokens = doc[t_1:t_start] + ['[unused3]'] + doc[t_start:t_end] + ['[unused4]'] + doc[t_end:t_2] 360 | h_name = doc[h_start:h_end] 361 | t_name = doc[t_start:t_end] 362 | h_token_ids = tokenizer.convert_tokens_to_ids(h_tokens) 363 | t_token_ids = tokenizer.convert_tokens_to_ids(t_tokens) 364 | selected_ret = process(tokenizer, " ".join(doc[h_start:h_end]), " ".join(doc[t_start:t_end]), h_tokens, t_tokens) 365 | for s_blk in selected_ret: 366 | while(tokenizer.convert_tokens_to_ids("|") in s_blk.ids): 367 | s_blk.ids[s_blk.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".") 368 | input_ids = tokenizer.build_inputs_with_special_tokens(h_token_ids, t_token_ids) 369 | token_type_ids = tokenizer.create_token_type_ids_from_sequences(h_token_ids, t_token_ids) 370 | obj = tokenizer._pad({'input_ids': input_ids, 'token_type_ids': token_type_ids}, max_length=args.seq_len, padding_strategy='max_length') 371 | _input_ids.append(obj['input_ids']) 372 | _token_type_ids.append(obj['token_type_ids']) 373 | _attention_mask.append(obj['attention_mask']) 374 | _rs.append(r) 375 | selected_rets.append(selected_ret) 376 | dplabel_t = torch.tensor(_rs, dtype=torch.long) 377 | rs_t = None 378 | r = None 379 | selected_inputs = torch.zeros(4, len(examples), CAPACITY, dtype=torch.int64) 380 | for ex, buf in enumerate(selected_rets): 381 | buf.export_01_turn(out=(selected_inputs[0, ex], selected_inputs[1, ex], selected_inputs[2, ex])) 382 | selected_ids = selected_inputs[0] 383 | selected_att_mask = selected_inputs[1] 384 | selected_token_type = selected_inputs[2] 385 | selected_labels = selected_inputs[3] 386 | return dplabel_t, rs_t, [r], selected_ids, selected_att_mask, selected_token_type, selected_labels, selected_rets 387 | 388 | 389 | def collate_fn_infer(batch, args, relation2id, tokenizer, redisd, encoder, sbert_wk): 390 | #assert len(batch) == 1 391 | batch = batch[0] 392 | h, t = batch[0].split('#') 393 | rs = [relation2id[r] for r in batch[1]] 394 | dps = batch[2] 395 | selected_rets = list() 396 | for doc1, doc2, l in dps: 397 | selected_ret = process_example_simp(h, t, doc1, doc2, tokenizer, args.seq_len, redisd, args.no_additional_marker, args.mask_entity) 398 | for s_blk in selected_ret: 399 | while(tokenizer.convert_tokens_to_ids("|") in s_blk.ids): 400 | s_blk.ids[s_blk.ids.index(tokenizer.convert_tokens_to_ids("|"))] = tokenizer.convert_tokens_to_ids(".") 401 | selected_rets.append(selected_ret) 402 | 403 | selected_inputs = torch.zeros(4, len(dps), CAPACITY, dtype=torch.int64) 404 | for dp, buf in enumerate(selected_rets): 405 | buf.export_01_turn(out=(selected_inputs[0, dp], selected_inputs[1, dp], selected_inputs[2, dp])) 406 | 407 | selected_ids = selected_inputs[0] 408 | selected_att_mask = selected_inputs[1] 409 | selected_token_type = selected_inputs[2] 410 | selected_labels = selected_inputs[3] 411 | 412 | return h, rs, t, selected_ids, selected_att_mask, selected_token_type, selected_labels, selected_rets 413 | 414 | 415 | 416 | 417 | class Codred(torch.nn.Module): 418 | def __init__(self, args, num_relations): 419 | super().__init__() 420 | self.bert = BertModel.from_pretrained('bert-base-cased') 421 | self.predictor = torch.nn.Linear(self.bert.config.hidden_size, num_relations) 422 | weight = torch.ones(num_relations, dtype=torch.float32) 423 | weight[0] = 0.1 424 | self.d_model = 768 425 | self.reduced_dim = 256 426 | self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=weight, reduction='none') 427 | self.aggregator = args.aggregator 428 | self.no_doc_pair_supervision = args.no_doc_pair_supervision 429 | self.matt = MatTransformer(h = 8 , d_model = self.d_model , hidden_size = 1024 , num_layers = 4 , device = torch.device(0)) 430 | 431 | self.graph_enc = GraphEncoder(h = 8 , d_model = self.d_model , hidden_size = 1024 , num_layers = 4) 432 | self.wu = nn.Linear(self.d_model , self.d_model) 433 | self.wv = nn.Linear(self.d_model , self.d_model) 434 | self.wi = nn.Linear(self.d_model , self.d_model) 435 | self.ln1 = nn.Linear(self.d_model , self.d_model) 436 | self.ln1_gnn = nn.Linear(2* self.d_model , self.d_model) 437 | self.dim_reduction = nn.Linear(self.d_model, self.reduced_dim) 438 | self.reduced_predictor = torch.nn.Linear(self.reduced_dim, num_relations) 439 | self.gamma = 2 440 | self.alpha = 0.25 441 | self.beta = 0.01 442 | self.d_k = 64 443 | self.num_relations = num_relations 444 | self.ent_emb = nn.Parameter(torch.zeros(2 , self.d_model)) 445 | self.gnn = True 446 | self.norm = nn.LayerNorm(self.d_model) 447 | self.att_net = Attention(h=self.d_model, d_model=self.d_model) 448 | self.s_linear = torch.nn.Linear(self.d_model, 2) 449 | self.dotsim = DotProductSimilarity(scale_output=False) 450 | 451 | 452 | def forward(self, input_ids, token_type_ids, attention_mask, dplabel=None, rs=None, train=True): 453 | # input_ids: T(num_sentences, seq_len) 454 | # token_type_ids: T(num_sentences, seq_len) 455 | # attention_mask: T(num_sentences, seq_len) 456 | # rs: T(batch_size) 457 | # maps: T(batch_size, max_bag_size) 458 | # embedding: T(num_sentences, seq_len, embedding_size) 459 | bag_len, seq_len = input_ids.size() 460 | embedding, _ = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False) 461 | # r_embedding: T(num_sentences, embedding_size) 462 | p_embedding = embedding[:, 0, :] 463 | if bag_len>8: 464 | print("bag_len:", bag_len) 465 | if rs is not None or not train: 466 | entity_mask, entity_span_list = self.get_htb(input_ids) 467 | h_embs = [] 468 | t_embs = [] 469 | b_embs = [] 470 | dp_embs = [] 471 | h_num = [] 472 | t_num = [] 473 | b_num = [] 474 | for dp in range(0,bag_len): 475 | b_embs_dp = [] 476 | try: 477 | h_span = entity_span_list[dp][0] 478 | t_span = entity_span_list[dp][1] 479 | b_span_chunks = entity_span_list[dp][2] 480 | h_emb = torch.max(embedding[dp, h_span[0]:h_span[-1]+1], dim=0)[0] 481 | t_emb = torch.max(embedding[dp, t_span[0]:t_span[-1]+1], dim=0)[0] 482 | h_embs.append(h_emb) 483 | t_embs.append(t_emb) 484 | for b_span in b_span_chunks: 485 | b_emb = torch.max(embedding[dp, b_span[0]:b_span[1]+1], dim=0)[0] 486 | b_embs_dp.append(b_emb) 487 | if bag_len >= 16: 488 | if len(b_embs_dp) > 3: 489 | b_embs_dp = random.choices(b_embs_dp, k=3) 490 | if bag_len >= 14: 491 | if len(b_embs_dp) > 4: 492 | b_embs_dp = random.choices(b_embs_dp, k=4) 493 | elif bag_len >= 10: 494 | if len(b_embs_dp) > 5: 495 | b_embs_dp = random.choices(b_embs_dp, k=5) 496 | else: 497 | if len(b_embs_dp) > 8: 498 | b_embs_dp = random.choices(b_embs_dp, k=8) 499 | else: 500 | b_embs_dp = b_embs_dp 501 | b_embs.append(b_embs_dp) 502 | h_num.append(1) 503 | t_num.append(1) 504 | b_num.append(len(b_embs_dp)) 505 | dp_embs.append(p_embedding[dp]) 506 | except IndexError as e: 507 | continue 508 | print(bag_len, b_num) 509 | htb_index = [] 510 | htb_embs = [] 511 | htb_start = [0] 512 | htb_end = [] 513 | for h_emb, t_emb, b_emb in zip(h_embs, t_embs, b_embs): 514 | htb_embs.extend([h_emb,t_emb]) 515 | htb_index.extend([1,2]) 516 | htb_embs.extend(b_emb) 517 | htb_index.extend([3]*len(b_emb)) 518 | htb_end.append(len(htb_index)-1) 519 | htb_start.append(len(htb_index)) 520 | htb_start = htb_start[:-1] 521 | 522 | 523 | rel_mask = torch.ones(1,len(htb_index), len(htb_index)).to(embedding.device) 524 | try: 525 | htb_embs_t = torch.stack(htb_embs, dim=0).unsqueeze(0) 526 | except: 527 | print(input_ids) 528 | 529 | 530 | u = self.wu(htb_embs_t) 531 | v = self.wv(htb_embs_t) 532 | 533 | alpha = u.view(1, len(htb_index), 1, htb_embs_t.size()[-1]) + v.view(1, 1, len(htb_index), htb_embs_t.size()[-1]) 534 | alpha = F.relu(alpha) 535 | 536 | rel_enco = F.relu(self.ln1(alpha)) 537 | bs,es,es,d = rel_enco.size() 538 | 539 | rel_mask = torch.ones(1,len(htb_index), len(htb_index)).to(embedding.device) 540 | 541 | rel_enco_m = self.matt(rel_enco, rel_mask) 542 | h_pos = [] 543 | t_pos = [] 544 | for i, e_type in enumerate(htb_index): 545 | if e_type == 1: 546 | h_pos.append(i) 547 | elif e_type == 2: 548 | t_pos.append(i) 549 | else: 550 | continue 551 | assert len(h_pos) == len(t_pos) 552 | rel_enco_m_ht = [] 553 | for i,j in zip(h_pos, t_pos): 554 | rel_enco_m_ht.append(rel_enco_m[0][i][j]) 555 | t_feature_m = torch.stack(rel_enco_m_ht) 556 | predict_logits = self.predictor(t_feature_m) 557 | ht_logits = predict_logits 558 | bag_logit = torch.max(ht_logits.transpose(0,1),dim=1)[0].unsqueeze(0) 559 | path_logit = ht_logits 560 | else: # Inner doc 561 | entity_mask, entity_span_list = self.get_htb(input_ids) 562 | path_logits = [] 563 | ht_logits_flatten_list = [] 564 | for dp in range(0,bag_len): 565 | h_embs = [] 566 | t_embs = [] 567 | b_embs = [] 568 | try: 569 | h_span = entity_span_list[dp][0] 570 | t_span = entity_span_list[dp][1] 571 | b_span_chunks = entity_span_list[dp][2] 572 | h_emb = torch.max(embedding[dp, h_span[0]:h_span[-1]+1], dim=0)[0] 573 | t_emb = torch.max(embedding[dp, t_span[0]:t_span[-1]+1], dim=0)[0] 574 | h_embs.append(h_emb) 575 | t_embs.append(t_emb) 576 | for b_span in b_span_chunks: 577 | b_emb = torch.max(embedding[dp, b_span[0]:b_span[1]+1], dim=0)[0] 578 | b_embs.append(b_emb) 579 | h_index = [1 for _ in h_embs] 580 | t_index = [2 for _ in t_embs] 581 | b_index = [3 for _ in b_embs] 582 | htb_index = [] 583 | htb_embs = [] 584 | for idx, embs in zip([h_index, t_index, b_index], [h_embs, t_embs, b_embs]): 585 | htb_index.extend(idx) 586 | htb_embs.extend(embs) 587 | rel_mask = torch.ones(1,len(htb_index), len(htb_index)).to(embedding.device) 588 | 589 | htb_embs_t = torch.stack(htb_embs, dim=0).unsqueeze(0) 590 | 591 | u = self.wu(htb_embs_t) 592 | v = self.wv(htb_embs_t) 593 | alpha = u.view(1, len(htb_index), 1, htb_embs_t.size()[-1]) + v.view(1, 1, len(htb_index), htb_embs_t.size()[-1]) 594 | alpha = F.relu(alpha) 595 | 596 | rel_enco = F.relu(self.ln1(alpha)) 597 | 598 | rel_enco_m = self.matt(rel_enco , rel_mask) 599 | 600 | t_feature = rel_enco_m 601 | bs,es,es,d = rel_enco.size() 602 | 603 | predict_logits = self.predictor(t_feature.reshape(bs,es,es,d)) 604 | ht_logits = predict_logits[0][:len(h_index), len(h_index):len(h_index)+len(t_index)] 605 | _ht_logits_flatten = ht_logits.reshape(1, -1, self.num_relations) 606 | ht_logits = predict_logits[0][:len(h_index), len(h_index):len(h_index)+len(t_index)] 607 | path_logits.append(ht_logits) 608 | ht_logits_flatten_list.append(_ht_logits_flatten) 609 | except Exception as e: 610 | print(e) 611 | pdb.set_trace() 612 | try: 613 | path_logit = torch.stack(path_logits).reshape(1, 1, -1, self.num_relations).squeeze(0).squeeze(0) 614 | except Exception as e: 615 | print(e) 616 | pdb.set_trace() 617 | 618 | 619 | if dplabel is not None and rs is None: 620 | ht_logits_flatten = torch.stack(ht_logits_flatten_list).squeeze(1) 621 | ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1) 622 | y_true = torch.zeros_like(ht_logits_flatten) 623 | for idx, dpl in enumerate(dplabel): 624 | y_true[idx, 0, dpl.item()] = 1 625 | bag_logit = path_logit 626 | loss = self._multilabel_categorical_crossentropy(ht_logits_flatten, y_true, ht_fixed_low+2, ht_fixed_low) 627 | elif rs is not None: 628 | _, prediction = torch.max(bag_logit, dim=1) 629 | if self.no_doc_pair_supervision: 630 | pass 631 | else: 632 | ht_logits_flatten = ht_logits.unsqueeze(1) 633 | y_true = torch.zeros_like(ht_logits_flatten) 634 | ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1) 635 | if rs.item() != 0: 636 | for idx, dpl in enumerate(dplabel): 637 | try: 638 | y_true[idx, : , dpl.item()] = torch.ones_like(y_true[idx, : , dpl.item()]) 639 | except: 640 | print("unmatched") 641 | #pdb.set_trace() 642 | loss = self._multilabel_categorical_crossentropy(ht_logits_flatten, y_true, ht_fixed_low+2, ht_fixed_low) 643 | 644 | else: 645 | ht_logits_flatten = ht_logits.unsqueeze(1) 646 | ht_fixed_low = (torch.ones_like(ht_logits_flatten)*8)[:,:,0].unsqueeze(-1) 647 | _, prediction = torch.max(bag_logit, dim=1) 648 | loss = None 649 | prediction = [] 650 | return loss, prediction, bag_logit, ht_logits_flatten.transpose(0,1), (ht_fixed_low+2).transpose(0,1) 651 | 652 | 653 | def _multilabel_categorical_crossentropy(self, y_pred, y_true, cr_ceil, cr_low, ghm=True, r_dropout=True): 654 | # cr_low + 2 = cr_ceil 655 | y_pred = (1 - 2 * y_true) * y_pred 656 | y_pred_neg = y_pred - y_true * 1e12 657 | y_pred_pos = y_pred - (1 - y_true) * 1e12 658 | y_pred_neg = torch.cat([y_pred_neg, cr_ceil], dim=-1) 659 | y_pred_pos = torch.cat([y_pred_pos, -cr_low], dim=-1) 660 | neg_loss = torch.logsumexp(y_pred_neg, dim=-1) 661 | pos_loss = torch.logsumexp(y_pred_pos, dim=-1) 662 | 663 | return ((neg_loss + pos_loss + cr_low.squeeze(-1) - cr_ceil.squeeze(-1))).mean() 664 | 665 | def graph_encode(self , ent_encode , rel_encode , ent_mask , rel_mask): 666 | bs , ne , d = ent_encode.size() 667 | ent_encode = ent_encode + self.ent_emb[0].view(1,1,d) 668 | rel_encode = rel_encode + self.ent_emb[1].view(1,1,1,d) 669 | rel_encode , ent_encode = self.graph_enc(rel_encode , ent_encode , rel_mask , ent_mask) 670 | return rel_encode 671 | 672 | 673 | def get_htb(self, input_ids): 674 | htb_mask_list = [] 675 | htb_list_batch = [] 676 | for pi in range(input_ids.size()[0]): 677 | #pdb.set_trace() 678 | tmp = torch.nonzero(input_ids[pi] - torch.full(([input_ids.size()[1]]), 1).to(input_ids.device)) 679 | if tmp.size()[0] < input_ids.size()[0]: 680 | print(input_ids) 681 | try: 682 | h_starts = [i[0] for i in (input_ids[pi]==H_START_MARKER_ID).nonzero().detach().tolist()] 683 | h_ends = [i[0] for i in (input_ids[pi]==H_END_MARKER_ID).nonzero().detach().tolist()] 684 | t_starts = [i[0] for i in (input_ids[pi]==T_START_MARKER_ID).nonzero().detach().tolist()] 685 | t_ends = [i[0] for i in (input_ids[pi]==T_END_MARKER_ID).nonzero().detach().tolist()] 686 | if len(h_starts) == len(h_ends): 687 | h_start = h_starts[0] 688 | h_end = h_ends[0] 689 | else: 690 | for h_s in h_starts: 691 | for h_e in h_ends: 692 | if 0 < h_e - h_s < 20: 693 | h_start = h_s 694 | h_end = h_e 695 | break 696 | if len(t_starts) == len(t_ends): 697 | t_start = t_starts[0] 698 | t_end = t_ends[0] 699 | else: 700 | for t_s in t_starts: 701 | for t_e in t_ends: 702 | if 0 < t_e - t_s < 20: 703 | t_start = t_s 704 | t_end = t_e 705 | break 706 | if h_end-h_start<=0 or t_end-t_start<=0: 707 | # print(h_starts) 708 | # print(h_ends) 709 | # print(t_starts) 710 | # print(t_ends) 711 | # pdb.set_trace() 712 | if h_end-h_start<=0: 713 | for h_s in h_starts: 714 | for h_e in h_ends: 715 | if 0 < h_e - h_s < 20: 716 | h_start = h_s 717 | h_end = h_e 718 | break 719 | if t_end-t_start<=0: 720 | for t_s in t_starts: 721 | for t_e in t_ends: 722 | if 0 < t_e - t_s < 20: 723 | t_start = t_s 724 | t_end = t_e 725 | break 726 | if h_end-h_start<=0 or t_end-t_start<=0: 727 | pdb.set_trace() 728 | 729 | b_spans = torch.nonzero(torch.gt(torch.full(([input_ids.size()[1]]), 99).to(input_ids.device), input_ids[pi])).squeeze(0).squeeze(1).detach().tolist() 730 | token_len = input_ids[pi].nonzero().size()[0] 731 | b_spans = [i for i in b_spans if i <= token_len-1] 732 | assert len(b_spans) >= 4 733 | for i in h_starts + h_ends + t_starts + t_ends: 734 | b_spans.remove(i) 735 | h_span = [h_pos for h_pos in range(h_start, h_end+1)] 736 | t_span = [t_pos for t_pos in range(t_start, t_end+1)] 737 | h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(h_span).to(input_ids.device), 1) 738 | t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(t_span).to(input_ids.device), 1) 739 | except:# dps<8 740 | #pdb.set_trace() 741 | h_span = [] 742 | t_span = [] 743 | h_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device) 744 | t_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device) 745 | b_spans = [] 746 | b_span_ = [] 747 | if len(b_spans) > 0 and len(b_spans)%2==0: 748 | b_span_chunks = [b_spans[i:i+2] for i in range(0,len(b_spans),2)] 749 | b_span = [] 750 | for span in b_span_chunks: 751 | b_span.extend([b_pos for b_pos in range(span[0], span[1]+1)]) 752 | b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span).to(input_ids.device), 1) 753 | b_span_.extend(b_span) 754 | elif len(b_spans) > 0 and len(b_spans)%2==1: 755 | b_span = [] 756 | ptr = 0 757 | #pdb.set_trace() 758 | while(ptr<=len(b_spans)-1): 759 | try: 760 | if input_ids[pi][b_spans[ptr+1]] - input_ids[pi][b_spans[ptr]] == 1: 761 | b_span.append([b_spans[ptr], b_spans[ptr+1]]) 762 | ptr += 2 763 | else: 764 | ptr += 1 765 | except IndexError as e: 766 | ptr += 1 767 | for bs in b_span: 768 | #pdb.set_trace() 769 | #ex_bs = range(bs[0], bs[1]) 770 | b_span_.extend(bs) 771 | if len(b_span_)%2 != 0: 772 | print(b_spans) 773 | b_span_chunks = [b_span_[i:i+2] for i in range(0,len(b_span_),2)] 774 | b_mask = torch.zeros_like(input_ids[pi]).to(input_ids.device).scatter(0, torch.tensor(b_span_).to(input_ids.device), 1) 775 | else: 776 | b_span_ = [] 777 | b_span_chunks = [] 778 | b_mask = torch.zeros_like(input_ids[pi]) 779 | htb_mask = torch.concat([h_mask.unsqueeze(0), t_mask.unsqueeze(0), b_mask.unsqueeze(0)], dim=0) #[3,512] 780 | htb_mask_list.append(htb_mask) 781 | htb_list_batch.append([h_span, t_span, b_span_chunks]) 782 | # pdb.set_trace() 783 | htb_mask_batch = torch.stack(htb_mask_list,dim=0) 784 | return htb_mask_batch, htb_list_batch # 785 | 786 | def get_doc_entities(h, t, tokenizer, redisd, no_additional_marker, mask_entity, collec_doc1_titles, collec_doc2_titles): 787 | max_len = 99999 788 | bert_max_len = 512 789 | Doc1_tokens = [] 790 | Doc2_tokens = [] 791 | B_entities = [] 792 | for doc1_title, doc2_title in zip(collec_doc1_titles, collec_doc2_titles): 793 | doc1 = json.loads(redisd.get('codred-doc-' + doc1_title)) 794 | doc2 = json.loads(redisd.get('codred-doc-' + doc2_title)) 795 | v_h = None 796 | for entity in doc1['entities']: 797 | if 'Q' in entity and 'Q' + str(entity['Q']) == h and v_h is None: 798 | v_h = entity 799 | assert v_h is not None 800 | v_t = None 801 | for entity in doc2['entities']: 802 | if 'Q' in entity and 'Q' + str(entity['Q']) == t and v_t is None: 803 | v_t = entity 804 | assert v_t is not None 805 | d1_v = dict() 806 | for entity in doc1['entities']: 807 | if 'Q' in entity: 808 | d1_v[entity['Q']] = entity 809 | d2_v = dict() 810 | for entity in doc2['entities']: 811 | if 'Q' in entity: 812 | d2_v[entity['Q']] = entity 813 | ov = set(d1_v.keys()) & set(d2_v.keys()) 814 | if len(ov) > 40: 815 | ov = set(random.choices(list(ov), k=40)) 816 | ov = list(ov) 817 | ma = dict() 818 | for e in ov: 819 | ma[e] = len(ma) 820 | B_entities.append(ma) 821 | 822 | # print(B_entities) 823 | 824 | return B_entities 825 | 826 | class CodredCallback(TrainerCallback): 827 | def __init__(self): 828 | super().__init__() 829 | 830 | def on_argument(self, parser): 831 | parser.add_argument('--seq_len', type=int, default=512) 832 | parser.add_argument('--aggregator', type=str, default='attention') 833 | parser.add_argument('--positive_only', action='store_true') 834 | parser.add_argument('--positive_ep_only', action='store_true') 835 | parser.add_argument('--no_doc_pair_supervision', action='store_true') 836 | parser.add_argument('--no_additional_marker', action='store_true') 837 | parser.add_argument('--mask_entity', action='store_true') 838 | parser.add_argument('--single_path', action='store_true') 839 | parser.add_argument('--dsre_only', action='store_true') 840 | parser.add_argument('--raw_only', action='store_true') 841 | parser.add_argument('--load_model_path', type=str, default=None) 842 | parser.add_argument('--train_file', type=str, default='../data/rawdata/train_dataset.json') 843 | parser.add_argument('--dev_file', type=str, default='../data/rawdata/dev_dataset.json') 844 | parser.add_argument('--test_file', type=str, default='../data/rawdata/test_dataset.json') 845 | parser.add_argument('--dsre_file', type=str, default='../data/dsre_train_examples.json') 846 | parser.add_argument('--model_name', type=str, default='bert') 847 | 848 | 849 | def load_model(self): 850 | relations = json.load(open('../data/rawdata/relations.json')) 851 | relations.sort() 852 | self.relations = ['n/a'] + relations 853 | self.relation2id = dict() 854 | for index, relation in enumerate(self.relations): 855 | self.relation2id[relation] = index 856 | with self.trainer.cache(): 857 | reasoner = Codred(self.args, len(self.relations)) 858 | if self.args.load_model_path: 859 | load_model(reasoner, self.args.load_model_path) 860 | tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True) 861 | self.tokenizer = tokenizer 862 | self.bert = BertModel.from_pretrained('bert-base-cased') 863 | self.sbert_wk = sbert(device='cuda') 864 | return reasoner 865 | 866 | def load_data(self): 867 | train_dataset = json.load(open(self.args.train_file)) 868 | dev_dataset = json.load(open(self.args.dev_file)) 869 | test_dataset = json.load(open(self.args.test_file)) 870 | if self.args.positive_only: 871 | train_dataset = [d for d in train_dataset if d[3] != 'n/a'] 872 | dev_dataset = [d for d in dev_dataset if d[3] != 'n/a'] 873 | test_dataset = [d for d in test_dataset if d[3] != 'n/a'] 874 | train_bags = place_train_data(train_dataset) 875 | dev_bags = place_dev_data(dev_dataset, self.args.single_path) 876 | test_bags = place_test_data(test_dataset, self.args.single_path) 877 | if self.args.positive_ep_only: 878 | train_bags = [b for b in train_bags if b[1] != 'n/a'] 879 | dev_bags = [b for b in dev_bags if 'n/a' not in b[1]] 880 | test_bags = [b for b in test_bags if 'n/a' not in b[1]] 881 | self.dsre_train_dataset = json.load(open(self.args.dsre_file)) 882 | self.dsre_train_dataset = [d for i, d in enumerate(self.dsre_train_dataset) if i % 10 == 0] 883 | d = list() 884 | for i in range(len(self.dsre_train_dataset) // 8): 885 | d.append(self.dsre_train_dataset[8 * i:8 * i + 8]) 886 | if self.args.raw_only: 887 | pass 888 | elif self.args.dsre_only: 889 | train_bags = d 890 | else: 891 | d.extend(train_bags) 892 | train_bags = d 893 | self.redisd = redis.Redis(host='localhost', port=6379, decode_responses=True, db=0) 894 | with self.trainer.once(): 895 | self.train_logger = Logger(['train_loss', 'train_acc', 'train_pos_acc', 'train_dsre_acc'], self.trainer.writer, self.args.logging_steps, self.args.local_rank) 896 | self.dev_logger = Logger(['dev_mean_prec', 'dev_f1', 'dev_auc'], self.trainer.writer, 1, self.args.local_rank) 897 | self.test_logger = Logger(['test_mean_prec', 'test_f1', 'test_auc'], self.trainer.writer, 1, self.args.local_rank) 898 | return train_bags, dev_bags, test_bags 899 | 900 | def collate_fn(self): 901 | return partial(collate_fn, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd, encoder = self.bert, sbert_wk=self.sbert_wk), partial(collate_fn_infer, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd, encoder=self.bert, sbert_wk=self.sbert_wk), partial(collate_fn_infer, args=self.args, relation2id=self.relation2id, tokenizer=self.tokenizer, redisd=self.redisd, encoder=self.bert, sbert_wk=self.sbert_wk) 902 | 903 | def on_train_epoch_start(self, epoch): 904 | pass 905 | 906 | def on_train_step(self, step, train_step, inputs, extra, loss, outputs): 907 | with self.trainer.once(): 908 | self.train_logger.log(train_loss=loss) 909 | if inputs['rs'] is not None: 910 | _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs 911 | rs = extra['rs'] 912 | if ht_logits_flatten is not None: 913 | r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1) 914 | if r_score>ht_threshold_flatten[0, 0, 0]: 915 | prediction = [r_idx.item()] 916 | else: 917 | prediction = [0] 918 | 919 | for p, score, gold in zip(prediction, logit, rs): 920 | self.train_logger.log(train_acc=1 if p == gold else 0) 921 | if gold > 0: 922 | self.train_logger.log(train_pos_acc=1 if p == gold else 0) 923 | else: 924 | _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs 925 | dplabel = inputs['dplabel'] 926 | logit, dplabel = tensor_to_obj(logit, dplabel) 927 | prediction = [] 928 | if ht_logits_flatten is not None: 929 | r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1) 930 | for dp_i, (r_s, r_i) in enumerate(zip(r_score, r_idx)): 931 | if r_s > ht_threshold_flatten[dp_i, 0, 0]: 932 | prediction.append(r_i.item()) 933 | else: 934 | prediction.append(0) 935 | for p, l in zip(prediction, dplabel): 936 | self.train_logger.log(train_dsre_acc=1 if p == l else 0) 937 | 938 | def on_train_epoch_end(self, epoch): 939 | #for k,v in self.train_logger.d: 940 | print(epoch, self.train_logger.d) 941 | pass 942 | 943 | def on_dev_epoch_start(self, epoch): 944 | self._prediction = list() 945 | 946 | def on_dev_step(self, step, inputs, extra, outputs): 947 | _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs 948 | r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1) 949 | eval_logit = torch.max(ht_logits_flatten,dim=1)[0] 950 | 951 | if r_score>ht_threshold_flatten[:, 0, 0]: 952 | prediction = [r_idx.item()] 953 | else: 954 | prediction = [0] 955 | h, t, rs = extra['h'], extra['t'], extra['rs'] 956 | logit = tensor_to_obj(logit) 957 | self._prediction.append([prediction[0], eval_logit[0], h, t, rs]) 958 | 959 | def on_dev_epoch_end(self, epoch): 960 | self._prediction = self.trainer.distributed_broadcast(self._prediction) 961 | results = list() 962 | pred_result = list() 963 | facts = dict() 964 | for p, score, h, t, rs in self._prediction: 965 | rs = [self.relations[r] for r in rs] 966 | for i in range(1, len(score)): 967 | pred_result.append({'entpair': [h, t], 'relation': self.relations[i], 'score': score[i]}) 968 | results.append([h, rs, t, self.relations[p]]) 969 | for r in rs: 970 | if r != 'n/a': 971 | facts[(h, t, r)] = 1 972 | stat = eval_performance(facts, pred_result) 973 | with self.trainer.once(): 974 | self.dev_logger.log(dev_mean_prec=stat['mean_prec'], dev_f1=stat['f1'], dev_auc=stat['auc']) 975 | json.dump(stat, open(f'output/dev-stat-dual-K1-{epoch}.json', 'w')) 976 | json.dump(results, open(f'output/dev-results-dual-K1-{epoch}.json', 'w')) 977 | return stat['f1'] 978 | 979 | def on_test_epoch_start(self, epoch): 980 | self._prediction = list() 981 | pass 982 | 983 | def on_test_step(self, step, inputs, extra, outputs): 984 | #_, prediction, logit = outputs 985 | #pdb.set_trace() 986 | _, prediction, logit, ht_logits_flatten, ht_threshold_flatten = outputs 987 | r_score, r_idx = torch.max(torch.max(ht_logits_flatten,dim=1)[0], dim=-1) 988 | eval_logit = torch.max(ht_logits_flatten,dim=1)[0] 989 | 990 | if r_score>ht_threshold_flatten[0, 0, 0]: 991 | #prediction = [r_idx.item() + 1] 992 | prediction = [r_idx.item()] 993 | else: 994 | prediction = [0] 995 | h, t, rs = extra['h'], extra['t'], extra['rs'] 996 | logit = tensor_to_obj(logit) 997 | self._prediction.append([prediction[0], eval_logit[0], h, t, rs]) 998 | #self._prediction.append([prediction[0], logit[0], h, t, rs]) 999 | 1000 | def on_test_epoch_end(self, epoch): 1001 | self._prediction = self.trainer.distributed_broadcast(self._prediction) 1002 | results = list() 1003 | pred_result = list() 1004 | facts = dict() 1005 | out_results = list() 1006 | coda_file = dict() 1007 | coda_file['setting'] = 'closed' 1008 | for p, score, h, t, rs in self._prediction: 1009 | rs = [self.relations[r] for r in rs] 1010 | for i in range(1, len(score)): 1011 | pred_result.append({'entpair': [h, t], 'relation': self.relations[i], 'score': score[i]}) 1012 | out_results.append({'h_id':str(h), "t_id":str(t), "relation": str(self.relations[i]), "score": float(score[i])}) 1013 | results.append([h, rs, t, self.relations[p]]) 1014 | for r in rs: 1015 | if r != 'n/a': 1016 | facts[(h, t, r)] = 1 1017 | #stat = eval_performance(facts, pred_result) 1018 | coda_file['predictions'] = out_results 1019 | with self.trainer.once(): 1020 | json.dump(results, open(f'output/test-results-{epoch}.json', 'w')) 1021 | json.dump(coda_file, open(f'output/test-codalab-results-{epoch}.json', 'w')) 1022 | return True 1023 | 1024 | def process_train_data(self, data): 1025 | selected_inputs = { 1026 | 'input_ids': data[3], 1027 | 'attention_mask': data[4], 1028 | 'token_type_ids': data[5], 1029 | 'rs':data[1], 1030 | 'dplabel': data[0], 1031 | 'train': True 1032 | } 1033 | return {'rs': data[2]}, selected_inputs, {'selected_rets': data[7]} 1034 | 1035 | def process_dev_data(self, data): 1036 | selected_inputs = { 1037 | 'input_ids': data[3], 1038 | 'attention_mask': data[4], 1039 | 'token_type_ids': data[5], 1040 | 'train': False 1041 | } 1042 | return {'h': data[0], 'rs': data[1], 't': data[2]}, selected_inputs, {'selected_rets': data[7]} 1043 | 1044 | def process_test_data(self, data): 1045 | 1046 | selected_inputs = { 1047 | 'input_ids': data[3], 1048 | 'attention_mask': data[4], 1049 | 'token_type_ids': data[5], 1050 | 'train': False 1051 | } 1052 | return {'h': data[0], 'rs': data[1], 't': data[2]}, selected_inputs, {'selected_rets': data[7]} 1053 | 1054 | 1055 | def main(): 1056 | trainer = Trainer(CodredCallback()) 1057 | trainer.run() 1058 | 1059 | 1060 | if __name__ == '__main__': 1061 | main() 1062 | --------------------------------------------------------------------------------