├── images ├── readme.md └── biolama.png ├── requirements.txt ├── .gitignore ├── preprocessing ├── utils.py ├── README.md ├── get_stats_triples.py ├── process_umls.py ├── process_wikidata_triples.py ├── filter_length.py ├── aggregate_data.py ├── process_wikidata_entities.py └── process_ctd.py ├── BioLAMA ├── utils.py ├── data_loader.py ├── cli_demo.py ├── evaluator.py ├── run_manual.py ├── preprocessor.py ├── run_ie.py ├── best.py ├── run_optiprompt.py └── decoder.py ├── README.md └── LICENSE /images/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/biolama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/BioLAMA/HEAD/images/biolama.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.4 2 | transformers==4.4.1 3 | nltk 4 | tqdm 5 | wikidataintegrator 6 | stanza -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | data.tar.gz 4 | RoBERTa-base-PM-Voc 5 | output 6 | BioLAMA.egg-info 7 | analysis -------------------------------------------------------------------------------- /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | from subprocess import check_output 2 | import re 3 | import string 4 | 5 | def find_sub_list(sl,l): 6 | results=[] 7 | sll=len(sl) 8 | for ind in (i for i,e in enumerate(l) if e==sl[0]): 9 | if l[ind:ind+sll]==sl: 10 | results.append(ind) 11 | return results 12 | 13 | def is_obj_in_sbj(sbj, objs): 14 | objs = [obj.lower().split() for obj in objs] 15 | sbj = sbj.lower().split() 16 | 17 | for obj in objs: 18 | result = find_sub_list(sl=obj, l=sbj) 19 | if len(result) >0: 20 | return True, ' '.join(sbj), ' '.join(obj) 21 | 22 | return False, '', '' 23 | 24 | def wc(filename): 25 | return int(check_output(["wc", "-l", filename]).split()[0]) 26 | 27 | # https://github.com/huggingface/transformers/blob/758ed3332b219dd3529a1d3639fa30aa4954e0f3/src/transformers/data/metrics/squad_metrics.py 28 | def normalize_answer(s): 29 | """Lower text and remove punctuation, articles and extra whitespace.""" 30 | 31 | def remove_articles(text): 32 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 33 | return re.sub(regex, " ", text) 34 | 35 | def white_space_fix(text): 36 | return " ".join(text.split()) 37 | 38 | def remove_punc(text): 39 | exclude = set(string.punctuation) 40 | return "".join(ch for ch in text if ch not in exclude) 41 | 42 | def lower(text): 43 | return text.lower() 44 | 45 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 46 | -------------------------------------------------------------------------------- /preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Data construction 2 | 3 | ## Quick Link 4 | * [UMLS](#umls) 5 | 6 | ## UMLS 7 | Before starting pre-processing UMLS, you need to download raw data from [2020AB UMLS Metathesaurus Files](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsarchives04.html#2020AB). 8 | ``` 9 | # Extract bio triples from UMLS 10 | python ./process_umls.py \ 11 | --rel_path 2020AB/META/MRREL.RRF \ 12 | --conso_path 2020AB/META/MRCONSO.RRF \ 13 | --sty_path 2020AB/META/MRSTY.RRF \ 14 | --output_dir ../data/umls/triples 15 | 16 | # Filter triples based on max length 17 | python ./filter_length.py \ 18 | --input_dir "../data/umls/triples/*.jsonl" \ 19 | --output_dir "../data/umls/triples_10sw" \ 20 | --model_name bert-base-cased \ 21 | --max_length 10 \ 22 | --pids UR116,UR124,UR173,UR180,UR211,UR214,UR221,UR254,UR256,UR44,UR45,UR48,UR49,UR50,UR588,UR625 23 | 24 | # Aggregate data 25 | python ./aggregate_data.py \ 26 | --input_path "data/umls/triples_10sw/*.jsonl" \ 27 | --model_name_or_path bert-base-cased \ 28 | --output_dir ../data/umls/triples_processed \ 29 | --min_count 500 \ 30 | --max_count 2000 31 | 32 | # get triple stats 33 | python ./get_stats_triples.py \ 34 | --data_dir '../data/umls/triples_processed/*' \ 35 | --property_path '../data/umls/meta/properties.tsv' 36 | ``` 37 | 38 | Statistics 39 | ``` 40 | PID TRAIN DEV TEST 41 | UR44 452 113 566 42 | UR221 241 61 302 43 | UR45 772 193 965 44 | UR48 700 176 876 45 | UR211 650 162 813 46 | UR214 459 115 574 47 | UR256 244 62 306 48 | UR588 621 156 777 49 | UR254 672 169 841 50 | UR180 346 87 434 51 | UR116 668 167 835 52 | UR625 381 96 477 53 | UR173 512 128 640 54 | UR49 615 154 769 55 | UR50 663 166 829 56 | UR124 463 116 580 57 | ================================ 58 | TOTAL 8459 2121 10584 59 | ``` -------------------------------------------------------------------------------- /BioLAMA/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import collections 4 | 5 | # https://github.com/huggingface/transformers/blob/758ed3332b219dd3529a1d3639fa30aa4954e0f3/src/transformers/data/metrics/squad_metrics.py 6 | def normalize_answer(s): 7 | """Lower text and remove punctuation, articles and extra whitespace.""" 8 | 9 | def remove_articles(text): 10 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 11 | return re.sub(regex, " ", text) 12 | 13 | def white_space_fix(text): 14 | return " ".join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return "".join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | def get_tokens(s): 26 | if not s: 27 | return [] 28 | return normalize_answer(s).split() 29 | 30 | def compute_exact(a_gold, a_pred): 31 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 32 | 33 | def compute_f1(a_gold, a_pred): 34 | gold_toks = get_tokens(a_gold) 35 | pred_toks = get_tokens(a_pred) 36 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 37 | num_same = sum(common.values()) 38 | if len(gold_toks) == 0 or len(pred_toks) == 0: 39 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 40 | return int(gold_toks == pred_toks) 41 | if num_same == 0: 42 | return 0 43 | precision = 1.0 * num_same / len(pred_toks) 44 | recall = 1.0 * num_same / len(gold_toks) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1 47 | 48 | def find_sub_list(sl,l): 49 | results=[] 50 | sll=len(sl) 51 | for ind in (i for i,e in enumerate(l) if e==sl[0]): 52 | if l[ind:ind+sll]==sl: 53 | results.append(ind) 54 | return results 55 | 56 | def convert_2d_list_to_1d(l): 57 | return [j for sub in l for j in sub] 58 | 59 | def convert_1d_list_to_2d(l, n): 60 | return [l[i:i+n] for i in range(0, len(l), n)] -------------------------------------------------------------------------------- /BioLAMA/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | from torch.utils.data import Dataset 4 | import torch 5 | import random 6 | from transformers import ( 7 | BertTokenizer, 8 | RobertaTokenizer 9 | ) 10 | 11 | random.seed(0) 12 | class FactDataset(Dataset): 13 | def __init__(self, input_file, prompt_token_len, tokenizer, template): 14 | print(f"FactDataset! input_file={input_file} prompt_token_len={prompt_token_len}") 15 | self.tokenizer = tokenizer 16 | self.prompt_token_len = prompt_token_len 17 | self.template = template 18 | 19 | self.data = self.load_data( 20 | input_file=input_file 21 | ) 22 | # shuffle data 23 | random.shuffle(self.data) 24 | 25 | if isinstance(tokenizer, BertTokenizer): 26 | self.mask_token = '[MASK]' 27 | self.prompt_token = '[unused1]' 28 | elif isinstance(tokenizer, RobertaTokenizer): 29 | self.mask_token = '' 30 | self.prompt_token = '¤' # hotfix for biolm 31 | else: 32 | print(f"tokenizer type = {type(tokenizer)}") 33 | assert 0 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, idx): 39 | sub, obj = self.data[idx] 40 | 41 | mask_idx = self.tokenizer.convert_tokens_to_ids(self.mask_token) 42 | prompt_idx = self.tokenizer.convert_tokens_to_ids(self.prompt_token) 43 | 44 | assert mask_idx != prompt_idx 45 | 46 | input_sentence, tokenized_obj = self.convert_template_to_input(sub, obj) 47 | input_ids = torch.tensor(self.tokenizer.encode(input_sentence)) 48 | 49 | if idx <5: 50 | print(f"{input_sentence}, {input_ids}") 51 | 52 | # get template tokens 53 | # replace 54 | 55 | mask_ind = input_ids.eq(mask_idx) 56 | input_ids[mask_ind] = torch.tensor(tokenized_obj) 57 | mask_ind = mask_ind.long() 58 | 59 | return input_ids, mask_ind 60 | 61 | def convert_template_to_input(self, sub, obj): 62 | tokenized_obj = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(obj)) 63 | obj_len = len(tokenized_obj) 64 | 65 | # fill subject and mask 66 | input_sentence = self.template.replace('[X]', sub) 67 | input_sentence = input_sentence.replace('[Y]', f' {self.mask_token} ' * obj_len) 68 | 69 | return input_sentence, tokenized_obj 70 | 71 | def load_data(self, input_file): 72 | data = [] 73 | 74 | with open(input_file) as f: 75 | for line in f: 76 | sample = json.loads(line) 77 | sub_label = sample['sub_label'] 78 | 79 | for obj_label, obj_aliases in zip(sample['obj_labels'],sample['obj_aliases']): 80 | data.append((sub_label,obj_label)) 81 | 82 | return data -------------------------------------------------------------------------------- /BioLAMA/cli_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ( 3 | AutoTokenizer, 4 | AutoModelWithLMHead 5 | ) 6 | 7 | from preprocessor import Preprocessor 8 | from decoder import Decoder 9 | import argparse 10 | import os 11 | 12 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 13 | 14 | def print_predictions(sentence, preds_probs): 15 | k = min(len(preds_probs),10) 16 | # print(f"Top {k} predictions") 17 | print("-------------------------") 18 | print(f"Rank\tProb\tPred") 19 | print("-------------------------") 20 | for i in range(k): 21 | preds_prob = preds_probs[i] 22 | print(f"{i+1}\t{round(preds_prob[1],3)}\t{preds_prob[0]}") 23 | 24 | print("-------------------------") 25 | # print("\n") 26 | print("Top1 prediction sentence:") 27 | print(f"\"{sentence.replace('[Y]',preds_probs[0][0])}\"") 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | # parser.add_argument("--text", required=True) 32 | parser.add_argument("--model_name_or_path", default='bert-base-uncased') 33 | parser.add_argument("--num_mask", type=int, default=10) 34 | parser.add_argument("--init_method", choices=['independent','order','confidence'], default='confidence') 35 | parser.add_argument("--iter_method", choices=['none','order','confidence'], default='none') 36 | parser.add_argument("--max_iter", type=int, default=10) 37 | parser.add_argument("--beam_size", type=int, default=5) 38 | parser.add_argument("--batch_size", type=int, default=10) 39 | args = parser.parse_args() 40 | 41 | print(f'load model {args.model_name_or_path}') 42 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 43 | lm_model = AutoModelWithLMHead.from_pretrained(args.model_name_or_path) 44 | if torch.cuda.is_available(): 45 | lm_model = lm_model.cuda() 46 | 47 | # make sure this is only an evaluation 48 | lm_model.eval() 49 | for param in lm_model.parameters(): 50 | param.grad = None 51 | 52 | preprocessor = Preprocessor(tokenizer=tokenizer, num_mask=args.num_mask) 53 | decoder = Decoder( 54 | model=lm_model, 55 | tokenizer=tokenizer, 56 | init_method=args.init_method, 57 | iter_method=args.iter_method, 58 | MAX_ITER=args.max_iter, 59 | BEAM_SIZE=args.beam_size, 60 | NUM_MASK=args.num_mask, 61 | BATCH_SIZE=args.batch_size, 62 | verbose=False 63 | ) 64 | 65 | while True: 66 | text = input("Please enter input (e.g., Flu has symptom such as [Y].):\n") 67 | if "[Y]" not in text: 68 | print("[Warning] Please type in the proper format.\n") 69 | continue 70 | 71 | # sentences = preprocessor.preprocess_single_sent(sentence=args.text) 72 | sentences = preprocessor.preprocess_single_sent(sentence=text) 73 | # print(sentences) 74 | all_preds_probs = decoder.decode([sentences], batch_size=args.batch_size, verbose=False) # topk predictions 75 | preds_probs = all_preds_probs[0] 76 | 77 | # print_predictions(args.text, preds_probs) 78 | print_predictions(text, preds_probs) 79 | 80 | print("\n") 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /preprocessing/get_stats_triples.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import os 4 | import argparse 5 | from glob import glob 6 | 7 | 8 | def get_statistics_of_file(path): 9 | count = 0 10 | sub_types = [] 11 | obj_types = [] 12 | with open(path) as f: 13 | for line in f: 14 | data = json.loads(line) 15 | count+= 1 16 | 17 | if ('sub_type' in data) and ('obj_types' in data): 18 | local_sub_type = data['sub_type'] 19 | local_obj_types = data['obj_types'] 20 | 21 | sub_types.append(local_sub_type) 22 | obj_types += local_obj_types 23 | 24 | return count, sub_types, obj_types 25 | 26 | 27 | def get_obj_counts(path): 28 | obj_counts = [] 29 | with open(file=path) as f: 30 | for line in f: 31 | data = json.loads(line) 32 | obj_labels = data['obj_labels'] 33 | obj_counts.append(len(obj_labels)) 34 | 35 | return obj_counts 36 | 37 | 38 | """ 39 | Input: Triples 40 | Output: PID/Property Name/Count/Pair Type 41 | """ 42 | def main(args): 43 | input_dirs = glob(args.data_dir) 44 | if args.pids: # P1050 45 | new_dirs = [] 46 | pids = list(dict.fromkeys(args.pids.split(","))) 47 | for file in input_dirs: 48 | if file.split("/")[-1] in pids: 49 | new_dirs.append(file) 50 | input_dirs = new_dirs 51 | 52 | # pid2name 53 | if args.property_path: 54 | pid2name = {} 55 | with open(args.property_path) as f: 56 | rdr = csv.reader(f, delimiter='\t') 57 | r = list(rdr) 58 | for pid, name in r: 59 | pid2name[pid] = name 60 | 61 | # qid2type 62 | if args.type_path: 63 | with open(args.type_path) as f: 64 | qid2type = json.load(f) 65 | else: 66 | qid2type = None 67 | 68 | all_obj_counts = [] 69 | print(f"PID\tTRAIN\tDEV\tTEST") 70 | all_trains = 0 71 | all_devs = 0 72 | all_tests = 0 73 | for input_dir in input_dirs: 74 | pid = input_dir.split("/")[-1] 75 | train_file = os.path.join(input_dir, 'train.jsonl') 76 | dev_file = os.path.join(input_dir, 'dev.jsonl') 77 | test_file = os.path.join(input_dir, 'test.jsonl') 78 | 79 | sub_types = [] 80 | obj_types = [] 81 | train_count, train_sub_types, train_obj_types = get_statistics_of_file(train_file) 82 | dev_count, dev_sub_types, dev_obj_types = get_statistics_of_file(dev_file) 83 | test_count, test_sub_types, test_obj_types = get_statistics_of_file(test_file) 84 | 85 | if qid2type: 86 | sub_types = train_sub_types + dev_sub_types + test_sub_types 87 | obj_types = train_obj_types + dev_obj_types + test_obj_types 88 | sub_types = [qid2type[st] for st in sub_types] 89 | obj_types = [qid2type[ot] for ot in obj_types] 90 | 91 | sub_types = list(set(sub_types)) 92 | obj_types = list(set(obj_types)) 93 | 94 | print(f"{pid}\t{train_count}\t{dev_count}\t{test_count}") 95 | 96 | all_trains += train_count 97 | all_devs += dev_count 98 | all_tests += test_count 99 | 100 | print("================================") 101 | print(f"TOTAL\t{all_trains}\t{all_devs}\t{all_tests}") 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--data_dir", 106 | default='../data/wikidata/triples_processed' 107 | ) 108 | parser.add_argument("--pids", 109 | default=None 110 | ) 111 | parser.add_argument("--property_path", 112 | default=None 113 | ) 114 | parser.add_argument("--type_path", 115 | default=None 116 | ) 117 | 118 | args = parser.parse_args() 119 | 120 | main(args) 121 | -------------------------------------------------------------------------------- /preprocessing/process_umls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import traceback 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def main(args): 10 | rel_path = args.rel_path 11 | conso_path = args.conso_path 12 | sty_path = args.sty_path 13 | 14 | # load cui_dict 15 | cui2names = {} 16 | with open(file=conso_path) as f: 17 | for line in tqdm(f, desc="Creating cui2names"): 18 | row = line.strip().split('|') 19 | cui = row[0] 20 | lang = row[1] 21 | name = row[14] 22 | 23 | if lang != 'ENG': 24 | continue 25 | 26 | if cui not in cui2names: 27 | cui2names[cui] = [] 28 | 29 | cui2names[cui].append(name) 30 | 31 | # load cui2types 32 | cui2types = {} 33 | with open(file=sty_path) as f: 34 | for line in tqdm(f, desc="Creating cui2types"): 35 | row = line.strip().split('|') 36 | cui = row[0] 37 | type_ = row[3] 38 | 39 | if cui not in cui2types: 40 | cui2types[cui] = [] 41 | 42 | cui2types[cui].append(type_) 43 | 44 | # load triples 45 | relation2id = {} 46 | relation2triples = {} 47 | relation_idx = 0 48 | with open(file=rel_path) as f: 49 | for line in tqdm(f, desc="Creating data"): 50 | row = line.strip().split('|') 51 | 52 | try: 53 | subj_cui = row[0] 54 | subj_label = cui2names[subj_cui][0] 55 | subj_aliases = cui2names[subj_cui][1:] 56 | subj_types = cui2types[subj_cui] 57 | 58 | obj_cui = row[4] 59 | obj_label = cui2names[obj_cui][0] 60 | obj_aliases = cui2names[obj_cui][1:] 61 | obj_types = cui2types[obj_cui] 62 | except KeyError: 63 | continue 64 | except Exception as e: 65 | print(e) 66 | traceback.print_exc() 67 | raise e 68 | 69 | 70 | relation = row[7] 71 | 72 | if relation == '': 73 | continue 74 | 75 | if relation not in relation2id: 76 | relation2id[relation] = f'UR{relation_idx}' 77 | relation_idx += 1 78 | 79 | relation_id = relation2id[relation] 80 | 81 | if relation_id not in relation2triples: 82 | relation2triples[relation2id[relation]] = [] 83 | 84 | template = {'predicate_id': relation_id, 85 | 'predicate_name': relation, 86 | 'sub_uri': subj_cui, 87 | 'sub_type': subj_types, 88 | 'sub_label': subj_label, 89 | 'sub_aliases': subj_aliases, 90 | 'obj_uri': obj_cui, 91 | 'obj_type': obj_types, 92 | 'obj_label': obj_label, 93 | 'obj_aliases': obj_aliases} 94 | 95 | relation2triples[relation_id].append(template) 96 | 97 | rel_counts = {} 98 | for relid, triples in relation2triples.items(): 99 | assert relid not in rel_counts 100 | rel_counts[relid] = len(triples) 101 | 102 | for relation_id, triples in tqdm(relation2triples.items(), desc=f"Saving data in {args.output_dir}"): 103 | save_filename = os.path.join(args.output_dir, relation_id + '.jsonl') 104 | with open(file=save_filename, mode='w') as f: 105 | for triple in triples: 106 | json.dump(obj=triple, fp=f) 107 | f.write('\n') 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--rel_path', default='2020AB/META/MRREL.RRF', type=str) 113 | parser.add_argument('--conso_path', default='2020AB/META/MRCONSO.RRF', type=str) 114 | parser.add_argument('--sty_path', default='2020AB/META/MRSTY.RRF', type=str) 115 | parser.add_argument('--output_dir', default='data/umls/triples', type=str) 116 | 117 | args = parser.parse_args() 118 | 119 | main(args) 120 | -------------------------------------------------------------------------------- /BioLAMA/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import os 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | from transformers import ( 9 | AutoTokenizer, 10 | BertTokenizer, 11 | RobertaTokenizer 12 | ) 13 | import glob 14 | from utils import ( 15 | compute_exact 16 | ) 17 | 18 | def flatten_list(l): 19 | new_list = [] 20 | 21 | for element in l: 22 | if isinstance(element, str): 23 | new_list.append(element) 24 | elif isinstance(element, list): 25 | if element != []: 26 | new_list.extend(element) 27 | 28 | return new_list 29 | 30 | 31 | def get_raw_score(pred, golds): 32 | em = max(compute_exact(a, pred) for a in golds) 33 | 34 | return em 35 | 36 | class Evaluator(): 37 | def __init__(self, tokenizer=None): 38 | self.tokenizer = tokenizer 39 | 40 | if isinstance(tokenizer, BertTokenizer): 41 | self.mask_token = '[MASK]' 42 | elif isinstance(tokenizer, RobertaTokenizer): 43 | self.mask_token = '' 44 | elif tokenizer == None: 45 | self.mask_token = '' 46 | else: 47 | print(f"tokenizer type = {type(tokenizer)}") 48 | assert 0 49 | 50 | def check_multi(self, golds): 51 | token_nums = [len(self.tokenizer.tokenize(gold)) for gold in golds] 52 | if 1 in token_nums: 53 | return False 54 | else: 55 | return True 56 | 57 | def evaluate_preds_for_single_sample(self, preds, golds): 58 | """ 59 | input: prediction strings, gold strings 60 | output: em score, f1 score 61 | """ 62 | 63 | ems = [] 64 | for pred in preds: 65 | em = get_raw_score(pred, golds) 66 | ems.append(em) 67 | 68 | return max(ems) 69 | 70 | def evaluate(self, all_preds_probs, all_golds, subjects, prompts, inputs, uuids): 71 | """ 72 | input: prediction strings for all samples, gold strings for all samples 73 | output: accuracy 74 | """ 75 | result = [] 76 | 77 | topk = len(all_preds_probs[0]) 78 | 79 | print(f"topk={topk}") 80 | 81 | hits = [0]*topk # for topk 82 | 83 | total = 0 84 | 85 | if not subjects: 86 | subjects = ['']*len(all_preds_probs) 87 | if not prompts: 88 | prompts = ['']*len(all_preds_probs) 89 | 90 | assert len(all_preds_probs) == len(all_golds) 91 | for i, preds_probs in tqdm(enumerate(all_preds_probs), total=len(all_preds_probs)): 92 | # probs = all_probs[i] 93 | golds = all_golds[i] 94 | subject = subjects[i] 95 | prompt = prompts[i] 96 | _input = inputs[i] 97 | uuid = uuids[i] 98 | 99 | temp={ 100 | 'uuid': uuid, 101 | 'subject': subject, 102 | 'prompt': prompt, 103 | 'input': _input[0].replace(f'{self.mask_token}','[Y]'), 104 | 'golds': golds, 105 | } 106 | 107 | max_hit = 0 108 | 109 | topk_preds_probs = preds_probs # see all for logging 110 | topk_preds = [t[0] for t in topk_preds_probs] 111 | 112 | temp['corrected_preds'] = [] 113 | for k, kth_pred in enumerate(topk_preds): 114 | _hit = self.evaluate_preds_for_single_sample(preds=[kth_pred], golds=golds) 115 | if _hit: 116 | temp['corrected_preds'].append((kth_pred, k)) 117 | if k < topk: 118 | max_hit = max(_hit, max_hit) # if previous max em is 1, follow it in this k 119 | hits[k] += max_hit 120 | temp['preds'] = preds_probs 121 | 122 | total += 1 123 | result.append(temp) 124 | 125 | final_accs = [] 126 | 127 | for hit in hits: 128 | final_accs.append(round(hit/(total+1e-7),5)) 129 | 130 | performance = { 131 | 'acc@k': final_accs, 132 | } 133 | 134 | result = {'result': result, 'performance': performance} 135 | 136 | return result -------------------------------------------------------------------------------- /BioLAMA/run_manual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ( 3 | AutoTokenizer, 4 | AutoModelWithLMHead 5 | ) 6 | 7 | import json 8 | from preprocessor import Preprocessor 9 | from decoder import Decoder 10 | from evaluator import Evaluator 11 | import argparse 12 | from glob import glob 13 | import os 14 | import numpy as np 15 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model_name_or_path", default='bert-base-uncased') 20 | parser.add_argument("--prompt_path", default='./data/wikidata/prompts/manual.jsonl') 21 | parser.add_argument("--test_path", required=True) 22 | parser.add_argument("--num_mask", type=int, default=10) 23 | parser.add_argument("--init_method", choices=['independent','order','confidence'], default='confidence') 24 | parser.add_argument("--iter_method", choices=['none','order','confidence'], default='none') 25 | parser.add_argument("--max_iter", type=int, default=10) 26 | parser.add_argument("--beam_size", type=int, default=5) 27 | parser.add_argument("--batch_size", type=int, default=10) 28 | parser.add_argument("--pids", default=None) 29 | parser.add_argument("--output_dir", default=None) 30 | parser.add_argument("--draft", action="store_true") 31 | 32 | args = parser.parse_args() 33 | if args.draft: 34 | args.output_dir = args.output_dir + "_draft" 35 | os.makedirs(args.output_dir, exist_ok=True) 36 | 37 | print(f'load model {args.model_name_or_path}') 38 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 39 | lm_model = AutoModelWithLMHead.from_pretrained(args.model_name_or_path) 40 | if torch.cuda.is_available(): 41 | lm_model = lm_model.cuda() 42 | 43 | # make sure this is only an evaluation 44 | lm_model.eval() 45 | for param in lm_model.parameters(): 46 | param.grad = None 47 | 48 | print('prompt map loading') 49 | pid2prompt_meta = {} 50 | with open(args.prompt_path) as f: 51 | for line in f: 52 | l = json.loads(line) 53 | pid2prompt_meta[l['relation']] = { 54 | 'template':l['template'], 55 | } 56 | 57 | print('load modules') 58 | preprocessor = Preprocessor(tokenizer=tokenizer, num_mask=args.num_mask) 59 | decoder = Decoder( 60 | model=lm_model, 61 | tokenizer=tokenizer, 62 | init_method=args.init_method, 63 | iter_method=args.iter_method, 64 | MAX_ITER=args.max_iter, 65 | BEAM_SIZE=args.beam_size, 66 | NUM_MASK=args.num_mask, 67 | BATCH_SIZE=args.batch_size, 68 | ) 69 | evaluator = Evaluator( 70 | tokenizer=tokenizer, 71 | ) 72 | 73 | files = glob(args.test_path) 74 | if args.pids: # e.g., P1050 75 | new_files = [] 76 | pids = list(dict.fromkeys(args.pids.split(","))) 77 | for file in files: 78 | if file.split("/")[-2] in pids: 79 | new_files.append(file) 80 | files = new_files 81 | 82 | total_relation = 0 83 | pid2performance={} 84 | for file_path in files: 85 | pid = file_path.split("/")[-2] 86 | 87 | template = pid2prompt_meta[pid]['template'] 88 | 89 | print(f'preprocess {file_path}') 90 | sentences, all_gold_objects, subjects, prompts, uuids = preprocessor.preprocess( 91 | file_path, 92 | template = template, 93 | draft=args.draft) 94 | 95 | print(f'decode {file_path}') 96 | all_preds_probs = decoder.decode(sentences, batch_size=args.batch_size) # topk predictions 97 | 98 | print(f'evaluate {file_path}') 99 | result = evaluator.evaluate( 100 | all_preds_probs = all_preds_probs, 101 | all_golds = all_gold_objects, 102 | subjects=subjects, 103 | prompts=prompts, 104 | inputs=sentences, 105 | uuids=uuids 106 | ) 107 | 108 | # saving log 109 | with open(os.path.join(args.output_dir, pid + ".json"), 'w') as f: 110 | json.dump(result, f) 111 | 112 | if len(result['result']) == 0: 113 | print("nothing to print") 114 | continue 115 | 116 | total_relation += 1 117 | 118 | performance = result['performance'] 119 | local_acc = performance['acc@k'] 120 | 121 | logging_data ={} 122 | pid2performance[pid] = {} 123 | for k in range(args.beam_size): 124 | if k+1 in [1,5]: 125 | acc = local_acc[k] 126 | logging_data[f"{pid}_acc@{k+1}"] = acc * 100 127 | pid2performance[pid][f'acc@{k+1}'] = acc * 100 128 | 129 | print(f'performance of {pid}') 130 | print(logging_data) 131 | 132 | print("PID\tAcc@1\tAcc@5") 133 | print("-------------------------") 134 | acc1s = [] 135 | acc5s = [] 136 | for pid, performance in pid2performance.items(): 137 | acc1 = performance['acc@1'] 138 | acc5 = performance['acc@5'] 139 | acc1s.append(acc1) 140 | acc5s.append(acc5) 141 | print(f"{pid}\t{round(acc1,2)}\t{round(acc5,2)}") 142 | 143 | print("-------------------------") 144 | print(f"MACRO\t{round(np.mean(acc1s),2)}\t{round(np.mean(acc5s),2)}") 145 | 146 | if __name__ == '__main__': 147 | main() -------------------------------------------------------------------------------- /preprocessing/process_wikidata_triples.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | from wikidataintegrator.wdi_core import WDItemEngine 7 | import traceback 8 | from wikidataintegrator.wdi_config import config 9 | from glob import glob 10 | import csv 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--property_path", required=True) 15 | parser.add_argument("--entity_dir",required=True) 16 | parser.add_argument("--output_dir",required=True) 17 | args = parser.parse_args() 18 | 19 | # Sometimes cacheing allows a failed sparql query to finish on on subsequent attems 20 | # For this reason we ill run 3 times 21 | config['BACKOFF_MAX_TRIES'] = 3 22 | 23 | execute_sparql_query = WDItemEngine.execute_sparql_query 24 | # comment this ---v out to use official wikidata endpoint 25 | execute_sparql_query = functools.partial(execute_sparql_query, 26 | endpoint="http://163.152.163.168:9999/bigdata/namespace/wdq/sparql") 27 | 28 | def get_triples_given_property(pid, entity2meta): 29 | # get count first 30 | s = """ 31 | SELECT (COUNT(?item) as ?c) WHERE { 32 | { 33 | SELECT DISTINCT ?item ?value{ 34 | ?item wdt:{p} ?value 35 | } 36 | } 37 | } 38 | """.replace("{p}", pid) 39 | try: 40 | print(s) 41 | d = execute_sparql_query(s)['results']['bindings'] 42 | except Exception as e: 43 | print(e) 44 | traceback.print_exc() 45 | raise e 46 | 47 | count = [int(x['c']['value']) for x in d][0] 48 | print(f"{pid}:{count}") 49 | 50 | # query one batch at a time 51 | LIMIT = 3000 52 | max_iter = int(count/LIMIT) 53 | 54 | triples = [] 55 | for i in tqdm(range(max_iter + 1)): 56 | s = """ 57 | SELECT ?item ?value WHERE{ 58 | ?item wdt:{p} ?value. 59 | } 60 | LIMIT {limit} 61 | OFFSET {offset} 62 | """.replace("{p}", pid).replace("{limit}", str(LIMIT)).replace("{offset}", str(LIMIT*i)) 63 | try: 64 | print(f"get_triples_given_property={s}") 65 | d = execute_sparql_query(s)['results']['bindings'] 66 | except Exception as e: 67 | print(e) 68 | traceback.print_exc() 69 | raise e 70 | 71 | for sample in d: 72 | sbj_type = sample['item']['type'] 73 | obj_type = sample['value']['type'] 74 | 75 | # extract only entity to entity triple 76 | if (sbj_type != 'uri') or (obj_type != 'uri'): 77 | continue 78 | 79 | sub_uri = sample['item']['value'].split("/")[-1] 80 | obj_uri = sample['value']['value'].split("/")[-1] 81 | 82 | # extract only bioentity to bioentity triple 83 | if (sub_uri not in entity2meta): 84 | print(f"WARN! {sub_uri} is not bio entity") 85 | continue 86 | if (obj_uri not in entity2meta): 87 | print(f"WARN! {obj_uri} is not bio entity") 88 | continue 89 | 90 | triples.append({ 91 | "predicate_id": pid, 92 | "sub_uri": sub_uri, 93 | "sub_type": entity2meta[sub_uri]['type'], 94 | "sub_label": entity2meta[sub_uri]['label'], 95 | "sub_aliases": entity2meta[sub_uri]['aliases'], 96 | "obj_uri": obj_uri, 97 | "obj_type": entity2meta[obj_uri]['type'], 98 | "obj_label": entity2meta[obj_uri]['label'], 99 | "obj_aliases": entity2meta[obj_uri]['aliases'], 100 | }) 101 | return triples 102 | 103 | if __name__ == "__main__": 104 | print("[Start]load entites of each type") 105 | entity_files = glob(os.path.join(args.entity_dir, "Q*.tsv")) 106 | entity2meta = {} 107 | for entity_file in entity_files: 108 | _type = entity_file.split("/")[-1].replace(".tsv","") 109 | with open(entity_file) as f: 110 | for line in f: 111 | rdr = csv.reader(f, delimiter='\t') 112 | r = list(rdr) 113 | for entity_id, label, description, aliases in r: 114 | entity2meta[entity_id] = { 115 | 'type': _type, 116 | 'label': label.strip(), 117 | 'aliases': [al.strip() for al in aliases.split("|")] 118 | } 119 | print(f"[Finish]total entities={len(entity2meta)}") 120 | 121 | print("[Start]load bio properties") 122 | properties = {} 123 | with open(args.property_path, 'r') as f: 124 | rdr = csv.reader(f, delimiter='\t') 125 | r = list(rdr) 126 | for pid, plabel in r: 127 | properties[pid] = plabel 128 | print(f"[Finish]total entities={len(r)}") 129 | 130 | print("[Start]query bio triples from wikidata") 131 | pid2triples = {} 132 | num_triples = 0 133 | for pid in tqdm(properties): 134 | pid2triples[pid] = get_triples_given_property(pid, entity2meta) 135 | num_triples += len(pid2triples[pid]) 136 | print(f"[Finish]total triples={num_triples}") 137 | 138 | print("[Start]save triples") 139 | MIN_COUNT = 200 140 | os.makedirs(args.output_dir, exist_ok=True) 141 | for pid in pid2triples: 142 | triples = pid2triples[pid] 143 | if len(triples) < MIN_COUNT: 144 | continue 145 | 146 | output_path = os.path.join(args.output_dir, f"{pid}.jsonl" ) 147 | with open(output_path, 'w') as fo: 148 | for triple in pid2triples[pid]: 149 | output = json.dumps(triple, ensure_ascii=False) 150 | fo.write(output + "\n") 151 | print("[Finish]") -------------------------------------------------------------------------------- /BioLAMA/preprocessor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from transformers import ( 4 | BertTokenizer, 5 | RobertaTokenizer 6 | ) 7 | import random 8 | 9 | def find_sub_list(sl,l): 10 | results=[] 11 | sll=len(sl) 12 | for ind in (i for i,e in enumerate(l) if e==sl[0]): 13 | if l[ind:ind+sll]==sl: 14 | results.append(ind) 15 | return results 16 | 17 | class Preprocessor(): 18 | """ 19 | Object to perform dynamic masking of object tokens in given prompts. 20 | "[X] causes diseases such as [Y]." -> "[X] causes diseases such as [MASK] * len(Y)." 21 | """ 22 | def __init__(self, tokenizer, num_mask): 23 | self.tokenizer = tokenizer # bert tokenizer 24 | 25 | self.MASK_IDX = self.tokenizer.mask_token_id 26 | self.PAD_IDX = self.tokenizer.pad_token_id 27 | self.UNK_IDX = self.tokenizer.unk_token_id 28 | 29 | if isinstance(tokenizer, BertTokenizer): 30 | self.mask_token = '[MASK]' 31 | self.pad_token = '[PAD]' 32 | self.unk_token = '[UNK]' 33 | assert self.tokenizer.convert_ids_to_tokens(self.MASK_IDX) == self.mask_token 34 | assert self.tokenizer.convert_ids_to_tokens(self.PAD_IDX) == self.pad_token 35 | assert self.tokenizer.convert_ids_to_tokens(self.UNK_IDX) == self.unk_token 36 | 37 | elif isinstance(tokenizer, RobertaTokenizer): 38 | self.mask_token = '' 39 | self.pad_token = '' 40 | self.unk_token = '' 41 | assert self.tokenizer.convert_ids_to_tokens(self.PAD_IDX) == self.pad_token 42 | assert self.tokenizer.convert_ids_to_tokens(self.UNK_IDX) == self.unk_token 43 | else: 44 | print(f"tokenizer type = {type(tokenizer)}") 45 | import pdb ; pdb.set_trace() # get num_mask as an argument 46 | 47 | self.num_mask = num_mask 48 | 49 | def tokenize(self, text): 50 | return self.tokenizer.tokenize(text) 51 | 52 | def preprocess_single(self, subject, prompt): 53 | sentences = [] 54 | 55 | for i in range(1, self.num_mask + 1): 56 | sentence = prompt 57 | 58 | # fill in subject 59 | sentence = sentence.replace('[X]', subject) 60 | mask_sequence = (f"{self.mask_token} " * i).strip() 61 | sentence = sentence.replace('[Y]', mask_sequence) 62 | sentences.append(sentence) 63 | 64 | return sentences 65 | 66 | # input as a sentence 67 | # e.g., Adamantinoma of Long Bones has symptoms such as [Y]. 68 | def preprocess_single_sent(self, sentence): 69 | original_sent = sentence 70 | sentences = [] 71 | 72 | for i in range(1, self.num_mask + 1): 73 | # fill in subject 74 | mask_sequence = (f"{self.mask_token} " * i).strip() 75 | sentence = original_sent.replace('[Y]', mask_sequence) 76 | sentences.append(sentence) 77 | 78 | return sentences 79 | 80 | def preprocess(self, data_path, template, draft=False, shuffle_subject=False, replace_sub_syn=False): 81 | """ 82 | Masks out tokens corresponding to objects. 83 | 84 | Example 85 | ------- 86 | "meprobamate cures diseases such as headache ." -> "meprobamate cures diseases such as [MASK] ." 87 | """ 88 | 89 | all_masked_sentences = [] 90 | all_gold_objects = [] 91 | subjects=[] 92 | prompts=[] 93 | uuids=[] 94 | 95 | # load temp_subjects 96 | temp_subjects = [] 97 | with open(file=data_path, mode='r') as f: 98 | for line in f: 99 | sample = json.loads(line) 100 | temp_subjects.append(sample['sub_label']) 101 | random.shuffle(temp_subjects) 102 | 103 | # load data 104 | with open(file=data_path, mode='r') as f: 105 | index = 0 106 | for line in f: 107 | sample = json.loads(line) 108 | 109 | prompt = template 110 | if shuffle_subject: # for shuffle subject test 111 | subject = temp_subjects[index] 112 | else: 113 | subject = sample['sub_label'] 114 | 115 | if 'obj_labels' in sample: 116 | objects = sample['obj_labels'] 117 | elif 'obj_label' in sample: 118 | objects = [sample['obj_label']] 119 | else: 120 | assert 0 121 | 122 | # replace subject to synonym 123 | if replace_sub_syn: 124 | if len(sample['sub_aliases']): 125 | subject = random.sample(sample['sub_aliases'],k=1)[0] 126 | 127 | if 'obj_aliases' in sample: 128 | objects += [a for al in sample['obj_aliases'] for a in al] 129 | 130 | # lowercase 131 | lower_objects = list(dict.fromkeys([obj.lower() for obj in objects])) 132 | 133 | sentences = self.preprocess_single( 134 | subject=subject, 135 | prompt=prompt 136 | ) 137 | 138 | all_masked_sentences.append(sentences) 139 | all_gold_objects.append(lower_objects) 140 | subjects.append(subject) 141 | prompts.append(prompt) 142 | uuids.append(sample['uuid']) 143 | # print sentences with mask for debugging 144 | if index <= 2 - 1: 145 | print(sentences) 146 | 147 | if draft and index>=16 -1 : 148 | break 149 | index += 1 150 | 151 | return all_masked_sentences, all_gold_objects, subjects, prompts, uuids -------------------------------------------------------------------------------- /preprocessing/filter_length.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import os 4 | import argparse 5 | from glob import glob 6 | import re 7 | import string as STRING 8 | from utils import is_obj_in_sbj, wc 9 | from transformers import ( 10 | AutoTokenizer 11 | ) 12 | 13 | import string 14 | import re 15 | 16 | # https://github.com/huggingface/transformers/blob/758ed3332b219dd3529a1d3639fa30aa4954e0f3/src/transformers/data/metrics/squad_metrics.py 17 | def normalize_answer(s): 18 | """Lower text and remove punctuation, articles and extra whitespace.""" 19 | 20 | def remove_articles(text): 21 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 22 | return re.sub(regex, " ", text) 23 | 24 | def white_space_fix(text): 25 | return " ".join(text.split()) 26 | 27 | def remove_punc(text): 28 | exclude = set(string.punctuation) 29 | return "".join(ch for ch in text if ch not in exclude) 30 | 31 | def lower(text): 32 | return text.lower() 33 | 34 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 35 | 36 | def validate_alphanumeric_space_punctuation(text): 37 | return bool(re.match(f'^[ a-zA-Z0-9{STRING.punctuation}]+$', text)) 38 | 39 | def validate_len(string, tokenizer, maxlength): 40 | """ 41 | return true if length of subwords is less or equal to maxlength 42 | return false otherwise 43 | """ 44 | 45 | # Filter if string has a character which is not a alphanumeric, punctuation or space 46 | if validate_alphanumeric_space_punctuation(string) == False: 47 | return False 48 | 49 | # Filter ID 50 | if ":" in string: 51 | return False 52 | 53 | return len(tokenizer.tokenize(string)) <= maxlength 54 | 55 | def refine_aliases(label, aliases): 56 | new_aliases = [] 57 | for al in aliases: 58 | # normalize before comparision 59 | norm_al = normalize_answer(al) 60 | norm_label = normalize_answer(label) 61 | 62 | if norm_al not in norm_label: # not overlap 63 | new_aliases.append(al) 64 | # else: 65 | # print(f"refine_aliases {al} {label}") 66 | 67 | return new_aliases 68 | 69 | def main(args): 70 | input_dir = args.input_dir 71 | output_dir = args.output_dir 72 | os.makedirs(output_dir, exist_ok=True) 73 | 74 | assert input_dir != output_dir 75 | 76 | # load tokenizer 77 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 78 | 79 | input_files = glob(args.input_dir) 80 | 81 | # filter pids 82 | if args.pids: 83 | pids = list(dict.fromkeys(args.pids.split(","))) 84 | print(f"pids={pids}") 85 | input_files = [input_file for input_file in input_files if (input_file.split("/")[-1].split(".")[0] in pids)] 86 | 87 | org_num_triples = 0 88 | new_num_triples = 0 89 | MAX_LENGTH = args.max_length 90 | 91 | for input_file in tqdm(input_files): 92 | file_name = input_file.split("/")[-1] 93 | print(file_name) 94 | 95 | output_file = os.path.join(output_dir, file_name) 96 | 97 | # make tmp file for input 98 | with open(input_file) as f, open(output_file, 'w') as fo: 99 | num_lines = wc(input_file) 100 | 101 | for line in tqdm(f, total=num_lines): 102 | data = json.loads(line) 103 | org_num_triples += 1 104 | 105 | # filter when length of object is over max length 106 | sub_label = data['sub_label'] 107 | if not validate_len(sub_label,tokenizer, MAX_LENGTH): 108 | continue 109 | 110 | obj_label = data['obj_label'] 111 | if not validate_len(obj_label,tokenizer, MAX_LENGTH): 112 | continue 113 | 114 | # concat label and aliases 115 | data['sub_aliases'] = [al.strip() for al in data['sub_aliases'] if al.strip() != ''] 116 | data['sub_aliases'] = [alias for alias in data['sub_aliases'] if validate_len(alias,tokenizer, MAX_LENGTH)] 117 | 118 | # filter obj alias which is either empty or over max length 119 | data['obj_aliases'] = [al.strip() for al in data['obj_aliases'] if al.strip() != ''] 120 | data['obj_aliases'] = [al for al in data['obj_aliases'] if validate_len(al,tokenizer, MAX_LENGTH)] 121 | 122 | # filter overlap 123 | # 1) remove sbj_alias which overlaps with sbj 124 | sbj = data['sub_label'] 125 | sbj_aliases = data['sub_aliases'] 126 | sbj_aliases = refine_aliases(sbj, sbj_aliases) 127 | data['sub_aliases'] = sbj_aliases 128 | 129 | sbjs = [sbj] + sbj_aliases 130 | 131 | # 2) remove obj_alias which overlaps with obj 132 | obj = data['obj_label'] 133 | obj_aliases = data['obj_aliases'] 134 | obj_aliases = refine_aliases(obj, obj_aliases) 135 | data['obj_aliases'] = obj_aliases 136 | 137 | objs = [obj] + obj_aliases 138 | 139 | # 3) filter sbj-obj overlap 140 | is_overlap = False 141 | for sbj in sbjs: 142 | result, _sbj, _obj = is_obj_in_sbj(sbj=sbj, objs=objs) 143 | if result: 144 | is_overlap = True 145 | print(f"filter overlap! {_sbj}, {_obj}") 146 | break 147 | 148 | if is_overlap: # filter if overlapped 149 | continue 150 | 151 | # filter when sbj or obj has no name 152 | if data['sub_label'] == data['sub_uri']: 153 | print(f"filter no sub label {data['sub_label']}") 154 | continue 155 | 156 | if data['obj_label'] == data['obj_uri']: 157 | print(f"filter no obj label {data['obj_label']}") 158 | continue 159 | 160 | new_num_triples += 1 161 | output = json.dumps(data, ensure_ascii=False) 162 | fo.write(output + "\n") 163 | 164 | print(f"{new_num_triples}/{org_num_triples}={round(new_num_triples/org_num_triples,2)}") 165 | new_num_triples=0 166 | org_num_triples=0 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--input_dir", 171 | default='./data/wikidata/triples' 172 | ) 173 | parser.add_argument("--output_dir", 174 | default='./data/wikidata/triples_10sw' 175 | ) 176 | parser.add_argument("--max_length", 177 | required=True, 178 | type=int 179 | ) 180 | parser.add_argument("--pids", 181 | default=None 182 | ) 183 | parser.add_argument("--model_name_or_path", 184 | default='bert-base-cased' 185 | ) 186 | args = parser.parse_args() 187 | 188 | main(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BioLAMA 2 | 3 |
4 | BioLAMA 5 |
6 | 7 | BioLAMA is biomedical factual knowledge triples for probing biomedical LMs. The triples are collected and pre-processed from three sources: CTD, UMLS, and Wikidata. Please see our paper [ 8 | Can Language Models be Biomedical Knowledge Bases? (Sung et al., 2021)](http://arxiv.org/abs/2109.07154) for more details. 9 | 10 | ### Updates 11 | * \[**Mar 17, 2022**\] The BioLAMA probe with the CTD/UMLS/Wikidata triples are released [here](https://drive.google.com/file/d/1CcjpmNuAXavL3aMjwVqiiziMu3OGDyyG/view?usp=sharing). 12 | 13 | ## Getting Started 14 | After the [installation](#installation), you can easily try BioLAMA with manual prompts. When a subject is "flu" and you want to probe its symptoms from an LM, the input should be like "Flu has symptom such as \[Y\]." 15 | 16 | ``` 17 | # Set MODEL to bert-base-cased for BERT or dmis-lab/biobert-base-cased-v1.2 for BioBERT 18 | MODEL=./RoBERTa-base-PM-Voc/RoBERTa-base-PM-Voc-hf 19 | python ./BioLAMA/cli_demo.py \ 20 | --model_name_or_path ${MODEL} 21 | ``` 22 | 23 | Result: 24 | ``` 25 | Please enter input (e.g., Flu has symptoms such as [Y].): 26 | hepatocellular carcinoma has symptoms such as [Y]. 27 | ------------------------- 28 | Rank Prob Pred 29 | ------------------------- 30 | 1 0.648 jaundice 31 | 2 0.223 abdominal pain 32 | 3 0.127 jaundice and ascites 33 | 4 0.11 ascites 34 | 5 0.086 hepatomegaly 35 | 6 0.074 obstructive jaundice 36 | 7 0.06 abdominal pain and jaundice 37 | 8 0.059 ascites and jaundice 38 | 9 0.043 anorexia and jaundice 39 | 10 0.042 fever and jaundice 40 | ------------------------- 41 | Top1 prediction sentence: 42 | "hepatocellular carcinoma has symptoms such as jaundice." 43 | ``` 44 | 45 | ## Quick Link 46 | * [Installation](#installation) 47 | * [Resources](#resources) 48 | * [Experiments](#experiments) 49 | 50 | ## Installation 51 | 52 | ``` 53 | # Install torch with conda (please check your CUDA version) 54 | conda create -n BioLAMA python=3.7 55 | conda activate BioLAMA 56 | conda install pytorch=1.8.0 cudatoolkit=10.2 -c pytorch 57 | 58 | # Install BioLAMA 59 | git clone https://github.com/dmis-lab/BioLAMA.git 60 | cd BioLAMA 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | ## Resources 65 | 66 | ### Models 67 | For BERT and BioBERT, we use checkpoints provided in the Huggingface Hub: 68 | - [best-base-cased](https://huggingface.co/bert-base-cased) (for BERT) 69 | - [dmis-lab/biobert-base-cased-v1.2](https://huggingface.co/dmis-lab/biobert-base-cased-v1.2) (for BioBERT) 70 | 71 | Bio-LM is not provided in the Huggingface Hub. Therefore, we use the Bio-LM checkpoint released in [link](https://github.com/facebookresearch/bio-lm). Among the various versions of Bio-LMs, we use `RoBERTa-base-PM-Voc-hf'. 72 | ``` 73 | wget https://dl.fbaipublicfiles.com/biolm/RoBERTa-base-PM-Voc-hf.tar.gz 74 | tar -xzvf RoBERTa-base-PM-Voc-hf.tar.gz 75 | rm -rf RoBERTa-base-PM-Voc-hf.tar.gz 76 | ``` 77 | 78 | ### Datasets 79 | 80 | The dataset will take about 85 MB of space. You can download the dataset [here](https://drive.google.com/file/d/1CcjpmNuAXavL3aMjwVqiiziMu3OGDyyG/view?usp=sharing). 81 | 82 | ``` 83 | tar -xzvf data.tar.gz 84 | rm -rf data.tar.gz 85 | ``` 86 | 87 | The directory tree of the data is like: 88 | ``` 89 | data 90 | ├── ctd 91 | │ ├── entities 92 | │ ├── meta 93 | │ ├── prompts 94 | │ └── triples_processed 95 | │ └── CD1 96 | │ ├── dev.jsonl 97 | │ ├── test.jsonl 98 | │ └── train.jsonl 99 | ├── wikidata 100 | │ ├── entities 101 | │ ├── meta 102 | │ ├── prompts 103 | │ └── triples_processed 104 | │ └── P2175 105 | │ ├── dev.jsonl 106 | │ ├── test.jsonl 107 | │ └── train.jsonl 108 | └── umls 109 | ├── meta 110 | └── prompts 111 | └── triples_processed 112 | └── UR44 113 | ├── dev.jsonl 114 | ├── test.jsonl 115 | └── train.jsonl 116 | 117 | 118 | ``` 119 | 120 | ## Experiments 121 | 122 | We provide two ways of probing PLMs with BioLAMA: 123 | - [Manual Prompt](#manual-prompt) 124 | - [OptiPrompt](#optiprompt) 125 | 126 | ### Manual Prompt 127 | 128 | Manual Prompt probes PLMs using pre-defined manual prompts. The predictions and scores will be logged in '/output'. 129 | 130 | ``` 131 | # Set TASK to 'ctd' for CTD or 'umls' for UMLS 132 | # Set MODEL to 'bert-base-cased' for BERT or 'dmis-lab/biobert-base-cased-v1.2' for BioBERT 133 | TASK=wikidata 134 | MODEL=./RoBERTa-base-PM-Voc/RoBERTa-base-PM-Voc-hf 135 | PROMPT_PATH=./data/${TASK}/prompts/manual.jsonl 136 | TEST_PATH=./data/${TASK}/triples_processed/*/test.jsonl 137 | 138 | python ./BioLAMA/run_manual.py \ 139 | --model_name_or_path ${MODEL} \ 140 | --prompt_path ${PROMPT_PATH} \ 141 | --test_path "${TEST_PATH}" \ 142 | --init_method confidence \ 143 | --iter_method none \ 144 | --num_mask 10 \ 145 | --max_iter 10 \ 146 | --beam_size 5 \ 147 | --batch_size 16 \ 148 | --output_dir ./output/${TASK}_manual 149 | ``` 150 | 151 | Result: 152 | ``` 153 | PID Acc@1 Acc@5 154 | ------------------------- 155 | P2175 9.40 21.11 156 | P2176 22.46 39.75 157 | P2293 2.24 11.43 158 | P4044 9.47 19.47 159 | P780 16.30 37.85 160 | ------------------------- 161 | MACRO 11.97 25.92 162 | ``` 163 | 164 | ### OptiPrompt 165 | 166 | OptiPrompt probes PLMs using embedding-based prompts starting from embeddings of manual prompts. The predictions and scores will be logged in '/output'. 167 | 168 | ``` 169 | # Set TASK to 'ctd' for CTD or 'umls' for UMLS 170 | # Set MODEL to 'bert-base-cased' for BERT or 'dmis-lab/biobert-base-cased-v1.2' for BioBERT 171 | TASK=wikidata 172 | MODEL=./RoBERTa-base-PM-Voc/RoBERTa-base-PM-Voc-hf 173 | PROMPT_PATH=./data/${TASK}/prompts/manual.jsonl 174 | TRAIN_PATH=./data/${TASK}/triples_processed/*/train.jsonl 175 | DEV_PATH=./data/${TASK}/triples_processed/*/dev.jsonl 176 | TEST_PATH=./data/${TASK}/triples_processed/*/test.jsonl 177 | 178 | python ./BioLAMA/run_optiprompt.py \ 179 | --model_name_or_path ${MODEL} \ 180 | --train_path "${TRAIN_PATH}" \ 181 | --dev_path "${DEV_PATH}" \ 182 | --test_path "${TEST_PATH}" \ 183 | --prompt_path ${PROMPT_PATH} \ 184 | --num_mask 10 \ 185 | --init_method confidence \ 186 | --iter_method none \ 187 | --max_iter 10 \ 188 | --beam_size 5 \ 189 | --batch_size 16 \ 190 | --lr 3e-3 \ 191 | --epochs 10 \ 192 | --seed 0 \ 193 | --prompt_token_len 5 \ 194 | --init_manual_template \ 195 | --output_dir ./output/${TASK}_optiprompt 196 | ``` 197 | 198 | Result: 199 | ``` 200 | PID Acc@1 Acc@5 201 | ------------------------- 202 | P2175 9.47 24.94 203 | P2176 20.14 39.57 204 | P2293 2.90 9.21 205 | P4044 7.53 18.58 206 | P780 12.98 33.43 207 | ------------------------- 208 | MACRO 7.28 18.51 209 | ``` 210 | 211 | ### IE Baseline (BEST) 212 | 213 | BEST (Biomedical Entity Search Tool) is a returns relevant biomedical entity given a query. By constructing the query We used BEST as an information extraction baseline. 214 | 215 | ``` 216 | TASK=wikidata 217 | TEST_PATH=./data/${TASK}/triples_processed/*/test.jsonl 218 | CUDA_VISIBLE_DEVICES=0 python ./BioLAMA/run_ie.py \ 219 | --test_path "${TEST_PATH}" \ 220 | --output_dir ./output/${TASK}_ie 221 | ``` 222 | 223 | 224 | ## Acknowledgement 225 | Parts of the code are modified from [genewikiworld](https://github.com/SuLab/genewikiworld), [X-FACTR](https://github.com/jzbjyb/X-FACTR), and [OptiPrompt](https://github.com/princeton-nlp/OptiPrompt). We appreciate the authors for making their projects open-sourced. 226 | 227 | ## Citations 228 | ```bibtex 229 | @inproceedings{sung2021can, 230 | title={Can Language Models be Biomedical Knowledge Bases}, 231 | author={Sung, Mujeen and Lee, Jinhyuk and Yi, Sean and Jeon, Minji and Kim, Sungdong and Kang, Jaewoo}, 232 | booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 233 | year={2021}, 234 | } 235 | ``` 236 | -------------------------------------------------------------------------------- /preprocessing/aggregate_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import os 4 | import argparse 5 | from glob import glob 6 | from transformers import ( 7 | AutoTokenizer 8 | ) 9 | import random 10 | from nltk.tokenize import sent_tokenize 11 | import copy 12 | from utils import wc 13 | from collections import Counter 14 | 15 | random.seed(0) 16 | 17 | def flatten_list(l): 18 | new_list = [] 19 | for element in l: 20 | if isinstance(element, str): 21 | new_list.append(element) 22 | elif isinstance(element, list): 23 | new_list.extend(element) 24 | 25 | return new_list 26 | 27 | def save_to_jsonl(save_path, data): 28 | print(f"Saving to {save_path}") 29 | with open(file=save_path, mode='w') as f: 30 | for uuid in data: 31 | output = data[uuid] 32 | output = json.dumps(output, ensure_ascii=False) 33 | f.write(output + '\n') 34 | 35 | def shuffle_and_truncate_triples(triples_to_probe, max_count=None): 36 | total = len(triples_to_probe) 37 | 38 | # shuffle 39 | l = list(triples_to_probe.items()) 40 | random.shuffle(l) 41 | 42 | if max_count and (total > max_count): 43 | l = random.sample(l, k=max_count) 44 | 45 | triples_to_probe = dict(l) 46 | return triples_to_probe 47 | 48 | def undersample(triples_to_probe, k=5): 49 | all_obj_uris = [] 50 | samples = [] 51 | for _, sample in triples_to_probe.items(): 52 | obj_uris = sample['obj_uris'] 53 | all_obj_uris += obj_uris 54 | samples.append(sample) 55 | 56 | # count per class (obj2count) 57 | obj2count = {k: v for k, v in sorted(Counter(all_obj_uris).items(), key=lambda item: item[1],reverse=True)} 58 | 59 | if len(list(obj2count.items())) < k: 60 | return {} 61 | 62 | # topk count to undersample 63 | try: 64 | topk_count = list(obj2count.items())[k-1][1] 65 | except IndexError: 66 | assert 0 67 | 68 | # undersample 69 | while True: 70 | # init for iteration 71 | obj2count = {k: v for k, v in sorted(obj2count.items(), key=lambda item: item[1],reverse=True)} 72 | majority_obj,majority_count = list(obj2count.items())[0] # majority object 73 | if majority_count <= topk_count: 74 | break 75 | 76 | random.shuffle(samples) 77 | 78 | for i, sample in enumerate(samples): 79 | obj_uris = sample['obj_uris'] 80 | if majority_obj in obj_uris: 81 | for obj_uri in obj_uris: 82 | obj2count[obj_uri] -= 1 83 | del samples[i] # remove this sample 84 | break 85 | 86 | # restore samples to triples_to_probe 87 | new_triples_to_probe = {} 88 | for sample in samples: 89 | uuid = sample['uuid'] 90 | new_triples_to_probe[uuid] = sample 91 | print(f"undersample {len(triples_to_probe)}->{len(new_triples_to_probe)}") 92 | return new_triples_to_probe 93 | 94 | def split_train_dev_test(triples_to_probe): 95 | total = len(triples_to_probe) 96 | 97 | # 4:1:5 = train:dev:test 98 | trainset = dict(list(triples_to_probe.items())[:int(total*0.4)]) 99 | devset = dict(list(triples_to_probe.items())[int(total*0.4):int(total*0.5)]) 100 | testset = dict(list(triples_to_probe.items())[int(total*0.5):]) 101 | print(list(trainset.items())[0]) 102 | print(list(devset.items())[0]) 103 | 104 | print(list(testset.items())[0]) 105 | print(f"total len:{len(triples_to_probe)}") 106 | print(f"trainset={len(trainset)} devset={len(devset)} testset={len(testset)}") 107 | return trainset, devset, testset 108 | 109 | def main(args): 110 | input_path = args.input_path 111 | output_dir = args.output_dir 112 | os.makedirs(output_dir, exist_ok=True) 113 | 114 | assert input_path != output_dir 115 | 116 | if args.sub_obj_type_path: 117 | with open(args.sub_obj_type_path) as f: 118 | sub_obj_types = json.load(f) 119 | else: 120 | sub_obj_types ={} 121 | 122 | if args.model_name_or_path: 123 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 124 | input_files = glob(args.input_path) 125 | unique_objs = set() 126 | 127 | for input_file in tqdm(input_files): 128 | file_name = input_file.split("/")[-1] 129 | property_name = file_name.replace(".jsonl","") 130 | 131 | print(file_name) 132 | triples_to_probe = {} 133 | 134 | # make tmp file for input 135 | with open(input_file) as f: 136 | num_lines = wc(input_file) 137 | for line in tqdm(f, total=num_lines): 138 | data = json.loads(line) 139 | pid = data['predicate_id'] 140 | uuid = '-'.join([data['sub_uri'], pid]) 141 | 142 | # filter noisy type sample 143 | if pid in sub_obj_types: 144 | sub_types = sub_obj_types[pid]['sub_types'] 145 | obj_types = sub_obj_types[pid]['obj_types'] 146 | if (data['sub_type'] not in sub_types) or (data['obj_type'] not in obj_types): 147 | continue 148 | 149 | if uuid not in triples_to_probe: 150 | triples_to_probe[uuid] = { 151 | 'uuid': uuid, 152 | 'predicate_id': pid, 153 | 'sub_uri': data['sub_uri'], 154 | 'sub_label': data['sub_label'], 155 | 'sub_type': data['sub_type'] if 'sub_type' in data else '', 156 | 'sub_aliases':data['sub_aliases'], 157 | 'obj_uris': [], 158 | 'obj_labels': [], 159 | 'obj_types': [], 160 | 'obj_aliases':[], 161 | } 162 | 163 | if data['obj_uri'] in triples_to_probe[uuid]['obj_uris']: 164 | continue 165 | 166 | # for multiple answers 167 | triples_to_probe[uuid]['obj_uris'].append(data['obj_uri']) 168 | triples_to_probe[uuid]['obj_labels'].append(data['obj_label']) 169 | triples_to_probe[uuid]['obj_aliases'].append(data['obj_aliases']) 170 | if 'obj_type' in data: 171 | triples_to_probe[uuid]['obj_types'].append(data['obj_type']) 172 | 173 | # split train/dev/test 174 | triples_to_probe = shuffle_and_truncate_triples(triples_to_probe, max_count=args.max_count) 175 | 176 | # filter min_count 177 | if args.min_count and len(triples_to_probe) < args.min_count: 178 | print(f"filter this cause {len(triples_to_probe)} < {args.min_count}") 179 | continue 180 | 181 | # undersample for balancing 182 | triples_to_probe = undersample(triples_to_probe) 183 | 184 | # filter min_count 185 | if args.min_count and len(triples_to_probe) < args.min_count: 186 | print(f"filter this cause {len(triples_to_probe)} < {args.min_count}") 187 | continue 188 | 189 | trainset, devset, testset = split_train_dev_test(triples_to_probe) 190 | 191 | sub_output_dir = os.path.join(output_dir, property_name) 192 | os.makedirs(sub_output_dir, exist_ok=True) 193 | 194 | # for length stat 195 | for value in triples_to_probe.values(): 196 | objs = copy.deepcopy(value['obj_labels']) 197 | objs.extend(copy.deepcopy(value['obj_aliases'])) 198 | objs = flatten_list(objs) 199 | 200 | for obj in objs: 201 | unique_objs.add(obj) 202 | 203 | # save train,dev,test 204 | train_path = os.path.join(sub_output_dir, 'train.jsonl') 205 | dev_path = os.path.join(sub_output_dir, 'dev.jsonl') 206 | test_path = os.path.join(sub_output_dir, 'test.jsonl') 207 | 208 | save_to_jsonl(save_path=train_path, data=trainset) 209 | save_to_jsonl(save_path=dev_path, data=devset) 210 | save_to_jsonl(save_path=test_path, data=testset) 211 | 212 | # do this with new triples_to_probe (after truncate, undersample) 213 | obj_lengths = {} 214 | for obj in tqdm(iterable=unique_objs, desc="Getting obj lengths", total=len(unique_objs)): 215 | tokenized_obj = tokenizer.tokenize(obj) 216 | obj_len = len(tokenized_obj) 217 | 218 | try: 219 | obj_lengths[obj_len] += 1 220 | except KeyError: 221 | obj_lengths[obj_len] = 1 222 | 223 | obj_lengths = sorted([(length, count) for length, count in obj_lengths.items()], key=lambda x: x[0]) 224 | print() 225 | print("Obj Length: Count") 226 | for pair in obj_lengths: 227 | print(f'{pair[0]}, {pair[1]}') 228 | print() 229 | 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument("--input_path") 234 | parser.add_argument("--entity_dir") 235 | parser.add_argument("--output_dir") 236 | parser.add_argument("--min_count", type=int, default=500) 237 | parser.add_argument("--max_count", type=int, default=2000) 238 | parser.add_argument("--sub_obj_type_path", default=None) 239 | parser.add_argument('--model_name_or_path') 240 | args = parser.parse_args() 241 | 242 | main(args) 243 | -------------------------------------------------------------------------------- /BioLAMA/run_ie.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | 5 | import numpy as np 6 | import stanza 7 | from tqdm import tqdm 8 | 9 | import best 10 | from evaluator import Evaluator 11 | import os 12 | 13 | def flatten_list(l): 14 | new_list = [] 15 | 16 | for element in l: 17 | if isinstance(element, str): 18 | new_list.append(element) 19 | elif isinstance(element, list) and (element != []): 20 | new_list.extend(element) 21 | 22 | return new_list 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--test_path', required=True) 28 | parser.add_argument('--output_dir', required=True) 29 | parser.add_argument('--topN', type=int, default=50) 30 | parser.add_argument('--draft', action='store_true') 31 | 32 | args = parser.parse_args() 33 | 34 | os.makedirs(args.output_dir, exist_ok=True) 35 | 36 | task = args.test_path.split("/")[-4] 37 | 38 | evaluator = Evaluator() 39 | 40 | stanza.download('en', package='craft') 41 | stanza_pipeline = stanza.Pipeline('en', package='craft') 42 | 43 | if task == 'ctd': 44 | PROPERTY_NAMES = {'CD1': "therapeutic", 45 | 'CD2': "marker mechanism", 46 | 'CG1': "decreases expresion", 47 | 'CG17': "increases expression", 48 | 'CG18': "increases expression", 49 | 'CG2': "decreases activity", 50 | 'CG21': "increases phosphorylation", 51 | 'CG4': "increases activity", 52 | 'CG6': "decreases expression", 53 | 'CG9': "affects binding", 54 | 'CP1': "decreases", 55 | 'CP2': "increases", 56 | 'CP3': "affects", 57 | 'GD1': "marker mechanism", 58 | 'GP1': "association"} 59 | 60 | PROPERTY_OBJ_TYPES = {'CD1': 'disease', 61 | 'CD2': 'disease', 62 | 'CG1': 'gene', 63 | 'CG17': 'gene', 64 | 'CG18': 'gene', 65 | 'CG2': 'gene', 66 | 'CG21': 'gene', 67 | 'CG4': 'gene', 68 | 'CG6': 'gene', 69 | 'CG9': 'gene', 70 | 'CP1': 'disease', 71 | 'CP2': 'disease', 72 | 'CP3': 'disease', 73 | 'GD1': 'disease', 74 | 'GP1': 'pathway'} 75 | elif task == 'umls': 76 | PROPERTY_NAMES = {'UR44': "may be prevented by", 77 | 'UR45': "may be treated by", 78 | 'UR48': "physiologic effect of", 79 | 'UR49': "mechanism of action of", 80 | 'UR50': "therapeutic class of", 81 | 'UR116': "clinically associated with", 82 | 'UR124': "may treat", 83 | 'UR173': "causative agent of", 84 | 'UR180': "is finding of disease", 85 | 'UR211': "biological process involves gene product", 86 | 'UR214': "cause of", 87 | 'UR221': "gene mapped to disease", 88 | 'UR254': "may be molecular abnormality of disease", 89 | 'UR256': "may be molecular abnormality of disease", 90 | 'UR588': "process involves gene", 91 | 'UR625': "disease has associated gene"} 92 | 93 | PROPERTY_OBJ_TYPES = {'UR44': 'disease', 94 | 'UR45': 'disease', 95 | 'UR48': 'disease', 96 | 'UR49': 'all entity type', 97 | 'UR50': 'chemical compound', 98 | 'UR116': 'disease', 99 | 'UR124': 'chemical compound', 100 | 'UR173': 'all entity type', 101 | 'UR180': 'all entity type', 102 | 'UR211': 'all entity type', 103 | 'UR214': 'disease', 104 | 'UR221': 'gene', 105 | 'UR254': 'disease', 106 | 'UR256': 'all entity type', 107 | 'UR588': 'disease', 108 | 'UR625': 'disease'} 109 | elif task == 'wikidata': 110 | PROPERTY_NAMES = {'P2176': "drug used for treatment", 111 | 'P2175': "medical condition treated", 112 | 'P780': "symptoms", 113 | 'P2293': "genetic association", 114 | 'P4044': "therapeutic area"} 115 | 116 | PROPERTY_OBJ_TYPES = {'P2176': 'chemical compound', 117 | 'P2175': 'disease', 118 | 'P780': 'disease', 119 | 'P2293': 'disease', 120 | 'P4044': 'disease'} 121 | else: 122 | print(f"not supporting task:{task}") 123 | assert 0 124 | 125 | test_files = glob.glob(args.test_path) 126 | selected_data = {} 127 | for test_file in test_files: 128 | pid = test_file.split('/')[-2] 129 | 130 | with open(file=test_file) as f: 131 | data = [json.loads(line) for line in f] 132 | selected_data[pid] = data 133 | 134 | ie_result = {} 135 | pid2performance={} 136 | pbar = tqdm(iterable=selected_data.items(), total=len(selected_data)) 137 | for pid, data in pbar: 138 | pbar.set_description(desc=f"Running IE for {pid}") 139 | 140 | ie_result[pid] = {'acc@1': 0.0, 'acc@5': 0.0} 141 | 142 | all_preds_probs = [] 143 | all_golds = [] 144 | uuids = [] 145 | 146 | if args.draft: 147 | data = data[:5] 148 | 149 | for idx, row in tqdm(enumerate(data), total=len(data)): 150 | subj_label = row['sub_label'] 151 | uuid = row['uuid'] 152 | 153 | objs = row['obj_labels'] 154 | objs.extend(row['obj_aliases']) 155 | objs = flatten_list(l=objs) 156 | 157 | property_name = PROPERTY_NAMES[pid] 158 | doc = stanza_pipeline(property_name) 159 | stanza_words = doc.sentences[0].words 160 | valid_words = [x.lemma for x in stanza_words if x.pos in ['NOUN', 'VERB', 'ADJ']] 161 | assert len(valid_words) != 0 162 | 163 | query = f"({subj_label}) AND ({' '.join(valid_words)})" 164 | 165 | # if idx<5: 166 | # print(f"{pid}:{query}") 167 | ie_query = best.BESTQuery(query, topN=args.topN, filterObjectName=PROPERTY_OBJ_TYPES[pid]) 168 | result = best.getRelevantBioEntities(ie_query) 169 | 170 | try: 171 | result_names = [(x['entityName'], x['score']) for x in result] 172 | except (KeyError, TypeError): 173 | result_names = [] 174 | 175 | result_names.extend([('', 0.0)] * (10 - len(result_names))) 176 | 177 | all_preds_probs.append(result_names) 178 | all_golds.append(objs) 179 | uuids.append(uuid) 180 | 181 | inputs = [] 182 | for sample in all_preds_probs: 183 | inputs_ = [] 184 | 185 | for _ in sample: 186 | inputs_.append('') 187 | 188 | inputs.append(inputs_) 189 | 190 | result = evaluator.evaluate(all_preds_probs=all_preds_probs, all_golds=all_golds, subjects=[''] * len(all_preds_probs), prompts='', uuids=uuids, inputs=inputs) 191 | 192 | # saving log 193 | log_file = os.path.join(args.output_dir, pid + ".json") 194 | print(f"save {log_file}") 195 | with open(log_file, 'w') as f: 196 | json.dump(result, f) 197 | 198 | 199 | acc_at_1 = result['performance']['acc@k'][0] 200 | acc_at_5 = result['performance']['acc@k'][4] 201 | 202 | ie_result[pid]['acc@1'] = acc_at_1 203 | ie_result[pid]['acc@5'] = acc_at_5 204 | 205 | performance = result['performance'] 206 | local_acc = performance['acc@k'] 207 | 208 | logging_data ={} 209 | pid2performance[pid] = {} 210 | for k in range(5): 211 | if k+1 in [1,5]: 212 | acc = local_acc[k] 213 | logging_data[f"{pid}_acc@{k+1}"] = acc * 100 214 | pid2performance[pid][f'acc@{k+1}'] = acc * 100 215 | print(f'performance of {pid}') 216 | print(logging_data) 217 | 218 | print("PID\tAcc@1\tAcc@5") 219 | print("-------------------------") 220 | acc1s = [] 221 | acc5s = [] 222 | for pid, performance in pid2performance.items(): 223 | acc1 = performance['acc@1'] 224 | acc5 = performance['acc@5'] 225 | acc1s.append(acc1) 226 | acc5s.append(acc5) 227 | print(f"{pid}\t{round(acc1,2)}\t{round(acc5,2)}") 228 | 229 | print("-------------------------") 230 | print(f"MACRO\t{round(np.mean(acc1s),2)}\t{round(np.mean(acc5s),2)}") 231 | 232 | if __name__ == '__main__': 233 | main() -------------------------------------------------------------------------------- /BioLAMA/best.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: BEST 3 | :platform: Unix, linux, Windows 4 | .. moduleauthor:: Sunkyu Kim 5 | ================================ 6 | Biomedical Entity Query API v2 7 | ================================ 8 | API Description 9 | ================ 10 | This API is for use of BEST(Biomedical Entity Search Tool) in various purposes. 11 | All users can access BEST at : http://best.korea.ac.kr/ 12 | For bugs and inquiries, please contact: 13 | * Jaewoo Kang(kangj@korea.ac.kr) 14 | * Sunkyu Kim(sunkyu-kim@korea.ac.kr) 15 | Reference : https://doi.org/10.1371/journal.pone.0164680 16 | Usage Examples 17 | =============== 18 | To see ‘gene’s related ‘breast cancer’, use this sample code. 19 | >>> bestQuery = best.BESTQuery("breast cancer", 20 | filterObjectName="gene", 21 | noAbsTxt=False) 22 | >>> searchResult = best.getRelevantBioEntities(bestQuery) 23 | >>> print(searchResult) 24 | [{ 'entityname' : 'ERBB2', 25 | 'score' : 8098.43, 26 | 'PMIDs' : ['28427196', '28341751', '28199325'], 27 | 'abstracts' : [ 28 | 'Molecular-based cancer tests...', 29 | 'The molecular subtype of breast...' 30 | 'Breast cancer is the second leading cause of...'], 31 | 'numArticles':14537 32 | 'rank' : 1}, 33 | { 'entityname' : 'ESR1', 34 | 'score' : 7340.54, 35 | 'PMIDs' : ['27923387', '28274211', '26276891'], 36 | 'abstracts' : [ 37 | 'Several studies have shown that mammographic..', 38 | 'A shift towards less burdening and more...' 39 | 'The complete molecular basis of the organ-...'], 40 | 'numArticles':18084 41 | 'rank' : 2}, 42 | ... 43 | ] 44 | Changing noAbsTxt=True can make the process faster. 45 | >>> bestQuery = best.BESTQuery("breast cancer", 46 | filterObjectName="gene", 47 | noAbsTxt=True) 48 | >>> searchResult = best.getRelevantBioEntities(bestQuery) 49 | >>> print(searchResult) 50 | [{ 'entityname' : 'ERBB2', 51 | 'score' : 8098.43, 52 | 'PMIDs' : [], 53 | 'abstracts' : [], 54 | 'numArticles':14537 55 | 'rank' : 1}, 56 | { 'entityname' : 'ESR1', 57 | 'score' : 7340.54, 58 | 'PMIDs' : [], 59 | 'abstracts' : [], 60 | 'numArticles':18084 61 | 'rank' : 2}, 62 | ... 63 | ] 64 | If you want to see other entity types, change filterObjectName. 65 | .. note:: Total 10 filterObjects(entity types) are available. 66 | * gene 67 | * drug 68 | * chemical compound 69 | * target 70 | * disease 71 | * toxin 72 | * transcription factor 73 | * mirna 74 | * pathway 75 | * mutation 76 | >>> bestQuery = best.BESTQuery("breast cancer", 77 | filterObjectName="drug", 78 | noAbsTxt=True) 79 | >>> searchResult = best.getRelevantBioEntities(bestQuery) 80 | >>> print(searchResult) 81 | [{ 'entityname' : 'tamoxifen', 82 | 'score' : 3208.687, 83 | 'abstracts' : [], 84 | 'numArticles':10583 85 | 'rank' : 1}, 86 | { 'entityname' : 'doxorubicin', 87 | 'score' : 1639.867, 88 | 'abstracts' : [], 89 | 'numArticles':6074 90 | 'rank' : 2}, 91 | ... 92 | ] 93 | Class/Function Description 94 | =========================== 95 | """ 96 | import http 97 | #from http.client import HTTPException 98 | import socket 99 | 100 | class BESTQuery(): 101 | """ 102 | BESTQuery class is basic query object for BEST API. 103 | """ 104 | 105 | __besturl = "http://best.korea.ac.kr/s?" 106 | 107 | 108 | def __init__(self, querystr, filterObjectName="All Entity Type", topN=20, noAbsTxt=True): 109 | """BESTQuery 110 | :param querystr, filterObjectName : result type, topN, noAbsTxt : if True, the result doesn't include the abstract texts. 111 | . 112 | >>> query = BESTQuery("lung cancer", filterObjectName="gene", topN=10, noAbsTxt=False) 113 | >>> # 10 genes related with lung cancer is searched including the abstract texts. 114 | """ 115 | 116 | self.querystr = querystr 117 | self.filterObjectName = filterObjectName 118 | self.topN = topN 119 | self.noAbsTxt = noAbsTxt 120 | 121 | def setQuerystr (self, querystr): 122 | """Setting the query 123 | :param querystr: a string object 124 | >>> query.setQuery(["cancer"]) 125 | """ 126 | if type(querystr) is not str: 127 | print ("Initialize error : invalid query. It should be a string object.") 128 | print (querystr) 129 | return 130 | 131 | if len(querystr) == 0: 132 | return 133 | 134 | self.querystr = querystr 135 | 136 | def getQuerystr (self): 137 | """Getting the query String 138 | :return: A string 139 | >>> querystr = query.getQuerystr() 140 | >>> print (querystr) 141 | ["cancer"] 142 | """ 143 | return self.querystr 144 | 145 | def _isValid(self): 146 | if self.querystr is not None and self.querystr is not None and type(self.querystr) is not str: 147 | return False 148 | 149 | for keya in self.querystr : 150 | if type(keya) is not str : 151 | return False 152 | 153 | if self.topN <= 0: 154 | return False 155 | 156 | return True 157 | 158 | def setTopN (self, n): 159 | """ Setting the number of results retrieved by query 160 | :param n: the number of results to be retrieved 161 | >>> query.setTopN(100) 162 | """ 163 | self.topN = n 164 | 165 | def getTopN (self): 166 | """ Getting the number of results retrieved by query 167 | :return: the number of results to be retrieved 168 | >>> print (query.getTopN()) 169 | 100 170 | """ 171 | return self.topN 172 | 173 | def setFilterObjectName (self, oname): 174 | """ Setting the filtering object. 175 | Total 10 types are available. 176 | * gene 177 | * drug 178 | * chemical compound 179 | * target 180 | * disease 181 | * toxin 182 | * transcription factor 183 | * mirna 184 | * pathway 185 | * mutation 186 | >>> qeury.setFilterObjectName("Gene") 187 | """ 188 | self.filterObjectName = oname 189 | 190 | def getFilterObjectName (self): 191 | """ Getting the filtering entity type. 192 | >>> print(query.getFilterObjectName()) 193 | "breast cancer" 194 | """ 195 | return self.filterObjectName 196 | 197 | def makeQueryString(self): 198 | queryKeywords = self.querystr 199 | querytype = self.filterObjectName.lower() 200 | noAbsTxt = self.noAbsTxt 201 | 202 | import urllib.parse 203 | 204 | queryKeywords = "q=" + urllib.parse.quote(queryKeywords) 205 | 206 | otype = "" 207 | if querytype == "gene": 208 | otype = "8" 209 | elif querytype == "drug": 210 | otype = "5" 211 | elif querytype == "chemical compound": 212 | otype = "3" 213 | elif querytype == "target": 214 | otype = "14" 215 | elif querytype == "disease": 216 | otype = "4" 217 | elif querytype == "toxin": 218 | otype = "15" 219 | elif querytype == "transcription factor": 220 | otype = "16" 221 | elif querytype == "mirna": 222 | otype = "10" 223 | elif querytype == "pathway": 224 | otype = "12" 225 | elif querytype == "mutation": 226 | otype = "17" 227 | elif querytype == "all entity type": 228 | otype = "" 229 | else: 230 | print ("Invalid type! Object type : All Entity Type") 231 | otype = "" 232 | 233 | if noAbsTxt: 234 | strQuery = self.__besturl + "t=l&wt=xslt&tr=tmpl2.xsl" + "&otype=" + otype + "&rows=" + str(self.topN) + "&" + queryKeywords 235 | else: 236 | strQuery = self.__besturl + "t=l&wt=xslt&tr=tmpl_170602.xsl" + "&otype=" + otype + "&rows=" + str(self.topN) + "&" + queryKeywords 237 | 238 | return strQuery 239 | 240 | def toDataObj(self): 241 | return {"query":self.querystr, "filterObjectName":self.filterObjectName, "topN":self.topN} 242 | 243 | def getRelevantBioEntities(bestQuery): 244 | """ Function for retrieval from BEST 245 | :param bestQuery: BESTQuery 246 | :return: parsed objects (dict-BIOENTITY). 247 | * BIOENTITY (dict): {"entityName":str, "rank":int, "score":float, "numArticles":int, "abstracts":[str]} 248 | >>> bestQuery = BESTQuery( "lung cancer", 249 | filterObjectName="gene", 250 | topN=10, 251 | noAbsTxt=True ) 252 | >>> relevantEntities = getRelevantBioEntities(bestQuery) 253 | """ 254 | if not (type(bestQuery) is BESTQuery): 255 | print ("query is invalid! please check your query object.") 256 | return None 257 | 258 | if not bestQuery._isValid() : 259 | print ("Query object is invalid. Please check the query") 260 | print ("Query : ") 261 | print (" query: " + str(bestQuery.query)) 262 | print (" topN: " + str(bestQuery.topN)) 263 | 264 | return None 265 | 266 | urlquery = bestQuery.makeQueryString() 267 | 268 | import urllib 269 | 270 | resultStr = "" 271 | again = 0 272 | while (again < 5): 273 | try: 274 | request = urllib.request.Request(urlquery) 275 | request.add_header('User-Agent', 'Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1)') 276 | 277 | geneUrl = urllib.request.urlopen(request, timeout=5) 278 | resultStr = geneUrl.read().decode('utf-8') 279 | again = 10 280 | except http.client.BadStatusLine: 281 | again += 1 282 | except http.client.HTTPException: 283 | again += 1 284 | except socket.timeout: 285 | again += 1 286 | except socket.error: 287 | again += 1 288 | except urllib.error.URLError: 289 | again += 1 290 | except Exception: 291 | again += 1 292 | 293 | if again == 5: 294 | print("Network status is not good") 295 | return None 296 | 297 | result = __makeDataFromBestQueryResult(resultStr) 298 | 299 | return result 300 | 301 | def __makeDataFromBestQueryResult(resultStr): 302 | lines = resultStr.split('\n') 303 | linesCnt = len(lines) 304 | 305 | resultDataArr = [] 306 | curData = {"rank":0} 307 | for i in range(1, linesCnt) : 308 | line = lines[i] 309 | 310 | if line.startswith("@@@"): 311 | pmid, text = line[3:].strip().split("###") 312 | curData["abstracts"].append(text) 313 | curData["PMIDs"].append(pmid) 314 | else: 315 | if len(line.strip()) == 0 : 316 | continue 317 | 318 | if curData["rank"] != 0: 319 | resultDataArr.append(curData) 320 | 321 | dataResult = line.split(" | ") 322 | 323 | curData = {"rank":int(dataResult[0].strip()), "entityName":dataResult[1].strip(), "score":float(dataResult[2].strip()), "numArticles":int(dataResult[3].strip()), "abstracts":[], "PMIDs":[]} 324 | 325 | resultDataArr.append(curData) 326 | 327 | return resultDataArr -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /preprocessing/process_wikidata_entities.py: -------------------------------------------------------------------------------- 1 | import time 2 | import functools 3 | import numpy as np 4 | 5 | import requests 6 | from tqdm import tqdm 7 | from wikidataintegrator.wdi_core import WDItemEngine 8 | import traceback 9 | from wikidataintegrator.wdi_config import config 10 | import csv 11 | import argparse 12 | import os 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--output_dir",required=True) 16 | args = parser.parse_args() 17 | 18 | # Sometimes cacheing allows a failed sparql query to finish on on subsequent attems 19 | # For this reason we ill run 3 times 20 | config['BACKOFF_MAX_TRIES'] = 3 21 | 22 | execute_sparql_query = WDItemEngine.execute_sparql_query 23 | # comment this ---v out to use official wikidata endpoint 24 | # execute_sparql_query = functools.partial(execute_sparql_query, 25 | # endpoint="http://163.152.163.168:9999/bigdata/namespace/wdq/sparql") 26 | 27 | # instance of subject, subclass of object 28 | special_edges = [('Q11173', 'P1542', 'Q21167512'), # chemical, cause of, chemical hazard 29 | ('Q12136', 'P780', 'Q169872'), # disease, symptom, symptom 30 | ('Q12136', 'P780', 'Q1441305'), # disease, symptom, medical sign 31 | ('Q21167512', 'P780', 'Q169872')] # chemical hazard, symptom, symptom 32 | 33 | special_starts = [q[:2] for q in special_edges] 34 | 35 | 36 | def change_endpoint(endpoint): 37 | global execute_sparql_query 38 | execute_sparql_query = functools.partial(execute_sparql_query, endpoint=endpoint) 39 | 40 | 41 | def chunks(l, n): 42 | """Yield successive n-sized chunks from l.""" 43 | for i in range(0, len(l), n): 44 | yield l[i:i + n] 45 | 46 | 47 | def getConceptLabel(qid): 48 | return getConceptLabels((qid,))[qid] 49 | 50 | 51 | def getConceptLabels(qids): 52 | out = dict() 53 | for chunk in chunks(list(set(qids)), 50): 54 | this_qids = {qid.replace("wd:", "") if qid.startswith("wd:") else qid for qid in chunk} 55 | # Found Some results that begin with 't' and cause request to return no results 56 | bad_ids = {qid for qid in this_qids if not qid.startswith('Q')} 57 | this_qids = '|'.join(this_qids - bad_ids) 58 | params = {'action': 'wbgetentities', 'ids': this_qids, 'languages': 'en', 'format': 'json', 'props': 'labels'} 59 | r = requests.get("https://www.wikidata.org/w/api.php", params=params) 60 | r.raise_for_status() 61 | wd = r.json()['entities'] 62 | # Use empty labels for the bad ids 63 | wd.update({bad_id: {'labels': {'en': {'value': ""}}} for bad_id in bad_ids}) 64 | out.update({k: v['labels'].get('en', dict()).get('value', '') for k, v in wd.items()}) 65 | return out 66 | 67 | 68 | def get_prop_labels(): 69 | """ returns a dict of labels for all properties in wikidata """ 70 | s = """ 71 | SELECT DISTINCT ?property ?propertyLabel 72 | WHERE { 73 | ?property a wikibase:Property . 74 | SERVICE wikibase:label { bd:serviceParam wikibase:language "en" } 75 | }""" 76 | try: 77 | d = execute_sparql_query(s)['results']['bindings'] 78 | except: 79 | print("***** FAILED SPARQL *****") 80 | d = [] 81 | d = {x['property']['value'].replace("http://www.wikidata.org/entity/", ""): 82 | x['propertyLabel']['value'] for x in d} 83 | return d 84 | 85 | 86 | def determine_p(use_subclass, extend=True): 87 | # p = "wdt:P279*" if use_subclass else "wdt:P31/wdt:P279*" 88 | p = "wdt:P279*" if use_subclass else "wdt:P31*" # HOTFIX. TEST 89 | # Option to not extend down 'subclass_of' edges (useful for highly populated node types) 90 | if not extend: 91 | p = p.replace('/wdt:P279*', '').replace('*', '') 92 | return p 93 | 94 | 95 | def is_subclass(qid, return_val=False): 96 | instance_count, instance_items = get_type_entities(qid, use_subclass=False, extend_subclass=False) 97 | subclass_count, subclass_items = get_type_entities(qid, use_subclass=True, extend_subclass=False) 98 | 99 | # If the numbers are close, we need to determine if its because some have both subclass and instance of values 100 | if instance_count != 0 and subclass_count != 0 and abs(np.log10(instance_count) - np.log10(subclass_count)) >= 1: 101 | 102 | p0 = "wdt:P31" 103 | p1 = "wdt:P279" 104 | 105 | s_both = """ 106 | SELECT (COUNT(DISTINCT ?item) as ?c) WHERE { 107 | ?item {p0} {wds} . 108 | ?item {p1} {wds} . 109 | } 110 | """ 111 | s_both = s_both.replace("{wds}", "wd:" + qid).replace("{p0}", p0).replace("{p1}", p1) 112 | # print(f"is_subclass={s_both}") 113 | both = execute_sparql_query(s_both)['results']['bindings'] 114 | both = {qid: int(x['c']['value']) for x in both}.popitem()[1] 115 | 116 | is_sub = subclass_count - both > instance_count 117 | else: 118 | is_sub = subclass_count > instance_count 119 | 120 | if return_val: 121 | if is_sub: 122 | return is_sub, subclass_count, subclass_items 123 | else: 124 | return is_sub, instance_count, instance_items 125 | # return is_sub, subclass_count if is_sub else instance_count 126 | return is_sub 127 | 128 | def get_type_entities(qid, use_subclass=False, extend_subclass=True): 129 | """ 130 | For each qid, get the number of items that are instance of (types) this qid 131 | """ 132 | p = determine_p(use_subclass, extend_subclass) 133 | # get count first 134 | s = """ 135 | SELECT (COUNT(DISTINCT ?item) as ?c) WHERE { 136 | ?item {p} {wds} 137 | } 138 | """.replace("{wds}", "wd:" + qid).replace("{p}", p) 139 | try: 140 | d = execute_sparql_query(s)['results']['bindings'] 141 | except Exception as e: 142 | print(e) 143 | traceback.print_exc() 144 | raise e 145 | 146 | count = {qid: int(x['c']['value']) for x in d}.popitem()[1] 147 | 148 | # get entity 149 | LIMIT = 5000 150 | max_iter = int(count/LIMIT) 151 | 152 | items = [] 153 | for i in tqdm(range(max_iter + 1),desc=f"{qid}"): 154 | s = """ 155 | SELECT ?item ?itemLabel ?itemDescription 156 | (GROUP_CONCAT(DISTINCT(?itemAlt); separator='| ') as ?itemAltLabel) 157 | WHERE{ 158 | { 159 | SELECT DISTINCT ?item WHERE { 160 | ?item {p} {wds} 161 | } 162 | LIMIT {limit} 163 | OFFSET {offset} 164 | } 165 | OPTIONAL { 166 | ?item skos:altLabel ?itemAlt . 167 | FILTER (lang(?itemAlt)='en') 168 | } 169 | SERVICE wikibase:label { bd:serviceParam wikibase:language "en". } 170 | } GROUP BY ?item ?itemLabel ?itemDescription 171 | """.replace("{wds}", "wd:" + qid).replace("{p}", p).replace("{limit}", str(LIMIT)).replace("{offset}", str(LIMIT*i)) 172 | try: 173 | # print(f"get_type_entities={s}") 174 | d = execute_sparql_query(s)['results']['bindings'] 175 | except Exception as e: 176 | print(e) 177 | traceback.print_exc() 178 | raise e 179 | 180 | for row in d: 181 | item = row['item']['value'].split("/")[-1] 182 | itemLabel = row['itemLabel']['value'] if 'itemLabel' in row else '' 183 | itemAltLabel = row['itemAltLabel']['value'] if 'itemAltLabel' in row else '' 184 | itemDescription = row['itemDescription']['value'] if 'itemDescription' in row else '' 185 | 186 | 187 | items.append((item, itemLabel, itemDescription, itemAltLabel)) 188 | 189 | # assert count == len(items) 190 | if count != len(items): 191 | print(f"WARN! {count} != {len(items)}") 192 | count = len(items) 193 | 194 | return count, items 195 | 196 | def determine_node_type_and_get_counts(node_ids, name_map=dict(), max_size_for_expansion=200000): 197 | # get all node counts for my special types 198 | subclass_nodes = dict() 199 | expand_nodes = dict() 200 | type_count = dict() 201 | # type_items = dict() 202 | 203 | # These nodes we've seeded and are all 'instance_of' or they are very large and should waste time expanding 204 | # Down subclasses. 205 | # Q11173: chemical compound 206 | # Q2996394: biological process 207 | # Q14860489: molecular function 208 | # Q5058355: cellular component 209 | # Q13442814: scholarly article 210 | # Q16521: taxon 211 | expand_nodes = {q: False for q in ['Q11173', 'Q2996394', 'Q14860489', 'Q5058355', 'Q13442814', 'Q16521']} 212 | 213 | time.sleep(0.5) # Sometimes TQDM prints early, so sleep will endure messages are printed before TQDM starts 214 | t = tqdm(node_ids) 215 | for qid in t: 216 | t.set_description(name_map[qid]) 217 | t.refresh() 218 | is_sub, count, items = is_subclass(qid, True) 219 | 220 | subclass_nodes[qid] = is_sub 221 | if qid not in expand_nodes: 222 | expand_nodes[qid] = count <= max_size_for_expansion 223 | # if expand_nodes[qid]: 224 | # print(f"{qid} is newly added for expand nodes") 225 | if expand_nodes[qid]: 226 | # Small number of nodes, so expand the sublcass... 227 | # count_ext = get_type_count(qid, use_subclass=is_sub, extend_subclass=True) 228 | # count_ext, items_ext = get_type_entities(qid, use_subclass=is_sub, extend_subclass=True) 229 | count_ext, sub_items_ext = get_type_entities(qid, use_subclass=True, extend_subclass=True) 230 | count_ext, ins_items_ext = get_type_entities(qid, use_subclass=False, extend_subclass=True) 231 | 232 | # aggregate items from subclass_of and instances_of 233 | items_ext = list(set(sub_items_ext + ins_items_ext)) 234 | count_ext = len(items_ext) 235 | 236 | # Ensure this is still ok (some will baloon like chemical Compound) 237 | expand_nodes[qid] = count_ext <= max_size_for_expansion 238 | # If its still ok, update the counts 239 | if expand_nodes[qid]: 240 | count = count_ext 241 | items = items_ext 242 | 243 | type_count[qid] = count 244 | 245 | os.makedirs(args.output_dir, exist_ok = True) 246 | 247 | # save type items as files 248 | output_path = os.path.join(args.output_dir, f'{qid}.tsv') 249 | with open(output_path, 'w') as f: 250 | wr = csv.writer(f, delimiter="\t") 251 | for item in items: 252 | wr.writerow(item) 253 | 254 | return type_count, subclass_nodes, expand_nodes 255 | 256 | # Q13442814: scholarly article 257 | # Q16521: taxon 258 | def search_metagraph_from_seeds(seed_nodes, skip_types=('Q13442814', 'Q16521'), min_counts=200, 259 | max_size_for_expansion=200000): 260 | # Make set for easy operations 261 | skip_types = set(skip_types) 262 | 263 | print("Getting type counts") 264 | time.sleep(0.5) # Sometimes TQDM prints early, so sleep will endure messages are printed before TQDM starts 265 | determine_node_type_and_get_counts(seed_nodes.keys(), 266 | seed_nodes, 267 | max_size_for_expansion) 268 | 269 | if __name__ == "__main__": 270 | min_counts = 200 271 | 272 | # these are the special nodes that will have their external ID counts displayed, 273 | # the labels aren't, outputted, only used for monitoring status 274 | seed_nodes = { 275 | 'Q12136': 'disease', 276 | 'Q7187': 'gene', 277 | 'Q8054': 'protein', 278 | 'Q37748': 'chromosome', 279 | 'Q215980': 'ribosomal RNA', 280 | 'Q11173': 'chemical_compound', 281 | 'Q12140': 'medication', 282 | 'Q28885102': 'pharmaceutical_product', 283 | 'Q417841': 'protein_family', 284 | 'Q898273': 'protein_domain', 285 | 'Q2996394': 'biological_process', 286 | 'Q14860489': 'molecular_function', 287 | 'Q5058355': 'cellular_component', 288 | 'Q3273544': 'structural_motif', 289 | 'Q7644128': 'supersecondary_structure', 290 | 'Q616005': 'binding_site', 291 | 'Q423026': 'active_site', 292 | 'Q4936952': 'anatomical structure', 293 | 'Q169872': 'symptom', 294 | 'Q15304597': 'sequence variant', 295 | 'Q4915012': 'biological pathway', 296 | 'Q50377224': 'pharmacologic action', # Subclass 297 | 'Q50379781': 'therapeutic use', 298 | 'Q3271540': 'mechanism of action', # Subclass 299 | } 300 | 301 | # skip edge searches 302 | skip_types = {'Q13442814', 'Q16521'} 303 | 304 | # get entities 305 | search_metagraph_from_seeds(seed_nodes, skip_types, min_counts, 200000) -------------------------------------------------------------------------------- /BioLAMA/run_optiprompt.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AdamW, 4 | get_linear_schedule_with_warmup, 5 | BertTokenizer, 6 | RobertaTokenizer, 7 | BertForMaskedLM, 8 | RobertaForMaskedLM 9 | ) 10 | from torch.utils.data import DataLoader 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | import torch 14 | from data_loader import FactDataset 15 | from tqdm import tqdm 16 | from preprocessor import Preprocessor 17 | from decoder import Decoder 18 | from evaluator import Evaluator 19 | import argparse 20 | import glob 21 | import os 22 | import copy 23 | import numpy as np 24 | import json 25 | import random 26 | 27 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 28 | 29 | MAX_NUM_VECTORS = 20 30 | 31 | def init_template(base_model, tokenizer, prompt_token_len, init_manual_template, manual_template=''): 32 | if init_manual_template: 33 | # hotfix for roberta 34 | manual_template = manual_template.replace("[X]", " [X] ").replace("[Y]", " [Y] ") 35 | manual_template = " ".join(manual_template.split()) 36 | 37 | template = convert_manual_to_dense(manual_template, base_model, tokenizer) 38 | else: 39 | template = '[X] ' + ' '.join(['[V%d]'%(i+1) for i in range(prompt_token_len)]) + ' [Y] .' 40 | return template 41 | 42 | def get_new_token(vid): 43 | assert(vid > 0 and vid <= MAX_NUM_VECTORS) 44 | return '[V%d]'%(vid) 45 | 46 | def prepare_for_dense_prompt(lm_model, tokenizer): 47 | new_tokens = [get_new_token(i+1) for i in range(MAX_NUM_VECTORS)] 48 | tokenizer.add_tokens(new_tokens) 49 | ebd = lm_model.resize_token_embeddings(len(tokenizer)) 50 | print('# vocab after adding new tokens: %d'%len(tokenizer)) 51 | 52 | def set_seed(seed): 53 | """ 54 | Set the random seed. 55 | """ 56 | np.random.seed(seed) 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed(seed) 59 | random.seed(seed) 60 | torch.backends.cudnn.deterministic = True 61 | 62 | def convert_manual_to_dense(manual_template, base_model, tokenizer): 63 | def assign_embedding(new_token, token): 64 | """ 65 | assign the embedding of token to new_token 66 | """ 67 | print('Tie embeddings of tokens: (%s, %s)'%(new_token, token)) 68 | id_a = tokenizer.convert_tokens_to_ids([new_token])[0] 69 | id_b = tokenizer.convert_tokens_to_ids([token])[0] 70 | with torch.no_grad(): 71 | base_model.embeddings.word_embeddings.weight[id_a] = base_model.embeddings.word_embeddings.weight[id_b].detach().clone() 72 | 73 | new_token_id = 0 74 | template = [] 75 | for word in manual_template.split(): 76 | if word in ['[X]', '[Y]']: 77 | template.append(word) 78 | else: 79 | tokens = tokenizer.tokenize(' ' + word) 80 | for token in tokens: 81 | new_token_id += 1 82 | template.append(get_new_token(new_token_id)) 83 | assign_embedding(get_new_token(new_token_id), token) 84 | 85 | template = ' '.join(template) 86 | return template 87 | 88 | def load_model(model_name, tokenizer, random_init='none'): 89 | if isinstance(tokenizer, BertTokenizer): 90 | lm_model = BertForMaskedLM.from_pretrained( 91 | model_name 92 | ).cuda() 93 | base_model = lm_model.bert 94 | 95 | elif isinstance(tokenizer, RobertaTokenizer): 96 | lm_model = RobertaForMaskedLM.from_pretrained( 97 | model_name 98 | ).cuda() 99 | base_model = lm_model.roberta 100 | 101 | else: 102 | print(f"tokenizer type = {type(tokenizer)}") 103 | assert 0 104 | 105 | return lm_model, base_model 106 | 107 | def save_optiprompt(path, lm_model, tokenizer, original_vocab_size): 108 | if isinstance(tokenizer, BertTokenizer): 109 | base_model = lm_model.bert 110 | elif isinstance(tokenizer, RobertaTokenizer): 111 | base_model = lm_model.roberta 112 | 113 | print(f"Saving OptiPrompt's [V]s.. {path}") 114 | vs = base_model.embeddings.word_embeddings.weight[original_vocab_size:].detach().cpu().numpy() 115 | with open(path, 'wb') as f: 116 | np.save(f, vs) 117 | 118 | def load_optiprompt(path, lm_model, tokenizer, original_vocab_size): 119 | if isinstance(tokenizer, BertTokenizer): 120 | base_model = lm_model.bert 121 | elif isinstance(tokenizer, RobertaTokenizer): 122 | base_model = lm_model.roberta 123 | 124 | print("Loading OptiPrompt's [V]s..") 125 | with open(path, 'rb') as f: 126 | vs = np.load(f) 127 | 128 | # copy fine-tuned new_tokens to the pre-trained model 129 | with torch.no_grad(): 130 | # base_model.embeddings.word_embeddings.weight[original_vocab_size:] = torch.Tensor(vs) 131 | base_model.embeddings.word_embeddings.weight[original_vocab_size:original_vocab_size+len(vs)] = torch.Tensor(vs) 132 | 133 | return lm_model, base_model 134 | 135 | # for dev or test 136 | def validate(file, lm_model, preprocessor, decoder, evaluator, template, batch_size, draft=False): 137 | print(f'validate {file}') 138 | lm_model.eval() 139 | sentences, all_gold_objects, subjects, prompts, uuids = preprocessor.preprocess( 140 | file, 141 | template = template, 142 | draft=draft 143 | ) 144 | 145 | decoder.set_model(lm_model) 146 | all_preds_probs = decoder.decode( 147 | sentences, 148 | batch_size=batch_size 149 | ) 150 | 151 | result = evaluator.evaluate( 152 | all_preds_probs = all_preds_probs, 153 | all_golds = all_gold_objects, 154 | subjects=subjects, 155 | prompts=prompts, 156 | uuids=uuids, 157 | inputs=sentences, 158 | ) 159 | 160 | return result 161 | 162 | def main(): 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument("--model_name_or_path", default='bert-base-uncased') 165 | parser.add_argument("--train_path", required=True) 166 | parser.add_argument("--dev_path", required=True) 167 | parser.add_argument("--test_path", required=True) 168 | parser.add_argument("--num_mask", type=int, required=True) 169 | parser.add_argument("--draft", action="store_true") 170 | parser.add_argument("--init_method", choices=['independent','order','confidence'], default='independent') 171 | parser.add_argument("--iter_method", choices=['none','order','confidence'], default='none') 172 | parser.add_argument("--max_iter", type=int, default=1) 173 | parser.add_argument("--beam_size", type=int, default=1) 174 | parser.add_argument("--batch_size", type=int, default=16) 175 | parser.add_argument("--epochs", type=int, default=10) 176 | parser.add_argument("--lr", type=float, default=3e-3) 177 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 178 | parser.add_argument("--prompt_token_len", type=int, default=5) 179 | parser.add_argument("--prompt_path") 180 | parser.add_argument("--prompt_vector_dir", default=None) 181 | parser.add_argument("--init_manual_template", action='store_true') 182 | parser.add_argument("--pids", default=None) 183 | parser.add_argument("--output_dir", default=None) 184 | parser.add_argument("--seed", type=int, default=0) 185 | 186 | args = parser.parse_args() 187 | 188 | train_files = sorted(glob.glob(args.train_path)) 189 | dev_files = sorted(glob.glob(args.dev_path)) 190 | test_files = sorted(glob.glob(args.test_path)) 191 | 192 | if args.pids == None: 193 | pids = [f.split("/")[-2] for f in test_files] 194 | else: 195 | pids = args.pids.split(",") 196 | 197 | if args.init_manual_template: 198 | args.output_dir = args.output_dir + "_imt" 199 | 200 | if args.draft: 201 | args.output_dir = args.output_dir + "_draft" 202 | os.makedirs(args.output_dir, exist_ok=True) 203 | 204 | pid2prompt = {} 205 | if args.prompt_path: 206 | with open(args.prompt_path) as f: 207 | for line in f: 208 | row = json.loads(line) 209 | pid = row['relation'] 210 | prompt = row['template'] 211 | pid2prompt[pid] = prompt 212 | # check 213 | for pid in pids: 214 | assert pid in pid2prompt 215 | 216 | set_seed(args.seed) 217 | 218 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 219 | 220 | original_vocab_size = len(tokenizer) 221 | print('Original vocab size: %d'%original_vocab_size) 222 | 223 | def collate(examples): 224 | inputs = [ex[0] for ex in examples] 225 | obj_ind = [ex[1] for ex in examples] 226 | 227 | if tokenizer._pad_token is None: 228 | return pad_sequence(inputs, batch_first=True), pad_sequence(obj_ind, batch_first=True) 229 | return pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id), pad_sequence(obj_ind, batch_first=True, padding_value=tokenizer.pad_token_id) 230 | 231 | def mask_tokens(inputs: torch.Tensor, obj_ind: torch.Tensor, tokenizer, mask_token): 232 | labels= inputs.clone() 233 | masked_inputs = inputs.clone() 234 | 235 | mask_idx = tokenizer.convert_tokens_to_ids(mask_token) 236 | masked_indices = obj_ind.bool() 237 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 238 | 239 | masked_inputs[masked_indices] = mask_idx 240 | return masked_inputs, labels 241 | 242 | total_relation = 0 243 | 244 | pid2performance = {} 245 | for train_file, dev_file, test_file in zip(train_files, dev_files, test_files): 246 | assert train_file.split("/")[-2] == dev_file.split("/")[-2] == test_file.split("/")[-2] 247 | pid = train_file.split("/")[-2] 248 | if pid not in pids: 249 | continue 250 | 251 | if args.prompt_vector_dir: 252 | prompt_vector_dir = args.prompt_vector_dir 253 | load_prompt_vector = True 254 | else: 255 | prompt_vector_dir = args.output_dir 256 | load_prompt_vector = False 257 | 258 | optiprompt_path = os.path.join(prompt_vector_dir, f"{pid}_optiprompt.np") 259 | print(optiprompt_path) 260 | 261 | lm_model, base_model = load_model( 262 | model_name=args.model_name_or_path, 263 | tokenizer=tokenizer, 264 | ) 265 | # load_optiprompt if the checkpoint exists 266 | prepare_for_dense_prompt(lm_model, tokenizer) 267 | 268 | template = init_template( 269 | base_model=base_model, 270 | tokenizer=tokenizer, 271 | prompt_token_len=args.prompt_token_len, 272 | init_manual_template=args.init_manual_template, 273 | manual_template=pid2prompt[pid] if args.init_manual_template else '' 274 | ) 275 | print('Template: %s'%template) 276 | 277 | do_train = True 278 | if load_prompt_vector: 279 | print(f"load prompt vector {optiprompt_path}") 280 | lm_model, base_model = load_optiprompt(optiprompt_path, lm_model, tokenizer, original_vocab_size) 281 | do_train = False 282 | 283 | train_dataset = FactDataset( 284 | input_file=train_file, 285 | prompt_token_len=args.prompt_token_len, 286 | tokenizer=tokenizer, 287 | template = template, 288 | ) 289 | 290 | train_dataloader = DataLoader( 291 | train_dataset, batch_size=args.batch_size, collate_fn=collate 292 | ) 293 | 294 | epochs = args.epochs if not args.draft else 1 295 | t_total = len(train_dataloader) * epochs 296 | optimizer = AdamW([{'params': base_model.embeddings.word_embeddings.parameters()}], lr=args.lr, correct_bias=False) 297 | scheduler = get_linear_schedule_with_warmup( 298 | optimizer, num_warmup_steps=int(t_total/10), num_training_steps=t_total 299 | ) 300 | preprocessor = Preprocessor( 301 | tokenizer=tokenizer, 302 | num_mask=args.num_mask, 303 | ) 304 | 305 | decoder = Decoder( 306 | model=lm_model, 307 | tokenizer=tokenizer, 308 | init_method=args.init_method, 309 | iter_method=args.iter_method, 310 | MAX_ITER=args.max_iter, 311 | BEAM_SIZE=args.beam_size, 312 | NUM_MASK=args.num_mask, 313 | BATCH_SIZE=args.batch_size, 314 | ) 315 | evaluator = Evaluator( 316 | tokenizer=tokenizer 317 | ) 318 | 319 | if do_train: 320 | 321 | # initalize best_acc with evaluation at epoch 0 322 | result = validate( 323 | file=dev_file, 324 | lm_model=lm_model, 325 | preprocessor=preprocessor, 326 | decoder=decoder, 327 | evaluator=evaluator, 328 | template=template, 329 | batch_size=args.batch_size, 330 | draft=args.draft) 331 | best_acc = result['performance']['acc@k'][0] 332 | best_epoch = 0 333 | best_model = copy.deepcopy(lm_model) 334 | 335 | global_step = 0 336 | for epoch in range(1, epochs+1): 337 | for batch in tqdm(train_dataloader): 338 | lm_model.train() 339 | global_step += 1 340 | inputs, labels = mask_tokens( 341 | inputs = batch[0], 342 | obj_ind = batch[1], 343 | tokenizer = tokenizer, 344 | mask_token = train_dataset.mask_token, 345 | ) 346 | 347 | output = lm_model(input_ids=inputs.cuda(), labels=labels.cuda()) 348 | loss = output[0] 349 | loss = loss.mean() 350 | 351 | if (global_step+1) % 20 == 0: 352 | print(f"step={global_step} loss={round(loss.item(),5)}") 353 | 354 | loss.backward() 355 | 356 | # set normal tokens' gradients to be zero 357 | for p in base_model.embeddings.word_embeddings.parameters(): 358 | # only update new tokens 359 | p.grad[:original_vocab_size, :] = 0.0 360 | 361 | optimizer.step() 362 | scheduler.step() 363 | lm_model.zero_grad() 364 | 365 | 366 | result = validate(file=dev_file, 367 | lm_model=lm_model, 368 | preprocessor=preprocessor, 369 | decoder=decoder, 370 | evaluator=evaluator, 371 | template=template, 372 | batch_size=args.batch_size, 373 | draft=args.draft) 374 | 375 | acc_1 = result['performance']['acc@k'][0] 376 | if best_acc < acc_1: 377 | best_acc = acc_1 378 | best_epoch = epoch 379 | best_model = copy.deepcopy(lm_model) 380 | print(f"{pid} updated best acc={best_acc} epoch={best_epoch}") 381 | 382 | print(f"{pid} overall best acc={best_acc} epoch={best_epoch}") 383 | 384 | lm_model = best_model 385 | 386 | # save best optiprompt 387 | save_optiprompt(optiprompt_path, lm_model, tokenizer, original_vocab_size) 388 | 389 | # test 390 | result = validate( 391 | file=test_file, 392 | lm_model=lm_model, 393 | preprocessor=preprocessor, 394 | decoder=decoder, 395 | evaluator=evaluator, 396 | template=template, 397 | batch_size=args.batch_size, 398 | draft=args.draft) 399 | 400 | # saving log 401 | log_file = os.path.join(args.output_dir, pid + ".json") 402 | print(f"save {log_file}") 403 | with open(log_file, 'w') as f: 404 | json.dump(result, f) 405 | 406 | total_relation += 1 407 | 408 | performance = result['performance'] 409 | local_acc = performance['acc@k'] 410 | 411 | logging_data ={} 412 | pid2performance[pid] = {} 413 | for k in range(args.beam_size): 414 | if k+1 in [1,5]: 415 | acc = local_acc[k] 416 | logging_data[f"{pid}_acc@{k+1}"] = acc * 100 417 | pid2performance[pid][f'acc@{k+1}'] = acc * 100 418 | 419 | print(f'performance of {pid}') 420 | print(logging_data) 421 | 422 | print("PID\tAcc@1\tAcc@5") 423 | print("-------------------------") 424 | acc1s = [] 425 | acc5s = [] 426 | for pid, performance in pid2performance.items(): 427 | acc1 = performance['acc@1'] 428 | acc5 = performance['acc@5'] 429 | acc1s.append(acc1) 430 | acc5s.append(acc5) 431 | print(f"{pid}\t{round(acc1,2)}\t{round(acc5,2)}") 432 | 433 | print("-------------------------") 434 | print(f"MACRO\t{round(np.mean(acc1s),2)}\t{round(np.mean(acc5s),2)}") 435 | 436 | if __name__ == '__main__': 437 | main() -------------------------------------------------------------------------------- /preprocessing/process_ctd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import traceback 5 | import sys 6 | 7 | csv.field_size_limit(sys.maxsize) 8 | 9 | import os 10 | 11 | def save_pid2data(pid2data, output_dir): 12 | for pid in pid2data: 13 | data = pid2data[pid] 14 | save_path = os.path.join(output_dir, f"{pid}.jsonl") 15 | with open(save_path, 'w') as fo: 16 | for sample in data: 17 | new_sample ={ 18 | 'predicate_id':pid, 19 | 'sub_uri': sample['sub_uri'], 20 | 'sub_label': sample['sub_label'], 21 | 'sub_aliases': sample['sub_aliases'], 22 | 'obj_uri': sample['obj_uri'], 23 | 'obj_label': sample['obj_label'], 24 | 'obj_aliases': sample['obj_aliases'], 25 | } 26 | new_sample = json.dumps(new_sample, ensure_ascii=False) 27 | fo.write(new_sample + "\n") 28 | 29 | def process_chemicals_genes_ixns(data_file, output_dir, chemicals_dict, genes_dict): 30 | print(f"Opening {data_file}") 31 | data = [] 32 | properties2meta = {} 33 | with open(file=data_file) as f: 34 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 35 | CG = 1 36 | for line in tsv_reader: 37 | if line[0].startswith("#"): 38 | continue 39 | ChemicalName, ChemicalID, CasRN, GeneSymbol, GeneID, GeneForms, Organism, OrganismID, Interaction, InteractionActions, PubMedIDs = line 40 | 41 | if Organism != 'Homo sapiens': 42 | continue 43 | # use unique interaction 44 | if '|' in InteractionActions: 45 | continue 46 | # use unique gene type 47 | if ('|' in GeneForms) or (GeneForms == ''): 48 | continue 49 | 50 | if chemicals_dict[ChemicalID]['name'] != ChemicalName: 51 | import pdb ; pdb.set_trace() 52 | if GeneID not in genes_dict: # not human gene 53 | # print(f"{GeneID} not in dict => Skip") 54 | continue 55 | elif genes_dict[GeneID]['symbol'].upper() != GeneSymbol.upper(): 56 | import pdb ; pdb.set_trace() 57 | else: 58 | gene_synonyms = [genes_dict[GeneID]['name']] + genes_dict[GeneID]['synonyms'] 59 | 60 | InteractionActions = InteractionActions.replace("^"," ") 61 | 62 | property_label = InteractionActions + "_" + GeneForms 63 | data.append({ 64 | 'property_label':property_label, 65 | 'evidence': Interaction, 66 | 'sub_uri': ChemicalID, 67 | 'sub_label': ChemicalName, 68 | 'sub_aliases': chemicals_dict[ChemicalID]['synonyms'], 69 | 'obj_uri': GeneID, 70 | 'obj_label': GeneSymbol, 71 | 'obj_aliases': gene_synonyms 72 | }) 73 | 74 | if property_label not in properties2meta: 75 | properties2meta[property_label] = { 76 | 'count':0, 77 | 'pid':f'CG{CG}' 78 | } 79 | CG += 1 80 | 81 | properties2meta[property_label]['property'] = InteractionActions 82 | properties2meta[property_label]['obj_type'] = GeneForms 83 | properties2meta[property_label]['example'] = Interaction 84 | 85 | properties2meta[property_label]['count'] += 1 86 | 87 | pid2data = {} 88 | for sample in data: 89 | _prop = sample['property_label'] 90 | meta = properties2meta[_prop] 91 | pid = meta['pid'] 92 | count = meta['count'] 93 | 94 | # filter less than 2000 95 | if count < 2000: 96 | continue 97 | if pid not in pid2data: 98 | pid2data[pid] = [] 99 | 100 | pid2data[pid].append(sample) 101 | 102 | save_pid2data(pid2data, output_dir) 103 | 104 | def process_chemicals_diseases(data_file, output_dir, chemicals_dict, diseases_dict): 105 | print(f"Opening {data_file}") 106 | data = [] 107 | properties2meta = {} 108 | with open(file=data_file) as f: 109 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 110 | CD = 1 111 | for line in tsv_reader: 112 | if line[0].startswith("#"): 113 | continue 114 | ChemicalName, ChemicalID, CasRN, DiseaseName, DiseaseID, DirectEvidence, InferenceGeneSymbol, InferenceScore,OmimIDs, PubMedIDs = line 115 | 116 | DiseaseID = DiseaseID.replace("MESH:", "") 117 | # filter infered 118 | if not DirectEvidence: 119 | continue 120 | 121 | property_label = DirectEvidence 122 | # print(property_label) 123 | data.append({ 124 | 'property_label':property_label, 125 | 'sub_uri': ChemicalID, 126 | 'sub_label': ChemicalName, 127 | 'sub_type': 'chemical', 128 | 'sub_aliases': chemicals_dict[ChemicalID]['synonyms'], 129 | 'obj_uri': DiseaseID, 130 | 'obj_label': DiseaseName, 131 | 'obj_type': 'disease', 132 | 'obj_aliases': diseases_dict[DiseaseID]['synonyms'] 133 | }) 134 | 135 | if property_label not in properties2meta: 136 | properties2meta[property_label] = { 137 | 'count':0, 138 | 'pid':f'CD{CD}' 139 | } 140 | CD += 1 141 | 142 | properties2meta[property_label]['property'] = property_label 143 | 144 | properties2meta[property_label]['count'] += 1 145 | 146 | pid2data = {} 147 | for sample in data: 148 | _prop = sample['property_label'] 149 | meta = properties2meta[_prop] 150 | pid = meta['pid'] 151 | count = meta['count'] 152 | 153 | # filter less than 2000 154 | if count < 2000: 155 | continue 156 | if pid not in pid2data: 157 | pid2data[pid] = [] 158 | 159 | pid2data[pid].append(sample) 160 | 161 | save_pid2data(pid2data, output_dir) 162 | 163 | def process_genes_diseases(data_file, output_dir, genes_dict, diseases_dict): 164 | print(f"Opening {data_file}") 165 | data = [] 166 | properties2meta = {} 167 | with open(file=data_file) as f: 168 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 169 | GD = 1 170 | for line in tsv_reader: 171 | if line[0].startswith("#"): 172 | continue 173 | GeneSymbol, GeneID, DiseaseName, DiseaseID, DirectEvidence, InferenceChemicalName, InferenceScore, OmimIDs, PubMedIDs = line 174 | 175 | DiseaseID = DiseaseID.replace("MESH:", "") 176 | # filter infered 177 | if not DirectEvidence: 178 | continue 179 | if GeneID not in genes_dict: # not human gene 180 | # print(f"{GeneID} not in dict => Skip") 181 | continue 182 | property_label = DirectEvidence 183 | # print(property_label) 184 | data.append({ 185 | 'property_label':property_label, 186 | 'sub_uri': GeneID, 187 | 'sub_label': GeneSymbol, 188 | 'sub_type': 'gene', 189 | 'sub_aliases': [genes_dict[GeneID]['name']] + genes_dict[GeneID]['synonyms'], 190 | 'obj_uri': DiseaseID, 191 | 'obj_label': DiseaseName, 192 | 'obj_type': 'disease', 193 | 'obj_aliases': diseases_dict[DiseaseID]['synonyms'] 194 | }) 195 | 196 | if property_label not in properties2meta: 197 | properties2meta[property_label] = { 198 | 'count':0, 199 | 'pid':f'GD{GD}' 200 | } 201 | GD += 1 202 | 203 | properties2meta[property_label]['property'] = property_label 204 | 205 | properties2meta[property_label]['count'] += 1 206 | 207 | pid2data = {} 208 | for sample in data: 209 | _prop = sample['property_label'] 210 | meta = properties2meta[_prop] 211 | pid = meta['pid'] 212 | count = meta['count'] 213 | 214 | # filter less than 2000 215 | if count < 2000: 216 | continue 217 | if pid not in pid2data: 218 | pid2data[pid] = [] 219 | 220 | pid2data[pid].append(sample) 221 | 222 | save_pid2data(pid2data, output_dir) 223 | 224 | def process_genes_pathways(data_file, output_dir, genes_dict): 225 | print(f"Opening {data_file}") 226 | data = [] 227 | properties2meta = {} 228 | with open(file=data_file) as f: 229 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 230 | GP = 1 231 | for line in tsv_reader: 232 | if line[0].startswith("#"): 233 | continue 234 | GeneSymbol, GeneID, PathwayName, PathwayID = line 235 | 236 | property_label = 'association' 237 | if GeneID not in genes_dict: # not human gene 238 | # print(f"{GeneID} not in dict => Skip") 239 | continue 240 | 241 | data.append({ 242 | 'property_label':property_label, 243 | 'sub_uri': GeneID, 244 | 'sub_label': GeneSymbol, 245 | 'sub_type': 'gene', 246 | 'sub_aliases': [genes_dict[GeneID]['name']] + genes_dict[GeneID]['synonyms'], 247 | 'obj_uri': PathwayID, 248 | 'obj_label': PathwayName, 249 | 'obj_type':'pathway', 250 | 'obj_aliases': [] 251 | }) 252 | 253 | if property_label not in properties2meta: 254 | properties2meta[property_label] = { 255 | 'count':0, 256 | 'pid':f'GP{GP}' 257 | } 258 | GP += 1 259 | 260 | properties2meta[property_label]['property'] = property_label 261 | 262 | properties2meta[property_label]['count'] += 1 263 | 264 | pid2data = {} 265 | for sample in data: 266 | _prop = sample['property_label'] 267 | meta = properties2meta[_prop] 268 | pid = meta['pid'] 269 | count = meta['count'] 270 | 271 | # filter less than 2000 272 | if count < 2000: 273 | continue 274 | if pid not in pid2data: 275 | pid2data[pid] = [] 276 | 277 | pid2data[pid].append(sample) 278 | 279 | save_pid2data(pid2data, output_dir) 280 | 281 | def process_chemicals_phenotypestype(data_file, output_dir, chemicals_dict): 282 | print(f"Opening {data_file}") 283 | data = [] 284 | properties2meta = {} 285 | with open(file=data_file) as f: 286 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 287 | CP = 1 288 | for line in tsv_reader: 289 | if line[0].startswith("#"): 290 | continue 291 | chemicalname, chemicalid,casrn,phenotypename,phenotypeid,comentionedterms,organism,organismid,interaction,interactionactions,anatomyterms,inferencegenesymbols,pubmedids, _ = line 292 | if organism != 'Homo sapiens': 293 | continue 294 | # use unique interaction 295 | if '|' in interactionactions: 296 | continue 297 | 298 | assert chemicals_dict[chemicalid]['name'] == chemicalname 299 | 300 | interactionactions = interactionactions.replace("^"," ") 301 | 302 | property_label = interactionactions 303 | 304 | data.append({ 305 | 'property_label':property_label, 306 | 'sub_uri': chemicalid, 307 | 'sub_label': chemicalname, 308 | 'sub_type': 'chemical', 309 | 'sub_aliases': chemicals_dict[chemicalid]['synonyms'], 310 | 'obj_uri': phenotypeid, 311 | 'obj_label': phenotypename, 312 | 'obj_type': 'phenotype', 313 | 'obj_aliases': [] 314 | }) 315 | 316 | if property_label not in properties2meta: 317 | properties2meta[property_label] = { 318 | 'count':0, 319 | 'pid':f'CP{CP}' 320 | } 321 | CP += 1 322 | 323 | properties2meta[property_label]['property'] = property_label 324 | properties2meta[property_label]['example'] = interaction 325 | 326 | properties2meta[property_label]['count'] += 1 327 | 328 | pid2data = {} 329 | for sample in data: 330 | _prop = sample['property_label'] 331 | meta = properties2meta[_prop] 332 | pid = meta['pid'] 333 | count = meta['count'] 334 | 335 | # filter less than 2000 336 | if count < 2000: 337 | continue 338 | if pid not in pid2data: 339 | pid2data[pid] = [] 340 | 341 | pid2data[pid].append(sample) 342 | 343 | save_pid2data(pid2data, output_dir) 344 | 345 | def process_genes_dict(data_file): 346 | print(f"Opening {data_file}") 347 | geneid2meta = {} 348 | with open(file=data_file) as f: 349 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 350 | for line in tsv_reader: 351 | if line[0].startswith("#"): 352 | continue 353 | 354 | GeneSymbol, GeneName, GeneID, AltGeneIDs, Synonyms, BioGRIDIDs, PharmGKBIDs, UniProtIDs = line 355 | geneid2meta[GeneID] = { 356 | 'symbol': GeneSymbol, 357 | 'name': GeneName, 358 | 'synonyms': Synonyms.split("|") 359 | } 360 | 361 | print(f"len(geneid2meta)={len(geneid2meta)}") 362 | return geneid2meta 363 | 364 | def process_chemicals_dict(data_file): 365 | print(f"Opening {data_file}") 366 | chemid2meta = {} 367 | with open(file=data_file) as f: 368 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 369 | for line in tsv_reader: 370 | if line[0].startswith("#"): 371 | continue 372 | 373 | ChemicalName,ChemicalID,CasRN,Definition,ParentIDs,TreeNumbers,ParentTreeNumbers,Synonyms = line 374 | ChemicalID = ChemicalID.replace("MESH:","") 375 | try: 376 | chemid2meta[ChemicalID] = { 377 | 'name': ChemicalName, 378 | 'synonyms': Synonyms.split("|") 379 | } 380 | except Exception as e: 381 | print(e) 382 | traceback.print_exc() 383 | raise e 384 | 385 | 386 | print(f"len(chemid2meta)={len(chemid2meta)}") 387 | return chemid2meta 388 | 389 | def process_diseases_dict(data_file): 390 | print(f"Opening {data_file}") 391 | diseaseid2meta = {} 392 | with open(file=data_file) as f: 393 | tsv_reader = csv.reader(f, delimiter='\t', quotechar='"') 394 | for line in tsv_reader: 395 | if line[0].startswith("#"): 396 | continue 397 | 398 | DiseaseName, DiseaseID, AltDiseaseIDs, Definition, ParentIDs, TreeNumbers, ParentTreeNumbers, Synonyms, SlimMappings = line 399 | DiseaseID = DiseaseID.replace("MESH:","") 400 | try: 401 | diseaseid2meta[DiseaseID] = { 402 | 'name': DiseaseName, 403 | 'synonyms': Synonyms.split("|") 404 | } 405 | except Exception as e: 406 | print(e) 407 | traceback.print_exc() 408 | import pdb ; pdb.set_trace() 409 | 410 | print(f"len(diseaseid2meta)={len(diseaseid2meta)}") 411 | return diseaseid2meta 412 | 413 | def main(args): 414 | os.makedirs(args.output_dir, exist_ok=True) 415 | 416 | if args.chemicals_dict: 417 | chemicals_dict = process_chemicals_dict(data_file=args.chemicals_dict) 418 | else: 419 | print("[WARN] no args.chemicals_dict") 420 | 421 | if args.diseases_dict: 422 | diseases_dict = process_diseases_dict(data_file=args.diseases_dict) 423 | else: 424 | print("[WARN] no args.diseases_dict") 425 | 426 | # use preprocessed human gene dictionary 427 | if args.genes_dict and args.genes_dict.endswith(".json"): 428 | with open(args.genes_dict) as f: 429 | genes_dict = json.load(f) 430 | else: 431 | print("[WARN] no args.genes_dict") 432 | 433 | if args.chemicals_genes_file: 434 | process_chemicals_genes_ixns( 435 | data_file=args.chemicals_genes_file, 436 | output_dir=args.output_dir, 437 | chemicals_dict=chemicals_dict, 438 | genes_dict=genes_dict 439 | ) 440 | else: 441 | print("[WARN] no args.chemicals_genes_file") 442 | 443 | if args.chemicals_diseases_file: 444 | process_chemicals_diseases( 445 | data_file=args.chemicals_diseases_file, 446 | output_dir=args.output_dir, 447 | chemicals_dict=chemicals_dict, 448 | diseases_dict=diseases_dict 449 | ) 450 | else: 451 | print("[WARN] no args.chemicals_diseases_file") 452 | 453 | if args.genes_diseases_file: 454 | process_genes_diseases( 455 | data_file=args.genes_diseases_file, 456 | output_dir=args.output_dir, 457 | genes_dict=genes_dict, 458 | diseases_dict=diseases_dict 459 | ) 460 | 461 | if args.genes_pathways_file: 462 | process_genes_pathways( 463 | data_file=args.genes_pathways_file, 464 | output_dir=args.output_dir, 465 | genes_dict=genes_dict 466 | ) 467 | 468 | if args.chemicals_phenotypes_file: 469 | process_chemicals_phenotypestype( 470 | data_file=args.chemicals_phenotypes_file, 471 | output_dir=args.output_dir, 472 | chemicals_dict=chemicals_dict 473 | ) 474 | 475 | if __name__ == '__main__': 476 | parser = argparse.ArgumentParser() 477 | parser.add_argument('--chemicals_dict', type=str) 478 | parser.add_argument('--diseases_dict', type=str) 479 | parser.add_argument('--genes_dict', type=str) 480 | parser.add_argument('--chemicals_genes_file', type=str) 481 | parser.add_argument('--chemicals_diseases_file', type=str) 482 | parser.add_argument('--chemicals_phenotypes_file', type=str) 483 | parser.add_argument('--genes_diseases_file', type=str) 484 | parser.add_argument('--genes_pathways_file', type=str) 485 | parser.add_argument('--output_dir', type=str) 486 | 487 | args = parser.parse_args() 488 | 489 | main(args) 490 | -------------------------------------------------------------------------------- /BioLAMA/decoder.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from typing import List, Dict, Tuple, Union 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from transformers import ( 7 | BertTokenizer, 8 | RobertaTokenizer 9 | ) 10 | from utils import ( 11 | normalize_answer, 12 | convert_2d_list_to_1d, 13 | convert_1d_list_to_2d 14 | ) 15 | 16 | # We remove duplicates of the predictions before evaluation. 17 | # We consider it to be duplicate when the normalized predictions are the same 18 | def remove_duplicate_preds_probs(preds_probs): 19 | tmp = {} 20 | normalized_preds = [] 21 | for pred, prob in preds_probs: 22 | normalized_pred = normalize_answer(pred) 23 | 24 | # this is how to remove duplicates 25 | if normalized_pred not in normalized_preds: 26 | tmp[pred] = prob 27 | normalized_preds.append(normalized_pred) 28 | 29 | # initialize new_preds_probs 30 | new_preds_probs = [('',0)] * len(preds_probs) 31 | 32 | # fill in the preds and probs 33 | for i, (pred, prob) in enumerate(tmp.items()): 34 | new_preds_probs[i] = (pred, prob) 35 | 36 | return new_preds_probs 37 | 38 | class Decoder(): 39 | def __init__(self, model, tokenizer, init_method, iter_method, MAX_ITER, BEAM_SIZE, NUM_MASK, BATCH_SIZE, verbose=True): 40 | print(f"init_method={init_method} iter_method={iter_method} MAX_ITER={MAX_ITER} BEAM_SIZE={BEAM_SIZE} NUM_MASK={NUM_MASK} BATCH_SIZE={BATCH_SIZE}") 41 | self.model = model # bert model 42 | self.tokenizer = tokenizer # bert tokenizer 43 | 44 | self.MASK_IDX = self.tokenizer.mask_token_id 45 | self.PAD_IDX = self.tokenizer.pad_token_id 46 | self.UNK_IDX = self.tokenizer.unk_token_id 47 | 48 | if isinstance(tokenizer, BertTokenizer): 49 | self.mask_token = '[MASK]' 50 | self.pad_token = '[PAD]' 51 | self.unk_token = '[UNK]' 52 | assert self.tokenizer.convert_ids_to_tokens(self.MASK_IDX) == self.mask_token 53 | assert self.tokenizer.convert_ids_to_tokens(self.PAD_IDX) == self.pad_token 54 | assert self.tokenizer.convert_ids_to_tokens(self.UNK_IDX) == self.unk_token 55 | 56 | elif isinstance(tokenizer, RobertaTokenizer): 57 | self.mask_token = '' 58 | self.pad_token = '' 59 | self.unk_token = '' 60 | assert self.tokenizer.convert_ids_to_tokens(self.MASK_IDX) == self.mask_token 61 | assert self.tokenizer.convert_ids_to_tokens(self.PAD_IDX) == self.pad_token 62 | assert self.tokenizer.convert_ids_to_tokens(self.UNK_IDX) == self.unk_token 63 | 64 | else: 65 | print(f"tokenizer type = {type(tokenizer)}") 66 | assert 0 67 | 68 | self.init_method = init_method 69 | self.iter_method = iter_method 70 | self.MAX_ITER = MAX_ITER 71 | self.BEAM_SIZE = BEAM_SIZE 72 | self.NUM_MASK = NUM_MASK 73 | self.BATCH_SIZE = BATCH_SIZE 74 | 75 | self.sentence_printed = not verbose 76 | 77 | def set_model(self,model): 78 | self.model = model 79 | 80 | def append_paddings(self, sentences): 81 | # Append [PAD]s next to [SEP] by max length in the batch 82 | tokenized_sentences = [] 83 | for sentence in sentences: 84 | tokenized_sentence = self.tokenizer.encode(sentence) 85 | tokenized_sentences.append(tokenized_sentence) 86 | 87 | len_max = len(max(tokenized_sentences, key=lambda x:len(x))) 88 | for tokenized_sentence in tokenized_sentences: 89 | tokenized_sentence += [self.PAD_IDX] * (len_max - len(tokenized_sentence)) 90 | 91 | return tokenized_sentences 92 | 93 | def decode_sentences(self, sentences): 94 | # Encode input using tokenizer 95 | # Append [PAD]s next to [SEP] by max length in the batch 96 | sentences = self.append_paddings(sentences) 97 | 98 | # for printing only once 99 | if self.sentence_printed == False: 100 | print(sentences[:5]) 101 | self.sentence_printed = True 102 | 103 | inp_tensor = torch.tensor(sentences) 104 | attention_mask = inp_tensor.ne(self.PAD_IDX).long() 105 | mask_ind = inp_tensor.eq(self.MASK_IDX).long() 106 | 107 | if torch.cuda.is_available(): 108 | inp_tensor = inp_tensor.cuda() 109 | attention_mask = attention_mask.cuda() 110 | mask_ind = mask_ind.cuda() 111 | 112 | batch_size = int(len(sentences)/self.NUM_MASK) 113 | # SHAPE: (batch_size, num_mask, seq_len) 114 | inp_tensor = inp_tensor.view(batch_size,self.NUM_MASK,-1) 115 | attention_mask = attention_mask.view(batch_size,self.NUM_MASK,-1) 116 | mask_ind = mask_ind.view(batch_size,self.NUM_MASK,-1) 117 | 118 | out_tensors=[] 119 | logprobs=[] 120 | for nm in range(self.NUM_MASK): 121 | # decode 122 | # SHAPE: (beam_size, batch_size, seq_len) 123 | b_out_tensor, b_logprob, iter = iter_decode_beam_search( 124 | self.model, inp_tensor[:, nm, :], mask_ind[:, nm, :], attention_mask[:, nm, :], 125 | restrict_vocab=[], mask_value=self.MASK_IDX, 126 | init_method=self.init_method, iter_method=self.iter_method, 127 | max_iter=self.MAX_ITER, tokenizer=self.tokenizer, 128 | reprob=False, beam_size=self.BEAM_SIZE) 129 | 130 | # SHAPE: (batch_size, beam_size, seq_len) 131 | b_out_tensor = b_out_tensor.permute(1,0,2) 132 | b_logprob = b_logprob.permute(1,0,2) 133 | 134 | out_tensors.append(b_out_tensor) 135 | logprobs.append(b_logprob) 136 | 137 | # SHAPE: (batch_size, beam_size, num_mask, seq_len) 138 | logprob = torch.stack(logprobs, 2) 139 | out_tensor = torch.stack(out_tensors, 2) 140 | 141 | # predict with topk (beamsize) 142 | all_preds = [] 143 | all_probs = [] 144 | 145 | for b_out_tensor, b_logprob, b_mask_ind in zip(out_tensor, logprob, mask_ind): 146 | for i in range(self.NUM_MASK): 147 | mask_len = i + 1 148 | preds = [] 149 | probs = [] 150 | 151 | for j in range(self.BEAM_SIZE): 152 | pred: np.ndarray = b_out_tensor[j][i].masked_select( 153 | b_mask_ind[i].eq(1)).detach().cpu().numpy().reshape(-1) 154 | log_prob = b_logprob[j][i].masked_select( 155 | b_mask_ind[i].eq(1)).detach().cpu().numpy().reshape(-1).sum(-1) 156 | 157 | # length normalization 158 | # 0.0 length_norm_coeff => mask_len_norm == 1. In this case, shorter predictions are favored. 159 | # 1.0 length_norm_coeff => mask_len_norm == mask_len. In this case, models are adjusted to favor longer predictions. 160 | length_norm_coeff = 0.0 161 | mask_len_norm = np.power(mask_len,length_norm_coeff) 162 | prob = np.exp(log_prob / mask_len_norm) 163 | 164 | pred = merge_subwords(pred, self.tokenizer, merge=True) 165 | 166 | preds.append(pred) 167 | probs.append(prob) 168 | 169 | all_preds.append(preds) 170 | all_probs.append(probs) 171 | 172 | return all_preds, all_probs 173 | 174 | def decode(self, input, batch_size=None, verbose=True): 175 | """ 176 | input: a list of lists of sentences with [MASK] 177 | output: a list of lists of predictions 178 | """ 179 | if batch_size == None: 180 | batch_size = self.BATCH_SIZE 181 | 182 | all_preds = [] 183 | all_probs = [] 184 | for b in tqdm(range(0, len(input), batch_size),desc="decode", disable=not verbose): 185 | query_batch = input[b:b + batch_size] 186 | max_length = len(query_batch[0]) 187 | 188 | # flat for batch processing 189 | flat_query_batch = convert_2d_list_to_1d(query_batch) 190 | flat_preds, flat_probs = self.decode_sentences(flat_query_batch) 191 | preds = convert_1d_list_to_2d(flat_preds, max_length) 192 | probs = convert_1d_list_to_2d(flat_probs, max_length) 193 | 194 | all_preds += preds 195 | all_probs += probs 196 | 197 | # Sort preds based on probs 198 | # The output format will be like [[('pain', 0.37236136611128684)]] 199 | all_preds_probs = [] 200 | for preds, probs in zip(all_preds, all_probs): 201 | flat_preds = convert_2d_list_to_1d(preds) 202 | flat_probs = convert_2d_list_to_1d(probs) 203 | 204 | preds_probs = list(zip(flat_preds, flat_probs)) 205 | preds_probs = sorted(preds_probs, key=lambda x: x[1], reverse=True) 206 | 207 | # Some predictions are decoded into the same output 208 | # These duplicates should be removed 209 | preds_probs = remove_duplicate_preds_probs(preds_probs) 210 | all_preds_probs.append(preds_probs) 211 | 212 | return all_preds_probs 213 | 214 | # https://github.com/jzbjyb/X-FACTR 215 | def merge_subwords(ids: Union[np.ndarray, List[int]], tokenizer, merge: bool=False) -> str: 216 | subwords = list(tokenizer.convert_ids_to_tokens(ids)) 217 | if not merge: 218 | return subwords 219 | else: 220 | merged_subword = "" 221 | for subword in subwords: 222 | if isinstance(tokenizer, BertTokenizer): 223 | if subword.startswith('##'): 224 | subword = subword.replace('##', '') 225 | merged_subword += subword 226 | else: 227 | merged_subword += ' ' + subword 228 | elif isinstance(tokenizer, RobertaTokenizer): 229 | if subword.startswith('Ġ'): 230 | subword = subword.replace('Ġ', ' ') 231 | merged_subword += subword 232 | else: 233 | merged_subword += '' + subword 234 | else: 235 | print('need to check tokenizer!') 236 | assert 0 237 | 238 | merged_subword = merged_subword.strip() 239 | return merged_subword 240 | 241 | def model_prediction_wrap(model, inp_tensor, attention_mask): 242 | with torch.no_grad(): 243 | logit = model(inp_tensor, attention_mask=attention_mask)[0] 244 | 245 | if hasattr(model, 'cls'): # bert 246 | bias = model.cls.predictions.bias 247 | elif hasattr(model, 'lm_head'): # roberta 248 | bias = model.lm_head.bias 249 | elif hasattr(model, 'pred_layer'): # xlm 250 | bias = 0.0 251 | else: 252 | raise Exception('not sure whether the bias is correct') 253 | logit = logit - bias 254 | 255 | return logit 256 | 257 | def iter_decode_beam_search(model, 258 | inp_tensor: torch.LongTensor, # SHAPE: (batch_size, seq_len) 259 | raw_mask: torch.LongTensor, # SHAPE: (batch_size, seq_len) 260 | attention_mask: torch.LongTensor, # SHAPE: (batch_size, seq_len) 261 | restrict_vocab: List[int] = None, 262 | mask_value: int = 0, # indicate which value is used for mask 263 | max_iter: int = None, # max number of iteration 264 | tokenizer = None, 265 | init_method: str='independent', 266 | iter_method: str='none', 267 | reprob: bool = False, # recompute the prob finally 268 | beam_size: int = 5, 269 | ) -> Tuple[torch.LongTensor, torch.Tensor, int]: # HAPE: (batch_size, seq_len) 270 | ''' 271 | Masks must be consecutive. 272 | ''' 273 | assert init_method in {'independent', 'order', 'confidence'} 274 | assert iter_method in {'none', 'order', 'confidence', 'confidence-multi'} 275 | bs, sl = inp_tensor.size(0), inp_tensor.size(1) 276 | init_mask = inp_tensor.eq(mask_value).long() # SHAPE: (batch_size, seq_len) 277 | init_has_mask = init_mask.sum().item() > 0 278 | 279 | if iter_method == 'confidence-multi': 280 | number_to_mask = torch.unique(init_mask.sum(-1)) 281 | assert number_to_mask.size(0) == 1, 'this batch has different numbers of mask tokens' 282 | number_to_mask = number_to_mask[0].item() - 1 283 | assert max_iter == 0, 'do not need to set max_iter in confidence-multi setting' 284 | elif iter_method == 'order': 285 | leftmost_mask = init_mask * torch.cat([init_mask.new_ones((bs, 1)), 1 - init_mask], 1)[:, :-1] 286 | number_to_mask = torch.unique(init_mask.sum(-1)) 287 | assert number_to_mask.size(0) == 1, 'this batch has different numbers of mask tokens' 288 | number_to_mask: int = number_to_mask[0].item() 289 | mask_offset: int = 0 290 | has_modified: bool = False # track wether modification happens during a left-to-right pass 291 | 292 | # SHAPE: (<=beam_size, batch_size, seq_len) 293 | out_tensors: List[torch.LongTensor] = inp_tensor.unsqueeze(0) 294 | # tokens not considered have log prob of zero 295 | out_logprobs: List[torch.Tensor] = torch.zeros_like(inp_tensor).float().unsqueeze(0) 296 | iter: int = 0 297 | stop: bool = False 298 | model_call = 0 299 | 300 | while True and init_has_mask: # skip when there is not mask initially 301 | next_out_tensors = [] 302 | next_out_logprobs = [] 303 | 304 | # enumerate over all previous result 305 | for out_tensor, out_logprob in zip(out_tensors, out_logprobs): 306 | # get input 307 | if iter > 0: 308 | if iter_method == 'none': 309 | inp_tensor = out_tensor 310 | if inp_tensor.eq(mask_value).long().sum().item() == 0: # no mask 311 | stop = True 312 | break 313 | elif iter_method == 'confidence': 314 | has_mask = out_tensor.eq(mask_value).any(-1).unsqueeze(-1).long() # SHAPE: (batch_size, 1) 315 | inp_tensor = out_tensor.scatter(1, out_logprob.min(-1)[1].unsqueeze(-1), mask_value) 316 | # no need to insert mask when there are masks 317 | inp_tensor = out_tensor * has_mask + inp_tensor * (1 - has_mask) 318 | elif iter_method == 'confidence-multi': 319 | has_mask = out_tensor.eq(mask_value).any(-1).unsqueeze(-1) # SHAPE: (batch_size, 1) 320 | all_has_mask = has_mask.all().item() 321 | assert all_has_mask == has_mask.any().item(), 'some samples have masks while the others do not' 322 | if not all_has_mask: 323 | if number_to_mask <= 0: 324 | stop = True 325 | break 326 | inp_tensor = out_tensor.scatter(1, (-out_logprob).topk(number_to_mask, dim=-1)[1], mask_value) 327 | init_method = 'independent' 328 | number_to_mask -= 1 329 | else: 330 | inp_tensor = out_tensor 331 | elif iter_method == 'order': 332 | has_mask = out_tensor.eq(mask_value).any(-1).unsqueeze(-1) # SHAPE: (batch_size, 1) 333 | all_has_mask = has_mask.all().item() 334 | any_has_mask = has_mask.any().item() 335 | assert all_has_mask == any_has_mask, \ 336 | 'some samples have masks while the others do not' 337 | if not all_has_mask: # no mask, should do refinement 338 | if mask_offset >= number_to_mask: 339 | mask_offset = 0 340 | if mask_offset == 0: # restart when starting from the beginning 341 | has_modified = False 342 | cur_mask = torch.cat([leftmost_mask.new_zeros((bs, mask_offset)), leftmost_mask], 1)[:, :sl] 343 | cur_mask = cur_mask * init_mask 344 | inp_tensor = out_tensor * (1 - cur_mask) + mask_value * cur_mask 345 | mask_offset += 1 346 | else: 347 | inp_tensor = out_tensor 348 | else: 349 | raise NotImplementedError 350 | 351 | # predict 352 | # SHAPE: (batch_size, seq_len) 353 | mask_mask = inp_tensor.eq(mask_value).long() 354 | model_call += 1 355 | logit = model_prediction_wrap(model, inp_tensor, attention_mask) 356 | if restrict_vocab is not None: 357 | logit[:, :, restrict_vocab] = float('-inf') 358 | # SHAPE: (batch_size, seq_len, beam_size) 359 | new_out_logprobs, new_out_tensors = logit.log_softmax(-1).topk(beam_size, dim=-1) 360 | 361 | if init_method == 'confidence': 362 | # mask out non-mask positions 363 | new_out_logprobs = new_out_logprobs + mask_mask.unsqueeze(-1).float().log() 364 | new_out_logprobs = new_out_logprobs.view(-1, sl * beam_size) 365 | new_out_tensors = new_out_tensors.view(-1, sl * beam_size) 366 | 367 | for b in range(beam_size): 368 | if init_method == 'independent': 369 | new_out_logprob = new_out_logprobs[:, :, b] 370 | new_out_tensor = new_out_tensors[:, :, b] 371 | # SHAPE: (batch_size, seq_len) 372 | changes = (out_tensor * mask_mask).ne(new_out_tensor * mask_mask) 373 | elif init_method == 'order': # only modify the left-most one. 374 | new_out_logprob = new_out_logprobs[:, :, b] 375 | new_out_tensor = new_out_tensors[:, :, b] 376 | # SHAPE: (batch_size, seq_len) 377 | changes = (out_tensor * mask_mask).ne(new_out_tensor * mask_mask) 378 | changes = changes & torch.cat([changes.new_ones((bs, 1)), ~changes], 1)[:, :-1] 379 | elif init_method == 'confidence': # only modify the most confident one. 380 | # SHAPE: (batch_size,) 381 | raw_lp, raw_ind = new_out_logprobs.max(-1) 382 | # SHAPE: (batch_size, 1) 383 | raw_lp, raw_ind = raw_lp.unsqueeze(-1), raw_ind.unsqueeze(-1) 384 | seq_ind = raw_ind // beam_size 385 | changes = mask_mask & torch.zeros_like(mask_mask).scatter(1, seq_ind, True) 386 | new_out_tensor = torch.zeros_like(out_tensor).scatter(1, seq_ind, new_out_tensors.gather(1, raw_ind)) 387 | new_out_logprob = torch.zeros_like(out_logprob).scatter(1, seq_ind, raw_lp) 388 | changes = (out_tensor * changes.long()).ne(new_out_tensor * changes.long()) 389 | # max for the next max in beam search 390 | new_out_logprobs = new_out_logprobs.scatter(1, raw_ind, float('-inf')) 391 | else: 392 | raise NotImplementedError 393 | 394 | # only modify tokens that have changes 395 | changes = changes.long() 396 | _out_tensor = out_tensor * (1 - changes) + new_out_tensor * changes 397 | _out_logprob = out_logprob * (1 - changes.float()) + new_out_logprob.detach() * changes.float() 398 | 399 | # involves heavy computation, where we re-compute probabilities for beam_size * beam_size samples 400 | if reprob: 401 | _out_logprob = compute_likelihood( 402 | model, _out_tensor, _out_logprob, 403 | init_mask, attention_mask, restrict_vocab, mask_value=mask_value) 404 | _out_logprob = _out_logprob * (1 - _out_tensor.eq(mask_value).float()) # skip mask tokens 405 | 406 | next_out_tensors.append(_out_tensor) 407 | next_out_logprobs.append(_out_logprob) 408 | 409 | ''' 410 | for i in range(bs): 411 | print(tokenizer.convert_ids_to_tokens(inp_tensor[i].cpu().numpy())) 412 | print(tokenizer.convert_ids_to_tokens(_out_tensor[i].cpu().numpy())) 413 | input() 414 | ''' 415 | 416 | if stop: 417 | break 418 | 419 | next_out_tensors = torch.stack(next_out_tensors, 0) 420 | next_out_logprobs = torch.stack(next_out_logprobs, 0) 421 | # tie breaking 422 | next_out_logprobs = next_out_logprobs + \ 423 | get_tie_breaking(int(next_out_logprobs.size(0))).view(-1, 1, 1).to(next_out_logprobs.device) 424 | 425 | # dedup 426 | not_dups = [] 427 | for i in range(bs): 428 | abs = next_out_tensors.size(0) 429 | # SHAPE: (all_beam_size, seq_len) 430 | one_sample = next_out_tensors[:, i, :] 431 | # SHAPE: (all_beam_size,) 432 | inv = torch.unique(one_sample, dim=0, return_inverse=True)[1] 433 | # SHAPE: (all_beam_size, all_beam_size) 434 | not_dup = inv.unsqueeze(-1).ne(inv.unsqueeze(0)) | \ 435 | (torch.arange(abs).unsqueeze(-1) <= torch.arange(abs).unsqueeze(0)).to(inv.device) 436 | # SHAPE: (all_beam_size,) 437 | not_dup = not_dup.all(-1) 438 | not_dups.append(not_dup) 439 | # SHAPE: (all_beam_size, batch_size) 440 | not_dups = torch.stack(not_dups, -1) 441 | 442 | # select top 443 | # SHAPE: (all_beam_size, batch_size) 444 | beam_score = (next_out_logprobs * init_mask.unsqueeze(0).float() + 445 | not_dups.unsqueeze(-1).float().log()).sum(-1) 446 | # SHAPE: (beam_size, batch_size, seq_len) 447 | beam_top = beam_score.topk(beam_size, dim=0)[1].view(-1, bs, 1).repeat(1, 1, sl) 448 | next_out_logprobs = torch.gather(next_out_logprobs, 0, beam_top) 449 | next_out_tensors = torch.gather(next_out_tensors, 0, beam_top) 450 | 451 | # stop condition for other type of iter 452 | if next_out_tensors.size(0) == out_tensors.size(0) and next_out_tensors.eq(out_tensors).all(): 453 | if iter_method != 'order': 454 | stop = True 455 | else: 456 | if iter_method == 'order': 457 | has_modified = True 458 | # stop condition for 'order' iter 459 | if iter_method == 'order' and not has_modified and mask_offset == number_to_mask: 460 | # reach the last position and no modification happens during this iteration 461 | stop = True 462 | 463 | #print(next_out_tensors.ne(out_tensors).any(-1).any(0).nonzero()) 464 | 465 | out_tensors = next_out_tensors 466 | out_logprobs = next_out_logprobs 467 | 468 | iter += 1 469 | if max_iter and iter >= max_iter: # max_iter can be zero 470 | stop = True 471 | if stop: 472 | break 473 | 474 | return out_tensors, out_logprobs, iter 475 | 476 | def compute_likelihood(model, 477 | inp_tensor: torch.LongTensor, # SHAPE: (batch_size, seq_len) 478 | lp_tensor: torch.Tensor, # SHAPE: (batch_size, seq_len) 479 | mask_tensor: torch.LongTensor, # SHAPE: (batch_size, seq_len) 480 | attention_mask: torch.LongTensor, # SHAPE: (batch_size, seq_len)) 481 | restrict_vocab: List[int] = None, 482 | mask_value: int=0, # indicate which value is used for mask 483 | ) -> torch.Tensor: # SHAPE: (batch_size, seq_len) 484 | ''' 485 | Masks must be consecutive. 486 | ''' 487 | bs, seq_len = inp_tensor.size(0), inp_tensor.size(1) 488 | max_num_masks = mask_tensor.sum(-1).max().item() 489 | leftmost_mask = mask_tensor * torch.cat([mask_tensor.new_ones((bs, 1)), 1 - mask_tensor], 1)[:, :-1] 490 | logits = None 491 | for i in range(max_num_masks): 492 | # SHAPE: (batch_size, seq_len) 493 | cur_mask = torch.cat([leftmost_mask.new_zeros((bs, i)), leftmost_mask], 1)[:, :seq_len] * mask_tensor 494 | inp_tensor_ = (1 - cur_mask) * inp_tensor + cur_mask * mask_value 495 | print(f"model call in compute_likelihood {i}") 496 | logit = model_prediction_wrap(model, inp_tensor_, attention_mask) 497 | cur_mask = cur_mask.unsqueeze(-1).float() 498 | if logits is None: 499 | logits = (logit * cur_mask).detach() 500 | else: 501 | logits = (logits * (1 - cur_mask) + logit * cur_mask).detach() 502 | if restrict_vocab is not None: 503 | logits[:, :, restrict_vocab] = float('-inf') 504 | lp = logits.log_softmax(-1) 505 | lp = torch.gather(lp.view(-1, lp.size(-1)), 1, inp_tensor.view(-1).unsqueeze(-1)).view(bs, seq_len) 506 | lp_tensor = (1 - mask_tensor).float() * lp_tensor + mask_tensor.float() * lp 507 | return lp_tensor.detach() 508 | 509 | _tie_breaking: Dict[int, torch.Tensor] = {} 510 | def get_tie_breaking(dim: int): 511 | if dim not in _tie_breaking: 512 | _tie_breaking[dim] = torch.zeros(dim).uniform_(0, 1e-5) 513 | return _tie_breaking[dim] 514 | --------------------------------------------------------------------------------