├── .gitignore ├── Knowledge_Extraction.py ├── LICENSE ├── Pattern_Extraction.py ├── README.md ├── Train_and_Predict.py ├── util.py └── words_alpha.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Knowledge_Extraction.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | 3 | KG_path = 'KG_v0.1.0.db' 4 | 5 | st = time.time() 6 | kg_conn = KG_Connection(db_path=KG_path, mode='memory') 7 | print('Finished in {:.2f} seconds'.format(time.time() - st)) 8 | 9 | print('We are collecting eventualities from ASER...') 10 | selected_eventuality_kg = list() 11 | eventuality_id_to_graph = dict() 12 | 13 | for tmp_key, tmp_eventuality in tqdm(kg_conn.event_cache.items()): 14 | tmp = dict() 15 | tmp['type'] = 'eventuality' 16 | tmp['id'] = tmp_eventuality['_id'] 17 | tmp['verbs'] = tmp_eventuality['verbs'] 18 | tmp['words'] = tmp_eventuality['words'] 19 | tmp['frequency'] = tmp_eventuality['frequency'] 20 | selected_eventuality_kg.append(tmp) 21 | eventuality_id_to_graph[tmp['id']] = eventuality_to_graph(tmp) 22 | 23 | print('We are collecting edges from ASER...') 24 | 25 | selected_edge_kg = list() 26 | edge_id_to_graph = dict() 27 | for tmp_key, tmp_edge in tqdm(kg_conn.relation_cache.items()): 28 | tmp = dict() 29 | event_1 = kg_conn.event_cache[tmp_edge['event1_id']] 30 | event_2 = kg_conn.event_cache[tmp_edge['event2_id']] 31 | tmp['id'] = tmp_edge['_id'] 32 | tmp['type'] = 'edge' 33 | tmp['event_1_verbs'] = event_1['verbs'] 34 | tmp['event_1_words'] = event_1['words'] 35 | tmp['event_2_verbs'] = event_2['verbs'] 36 | tmp['event_2_words'] = event_2['words'] 37 | tmp['frequency'] = tmp_edge['Co_Occurrence'] 38 | tmp['connective'] = 'Co_Occurrence' 39 | for tmp_connective in Connectives: 40 | if tmp_edge[tmp_connective] > 0: 41 | tmp['frequency'] = tmp_edge[tmp_connective] 42 | tmp['connective'] = tmp_connective 43 | selected_edge_kg.append(tmp) 44 | edge_id_to_graph[tmp['id']] = edge_to_graph(tmp) 45 | 46 | number_of_worker = 35 47 | print('Start to collect knowledge from eventualities') 48 | chunked_ASER = chunks(selected_eventuality_kg, number_of_worker) 49 | workers = Pool(number_of_worker) 50 | all_extracted_knowledge = list() 51 | all_eventuality_match = list() 52 | all_results = list() 53 | for aser_subset in chunked_ASER: 54 | tmp_result = workers.apply_async(extract_knowledge_from_eventuality_set, args=(selected_patterns, aser_subset,)) 55 | all_results.append(tmp_result) 56 | workers.close() 57 | workers.join() 58 | all_results = [tmp_result.get() for tmp_result in all_results] 59 | for tmp_result in all_results: 60 | all_extracted_knowledge.append(tmp_result[0]) 61 | all_eventuality_match.append(tmp_result[1]) 62 | 63 | print('Start to merge eventuality knowledge') 64 | extracted_eventuality_knowledge = merge_extracted_knowledge_from_multi_core(all_extracted_knowledge) 65 | 66 | print('Start to merge eventuality matches') 67 | merged_eventuality_match = dict() 68 | for tmp_match_list in all_eventuality_match: 69 | for tmp_k in tmp_match_list: 70 | if tmp_k not in merged_eventuality_match: 71 | merged_eventuality_match[tmp_k] = list() 72 | merged_eventuality_match[tmp_k] += tmp_match_list[tmp_k] 73 | 74 | chunked_ASER = chunks(selected_edge_kg, number_of_worker) 75 | workers = Pool(number_of_worker) 76 | all_extracted_knowledge = list() 77 | all_edge_match = list() 78 | all_results = list() 79 | for aser_subset in chunked_ASER: 80 | tmp_result = workers.apply_async(extract_knowledge_from_edge_set, args=(selected_patterns, aser_subset,)) 81 | all_results.append(tmp_result) 82 | workers.close() 83 | workers.join() 84 | all_results = [tmp_result.get() for tmp_result in all_results] 85 | for tmp_result in all_results: 86 | all_extracted_knowledge.append(tmp_result[0]) 87 | all_edge_match.append(tmp_result[1]) 88 | 89 | print('Start to merge edge knowledge') 90 | extracted_edge_knowledge = merge_extracted_knowledge_from_multi_core(all_extracted_knowledge) 91 | 92 | print('Start to merge edge matches') 93 | merged_edge_match = dict() 94 | for tmp_match_list in all_edge_match: 95 | for tmp_k in tmp_match_list: 96 | if tmp_k not in merged_edge_match: 97 | merged_edge_match[tmp_k] = list() 98 | merged_edge_match[tmp_k] += tmp_match_list[tmp_k] 99 | 100 | print('We are loading all words...') 101 | all_words = list() 102 | with open('words_alpha.txt', 'r') as f: 103 | for line in f: 104 | all_words.append(line[:-1]) 105 | 106 | all_words = set(all_words) 107 | 108 | print('start to merge knowledge...') 109 | extracted_knowledge = dict() 110 | for r in extracted_eventuality_knowledge: 111 | extracted_knowledge[r] = dict() 112 | for p in extracted_eventuality_knowledge[r]: 113 | for tmp_triplet in extracted_eventuality_knowledge[r][p]: 114 | if tmp_triplet in extracted_knowledge[r]: 115 | extracted_knowledge[r][tmp_triplet] += extracted_eventuality_knowledge[r][p][tmp_triplet] 116 | else: 117 | extracted_knowledge[r][tmp_triplet] = extracted_eventuality_knowledge[r][p][tmp_triplet] 118 | for p in extracted_edge_knowledge[r]: 119 | for tmp_triplet in extracted_edge_knowledge[r][p]: 120 | if tmp_triplet in extracted_knowledge[r]: 121 | extracted_knowledge[r][tmp_triplet] += extracted_edge_knowledge[r][p][tmp_triplet] 122 | else: 123 | extracted_knowledge[r][tmp_triplet] = extracted_edge_knowledge[r][p][tmp_triplet] 124 | 125 | filtered_knowledge = dict() 126 | # Check if all the extracted words are English words to filter out other languages. 127 | for r in extracted_knowledge: 128 | print('We are filtering knowledge for relation:', r) 129 | filtered_knowledge[r] = dict() 130 | for tmp_k in tqdm(extracted_knowledge[r]): 131 | head_words = tmp_k.split('$$')[0].split(' ') 132 | tail_words = tmp_k.split('$$')[1].split(' ') 133 | found_invalid_words = False 134 | for w in head_words: 135 | if w not in all_words: 136 | found_invalid_words = True 137 | for w in tail_words: 138 | if w not in all_words: 139 | found_invalid_words = True 140 | if found_invalid_words: 141 | continue 142 | filtered_knowledge[r][tmp_k] = extracted_knowledge[r][tmp_k] 143 | 144 | 145 | # Store extracted knowledge based on relations for next step plausibility prediction 146 | if not os.path.isdir('extracted_knowledge'): 147 | os.mkdir('extracted_knowledge') 148 | 149 | # full_dataset = dict() 150 | missing_count = 0 151 | matched_count = 0 152 | for r in filtered_knowledge: 153 | print('We are working on relation:', r) 154 | tmp_dataset = list() 155 | for tmp_k in tqdm(filtered_knowledge[r]): 156 | tmp_example = dict() 157 | if tmp_k in merged_eventuality_match: 158 | eventuality_observations = list() 159 | for tmp_eventuality in merged_eventuality_match[tmp_k]: 160 | tmp_graph = eventuality_id_to_graph[tmp_eventuality['id']] 161 | tmp_eventuality['graph'] = tmp_graph 162 | eventuality_observations.append(tmp_eventuality) 163 | else: 164 | eventuality_observations = list() 165 | if tmp_k in merged_edge_match: 166 | edge_observations = list() 167 | for tmp_edge in merged_edge_match[tmp_k]: 168 | tmp_graph = edge_id_to_graph[tmp_edge['id']] 169 | tmp_edge['graph'] = tmp_graph 170 | edge_observations.append(tmp_edge) 171 | else: 172 | edge_observations = list() 173 | tmp_example['knowledge'] = tmp_k 174 | tmp_example['eventuality_observations'] = eventuality_observations 175 | tmp_example['edge_observations'] = edge_observations 176 | tmp_example['plausibility'] = 0 177 | if len(eventuality_observations) == 0 and len(edge_observations) == 0: 178 | missing_count += 1 179 | continue 180 | else: 181 | matched_count += 1 182 | tmp_dataset.append(tmp_example) 183 | with open('extracted_knowledge/'+r+'.json', 'w') as f: 184 | json.dump(tmp_dataset, f) 185 | 186 | print('end') 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HKUST-KnowComp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pattern_Extraction.py: -------------------------------------------------------------------------------- 1 | import ujson as json 2 | import spacy 3 | from tqdm import tqdm 4 | import math 5 | 6 | 7 | def path_to_pattern(head_edges, between_edges, tail_edges): 8 | # we only focus on the between_edges first 9 | tmp_edges = list() 10 | pattern = '' 11 | current_word = 'HEAD' 12 | next_word = '' 13 | seen_positions = list() 14 | 15 | # we are working on the between_edges first 16 | while next_word != 'TAIL': 17 | new_word = '' 18 | for edge in between_edges: 19 | if edge[0][1] in seen_positions or edge[2][1] in seen_positions: 20 | continue 21 | if edge[0][0] == current_word: 22 | if edge[2][0] == 'TAIL': 23 | pattern += '->' 24 | pattern += edge[1] 25 | pattern += '->' 26 | next_word = 'TAIL' 27 | else: 28 | pattern += '->' 29 | pattern += edge[1] 30 | pattern += '->' 31 | pattern += edge[2][0] 32 | # pattern += '->' 33 | new_word = edge[2][0] 34 | seen_positions.append(edge[0][1]) 35 | break 36 | elif edge[2][0] == current_word: 37 | if edge[0][0] == 'TAIL': 38 | pattern += '<-' 39 | pattern += edge[1] 40 | pattern += '<-' 41 | next_word = 'TAIL' 42 | else: 43 | pattern += '<-' 44 | pattern += edge[1] 45 | pattern += '<-' 46 | pattern += edge[0][0] 47 | # pattern += '<-' 48 | new_word = edge[0][0] 49 | seen_positions.append(edge[2][1]) 50 | break 51 | current_word = new_word 52 | 53 | # we are working on the head_edges 54 | 55 | if len(head_edges) == 0: 56 | head_pattern = '()' 57 | else: 58 | head_pattern = '(-' 59 | for edge in head_edges: 60 | head_pattern += edge[1] 61 | head_pattern += '-' 62 | head_pattern += ')' 63 | 64 | # we are working on the tail edges 65 | if len(tail_edges) == 0: 66 | tail_pattern = '()' 67 | else: 68 | tail_pattern = '(-' 69 | for edge in tail_edges: 70 | tail_pattern += edge[1] 71 | tail_pattern += '-' 72 | tail_pattern += ')' 73 | 74 | overall_pattern = head_pattern + pattern + tail_pattern 75 | 76 | return overall_pattern 77 | 78 | 79 | def find_shortest_path(all_edges, start, end, used_edges): 80 | potential_pathes = list() 81 | for edge in all_edges: 82 | if edge in used_edges: 83 | continue 84 | if edge[0][1] == start: 85 | if edge[2][1] == end: 86 | return [edge], 1 87 | else: 88 | potential_pathes.append({'edges': [edge], 'new_start': edge[2][1]}) 89 | continue 90 | if edge[2][1] == start: 91 | if edge[0][1] == end: 92 | return [edge], 1 93 | else: 94 | potential_pathes.append({'edges': [edge], 'new_start': edge[0][1]}) 95 | if len(potential_pathes) == 0: 96 | return [], 0 97 | shortest_path = list() 98 | shortest_length = 100 99 | for potential_path in potential_pathes: 100 | all_used_edges = used_edges + potential_path['edges'] 101 | tmp_new_edges, tmp_new_length = find_shortest_path(all_edges, potential_path['new_start'], end, all_used_edges) 102 | if len(tmp_new_edges) > 0 and tmp_new_length < shortest_length: 103 | shortest_length = tmp_new_length 104 | shortest_path = tmp_new_edges + potential_path['edges'] 105 | return shortest_path, shortest_length + 1 106 | 107 | 108 | def extract_pattern(OMCS_pair, tmp_graph): 109 | head_words = OMCS_pair.split('$$')[0].split() 110 | tail_words = OMCS_pair.split('$$')[1].split() 111 | eventuality_words = tmp_graph['words'].split() 112 | # check repeat words 113 | for w in head_words: 114 | if w in tail_words: 115 | return None 116 | 117 | # locate position 118 | head_positions = list() 119 | tail_positions = list() 120 | for w in head_words: 121 | found_location = False 122 | for i, tmp_word in enumerate(eventuality_words): 123 | if found_location: 124 | if w == tmp_word: 125 | return None 126 | else: 127 | if w == tmp_word: 128 | head_positions.append(i) 129 | found_location = True 130 | for w in tail_words: 131 | found_location = False 132 | for i, tmp_word in enumerate(eventuality_words): 133 | if found_location: 134 | if w == tmp_word: 135 | return None 136 | else: 137 | if w == tmp_word: 138 | tail_positions.append(i) 139 | found_location = True 140 | 141 | doc = nlp(tmp_graph['words']) 142 | all_dependency_edges = list() 143 | for word in doc: 144 | all_dependency_edges.append(((word.head.norm_, word.head.i), word.dep_, (word.norm_, word.i))) 145 | 146 | head_dependency_edges = list() 147 | tail_dependency_edges = list() 148 | 149 | # find head internal edges: 150 | if len(head_positions) > 1: 151 | for position_1 in head_positions: 152 | for position_2 in head_positions: 153 | if position_1 < position_2: 154 | paths, length = find_shortest_path(all_dependency_edges, position_1, position_2, list()) 155 | head_dependency_edges += paths 156 | head_dependency_edges = list(set(head_dependency_edges)) 157 | 158 | # find tail internal edges 159 | if len(tail_positions) > 1: 160 | for position_1 in tail_positions: 161 | for position_2 in tail_positions: 162 | if position_1 < position_2: 163 | paths, length = find_shortest_path(all_dependency_edges, position_1, position_2, list()) 164 | tail_dependency_edges += paths 165 | tail_dependency_edges = list(set(tail_dependency_edges)) 166 | 167 | head_contained_positions = list() 168 | tail_contained_positions = list() 169 | if len(head_dependency_edges) == 0: 170 | head_contained_positions.append(head_positions[0]) 171 | else: 172 | for d_edge in head_dependency_edges: 173 | head_contained_positions.append(d_edge[0][1]) 174 | head_contained_positions.append(d_edge[2][1]) 175 | if len(tail_dependency_edges) == 0: 176 | tail_contained_positions.append(tail_positions[0]) 177 | else: 178 | for d_edge in tail_dependency_edges: 179 | tail_contained_positions.append(d_edge[0][1]) 180 | tail_contained_positions.append(d_edge[2][1]) 181 | 182 | # We need to check if there is overlap 183 | for position in head_contained_positions: 184 | if position in tail_contained_positions: 185 | return None 186 | 187 | new_edges = list() 188 | for d_edge in all_dependency_edges: 189 | if d_edge[0][1] in head_contained_positions: 190 | if d_edge[2][1] in head_contained_positions: 191 | continue 192 | elif d_edge[2][1] in tail_contained_positions: 193 | new_edges.append((('HEAD', 'HEAD'), d_edge[1], ('TAIL', 'TAIL'))) 194 | else: 195 | new_edges.append((('HEAD', 'HEAD'), d_edge[1], d_edge[2])) 196 | elif d_edge[0][1] in tail_contained_positions: 197 | if d_edge[2][1] in head_contained_positions: 198 | new_edges.append((('TAIL', 'TAIL'), d_edge[1], ('HEAD', 'HEAD'))) 199 | elif d_edge[2][1] in tail_contained_positions: 200 | continue 201 | else: 202 | new_edges.append((('TAIL', 'TAIL'), d_edge[1], d_edge[2])) 203 | else: 204 | if d_edge[2][1] in head_contained_positions: 205 | new_edges.append((d_edge[0], d_edge[1], ('HEAD', 'HEAD'))) 206 | elif d_edge[2][1] in tail_contained_positions: 207 | new_edges.append((d_edge[0], d_edge[1], ('TAIL', 'TAIL'))) 208 | else: 209 | new_edges.append((d_edge[0], d_edge[1], d_edge[2])) 210 | between_edges, _ = find_shortest_path(new_edges, 'HEAD', 'TAIL', list()) 211 | 212 | # find shortest path between head and tail 213 | if len(between_edges) == 0: 214 | return None 215 | 216 | pattern = path_to_pattern(head_dependency_edges, between_edges, tail_dependency_edges) 217 | 218 | return pattern 219 | 220 | 221 | def extract_pattern_from_edge(OMCS_pair, tmp_edge): 222 | head_words = OMCS_pair.split('$$')[0].split() 223 | tail_words = OMCS_pair.split('$$')[1].split() 224 | 225 | eventuality1_words = tmp_edge['event_1_words'].split() 226 | eventuality2_words = tmp_edge['event_2_words'].split() 227 | 228 | # check repeat words 229 | for w in head_words: 230 | if w in tail_words: 231 | return None 232 | 233 | head_in_event1 = True 234 | head_in_event2 = True 235 | tail_in_event1 = True 236 | tail_in_event2 = True 237 | for w in head_words: 238 | if w not in eventuality1_words: 239 | head_in_event1 = False 240 | if w not in eventuality2_words: 241 | head_in_event2 = False 242 | for w in tail_words: 243 | if w not in eventuality1_words: 244 | tail_in_event1 = False 245 | if w not in eventuality2_words: 246 | tail_in_event2 = False 247 | 248 | if (head_in_event1 and tail_in_event2 and not head_in_event2 and not tail_in_event1) or ( 249 | head_in_event2 and tail_in_event1 and not head_in_event1 and not tail_in_event2): 250 | pass 251 | else: 252 | return None 253 | 254 | all_words = list() 255 | parsed_eventuality1_words = list() 256 | doc = nlp(tmp_edge['event_1_words']) 257 | event1_dependency_edges = list() 258 | event1_verb = [] 259 | for word in doc: 260 | event1_dependency_edges.append(((word.head.norm_, word.head.i), word.dep_, (word.norm_, word.i))) 261 | all_words.append(word.text) 262 | parsed_eventuality1_words.append(word.text) 263 | if word.dep_ == 'ROOT': 264 | event1_verb = (word.norm_, word.i) 265 | 266 | doc = nlp(tmp_edge['event_2_words']) 267 | event2_dependency_edges = list() 268 | event2_verb = [] 269 | for word in doc: 270 | event2_dependency_edges.append(((word.head.norm_, word.head.i + len(parsed_eventuality1_words)), word.dep_, 271 | (word.norm_, word.i + len(parsed_eventuality1_words)))) 272 | all_words.append(word.text) 273 | if word.dep_ == 'ROOT': 274 | event2_verb = (word.norm_, word.i + len(parsed_eventuality1_words)) 275 | 276 | head_dependency_edges = list() 277 | tail_dependency_edges = list() 278 | all_dependency_edges = event1_dependency_edges + event2_dependency_edges 279 | 280 | all_dependency_edges.append((event1_verb, tmp_edge['connective'], event2_verb)) 281 | 282 | # locate position 283 | head_positions = list() 284 | tail_positions = list() 285 | for w in head_words: 286 | found_location = False 287 | for i, tmp_word in enumerate(all_words): 288 | if found_location: 289 | if w == tmp_word: 290 | return None 291 | else: 292 | if w == tmp_word: 293 | head_positions.append(i) 294 | found_location = True 295 | for w in tail_words: 296 | found_location = False 297 | for i, tmp_word in enumerate(all_words): 298 | if found_location: 299 | if w == tmp_word: 300 | return None 301 | else: 302 | if w == tmp_word: 303 | tail_positions.append(i) 304 | found_location = True 305 | 306 | if head_in_event1: 307 | # find head internal edges: 308 | if len(head_positions) > 1: 309 | for position_1 in head_positions: 310 | for position_2 in head_positions: 311 | if position_1 < position_2: 312 | paths, length = find_shortest_path(event1_dependency_edges, position_1, position_2, list()) 313 | head_dependency_edges += paths 314 | head_dependency_edges = list(set(head_dependency_edges)) 315 | 316 | # find tail internal edges 317 | if len(tail_positions) > 1: 318 | for position_1 in tail_positions: 319 | for position_2 in tail_positions: 320 | if position_1 < position_2: 321 | paths, length = find_shortest_path(event2_dependency_edges, position_1, position_2, list()) 322 | tail_dependency_edges += paths 323 | tail_dependency_edges = list(set(tail_dependency_edges)) 324 | else: 325 | # find head internal edges: 326 | if len(head_positions) > 1: 327 | for position_1 in head_positions: 328 | for position_2 in head_positions: 329 | if position_1 < position_2: 330 | paths, length = find_shortest_path(event2_dependency_edges, position_1, position_2, list()) 331 | head_dependency_edges += paths 332 | head_dependency_edges = list(set(head_dependency_edges)) 333 | 334 | # find tail internal edges 335 | if len(tail_positions) > 1: 336 | for position_1 in tail_positions: 337 | for position_2 in tail_positions: 338 | if position_1 < position_2: 339 | paths, length = find_shortest_path(event1_dependency_edges, position_1, position_2, list()) 340 | tail_dependency_edges += paths 341 | tail_dependency_edges = list(set(tail_dependency_edges)) 342 | 343 | head_contained_positions = list() 344 | tail_contained_positions = list() 345 | if len(head_dependency_edges) == 0: 346 | head_contained_positions.append(head_positions[0]) 347 | else: 348 | for d_edge in head_dependency_edges: 349 | head_contained_positions.append(d_edge[0][1]) 350 | head_contained_positions.append(d_edge[2][1]) 351 | if len(tail_dependency_edges) == 0: 352 | tail_contained_positions.append(tail_positions[0]) 353 | else: 354 | for d_edge in tail_dependency_edges: 355 | tail_contained_positions.append(d_edge[0][1]) 356 | tail_contained_positions.append(d_edge[2][1]) 357 | 358 | # We need to check if there is overlap 359 | for position in head_contained_positions: 360 | if position in tail_contained_positions: 361 | return None 362 | 363 | new_edges = list() 364 | for d_edge in all_dependency_edges: 365 | if d_edge[0][1] in head_contained_positions: 366 | if d_edge[2][1] in head_contained_positions: 367 | continue 368 | elif d_edge[2][1] in tail_contained_positions: 369 | new_edges.append((('HEAD', 'HEAD'), d_edge[1], ('TAIL', 'TAIL'))) 370 | else: 371 | new_edges.append((('HEAD', 'HEAD'), d_edge[1], d_edge[2])) 372 | elif d_edge[0][1] in tail_contained_positions: 373 | if d_edge[2][1] in head_contained_positions: 374 | new_edges.append((('TAIL', 'TAIL'), d_edge[1], ('HEAD', 'HEAD'))) 375 | elif d_edge[2][1] in tail_contained_positions: 376 | continue 377 | else: 378 | new_edges.append((('TAIL', 'TAIL'), d_edge[1], d_edge[2])) 379 | else: 380 | if d_edge[2][1] in head_contained_positions: 381 | new_edges.append((d_edge[0], d_edge[1], ('HEAD', 'HEAD'))) 382 | elif d_edge[2][1] in tail_contained_positions: 383 | new_edges.append((d_edge[0], d_edge[1], ('TAIL', 'TAIL'))) 384 | else: 385 | new_edges.append((d_edge[0], d_edge[1], d_edge[2])) 386 | between_edges, _ = find_shortest_path(new_edges, 'HEAD', 'TAIL', list()) 387 | 388 | # find shortest path between head and tail 389 | if len(between_edges) == 0: 390 | return None 391 | 392 | pattern = path_to_pattern(head_dependency_edges, between_edges, tail_dependency_edges) 393 | 394 | return pattern 395 | 396 | 397 | def get_unique_score(tmp_p, tmp_r, unique_dict): 398 | tmp_score = None 399 | for relation_pair in unique_dict[tmp_p]: 400 | if relation_pair[0] == tmp_r: 401 | tmp_score = relation_pair[1] 402 | break 403 | return tmp_score 404 | 405 | 406 | def compute_length_score(tmp_pattern): 407 | head_pattern = tmp_pattern.split(')')[0][1:] 408 | internal_pattern = tmp_pattern.split(')')[1].split('(')[0] 409 | tail_pattern = tmp_pattern.split('(')[2][:-1] 410 | 411 | head_count = 0 412 | for w in head_pattern.split('-'): 413 | if w not in ['', '<', '>']: 414 | head_count += 1 415 | internal_count = 0 416 | for w in internal_pattern.split('-'): 417 | if w not in ['', '<', '>']: 418 | internal_count += 1 419 | tail_count = 0 420 | for w in tail_pattern.split('-'): 421 | if w not in ['', '<', '>']: 422 | tail_count += 1 423 | 424 | tmp_score = min(3, head_count + internal_count + tail_count) 425 | 426 | return tmp_score 427 | 428 | 429 | def find_discourse_relation(tmp_pattern): 430 | tmp_status = False 431 | for discourse_r in discourse_relations: 432 | if discourse_r in tmp_pattern: 433 | tmp_status = True 434 | break 435 | return tmp_status 436 | 437 | 438 | def check_pattern_stop_relations(stop_relations, pattern): 439 | for tmp_r in stop_relations: 440 | if tmp_r in pattern: 441 | return True 442 | return False 443 | 444 | 445 | selected_relations = ['AtLocation', 'CapableOf', 'Causes', 'CausesDesire', 'CreatedBy', 'DefinedAs', 'Desires', 'HasA', 446 | 'HasPrerequisite', 'HasProperty', 'HasSubevent', 'HasFirstSubevent', 'HasLastSubevent', 447 | 'InstanceOf', 'LocatedNear', 'MadeOf', 'MotivatedByGoal', 'PartOf', 'ReceivesAction', 'UsedFor'] 448 | 449 | discourse_relations = ['Precedence', 'Succession', 'Synchronous', 'Reason', 'Result', 'Condition', 'Contrast', 450 | 'Concession', 'Conjunction', 'Instantiation', 'Restatement', 'Alternative', 'ChosenAlternative', 451 | 'Exception'] 452 | 453 | with open('node_matches.json', 'r') as f: 454 | sample_data = json.load(f) 455 | nlp = spacy.load('en') 456 | 457 | raw_eventuality_patterns = dict() 458 | 459 | for tmp_r in sample_data: 460 | print('We are working on:', tmp_r) 461 | test_data = sample_data[tmp_r] 462 | pattern_counting = dict() 463 | for OMCS_pair in tqdm(test_data): 464 | if len(test_data[OMCS_pair]) == 0: 465 | continue 466 | for tmp_eventuality in test_data[OMCS_pair][:50]: 467 | pattern = extract_pattern(OMCS_pair, tmp_eventuality) 468 | if not pattern: 469 | continue 470 | if pattern not in pattern_counting: 471 | pattern_counting[pattern] = 0 472 | pattern_counting[pattern] += 1 473 | 474 | sorted_patterns = sorted(pattern_counting.items(), key=lambda x: x[1], reverse=True) 475 | selected_patterns = sorted_patterns 476 | raw_eventuality_patterns[tmp_r] = selected_patterns 477 | 478 | 479 | with open('edge_matches.json', 'r') as f: 480 | sample_edge_data = json.load(f) 481 | nlp = spacy.load('en') 482 | 483 | raw_edge_patterns = dict() 484 | 485 | for tmp_r in sample_edge_data: 486 | print('We are working on:', tmp_r) 487 | test_edge_data = sample_edge_data[tmp_r] 488 | pattern_counting = dict() 489 | for OMCS_pair in tqdm(test_edge_data): 490 | if len(test_edge_data[OMCS_pair]) == 0: 491 | continue 492 | selected_match_eventualities = test_edge_data[OMCS_pair][:50] 493 | for tmp_edge in selected_match_eventualities: 494 | pattern = extract_pattern_from_edge(OMCS_pair, tmp_edge) 495 | if not pattern: 496 | continue 497 | if pattern not in pattern_counting: 498 | pattern_counting[pattern] = 0 499 | pattern_counting[pattern] += 1 500 | 501 | sorted_patterns = sorted(pattern_counting.items(), key=lambda x: x[1], reverse=True) 502 | selected_patterns = sorted_patterns 503 | raw_edge_patterns[tmp_r] = selected_patterns 504 | 505 | new_eventuality_patterns = dict() 506 | seen_eventuality_patterns = dict() 507 | for r in raw_eventuality_patterns: 508 | if r not in selected_relations: 509 | continue 510 | new_eventuality_patterns[r] = list() 511 | seen_eventuality_patterns[r] = list() 512 | for pattern in raw_eventuality_patterns[r]: 513 | no_direction_pattern = pattern[0].replace('>', '').replace('<', '') 514 | if no_direction_pattern in seen_eventuality_patterns[r]: 515 | continue 516 | seen_eventuality_patterns[r].append(no_direction_pattern) 517 | new_eventuality_patterns[r].append(pattern) 518 | 519 | new_edge_patterns = dict() 520 | seen_edge_patterns = dict() 521 | for r in raw_edge_patterns: 522 | if r not in selected_relations: 523 | continue 524 | new_edge_patterns[r] = list() 525 | seen_edge_patterns[r] = list() 526 | for pattern in raw_edge_patterns[r]: 527 | no_direction_pattern = pattern[0].replace('>', '').replace('<', '') 528 | if no_direction_pattern in seen_edge_patterns[r]: 529 | continue 530 | seen_edge_patterns[r].append(no_direction_pattern) 531 | new_edge_patterns[r].append(pattern) 532 | 533 | eventuality_patterns = new_eventuality_patterns 534 | edge_patterns = new_edge_patterns 535 | 536 | 537 | with open('lemmatized_commonsense_knowledge.json', 'r') as f: 538 | lemmatized_commonsense_knowledge = json.load(f) 539 | 540 | 541 | all_eventualities_patterns_count = dict() 542 | for r in eventuality_patterns: 543 | for p in eventuality_patterns[r]: 544 | if p[0] not in all_eventualities_patterns_count: 545 | all_eventualities_patterns_count[p[0]] = dict() 546 | all_eventualities_patterns_count[p[0]][r] = p[1] / math.sqrt(len(lemmatized_commonsense_knowledge[r])) 547 | 548 | # prepare u_score 549 | new_eventuality_pattern_count = dict() 550 | for p in all_eventualities_patterns_count: 551 | sum_count = 0 552 | for r in all_eventualities_patterns_count[p]: 553 | sum_count += all_eventualities_patterns_count[p][r] 554 | new_tmp_count = list() 555 | for r in all_eventualities_patterns_count[p]: 556 | new_tmp_count.append((r, all_eventualities_patterns_count[p][r] / sum_count)) 557 | sorted_tmp_count = sorted(new_tmp_count, key=lambda x: x[1], reverse=True) 558 | new_eventuality_pattern_count[p] = sorted_tmp_count 559 | 560 | # p[1] is the counting (c_score) 561 | eventuality_patterns_by_score = dict() 562 | for r in eventuality_patterns: 563 | tmp_patterns = list() 564 | for p in eventuality_patterns[r]: 565 | u_score = get_unique_score(p[0], r, new_eventuality_pattern_count) 566 | l_score = compute_length_score(p[0]) 567 | if u_score: 568 | tmp_patterns.append((p[0], p[1] * l_score * u_score)) 569 | eventuality_patterns_by_score[r] = sorted(tmp_patterns, key=lambda x: x[1], reverse=True) 570 | 571 | all_edge_patterns_count = dict() 572 | for r in edge_patterns: 573 | for p in edge_patterns[r]: 574 | if p[0] not in all_edge_patterns_count: 575 | all_edge_patterns_count[p[0]] = dict() 576 | all_edge_patterns_count[p[0]][r] = p[1] / math.sqrt(len(lemmatized_commonsense_knowledge[r])) 577 | 578 | # prepare u_score 579 | new_edge_pattern_count = dict() 580 | for p in all_edge_patterns_count: 581 | sum_count = 0 582 | for r in all_edge_patterns_count[p]: 583 | sum_count += all_edge_patterns_count[p][r] 584 | new_tmp_count = list() 585 | for r in all_edge_patterns_count[p]: 586 | new_tmp_count.append((r, all_edge_patterns_count[p][r] / sum_count)) 587 | sorted_tmp_count = sorted(new_tmp_count, key=lambda x: x[1], reverse=True) 588 | new_edge_pattern_count[p] = sorted_tmp_count 589 | 590 | # p[1] is the counting (c_score) 591 | edge_patterns_by_score = dict() 592 | for r in edge_patterns: 593 | tmp_patterns = list() 594 | for p in edge_patterns[r]: 595 | u_score = get_unique_score(p[0], r, new_edge_pattern_count) 596 | l_score = compute_length_score(p[0]) 597 | if u_score: 598 | tmp_patterns.append((p[0], p[1] * l_score * u_score)) 599 | edge_patterns_by_score[r] = sorted(tmp_patterns, key=lambda x: x[1], reverse=True) 600 | 601 | # Merge extracted patterns from eventuality and edge 602 | overall_pattern_by_score = dict() 603 | for r in edge_patterns_by_score: 604 | tmp_patterns = list() 605 | overall_score = 0 606 | for pattern in eventuality_patterns_by_score[r]: 607 | tmp_patterns.append(pattern) 608 | overall_score += pattern[1] 609 | for pattern in edge_patterns_by_score[r]: 610 | tmp_patterns.append((pattern[0], pattern[1])) 611 | overall_score += pattern[1] 612 | tmp_patterns = sorted(tmp_patterns, key=lambda x: x[1], reverse=True) 613 | overall_pattern_by_score[r] = list() 614 | for pattern in tmp_patterns: 615 | overall_pattern_by_score[r].append((pattern[0], pattern[1] / overall_score)) 616 | 617 | # setup the linguistic relation we do not want in our pattern, which is like the stop words filtering. 618 | pattern_stop_relations = ['det'] 619 | threshold = 0.05 620 | selected_patterns = dict() 621 | for r in overall_pattern_by_score: 622 | tmp_selected_pattern = list() 623 | for pattern in overall_pattern_by_score[r]: 624 | if pattern[1] > threshold and not check_pattern_stop_relations(pattern_stop_relations, pattern[0]): 625 | tmp_selected_pattern.append(pattern) 626 | selected_patterns[r] = tmp_selected_pattern 627 | 628 | with open('selected_patterns.json', 'w') as f: 629 | json.dump(selected_patterns, f) 630 | 631 | print('end') 632 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransOMCS 2 | 3 | 4 | This is the github repo for IJCAI 2020 paper ["TransOMCS: From Linguistic Graphs to Commonsense Knowledge"](https://arxiv.org/abs/2005.00206). 5 | 6 | ## Dependency 7 | 8 | Python 3.6, Pytorch 1.0 9 | 10 | 11 | ## Introduction of TransOMCS 12 | 13 | If you only want to use TransOMCS, you can **download** it from [TransOMCS](https://hkustconnect-my.sharepoint.com/:u:/g/personal/hzhangal_connect_ust_hk/EVeNd_qvealEiTi7gs0Xu6sBbPIZI5ncD7Z1MBMdOz5CXw?e=VWCAbg). 14 | 15 | Without any further filtering, TransOMCS contains 20 commonsense relations, 101 thousand unique words, and 18.48 million triplets. 16 | 17 | Here are the statistics and examples of different commonsense relations. 18 | 19 | | Relation Name | Number of triplets | Reasonable Ratio | Example| 20 | | :---: | :---: | :---: | :---:| 21 | | CapableOf | 6,145,829 | 58.4% | (government, CapableOf, protect) | 22 | | UsedFor | 3,475,254 | 50.8% | (kitchen, UsedFor, eat in) | 23 | | HasProperty | 2,127,824 | 59.1% | (account, HasProperty, established) | 24 | | AtLocation | 1,969,298 | 51.3% | (dryer, AtLocation, dishwasher) | 25 | | HasA | 1,562,961 | 68.9% | (forest, HasA, pool) | 26 | | ReceivesAction | 1,492,915 | 53.7% | (news, ReceivesAction, misattribute) | 27 | | InstanceOf | 777,688 | 52.2% | (atlanta, InstanceOf, city) | 28 | | PartOf | 357,486 | 62.8% | (player, PartOf, team) | 29 | | CausesDesire | 249,755 | 52.0% | (music, CausesDesire, listen) | 30 | | MadeOf | 114,111 | 55.3% | (world, MadeOf, country) | 31 | | CreatedBy | 52,957 | 64.6% | (film, CreatedBy, director) | 32 | | Causes | 50,439 | 53.4% | (misinterpret, Causes, apologize) | 33 | | HasPrerequisite | 43,141 | 62.7% | (doubt, HasPrerequisite, bring proof) | 34 | | HasSubevent | 18,904 | 56.1% | (be sure, HasSubevent, ask) | 35 | | MotivatedByGoal | 15,322 | 55.8% | (come, MotivatedByGoal, fun) | 36 | | HasLastSubevent | 14,048 | 58.9% | (hungry, HasLastSubevent, eat) | 37 | | Desires | 10,668 | 56.4% | (dog, Desires, play) | 38 | | HasFirstSubevent | 2,962 | 58.4% | (talk to, HasFirstSubevent, call) | 39 | | DefinedAs | 36 | 37.5% | (door, DefinedAs, entrance) | 40 | | LocatedNear | 19 | 85.7% | (shoe, LocatedNear, foot) | 41 | 42 | The reasonable ratio scores are annotated on the random sample over all of the extracted knowledge (no knowledge ranking). 43 | 44 | In general, TransOMCS is still quite noisy because TransOMCS is extracted from raw data with patterns. 45 | However, as shown in the paper, a careful use of the data in the downstream applications helps. 46 | We will keep working on improving its quality. 47 | 48 | ## Construction of TransOMCS 49 | 50 | If you want to repeat the process of creating TransOMCS with OMCS and ASER, please follow the following steps. 51 | 52 | 1. Download the core version of ASER from [ASER Homepage](https://hkust-knowcomp.github.io/ASER/) and install ASER 0.1 following [the guideline](https://github.com/HKUST-KnowComp/ASER/blob/master/ASER.ipynb). 53 | 2. Download the selected Commonsense OMCS Tuples and associated ASER graphs from [OMCS and ASER matches](https://hkustconnect-my.sharepoint.com/:u:/g/personal/hzhangal_connect_ust_hk/EfFZFamzsmdKozyrU0-TtXsBDbStkt_FmPyeFM2kT-K9FQ?e=noAb7u). 54 | 3. Download the randomly split knowledge ranking dataset from [Ranking Dataset](https://hkustconnect-my.sharepoint.com/:u:/g/personal/hzhangal_connect_ust_hk/Efc7NeRYSVpHqcGuflDU3uoBRPaks4Mz1kG_R9OUwviPLw?e=oJB3yA). 55 | 4. Unzip the downloaded matched OMCS tuple and ASER graphs in the same folder. 56 | 5. Extract patterns: `python Pattern_Extraction.py`. 57 | 6. Apply the extracted patterns to extract knowledge from ASER (You need to modify the location of your .db file): `python Knowledge_Extraction.py`. 58 | 7. Train a ranking model to rank extracted knowledge: `python Train_and_Predict.py`. 59 | 60 | 61 | ## Application of TransOMCS 62 | 63 | 64 | #### Reading Comprehension 65 | Please use the code in [reading comprehension model](https://github.com/intfloat/commonsense-rc) and replace the external knowledge with different subsets of TransOMCS based on your need. 66 | 67 | #### Dialog Generation 68 | Please use the code in [dialog model](https://github.com/HKUST-KnowComp/ASER/tree/master/experiment/Dialogue) and replace the external knowledge with different subsets of TransOMCS based on your need. 69 | 70 | ## TODO 71 | 72 | 1. Filter the current TransOMCS to further improve the quality (e.g., merge pronouns like 'he' and 'she' to human). 73 | 74 | ## Citation 75 | 76 | @inproceedings{zhang2020TransOMCS, 77 | author = {Hongming Zhang and Daniel Khashabi and Yangqiu Song and Dan Roth}, 78 | title = {TransOMCS: From Linguistic Graphs to Commonsense Knowledge}, 79 | booktitle = {Proceedings of International Joint Conference on Artificial Intelligence (IJCAI) 2020}, 80 | year = {2020} 81 | } 82 | 83 | ## Others 84 | If you have any other questions about this repo, you are welcome to open an issue or send me an [email](mailto:hzhangal@cse.ust.hk), I will respond to that as soon as possible. -------------------------------------------------------------------------------- /Train_and_Predict.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | import torch 3 | from pytorch_transformers import * 4 | import logging 5 | import argparse 6 | from torch.nn.utils.rnn import pad_sequence 7 | import torch.nn.functional as F 8 | import os 9 | import math 10 | import numpy 11 | 12 | Connective_dict = {'Precedence': 'before', 'Succession': 'after', 'Synchronous': 'simultaneously', 'Reason': 'because', 13 | 'Result': 'so', 'Condition': 'if', 'Contrast': 'but', 'Concession': 'although', 14 | 'Conjunction': 'and', 'Instantiation': 'for example', 'Restatement': 'in other words', 15 | 'Alternative': 'or', 'ChosenAlternative': 'instead', 'Exception': 'except'} 16 | 17 | all_relations = ['AtLocation', 'CapableOf', 'Causes', 'CausesDesire', 'CreatedBy', 'DefinedAs', 'Desires', 'HasA', 18 | 'HasPrerequisite', 'HasProperty', 'HasSubevent', 'HasFirstSubevent', 'HasLastSubevent', 'InstanceOf', 19 | 'MadeOf', 'MotivatedByGoal', 'PartOf', 'ReceivesAction', 'UsedFor'] 20 | 21 | 22 | def get_adj_matrix(tmp_tokenized_sentence, tmp_graph, max_length): 23 | raw_matrix = numpy.zeros((max_length, max_length)) 24 | for e in tmp_graph: 25 | head_word_index = tokenizer.encode(e[0][0])[0] 26 | tail_word_index = tokenizer.encode(e[2][0])[0] 27 | if head_word_index in tmp_tokenized_sentence and tail_word_index in tmp_tokenized_sentence: 28 | raw_matrix[tmp_tokenized_sentence.index(head_word_index)][ 29 | tmp_tokenized_sentence.index(tail_word_index)] = 1.0 30 | for tmp_p in range(len(tmp_tokenized_sentence)): 31 | raw_matrix[tmp_p][tmp_p] = 1.0 32 | return torch.tensor(raw_matrix) 33 | 34 | 35 | class TrainingExample: 36 | def __init__(self, raw_example): 37 | self.knowledge = raw_example['knowledge'] 38 | self.eventuality_observation = raw_example['eventuality_observations'] 39 | self.edge_observation = raw_example['edge_observations'] 40 | if raw_example['label'] == 'positive': 41 | self.label = torch.tensor([1]).to(device) 42 | else: 43 | self.label = torch.tensor([0]).to(device) 44 | self.all_observations, self.all_head_masks, self.all_tail_masks, self.all_frequencies, self.all_types, self.all_adj_matrices = self.tensorize_training_example() 45 | 46 | def tensorize_training_example(self): 47 | all_observations = list() 48 | all_head_masks = list() 49 | all_tail_masks = list() 50 | all_frequencies = list() # number of frequency and type of graph 51 | all_types = list() 52 | all_adj_matrices = list() 53 | head_words = tokenizer.encode(self.knowledge.split('$$')[0]) 54 | tail_words = tokenizer.encode(self.knowledge.split('$$')[1]) 55 | 56 | max_observation_length = 0 57 | for tmp_observation in self.eventuality_observation: 58 | tmp_sentence = tokenizer.encode('[CLS] ' + tmp_observation['words'] + ' . [SEP]') 59 | if len(tmp_sentence) > max_observation_length: 60 | max_observation_length = len(tmp_sentence) 61 | 62 | for tmp_observation in self.edge_observation: 63 | tmp_sentence = tokenizer.encode('[CLS] ' + tmp_observation['event_1_words'] + ' . [SEP] ' + Connective_dict[ 64 | tmp_observation['connective']] + ' ' + tmp_observation['event_2_words'] + ' . [SEP] ') 65 | if len(tmp_sentence) > max_observation_length: 66 | max_observation_length = len(tmp_sentence) 67 | 68 | for tmp_observation in self.eventuality_observation: 69 | tmp_sentence = tokenizer.encode('[CLS] ' + tmp_observation['words'] + ' . [SEP]') 70 | tmp_head_mask = list() 71 | for w in tmp_sentence: 72 | if w in head_words: 73 | tmp_head_mask.append(1.0) 74 | else: 75 | tmp_head_mask.append(0.0) 76 | tmp_tail_mask = list() 77 | for w in tmp_sentence: 78 | if w in tail_words: 79 | tmp_tail_mask.append(1.0) 80 | else: 81 | tmp_tail_mask.append(0.0) 82 | if tmp_observation['frequency'] > 64: 83 | tmp_frequency = [6] 84 | else: 85 | tmp_frequency = [int(math.log(tmp_observation['frequency'], 2))] 86 | all_observations.append(torch.tensor(tmp_sentence)) 87 | all_head_masks.append(torch.tensor(tmp_head_mask)) 88 | all_tail_masks.append(torch.tensor(tmp_tail_mask)) 89 | all_frequencies.append(torch.tensor(tmp_frequency)) 90 | all_types.append(torch.tensor([0])) 91 | all_adj_matrices.append(get_adj_matrix(tmp_sentence, tmp_observation['graph'], max_observation_length)) 92 | for tmp_observation in self.edge_observation: 93 | tmp_sentence = tokenizer.encode('[CLS] ' + tmp_observation['event_1_words'] + ' . [SEP] ' + Connective_dict[ 94 | tmp_observation['connective']] + ' ' + tmp_observation['event_2_words'] + ' . [SEP] ') 95 | tmp_head_mask = list() 96 | for w in tmp_sentence: 97 | if w in head_words: 98 | tmp_head_mask.append(1.0) 99 | else: 100 | tmp_head_mask.append(0.0) 101 | tmp_tail_mask = list() 102 | for w in tmp_sentence: 103 | if w in tail_words: 104 | tmp_tail_mask.append(1.0) 105 | else: 106 | tmp_tail_mask.append(0.0) 107 | if tmp_observation['frequency'] > 64: 108 | tmp_frequency = [6] 109 | else: 110 | tmp_frequency = [int(math.log(tmp_observation['frequency'], 2))] 111 | all_observations.append(torch.tensor(tmp_sentence)) 112 | all_head_masks.append(torch.tensor(tmp_head_mask)) 113 | all_tail_masks.append(torch.tensor(tmp_tail_mask)) 114 | all_frequencies.append(torch.tensor(tmp_frequency)) 115 | all_types.append(torch.tensor([1])) 116 | all_adj_matrices.append(get_adj_matrix(tmp_sentence, tmp_observation['graph'], max_observation_length)) 117 | 118 | tensorized_all_observations = pad_sequence(all_observations, batch_first=True).to(device) 119 | tensorized_head_masks = pad_sequence(all_head_masks, batch_first=True).to(device) 120 | tensorized_tail_masks = pad_sequence(all_tail_masks, batch_first=True).to(device) 121 | tensorized_frequencies = pad_sequence(all_frequencies, batch_first=True).to(device) 122 | tensorized_types = pad_sequence(all_types, batch_first=True).to(device) 123 | tensorized_adj_matrices = pad_sequence(all_adj_matrices, batch_first=True).to(device) 124 | 125 | return tensorized_all_observations, tensorized_head_masks, tensorized_tail_masks, tensorized_frequencies, tensorized_types, tensorized_adj_matrices 126 | 127 | 128 | class DataLoader: 129 | def __init__(self, data_path, relation_name): 130 | with open(data_path, 'r') as f: 131 | raw_dataset = json.load(f) 132 | self.train_set = raw_dataset[relation_name]['train'] 133 | self.test_set = raw_dataset[relation_name]['test'] 134 | random.shuffle(self.train_set) 135 | random.shuffle(self.test_set) 136 | print('Start to tensorize the train set.') 137 | self.tensorized_train = self.tensorize_dataset(self.train_set) 138 | print('Start to tensorize the test set.') 139 | self.tensorized_test = self.tensorize_dataset(self.test_set) 140 | 141 | def random_sample_train_set(self): 142 | print('We are randomly selecting the train example') 143 | positive_dataset = list() 144 | negative_dataset = list() 145 | for tmp_example in self.train_set: 146 | if tmp_example['label'] == 'positive': 147 | positive_dataset.append(tmp_example) 148 | else: 149 | negative_dataset.append(tmp_example) 150 | random.shuffle(positive_dataset) 151 | random.shuffle(negative_dataset) 152 | if len(positive_dataset) > len(negative_dataset): 153 | new_dataset = positive_dataset[:len(negative_dataset)] + negative_dataset 154 | else: 155 | new_dataset = positive_dataset + negative_dataset[:len(positive_dataset)] 156 | random.shuffle(new_dataset) 157 | self.tensorized_train = self.tensorize_dataset(new_dataset) 158 | 159 | def tensorize_dataset(self, input_dataset): 160 | tmp_tensorized_dataset = list() 161 | positive_count = 0 162 | negative_count = 0 163 | for tmp_example in tqdm(input_dataset): 164 | tmp_tensorized_dataset.append(TrainingExample(tmp_example)) 165 | if tmp_example['label'] == 'positive': 166 | positive_count += 1 167 | else: 168 | negative_count += 1 169 | print('Positive count:', positive_count, 'Negative count:', negative_count) 170 | return tmp_tensorized_dataset 171 | 172 | 173 | class DataLoaderPredict: 174 | def __init__(self, data_path): 175 | with open(data_path, 'r') as f: 176 | raw_dataset = json.load(f) 177 | self.test_set = list() 178 | for tmp_example in tqdm(raw_dataset): 179 | new_example = tmp_example 180 | new_example['label'] = 'na' 181 | new_eventuality_observations = list() 182 | new_edge_observations = list() 183 | 184 | for tmp_eventuality in tmp_example['eventuality_observations']: 185 | tmp_sentence = '[CLS] ' + tmp_eventuality['words'] + ' . [SEP]' 186 | if len(tmp_sentence.split(' ')) < 30 and len(tokenizer.encode(tmp_sentence)) < 64: 187 | new_eventuality_observations.append(tmp_eventuality) 188 | for tmp_edge in tmp_example['edge_observations']: 189 | tmp_sentence = '[CLS] ' + tmp_edge['event_1_words'] + ' . [SEP] ' + Connective_dict[ 190 | tmp_edge['connective']] + ' ' + tmp_edge['event_2_words'] + ' . [SEP] ' 191 | if len(tmp_sentence.split(' ')) < 30 and len(tokenizer.encode(tmp_sentence)) < 64: 192 | new_edge_observations.append(tmp_edge) 193 | new_example['eventuality_observations'] = new_eventuality_observations 194 | new_example['edge_observations'] = new_edge_observations 195 | if len(new_example['eventuality_observations']) + len(new_example['edge_observations']) > 0: 196 | self.test_set.append(new_example) 197 | print('number of new examples:', len(self.test_set)) 198 | self.trunked_test_sets = list() 199 | 200 | self.number_of_trunks = int(len(self.test_set) / 10000) + 1 201 | print('Number of trunks:', self.number_of_trunks) 202 | for i in range(self.number_of_trunks): 203 | self.trunked_test_sets.append(self.test_set[i * 10000:(i + 1) * 10000]) 204 | 205 | def tensorize_dataset(self, input_dataset): 206 | print('Start to tensorize the set.') 207 | tmp_tensorized_dataset = list() 208 | for tmp_example in tqdm(input_dataset): 209 | tmp_tensorized_dataset.append(TrainingExample(tmp_example)) 210 | return tmp_tensorized_dataset 211 | 212 | 213 | class CommonsenseRelationClassifier(BertModel): 214 | def __init__(self, config): 215 | super(BertModel, self).__init__(config) 216 | self.bert = BertModel(config) 217 | self.edge_attention_weight = torch.nn.Linear(768 * 2, 1) 218 | self.last_layer = torch.nn.Linear(768 * 6, 2) 219 | self.frequency_embedding = torch.nn.Embedding(7, 768) 220 | self.type_embedding = torch.nn.Embedding(2, 768) 221 | self.dropout = torch.nn.Dropout(0.5) 222 | 223 | def forward(self, raw_sentences, first_mask=None, second_mask=None, all_frequencies=None, all_types=None, 224 | adj_matrices=None, attention_mask=None, token_type_ids=None, 225 | position_ids=None, head_mask=None): 226 | number_of_observation = raw_sentences.size(0) 227 | number_of_token = raw_sentences.size(1) 228 | encoding_after_bert = self.bert(raw_sentences) 229 | bert_last_layer = encoding_after_bert[0] # [number_of_observation, number_of_token, embedding_size] 230 | 231 | # start to implement the graph attention 232 | last_layer_pile = bert_last_layer.repeat(1, number_of_token, 1).view( 233 | [number_of_observation, number_of_token, number_of_token, 234 | 768]) # [number_of_observation, number_of_token, number_of_token, embedding_size] 235 | last_layer_repeat = bert_last_layer.repeat(1, 1, number_of_token).view( 236 | [number_of_observation, number_of_token, number_of_token, 237 | 768]) # [number_of_observation, number_of_token, number_of_token, embedding_size] 238 | matched_last_layer = torch.cat([last_layer_pile, last_layer_repeat], 239 | dim=3) # [number_of_observation, number_of_token, number_of_token, embedding_size*2] 240 | attention_weight = self.edge_attention_weight(matched_last_layer).squeeze( 241 | dim=3) # [number_of_observation, number_of_token, number_of_token] 242 | adj_matrices = adj_matrices.float() 243 | attention_weight = attention_weight * adj_matrices # [number_of_observation, number_of_token, number_of_token] 244 | weight_after_softmax = F.softmax(attention_weight, dim=2).unsqueeze( 245 | 3) # [number_of_observation, number_of_token, number_of_token, 1] 246 | weight_after_softmax_matrices = weight_after_softmax.repeat(1, 1, 1, 247 | 768) # [number_of_observation, number_of_token, number_of_token, embedding_size] 248 | aggegated_embedding = torch.sum(last_layer_pile * weight_after_softmax_matrices, 249 | dim=2) # [number_of_observation, number_of_token, embedding_size] 250 | 251 | aggegated_embedding = torch.cat([aggegated_embedding, bert_last_layer], dim=2) 252 | 253 | # Start to implement the head/tail mask 254 | first_mask = first_mask[:, :, None] 255 | second_mask = second_mask[:, :, None] 256 | head_selection_mask = first_mask.expand( 257 | [-1, -1, 768 * 2]) # [number_of_observation, number_of_token, embedding_size] 258 | head_representation = torch.mean(aggegated_embedding * head_selection_mask, 259 | dim=1) # [number_of_observation, embedding_size] 260 | tail_selection_mask = second_mask.expand( 261 | [-1, -1, 768 * 2]) # [number_of_observation, number_of_token, embedding_size] 262 | tail_representation = torch.mean(aggegated_embedding * tail_selection_mask, 263 | dim=1) # [number_of_observation, embedding_size] 264 | 265 | # Start to add features 266 | frequency_feature = self.frequency_embedding(all_frequencies).squeeze( 267 | dim=1) # [number_of_observation, embedding_size] 268 | type_feature = self.type_embedding(all_types).squeeze(dim=1) # [number_of_observation, embedding_size] 269 | overall_representation = torch.cat( 270 | [head_representation, tail_representation, frequency_feature, type_feature], 271 | dim=1) # [number_of_observation, embedding_size*2] 272 | 273 | overall_representation = self.dropout(overall_representation) 274 | final_prediction = self.last_layer(overall_representation) # [batch_size, 2] 275 | 276 | final_prediction = torch.mean(final_prediction, dim=0).unsqueeze(0) # [1, 2] 277 | return final_prediction 278 | 279 | 280 | def train(model, train_data): 281 | all_loss = 0 282 | print('training:') 283 | random.shuffle(train_data) 284 | model.train() 285 | for tmp_example in tqdm(train_data): 286 | final_prediction = model(raw_sentences=tmp_example.all_observations, first_mask=tmp_example.all_head_masks, 287 | second_mask=tmp_example.all_tail_masks, all_frequencies=tmp_example.all_frequencies, 288 | all_types=tmp_example.all_types, adj_matrices=tmp_example.all_adj_matrices) # 1 * 2 289 | loss = loss_func(final_prediction, tmp_example.label) 290 | test_optimizer.zero_grad() 291 | loss.backward() 292 | test_optimizer.step() 293 | all_loss += loss.item() 294 | print('current loss:', all_loss / len(current_data.tensorized_train)) 295 | 296 | 297 | def test(model, test_data): 298 | correct_count = 0 299 | print('Testing') 300 | model.eval() 301 | for tmp_example in tqdm(test_data): 302 | final_prediction = model(raw_sentences=tmp_example.all_observations, first_mask=tmp_example.all_head_masks, 303 | second_mask=tmp_example.all_tail_masks, all_frequencies=tmp_example.all_frequencies, 304 | all_types=tmp_example.all_types, adj_matrices=tmp_example.all_adj_matrices) # 1 * 2 305 | if tmp_example.label.data[0] == 1: 306 | # current example is positive 307 | if final_prediction.data[0][1] >= final_prediction.data[0][0]: 308 | correct_count += 1 309 | else: 310 | # current example is negative 311 | if final_prediction.data[0][1] <= final_prediction.data[0][0]: 312 | correct_count += 1 313 | 314 | print('current accuracy:', correct_count, '/', len(test_data), correct_count / len(test_data)) 315 | return correct_count / len(test_data) 316 | 317 | 318 | def predict(model, data_for_predict, relation): 319 | model.eval() 320 | tmp_prediction_dict = dict() 321 | print('Start to predict') 322 | for tmp_example in tqdm(data_for_predict): 323 | final_prediction = model(raw_sentences=tmp_example.all_observations, first_mask=tmp_example.all_head_masks, 324 | second_mask=tmp_example.all_tail_masks, all_frequencies=tmp_example.all_frequencies, 325 | all_types=tmp_example.all_types, adj_matrices=tmp_example.all_adj_matrices) # 1 * 2 326 | scores = F.softmax(final_prediction, dim=1) 327 | tmp_prediction_dict[tmp_example.knowledge] = scores.data.tolist()[0][1] 328 | tmp_file_name = 'prediction/' + relation + '.json' 329 | try: 330 | with open(tmp_file_name, 'r') as f: 331 | prediction_dict = json.load(f) 332 | except FileNotFoundError: 333 | prediction_dict = dict() 334 | for tmp_k in tmp_prediction_dict: 335 | prediction_dict[tmp_k] = tmp_prediction_dict[tmp_k] 336 | with open(tmp_file_name, 'w') as f: 337 | json.dump(prediction_dict, f) 338 | 339 | 340 | parser = argparse.ArgumentParser() 341 | 342 | ## Required parameters 343 | parser.add_argument("--gpu", default='0', type=str, required=False, 344 | help="choose which gpu to use") 345 | parser.add_argument("--model", default='graph', type=str, required=False, 346 | help="choose the model to test") 347 | parser.add_argument("--lr", default=0.001, type=float, required=False, 348 | help="initial learning rate") 349 | parser.add_argument("--lrdecay", default=0.8, type=float, required=False, 350 | help="learning rate decay every 5 epochs") 351 | 352 | args = parser.parse_args() 353 | 354 | logging.basicConfig(level=logging.INFO) 355 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 356 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 357 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 358 | print('current device:', device) 359 | n_gpu = torch.cuda.device_count() 360 | print('number of gpu:', n_gpu) 361 | torch.cuda.get_device_name(0) 362 | 363 | current_model = CommonsenseRelationClassifier.from_pretrained('bert-base-uncased') 364 | test_optimizer = torch.optim.SGD(current_model.parameters(), lr=args.lr) 365 | loss_func = torch.nn.CrossEntropyLoss() 366 | current_model.to(device) 367 | 368 | performance_dict = dict() 369 | 370 | selected_relations = all_relations 371 | 372 | if not os.path.isdir('models'): 373 | os.mkdir('models') 374 | 375 | for r in selected_relations: 376 | current_data = DataLoader('ranking_dataset.json', r) 377 | print('Finish loading data') 378 | 379 | test(current_model, current_data.tensorized_test) 380 | 381 | best_performance = 0 382 | tmp_lr = args.lr 383 | current_data.random_sample_train_set() 384 | for i in range(50): 385 | if i % 5 == 0: 386 | test_optimizer = torch.optim.SGD(current_model.parameters(), lr=tmp_lr) 387 | tmp_lr = tmp_lr * args.lrdecay 388 | print('Current Iteration:', i + 1, '|', 'Relation:', r, '|', 389 | 'Current best performance:', best_performance) 390 | train(current_model, current_data.tensorized_train) 391 | tmp_performance = test(current_model, current_data.tensorized_test) 392 | if tmp_performance >= best_performance: 393 | best_performance = tmp_performance 394 | print('We are saving the new best model') 395 | torch.save(current_model.state_dict(), 'models/' + r + '.pth') 396 | performance_dict[r] = best_performance 397 | 398 | if not os.path.isdir('prediction'): 399 | os.mkdir('prediction') 400 | # This process might be slow due to the huge dataset scale. 401 | for r in selected_relations: 402 | print('Start to load data...') 403 | data_for_prediction = DataLoaderPredict('extracted_knowledge/' + r + '.json') 404 | print('Finish loading data...') 405 | current_model = CommonsenseRelationClassifier.from_pretrained('bert-base-uncased') 406 | current_model.load_state_dict(torch.load('models/' + r + '.pth')) 407 | current_model.to(device) 408 | print('We are working on relation:', r) 409 | for i in range(data_for_prediction.number_of_trunks): 410 | print('Working on set:', i + 1, '/', data_for_prediction.number_of_trunks, 'relation:', r) 411 | tmp_tensorized_data = data_for_prediction.tensorize_dataset(data_for_prediction.trunked_test_sets[i]) 412 | predict(current_model, tmp_tensorized_data, r) 413 | tmp_tensorized_data = list() 414 | 415 | overall_dict = dict() 416 | for r in all_relations: 417 | print('We are working on:', r) 418 | with open('prediction/' + r + '.json', 'r') as f: 419 | tmp_dict = json.load(f) 420 | for tmp_k in tqdm(tmp_dict): 421 | tmp_head = tmp_k.split('$$')[0] 422 | tmp_tail = tmp_k.split('$$')[1] 423 | new_k = tmp_head + '$$' + r + '$$' + tmp_tail 424 | overall_dict[new_k] = tmp_dict[tmp_k] 425 | 426 | sorted_result = sorted(overall_dict, key=lambda x: overall_dict[x], reverse=True) 427 | with open('prediction/TransOMCS.txt', 'w') as f: 428 | for tmp_k in sorted_result: 429 | f.write(tmp_k.split('$$')[0]) 430 | f.write('\t') 431 | f.write(tmp_k.split('$$')[1]) 432 | f.write('\t') 433 | f.write(tmp_k.split('$$')[2]) 434 | f.write('\t') 435 | f.write(str(overall_dict[tmp_k])) 436 | f.write('\n') 437 | 438 | print('end') 439 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from aser.database.db_API import KG_Connection 2 | import time 3 | from tqdm import tqdm 4 | # import aser 5 | import ujson as json 6 | from multiprocessing import Pool 7 | import spacy 8 | import random 9 | import pandas 10 | import numpy as np 11 | from itertools import combinations 12 | from scipy import spatial 13 | import os 14 | 15 | def get_ConceptNet_info(file_path): 16 | tmp_collection = dict() 17 | with open(file_path, 'r', encoding='utf-8') as f: 18 | for line in f: 19 | tmp_words = line[:-1].split('\t') 20 | if tmp_words[3] == '0': 21 | continue 22 | if tmp_words[0] not in tmp_collection: 23 | tmp_collection[tmp_words[0]] = list() 24 | tmp_collection[tmp_words[0]].append((tmp_words[1], tmp_words[2])) 25 | return tmp_collection 26 | 27 | 28 | def load_jsonlines(file_name): 29 | extracted_file = list() 30 | with open(file_name, 'r') as f: 31 | for line in f: 32 | tmp_info = json.loads(line) 33 | extracted_file.append(tmp_info) 34 | return extracted_file 35 | 36 | 37 | def chunks(l, group_number): 38 | if len(l) < 10: 39 | return [l] 40 | group_size = int(len(l) / group_number) 41 | final_data_groups = list() 42 | for i in range(0, len(l), group_size): 43 | final_data_groups.append(l[i:i+group_size]) 44 | return final_data_groups 45 | 46 | 47 | def match_commonsense_and_aser(sample_pairs, ASER): 48 | tmp_dict = dict() 49 | for tmp_tuple in sample_pairs: 50 | head_words = tmp_tuple['head'].split(' ') 51 | tail_words = tmp_tuple['tail'].split(' ') 52 | all_words = head_words + tail_words 53 | tmp_key = tmp_tuple['head'] + '$$' + tmp_tuple['tail'] 54 | matched_eventualities = list() 55 | for tmp_event in ASER: 56 | is_match = True 57 | for w in all_words: 58 | if w not in tmp_event['words'].split(' '): 59 | is_match = False 60 | break 61 | if is_match: 62 | matched_eventualities.append(tmp_event) 63 | tmp_dict[tmp_key] = matched_eventualities 64 | return tmp_dict 65 | 66 | 67 | def match_commonsense_and_aser_edge(sample_pairs, ASER): 68 | tmp_dict = dict() 69 | for tmp_tuple in sample_pairs: 70 | head_words = tmp_tuple['head'].split(' ') 71 | tail_words = tmp_tuple['tail'].split(' ') 72 | all_words = head_words + tail_words 73 | tmp_key = tmp_tuple['head'] + '$$' + tmp_tuple['tail'] 74 | matched_eventualities = list() 75 | for tmp_event in ASER: 76 | is_match = True 77 | edge_words = tmp_event['event_1_words'].split(' ') + tmp_event['event_2_words'].split(' ') 78 | for w in all_words: 79 | if w not in edge_words: 80 | is_match = False 81 | break 82 | if is_match: 83 | matched_eventualities.append(tmp_event) 84 | tmp_dict[tmp_key] = matched_eventualities 85 | return tmp_dict 86 | 87 | 88 | def find_head_tail_position_from_graph(graph, pattern_keywords, direction, loop): 89 | if loop == 5: 90 | print('Current loop is 5, we need to stop, something is wrong, please check.') 91 | return [] 92 | if len(pattern_keywords) == 0: 93 | return [] 94 | if direction == '<': 95 | if len(pattern_keywords) == 3: 96 | potential_links = list() 97 | for edge in graph: 98 | if edge[1] == pattern_keywords[1]: 99 | if pattern_keywords[0] == 'head': 100 | potential_links.append([edge[2], edge[0]]) 101 | else: 102 | if pattern_keywords[0] == edge[2][0]: 103 | potential_links.append([edge[0]]) 104 | return potential_links 105 | else: 106 | potential_links = list() 107 | for edge in graph: 108 | if edge[1] == pattern_keywords[1]: 109 | if pattern_keywords[0] == 'head': 110 | tmp_link = [edge[2], edge[0]] 111 | new_pattern_keywords = pattern_keywords[2:] 112 | rest_links = find_head_tail_position_from_graph(graph, new_pattern_keywords, direction, loop+1) 113 | for tmp_rest_link in rest_links: 114 | tmp_link = tmp_link + tmp_rest_link 115 | potential_links.append(tmp_link) 116 | else: 117 | if pattern_keywords[0] == edge[2][0]: 118 | tmp_link = [edge[0]] 119 | new_pattern_keywords = pattern_keywords[2:] 120 | rest_links = find_head_tail_position_from_graph(graph, new_pattern_keywords, direction, 121 | loop + 1) 122 | for tmp_rest_link in rest_links: 123 | tmp_link = tmp_link + tmp_rest_link 124 | potential_links.append(tmp_link) 125 | return potential_links 126 | else: 127 | if len(pattern_keywords) == 3: 128 | potential_links = list() 129 | for edge in graph: 130 | if edge[1] == pattern_keywords[1]: 131 | if pattern_keywords[0] == 'head': 132 | potential_links.append([edge[0], edge[2]]) 133 | else: 134 | if pattern_keywords[0] == edge[0][0]: 135 | potential_links.append([edge[2]]) 136 | return potential_links 137 | else: 138 | potential_links = list() 139 | for edge in graph: 140 | if edge[1] == pattern_keywords[1]: 141 | if pattern_keywords[0] == 'head': 142 | tmp_link = [edge[0], edge[2]] 143 | new_pattern_keywords = pattern_keywords[2:] 144 | rest_links = find_head_tail_position_from_graph(graph, new_pattern_keywords, direction, loop+1) 145 | for tmp_rest_link in rest_links: 146 | tmp_link = tmp_link + tmp_rest_link 147 | potential_links.append(tmp_link) 148 | else: 149 | if pattern_keywords[0] == edge[0][0]: 150 | tmp_link = [edge[2]] 151 | new_pattern_keywords = pattern_keywords[2:] 152 | rest_links = find_head_tail_position_from_graph(graph, new_pattern_keywords, direction, 153 | loop + 1) 154 | for tmp_rest_link in rest_links: 155 | tmp_link = tmp_link + tmp_rest_link 156 | potential_links.append(tmp_link) 157 | return potential_links 158 | 159 | 160 | def extract_knowledge_with_focused_position(graph, pattern_keywords, focused_position): 161 | if len(pattern_keywords) == 0: 162 | return focused_position[0] 163 | else: 164 | extracted_pattern = list() 165 | extracted_nodes = [focused_position] 166 | while len(extracted_pattern) != len(pattern_keywords): 167 | found_new_node = False 168 | for edge in graph: 169 | if edge[1] in pattern_keywords and edge[1] not in extracted_pattern: 170 | if edge[0] in extracted_nodes: 171 | extracted_nodes.append(edge[2]) 172 | found_new_node = True 173 | extracted_pattern.append(edge[1]) 174 | elif edge[2] in extracted_nodes: 175 | extracted_nodes.append(edge[0]) 176 | found_new_node = True 177 | extracted_pattern.append(edge[1]) 178 | if not found_new_node: 179 | break 180 | if len(extracted_pattern) == len(pattern_keywords): 181 | sorted_nodes = sorted(extracted_nodes, key=lambda x: x[1]) 182 | tmp_knowledge = '' 183 | for w in sorted_nodes: 184 | tmp_knowledge += w[0] 185 | tmp_knowledge += ' ' 186 | return tmp_knowledge[:-1] 187 | else: 188 | return None 189 | 190 | 191 | def extract_knowledge_from_graph_with_knowledge(graph, pattern): 192 | head_pattern = pattern.split(')')[0][1:] 193 | if head_pattern == '': 194 | head_keywords = [] 195 | else: 196 | head_keywords = head_pattern.split('-')[1:-1] 197 | internal_pattern = pattern.split(')')[1].split('(')[0] 198 | tail_pattern = pattern.split('(')[2][:-1] 199 | if tail_pattern == '': 200 | tail_keywords = [] 201 | else: 202 | tail_keywords = tail_pattern.split('-')[1:-1] 203 | focus_nodes = list() 204 | 205 | # We need to detect double direction 206 | if '<-' in internal_pattern and '->' in internal_pattern: 207 | all_paths = list() 208 | # we find a double direction 209 | if internal_pattern[0] == '<': 210 | middle_word = internal_pattern.split('<-')[-1].split('->')[0] 211 | first_half_pattern = internal_pattern.split(middle_word)[0] 212 | first_half_keywords = first_half_pattern.split('<-') 213 | first_half_keywords[0] = 'head' 214 | first_half_keywords[-1] = 'tail' 215 | first_half_paths = find_head_tail_position_from_graph(graph=graph, pattern_keywords=first_half_keywords, direction='<', loop=0) 216 | second_half_pattern = internal_pattern.split(middle_word)[1] 217 | second_half_keywords = second_half_pattern.split('->') 218 | second_half_keywords[0] = 'head' 219 | second_half_keywords[-1] = 'tail' 220 | second_half_paths = find_head_tail_position_from_graph(graph=graph, pattern_keywords=second_half_keywords, 221 | direction='>', loop=0) 222 | for tmp_first_half_path in first_half_paths: 223 | for tmp_second_half_path in second_half_paths: 224 | if tmp_first_half_path[-1] == tmp_second_half_path[0] and tmp_first_half_path[-1][0] == middle_word: 225 | all_paths.append((tmp_first_half_path[0], tmp_second_half_path[-1])) 226 | else: 227 | middle_word = internal_pattern.split('->')[-1].split('<-')[0] 228 | first_half_pattern = internal_pattern.split(middle_word)[0] 229 | first_half_keywords = first_half_pattern.split('->') 230 | first_half_keywords[0] = 'head' 231 | first_half_keywords[-1] = 'tail' 232 | first_half_paths = find_head_tail_position_from_graph(graph=graph, pattern_keywords=first_half_keywords, 233 | direction='>', loop=0) 234 | second_half_pattern = internal_pattern.split(middle_word)[1] 235 | second_half_keywords = second_half_pattern.split('<-') 236 | second_half_keywords[0] = 'head' 237 | second_half_keywords[-1] = 'tail' 238 | second_half_paths = find_head_tail_position_from_graph(graph=graph, pattern_keywords=second_half_keywords, 239 | direction='<', loop=0) 240 | for tmp_first_half_path in first_half_paths: 241 | for tmp_second_half_path in second_half_paths: 242 | if tmp_first_half_path[-1] == tmp_second_half_path[0] and tmp_first_half_path[-1][0] == middle_word: 243 | all_paths.append((tmp_first_half_path[0], tmp_second_half_path[-1])) 244 | else: 245 | if internal_pattern[0] == '<': 246 | pattern_keywords = internal_pattern.split('<-') 247 | else: 248 | pattern_keywords = internal_pattern.split('->') 249 | pattern_keywords[0] = 'head' 250 | pattern_keywords[-1] = 'tail' 251 | all_paths = find_head_tail_position_from_graph(graph=graph, pattern_keywords=pattern_keywords, direction=internal_pattern[0], loop=0) 252 | 253 | extracted_knowledge_list = list() 254 | for tmp_path in all_paths: 255 | head_knowledge = extract_knowledge_with_focused_position(graph, head_keywords, tmp_path[0]) 256 | tail_knowledge = extract_knowledge_with_focused_position(graph, tail_keywords, tmp_path[-1]) 257 | if head_knowledge and tail_knowledge: 258 | extracted_knowledge_list.append(head_knowledge + '$$' + tail_knowledge) 259 | return extracted_knowledge_list 260 | 261 | 262 | def extract_knowledge_from_eventuality_set(patterns, eventuality_set): 263 | tmp_eventuality_dict = dict() 264 | tmp_extracted_knowledge = dict() 265 | for r in patterns: 266 | tmp_extracted_knowledge[r] = dict() 267 | for tmp_pattern in patterns[r]: 268 | tmp_extracted_knowledge[r][tmp_pattern[0]] = dict() 269 | for i, tmp_e in enumerate(eventuality_set): 270 | doc = nlp(tmp_e['words']) 271 | all_dependency_edges = list() 272 | for word in doc: 273 | all_dependency_edges.append(((word.head.norm_, word.head.i), word.dep_, (word.norm_, word.i))) 274 | for r in patterns: 275 | for pattern in patterns[r]: 276 | tmp_knowledge_list = extract_knowledge_from_graph_with_knowledge(all_dependency_edges, pattern[0]) 277 | for tmp_knowledge in tmp_knowledge_list: 278 | if tmp_knowledge not in tmp_extracted_knowledge[r][pattern[0]]: 279 | tmp_extracted_knowledge[r][pattern[0]][tmp_knowledge] = 0 280 | tmp_extracted_knowledge[r][pattern[0]][tmp_knowledge] += tmp_e['frequency'] 281 | if tmp_knowledge not in tmp_eventuality_dict: 282 | tmp_eventuality_dict[tmp_knowledge] = list() 283 | tmp_e['graph'] = all_dependency_edges 284 | tmp_eventuality_dict[tmp_knowledge].append(tmp_e) 285 | if i % 1000 == 0: 286 | print('finished:', i, '/', len(eventuality_set)) 287 | return tmp_extracted_knowledge, tmp_eventuality_dict 288 | 289 | 290 | def eventuality_to_graph(tmp_eventuality): 291 | doc = nlp(tmp_eventuality['words']) 292 | all_dependency_edges = list() 293 | for word in doc: 294 | all_dependency_edges.append(((word.head.norm_, word.head.i), word.dep_, (word.norm_, word.i))) 295 | return all_dependency_edges 296 | 297 | 298 | def eventuality_set_to_graph_set(eventuality_set): 299 | tmp_event_id_to_graph = dict() 300 | for i, tmp_eventuality in enumerate(eventuality_set): 301 | tmp_graph = eventuality_to_graph(tmp_eventuality) 302 | tmp_event_id_to_graph[tmp_eventuality['id']] = tmp_graph 303 | if i % 10000 == 0: 304 | print(i, '/', len(eventuality_set)) 305 | return tmp_event_id_to_graph 306 | 307 | 308 | def extract_knowledge_from_edge_set(patterns, edge_set): 309 | tmp_edge_dict = dict() 310 | tmp_extracted_knowledge = dict() 311 | for r in patterns: 312 | tmp_extracted_knowledge[r] = dict() 313 | for tmp_pattern in patterns[r]: 314 | tmp_extracted_knowledge[r][tmp_pattern[0]] = dict() 315 | for i, tmp_edge in enumerate(edge_set): 316 | parsed_eventuality1_words = list() 317 | doc = nlp(tmp_edge['event_1_words']) 318 | event1_dependency_edges = list() 319 | event1_verb = [] 320 | for word in doc: 321 | event1_dependency_edges.append(((word.head.norm_, word.head.i), word.dep_, (word.norm_, word.i))) 322 | parsed_eventuality1_words.append(word.text) 323 | if word.dep_ == 'ROOT': 324 | event1_verb = (word.norm_, word.i) 325 | 326 | doc = nlp(tmp_edge['event_2_words']) 327 | event2_dependency_edges = list() 328 | event2_verb = [] 329 | for word in doc: 330 | event2_dependency_edges.append(((word.head.norm_, word.head.i + len(parsed_eventuality1_words)), word.dep_, 331 | (word.norm_, word.i + len(parsed_eventuality1_words)))) 332 | if word.dep_ == 'ROOT': 333 | event2_verb = (word.norm_, word.i + len(parsed_eventuality1_words)) 334 | all_dependency_edges = event1_dependency_edges + event2_dependency_edges 335 | all_dependency_edges.append((event1_verb, tmp_edge['connective'], event2_verb)) 336 | for r in patterns: 337 | for pattern in patterns[r]: 338 | tmp_knowledge_list = extract_knowledge_from_graph_with_knowledge(all_dependency_edges, pattern[0]) 339 | for tmp_knowledge in tmp_knowledge_list: 340 | if tmp_knowledge not in tmp_extracted_knowledge[r][pattern[0]]: 341 | tmp_extracted_knowledge[r][pattern[0]][tmp_knowledge] = 0 342 | tmp_extracted_knowledge[r][pattern[0]][tmp_knowledge] += tmp_edge['frequency'] 343 | if tmp_knowledge not in tmp_edge_dict: 344 | tmp_edge_dict[tmp_knowledge] = list() 345 | tmp_edge['graph'] = all_dependency_edges 346 | tmp_edge_dict[tmp_knowledge].append(tmp_edge) 347 | if i % 1000 == 0: 348 | print('finished:', i, '/', len(edge_set)) 349 | return tmp_extracted_knowledge, tmp_edge_dict 350 | 351 | 352 | def edge_to_graph(tmp_edge): 353 | parsed_eventuality1_words = list() 354 | doc = nlp(tmp_edge['event_1_words']) 355 | event1_dependency_edges = list() 356 | event1_verb = [] 357 | for word in doc: 358 | event1_dependency_edges.append(((word.head.norm_, word.head.i), word.dep_, (word.norm_, word.i))) 359 | parsed_eventuality1_words.append(word.text) 360 | if word.dep_ == 'ROOT': 361 | event1_verb = (word.norm_, word.i) 362 | 363 | doc = nlp(tmp_edge['event_2_words']) 364 | event2_dependency_edges = list() 365 | event2_verb = [] 366 | for word in doc: 367 | event2_dependency_edges.append(((word.head.norm_, word.head.i + len(parsed_eventuality1_words)), word.dep_, 368 | (word.norm_, word.i + len(parsed_eventuality1_words)))) 369 | if word.dep_ == 'ROOT': 370 | event2_verb = (word.norm_, word.i + len(parsed_eventuality1_words)) 371 | all_dependency_edges = event1_dependency_edges + event2_dependency_edges 372 | all_dependency_edges.append((event1_verb, tmp_edge['connective'], event2_verb)) 373 | return all_dependency_edges 374 | 375 | 376 | def merge_extracted_knowledge_from_multi_core(all_extracted_knowledge): 377 | merged_knowledge = dict() 378 | for r in selected_patterns: 379 | merged_knowledge[r] = dict() 380 | for tmp_pattern in selected_patterns[r]: 381 | merged_knowledge[r][tmp_pattern[0]] = dict() 382 | for tmp_extracted_knowledge in tqdm(all_extracted_knowledge): 383 | for r in tmp_extracted_knowledge: 384 | for tmp_pattern in tmp_extracted_knowledge[r]: 385 | for tmp_k in tmp_extracted_knowledge[r][tmp_pattern]: 386 | if tmp_k not in merged_knowledge[r][tmp_pattern]: 387 | merged_knowledge[r][tmp_pattern][tmp_k] = tmp_extracted_knowledge[r][tmp_pattern][tmp_k] 388 | else: 389 | merged_knowledge[r][tmp_pattern][tmp_k] += tmp_extracted_knowledge[r][tmp_pattern][tmp_k] 390 | return merged_knowledge 391 | 392 | nlp = spacy.load('en_core_web_sm') 393 | 394 | try: 395 | with open('selected_patterns.json', 'r') as f: 396 | selected_patterns = json.load(f) 397 | print('Finish loading the patterns') 398 | except: 399 | pass 400 | 401 | Connectives = ['Precedence', 'Succession', 'Synchronous', 'Reason', 'Result', 'Condition', 'Contrast', 'Concession', 'Conjunction', 'Instantiation', 'Restatement', 'ChosenAlternative', 'Alternative', 'Exception'] 402 | 403 | --------------------------------------------------------------------------------