├── src ├── dataset_utils │ ├── __pycache__ │ │ ├── const.cpython-38.pyc │ │ ├── dataset_loader.cpython-38.pyc │ │ ├── lgl_data_loader.cpython-38.pyc │ │ ├── osm_sample_loader.cpython-38.pyc │ │ ├── paired_sample_loader.cpython-38.pyc │ │ └── wikipedia_sample_loader.cpython-38.pyc │ ├── geowebnews_data_loader.py │ ├── lgl_data_loader.py │ ├── dataset_loader.py │ ├── osm_sample_loader.py │ ├── paired_sample_loader.py │ └── const.py ├── utils │ ├── __pycache__ │ │ ├── baseline_utils.cpython-38.pyc │ │ ├── common_utils.cpython-38.pyc │ │ └── find_closest.cpython-38.pyc │ ├── find_closest.py │ ├── baseline_utils.py │ └── common_utils.py ├── models │ └── __pycache__ │ │ └── spatial_bert_model.cpython-38.pyc └── train_joint.py ├── experiments ├── entity_linking │ ├── utils │ │ ├── __pycache__ │ │ │ ├── common_utils.cpython-38.pyc │ │ │ ├── find_closest.cpython-38.pyc │ │ │ └── baseline_utils.cpython-38.pyc │ │ ├── find_closest.py │ │ ├── baseline_utils.py │ │ └── common_utils.py │ ├── link_geonames.py │ └── multi_link_geonames.py ├── toponym_detection │ ├── test_geobert_toponym.py │ ├── test_baseline_toponym.py │ ├── train_baseline_toponym.py │ └── train_geobert_toponym.py └── typing │ ├── test_cls_joint.py │ └── train_cls_joint.py ├── requirements.txt └── README.md /src/dataset_utils/__pycache__/const.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/dataset_utils/__pycache__/const.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/baseline_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/utils/__pycache__/baseline_utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/common_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/utils/__pycache__/common_utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/find_closest.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/utils/__pycache__/find_closest.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/spatial_bert_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/models/__pycache__/spatial_bert_model.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset_utils/__pycache__/dataset_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/dataset_utils/__pycache__/dataset_loader.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset_utils/__pycache__/lgl_data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/dataset_utils/__pycache__/lgl_data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset_utils/__pycache__/osm_sample_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/dataset_utils/__pycache__/osm_sample_loader.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset_utils/__pycache__/paired_sample_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/dataset_utils/__pycache__/paired_sample_loader.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset_utils/__pycache__/wikipedia_sample_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/src/dataset_utils/__pycache__/wikipedia_sample_loader.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/entity_linking/utils/__pycache__/common_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/experiments/entity_linking/utils/__pycache__/common_utils.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/entity_linking/utils/__pycache__/find_closest.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/experiments/entity_linking/utils/__pycache__/find_closest.cpython-38.pyc -------------------------------------------------------------------------------- /experiments/entity_linking/utils/__pycache__/baseline_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knowledge-computing/geolm/HEAD/experiments/entity_linking/utils/__pycache__/baseline_utils.cpython-38.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | geopandas==0.11.0 2 | haversine==2.6.0 3 | matplotlib==3.5.2 4 | numpy==1.22.3 5 | pandas==1.3.0 6 | pyproj==3.4.1 7 | pytorch-metric-learning==1.7.1 8 | scikit-learn==1.0.2 9 | scipy==1.10.0 10 | seaborn==0.11.2 11 | torch==1.11.0 12 | torchmetrics==0.11.4 13 | torchsummary==1.5.1 14 | tqdm==4.64.0 15 | transformers==4.18.0 16 | wandb==0.15.4 17 | -------------------------------------------------------------------------------- /src/utils/find_closest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.pairwise import cosine_similarity 3 | 4 | 5 | def find_self_closest_match(sim_matrix, word_list): 6 | '''sim_matrix should be (n,n)''' 7 | n = sim_matrix.shape[0] 8 | sim_matrix[range(n), range(n)] = 0 9 | indices = np.argmax(sim_matrix, axis = -1) 10 | ret_list = [] 11 | for ind in indices: 12 | ret_list.append(word_list[ind]) 13 | return ret_list 14 | 15 | 16 | def find_ref_closest_match(sim_matrix, word_list): 17 | ''' 18 | sim_matrix should be (n_ref, n_query) 19 | word_list should be (n_ref,) 20 | ''' 21 | n_ref, n_query = sim_matrix.shape[0], sim_matrix.shape[1] 22 | indices = np.argmax(sim_matrix, axis = 0) # similarity matrix, take the maximum 23 | #print(indices) 24 | ret_list = [] 25 | for ind in indices: 26 | ret_list.append(word_list[ind]) 27 | return ret_list 28 | 29 | def sort_ref_closest_match(sim_matrix, word_list): 30 | ''' 31 | sim_matrix should be (n_ref, n_query) 32 | word_list should be (n_ref,) 33 | ''' 34 | n_ref, n_query = sim_matrix.shape[0], sim_matrix.shape[1] 35 | 36 | indices_list = np.argsort(sim_matrix, axis = 0)[::-1] # descending order 37 | 38 | #print(indices_list) 39 | ret_list = [] 40 | for indices in indices_list: 41 | word_sorted = [] 42 | for ind in indices: 43 | word_sorted.append(word_list[ind]) 44 | ret_list.append(word_sorted) 45 | return ret_list -------------------------------------------------------------------------------- /experiments/entity_linking/utils/find_closest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.pairwise import cosine_similarity 3 | 4 | 5 | def find_self_closest_match(sim_matrix, word_list): 6 | '''sim_matrix should be (n,n)''' 7 | n = sim_matrix.shape[0] 8 | sim_matrix[range(n), range(n)] = 0 9 | indices = np.argmax(sim_matrix, axis = -1) 10 | ret_list = [] 11 | for ind in indices: 12 | ret_list.append(word_list[ind]) 13 | return ret_list 14 | 15 | 16 | def find_ref_closest_match(sim_matrix, word_list): 17 | ''' 18 | sim_matrix should be (n_ref, n_query) 19 | word_list should be (n_ref,) 20 | ''' 21 | n_ref, n_query = sim_matrix.shape[0], sim_matrix.shape[1] 22 | indices = np.argmax(sim_matrix, axis = 0) # similarity matrix, take the maximum 23 | #print(indices) 24 | ret_list = [] 25 | for ind in indices: 26 | ret_list.append(word_list[ind]) 27 | return ret_list 28 | 29 | def sort_ref_closest_match(sim_matrix, word_list): 30 | ''' 31 | sim_matrix should be (n_ref, n_query) 32 | word_list should be (n_ref,) 33 | ''' 34 | n_ref, n_query = sim_matrix.shape[0], sim_matrix.shape[1] 35 | 36 | indices_list = np.argsort(sim_matrix, axis = 0)[::-1] # descending order 37 | 38 | #print(indices_list) 39 | ret_list = [] 40 | for indices in indices_list: 41 | word_sorted = [] 42 | for ind in indices: 43 | word_sorted.append(word_list[ind]) 44 | ret_list.append(word_sorted) 45 | return ret_list -------------------------------------------------------------------------------- /src/utils/baseline_utils.py: -------------------------------------------------------------------------------- 1 | # from transformers import BertModel, BertTokenizerFast 2 | # from transformers import RobertaModel, RobertaTokenizer 3 | from transformers import AutoModel, AutoTokenizer 4 | # from transformers import LukeTokenizer, LukeModel 5 | 6 | 7 | def get_baseline_model(model_name): 8 | 9 | if model_name == 'bert-base': 10 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 11 | model = AutoModel.from_pretrained('bert-base-cased') 12 | 13 | elif model_name == 'bert-large': 14 | tokenizer = AutoTokenizer.from_pretrained("bert-large-cased") 15 | model = AutoModel.from_pretrained('bert-large-cased') 16 | 17 | # config = BertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) 18 | elif model_name == 'roberta-base': 19 | name_str = 'roberta-base' 20 | tokenizer = AutoTokenizer.from_pretrained(name_str) 21 | model = AutoModel.from_pretrained(name_str) 22 | 23 | elif model_name == 'roberta-large': 24 | tokenizer = AutoTokenizer.from_pretrained('roberta-large') 25 | model = AutoModel.from_pretrained('roberta-large') 26 | 27 | elif model_name == 'spanbert-base': 28 | tokenizer = AutoTokenizer.from_pretrained('SpanBERT/spanbert-base-cased') 29 | model = AutoModel.from_pretrained('SpanBERT/spanbert-base-cased') 30 | 31 | elif model_name == 'spanbert-large': 32 | tokenizer = AutoTokenizer.from_pretrained('SpanBERT/spanbert-large-cased') 33 | model = AutoModel.from_pretrained('SpanBERT/spanbert-large-cased') 34 | 35 | 36 | elif model_name == 'simcse-bert-base': 37 | name_str = 'princeton-nlp/unsup-simcse-bert-base-uncased' # they don't have cased version for unsupervised 38 | tokenizer = AutoTokenizer.from_pretrained(name_str) 39 | model = AutoModel.from_pretrained(name_str) 40 | 41 | elif model_name == 'simcse-bert-large': 42 | name_str = 'princeton-nlp/unsup-simcse-bert-large-uncased' # they don't have cased version for unsupervised 43 | tokenizer = AutoTokenizer.from_pretrained(name_str) 44 | model = AutoModel.from_pretrained(name_str) 45 | 46 | 47 | elif model_name == 'simcse-roberta-base': 48 | name_str = 'rinceton-nlpp/unsup-simcse-roberta-base' 49 | tokenizer = AutoTokenizer.from_pretrained(name_str) 50 | model = AutoModel.from_pretrained(name_str) 51 | 52 | elif model_name == 'simcse-roberta-large': 53 | name_str = 'princeton-nlp/unsup-simcse-roberta-large' 54 | tokenizer = AutoTokenizer.from_pretrained(name_str) 55 | model = AutoModel.from_pretrained(name_str) 56 | 57 | 58 | else: 59 | raise NotImplementedError 60 | 61 | 62 | return model, tokenizer # , config 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeoLM: Empowering Language Models for Geospatially Grounded Language Understanding 2 | 3 | [[Paper link](https://arxiv.org/pdf/2310.14478.pdf)] [[Toponym Detection Demo](https://huggingface.co/zekun-li/geolm-base-toponym-recognition)] [[CodeForHuggingFace](https://github.com/zekun-li/transformers/tree/geolm)] 4 | 5 | ## Install 6 | 7 | 1. Clone this repository: 8 | 9 | ```Shell 10 | git clone git@github.com:knowledge-computing/geolm.git 11 | cd geolm 12 | ``` 13 | 14 | 2. Install packages 15 | ```Shell 16 | conda create -n geolm_env python=3.8 -y 17 | conda activate geolm_env 18 | pip install --upgrade pip 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Pre-Train 23 | 1. Change directory to the pre-training script folder 24 | ``` 25 | cd src 26 | ``` 27 | 2. Run `train_joint.py` 28 | ``` 29 | python3 train_joint.py --model_save_dir=OUTPUT_WEIGHT_DIR --pseudo_sentence_dir='../../datasets/osm_pseudo_sent/world/' --nl_sentence_dir='../../datasets/wikidata/world_georelation/joint_v2/' --batch_size=28 --lr=1e-5 --spatial_dist_fill=900 --placename_to_osmid_path='../../datasets/osm_pseudo_sent/name-osmid-dict/placename_to_osmid.json' 30 | ``` 31 | 32 | ## Downstream Tasks 33 | 34 | ### Toponym Detection (Supervised) 35 | 1. Train with in-domain dataset 36 | ``` 37 | cd experiments/toponym_detection/ 38 | 39 | python3 train_geobert_toponym.py --model_save_dir=OUTPUT_TOPONYM_WEIGHT_DIR --model_option='geobert-base' --model_checkpoint_path=PRETRAINED_MODEL_WEIGHT --lr=1e-5 --epochs=30 --input_file_path=DATASET_PATH 40 | ``` 41 | 42 | 2. Test with in-domain dataset 43 | ``` 44 | cd experiments/toponym_detection/ 45 | 46 | python3 test_geobert_toponym.py --model_option='geobert-base' --model_save_path=TOPONYM_MODEL_PATH --input_file_path=DATASET_PATH --spatial_dist_fill=90000 47 | 48 | ``` 49 | 50 | ### Toponym Linking (Unsupervised) 51 | ``` 52 | python3 multi_link_geonames.py --model_name='joint-base' --query_dataset_path=DATASET_PATH --ref_dataset_path=CANDIDATES_FILE_PATH --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir=PRETRAINED_WEIGHT_DIR --spatial_bert_weight_name=PRETRAINED_WEIGHT_FILE --out_dir=OUTPUT_FOLDER 53 | ``` 54 | 55 | 56 | ### Geo-entity Typing (Supervised) 57 | 1. Train with in-domain dataset 58 | ``` 59 | python3 train_cls_joint.py --lr=1e-5 --sep_between_neighbors --bert_option='bert-base' --with_type --mlm_checkpoint_path=PRETRAINED_MODEL_PATH --epochs=30 --max_token_len=512 --model_save_dir=OUTPUT_TYPING_WEIGHT_DIR --spatial_dist_fill=90000 60 | ``` 61 | 2. Test with in-domain dataset 62 | ``` 63 | python3 test_cls_joint.py --sep_between_neighbors --bert_option='bert-base' --with_type --checkpoint_path=TYPING_WEIGHT_PATH 64 | ``` 65 | 66 | 67 | ## Cite 68 | ``` 69 | @article{li2023geolm, 70 | title={GeoLM: Empowering Language Models for Geospatially Grounded Language Understanding}, 71 | author={Li, Zekun and Zhou, Wenxuan and Chiang, Yao-Yi and Chen, Muhao}, 72 | journal={arXiv preprint arXiv:2310.14478}, 73 | year={2023} 74 | } 75 | ``` 76 | 77 | ## License 78 | CC BY-NC 4.0 79 | 80 | -------------------------------------------------------------------------------- /experiments/entity_linking/utils/baseline_utils.py: -------------------------------------------------------------------------------- 1 | # from transformers import BertModel, BertTokenizerFast 2 | # from transformers import RobertaModel, RobertaTokenizer 3 | from transformers import AutoModel, AutoTokenizer 4 | # from transformers import LukeTokenizer, LukeModel 5 | 6 | 7 | def get_baseline_model(model_name): 8 | 9 | if model_name == 'bert-base': 10 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 11 | model = AutoModel.from_pretrained('bert-base-cased') 12 | 13 | elif model_name == 'bert-large': 14 | tokenizer = AutoTokenizer.from_pretrained("bert-large-cased") 15 | model = AutoModel.from_pretrained('bert-large-cased') 16 | 17 | # config = BertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) 18 | elif model_name == 'roberta-base': 19 | name_str = 'roberta-base' 20 | tokenizer = AutoTokenizer.from_pretrained(name_str) 21 | model = AutoModel.from_pretrained(name_str) 22 | 23 | elif model_name == 'roberta-large': 24 | tokenizer = AutoTokenizer.from_pretrained('roberta-large') 25 | model = AutoModel.from_pretrained('roberta-large') 26 | 27 | elif model_name == 'spanbert-base': 28 | tokenizer = AutoTokenizer.from_pretrained('SpanBERT/spanbert-base-cased') 29 | model = AutoModel.from_pretrained('SpanBERT/spanbert-base-cased') 30 | 31 | elif model_name == 'spanbert-large': 32 | tokenizer = AutoTokenizer.from_pretrained('SpanBERT/spanbert-large-cased') 33 | model = AutoModel.from_pretrained('SpanBERT/spanbert-large-cased') 34 | 35 | 36 | 37 | elif model_name == 'simcse-bert-base': 38 | name_str = 'princeton-nlp/unsup-simcse-bert-base-uncased' # they don't have cased version 39 | tokenizer = AutoTokenizer.from_pretrained(name_str) 40 | model = AutoModel.from_pretrained(name_str) 41 | 42 | elif model_name == 'simcse-bert-large': 43 | name_str = 'princeton-nlp/unsup-simcse-bert-large-uncased' # they don't have cased version 44 | tokenizer = AutoTokenizer.from_pretrained(name_str) 45 | model = AutoModel.from_pretrained(name_str) 46 | 47 | 48 | elif model_name == 'simcse-roberta-base': 49 | name_str = 'princeton-nlpp/unsup-simcse-roberta-base' 50 | tokenizer = AutoTokenizer.from_pretrained(name_str) 51 | model = AutoModel.from_pretrained(name_str) 52 | 53 | elif model_name == 'simcse-roberta-large': 54 | name_str = 'princeton-nlp/unsup-simcse-roberta-large' 55 | tokenizer = AutoTokenizer.from_pretrained(name_str) 56 | model = AutoModel.from_pretrained(name_str) 57 | elif model_name == 'sap-bert': 58 | tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext") 59 | model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext") 60 | 61 | elif model_name == 'mirror-bert': 62 | tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/mirror-bert-base-uncased-word") 63 | model = AutoModel.from_pretrained("cambridgeltl/mirror-bert-base-uncased-word") 64 | 65 | 66 | else: 67 | raise NotImplementedError 68 | 69 | 70 | return model, tokenizer # , config 71 | -------------------------------------------------------------------------------- /src/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import pdb 5 | #from sklearn.metrics.pairwise import cosine_similarity 6 | 7 | 8 | import torch 9 | 10 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 11 | 12 | 13 | def write_to_csv(out_dir, map_name, match_list): 14 | out_path = os.path.join(out_dir, map_name + '.json') 15 | 16 | with open(out_path, 'w') as f: 17 | for match_dict in match_list: 18 | json.dump(match_dict, f) 19 | f.write('\n') 20 | 21 | 22 | 23 | def load_spatial_bert_pretrained_weights(model, weight_path): 24 | 25 | # load pretrained weights from SpatialBertLM to SpatialBertModel 26 | #pre_trained_model=torch.load(os.path.join(model_save_dir, weight_file_name)) 27 | pre_trained_model=torch.load(weight_path) 28 | 29 | if 'model' in pre_trained_model: 30 | pre_trained_model = pre_trained_model["model"] 31 | 32 | cnt_layers = 0 33 | cur_model_kvpair=model.state_dict() 34 | for key,value in cur_model_kvpair.items(): 35 | if 'bert.'+key in pre_trained_model: 36 | cur_model_kvpair[key]=pre_trained_model['bert.'+key] 37 | #print("weights loaded for", key) 38 | cnt_layers += 1 39 | else: 40 | print("No weight for", key) 41 | 42 | print(cnt_layers, 'layers loaded') 43 | 44 | model.load_state_dict(cur_model_kvpair) 45 | 46 | return model 47 | 48 | 49 | 50 | def get_spatialbert_embedding(entity, model, use_distance = True, agg = 'mean'): 51 | 52 | pseudo_sentence = entity['pseudo_sentence'][None,:].to(device) 53 | attention_mask = entity['attention_mask'][None,:].to(device) 54 | sent_position_ids = entity['sent_position_ids'][None,:].to(device) 55 | pivot_token_len = entity['pivot_token_len'] 56 | 57 | 58 | if 'norm_lng_list' in entity and use_distance: 59 | position_list_x = entity['norm_lng_list'][None,:].to(device) 60 | position_list_y = entity['norm_lat_list'][None,:].to(device) 61 | else: 62 | position_list_x = [] 63 | position_list_y = [] 64 | 65 | outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask, sent_position_ids = sent_position_ids, 66 | position_list_x = position_list_x, position_list_y = position_list_y) 67 | 68 | 69 | embeddings = outputs.last_hidden_state 70 | 71 | 72 | pivot_embed = embeddings[0][1:1+pivot_token_len] 73 | if agg == 'mean': 74 | pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 75 | elif agg == 'sum': 76 | pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 77 | else: 78 | raise NotImplementedError 79 | 80 | return pivot_embed 81 | 82 | def get_bert_embedding(entity, model, agg = 'mean'): 83 | 84 | pseudo_sentence = entity['pseudo_sentence'].unsqueeze(0).to(device) 85 | attention_mask = entity['attention_mask'].unsqueeze(0).to(device) 86 | pivot_token_len = entity['pivot_token_len'] 87 | 88 | 89 | outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask) 90 | 91 | 92 | embeddings = outputs.last_hidden_state 93 | 94 | 95 | pivot_embed = embeddings[0][1:1+pivot_token_len] 96 | if agg == 'mean': 97 | pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 98 | elif agg == 'sum': 99 | pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 100 | else: 101 | raise NotImplementedError 102 | 103 | 104 | return pivot_embed -------------------------------------------------------------------------------- /experiments/entity_linking/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import pdb 5 | #from sklearn.metrics.pairwise import cosine_similarity 6 | 7 | 8 | import torch 9 | 10 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 11 | 12 | 13 | def write_to_csv(out_dir, map_name, match_list): 14 | out_path = os.path.join(out_dir, map_name + '.json') 15 | 16 | with open(out_path, 'w') as f: 17 | for match_dict in match_list: 18 | json.dump(match_dict, f) 19 | f.write('\n') 20 | 21 | 22 | 23 | def load_spatial_bert_pretrained_weights(model, weight_path): 24 | 25 | # load pretrained weights from SpatialBertLM to SpatialBertModel 26 | #pre_trained_model=torch.load(os.path.join(model_save_dir, weight_file_name)) 27 | pre_trained_model=torch.load(weight_path) 28 | 29 | if 'model' in pre_trained_model: 30 | pre_trained_model = pre_trained_model["model"] 31 | 32 | cnt_layers = 0 33 | cur_model_kvpair=model.state_dict() 34 | for key,value in cur_model_kvpair.items(): 35 | if 'bert.'+key in pre_trained_model: 36 | cur_model_kvpair[key]=pre_trained_model['bert.'+key] 37 | #print("weights loaded for", key) 38 | cnt_layers += 1 39 | else: 40 | print("No weight for", key) 41 | 42 | print(cnt_layers, 'layers loaded') 43 | 44 | model.load_state_dict(cur_model_kvpair) 45 | 46 | return model 47 | 48 | 49 | 50 | def get_spatialbert_embedding(entity, model, use_distance = True, agg = 'mean'): 51 | 52 | pseudo_sentence = entity['pseudo_sentence'][None,:].to(device) 53 | attention_mask = entity['attention_mask'][None,:].to(device) 54 | sent_position_ids = entity['sent_position_ids'][None,:].to(device) 55 | pivot_token_len = entity['pivot_token_len'] 56 | 57 | 58 | if 'norm_lng_list' in entity and use_distance: 59 | position_list_x = entity['norm_lng_list'][None,:].to(device) 60 | position_list_y = entity['norm_lat_list'][None,:].to(device) 61 | else: 62 | position_list_x = [] 63 | position_list_y = [] 64 | 65 | outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask, sent_position_ids = sent_position_ids, 66 | position_list_x = position_list_x, position_list_y = position_list_y) 67 | 68 | 69 | embeddings = outputs.last_hidden_state 70 | 71 | 72 | pivot_embed = embeddings[0][1:1+pivot_token_len] 73 | if agg == 'mean': 74 | pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 75 | elif agg == 'sum': 76 | pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 77 | else: 78 | raise NotImplementedError 79 | 80 | return pivot_embed 81 | 82 | def get_bert_embedding(entity, model, agg = 'mean'): 83 | 84 | pseudo_sentence = entity['pseudo_sentence'].unsqueeze(0).to(device) 85 | attention_mask = entity['attention_mask'].unsqueeze(0).to(device) 86 | pivot_token_len = entity['pivot_token_len'] 87 | 88 | 89 | outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask) 90 | 91 | 92 | embeddings = outputs.last_hidden_state 93 | 94 | 95 | pivot_embed = embeddings[0][1:1+pivot_token_len] 96 | if agg == 'mean': 97 | pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 98 | elif agg == 'sum': 99 | pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) 100 | else: 101 | raise NotImplementedError 102 | 103 | 104 | return pivot_embed -------------------------------------------------------------------------------- /src/dataset_utils/geowebnews_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import math 6 | 7 | import torch 8 | from transformers import RobertaTokenizer, BertTokenizerFast 9 | from torch.utils.data import Dataset 10 | sys.path.append('/home/zekun/joint_model/src/datasets') 11 | from dataset_loader import SpatialDataset 12 | 13 | import pdb 14 | np.random.seed(2333) 15 | 16 | 17 | class GWN_ToponymDataset(SpatialDataset): 18 | def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, mode = None ): 19 | 20 | if tokenizer is None: 21 | self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') 22 | else: 23 | self.tokenizer = tokenizer 24 | 25 | self.max_token_len = max_token_len 26 | self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance 27 | self.mode = mode 28 | 29 | self.read_file(data_file_path, mode) 30 | 31 | super(ToponymDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = True ) 32 | 33 | self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token) 34 | self.cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token) 35 | self.mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 36 | self.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 37 | 38 | self.label_to_id = {'O': 0, 'B-topo':1, 'I-topo':2} 39 | self.id_to_label = {0:'O',1:'B-topo',2:'I-topo'} 40 | 41 | def read_file(self, data_file_path, mode): 42 | with open(data_file_path, 'r') as f: 43 | data = json.load(f) 44 | 45 | if mode == 'train': 46 | data = data[0:int(len(data) * 0.8)] 47 | elif mode == 'test': 48 | data = data[int(len(data) * 0.8):] 49 | elif mode is None: # use the full dataset (for mlm) 50 | pass 51 | else: 52 | raise NotImplementedError 53 | 54 | print('Dataset length ', len(data)) 55 | self.data = data 56 | self.len_data = len(data) 57 | 58 | def get_offset_mappings(self, offset_mapping): 59 | flat_offset_mapping = np.array(offset_mapping).flatten() 60 | offset_mapping_dict_start = {} 61 | offset_mapping_dict_end = {} 62 | for idx in range(0,len(flat_offset_mapping),2): 63 | char_pos = flat_offset_mapping[idx] 64 | if char_pos == 0 and idx != 0: 65 | continue 66 | token_pos = idx//2 + 1 67 | offset_mapping_dict_start[char_pos] = token_pos 68 | for idx in range(1,len(flat_offset_mapping),2): 69 | char_pos = flat_offset_mapping[idx] 70 | if char_pos == 0 and idx != 0: 71 | break 72 | token_pos = (idx-1)//2 + 1 +1 73 | offset_mapping_dict_end[char_pos] = token_pos 74 | 75 | return offset_mapping_dict_start, offset_mapping_dict_end 76 | 77 | def load_data(self, index): 78 | record = self.data[index] 79 | sentence = record['sentence'] 80 | # print(len(sentence)) 81 | 82 | # regular expression to break sentences 83 | # (? self.max_token_len: 119 | rand_start = np.random.randint(1, len(input_ids) - self.max_token_len ) # Do not include CLS and SEP [inclusive, exclusive) 120 | ret_dict['input_ids'] = torch.tensor([self.cls_token_id] + list(input_tokens['input_ids'][rand_start: rand_start + self.max_token_len -2]) + [self.sep_token_id]) 121 | ret_dict['attention_mask'] = torch.tensor([1] + list(input_tokens['attention_mask'][rand_start: rand_start + self.max_token_len -2]) + [1]) 122 | ret_dict['labels'] = torch.tensor([-100] + list(input_tokens['labels'][rand_start: rand_start + self.max_token_len -2]) + [-100]) 123 | elif len(input_ids) < self.max_token_len: 124 | pad_len = self.max_token_len - len(input_ids) 125 | ret_dict['input_ids'] = torch.tensor(list(input_tokens['input_ids']) + [self.pad_token_id] * pad_len ) 126 | ret_dict['attention_mask'] = torch.tensor(list(input_tokens['attention_mask']) + [0] * pad_len) 127 | ret_dict['labels'] = torch.tensor(list(input_tokens['labels']) + [-100] * pad_len) 128 | else: 129 | ret_dict['input_ids'] = torch.tensor(input_tokens['input_ids']) 130 | ret_dict['attention_mask'] = torch.tensor(input_tokens['attention_mask']) 131 | ret_dict['labels'] = torch.tensor(input_tokens['labels'] ) 132 | 133 | ret_dict['sent_position_ids'] = torch.tensor(np.arange(0, self.max_token_len)) 134 | ret_dict['norm_lng_list'] = torch.tensor([self.spatial_dist_fill for i in range(self.max_token_len)]).to(torch.float32) 135 | ret_dict['norm_lat_list'] = torch.tensor([self.spatial_dist_fill for i in range(self.max_token_len)]).to(torch.float32) 136 | ret_dict['token_type_ids'] = torch.zeros(self.max_token_len).int() # 0 for nl data 137 | 138 | return ret_dict 139 | 140 | 141 | def __len__(self): 142 | return self.len_data 143 | 144 | def __getitem__(self, index): 145 | return self.load_data(index) -------------------------------------------------------------------------------- /src/dataset_utils/lgl_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import math 6 | 7 | import torch 8 | from transformers import RobertaTokenizer, BertTokenizerFast 9 | from torch.utils.data import Dataset 10 | sys.path.append('/home/zekun/joint_model/src/dataset_utils') 11 | from dataset_loader import SpatialDataset 12 | 13 | import pdb 14 | np.random.seed(2333) 15 | 16 | 17 | class ToponymDataset(SpatialDataset): 18 | def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, mode = None ): 19 | 20 | if tokenizer is None: 21 | self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') 22 | else: 23 | self.tokenizer = tokenizer 24 | 25 | self.max_token_len = max_token_len 26 | self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance 27 | self.mode = mode 28 | 29 | self.read_file(data_file_path, mode) 30 | 31 | super(ToponymDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = True ) 32 | 33 | self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token) 34 | self.cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token) 35 | self.mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 36 | self.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 37 | 38 | self.label_to_id = {'O': 0, 'B-topo':1, 'I-topo':2} 39 | self.id_to_label = {0:'O',1:'B-topo',2:'I-topo'} 40 | 41 | def read_file(self, data_file_path, mode): 42 | with open(data_file_path, 'r') as f: 43 | data = json.load(f) 44 | 45 | if mode == 'train': 46 | data = data[0:int(len(data) * 0.8)] 47 | elif mode == 'test': 48 | data = data[int(len(data) * 0.8):] 49 | elif mode is None: # use the full dataset (for mlm) 50 | pass 51 | else: 52 | raise NotImplementedError 53 | 54 | print('Dataset length ', len(data)) 55 | self.data = data 56 | self.len_data = len(data) 57 | 58 | def get_offset_mappings(self, offset_mapping): 59 | flat_offset_mapping = np.array(offset_mapping).flatten() 60 | offset_mapping_dict_start = {} 61 | offset_mapping_dict_end = {} 62 | for idx in range(0,len(flat_offset_mapping),2): 63 | char_pos = flat_offset_mapping[idx] 64 | if char_pos == 0 and idx != 0: 65 | continue 66 | token_pos = idx//2 + 1 67 | offset_mapping_dict_start[char_pos] = token_pos 68 | for idx in range(1,len(flat_offset_mapping),2): 69 | char_pos = flat_offset_mapping[idx] 70 | if char_pos == 0 and idx != 0: 71 | break 72 | token_pos = (idx-1)//2 + 1 +1 73 | offset_mapping_dict_end[char_pos] = token_pos 74 | 75 | return offset_mapping_dict_start, offset_mapping_dict_end 76 | 77 | def load_data(self, index): 78 | record = self.data[index] 79 | sentence = record['sentence'] 80 | # print(len(sentence)) 81 | 82 | toponyms = record['toponyms'] 83 | 84 | input_tokens = self.tokenizer(sentence, padding="max_length", max_length=self.max_token_len, truncation = False, return_offsets_mapping = True) 85 | input_ids = input_tokens['input_ids'] 86 | 87 | offset_mapping_dict_start, offset_mapping_dict_end = self.get_offset_mappings(input_tokens['offset_mapping'][1:-1]) 88 | # labels = np.array([self.label_to_id['O'] for i in range(self.max_token_len)]) 89 | labels = np.array([self.label_to_id['O'] for i in range(len(input_ids))]) 90 | labels[0] = -100 # set CLS and SEP to -100 91 | labels[-1] = -100 92 | 93 | for toponym in toponyms: 94 | start = toponym['start'] 95 | end = toponym['end'] 96 | 97 | 98 | if start < 0 or end < 0: 99 | continue # skip wrong annotations in GWN 100 | 101 | if start not in offset_mapping_dict_start: 102 | print('offset_mapping_dict_start', offset_mapping_dict_start) 103 | print('start', start, sentence, input_tokens['offset_mapping'][1:-1]) 104 | if end not in offset_mapping_dict_end and end+1 not in offset_mapping_dict_end: 105 | print(len(sentence)) 106 | print('end', end, sentence, input_tokens['offset_mapping'][1:-1]) 107 | # token_start_idx, token_end_idx = offset_mapping_dict_start[start],offset_mapping_dict_end[end] 108 | try: 109 | token_start_idx, token_end_idx = offset_mapping_dict_start[start],offset_mapping_dict_end[end] 110 | except: 111 | token_start_idx, token_end_idx = offset_mapping_dict_start[start],offset_mapping_dict_end[end-1] 112 | assert token_start_idx < token_end_idx # can not be equal 113 | 114 | labels[token_start_idx + 1: token_end_idx ] = 2 115 | labels[token_start_idx] = 1 116 | 117 | input_tokens['labels'] = labels 118 | 119 | ret_dict = {} 120 | if len(input_ids) > self.max_token_len: 121 | rand_start = np.random.randint(1, len(input_ids) - self.max_token_len +1) # Do not include CLS and SEP [inclusive, exclusive) 122 | ret_dict['input_ids'] = torch.tensor([self.cls_token_id] + list(input_tokens['input_ids'][rand_start: rand_start + self.max_token_len -2]) + [self.sep_token_id]) 123 | ret_dict['attention_mask'] = torch.tensor([1] + list(input_tokens['attention_mask'][rand_start: rand_start + self.max_token_len -2]) + [1]) 124 | ret_dict['labels'] = torch.tensor([-100] + list(input_tokens['labels'][rand_start: rand_start + self.max_token_len -2]) + [-100]) 125 | elif len(input_ids) < self.max_token_len: 126 | pad_len = self.max_token_len - len(input_ids) 127 | ret_dict['input_ids'] = torch.tensor(list(input_tokens['input_ids']) + [self.pad_token_id] * pad_len ) 128 | ret_dict['attention_mask'] = torch.tensor(list(input_tokens['attention_mask']) + [0] * pad_len) 129 | ret_dict['labels'] = torch.tensor(list(input_tokens['labels']) + [-100] * pad_len) 130 | else: 131 | ret_dict['input_ids'] = torch.tensor(input_tokens['input_ids']) 132 | ret_dict['attention_mask'] = torch.tensor(input_tokens['attention_mask']) 133 | ret_dict['labels'] = torch.tensor(input_tokens['labels'] ) 134 | 135 | ret_dict['sent_position_ids'] = torch.tensor(np.arange(0, self.max_token_len)) 136 | # ret_dict['norm_lng_list'] = torch.tensor([self.spatial_dist_fill for i in range(self.max_token_len)]).to(torch.float32) 137 | # ret_dict['norm_lat_list'] = torch.tensor([self.spatial_dist_fill for i in range(self.max_token_len)]).to(torch.float32) 138 | ret_dict['norm_lng_list'] = torch.tensor([0 for i in range(self.max_token_len)]).to(torch.float32) 139 | ret_dict['norm_lat_list'] = torch.tensor([0 for i in range(self.max_token_len)]).to(torch.float32) 140 | ret_dict['token_type_ids'] = torch.zeros(self.max_token_len).int() # 0 for nl data 141 | 142 | return ret_dict 143 | 144 | 145 | 146 | def __len__(self): 147 | return self.len_data 148 | 149 | def __getitem__(self, index): 150 | return self.load_data(index) -------------------------------------------------------------------------------- /experiments/toponym_detection/test_geobert_toponym.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | from transformers import AutoModel, AutoTokenizer 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | sys.path.append('../../src/') 11 | 12 | from dataset_utils.lgl_data_loader import ToponymDataset 13 | from models.spatial_bert_model import SpatialBertConfig 14 | from models.spatial_bert_model import SpatialBertForTokenClassification 15 | from pytorch_metric_learning import losses 16 | import torchmetrics 17 | 18 | from seqeval.metrics import classification_report 19 | import numpy as np 20 | import argparse 21 | import pdb 22 | 23 | DEBUG = False 24 | 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | torch.manual_seed(42) 28 | torch.cuda.manual_seed_all(42) 29 | 30 | use_amp = True # whether to use automatic mixed precision 31 | 32 | id_to_label = {0:'O',1:'B-topo',2:'I-topo',-100:'O'} 33 | 34 | def test(args): 35 | 36 | num_workers = args.num_workers 37 | batch_size = args.batch_size 38 | max_token_len = args.max_token_len 39 | distance_norm_factor = args.distance_norm_factor 40 | spatial_dist_fill=args.spatial_dist_fill 41 | 42 | model_save_path = args.model_save_path 43 | 44 | 45 | 46 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 47 | 48 | if args.model_option == 'geobert-base': 49 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 50 | config = SpatialBertConfig() 51 | elif args.model_option == 'geobert-large': 52 | tokenizer = BertTokenizerFast.from_pretrained("bert-large-cased") 53 | config = SpatialBertConfig( hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) 54 | elif args.model_option == 'geobert-simcse-base': 55 | tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/unsup-simcse-bert-base-uncased') 56 | config = SpatialBertConfig() 57 | else: 58 | raise NotImplementedError 59 | 60 | config.num_labels = 3 61 | # config.vocab_size = 28996 # for bert-cased 62 | config.vocab_size = tokenizer.vocab_size 63 | model = SpatialBertForTokenClassification(config) 64 | # model.load_state_dict(bert_model.state_dict() , strict = False) # load sentence position embedding weights as well 65 | model.load_state_dict(torch.load(args.model_save_path)['model'], strict = True) 66 | 67 | test_dataset = ToponymDataset(data_file_path = args.input_file_path, 68 | tokenizer = tokenizer, 69 | max_token_len = max_token_len, 70 | distance_norm_factor = distance_norm_factor, 71 | spatial_dist_fill = spatial_dist_fill, 72 | mode = 'test' 73 | ) 74 | 75 | 76 | 77 | 78 | test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers, 79 | shuffle=False, pin_memory=True, drop_last=False) 80 | 81 | 82 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 83 | model.to(device) 84 | model.eval() 85 | 86 | 87 | 88 | print('start testing...') 89 | 90 | all_labels = [] 91 | all_pred = [] 92 | precision_metric = torchmetrics.Precision(task='multiclass', num_classes = 3, average=None, ignore_index = -100).to(device) 93 | recall_metric = torchmetrics.Recall(task='multiclass', num_classes = 3, average=None, ignore_index = -100).to(device) 94 | f1_metric = torchmetrics.F1Score(task='multiclass', num_classes = 3, average=None, ignore_index = -100).to(device) 95 | 96 | # setup loop with TQDM and dataloader 97 | loop = tqdm(test_loader, leave=True) 98 | iter = 0 99 | count_1, count_2 = 0, 0 100 | for batch in loop: 101 | 102 | with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp): 103 | input_ids = batch['input_ids'].to(device) 104 | labels = batch['labels'].to(device) 105 | attention_mask = batch['attention_mask'].to(device) 106 | sent_position_ids = batch['sent_position_ids'].to(device) 107 | norm_lng_list = batch['norm_lng_list'].to(device) 108 | norm_lat_list = batch['norm_lat_list'].to(device) 109 | # pdb.set_trace() 110 | 111 | with torch.no_grad(): 112 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels, sent_position_ids = sent_position_ids, spatial_position_list_x = norm_lng_list, 113 | spatial_position_list_y = norm_lat_list) 114 | 115 | for i in range(input_ids.shape[0]): 116 | for l in labels[i]: 117 | if l == 1: 118 | count_1 += 1 119 | if l == 2: 120 | count_2 += 1 121 | 122 | if DEBUG: 123 | for i in range(input_ids.shape[0]): 124 | print(tokenizer.decode(input_ids[i])) 125 | print(tokenizer.convert_ids_to_tokens(input_ids[i])) 126 | print(labels[i]) 127 | print(torch.argmax(outputs.logits[i],axis=-1)) 128 | for l in labels[i]: 129 | if l == 1: 130 | count_1 += 1 131 | if l == 2: 132 | count_2 += 1 133 | for tmp1, tmp2 in zip(tokenizer.convert_ids_to_tokens(input_ids[i]), torch.argmax(outputs.logits[i],axis=-1)): 134 | if tmp2 == 1 or tmp2== 2: 135 | print(tmp1, tmp2) 136 | pdb.set_trace() 137 | 138 | loss = outputs.loss 139 | logits = outputs.logits 140 | 141 | 142 | logits = torch.flatten(logits, start_dim = 0, end_dim = 1) 143 | labels = torch.flatten(labels, start_dim = 0, end_dim = 1) 144 | 145 | # pdb.set_trace() 146 | 147 | 148 | precision_metric(logits, labels) 149 | recall_metric(logits, labels) 150 | f1_metric(logits, labels) 151 | # pdb.set_trace() 152 | 153 | all_labels.append( [id_to_label[a] for a in labels.detach().cpu().numpy().tolist()]) 154 | all_pred.append([id_to_label[a] for a in torch.argmax(logits,dim=1).detach().cpu().numpy().tolist()]) 155 | 156 | print(count_1, count_2) 157 | total_precision = precision_metric.compute() 158 | total_recall = recall_metric.compute() 159 | total_f1 = f1_metric.compute() 160 | 161 | 162 | print(total_precision, total_recall, total_f1) 163 | print(classification_report(all_labels, all_pred, digits=5)) 164 | 165 | 166 | 167 | 168 | def main(): 169 | 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--num_workers', type=int, default=5) 172 | parser.add_argument('--batch_size', type=int, default=16) 173 | 174 | parser.add_argument('--max_token_len', type=int, default=512) 175 | 176 | 177 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 178 | parser.add_argument('--spatial_dist_fill', type=float, default = 90000) 179 | 180 | 181 | parser.add_argument('--model_option', type=str, default='geobert-base', choices=['geobert-base','geobert-large','geobert-simcse-base']) 182 | parser.add_argument('--model_save_path', type=str, default=None) 183 | 184 | parser.add_argument('--input_file_path', type=str, default='/home/zekun/toponym_detection/lgl/lgl.json') 185 | 186 | 187 | args = parser.parse_args() 188 | print('\n') 189 | print(args) 190 | print('\n') 191 | 192 | 193 | 194 | 195 | 196 | test(args) 197 | 198 | 199 | 200 | if __name__ == '__main__': 201 | 202 | main() -------------------------------------------------------------------------------- /experiments/toponym_detection/test_baseline_toponym.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | from transformers import AutoTokenizer, AutoModelForTokenClassification 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | sys.path.append('../../src/') 12 | from utils.baseline_utils import get_baseline_model 13 | 14 | from dataset_utils.lgl_data_loader import ToponymDataset 15 | from transformers.models.bert.modeling_bert import BertForTokenClassification, BertModel 16 | from transformers import RobertaForTokenClassification, AutoModelForTokenClassification 17 | from transformers import BertConfig 18 | from pytorch_metric_learning import losses 19 | import torchmetrics 20 | 21 | from seqeval.metrics import classification_report 22 | import numpy as np 23 | import argparse 24 | import pdb 25 | 26 | 27 | DEBUG = False 28 | 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | torch.manual_seed(42) 32 | torch.cuda.manual_seed_all(42) 33 | 34 | use_amp = True # whether to use automatic mixed precision 35 | 36 | 37 | MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large', 38 | 'spanbert-base','spanbert-large','luke-base','luke-large', 39 | 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large', 40 | 'sapbert-base'] 41 | 42 | id_to_label = {0:'O',1:'B-topo',2:'I-topo',-100:'O'} 43 | 44 | def test(args): 45 | 46 | num_workers = args.num_workers 47 | batch_size = args.batch_size 48 | max_token_len = args.max_token_len 49 | distance_norm_factor = args.distance_norm_factor 50 | spatial_dist_fill=args.spatial_dist_fill 51 | 52 | 53 | backbone_option = args.backbone_option 54 | assert(backbone_option in MODEL_OPTIONS) 55 | 56 | 57 | model_save_path = args.model_save_path 58 | 59 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 60 | 61 | # backbone_model, tokenizer = get_baseline_model(backbone_option) 62 | 63 | if backbone_option == 'bert-base': 64 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 65 | model = AutoModelForTokenClassification.from_pretrained("bert-base-cased", num_labels=3) 66 | 67 | elif backbone_option == 'bert-large': 68 | model = AutoModelForTokenClassification.from_pretrained("bert-large-cased", num_labels=3) 69 | elif backbone_option == 'roberta-base': 70 | tokenizer = AutoTokenizer.from_pretrained("roberta-base") 71 | model = AutoModelForTokenClassification.from_pretrained("roberta-base", num_labels=3) 72 | elif backbone_option == 'spanbert-base': 73 | tokenizer = AutoTokenizer.from_pretrained('SpanBERT/spanbert-base-cased') 74 | model = AutoModelForTokenClassification.from_pretrained("SpanBERT/spanbert-base-cased", num_labels=3) 75 | elif backbone_option == 'sapbert-base': 76 | tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext") 77 | model = AutoModelForTokenClassification.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", num_labels=3) 78 | 79 | elif backbone_option == 'simcse-bert-base': 80 | tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased") 81 | model = AutoModelForTokenClassification.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased", num_labels=3) 82 | 83 | else: 84 | raise NotImplementedError 85 | 86 | 87 | model.load_state_dict(torch.load(args.model_save_path)['model'], strict = False) 88 | 89 | 90 | test_dataset = ToponymDataset(data_file_path = args.input_file_path, 91 | tokenizer = tokenizer, 92 | max_token_len = max_token_len, 93 | distance_norm_factor = distance_norm_factor, 94 | spatial_dist_fill = spatial_dist_fill, 95 | mode = 'test' 96 | ) 97 | 98 | 99 | 100 | test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers, 101 | shuffle=False, pin_memory=True, drop_last=False) 102 | 103 | 104 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 105 | model.to(device) 106 | model.eval() 107 | 108 | 109 | 110 | print('start testing...') 111 | 112 | all_labels = [] 113 | all_pred = [] 114 | precision_metric = torchmetrics.Precision(task='multiclass', num_classes = 3, average=None, ignore_index = -100).to(device) 115 | recall_metric = torchmetrics.Recall(task='multiclass', num_classes = 3, average=None, ignore_index = -100).to(device) 116 | f1_metric = torchmetrics.F1Score(task='multiclass', num_classes = 3, average=None, ignore_index = -100).to(device) 117 | # setup loop with TQDM and dataloader 118 | loop = tqdm(test_loader, leave=True) 119 | iter = 0 120 | for batch in loop: 121 | 122 | with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp): 123 | input_ids = batch['input_ids'].to(device) 124 | labels = batch['labels'].to(device) 125 | attention_mask = batch['attention_mask'].to(device) 126 | 127 | with torch.no_grad(): 128 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels) 129 | 130 | if DEBUG: 131 | for i in range(input_ids.shape[0]): 132 | print(tokenizer.decode(input_ids[i])) 133 | print(tokenizer.convert_ids_to_tokens(input_ids[i])) 134 | print(labels[i]) 135 | print(torch.argmax(outputs.logits[i],axis=-1)) 136 | for tmp1, tmp2 in zip(tokenizer.convert_ids_to_tokens(input_ids[i]), torch.argmax(outputs.logits[i],axis=-1)): 137 | if tmp2 == 1 or tmp2== 2: 138 | print(tmp1, tmp2) 139 | pdb.set_trace() 140 | 141 | loss = outputs.loss 142 | logits = outputs.logits 143 | 144 | 145 | logits = torch.flatten(logits, start_dim = 0, end_dim = 1) 146 | labels = torch.flatten(labels, start_dim = 0, end_dim = 1) 147 | 148 | 149 | precision_metric(logits, labels) 150 | recall_metric(logits, labels) 151 | f1_metric(logits, labels) 152 | 153 | all_labels.append( [id_to_label[a] for a in labels.detach().cpu().numpy().tolist()]) 154 | all_pred.append([id_to_label[a] for a in torch.argmax(logits,dim=1).detach().cpu().numpy().tolist()]) 155 | 156 | total_precision = precision_metric.compute() 157 | total_recall = recall_metric.compute() 158 | total_f1 = f1_metric.compute() 159 | print(total_precision, total_recall, total_f1) 160 | 161 | print(total_precision, total_recall, total_f1) 162 | print(classification_report(all_labels, all_pred, digits=5)) 163 | 164 | 165 | 166 | 167 | def main(): 168 | 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument('--num_workers', type=int, default=5) 171 | parser.add_argument('--batch_size', type=int, default=16) 172 | 173 | parser.add_argument('--max_token_len', type=int, default=512) 174 | 175 | 176 | parser.add_argument('--distance_norm_factor', type=float, default = 0.0001) 177 | parser.add_argument('--spatial_dist_fill', type=float, default = 100) 178 | 179 | 180 | parser.add_argument('--backbone_option', type=str, default='bert-base') 181 | parser.add_argument('--model_save_path', type=str, default=None) 182 | 183 | parser.add_argument('--input_file_path', type=str, default='/home/zekun/toponym_detection/lgl/lgl.json') 184 | 185 | 186 | args = parser.parse_args() 187 | print('\n') 188 | print(args) 189 | print('\n') 190 | 191 | 192 | 193 | test(args) 194 | 195 | 196 | 197 | if __name__ == '__main__': 198 | 199 | main() -------------------------------------------------------------------------------- /src/dataset_utils/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from pyproj import Transformer as projTransformer 5 | import pdb 6 | 7 | np.random.seed(2333) 8 | 9 | class SpatialDataset(Dataset): 10 | def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ): 11 | self.tokenizer = tokenizer 12 | self.max_token_len = max_token_len 13 | self.distance_norm_factor = distance_norm_factor 14 | self.sep_between_neighbors = sep_between_neighbors 15 | self.ptransformer = projTransformer.from_crs("EPSG:4326", "EPSG:4087", always_xy=True) # https://epsg.io/4087, equidistant cylindrical projection 16 | 17 | 18 | def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0): 19 | 20 | sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token) 21 | cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token) 22 | mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 23 | pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 24 | max_token_len = self.max_token_len 25 | 26 | 27 | # process pivot 28 | pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name)) 29 | pivot_token_len = len(pivot_name_tokens) 30 | 31 | pivot_lng = pivot_pos[0] 32 | pivot_lat = pivot_pos[1] 33 | 34 | pivot_lng, pivot_lat = self.ptransformer.transform(pivot_lng, pivot_lat) 35 | 36 | # prepare entity mask 37 | entity_mask_arr = [] 38 | rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot 39 | # True for mask, False for unmask 40 | 41 | # check if pivot entity needs to be masked out, 15% prob. to be masked out 42 | if rand_entity[0] < 0.15: 43 | entity_mask_arr.extend([True] * pivot_token_len) 44 | else: 45 | entity_mask_arr.extend([False] * pivot_token_len) 46 | 47 | # process neighbors 48 | neighbor_token_list = [] 49 | neighbor_lng_list = [] 50 | neighbor_lat_list = [] 51 | 52 | # add separator between pivot and neighbor tokens 53 | # checking pivot_dist_fill is a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss) 54 | if self.sep_between_neighbors and pivot_dist_fill==0: 55 | neighbor_lng_list.append(spatial_dist_fill) 56 | neighbor_lat_list.append(spatial_dist_fill) 57 | neighbor_token_list.append(sep_token_id) 58 | 59 | for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]): 60 | 61 | if not neighbor_name[0].isalpha(): 62 | # only consider neighbors starting with letters 63 | continue 64 | 65 | neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name)) 66 | neighbor_token_len = len(neighbor_token) 67 | 68 | # compute the relative distance from neighbor to pivot, 69 | # normalize the relative distance by distance_norm_factor 70 | # apply the calculated distance for all the subtokens of the neighbor 71 | # neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) 72 | # neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) 73 | 74 | if 'coordinates' in neighbor_geometry: # to handle different json dict structures 75 | neighbor_lng , neighbor_lat = self.ptransformer.transform(neighbor_geometry['coordinates'][0], neighbor_geometry['coordinates'][1]) 76 | 77 | else: 78 | neighbor_lng , neighbor_lat = self.ptransformer.transform(neighbor_geometry[0], neighbor_geometry[1]) 79 | 80 | neighbor_lng_list.extend([(neighbor_lng - pivot_lng)/self.distance_norm_factor] * neighbor_token_len) 81 | neighbor_lat_list.extend([(neighbor_lat - pivot_lat)/self.distance_norm_factor] * neighbor_token_len) 82 | neighbor_token_list.extend(neighbor_token) 83 | 84 | 85 | if self.sep_between_neighbors: 86 | neighbor_lng_list.append(spatial_dist_fill) 87 | neighbor_lat_list.append(spatial_dist_fill) 88 | neighbor_token_list.append(sep_token_id) 89 | 90 | entity_mask_arr.extend([False]) 91 | 92 | 93 | if rnd < 0.15: 94 | #True: mask out, False: Keey original token 95 | entity_mask_arr.extend([True] * neighbor_token_len) 96 | else: 97 | entity_mask_arr.extend([False] * neighbor_token_len) 98 | 99 | 100 | pseudo_sentence = pivot_name_tokens + neighbor_token_list 101 | dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list 102 | dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list 103 | 104 | 105 | #including cls and sep 106 | sent_len = len(pseudo_sentence) 107 | 108 | max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token 109 | 110 | # padding and truncation 111 | if sent_len > max_token_len_middle : 112 | pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id] 113 | dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill] 114 | dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill] 115 | attention_mask = [0] + [1] * max_token_len_middle + [0] # make sure SEP and CLS are not attented to 116 | else: 117 | pad_len = max_token_len_middle - sent_len 118 | assert pad_len >= 0 119 | 120 | pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len 121 | dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len 122 | dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len 123 | attention_mask = [0] + [1] * sent_len + [0] * pad_len + [0] 124 | 125 | 126 | 127 | norm_lng_list = np.array(dist_lng_list) 128 | norm_lat_list = np.array(dist_lat_list) 129 | 130 | 131 | # mask entity in the pseudo sentence 132 | entity_mask_indices = np.where(entity_mask_arr)[0] # true: mask out 133 | masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)] 134 | 135 | # mask token in the pseudo sentence 136 | rand_token = np.random.uniform(size = len(pseudo_sentence)) 137 | # do not mask out cls and sep token. True: masked tokens False: Keey original token 138 | 139 | token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id) 140 | token_mask_indices = np.where(token_mask_arr)[0] 141 | 142 | masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)] 143 | 144 | 145 | # yield masked_token with 50% prob, masked_entity with 50% prob 146 | if np.random.rand() > 0.5: 147 | masked_input = torch.tensor(masked_entity_input) 148 | else: 149 | masked_input = torch.tensor(masked_token_input) 150 | 151 | train_data = {} 152 | # train_data['pivot_name'] = pivot_name 153 | train_data['pivot_token_idx'] = torch.tensor([1,pivot_token_len+1]) 154 | train_data['masked_input'] = masked_input 155 | train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence))) 156 | train_data['attention_mask'] = torch.tensor(attention_mask) 157 | train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32) 158 | train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32) 159 | train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence) 160 | 161 | return train_data 162 | 163 | 164 | 165 | def __len__(self): 166 | return NotImplementedError 167 | 168 | def __getitem__(self, index): 169 | raise NotImplementedError -------------------------------------------------------------------------------- /experiments/toponym_detection/train_baseline_toponym.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | from transformers import AutoTokenizer, AutoModelForTokenClassification 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | sys.path.append('../../src/') 12 | 13 | from dataset_utils.lgl_data_loader import ToponymDataset 14 | from transformers.models.bert.modeling_bert import BertForTokenClassification, BertModel 15 | from transformers import BertConfig, RobertaForTokenClassification, AutoModelForTokenClassification 16 | from pytorch_metric_learning import losses 17 | from utils.baseline_utils import get_baseline_model 18 | 19 | import numpy as np 20 | import argparse 21 | import pdb 22 | 23 | 24 | DEBUG = False 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | torch.manual_seed(42) 28 | torch.cuda.manual_seed_all(42) 29 | 30 | use_amp = True # whether to use automatic mixed precision 31 | 32 | 33 | MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large', 34 | 'spanbert-base','spanbert-large','luke-base','luke-large', 35 | 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large', 36 | 'sapbert-base'] 37 | 38 | 39 | def training(args): 40 | 41 | num_workers = args.num_workers 42 | batch_size = args.batch_size 43 | epochs = args.epochs 44 | lr = args.lr 45 | save_interval = args.save_interval 46 | max_token_len = args.max_token_len 47 | distance_norm_factor = args.distance_norm_factor 48 | spatial_dist_fill=args.spatial_dist_fill 49 | 50 | 51 | backbone_option = args.backbone_option 52 | assert(backbone_option in MODEL_OPTIONS) 53 | 54 | 55 | model_save_dir = os.path.join(args.model_save_dir, backbone_option) 56 | if not os.path.isdir(model_save_dir): 57 | os.makedirs(model_save_dir) 58 | 59 | 60 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 61 | 62 | 63 | print('model_save_dir', model_save_dir) 64 | print('\n') 65 | 66 | # backbone_model, tokenizer = get_baseline_model(backbone_option) 67 | # config = backbone_model.config 68 | 69 | 70 | if backbone_option == 'bert-base': 71 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 72 | model = AutoModelForTokenClassification.from_pretrained("bert-base-cased", num_labels=3) 73 | 74 | elif backbone_option == 'roberta-base': 75 | tokenizer = AutoTokenizer.from_pretrained("roberta-base") 76 | model = AutoModelForTokenClassification.from_pretrained("roberta-base", num_labels=3) 77 | 78 | elif backbone_option == 'spanbert-base': 79 | tokenizer = AutoTokenizer.from_pretrained('SpanBERT/spanbert-base-cased') 80 | model = AutoModelForTokenClassification.from_pretrained("SpanBERT/spanbert-base-cased", num_labels=3) 81 | # model.bert = backbone_model 82 | 83 | elif backbone_option == 'sapbert-base': 84 | tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext") 85 | model = AutoModelForTokenClassification.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", num_labels=3) 86 | 87 | elif backbone_option == 'simcse-bert-base': 88 | tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased") 89 | model = AutoModelForTokenClassification.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased", num_labels=3) 90 | 91 | else: 92 | raise NotImplementedError 93 | 94 | train_val_dataset = ToponymDataset(data_file_path = args.input_file_path, 95 | tokenizer = tokenizer, 96 | max_token_len = max_token_len, 97 | distance_norm_factor = distance_norm_factor, 98 | spatial_dist_fill = spatial_dist_fill, 99 | mode = 'train' 100 | ) 101 | 102 | percent_80 = int(len(train_val_dataset) * 0.8) 103 | train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [percent_80, len(train_val_dataset) - percent_80]) 104 | 105 | # pdb.set_trace() 106 | train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, 107 | shuffle=True, pin_memory=True, drop_last=True) 108 | val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers, 109 | shuffle=False, pin_memory=True, drop_last=False) 110 | 111 | 112 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 113 | model.to(device) 114 | model.train() 115 | 116 | # initialize optimizer 117 | optim = torch.optim.AdamW(model.parameters(), lr = lr) 118 | 119 | print('start training...') 120 | 121 | for epoch in range(epochs): 122 | # setup loop with TQDM and dataloader 123 | loop = tqdm(train_loader, leave=True) 124 | iter = 0 125 | for batch in loop: 126 | 127 | with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp): 128 | input_ids = batch['input_ids'].to(device) 129 | labels = batch['labels'].to(device) 130 | attention_mask = batch['attention_mask'].to(device) 131 | # pdb.set_trace() 132 | 133 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels) 134 | 135 | loss = outputs.loss 136 | 137 | 138 | scaler.scale(loss).backward() 139 | scaler.unscale_(optim) 140 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 141 | scaler.step(optim) 142 | scaler.update() 143 | optim.zero_grad() 144 | 145 | 146 | loop.set_description(f'Epoch {epoch}') 147 | loop.set_postfix({'loss':loss.item()}) 148 | 149 | 150 | iter += 1 151 | 152 | if iter % save_interval == 0 or iter == loop.total: 153 | loss_valid = validating(val_loader, model, device) 154 | print('validation loss', loss_valid) 155 | 156 | save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \ 157 | + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' ) 158 | 159 | checkpoint = {"model": model.state_dict(), 160 | "optimizer": optim.state_dict(), 161 | "scaler": scaler.state_dict()} 162 | torch.save(checkpoint, save_path) 163 | print('saving model checkpoint to', save_path) 164 | 165 | 166 | 167 | def validating(val_loader, model, device): 168 | 169 | with torch.no_grad(): 170 | 171 | loss_valid = 0 172 | data_count = 0 173 | loop = tqdm(val_loader, leave=True) 174 | 175 | for batch in loop: 176 | input_ids = batch['input_ids'].to(device) 177 | labels = batch['labels'].to(device) 178 | attention_mask = batch['attention_mask'].to(device) 179 | 180 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels) 181 | 182 | data_count += input_ids.shape[0] 183 | loss_valid += outputs.loss * input_ids.shape[0] 184 | 185 | loss_valid = loss_valid / data_count 186 | 187 | return loss_valid 188 | 189 | def main(): 190 | 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('--num_workers', type=int, default=5) 193 | parser.add_argument('--batch_size', type=int, default=16) 194 | parser.add_argument('--epochs', type=int, default=30) 195 | parser.add_argument('--save_interval', type=int, default=2000) 196 | parser.add_argument('--max_token_len', type=int, default=512) 197 | 198 | 199 | parser.add_argument('--lr', type=float, default = 1e-5) 200 | parser.add_argument('--distance_norm_factor', type=float, default = 100) # 0.0001) 201 | parser.add_argument('--spatial_dist_fill', type=float, default = 900) # 100) 202 | 203 | 204 | 205 | parser.add_argument('--backbone_option', type=str, default='bert-base') 206 | parser.add_argument('--model_save_dir', type=str, default=None) 207 | 208 | parser.add_argument('--input_file_path', type=str, default='/home/zekun/toponym_detection/lgl/lgl.json') 209 | 210 | 211 | args = parser.parse_args() 212 | print('\n') 213 | print(args) 214 | print('\n') 215 | 216 | # out_dir not None, and out_dir does not exist, then create out_dir 217 | if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir): 218 | os.makedirs(args.model_save_dir) 219 | 220 | training(args) 221 | 222 | 223 | 224 | if __name__ == '__main__': 225 | 226 | main() -------------------------------------------------------------------------------- /experiments/toponym_detection/train_geobert_toponym.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | from transformers import AutoModel, AutoTokenizer 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | sys.path.append('../../src/') 11 | from models.spatial_bert_model import SpatialBertModel 12 | from models.spatial_bert_model import SpatialBertConfig 13 | from models.spatial_bert_model import SpatialBertForTokenClassification 14 | from dataset_utils.lgl_data_loader import ToponymDataset 15 | from pytorch_metric_learning import losses 16 | 17 | import numpy as np 18 | import argparse 19 | import pdb 20 | 21 | 22 | DEBUG = False 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | torch.manual_seed(42) 26 | torch.cuda.manual_seed_all(42) 27 | 28 | use_amp = True # whether to use automatic mixed precision 29 | 30 | 31 | 32 | def training(args): 33 | 34 | num_workers = args.num_workers 35 | batch_size = args.batch_size 36 | epochs = args.epochs 37 | lr = args.lr 38 | save_interval = args.save_interval 39 | max_token_len = args.max_token_len 40 | distance_norm_factor = args.distance_norm_factor 41 | spatial_dist_fill=args.spatial_dist_fill 42 | 43 | 44 | model_save_dir = os.path.join(args.model_save_dir, args.model_option) 45 | if not os.path.isdir(model_save_dir): 46 | os.makedirs(model_save_dir) 47 | 48 | print('model_save_dir', model_save_dir) 49 | print('\n') 50 | 51 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 52 | 53 | if args.model_option == 'geobert-base': 54 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 55 | config = SpatialBertConfig() 56 | elif args.model_option == 'geobert-large': 57 | tokenizer = BertTokenizerFast.from_pretrained("bert-large-cased") 58 | config = SpatialBertConfig( hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) 59 | elif args.model_option == 'geobert-simcse-base': 60 | tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/unsup-simcse-bert-base-uncased') 61 | config = SpatialBertConfig() 62 | else: 63 | raise NotImplementedError 64 | 65 | config.num_labels = 3 66 | # config.vocab_size = 28996 # for bert-cased 67 | config.vocab_size = tokenizer.vocab_size 68 | 69 | model = SpatialBertForTokenClassification(config) 70 | 71 | model.load_state_dict(torch.load(args.model_checkpoint_path)['model'] , strict = False) 72 | 73 | train_val_dataset = ToponymDataset(data_file_path = args.input_file_path, 74 | tokenizer = tokenizer, 75 | max_token_len = max_token_len, 76 | distance_norm_factor = distance_norm_factor, 77 | spatial_dist_fill = spatial_dist_fill, 78 | mode = 'train' 79 | ) 80 | 81 | percent_80 = int(len(train_val_dataset) * 0.8) 82 | train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [percent_80, len(train_val_dataset) - percent_80]) 83 | 84 | # pdb.set_trace() 85 | train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, 86 | shuffle=True, pin_memory=True, drop_last=True) 87 | val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers, 88 | shuffle=False, pin_memory=True, drop_last=False) 89 | 90 | 91 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 92 | model.to(device) 93 | model.train() 94 | 95 | 96 | # initialize optimizer 97 | optim = torch.optim.AdamW(model.parameters(), lr = lr) 98 | 99 | print('start training...') 100 | 101 | for epoch in range(epochs): 102 | # setup loop with TQDM and dataloader 103 | loop = tqdm(train_loader, leave=True) 104 | iter = 0 105 | for batch in loop: 106 | 107 | with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp): 108 | input_ids = batch['input_ids'].to(device) 109 | labels = batch['labels'].to(device) 110 | attention_mask = batch['attention_mask'].to(device) 111 | sent_position_ids = batch['sent_position_ids'].to(device) 112 | norm_lng_list = batch['norm_lng_list'].to(device) 113 | norm_lat_list = batch['norm_lat_list'].to(device) 114 | token_type_ids = batch['token_type_ids'].to(device) 115 | 116 | outputs = model( 117 | input_ids = input_ids, 118 | attention_mask = attention_mask, 119 | labels = labels, 120 | sent_position_ids = sent_position_ids, 121 | spatial_position_list_x = norm_lng_list, 122 | spatial_position_list_y = norm_lat_list, 123 | token_type_ids = token_type_ids 124 | ) 125 | 126 | loss = outputs.loss 127 | 128 | 129 | scaler.scale(loss).backward() 130 | scaler.unscale_(optim) 131 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 132 | scaler.step(optim) 133 | scaler.update() 134 | optim.zero_grad() 135 | 136 | 137 | loop.set_description(f'Epoch {epoch}') 138 | loop.set_postfix({'loss':loss.item()}) 139 | 140 | 141 | iter += 1 142 | 143 | if iter % save_interval == 0 or iter == loop.total: 144 | loss_valid = validating(val_loader, model, device) 145 | print('validation loss', loss_valid) 146 | 147 | save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \ 148 | + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' ) 149 | 150 | checkpoint = {"model": model.state_dict(), 151 | "optimizer": optim.state_dict(), 152 | "scaler": scaler.state_dict()} 153 | torch.save(checkpoint, save_path) 154 | 155 | print('saving model checkpoint to', save_path) 156 | 157 | 158 | def validating(val_loader, model, device): 159 | 160 | with torch.no_grad(): 161 | 162 | loss_valid = 0 163 | loop = tqdm(val_loader, leave=True) 164 | data_count = 0 165 | for batch in loop: 166 | with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp): 167 | input_ids = batch['input_ids'].to(device) 168 | labels = batch['labels'].to(device) 169 | attention_mask = batch['attention_mask'].to(device) 170 | sent_position_ids = batch['sent_position_ids'].to(device) 171 | norm_lng_list = batch['norm_lng_list'].to(device) 172 | norm_lat_list = batch['norm_lat_list'].to(device) 173 | 174 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels, sent_position_ids = sent_position_ids, spatial_position_list_x = norm_lng_list, 175 | spatial_position_list_y = norm_lat_list) 176 | 177 | data_count += input_ids.shape[0] 178 | loss_valid += outputs.loss * input_ids.shape[0] 179 | 180 | loss_valid = loss_valid / data_count 181 | 182 | return loss_valid 183 | 184 | def main(): 185 | 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--num_workers', type=int, default=5) 188 | parser.add_argument('--batch_size', type=int, default=16) 189 | parser.add_argument('--epochs', type=int, default=30) 190 | parser.add_argument('--save_interval', type=int, default=2000) 191 | parser.add_argument('--max_token_len', type=int, default=512) 192 | 193 | 194 | parser.add_argument('--lr', type=float, default = 1e-5) 195 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 196 | parser.add_argument('--spatial_dist_fill', type=float, default = 900) 197 | 198 | 199 | parser.add_argument('--model_option', type=str, default='geobert-base', choices=['geobert-base','geobert-large','geobert-simcse-base']) 200 | parser.add_argument('--model_checkpoint_path', type = str, default = None) 201 | parser.add_argument('--model_save_dir', type=str, default=None) 202 | 203 | parser.add_argument('--input_file_path', type=str, default='/home/zekun/toponym_detection/lgl/lgl.json') 204 | 205 | args = parser.parse_args() 206 | print('\n') 207 | print(args) 208 | print('\n') 209 | 210 | 211 | # out_dir not None, and out_dir does not exist, then create out_dir 212 | if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir): 213 | os.makedirs(args.model_save_dir) 214 | 215 | training(args) 216 | 217 | 218 | 219 | if __name__ == '__main__': 220 | 221 | main() -------------------------------------------------------------------------------- /experiments/typing/test_cls_joint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | 11 | sys.path.append('/home/zekun/joint_model/src') 12 | from models.spatial_bert_model import SpatialBertModel 13 | from models.spatial_bert_model import SpatialBertConfig 14 | from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping 15 | from datasets.osm_sample_loader import PbfMapDataset 16 | from datasets.const import * 17 | from transformers.models.bert.modeling_bert import BertForMaskedLM 18 | 19 | from sklearn.metrics import label_ranking_average_precision_score 20 | from sklearn.metrics import precision_recall_fscore_support 21 | import numpy as np 22 | import argparse 23 | from sklearn.preprocessing import LabelEncoder 24 | import pdb 25 | 26 | 27 | DEBUG = False 28 | 29 | 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | torch.manual_seed(42) 33 | torch.cuda.manual_seed_all(42) 34 | 35 | 36 | def testing(args): 37 | 38 | max_token_len = args.max_token_len 39 | batch_size = args.batch_size 40 | num_workers = args.num_workers 41 | distance_norm_factor = args.distance_norm_factor 42 | spatial_dist_fill=args.spatial_dist_fill 43 | with_type = args.with_type 44 | sep_between_neighbors = args.sep_between_neighbors 45 | checkpoint_path = args.checkpoint_path 46 | if_no_spatial_distance = args.no_spatial_distance 47 | 48 | bert_option = args.bert_option 49 | 50 | 51 | 52 | if args.num_classes == 9: 53 | # london_file_path = '/home/zekun/spatial_bert/spatial_bert/experiments/semantic_typing/data/sql_output/osm-point-london-typing.json' 54 | # california_file_path = '/home/zekun/spatial_bert/spatial_bert/experiments/semantic_typing/data/sql_output/osm-point-california-typing.json' 55 | london_file_path = '/home/zekun/datasets/semantic_typing/data/sql_output/osm-point-london-typing.json' 56 | california_file_path = '/home/zekun/datasets/semantic_typing/data/sql_output/osm-point-california-typing.json' 57 | TYPE_LIST = CLASS_9_LIST 58 | type_key_str = 'class' 59 | elif args.num_classes == 74: 60 | london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json' 61 | california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json' 62 | TYPE_LIST = CLASS_74_LIST 63 | type_key_str = 'fine_class' 64 | else: 65 | raise NotImplementedError 66 | 67 | if bert_option == 'bert-base': 68 | # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 69 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 70 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST)) 71 | elif bert_option == 'bert-large': 72 | # tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 73 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 74 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(TYPE_LIST)) 75 | else: 76 | raise NotImplementedError 77 | 78 | 79 | config.vocab_size = 28996 80 | model = SpatialBertForSemanticTyping(config) 81 | 82 | 83 | #model.load_state_dict(bert_model.state_dict() , strict = False) # load sentence position embedding weights as well 84 | 85 | 86 | label_encoder = LabelEncoder() 87 | label_encoder.fit(TYPE_LIST) 88 | #model.load_state_dict(torch.load('../weights/mlm_mem_ep0.pth')) 89 | 90 | 91 | london_dataset = PbfMapDataset(data_file_path = london_file_path, 92 | tokenizer = tokenizer, 93 | max_token_len = max_token_len, 94 | distance_norm_factor = distance_norm_factor, 95 | spatial_dist_fill = spatial_dist_fill, 96 | with_type = with_type, 97 | type_key_str = type_key_str, 98 | sep_between_neighbors = sep_between_neighbors, 99 | label_encoder = label_encoder, 100 | mode = 'test') 101 | 102 | california_dataset = PbfMapDataset(data_file_path = california_file_path, 103 | tokenizer = tokenizer, 104 | max_token_len = max_token_len, 105 | distance_norm_factor = distance_norm_factor, 106 | spatial_dist_fill = spatial_dist_fill, 107 | with_type = with_type, 108 | type_key_str = type_key_str, 109 | sep_between_neighbors = sep_between_neighbors, 110 | label_encoder = label_encoder, 111 | mode = 'test') 112 | 113 | test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset]) 114 | 115 | 116 | 117 | test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers, 118 | shuffle=False, pin_memory=True, drop_last=False) 119 | 120 | 121 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 122 | model.to(device) 123 | 124 | model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device(device)), strict = True) #, strict = False) # # load sentence position embedding weights as well 125 | 126 | model.eval() 127 | 128 | 129 | 130 | print('start testing...') 131 | 132 | 133 | # setup loop with TQDM and dataloader 134 | loop = tqdm(test_loader, leave=True) 135 | 136 | 137 | mrr_total = 0. 138 | prec_total = 0. 139 | sample_cnt = 0 140 | 141 | gt_list = [] 142 | pred_list = [] 143 | 144 | for batch in loop: 145 | 146 | input_ids = batch['pseudo_sentence'].to(device) 147 | attention_mask = batch['attention_mask'].to(device) 148 | position_list_x = batch['norm_lng_list'].to(device) 149 | position_list_y = batch['norm_lat_list'].to(device) 150 | sent_position_ids = batch['sent_position_ids'].to(device) 151 | 152 | labels = batch['pivot_type'].to(device) 153 | entity_token_idx = batch['pivot_token_idx'].to(device) 154 | token_type_ids = torch.ones(input_ids.shape[0],512).int().to(device) 155 | 156 | outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids, 157 | position_list_x = position_list_x, position_list_y = position_list_y, 158 | labels = labels, token_type_ids = token_type_ids, pivot_token_idx_list=entity_token_idx) 159 | 160 | 161 | onehot_labels = F.one_hot(labels, num_classes=len(TYPE_LIST)) 162 | 163 | gt_list.extend(onehot_labels.cpu().detach().numpy()) 164 | pred_list.extend(outputs.logits.cpu().detach().numpy()) 165 | 166 | mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy()) 167 | mrr_total += mrr * input_ids.shape[0] 168 | sample_cnt += input_ids.shape[0] 169 | 170 | precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None) 171 | 172 | precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro') 173 | print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions]) 174 | print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls]) 175 | print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores]) 176 | print('supports:\n', supports) 177 | print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1)) 178 | 179 | #pdb.set_trace() 180 | #print(mrr_total/sample_cnt) 181 | 182 | 183 | 184 | def main(): 185 | 186 | parser = argparse.ArgumentParser() 187 | 188 | parser.add_argument('--max_token_len', type=int, default=512) 189 | parser.add_argument('--batch_size', type=int, default=12) 190 | parser.add_argument('--num_workers', type=int, default=5) 191 | 192 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 193 | parser.add_argument('--spatial_dist_fill', type=float, default = 90000) 194 | parser.add_argument('--num_classes', type=int, default = 9) 195 | 196 | parser.add_argument('--with_type', default=False, action='store_true') 197 | parser.add_argument('--sep_between_neighbors', default=False, action='store_true') 198 | parser.add_argument('--no_spatial_distance', default=False, action='store_true') 199 | 200 | parser.add_argument('--bert_option', type=str, default='bert-base') 201 | parser.add_argument('--prediction_save_dir', type=str, default=None) 202 | 203 | parser.add_argument('--checkpoint_path', type=str, default=None) 204 | 205 | 206 | args = parser.parse_args() 207 | print('\n') 208 | print(args) 209 | print('\n') 210 | 211 | 212 | # out_dir not None, and out_dir does not exist, then create out_dir 213 | if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir): 214 | os.makedirs(args.prediction_save_dir) 215 | 216 | testing(args) 217 | 218 | if __name__ == '__main__': 219 | 220 | main() 221 | 222 | -------------------------------------------------------------------------------- /src/dataset_utils/osm_sample_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import math 6 | 7 | import torch 8 | from transformers import RobertaTokenizer, BertTokenizer 9 | from torch.utils.data import Dataset 10 | sys.path.append('/home/zekun/joint_model/src/dataset_utils') 11 | from dataset_loader import SpatialDataset 12 | 13 | import pdb 14 | np.random.seed(2333) 15 | 16 | class PbfMapDataset(SpatialDataset): 17 | def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, 18 | with_type = True, sep_between_neighbors = False, label_encoder = None, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0.,type_key_str='class'): 19 | 20 | if tokenizer is None: 21 | self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') 22 | else: 23 | self.tokenizer = tokenizer 24 | 25 | self.max_token_len = max_token_len 26 | self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance 27 | self.with_type = with_type 28 | self.sep_between_neighbors = sep_between_neighbors 29 | self.label_encoder = label_encoder 30 | self.num_neighbor_limit = num_neighbor_limit 31 | self.read_file(data_file_path, mode) 32 | self.random_remove_neighbor = random_remove_neighbor 33 | self.type_key_str = type_key_str # key name of the class type in the input data dictionary 34 | 35 | super(PbfMapDataset, self).__init__(tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors ) 36 | 37 | 38 | def read_file(self, data_file_path, mode): 39 | 40 | with open(data_file_path, 'r') as f: 41 | data = f.readlines() 42 | 43 | if mode == 'train': 44 | data = data[0:int(len(data) * 0.8)] 45 | elif mode == 'test': 46 | data = data[int(len(data) * 0.8):] 47 | elif mode is None: # use the full dataset (for mlm) 48 | pass 49 | else: 50 | raise NotImplementedError 51 | 52 | self.len_data = len(data) # updated data length 53 | self.data = data 54 | 55 | def load_data(self, index): 56 | 57 | spatial_dist_fill = self.spatial_dist_fill 58 | line = self.data[index] # take one line from the input data according to the index 59 | 60 | line_data_dict = json.loads(line) 61 | 62 | # process pivot 63 | pivot_name = line_data_dict['info']['name'] 64 | pivot_pos = line_data_dict['info']['geometry']['coordinates'] 65 | 66 | 67 | neighbor_info = line_data_dict['neighbor_info'] 68 | neighbor_name_list = neighbor_info['name_list'] 69 | neighbor_geometry_list = neighbor_info['geometry_list'] 70 | 71 | if self.random_remove_neighbor != 0: 72 | num_neighbors = len(neighbor_name_list) 73 | rand_neighbor = np.random.uniform(size = num_neighbors) 74 | 75 | neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed 76 | neighbor_keep_arr = np.where(neighbor_keep_arr)[0] 77 | 78 | new_neighbor_name_list, new_neighbor_geometry_list = [],[] 79 | for i in range(0, num_neighbors): 80 | if i in neighbor_keep_arr: 81 | new_neighbor_name_list.append(neighbor_name_list[i]) 82 | new_neighbor_geometry_list.append(neighbor_geometry_list[i]) 83 | 84 | neighbor_name_list = new_neighbor_name_list 85 | neighbor_geometry_list = new_neighbor_geometry_list 86 | 87 | if self.num_neighbor_limit is not None: 88 | neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit] 89 | neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit] 90 | 91 | 92 | train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill ) 93 | 94 | if self.with_type: 95 | pivot_type = line_data_dict['info'][self.type_key_str] 96 | train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id 97 | 98 | # if 'ogc_fid' in line_data_dict['info']: 99 | # train_data['ogc_fid'] = line_data_dict['info']['ogc_fid'] 100 | 101 | return train_data 102 | 103 | def __len__(self): 104 | return self.len_data 105 | 106 | def __getitem__(self, index): 107 | return self.load_data(index) 108 | 109 | 110 | 111 | class PbfMapDatasetMarginRanking(SpatialDataset): 112 | def __init__(self, data_file_path, type_list = None, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, 113 | sep_between_neighbors = False, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0., type_key_str='class'): 114 | 115 | if tokenizer is None: 116 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 117 | else: 118 | self.tokenizer = tokenizer 119 | 120 | self.type_list = type_list 121 | self.type_key_str = type_key_str # key name of the class type in the input data dictionary 122 | self.max_token_len = max_token_len 123 | self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance 124 | self.sep_between_neighbors = sep_between_neighbors 125 | # self.label_encoder = label_encoder 126 | self.num_neighbor_limit = num_neighbor_limit 127 | self.read_file(data_file_path, mode) 128 | self.random_remove_neighbor = random_remove_neighbor 129 | self.mode = mode 130 | 131 | 132 | super(PbfMapDatasetMarginRanking, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors ) 133 | 134 | 135 | def read_file(self, data_file_path, mode): 136 | 137 | with open(data_file_path, 'r') as f: 138 | data = f.readlines() 139 | 140 | if mode == 'train': 141 | data = data[0:int(len(data) * 0.8)] 142 | elif mode == 'test': 143 | data = data[int(len(data) * 0.8):] 144 | self.all_types_data = self.prepare_all_types_data() 145 | elif mode is None: # use the full dataset (for mlm) 146 | self.all_types_data = self.prepare_all_types_data() 147 | pass 148 | else: 149 | raise NotImplementedError 150 | 151 | self.len_data = len(data) # updated data length 152 | self.data = data 153 | 154 | def prepare_all_types_data(self): 155 | type_list = self.type_list 156 | spatial_dist_fill = self.spatial_dist_fill 157 | type_data_dict = dict() 158 | for type_name in type_list: 159 | type_pos = [None, None] # use filler values 160 | type_data = self.parse_spatial_context(type_name, type_pos, pivot_dist_fill = 0., 161 | neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill) 162 | type_data_dict[type_name] = type_data 163 | 164 | return type_data_dict 165 | 166 | def load_data(self, index): 167 | 168 | spatial_dist_fill = self.spatial_dist_fill 169 | line = self.data[index] # take one line from the input data according to the index 170 | 171 | line_data_dict = json.loads(line) 172 | 173 | # process pivot 174 | pivot_name = line_data_dict['info']['name'] 175 | pivot_pos = line_data_dict['info']['geometry']['coordinates'] 176 | 177 | 178 | neighbor_info = line_data_dict['neighbor_info'] 179 | neighbor_name_list = neighbor_info['name_list'] 180 | neighbor_geometry_list = neighbor_info['geometry_list'] 181 | 182 | if self.random_remove_neighbor != 0: 183 | num_neighbors = len(neighbor_name_list) 184 | rand_neighbor = np.random.uniform(size = num_neighbors) 185 | 186 | neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed 187 | neighbor_keep_arr = np.where(neighbor_keep_arr)[0] 188 | 189 | new_neighbor_name_list, new_neighbor_geometry_list = [],[] 190 | for i in range(0, num_neighbors): 191 | if i in neighbor_keep_arr: 192 | new_neighbor_name_list.append(neighbor_name_list[i]) 193 | new_neighbor_geometry_list.append(neighbor_geometry_list[i]) 194 | 195 | neighbor_name_list = new_neighbor_name_list 196 | neighbor_geometry_list = new_neighbor_geometry_list 197 | 198 | if self.num_neighbor_limit is not None: 199 | neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit] 200 | neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit] 201 | 202 | 203 | train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill ) 204 | 205 | if 'ogc_fid' in line_data_dict['info']: 206 | train_data['ogc_fid'] = line_data_dict['info']['ogc_fid'] 207 | 208 | # train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id 209 | 210 | pivot_type = line_data_dict['info'][self.type_key_str] 211 | train_data['pivot_type'] = pivot_type 212 | 213 | if self.mode == 'train' : 214 | # postive class 215 | postive_name = pivot_type # class type string as input to tokenizer 216 | positive_pos = [None, None] # use filler values 217 | postive_type_data = self.parse_spatial_context(postive_name, positive_pos, pivot_dist_fill = 0., 218 | neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill) 219 | train_data['positive_type_data'] = postive_type_data 220 | 221 | 222 | # negative class 223 | other_type_list = self.type_list.copy() 224 | other_type_list.remove(pivot_type) 225 | other_type = np.random.choice(other_type_list) 226 | negative_name = other_type 227 | negative_pos = [None, None] # use filler values 228 | negative_type_data = self.parse_spatial_context(negative_name, negative_pos, pivot_dist_fill = 0., 229 | neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill) 230 | train_data['negative_type_data'] = negative_type_data 231 | 232 | elif self.mode == 'test' or self.mode == None: 233 | # return data for all class types in type_list 234 | train_data['all_types_data'] = self.all_types_data 235 | 236 | else: 237 | raise NotImplementedError 238 | 239 | return train_data 240 | 241 | def __len__(self): 242 | return self.len_data 243 | 244 | def __getitem__(self, index): 245 | return self.load_data(index) 246 | 247 | 248 | -------------------------------------------------------------------------------- /experiments/typing/train_cls_joint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | from transformers import BertTokenizerFast 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | # sys.path.append('../../../') 12 | sys.path.append('/home/zekun/joint_model/src') 13 | from models.spatial_bert_model import SpatialBertModel 14 | from models.spatial_bert_model import SpatialBertConfig 15 | from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping 16 | from datasets.osm_sample_loader import PbfMapDataset 17 | from datasets.const import * 18 | #from utils.common_utils import load_spatial_bert_pretrained_weights 19 | 20 | from transformers.models.bert.modeling_bert import BertForMaskedLM 21 | 22 | import numpy as np 23 | import argparse 24 | from sklearn.preprocessing import LabelEncoder 25 | import pdb 26 | 27 | 28 | DEBUG = False 29 | 30 | 31 | def training(args): 32 | 33 | num_workers = args.num_workers 34 | batch_size = args.batch_size 35 | epochs = args.epochs 36 | lr = args.lr #1e-7 # 5e-5 37 | save_interval = args.save_interval 38 | max_token_len = args.max_token_len 39 | distance_norm_factor = args.distance_norm_factor 40 | spatial_dist_fill=args.spatial_dist_fill 41 | with_type = args.with_type 42 | sep_between_neighbors = args.sep_between_neighbors 43 | freeze_backbone = args.freeze_backbone 44 | mlm_checkpoint_path = args.mlm_checkpoint_path 45 | 46 | if_no_spatial_distance = args.no_spatial_distance 47 | 48 | 49 | bert_option = args.bert_option 50 | 51 | assert bert_option in ['bert-base','bert-large'] 52 | 53 | if args.num_classes == 9: 54 | # london_file_path = '/home/zekun/spatial_bert/spatial_bert/experiments/semantic_typing/data/sql_output/osm-point-london-typing.json' 55 | # california_file_path = '/home/zekun/spatial_bert/spatial_bert/experiments/semantic_typing/data/sql_output/osm-point-california-typing.json' 56 | london_file_path = '/home/zekun/datasets/semantic_typing/data/sql_output/osm-point-london-typing.json' 57 | california_file_path = '/home/zekun/datasets/semantic_typing/data/sql_output/osm-point-california-typing.json' 58 | TYPE_LIST = CLASS_9_LIST 59 | type_key_str = 'class' 60 | elif args.num_classes == 74: 61 | london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json' 62 | california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json' 63 | TYPE_LIST = CLASS_74_LIST 64 | type_key_str = 'fine_class' 65 | else: 66 | raise NotImplementedError 67 | 68 | 69 | if args.model_save_dir is None: 70 | checkpoint_basename = os.path.basename(mlm_checkpoint_path) 71 | checkpoint_prefix = checkpoint_basename.replace("mlm_mem_keeppos_","").strip('.pth') 72 | 73 | sep_pathstr = '_sep' if sep_between_neighbors else '_nosep' 74 | freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze' 75 | if if_no_spatial_distance: 76 | model_save_dir = '/data3/zekun/spatial_bert_weights_ablation/' 77 | else: 78 | model_save_dir = '/data3/zekun/spatial_bert_weights/' 79 | model_save_dir = os.path.join(model_save_dir, 'typing_lr' + str("{:.0e}".format(lr)) + sep_pathstr +'_'+bert_option+ freeze_pathstr + '_london_california_bsize' + str(batch_size) ) 80 | model_save_dir = os.path.join(model_save_dir, checkpoint_prefix) 81 | 82 | if not os.path.isdir(model_save_dir): 83 | os.makedirs(model_save_dir) 84 | else: 85 | model_save_dir = args.model_save_dir 86 | 87 | 88 | print('model_save_dir', model_save_dir) 89 | print('\n') 90 | 91 | if bert_option == 'bert-base': 92 | # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 93 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 94 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST)) 95 | elif bert_option == 'bert-large': 96 | # tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 97 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 98 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24, num_semantic_types=len(TYPE_LIST)) 99 | else: 100 | raise NotImplementedError 101 | 102 | 103 | 104 | label_encoder = LabelEncoder() 105 | label_encoder.fit(TYPE_LIST) 106 | 107 | 108 | london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path, 109 | tokenizer = tokenizer, 110 | max_token_len = max_token_len, 111 | distance_norm_factor = distance_norm_factor, 112 | spatial_dist_fill = spatial_dist_fill, 113 | with_type = with_type, 114 | type_key_str = type_key_str, 115 | sep_between_neighbors = sep_between_neighbors, 116 | label_encoder = label_encoder, 117 | mode = 'train') 118 | 119 | percent_80 = int(len(london_train_val_dataset) * 0.8) 120 | london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80]) 121 | 122 | california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path, 123 | tokenizer = tokenizer, 124 | max_token_len = max_token_len, 125 | distance_norm_factor = distance_norm_factor, 126 | spatial_dist_fill = spatial_dist_fill, 127 | with_type = with_type, 128 | type_key_str = type_key_str, 129 | sep_between_neighbors = sep_between_neighbors, 130 | label_encoder = label_encoder, 131 | mode = 'train') 132 | percent_80 = int(len(california_train_val_dataset) * 0.8) 133 | california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80]) 134 | 135 | train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset]) 136 | val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset]) 137 | 138 | 139 | if DEBUG: 140 | train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, 141 | shuffle=False, pin_memory=True, drop_last=True) 142 | val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers, 143 | shuffle=False, pin_memory=True, drop_last=False) 144 | else: 145 | train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, 146 | shuffle=True, pin_memory=True, drop_last=True) 147 | val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers, 148 | shuffle=False, pin_memory=True, drop_last=False) 149 | 150 | 151 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 152 | 153 | config.vocab_size = 28996 154 | model = SpatialBertForSemanticTyping(config) 155 | model.to(device) 156 | 157 | #model = load_spatial_bert_pretrained_weights(model, mlm_checkpoint_path) 158 | 159 | # model.load_state_dict(torch.load(mlm_checkpoint_path), strict = False) # # load sentence position embedding weights as well 160 | model.load_state_dict(torch.load(mlm_checkpoint_path, map_location=torch.device(device))['model'], strict = False) 161 | 162 | model.train() 163 | 164 | 165 | 166 | # initialize optimizer 167 | optim = AdamW(model.parameters(), lr = lr) 168 | 169 | print('start training...') 170 | 171 | for epoch in range(epochs): 172 | # setup loop with TQDM and dataloader 173 | loop = tqdm(train_loader, leave=True) 174 | iter = 0 175 | for batch in loop: 176 | # initialize calculated gradients (from prev step) 177 | optim.zero_grad() 178 | # pull all tensor batches required for training 179 | input_ids = batch['pseudo_sentence'].to(device) 180 | attention_mask = batch['attention_mask'].to(device) 181 | position_list_x = batch['norm_lng_list'].to(device) 182 | position_list_y = batch['norm_lat_list'].to(device) 183 | sent_position_ids = batch['sent_position_ids'].to(device) 184 | 185 | #labels = batch['pseudo_sentence'].to(device) 186 | labels = batch['pivot_type'].to(device) 187 | entity_token_idx = batch['pivot_token_idx'].to(device) 188 | # pivot_lens = batch['pivot_token_len'].to(device) 189 | token_type_ids = torch.ones(input_ids.shape[0],max_token_len).int().to(device) 190 | 191 | # pdb.set_trace() 192 | 193 | outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids, 194 | position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, token_type_ids = token_type_ids, 195 | pivot_token_idx_list=entity_token_idx) 196 | 197 | 198 | loss = outputs.loss 199 | loss.backward() 200 | optim.step() 201 | 202 | loop.set_description(f'Epoch {epoch}') 203 | loop.set_postfix({'loss':loss.item()}) 204 | 205 | if DEBUG: 206 | print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() ) 207 | 208 | iter += 1 209 | 210 | if iter % save_interval == 0 or iter == loop.total: 211 | loss_valid = validating(val_loader, model, device) 212 | 213 | save_path = os.path.join(model_save_dir, 'keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \ 214 | + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' ) 215 | 216 | torch.save(model.state_dict(), save_path) 217 | print('validation loss', loss_valid) 218 | print('saving model checkpoint to', save_path) 219 | 220 | def validating(val_loader, model, device): 221 | 222 | with torch.no_grad(): 223 | 224 | loss_valid = 0 225 | loop = tqdm(val_loader, leave=True) 226 | 227 | for batch in loop: 228 | input_ids = batch['pseudo_sentence'].to(device) 229 | attention_mask = batch['attention_mask'].to(device) 230 | position_list_x = batch['norm_lng_list'].to(device) 231 | position_list_y = batch['norm_lat_list'].to(device) 232 | sent_position_ids = batch['sent_position_ids'].to(device) 233 | 234 | 235 | labels = batch['pivot_type'].to(device) 236 | entity_token_idx = batch['pivot_token_idx'].to(device) 237 | # pivot_lens = batch['pivot_token_len'].to(device) 238 | 239 | token_type_ids = torch.ones(input_ids.shape[0],512).int().to(device) 240 | 241 | outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids, 242 | position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, token_type_ids = token_type_ids, 243 | pivot_token_idx_list=entity_token_idx) 244 | 245 | loss_valid += outputs.loss 246 | 247 | loss_valid /= len(val_loader) 248 | 249 | return loss_valid 250 | 251 | 252 | def main(): 253 | 254 | parser = argparse.ArgumentParser() 255 | parser.add_argument('--num_workers', type=int, default=5) 256 | parser.add_argument('--batch_size', type=int, default=12) 257 | parser.add_argument('--epochs', type=int, default=20) 258 | parser.add_argument('--save_interval', type=int, default=2000) 259 | parser.add_argument('--max_token_len', type=int, default=512) 260 | 261 | 262 | parser.add_argument('--lr', type=float, default = 5e-5) 263 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 264 | parser.add_argument('--spatial_dist_fill', type=float, default = 90000) 265 | parser.add_argument('--num_classes', type=int, default = 9) 266 | 267 | parser.add_argument('--with_type', default=False, action='store_true') 268 | parser.add_argument('--sep_between_neighbors', default=False, action='store_true') 269 | parser.add_argument('--freeze_backbone', default=False, action='store_true') 270 | parser.add_argument('--no_spatial_distance', default=False, action='store_true') 271 | 272 | parser.add_argument('--bert_option', type=str, default='bert-base') 273 | parser.add_argument('--model_save_dir', type=str, default=None) 274 | 275 | parser.add_argument('--mlm_checkpoint_path', type=str, default=None) 276 | 277 | 278 | args = parser.parse_args() 279 | print('\n') 280 | print(args) 281 | print('\n') 282 | 283 | 284 | # out_dir not None, and out_dir does not exist, then create out_dir 285 | if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir): 286 | os.makedirs(args.model_save_dir) 287 | 288 | training(args) 289 | 290 | 291 | if __name__ == '__main__': 292 | 293 | main() 294 | 295 | -------------------------------------------------------------------------------- /src/dataset_utils/paired_sample_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import math 6 | 7 | import torch 8 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 9 | from torch.utils.data import Dataset 10 | sys.path.append('/home/zekun/joint_model/src/datasets') 11 | from dataset_loader import SpatialDataset 12 | 13 | import pdb 14 | np.random.seed(2333) 15 | 16 | class JointDataset(SpatialDataset): 17 | def __init__(self, geo_file_path, nl_file_path, placename_to_osmid_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10, 18 | sep_between_neighbors = False, label_encoder = None, if_rand_seq=False, type_key_str='class'): 19 | 20 | if tokenizer is None: 21 | self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') 22 | else: 23 | self.tokenizer = tokenizer 24 | 25 | self.max_token_len = max_token_len 26 | self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance 27 | self.sep_between_neighbors = sep_between_neighbors 28 | self.label_encoder = label_encoder 29 | self.read_placename2osm_dict(placename_to_osmid_path) # to prepare hard negative samples 30 | self.read_geo_file(geo_file_path) 31 | self.read_nl_file(nl_file_path) 32 | self.type_key_str = type_key_str # key name of the class type in the input data dictionary 33 | self.if_rand_seq = if_rand_seq 34 | 35 | super(JointDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors ) 36 | 37 | self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token) 38 | self.cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token) 39 | self.mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 40 | self.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 41 | 42 | def read_placename2osm_dict(self, placename_to_osmid_path): 43 | with open(placename_to_osmid_path, 'r') as f: 44 | placename2osm_dict = json.load(f) 45 | self.placename2osm_dict = placename2osm_dict 46 | 47 | def read_geo_file(self, geo_file_path): 48 | 49 | with open(geo_file_path, 'r') as f: 50 | data = f.readlines() 51 | 52 | self.len_geo_data = len(data) # updated data length 53 | self.geo_data = data 54 | 55 | def read_nl_file(self, nl_file_path): 56 | with open(nl_file_path, 'r') as f: 57 | nl_data = json.load(f) 58 | 59 | self.nl_data = nl_data 60 | 61 | def prepare_nl_data(self, pivot_osm_id): 62 | 63 | nl_sample_dict = self.nl_data[pivot_osm_id] 64 | sentences = nl_sample_dict['sentence'] 65 | subject_index_list = nl_sample_dict['subject_index_list'] 66 | 67 | sample_idx = np.random.randint(len(sentences)) 68 | sent = sentences[sample_idx] 69 | subject_schar, subject_tchar = subject_index_list[sample_idx] # start and end index in character 70 | 71 | nl_tokens = self.tokenizer(sent, padding="max_length", max_length=self.max_token_len, truncation = True, return_offsets_mapping = True) 72 | pseudo_sentence = nl_tokens['input_ids'] 73 | 74 | rand = np.random.uniform(size = self.max_token_len) 75 | 76 | mlm_mask_arr = (rand <0.15) & (np.array(pseudo_sentence) != self.cls_token_id) & (np.array(pseudo_sentence) != self.sep_token_id) & (np.array(pseudo_sentence) != self.pad_token_id) 77 | 78 | token_mask_indices = np.where(mlm_mask_arr)[0] 79 | 80 | masked_token_input = [self.mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, self.max_token_len)] 81 | 82 | 83 | offset_mapping = nl_tokens['offset_mapping'][1:-1] 84 | flat_offset_mapping = np.array(offset_mapping).flatten() 85 | offset_mapping_dict_start = {} 86 | offset_mapping_dict_end = {} 87 | for idx in range(0,len(flat_offset_mapping),2): 88 | char_pos = flat_offset_mapping[idx] 89 | if char_pos == 0 and idx != 0: 90 | break 91 | token_pos = idx//2 + 1 92 | offset_mapping_dict_start[char_pos] = token_pos 93 | for idx in range(1,len(flat_offset_mapping),2): 94 | char_pos = flat_offset_mapping[idx] 95 | if char_pos == 0 and idx != 0: 96 | break 97 | token_pos = (idx-1)//2 + 1 +1 98 | offset_mapping_dict_end[char_pos] = token_pos 99 | 100 | if subject_schar not in offset_mapping_dict_start or subject_tchar not in offset_mapping_dict_end: 101 | print(pivot_osm_id, sample_idx) 102 | return self.prepare_nl_data(pivot_osm_id) # a work-around, TODO: fix this 103 | 104 | 105 | if offset_mapping_dict_start[subject_schar] == offset_mapping_dict_end[subject_tchar]: 106 | print('\n') 107 | print(offset_mapping_dict_start, offset_mapping_dict_end) 108 | print(subject_schar, subject_tchar) 109 | print(sent) 110 | print(pseudo_sentence) 111 | print(self.tokenizer.convert_ids_to_tokens(pseudo_sentence)) 112 | print('\n') 113 | # token end index is exclusive 114 | token_start_idx, token_end_idx = offset_mapping_dict_start[subject_schar],offset_mapping_dict_end[subject_tchar] 115 | assert token_start_idx < token_end_idx # can not be equal 116 | 117 | 118 | train_data = {} 119 | train_data['masked_input'] = torch.tensor(masked_token_input) 120 | train_data['pivot_token_idx'] = torch.tensor([token_start_idx, token_end_idx]) 121 | 122 | train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence) 123 | train_data['sent_len'] = torch.tensor(np.sum(np.array(nl_tokens['attention_mask']) == 1)) # pseudo sentence length including CLS and SEP token 124 | train_data['attention_mask'] = torch.tensor(nl_tokens['attention_mask']) 125 | train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence))) 126 | # train_data['norm_lng_list'] = torch.tensor([self.spatial_dist_fill for i in range(len(pseudo_sentence))]).to(torch.float32) 127 | # train_data['norm_lat_list'] = torch.tensor([self.spatial_dist_fill for i in range(len(pseudo_sentence))]).to(torch.float32) 128 | train_data['norm_lng_list'] = torch.tensor([0 for i in range(len(pseudo_sentence))]).to(torch.float32) 129 | train_data['norm_lat_list'] = torch.tensor([0 for i in range(len(pseudo_sentence))]).to(torch.float32) 130 | train_data['token_type_ids'] = torch.zeros(len(pseudo_sentence)).int() # 0 for nl data 131 | 132 | return train_data 133 | 134 | def load_data(self, geo_line_data_dict): 135 | 136 | # process pivot 137 | pivot_name = geo_line_data_dict['info']['name'] 138 | pivot_pos = geo_line_data_dict['info']['geometry']['coordinates'] 139 | pivot_osm_id = geo_line_data_dict['info']['osm_id'] 140 | 141 | neighbor_info = geo_line_data_dict['neighbor_info'] 142 | neighbor_name_list = neighbor_info['name_list'] 143 | neighbor_geometry_list = neighbor_info['geometry_list'] 144 | # print(neighbor_geometry_list) 145 | 146 | train_data = {} 147 | train_data['geo_data'] = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, self.spatial_dist_fill ) 148 | train_data['geo_data']['token_type_ids'] = torch.ones( len(train_data['geo_data']['pseudo_sentence'])).int() # type 1 for geo data 149 | train_data['geo_data']['sent_len'] = torch.sum((train_data['geo_data']['attention_mask']) == 1) # pseudo sentence length including CLS and SEP token 150 | train_data['nl_data'] = self.prepare_nl_data(pivot_osm_id) 151 | 152 | train_data['concat_data'] = {} 153 | nl_data_len = train_data['nl_data']['sent_len'] 154 | geo_data_len = train_data['geo_data']['sent_len'] 155 | 156 | if nl_data_len + geo_data_len <= self.max_token_len: 157 | # if the total length is smaller than max_token_len, take the full sentence (remove [CLS] before geo sentence) 158 | # print (train_data['nl_data']['masked_input'][:nl_data_len].shape, train_data['geo_data']['masked_input'][1 : self.max_token_len - nl_data_len + 1].shape) 159 | train_data['concat_data']['masked_input'] = torch.cat((train_data['nl_data']['masked_input'][:nl_data_len] ,train_data['geo_data']['masked_input'][1 : self.max_token_len - nl_data_len + 1])) 160 | train_data['concat_data']['attention_mask'] = torch.cat((train_data['nl_data']['attention_mask'][:nl_data_len] ,train_data['geo_data']['attention_mask'][1 : self.max_token_len - nl_data_len + 1])) 161 | train_data['concat_data']['sent_position_ids'] = torch.cat((train_data['nl_data']['sent_position_ids'][:nl_data_len] ,train_data['geo_data']['sent_position_ids'][1 : self.max_token_len - nl_data_len + 1])) 162 | train_data['concat_data']['pseudo_sentence'] = torch.cat((train_data['nl_data']['pseudo_sentence'][:nl_data_len] ,train_data['geo_data']['pseudo_sentence'][1 : self.max_token_len - nl_data_len + 1])) 163 | train_data['concat_data']['token_type_ids'] = torch.cat((train_data['nl_data']['token_type_ids'][:nl_data_len] ,train_data['geo_data']['token_type_ids'][1 : self.max_token_len - nl_data_len + 1])) 164 | 165 | else: 166 | # otherwise, 167 | if nl_data_len <= self.max_token_len / 2 : 168 | # if the nl_data_len is <= 0.5 * max_token_len, then truncate geodata 169 | # concat geo data , remove [CLS] from geo_data 170 | # SEP alrady added at the end of nl sentence after tokenization 171 | # print(train_data['geo_data']['masked_input'][1 : self.max_token_len - nl_data_len ].shape, torch.tensor([self.sep_token_id]).shape) 172 | train_data['concat_data']['masked_input'] = torch.cat((train_data['nl_data']['masked_input'][:nl_data_len] , train_data['geo_data']['masked_input'][1 : self.max_token_len - nl_data_len ] , torch.tensor([self.sep_token_id]))) 173 | train_data['concat_data']['attention_mask'] = torch.cat((train_data['nl_data']['attention_mask'][:nl_data_len] , train_data['geo_data']['attention_mask'][1 : self.max_token_len - nl_data_len ] , torch.tensor([0]))) 174 | train_data['concat_data']['sent_position_ids'] = torch.cat((train_data['nl_data']['sent_position_ids'][:nl_data_len] , train_data['geo_data']['sent_position_ids'][1 : self.max_token_len - nl_data_len ] , torch.tensor([self.max_token_len - nl_data_len]))) 175 | train_data['concat_data']['pseudo_sentence'] = torch.cat((train_data['nl_data']['pseudo_sentence'][:nl_data_len] , train_data['geo_data']['pseudo_sentence'][1 : self.max_token_len - nl_data_len ] , torch.tensor([self.sep_token_id]))) 176 | train_data['concat_data']['token_type_ids'] = torch.cat((train_data['nl_data']['token_type_ids'][:nl_data_len] , train_data['geo_data']['token_type_ids'][1 : self.max_token_len - nl_data_len ] , torch.tensor([1]))) 177 | 178 | else: 179 | 180 | train_data['concat_data']['masked_input'] = torch.cat((train_data['nl_data']['masked_input'][:self.max_token_len // 2 - 1], torch.tensor([self.sep_token_id]) , train_data['geo_data']['masked_input'][1 : self.max_token_len//2 ], torch.tensor([self.sep_token_id]))) 181 | train_data['concat_data']['attention_mask'] = torch.cat((train_data['nl_data']['attention_mask'][:self.max_token_len // 2 - 1], torch.tensor([1]) , train_data['geo_data']['attention_mask'][1 : self.max_token_len//2 ], torch.tensor([0]))) 182 | train_data['concat_data']['sent_position_ids'] = torch.cat((train_data['nl_data']['sent_position_ids'][:self.max_token_len // 2 - 1], torch.tensor([self.max_token_len // 2 - 1]) , train_data['geo_data']['sent_position_ids'][1 : self.max_token_len//2 ], torch.tensor([self.max_token_len//2]))) 183 | train_data['concat_data']['pseudo_sentence'] = torch.cat((train_data['nl_data']['pseudo_sentence'][:self.max_token_len // 2 - 1], torch.tensor([self.sep_token_id]) , train_data['geo_data']['pseudo_sentence'][1 : self.max_token_len//2 ], torch.tensor([self.sep_token_id]))) 184 | train_data['concat_data']['token_type_ids'] = torch.cat((train_data['nl_data']['token_type_ids'][:self.max_token_len // 2 - 1], torch.tensor([0]) , train_data['geo_data']['token_type_ids'][1 : self.max_token_len//2 ], torch.tensor([1]))) 185 | 186 | # print('c', train_data['concat_data']['masked_input'].shape, train_data['concat_data']['attention_mask'].shape, train_data['concat_data']['sent_position_ids'].shape, 187 | # train_data['concat_data']['sent_position_ids'].shape, train_data['concat_data']['pseudo_sentence'].shape, train_data['concat_data']['token_type_ids'].shape ) 188 | 189 | 190 | train_data['concat_data']['norm_lng_list'] = torch.tensor([self.spatial_dist_fill for i in range(self.max_token_len)]).to(torch.float32) 191 | train_data['concat_data']['norm_lat_list'] = torch.tensor([self.spatial_dist_fill for i in range(self.max_token_len)]).to(torch.float32) 192 | 193 | return train_data 194 | 195 | def __len__(self): 196 | return self.len_geo_data 197 | 198 | def __getitem__(self, index): 199 | spatial_dist_fill = self.spatial_dist_fill 200 | 201 | if self.if_rand_seq: 202 | # randomly take samples, ignoring the index 203 | line = self.geo_data[np.random.randint(self.len_geo_data)] 204 | else: 205 | line = self.geo_data[index] # take one line from the input data according to the index 206 | 207 | geo_line_data_dict = json.loads(line) 208 | 209 | while geo_line_data_dict['info']['osm_id'] not in self.nl_data: 210 | line = self.geo_data[np.random.randint(self.len_geo_data)] 211 | geo_line_data_dict = json.loads(line) 212 | 213 | return self.load_data(geo_line_data_dict) 214 | -------------------------------------------------------------------------------- /src/dataset_utils/const.py: -------------------------------------------------------------------------------- 1 | def revert_dict(coarse_to_fine_dict): 2 | fine_to_coarse_dict = dict() 3 | for key, value in coarse_to_fine_dict.items(): 4 | for v in value: 5 | fine_to_coarse_dict[v] = key 6 | return fine_to_coarse_dict 7 | 8 | CLASS_9_LIST = ['education', 'entertainment_arts_culture', 'facilities', 'financial', 'healthcare', 'public_service', 'sustenance', 'transportation', 'waste_management'] 9 | 10 | CLASS_118_LIST=['animal_boarding', 'animal_breeding', 'animal_shelter', 'arts_centre', 'atm', 'baby_hatch', 'baking_oven', 'bank', 'bar', 'bbq', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'biergarten', 'boat_rental', 'boat_sharing', 'brothel', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'casino', 'charging_station', 'childcare', 'cinema', 'clinic', 'clock', 'college', 'community_centre', 'compressed_air', 'conference_centre', 'courthouse', 'crematorium', 'dentist', 'dive_centre', 'doctors', 'dog_toilet', 'dressing_room', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'funeral_hall', 'gambling', 'give_box', 'grave_yard', 'grit_bin', 'hospital', 'hunting_stand', 'ice_cream', 'internet_cafe', 'kindergarten', 'kitchen', 'kneipp_water_cure', 'language_school', 'library', 'lounger', 'love_hotel', 'marketplace', 'monastery', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'parking_entrance', 'parking_space', 'pharmacy', 'photo_booth', 'place_of_mourning', 'place_of_worship', 'planetarium', 'police', 'post_box', 'post_depot', 'post_office', 'prison', 'pub', 'public_bath', 'public_bookcase', 'ranger_station', 'recycling', 'refugee_site', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'shower', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'toy_library', 'training', 'university', 'vehicle_inspection', 'vending_machine', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place'] 11 | 12 | CLASS_95_LIST = ['arts_centre', 'atm', 'baby_hatch', 'bank', 'bar', 'bbq', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'biergarten', 'boat_rental', 'boat_sharing', 'brothel', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'casino', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'compressed_air', 'conference_centre', 'courthouse', 'dentist', 'doctors', 'dog_toilet', 'dressing_room', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'gambling', 'give_box', 'grit_bin', 'hospital', 'ice_cream', 'kindergarten', 'language_school', 'library', 'love_hotel', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'parking_entrance', 'parking_space', 'pharmacy', 'planetarium', 'police', 'post_box', 'post_depot', 'post_office', 'prison', 'pub', 'public_bookcase', 'ranger_station', 'recycling', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'shower', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'toy_library', 'training', 'university', 'vehicle_inspection', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place'] 13 | 14 | CLASS_74_LIST = ['arts_centre', 'atm', 'bank', 'bar', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'boat_rental', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'conference_centre', 'courthouse', 'dentist', 'doctors', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'gambling', 'hospital', 'kindergarten', 'language_school', 'library', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'pharmacy', 'police', 'post_box', 'post_depot', 'post_office', 'pub', 'public_bookcase', 'recycling', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'university', 'vehicle_inspection', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place'] 15 | 16 | FEWSHOT_CLASS_55_LIST = ['arts_centre', 'atm', 'bank', 'bar', 'bench', 'bicycle_parking', 'bicycle_rental', 'boat_rental', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'courthouse', 'dentist', 'doctors', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'fountain', 'fuel', 'hospital', 'kindergarten', 'library', 'music_school', 'nightclub', 'parking', 'pharmacy', 'police', 'post_box', 'post_office', 'pub', 'public_bookcase', 'recycling', 'restaurant', 'school', 'shelter', 'social_centre', 'social_facility', 'studio', 'theatre', 'toilets', 'townhall', 'university', 'vehicle_inspection', 'veterinary'] 17 | 18 | 19 | DICT_9to74 = {'sustenance':['bar','cafe','fast_food','food_court','pub','restaurant'], 20 | 'education':['college','driving_school','kindergarten','language_school','library','music_school','school','university'], 21 | 'transportation':['bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental', 22 | 'bus_station','car_rental','car_sharing','car_wash','vehicle_inspection','charging_station','ferry_terminal', 23 | 'fuel','motorcycle_parking','parking','taxi'], 24 | 'financial':['atm','bank','bureau_de_change'], 25 | 'healthcare':['clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary'], 26 | 'entertainment_arts_culture':['arts_centre','cinema','community_centre', 27 | 'conference_centre','events_venue','fountain','gambling', 28 | 'nightclub','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre'], 29 | 'public_service':['courthouse','fire_station','police','post_box', 30 | 'post_depot','post_office','townhall'], 31 | 'facilities':['bench','drinking_water','parcel_locker','shelter', 32 | 'telephone','toilets','water_point','watering_place'], 33 | 'waste_management':['sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station',] 34 | } 35 | 36 | DICT_74to9 = revert_dict(DICT_9to74) 37 | 38 | DICT_9to95 = { 39 | 'education':{'college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university'}, 40 | 'entertainment_arts_culture':{'arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre'}, 41 | 'facilities':{'bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place'}, 42 | 'financial':{'atm','bank','bureau_de_change'}, 43 | 'healthcare':{'baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary'}, 44 | 'public_service':{'courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall'}, 45 | 'sustenance':{'bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant',}, 46 | 'transportation':{'bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi'}, 47 | 'waste_management':{'sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station'}} 48 | 49 | DICT_95to9 = revert_dict(DICT_9to95) 50 | 51 | # DICT_95to9 = { 52 | # 'college':'education', 53 | # 'driving_school':'education', 54 | # 'kindergarten':'education', 55 | # 'language_school':'education', 56 | # 'library':'education', 57 | # 'toy_library':'education', 58 | # 'training':'education', 59 | # 'music_school':'education', 60 | # 'school':'education', 61 | # 'university':'education', 62 | 63 | # 'arts_centre':'entertainment_arts_culture', 64 | # 'brothel':'entertainment_arts_culture', 65 | # 'casino':'entertainment_arts_culture', 66 | # 'cinema':'entertainment_arts_culture', 67 | # 'community_centre':'entertainment_arts_culture', 68 | # 'conference_centre':'entertainment_arts_culture', 69 | # 'events_venue':'entertainment_arts_culture', 70 | # 'fountain':'entertainment_arts_culture', 71 | # 'gambling':'entertainment_arts_culture', 72 | # 'love_hotel':'entertainment_arts_culture', 73 | # 'nightclub':'entertainment_arts_culture', 74 | # 'planetarium':'entertainment_arts_culture', 75 | # 'public_bookcase':'entertainment_arts_culture', 76 | # 'social_centre':'entertainment_arts_culture', 77 | # 'stripclub':'entertainment_arts_culture', 78 | # 'studio':'entertainment_arts_culture', 79 | # 'swingerclub':'entertainment_arts_culture', 80 | # 'theatre':'entertainment_arts_culture', 81 | 82 | # 'bbq': 'facilities', 83 | # 'bench': 'facilities', 84 | # 'dog_toilet': 'facilities', 85 | # 'dressing_room': 'facilities', 86 | # 'drinking_water': 'facilities', 87 | # 'give_box': 'facilities', 88 | # 'parcel_locker': 'facilities', 89 | # 'shelter': 'facilities', 90 | # 'shower': 'facilities', 91 | # 'telephone': 'facilities', 92 | # 'toilets': 'facilities', 93 | # 'water_point': 'facilities', 94 | # 'watering_place': 'facilities', 95 | 96 | # 'atm': 'financial', 97 | # 'bank': 'financial', 98 | # 'bureau_de_change': 'financial', 99 | 100 | # 'baby_hatch':'healthcare', 101 | # 'clinic':'healthcare', 102 | # 'dentist':'healthcare', 103 | # 'doctors':'healthcare', 104 | # 'hospital':'healthcare', 105 | # 'nursing_home':'healthcare', 106 | # 'pharmacy':'healthcare', 107 | # 'social_facility':'healthcare', 108 | # 'veterinary':'healthcare', 109 | 110 | # 'courthouse': 'public_service', 111 | # 'fire_station': 'public_service', 112 | # 'police': 'public_service', 113 | # 'post_box': 'public_service', 114 | # 'post_depot': 'public_service', 115 | # 'post_office': 'public_service', 116 | # 'prison': 'public_service', 117 | # 'ranger_station': 'public_service', 118 | # 'townhall': 'public_service', 119 | 120 | # 'bar': 'sustenance', 121 | # 'biergarten': 'sustenance', 122 | # 'cafe': 'sustenance', 123 | # 'fast_food': 'sustenance', 124 | # 'food_court': 'sustenance', 125 | # 'ice_cream': 'sustenance', 126 | # 'pub': 'sustenance', 127 | # 'restaurant': 'sustenance', 128 | 129 | # 'bicycle_parking': 'transportation', 130 | # 'bicycle_repair_station': 'transportation', 131 | # 'bicycle_rental': 'transportation', 132 | # 'boat_rental': 'transportation', 133 | # 'boat_sharing': 'transportation', 134 | # 'bus_station': 'transportation', 135 | # 'car_rental': 'transportation', 136 | # 'car_sharing': 'transportation', 137 | # 'car_wash': 'transportation', 138 | # 'compressed_air': 'transportation', 139 | # 'vehicle_inspection': 'transportation', 140 | # 'charging_station': 'transportation', 141 | # 'ferry_terminal': 'transportation', 142 | # 'fuel': 'transportation', 143 | # 'grit_bin': 'transportation', 144 | # 'motorcycle_parking': 'transportation', 145 | # 'parking': 'transportation', 146 | # 'parking_entrance': 'transportation', 147 | # 'parking_space': 'transportation', 148 | # 'taxi': 'transportation', 149 | 150 | # 'sanitary_dump_station': 'waste_management', 151 | # 'recycling': 'waste_management', 152 | # 'waste_basket': 'waste_management', 153 | # 'waste_disposal': 'waste_management', 154 | # 'waste_transfer_station': 'waste_management', 155 | 156 | # } 157 | 158 | # CLASS_9_LIST = ['sustenance', 'education', 'transportation', 'financial', 'healthcare', 'entertainment_arts_culture', 'public_service', 'facilities', 'waste_management'] 159 | 160 | # FINE_LIST = ['bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant','college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university','bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi','atm','bank','bureau_de_change','baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary','arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre','courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall','bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place','sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station'] 161 | 162 | # FINE_LIST = ['bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant','college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university','bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi','atm','bank','bureau_de_change','baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary','arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre','courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall','bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place','sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station','animal_boarding','animal_breeding','animal_shelter','baking_oven','childcare','clock','crematorium','dive_centre','funeral_hall','grave_yard','hunting_stand','internet_cafe','kitchen','kneipp_water_cure','lounger','marketplace','monastery','photo_booth','place_of_mourning','place_of_worship','public_bath','refugee_site','vending_machine'] -------------------------------------------------------------------------------- /src/train_joint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from transformers import RobertaTokenizer, BertTokenizer, BertTokenizerFast 4 | from tqdm import tqdm # for our progress bar 5 | from transformers import AdamW 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, Sampler 9 | 10 | sys.path.append('../../../') 11 | from models.spatial_bert_model import SpatialBertModel 12 | from models.spatial_bert_model import SpatialBertConfig 13 | from models.spatial_bert_model import SpatialBertForMaskedLM 14 | from dataset_utils.osm_sample_loader import PbfMapDataset 15 | from dataset_utils.paired_sample_loader import JointDataset 16 | from transformers.models.bert.modeling_bert import BertForMaskedLM 17 | from pytorch_metric_learning import losses 18 | from transformers import AutoModel, AutoTokenizer 19 | 20 | import numpy as np 21 | import argparse 22 | import pdb 23 | 24 | 25 | DEBUG = False 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | torch.manual_seed(42) 29 | torch.cuda.manual_seed_all(42) 30 | 31 | use_amp = True # whether to use automatic mixed precision 32 | 33 | 34 | #TODO: 35 | # unify pivot name and pivot len, ogc_fid in two loaders 36 | 37 | 38 | class MyBatchSampler(Sampler): 39 | def __init__(self, batch_size, single_dataset_len): 40 | batch_size = batch_size // 2 41 | num_batches = single_dataset_len // batch_size 42 | batch_list = [] 43 | for i in range(0, num_batches): 44 | cur_batch_list = [] 45 | cur_batch_list.extend([j for j in range(i * batch_size, (i+1)*batch_size)]) 46 | cur_batch_list.extend([j for j in range(i * batch_size + single_dataset_len, (i+1)*batch_size + single_dataset_len)]) 47 | batch_list.append(cur_batch_list) 48 | 49 | self.batches = batch_list 50 | # print(batch_list[0:100]) 51 | 52 | def __iter__(self): 53 | for batch in self.batches: 54 | yield batch 55 | def __len__(self): 56 | return len(self.batches) 57 | 58 | def training(args): 59 | 60 | num_workers = args.num_workers 61 | batch_size = args.batch_size //2 62 | epochs = args.epochs 63 | lr = args.lr #1e-7 # 5e-5 64 | save_interval = args.save_interval 65 | max_token_len = args.max_token_len 66 | distance_norm_factor = args.distance_norm_factor 67 | spatial_dist_fill=args.spatial_dist_fill 68 | # nl_dist_fill = args.nl_dist_fill 69 | with_type = args.with_type 70 | 71 | bert_option = args.bert_option 72 | if_no_spatial_distance = args.no_spatial_distance 73 | 74 | assert bert_option in ['bert-base','bert-large', 'simcse-base'] 75 | 76 | 77 | model_save_dir = args.model_save_dir 78 | 79 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 80 | 81 | 82 | print('model_save_dir', model_save_dir) 83 | print('\n') 84 | 85 | if bert_option == 'bert-base': 86 | # bert_model = BertForMaskedLM.from_pretrained('bert-base-cased') 87 | # tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 88 | 89 | # same as 90 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 91 | bert_model = AutoModel.from_pretrained('bert-base-cased') 92 | 93 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance) 94 | elif bert_option == 'simcse-base': 95 | name_str = 'princeton-nlp/unsup-simcse-bert-base-uncased' # they don't have cased version 96 | tokenizer = AutoTokenizer.from_pretrained(name_str) 97 | bert_model = AutoModel.from_pretrained(name_str) 98 | 99 | # bert_model = BertForMaskedLM.from_pretrained('bert-base-cased') 100 | # tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 101 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance) 102 | 103 | elif bert_option == 'bert-large': 104 | 105 | # bert_model = BertForMaskedLM.from_pretrained('bert-large-cased') 106 | # tokenizer = BertTokenizerFast.from_pretrained("bert-large-cased") 107 | 108 | # same as 109 | tokenizer = AutoTokenizer.from_pretrained("bert-large-cased") 110 | bert_model = AutoModel.from_pretrained('bert-large-cased') 111 | 112 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) 113 | else: 114 | raise NotImplementedError 115 | 116 | # config.vocab_size = 28996 # for bert-cased 117 | config.vocab_size = tokenizer.vocab_size 118 | 119 | model = SpatialBertForMaskedLM(config) 120 | 121 | 122 | new_state_dict ={} 123 | # Modify the keys in the pretrained_dict to match the new model's prefix 124 | for key, value in bert_model.state_dict().items(): 125 | 126 | new_key = key.replace("encoder.", "bert.encoder.") 127 | new_key = new_key.replace("embeddings.","bert.embeddings.") 128 | new_key = new_key.replace("word_bert.embeddings","word_embeddings") 129 | new_key = new_key.replace("position_bert.embeddings","position_embeddings") 130 | new_key = new_key.replace("token_type_bert.embeddings","token_type_embeddings") 131 | new_state_dict[new_key] = value 132 | 133 | model.load_state_dict(new_state_dict, strict = False) # load sentence position embedding weights as well 134 | 135 | 136 | train_dataset_list = [] 137 | continent_list = ['africa','antarctica','asia','australia_oceania', 138 | 'central_america','europe','north_america','south_america'] 139 | 140 | 141 | geo_pseudo_sent_path = os.path.join(args.pseudo_sentence_dir, 'world.json') 142 | nl_sent_path = os.path.join(args.nl_sentence_dir, 'world.json') 143 | 144 | # take samples sequentially to gather hard negative in one batch 145 | train_dataset_hardneg = JointDataset(geo_file_path = geo_pseudo_sent_path, 146 | nl_file_path = nl_sent_path, 147 | placename_to_osmid_path = args.placename_to_osmid_path, 148 | tokenizer = tokenizer, 149 | max_token_len = max_token_len, 150 | distance_norm_factor = distance_norm_factor, 151 | spatial_dist_fill = spatial_dist_fill, 152 | sep_between_neighbors = True, 153 | label_encoder = None, 154 | if_rand_seq = False, 155 | ) 156 | 157 | # take samples randomly 158 | train_dataset_randseq = JointDataset(geo_file_path = geo_pseudo_sent_path, 159 | nl_file_path = nl_sent_path, 160 | placename_to_osmid_path = args.placename_to_osmid_path, 161 | tokenizer = tokenizer, 162 | max_token_len = max_token_len, 163 | distance_norm_factor = distance_norm_factor, 164 | spatial_dist_fill = spatial_dist_fill, 165 | sep_between_neighbors = True, 166 | label_encoder = None, 167 | if_rand_seq = True 168 | ) 169 | 170 | train_dataset = torch.utils.data.ConcatDataset([train_dataset_hardneg, train_dataset_randseq]) 171 | 172 | batch_sampler = MyBatchSampler(batch_size = batch_size, single_dataset_len = len(train_dataset_randseq)) 173 | train_loader = DataLoader(train_dataset, num_workers=num_workers, # batch_size= batch_size, 174 | pin_memory=True, # drop_last=True, shuffle=False, 175 | batch_sampler= batch_sampler) 176 | # shuffle needs to be false 177 | 178 | 179 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 180 | model.to(device) 181 | model.train() 182 | 183 | 184 | # initialize optimizer 185 | optim = torch.optim.AdamW(model.parameters(), lr = lr) 186 | 187 | contrastive_criterion = losses.NTXentLoss(temperature=0.01) 188 | 189 | if args.checkpoint_weight is not None: 190 | print('load weights from checkpoint', args.checkpoint_weight) 191 | checkpoint = torch.load(args.checkpoint_weight) 192 | model.load_state_dict(checkpoint["model"], strict = True) 193 | optim.load_state_dict(checkpoint["optimizer"]) 194 | 195 | 196 | print('start training...') 197 | 198 | for epoch in range(epochs): 199 | # setup loop with TQDM and dataloader 200 | loop = tqdm(train_loader, leave=True) 201 | iter = 0 202 | for batch in loop: 203 | 204 | with torch.autocast(device_type = 'cuda', dtype = torch.float16, enabled = use_amp): 205 | 206 | nl_data = batch['nl_data'] 207 | geo_data = batch['geo_data'] 208 | concat_data = batch['concat_data'] 209 | 210 | nl_input_ids = nl_data['masked_input'].to(device) 211 | nl_entity_token_idx = nl_data['pivot_token_idx'].to(device) 212 | nl_attention_mask = nl_data['attention_mask'].to(device) 213 | nl_position_list_x = nl_data['norm_lng_list'].to(device) 214 | nl_position_list_y = nl_data['norm_lat_list'].to(device) 215 | nl_sent_position_ids = nl_data['sent_position_ids'].to(device) 216 | nl_pseudo_sentence = nl_data['pseudo_sentence'].to(device) 217 | nl_token_type_ids = nl_data['token_type_ids'].to(device) 218 | 219 | geo_input_ids = geo_data['masked_input'].to(device) 220 | geo_entity_token_idx = geo_data['pivot_token_idx'].to(device) 221 | geo_attention_mask = geo_data['attention_mask'].to(device) 222 | geo_position_list_x = geo_data['norm_lng_list'].to(device) 223 | geo_position_list_y = geo_data['norm_lat_list'].to(device) 224 | geo_sent_position_ids = geo_data['sent_position_ids'].to(device) 225 | geo_pseudo_sentence = geo_data['pseudo_sentence'].to(device) 226 | geo_token_type_ids = geo_data['token_type_ids'].to(device) 227 | 228 | joint_input_ids = concat_data['masked_input'].to(device) 229 | joint_attention_mask = concat_data['attention_mask'].to(device) 230 | joint_position_list_x = concat_data['norm_lng_list'].to(device) 231 | joint_position_list_y = concat_data['norm_lat_list'].to(device) 232 | joint_sent_position_ids = concat_data['sent_position_ids'].to(device) 233 | joint_pseudo_sentence = concat_data['pseudo_sentence'].to(device) 234 | joint_token_type_ids = concat_data['token_type_ids'].to(device) 235 | # pdb.set_trace() 236 | 237 | 238 | try: 239 | outputs1 = model(joint_input_ids, attention_mask = joint_attention_mask, sent_position_ids = joint_sent_position_ids, pivot_token_idx_list=None, 240 | spatial_position_list_x = joint_position_list_x, spatial_position_list_y = joint_position_list_y, token_type_ids = joint_token_type_ids, labels = joint_pseudo_sentence) 241 | except Exception as e: 242 | print(e) 243 | pdb.set_trace() 244 | 245 | # mlm for on joint geo and nl data 246 | 247 | loss1 = outputs1.loss 248 | 249 | # loss1.backward() 250 | # optim.step() 251 | scaler.scale(loss1).backward() 252 | scaler.unscale_(optim) 253 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 254 | scaler.step(optim) 255 | scaler.update() 256 | optim.zero_grad() 257 | 258 | 259 | outputs1 = model(geo_pseudo_sentence, attention_mask = geo_attention_mask, sent_position_ids = geo_sent_position_ids, pivot_token_idx_list=geo_entity_token_idx, 260 | spatial_position_list_x = geo_position_list_x, spatial_position_list_y = geo_position_list_y, token_type_ids = geo_token_type_ids, labels = None) 261 | 262 | outputs2 = model(nl_pseudo_sentence, attention_mask = nl_attention_mask, sent_position_ids = nl_sent_position_ids,pivot_token_idx_list=nl_entity_token_idx, 263 | spatial_position_list_x = nl_position_list_x, spatial_position_list_y = nl_position_list_y, token_type_ids = nl_token_type_ids, labels = None) 264 | 265 | embedding = torch.cat((outputs1.hidden_states, outputs2.hidden_states), 0) 266 | indicator = torch.arange(0, batch_size, dtype=torch.float32, requires_grad=False).to(device) 267 | indicator = torch.cat((indicator, indicator),0) 268 | loss3 = contrastive_criterion(embedding, indicator) 269 | 270 | if torch.isnan(loss3): 271 | print(nl_entity_token_idx) 272 | return 273 | 274 | scaler.scale(loss3).backward() 275 | scaler.unscale_(optim) 276 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 277 | scaler.step(optim) 278 | scaler.update() 279 | optim.zero_grad() 280 | 281 | loss = loss1 + loss3 282 | # pdb.set_trace() 283 | 284 | loop.set_description(f'Epoch {epoch}') 285 | loop.set_postfix({'loss':loss.item(),'mlm':loss1.item(),'contrast':loss3.item()}) 286 | 287 | 288 | if DEBUG: 289 | print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() ) 290 | 291 | iter += 1 292 | 293 | if iter % save_interval == 0 or iter == loop.total: 294 | save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \ 295 | + '_' +str("{:.4f}".format(loss.item())) +'.pth' ) 296 | checkpoint = {"model": model.state_dict(), 297 | "optimizer": optim.state_dict(), 298 | "scaler": scaler.state_dict()} 299 | torch.save(checkpoint, save_path) 300 | print('saving model checkpoint to', save_path) 301 | 302 | 303 | def main(): 304 | 305 | parser = argparse.ArgumentParser() 306 | parser.add_argument('--num_workers', type=int, default=5) 307 | parser.add_argument('--batch_size', type=int, default=16) 308 | parser.add_argument('--epochs', type=int, default=15) 309 | parser.add_argument('--save_interval', type=int, default=8000) 310 | parser.add_argument('--max_token_len', type=int, default=512) 311 | 312 | parser.add_argument('--lr', type=float, default = 1e-5) 313 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 314 | parser.add_argument('--spatial_dist_fill', type=float, default = 90000) 315 | # parser.add_argument('--nl_dist_fill', type=float, default = 0) 316 | 317 | parser.add_argument('--with_type', default=False, action='store_true') 318 | parser.add_argument('--no_spatial_distance', default=False, action='store_true') 319 | 320 | parser.add_argument('--bert_option', type=str, default='bert-base', choices= ['bert-base','bert-large', 'simcse-base']) 321 | parser.add_argument('--model_save_dir', type=str, default='/data4/zekun/joint_model/weights_base_v1') 322 | parser.add_argument('--checkpoint_weight', type=str, default=None) 323 | 324 | parser.add_argument('--pseudo_sentence_dir', type=str, default = '/data4/zekun/osm_pseudo_sent/world_append_wikidata/') 325 | parser.add_argument('--nl_sentence_dir', type=str, default = '/data4/zekun/wikidata/world_georelation/joint_0618_valid') 326 | parser.add_argument('--placename_to_osmid_path', type=str, default = '/home/zekun/datasets/osm_pseudo_sent/name-osmid-dict/placename_to_osmid.json') 327 | 328 | # parser.add_argument('--wikidata_dir', type=str, default = '/data4/zekun/wikidata/world_geo/') 329 | #parser.add_argument('--wikipedia_dir', type=str, default = '/data2/wikipedia/separate_text_world/') 330 | #parser.add_argument('--trie_file_dir', type=str, default = '/data2/wikipedia/trie_output_world_format/') 331 | 332 | # CUDA_VISIBLE_DEVICES='2' python3 train_joint.py --model_save_dir='../../weights_base_run2' --pseudo_sentence_dir='../../datasets/osm_pseudo_sent/world_append_wikidata/' --nl_sentence_dir='../../datasets/wikidata/world_georelation/joint_v2/' --batch_size=32 333 | # CUDA_VISIBLE_DEVICES='3' python3 train_joint.py --bert_option='bert-large' --pseudo_sentence_dir='../../datasets/osm_pseudo_sent/world_append_wikidata/' --nl_sentence_dir='../../datasets/wikidata/world_georelation/joint_v2/' --model_save_dir='../../weights_large' --batch_size=10 --lr=1e-6 334 | 335 | # CUDA_VISIBLE_DEVICES='2' python3 train_joint.py --model_save_dir='../../weights_base_run2' --pseudo_sentence_dir='../../datasets/osm_pseudo_sent/world_append_wikidata/' --nl_sentence_dir='../../datasets/wikidata/world_georelation/joint_v2/' --batch_size=32 --checkpoint_weight='../../weights_base/ep1_iter60000_0.0440.pth' 336 | 337 | 338 | args = parser.parse_args() 339 | print('\n') 340 | print(args) 341 | print('\n') 342 | 343 | 344 | # out_dir not None, and out_dir does not exist, then create out_dir 345 | if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir): 346 | os.makedirs(args.model_save_dir) 347 | 348 | training(args) 349 | 350 | 351 | 352 | if __name__ == '__main__': 353 | 354 | main() 355 | 356 | -------------------------------------------------------------------------------- /experiments/entity_linking/link_geonames.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import sys 5 | import os 6 | import numpy as np 7 | import pdb 8 | import json 9 | import scipy.spatial as sp 10 | import argparse 11 | 12 | 13 | import torch 14 | from torch.utils.data import DataLoader 15 | from transformers import BertTokenizerFast 16 | from tqdm import tqdm # for our progress bar 17 | 18 | sys.path.append('../../src/') 19 | from datasets.dataset_loader import SpatialDataset 20 | from models.spatial_bert_model import SpatialBertModel 21 | from models.spatial_bert_model import SpatialBertConfig 22 | from utils.find_closest import find_ref_closest_match, sort_ref_closest_match 23 | from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv 24 | from utils.baseline_utils import get_baseline_model 25 | from transformers import BertModel 26 | from transformers import AutoModel, AutoTokenizer 27 | import pdb 28 | 29 | from haversine import haversine, Unit 30 | 31 | 32 | MODEL_OPTIONS = ['joint-base','joint-large', 'bert-base','bert-large','roberta-base','roberta-large', 33 | 'spanbert-base','spanbert-large',# 'luke-base','luke-large', 34 | 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large', 'simcse-base', 35 | 'sap-bert','mirror-bert'] 36 | 37 | 38 | 39 | def get_offset_mapping(nl_tokens): 40 | offset_mapping = nl_tokens['offset_mapping'][1:-1] 41 | flat_offset_mapping = np.array(offset_mapping).flatten() 42 | offset_mapping_dict_start = {} 43 | offset_mapping_dict_end = {} 44 | for idx in range(0,len(flat_offset_mapping),2): 45 | char_pos = flat_offset_mapping[idx] 46 | if char_pos == 0 and idx != 0: 47 | break 48 | token_pos = idx//2 + 1 49 | offset_mapping_dict_start[char_pos] = token_pos 50 | for idx in range(1,len(flat_offset_mapping),2): 51 | char_pos = flat_offset_mapping[idx] 52 | if char_pos == 0 and idx != 0: 53 | break 54 | token_pos = (idx-1)//2 + 1 +1 55 | offset_mapping_dict_end[char_pos] = token_pos 56 | 57 | return offset_mapping_dict_start, offset_mapping_dict_end 58 | 59 | def get_nl_feature(text, gt_lat, gt_lon, start_span, end_span, model, model_name, tokenizer, spatial_dist_fill, device): 60 | 61 | # text = paragraph['text'] 62 | # gt_lat = paragraph['lat'] 63 | # gt_lon = paragraph['lon'] 64 | 65 | # spans = paragraph['spans'] # TODO: can be improved 66 | # selected_span = spans[0] 67 | 68 | sentence_len = 512 69 | nl_tokens = tokenizer(text, padding="max_length", max_length=sentence_len, truncation = True, return_offsets_mapping = True) 70 | offset_mapping_dict_start, offset_mapping_dict_end = get_offset_mapping(nl_tokens) 71 | if start_span not in offset_mapping_dict_start or end_span not in offset_mapping_dict_end: 72 | # pdb.set_trace() 73 | return None # TODO: exceeds length. fix later 74 | 75 | token_start_idx = offset_mapping_dict_start[start_span] 76 | token_end_idx = offset_mapping_dict_end[end_span] 77 | 78 | nl_tokens['sent_position_ids'] = torch.tensor(np.array([np.arange(0, sentence_len)])).to(device) 79 | nl_tokens['norm_lng_list'] = torch.tensor([[0 for i in range(sentence_len)]]).to(torch.float32).to(device) 80 | nl_tokens['norm_lat_list'] = torch.tensor([[0 for i in range(sentence_len)]]).to(torch.float32).to(device) 81 | nl_tokens['token_type_ids'] = torch.zeros(1,sentence_len).int().to(device) 82 | entity_token_idx = torch.tensor([[token_start_idx, token_end_idx]]).to(device) 83 | 84 | if model_name == 'joint-base' or model_name=='joint-large' or model_name == 'simcse-base': 85 | nl_outputs = model(torch.tensor([nl_tokens['input_ids']]).to(device), 86 | attention_mask = torch.tensor([nl_tokens['attention_mask']]).to(device), 87 | sent_position_ids = nl_tokens['sent_position_ids'], 88 | pivot_token_idx_list=entity_token_idx, 89 | position_list_x = nl_tokens['norm_lng_list'], 90 | position_list_y = nl_tokens['norm_lat_list'] , 91 | token_type_ids = nl_tokens['token_type_ids']) 92 | 93 | nl_entity_feature = nl_outputs.pooler_output 94 | nl_entity_feature = nl_entity_feature[0].detach().cpu().numpy() 95 | else: 96 | nl_outputs = model(torch.tensor([nl_tokens['input_ids']]).to(device), 97 | attention_mask =torch.tensor([nl_tokens['attention_mask']]).to(device), 98 | token_type_ids = nl_tokens['token_type_ids'], 99 | ) 100 | embeddings = nl_outputs.last_hidden_state 101 | nl_entity_feature = embeddings[0][token_start_idx:token_end_idx] 102 | nl_entity_feature = torch.mean(nl_entity_feature, axis = 0).detach().cpu().numpy() # (768, ) 103 | # print(nl_entity_feature.shape) 104 | 105 | return nl_entity_feature 106 | 107 | def get_geoname_features(geonames_cand, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset): 108 | 109 | 110 | geo_feat_list = [] 111 | geo_id_list = [] 112 | geo_loc_list = [] 113 | for gn_cand in geonames_cand: 114 | 115 | pivot_name = gn_cand['info']['name'] 116 | pivot_pos = gn_cand['info']['geometry']['coordinates'] #(lng, lat) 117 | pivot_geonames_id = gn_cand['info']['geoname_id'] 118 | 119 | 120 | neighbor_info = gn_cand['neighbor_info'] 121 | neighbor_name_list = neighbor_info['name_list'] 122 | neighbor_geometry_list = neighbor_info['geometry_list'] 123 | 124 | 125 | geo_data = spatial_dataset.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill ) 126 | geo_data['token_type_ids'] = torch.ones( len(geo_data['pseudo_sentence'])).int() # type 1 for geo data 127 | 128 | if model_name == 'joint-base' or model_name=='joint-large' or model_name == 'simcse-base': 129 | geo_outputs = model(geo_data['pseudo_sentence'].unsqueeze(0).to(device), 130 | attention_mask = geo_data['attention_mask'].unsqueeze(0).to(device), 131 | sent_position_ids = geo_data['sent_position_ids'].unsqueeze(0).to(device), 132 | pivot_token_idx_list= geo_data['pivot_token_idx'].unsqueeze(0).to(device), 133 | position_list_x = geo_data['norm_lng_list'].unsqueeze(0).to(device), 134 | position_list_y = geo_data['norm_lat_list'].unsqueeze(0).to(device), 135 | token_type_ids = geo_data['token_type_ids'].unsqueeze(0).to(device), 136 | ) 137 | geo_feat = geo_outputs.pooler_output 138 | geo_feat = geo_feat[0].detach().cpu().numpy() 139 | 140 | else: 141 | # pdb.set_trace() 142 | if model_name == 'roberta-base': 143 | geo_outputs = model(geo_data['pseudo_sentence'].unsqueeze(0).to(device), 144 | attention_mask = geo_data['attention_mask'].unsqueeze(0).to(device), 145 | ) 146 | else: 147 | geo_outputs = model(geo_data['pseudo_sentence'].unsqueeze(0).to(device), 148 | attention_mask = geo_data['attention_mask'].unsqueeze(0).to(device), 149 | token_type_ids = geo_data['token_type_ids'].unsqueeze(0).to(device), 150 | ) 151 | 152 | embeddings = geo_outputs.last_hidden_state 153 | geo_feat = embeddings[0][geo_data['pivot_token_idx'][0]:geo_data['pivot_token_idx'][1]] 154 | geo_feat = torch.mean(geo_feat, axis = 0).detach().cpu().numpy() # (768, ) 155 | 156 | 157 | # pivot_embed = pivot_embed[0].detach().cpu().numpy() 158 | 159 | # print(geo_feat.shape) 160 | 161 | geo_feat_list.append(geo_feat) 162 | geo_id_list.append(pivot_geonames_id) 163 | geo_loc_list.append({'lon':pivot_pos[0], 'lat':pivot_pos[1]}) 164 | 165 | return geo_feat_list, geo_loc_list, geo_id_list 166 | 167 | 168 | def wiktor_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset): 169 | 170 | distance_list = [] 171 | acc_count_at_161 = 0 172 | acc_count_total = 0 173 | correct_geoname_count = 0 174 | 175 | with open(out_path, 'w') as f: 176 | pass # flush 177 | 178 | 179 | # overall list, geodestic distance histogram 180 | for query_name, paragraph_list in query_data.items(): 181 | 182 | if query_name in geonames_dict: 183 | geonames_cand = geonames_dict[query_name] 184 | # print(query_name, len(paragraph_list), len(geonames_dict[query_name])) 185 | geoname_features, geonames_loc_list, geonames_id_list = get_geoname_features(geonames_cand, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset) 186 | else: 187 | continue 188 | # print(query_name, 'not in geonames_dict') 189 | 190 | samename_ret_list = [] 191 | for paragraph in paragraph_list: 192 | # cur_dict = {'text':text, 'feature':feature, 'url':url, 'country':country, 'lat':lat, 'lon':lon, 193 | # 'spans':spans} 194 | if 'url' in paragraph: 195 | wiki_url = paragraph['url'] 196 | else: 197 | wiki_url = None 198 | 199 | text = paragraph['text'] 200 | gt_lat = paragraph['lat'] 201 | gt_lon = paragraph['lon'] 202 | 203 | spans = paragraph['spans'] # TODO: can be improved 204 | if len(spans) == 0: 205 | # pdb.set_trace() 206 | continue 207 | selected_span = spans[0] 208 | start_span, end_span = selected_span[0], selected_span[1] 209 | 210 | 211 | nl_feature = get_nl_feature(text, gt_lat, gt_lon, start_span, end_span, model, model_name, tokenizer, spatial_dist_fill, device) 212 | 213 | if nl_feature is None: continue 214 | 215 | # nl_feature_shape: torch.Size([1, 768]) 216 | 217 | sim_matrix = 1 - sp.distance.cdist(np.array(geoname_features), np.array([nl_feature]), 'cosine') 218 | 219 | closest_match_geonames_id = sort_ref_closest_match(sim_matrix, geonames_id_list) 220 | closest_match_geonames_loc = sort_ref_closest_match(sim_matrix, geonames_loc_list) 221 | 222 | sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order 223 | 224 | ret_dict = dict() 225 | ret_dict['pivot_name'] = query_name 226 | ret_dict['gt_loc'] = {'lon':paragraph['lon'], 'lat':paragraph['lat']} 227 | ret_dict['wiki_url'] = wiki_url 228 | ret_dict['sorted_match_geoname_id'] = [a[0] for a in closest_match_geonames_id] 229 | ret_dict['closest_match_geonames_loc'] = [a[0] for a in closest_match_geonames_loc] 230 | #ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des] 231 | ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix] 232 | 233 | samename_ret_list.append(ret_dict) 234 | 235 | # print(ret_dict['gt_loc'], ret_dict['wiki_url'], ret_dict['closest_match_geonames_loc']) 236 | 237 | gt_loc = (float(paragraph['lat']), float(paragraph['lon'])) 238 | pred_loc = ret_dict['closest_match_geonames_loc'][0] 239 | pred_loc = (pred_loc['lat'], pred_loc['lon']) 240 | error_dist = haversine(gt_loc, pred_loc) 241 | distance_list.append(error_dist) 242 | # pdb.set_trace() 243 | if error_dist < 161: 244 | acc_count_at_161 += 1 245 | 246 | acc_count_total+=1 247 | 248 | ret_dict['sorted_match_geoname_id'] = ret_dict['sorted_match_geoname_id'] 249 | ret_dict['closest_match_geonames_loc'] = ret_dict['closest_match_geonames_loc'] 250 | 251 | with open(out_path, 'a') as f: 252 | json.dump(ret_dict, f) 253 | f.write('\n') 254 | 255 | return {'distance_list':distance_list, 'acc_at_161:': 1.0*acc_count_at_161/acc_count_total} 256 | 257 | 258 | def toponym_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset): 259 | 260 | distance_list = [] 261 | acc_count_at_161 = 0 262 | acc_count_total = 0 263 | correct_geoname_count = 0 264 | 265 | with open(out_path, 'w') as f: 266 | pass # flush 267 | 268 | for sample in query_data: 269 | # cur_dict = {'sentence':sentence, 'toponyms':[]} 270 | text = sample['sentence'] 271 | 272 | for toponym in sample['toponyms']: 273 | if 'geoname_id' not in toponym: 274 | continue # skip this sample in evaluation 275 | 276 | query_name = toponym['text'] 277 | start_span = toponym['start'] 278 | end_span = toponym['end'] 279 | geoname_id = toponym['geoname_id'] 280 | gt_lat = toponym['lat'] 281 | gt_lon = toponym['lon'] 282 | 283 | nl_feature = get_nl_feature(text, gt_lat, gt_lon, start_span, end_span, model, model_name, tokenizer, spatial_dist_fill, device) 284 | 285 | if nl_feature is None: continue 286 | 287 | if query_name in geonames_dict: 288 | geonames_cand = geonames_dict[query_name] 289 | # print(query_name, len(geonames_dict[query_name])) 290 | geoname_features, geonames_loc_list, geonames_id_list = get_geoname_features(geonames_cand, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset) 291 | # pdb.set_trace() 292 | # print(geoname_features) 293 | else: 294 | continue 295 | 296 | sim_matrix = 1 - sp.distance.cdist(np.array(geoname_features), np.array([nl_feature]), 'cosine') 297 | 298 | closest_match_geonames_id = sort_ref_closest_match(sim_matrix, geonames_id_list) 299 | closest_match_geonames_loc = sort_ref_closest_match(sim_matrix, geonames_loc_list) 300 | 301 | sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order 302 | 303 | ret_dict = dict() 304 | ret_dict['pivot_name'] = query_name 305 | ret_dict['gt_loc'] = {'lon':gt_lon, 'lat':gt_lat} 306 | ret_dict['geoname_id'] = geoname_id 307 | ret_dict['sorted_match_geoname_id'] = [a[0] for a in closest_match_geonames_id] 308 | ret_dict['closest_match_geonames_loc'] = [a[0] for a in closest_match_geonames_loc] 309 | #ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des] 310 | ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix] 311 | 312 | # samename_ret_list.append(ret_dict) 313 | 314 | # print(ret_dict['gt_loc'], ret_dict['closest_match_geonames_loc']) 315 | 316 | gt_loc = (gt_lat, gt_lon) 317 | pred_loc = ret_dict['closest_match_geonames_loc'][0] 318 | pred_loc = (pred_loc['lat'], pred_loc['lon']) 319 | error_dist = haversine(gt_loc, pred_loc) 320 | distance_list.append(error_dist) 321 | 322 | if error_dist < 161: 323 | acc_count_at_161 += 1 324 | 325 | if str(ret_dict['sorted_match_geoname_id'][0]) == geoname_id: 326 | correct_geoname_count += 1 327 | 328 | acc_count_total+=1 329 | ret_dict['sorted_match_geoname_id'] = ret_dict['sorted_match_geoname_id'] 330 | ret_dict['closest_match_geonames_loc'] = ret_dict['closest_match_geonames_loc'] 331 | 332 | with open(out_path, 'a') as f: 333 | json.dump(ret_dict, f) 334 | f.write('\n') 335 | 336 | return {'distance_list':distance_list, 'acc@1': 1.0*correct_geoname_count/acc_count_total, 'acc_at_161:': 1.0*acc_count_at_161/acc_count_total} 337 | 338 | 339 | def entity_linking_func(args): 340 | 341 | model_name = args.model_name 342 | 343 | distance_norm_factor = args.distance_norm_factor 344 | spatial_dist_fill= args.spatial_dist_fill 345 | sep_between_neighbors = True 346 | spatial_bert_weight_dir = args.spatial_bert_weight_dir 347 | spatial_bert_weight_name = args.spatial_bert_weight_name 348 | if_no_spatial_distance = args.no_spatial_distance 349 | 350 | 351 | assert model_name in MODEL_OPTIONS 352 | 353 | 354 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 355 | 356 | out_dir = args.out_dir 357 | 358 | print('out_dir', out_dir) 359 | 360 | if model_name == 'joint-base' or model_name == 'joint-large' or model_name =='simcse-base': 361 | if model_name == 'joint-base' or model_name == 'joint-large': 362 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 363 | else: 364 | tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/unsup-simcse-bert-base-uncased') 365 | 366 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance) 367 | 368 | config.vocab_size = tokenizer.vocab_size 369 | 370 | model = SpatialBertModel(config) 371 | 372 | model.to(device) 373 | model.eval() 374 | 375 | # load pretrained weights 376 | weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name) 377 | model = load_spatial_bert_pretrained_weights(model, weight_path) 378 | 379 | elif model_name in MODEL_OPTIONS: #'bert-base': 380 | model, tokenizer = get_baseline_model(model_name) 381 | # model.config.type_vocab_size=2 382 | model.to(device) 383 | model.eval() 384 | else: 385 | raise NotImplementedError 386 | 387 | spatial_dataset = SpatialDataset(tokenizer , max_token_len=512 , distance_norm_factor=distance_norm_factor, sep_between_neighbors = True) 388 | 389 | 390 | with open(args.query_dataset_path,'r') as f: 391 | query_data = json.load(f) 392 | 393 | geonames_dict = {} 394 | 395 | with open(args.ref_dataset_path,'r') as f: 396 | geonames_data = json.load(f) 397 | for info in geonames_data: 398 | key = next(iter(info)) 399 | value = info[key] 400 | 401 | geonames_dict[key] = value 402 | 403 | if 'WikToR' in args.query_dataset_path: 404 | out_path = os.path.join(out_dir, 'wiktor.json') 405 | eval_info = wiktor_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset) 406 | elif 'lgl' in args.query_dataset_path or 'geowebnews' in args.query_dataset_path: 407 | if 'lgl' in args.query_dataset_path: 408 | out_path = os.path.join(out_dir, 'lgl.json') 409 | elif 'geowebnews' in args.query_dataset_path: 410 | out_path = os.path.join(out_dir, 'geowebnews.json') 411 | 412 | eval_info = toponym_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset) 413 | 414 | # print(distance_list) 415 | # print('acc_at_161:',1.0*acc_count_at_161/acc_count_total) 416 | 417 | with open(out_path, 'a') as f: 418 | json.dump(eval_info, f) 419 | f.write('\n') 420 | 421 | print(eval_info) 422 | 423 | 424 | def main(): 425 | parser = argparse.ArgumentParser() 426 | parser.add_argument('--model_name', type=str, default='joint-base') 427 | parser.add_argument('--query_dataset_path', type=str, default='../../data/WikToR.json') 428 | # parser.add_argument('--ref_dataset_path', type=str, default='../../data/geoname-ids-v3-part02.json') 429 | parser.add_argument('--ref_dataset_path', type=str, default='/home/zekun/datasets/geonames/geonames_for_wiktor/geoname-ids.json') 430 | 431 | parser.add_argument('--out_dir', type=str, default=None) 432 | 433 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 434 | parser.add_argument('--spatial_dist_fill', type=float, default = 90000) 435 | 436 | parser.add_argument('--no_spatial_distance', default=False, action='store_true') 437 | 438 | parser.add_argument('--spatial_bert_weight_dir', type = str, default = None) 439 | parser.add_argument('--spatial_bert_weight_name', type = str, default = None) 440 | 441 | args = parser.parse_args() 442 | print('\n') 443 | print(args) 444 | print('\n') 445 | 446 | # out_dir not None, and out_dir does not exist, then create out_dir 447 | if args.out_dir is not None and not os.path.isdir(args.out_dir): 448 | os.makedirs(args.out_dir) 449 | 450 | entity_linking_func(args) 451 | 452 | # python3 link_geonames.py --out_dir='debug' --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep12_iter24000_0.0061.pth' 453 | # ep14_iter88000_0.0039.pth 454 | # ep14_iter96000_0.0382.pth 455 | 456 | # python3 link_geonames.py --out_dir='debug' --model_name='bert-base' 457 | 458 | # python3 link_geonames.py --model_name='bert-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --out_dir='baselines/bert-base' 459 | 460 | # CUDA_VISIBLE_DEVICES='1' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep14_iter88000_0.0039.pth' 461 | 462 | # May 463 | 464 | # CUDA_VISIBLE_DEVICES='3' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/geowebnews/GWN.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_geowebnews/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0505/' --spatial_bert_weight_name='ep0_iter80000_1.3164.pth' --out_dir='results' 465 | 466 | # CUDA_VISIBLE_DEVICES='3' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0505/' --spatial_bert_weight_name='ep0_iter80000_1.3164.pth' --out_dir='results' 467 | 468 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0505/' --spatial_bert_weight_name='ep1_iter52000_1.3994.pth' --out_dir='results' 469 | 470 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0506/' --spatial_bert_weight_name='ep0_iter48000_1.5495.pth' --out_dir='results-0506' 471 | 472 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep14_iter108000_0.0172.pth' --out_dir='results-run2' 473 | 474 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/geowebnews/GWN.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_geowebnews/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep14_iter108000_0.0172.pth' --out_dir='results-run2' 475 | 476 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0508/' --spatial_bert_weight_name='ep0_iter144000_0.5711.pth' --out_dir='results-run2' 477 | 478 | # ../../weights_base_0511/ep1_iter84000_0.5168.pth 479 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0511/' --spatial_bert_weight_name='ep1_iter84000_0.5168.pth' --out_dir='debug' 480 | 481 | 482 | # CUDA_VISIBLE_DEVICES='1' python3 link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=900 --spatial_bert_weight_dir='/data4/zekun/joint_model/weights_0517/' --spatial_bert_weight_name='ep5_iter04000_0.0486.pth' --out_dir='debug' --query_dataset_path='/data4/zekun/toponym_detection/geowebnews/GWN.json' --ref_dataset_path='/data4/zekun/geonames/geonames_for_geowebnews/geoname-ids.json' 483 | # 'acc@1': 0.23718439173680184, 'acc_at_161:': 0.31675592960979343} 484 | 485 | if __name__ == '__main__': 486 | 487 | main() 488 | 489 | 490 | -------------------------------------------------------------------------------- /experiments/entity_linking/multi_link_geonames.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import sys 5 | import os 6 | import numpy as np 7 | import pdb 8 | import json 9 | import scipy.spatial as sp 10 | import argparse 11 | import math 12 | 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | from transformers import BertTokenizerFast 17 | from tqdm import tqdm # for our progress bar 18 | 19 | sys.path.insert(0,'/home/zekun/joint_model/src/') 20 | from dataset_utils.dataset_loader import SpatialDataset 21 | from models.spatial_bert_model import SpatialBertModel 22 | from models.spatial_bert_model import SpatialBertConfig 23 | from utils.find_closest import find_ref_closest_match, sort_ref_closest_match 24 | from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv 25 | from utils.baseline_utils import get_baseline_model 26 | from transformers import BertModel 27 | from transformers import AutoModel, AutoTokenizer 28 | import pdb 29 | 30 | from haversine import haversine, Unit 31 | 32 | 33 | MODEL_OPTIONS = ['joint-base','joint-large', 'bert-base','bert-large','roberta-base','roberta-large', 34 | 'spanbert-base','spanbert-large',# 'luke-base','luke-large', 35 | 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large', 'simcse-base', 36 | 'sap-bert'] 37 | 38 | 39 | def get_offset_mapping(nl_tokens): 40 | offset_mapping = nl_tokens['offset_mapping'][1:-1] 41 | flat_offset_mapping = np.array(offset_mapping).flatten() 42 | offset_mapping_dict_start = {} 43 | offset_mapping_dict_end = {} 44 | for idx in range(0,len(flat_offset_mapping),2): 45 | char_pos = flat_offset_mapping[idx] 46 | if char_pos == 0 and idx != 0: 47 | break 48 | token_pos = idx//2 + 1 49 | offset_mapping_dict_start[char_pos] = token_pos 50 | for idx in range(1,len(flat_offset_mapping),2): 51 | char_pos = flat_offset_mapping[idx] 52 | if char_pos == 0 and idx != 0: 53 | break 54 | token_pos = (idx-1)//2 + 1 +1 55 | offset_mapping_dict_end[char_pos] = token_pos 56 | 57 | return offset_mapping_dict_start, offset_mapping_dict_end 58 | 59 | def get_nl_feature(text, gt_lat, gt_lon, start_span, end_span, model, model_name, tokenizer, spatial_dist_fill, device): 60 | 61 | # text = paragraph['text'] 62 | # gt_lat = paragraph['lat'] 63 | # gt_lon = paragraph['lon'] 64 | 65 | # spans = paragraph['spans'] # TODO: can be improved 66 | # selected_span = spans[0] 67 | 68 | sentence_len = 512 69 | nl_tokens = tokenizer(text, padding="max_length", max_length=sentence_len, truncation = True, return_offsets_mapping = True) 70 | offset_mapping_dict_start, offset_mapping_dict_end = get_offset_mapping(nl_tokens) 71 | if start_span not in offset_mapping_dict_start or end_span not in offset_mapping_dict_end: 72 | # pdb.set_trace() 73 | return None # TODO: exceeds length. fix later 74 | 75 | token_start_idx = offset_mapping_dict_start[start_span] 76 | token_end_idx = offset_mapping_dict_end[end_span] 77 | 78 | nl_tokens['sent_position_ids'] = torch.tensor(np.array([np.arange(0, sentence_len)])).to(device) 79 | nl_tokens['norm_lng_list'] = torch.tensor([[0 for i in range(sentence_len)]]).to(torch.float32).to(device) 80 | nl_tokens['norm_lat_list'] = torch.tensor([[0 for i in range(sentence_len)]]).to(torch.float32).to(device) 81 | nl_tokens['token_type_ids'] = torch.zeros(1,sentence_len).int().to(device) 82 | entity_token_idx = torch.tensor([[token_start_idx, token_end_idx]]).to(device) 83 | 84 | if model_name == 'joint-base' or model_name=='joint-large' or model_name == 'simcse-base': 85 | nl_outputs = model(torch.tensor([nl_tokens['input_ids']]).to(device), 86 | attention_mask = torch.tensor([nl_tokens['attention_mask']]).to(device), 87 | sent_position_ids = nl_tokens['sent_position_ids'], 88 | pivot_token_idx_list=entity_token_idx, 89 | spatial_position_list_x = nl_tokens['norm_lng_list'], 90 | spatial_position_list_y = nl_tokens['norm_lat_list'] , 91 | token_type_ids = nl_tokens['token_type_ids']) 92 | 93 | nl_entity_feature = nl_outputs.pooler_output 94 | nl_entity_feature = nl_entity_feature[0].detach().cpu().numpy() 95 | else: 96 | nl_outputs = model(torch.tensor([nl_tokens['input_ids']]).to(device), 97 | attention_mask =torch.tensor([nl_tokens['attention_mask']]).to(device), 98 | token_type_ids = nl_tokens['token_type_ids'], 99 | ) 100 | embeddings = nl_outputs.last_hidden_state 101 | nl_entity_feature = embeddings[0][token_start_idx:token_end_idx] 102 | nl_entity_feature = torch.mean(nl_entity_feature, axis = 0).detach().cpu().numpy() # (768, ) 103 | # print(nl_entity_feature.shape) 104 | 105 | return nl_entity_feature 106 | 107 | def get_geoname_features(geonames_cand, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, batch_size): 108 | 109 | 110 | geo_feat_list = [] 111 | geo_id_list = [] 112 | geo_loc_list = [] 113 | 114 | # if len(geonames_cand) <= batch_size: 115 | # # process all cands at once 116 | # else: 117 | # # break into rounds 118 | 119 | # break down into chunks 120 | for chunk_idx in range(0, int(math.ceil(1.0 * len(geonames_cand) / batch_size ))): 121 | 122 | pseudo_sent_concat = [] 123 | attention_mask_concat= [] 124 | sent_position_id_concat = [] 125 | pivot_token_idx_concat = [] 126 | position_x_concat = [] 127 | position_y_concat = [] 128 | 129 | for gn_cand in geonames_cand[chunk_idx * batch_size: (chunk_idx+1)*batch_size]: 130 | 131 | pivot_name = gn_cand['info']['name'] 132 | pivot_pos = gn_cand['info']['geometry']['coordinates'] #(lng, lat) 133 | pivot_geonames_id = gn_cand['info']['geoname_id'] 134 | 135 | geo_id_list.append(pivot_geonames_id) 136 | geo_loc_list.append({'lon':pivot_pos[0], 'lat':pivot_pos[1]}) 137 | 138 | 139 | neighbor_info = gn_cand['neighbor_info'] 140 | neighbor_name_list = neighbor_info['name_list'] 141 | neighbor_geometry_list = neighbor_info['geometry_list'] 142 | 143 | 144 | geo_data = spatial_dataset.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill ) 145 | 146 | pseudo_sent_concat.append(geo_data['pseudo_sentence']) 147 | attention_mask_concat.append(geo_data['attention_mask']) 148 | sent_position_id_concat.append(geo_data['sent_position_ids']) 149 | pivot_token_idx_concat.append(geo_data['pivot_token_idx']) 150 | position_x_concat.append(geo_data['norm_lng_list']) 151 | position_y_concat.append(geo_data['norm_lat_list']) 152 | n_samples = len(pseudo_sent_concat) 153 | 154 | 155 | if model_name == 'joint-base' or model_name=='joint-large' or model_name == 'simcse-base': 156 | geo_outputs = model(torch.stack(pseudo_sent_concat, dim = 0).to(device), 157 | attention_mask = torch.stack(attention_mask_concat, dim = 0).to(device), 158 | sent_position_ids = torch.stack(sent_position_id_concat, dim = 0).to(device), 159 | pivot_token_idx_list= torch.stack(pivot_token_idx_concat, dim = 0).to(device), 160 | spatial_position_list_x = torch.stack(position_x_concat, dim = 0).to(device), 161 | spatial_position_list_y = torch.stack(position_y_concat, dim = 0).to(device), 162 | token_type_ids = torch.ones( n_samples, len(geo_data['pseudo_sentence'])).int() .to(device), # type 1 for geo data 163 | ) 164 | geo_feat = geo_outputs.pooler_output 165 | geo_feat = geo_feat.detach().cpu().numpy() 166 | 167 | 168 | else: 169 | raise NotImplementedError 170 | # geo_outputs = model(torch.stack(pseudo_sent_concat, dim = 0).to(device), 171 | # attention_mask = torch.stack(attention_mask_concat, dim = 0).to(device), 172 | # token_type_ids = torch.ones( n_samples, len(geo_data['pseudo_sentence'])).int() .to(device), 173 | # ) 174 | 175 | # embeddings = geo_outputs.last_hidden_state 176 | # TODO: update this!! 177 | # geo_feat = embeddings[0][geo_data['pivot_token_idx'][0]:geo_data['pivot_token_idx'][1]] 178 | # geo_feat = torch.mean(geo_feat, axis = 0).detach().cpu().numpy() # (768, ) 179 | 180 | 181 | # pivot_embed = pivot_embed[0].detach().cpu().numpy() 182 | 183 | # print(geo_feat.shape) 184 | 185 | geo_feat_list.append(geo_feat) 186 | # pdb.set_trace() 187 | 188 | geo_feat_list = np.concatenate(geo_feat_list, axis = 0) 189 | 190 | return geo_feat_list, geo_loc_list, geo_id_list 191 | 192 | 193 | def wiktor_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, batch_size): 194 | 195 | distance_list = [] 196 | acc_count_at_161 = 0 197 | acc_count_total = 0 198 | correct_geoname_count = 0 199 | 200 | with open(out_path, 'w') as f: 201 | pass # flush 202 | 203 | 204 | # overall list, geodestic distance histogram 205 | for query_name, paragraph_list in query_data.items(): 206 | 207 | if query_name in geonames_dict: 208 | geonames_cand = geonames_dict[query_name] 209 | # print(query_name, len(paragraph_list), len(geonames_dict[query_name])) 210 | geoname_features, geonames_loc_list, geonames_id_list = get_geoname_features(geonames_cand, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, batch_size) 211 | else: 212 | continue 213 | # print(query_name, 'not in geonames_dict') 214 | 215 | samename_ret_list = [] 216 | for paragraph in paragraph_list: 217 | # cur_dict = {'text':text, 'feature':feature, 'url':url, 'country':country, 'lat':lat, 'lon':lon, 218 | # 'spans':spans} 219 | if 'url' in paragraph: 220 | wiki_url = paragraph['url'] 221 | else: 222 | wiki_url = None 223 | 224 | text = paragraph['text'] 225 | gt_lat = paragraph['lat'] 226 | gt_lon = paragraph['lon'] 227 | 228 | spans = paragraph['spans'] # TODO: can be improved 229 | if len(spans) == 0: 230 | # pdb.set_trace() 231 | continue 232 | selected_span = spans[0] 233 | start_span, end_span = selected_span[0], selected_span[1] 234 | 235 | 236 | nl_feature = get_nl_feature(text, gt_lat, gt_lon, start_span, end_span, model, model_name, tokenizer, spatial_dist_fill, device) 237 | 238 | if nl_feature is None: continue 239 | 240 | # nl_feature_shape: torch.Size([1, 768]) 241 | 242 | sim_matrix = 1 - sp.distance.cdist(np.array(geoname_features), np.array([nl_feature]), 'cosine') 243 | 244 | closest_match_geonames_id = sort_ref_closest_match(sim_matrix, geonames_id_list) 245 | closest_match_geonames_loc = sort_ref_closest_match(sim_matrix, geonames_loc_list) 246 | 247 | sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order 248 | 249 | ret_dict = dict() 250 | ret_dict['pivot_name'] = query_name 251 | ret_dict['gt_loc'] = {'lon':paragraph['lon'], 'lat':paragraph['lat']} 252 | ret_dict['wiki_url'] = wiki_url 253 | ret_dict['sorted_match_geoname_id'] = [a[0] for a in closest_match_geonames_id] 254 | ret_dict['closest_match_geonames_loc'] = [a[0] for a in closest_match_geonames_loc] 255 | #ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des] 256 | ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix] 257 | 258 | samename_ret_list.append(ret_dict) 259 | 260 | # print(ret_dict['gt_loc'], ret_dict['wiki_url'], ret_dict['closest_match_geonames_loc']) 261 | 262 | gt_loc = (float(paragraph['lat']), float(paragraph['lon'])) 263 | pred_loc = ret_dict['closest_match_geonames_loc'][0] 264 | pred_loc = (pred_loc['lat'], pred_loc['lon']) 265 | error_dist = haversine(gt_loc, pred_loc) 266 | distance_list.append(error_dist) 267 | # pdb.set_trace() 268 | if error_dist < 161: 269 | acc_count_at_161 += 1 270 | 271 | acc_count_total+=1 272 | 273 | ret_dict['sorted_match_geoname_id'] = ret_dict['sorted_match_geoname_id'][:20] 274 | ret_dict['closest_match_geonames_loc'] = ret_dict['closest_match_geonames_loc'][:20] 275 | 276 | with open(out_path, 'a') as f: 277 | json.dump(ret_dict, f) 278 | f.write('\n') 279 | 280 | return {'distance_list':distance_list, 'acc_at_161:': 1.0*acc_count_at_161/acc_count_total} 281 | 282 | 283 | def toponym_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, batch_size): 284 | 285 | distance_list = [] 286 | acc_count_at_161 = 0 287 | acc_count_total = 0 288 | correct_geoname_count = 0 289 | 290 | with open(out_path, 'w') as f: 291 | pass # flush 292 | 293 | for sample in query_data: 294 | # cur_dict = {'sentence':sentence, 'toponyms':[]} 295 | text = sample['sentence'] 296 | 297 | for toponym in sample['toponyms']: 298 | if 'geoname_id' not in toponym: 299 | continue # skip this sample in evaluation 300 | 301 | query_name = toponym['text'] 302 | start_span = toponym['start'] 303 | end_span = toponym['end'] 304 | geoname_id = toponym['geoname_id'] 305 | gt_lat = toponym['lat'] 306 | gt_lon = toponym['lon'] 307 | 308 | nl_feature = get_nl_feature(text, gt_lat, gt_lon, start_span, end_span, model, model_name, tokenizer, spatial_dist_fill, device) 309 | 310 | if nl_feature is None: continue 311 | 312 | if query_name in geonames_dict: 313 | geonames_cand = geonames_dict[query_name] 314 | # print(query_name, len(geonames_dict[query_name])) 315 | geoname_features, geonames_loc_list, geonames_id_list = get_geoname_features(geonames_cand, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, batch_size) 316 | # pdb.set_trace() 317 | # print(geoname_features) 318 | else: 319 | continue 320 | 321 | sim_matrix = 1 - sp.distance.cdist(np.array(geoname_features), np.array([nl_feature]), 'cosine') 322 | 323 | closest_match_geonames_id = sort_ref_closest_match(sim_matrix, geonames_id_list) 324 | closest_match_geonames_loc = sort_ref_closest_match(sim_matrix, geonames_loc_list) 325 | 326 | sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order 327 | 328 | ret_dict = dict() 329 | ret_dict['pivot_name'] = query_name 330 | ret_dict['gt_loc'] = {'lon':gt_lon, 'lat':gt_lat} 331 | ret_dict['geoname_id'] = geoname_id 332 | ret_dict['sorted_match_geoname_id'] = [a[0] for a in closest_match_geonames_id] 333 | ret_dict['closest_match_geonames_loc'] = [a[0] for a in closest_match_geonames_loc] 334 | #ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des] 335 | ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix] 336 | 337 | # samename_ret_list.append(ret_dict) 338 | 339 | gt_loc = (gt_lat, gt_lon) 340 | pred_loc = ret_dict['closest_match_geonames_loc'][0] 341 | pred_loc = (pred_loc['lat'], pred_loc['lon']) 342 | error_dist = haversine(gt_loc, pred_loc) 343 | distance_list.append(error_dist) 344 | 345 | if error_dist < 161: 346 | acc_count_at_161 += 1 347 | 348 | if str(ret_dict['sorted_match_geoname_id'][0]) == geoname_id: 349 | correct_geoname_count += 1 350 | 351 | acc_count_total+=1 352 | ret_dict['sorted_match_geoname_id'] = ret_dict['sorted_match_geoname_id'][:20] 353 | ret_dict['closest_match_geonames_loc'] = ret_dict['closest_match_geonames_loc'][:20] 354 | 355 | with open(out_path, 'a') as f: 356 | json.dump(ret_dict, f) 357 | f.write('\n') 358 | 359 | return {'distance_list':distance_list, 'acc@1': 1.0*correct_geoname_count/acc_count_total, 'acc_at_161:': 1.0*acc_count_at_161/acc_count_total} 360 | 361 | 362 | def entity_linking_func(args): 363 | 364 | model_name = args.model_name 365 | 366 | distance_norm_factor = args.distance_norm_factor 367 | spatial_dist_fill= args.spatial_dist_fill 368 | sep_between_neighbors = True 369 | spatial_bert_weight_dir = args.spatial_bert_weight_dir 370 | spatial_bert_weight_name = args.spatial_bert_weight_name 371 | if_no_spatial_distance = args.no_spatial_distance 372 | 373 | 374 | assert model_name in MODEL_OPTIONS 375 | 376 | 377 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 378 | 379 | out_dir = args.out_dir 380 | 381 | print('out_dir', out_dir) 382 | 383 | if model_name == 'joint-base' or model_name == 'joint-large' or model_name =='simcse-base': 384 | if model_name == 'joint-base' or model_name == 'joint-large': 385 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 386 | else: 387 | tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/unsup-simcse-bert-base-uncased') 388 | 389 | config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance) 390 | 391 | config.vocab_size = tokenizer.vocab_size 392 | 393 | model = SpatialBertModel(config) 394 | 395 | model.to(device) 396 | model.eval() 397 | 398 | # load pretrained weights 399 | weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name) 400 | model = load_spatial_bert_pretrained_weights(model, weight_path) 401 | 402 | 403 | elif model_name in MODEL_OPTIONS: #'bert-base': 404 | model, tokenizer = get_baseline_model(model_name) 405 | # model.config.type_vocab_size=2 406 | model.to(device) 407 | model.eval() 408 | else: 409 | raise NotImplementedError 410 | 411 | spatial_dataset = SpatialDataset(tokenizer , max_token_len=512 , distance_norm_factor=distance_norm_factor, sep_between_neighbors = True) 412 | 413 | 414 | with open(args.query_dataset_path,'r') as f: 415 | query_data = json.load(f) 416 | 417 | geonames_dict = {} 418 | 419 | with open(args.ref_dataset_path,'r') as f: 420 | geonames_data = json.load(f) 421 | for info in geonames_data: 422 | key = next(iter(info)) 423 | value = info[key] 424 | 425 | geonames_dict[key] = value 426 | 427 | if 'WikToR' in args.query_dataset_path: 428 | out_path = os.path.join(out_dir, 'wiktor.json') 429 | eval_info = wiktor_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, args.batch_size) 430 | elif 'lgl' in args.query_dataset_path or 'geowebnews' in args.query_dataset_path: 431 | if 'lgl' in args.query_dataset_path: 432 | out_path = os.path.join(out_dir, 'lgl.json') 433 | elif 'geowebnews' in args.query_dataset_path: 434 | out_path = os.path.join(out_dir, 'geowebnews.json') 435 | 436 | eval_info = toponym_linking(out_path, query_data, geonames_dict, model, model_name, tokenizer, spatial_dist_fill, device, spatial_dataset, args.batch_size) 437 | 438 | # print(distance_list) 439 | # print('acc_at_161:',1.0*acc_count_at_161/acc_count_total) 440 | 441 | with open(out_path, 'a') as f: 442 | json.dump(eval_info, f) 443 | f.write('\n') 444 | 445 | print(eval_info) 446 | 447 | 448 | def main(): 449 | parser = argparse.ArgumentParser() 450 | parser.add_argument('--model_name', type=str, default='joint-base') 451 | parser.add_argument('--query_dataset_path', type=str, default='../../data/WikToR.json') 452 | # parser.add_argument('--ref_dataset_path', type=str, default='../../data/geoname-ids-v3-part02.json') 453 | parser.add_argument('--ref_dataset_path', type=str, default='/home/zekun/datasets/geonames/geonames_for_wiktor/geoname-ids.json') 454 | 455 | parser.add_argument('--out_dir', type=str, default=None) 456 | 457 | parser.add_argument('--distance_norm_factor', type=float, default = 100) 458 | parser.add_argument('--spatial_dist_fill', type=float, default = 90000) 459 | parser.add_argument('--batch_size', type=int, default = 20) 460 | 461 | parser.add_argument('--no_spatial_distance', default=False, action='store_true') 462 | 463 | parser.add_argument('--spatial_bert_weight_dir', type = str, default = None) 464 | parser.add_argument('--spatial_bert_weight_name', type = str, default = None) 465 | 466 | args = parser.parse_args() 467 | print('\n') 468 | print(args) 469 | print('\n') 470 | 471 | # out_dir not None, and out_dir does not exist, then create out_dir 472 | if args.out_dir is not None and not os.path.isdir(args.out_dir): 473 | os.makedirs(args.out_dir) 474 | 475 | entity_linking_func(args) 476 | 477 | # python3 link_geonames.py --out_dir='debug' --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep12_iter24000_0.0061.pth' 478 | # ep14_iter88000_0.0039.pth 479 | # ep14_iter96000_0.0382.pth 480 | 481 | # python3 link_geonames.py --out_dir='debug' --model_name='bert-base' 482 | 483 | # python3 link_geonames.py --model_name='bert-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --out_dir='baselines/bert-base' 484 | 485 | # CUDA_VISIBLE_DEVICES='1' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep14_iter88000_0.0039.pth' 486 | 487 | # May 488 | 489 | # CUDA_VISIBLE_DEVICES='3' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/geowebnews/GWN.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_geowebnews/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0505/' --spatial_bert_weight_name='ep0_iter80000_1.3164.pth' --out_dir='results' 490 | 491 | # CUDA_VISIBLE_DEVICES='3' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0505/' --spatial_bert_weight_name='ep0_iter80000_1.3164.pth' --out_dir='results' 492 | 493 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0505/' --spatial_bert_weight_name='ep1_iter52000_1.3994.pth' --out_dir='results' 494 | 495 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0506/' --spatial_bert_weight_name='ep0_iter48000_1.5495.pth' --out_dir='results-0506' 496 | 497 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/lgl/lgl.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_lgl/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep14_iter108000_0.0172.pth' --out_dir='results-run2' 498 | 499 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --query_dataset_path='/home/zekun/toponym_detection/geowebnews/GWN.json' --ref_dataset_path='/home/zekun/datasets/geonames/geonames_for_geowebnews/geoname-ids.json' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_run2/' --spatial_bert_weight_name='ep14_iter108000_0.0172.pth' --out_dir='results-run2' 500 | 501 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0508/' --spatial_bert_weight_name='ep0_iter144000_0.5711.pth' --out_dir='results-run2' 502 | 503 | # ../../weights_base_0511/ep1_iter84000_0.5168.pth 504 | # CUDA_VISIBLE_DEVICES='0' python3 link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=90000 --spatial_bert_weight_dir='/home/zekun/weights_base_0511/' --spatial_bert_weight_name='ep1_iter84000_0.5168.pth' --out_dir='debug' 505 | 506 | 507 | # CUDA_VISIBLE_DEVICES='1' python3 link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=900 --spatial_bert_weight_dir='/data4/zekun/joint_model/weights_0517/' --spatial_bert_weight_name='ep5_iter04000_0.0486.pth' --out_dir='debug' --query_dataset_path='/data4/zekun/toponym_detection/geowebnews/GWN.json' --ref_dataset_path='/data4/zekun/geonames/geonames_for_geowebnews/geoname-ids.json' 508 | # 'acc@1': 0.23718439173680184, 'acc_at_161:': 0.31675592960979343} 509 | 510 | # CUDA_VISIBLE_DEVICES='1' python3 multi_link_geonames.py --model_name='joint-base' --distance_norm_factor=100 --spatial_dist_fill=900 --spatial_bert_weight_dir='/home/zekun/weights_base_0511/' --spatial_bert_weight_name='ep4_iter08000_0.0326.pth' --out_dir='debug' --batch_size=24 511 | 512 | if __name__ == '__main__': 513 | 514 | main() 515 | 516 | 517 | --------------------------------------------------------------------------------