├── .gitignore ├── LICENSE ├── README.md ├── amr_diff_ex.PNG ├── amr_utils ├── __init__.py ├── alignments.py ├── amr.py ├── amr_diff.py ├── amr_readers.py ├── display_alignments.py ├── graph_utils.py ├── propbank_frames.py ├── smatch.py └── style.py ├── data ├── test_amrs.alignments.json ├── test_amrs.txt └── test_amrs2.txt ├── display_align_ex.PNG ├── html_ex.PNG ├── latex_ex.PNG └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Austin Blodgett 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMR-utils 2 | AMR-utils is a python package for working with AMR data, with tools for reading AMRs and alignments, performing graph operations, and displaying and visualizing AMR data. I wrote this package to store operations that I often need or find useful when doing research with AMRs. This code is written and maintained by Austin Blodgett. 3 | ### Features: 4 | - Load AMRs from a file or directory, with support for multiple formats 5 | - Load AMR alignments, with support for LDC, JAMR, and ISI alignment formats 6 | - A simple class for accessing AMR nodes, edges, tokens, etc. 7 | - Graph operations for operating on AMR data 8 | - Tools for AMR Visualization 9 | - Convert AMR graphs to Latex (using the tikz library) 10 | - Display AMR strings as HTML, with overridable display settings for nodes, tokens, and edges 11 | - AMR Diff, Display differnces between AMRs as HTML 12 | - Display AMR Alignments as HTML 13 | 14 | ### Requirements 15 | - Python 3.6 or higher 16 | - [PENMAN library](https://github.com/goodmami/penman) 17 | 18 | ### Install 19 | ``` 20 | git clone https://github.com//ablodge/amr-utils 21 | pip install penman 22 | pip install ./amr-utils 23 | ``` 24 | 25 | ### Wiki 26 | If you have a question that isn't answered by this document, please check the Wiki. 27 | 28 | ### Notes 29 | - A small excerpt of code is taken from [smatch](https://github.com/snowblink14/smatch) for AMR-to-AMR alignment in the AMR Diff tool, so that results from AMR Diff are directly related to the smatch score. 30 | 31 | # AMR Reader 32 | The class `AMR_Reader` can be used to load AMRs or AMR alignments from a number of different formats including LDC, JAMR, and ISI. An `AMR_Reader` can be used as follows. 33 | 34 | ``` 35 | from amr_utils.amr_readers import AMR_Reader 36 | 37 | reader = AMR_Reader() 38 | amrs = reader.load(amr_file, remove_wiki=True) 39 | ``` 40 | 41 | AMRs must be separated by empty lines, but otherwise can take various formats. 42 | Simplified: 43 | ``` 44 | # Dogs chase cats. 45 | (c/chase-01 :ARG0 (d/dog) 46 | :ARG1 (c2/cat)) 47 | ``` 48 | 49 | JAMR-style graph metdata format: 50 | 51 | ``` 52 | # ::id 1 53 | # ::tok Dogs chase cats. 54 | # ::node c chase-01 55 | # ::node d dog 56 | # ::node c2 cat 57 | # ::root c chase-01 58 | # ::edge chase-01 ARG0 dog c d 59 | # ::edge chase-01 ARG1 cat c c2 60 | (c/chase-01 :ARG0 (d/dog) 61 | :ARG1 (c2/cat)) 62 | ``` 63 | 64 | ### Loading Alignments from LDC, JAMR, or ISI 65 | AMR Alignments can also be loaded from different formats: 66 | - LDC: 67 | `# ::alignments 0-1.1 1-1 1-1.1.r 1-1.2.r 2-1.2` 68 | - JAMR: 69 | `# ::alignments 0-1|0.0 1-2|0 2-3|0.1` 70 | - ISI: 71 | `(c/chase-01~e.1 :ARG0~e.1 (d/dog~e.0) :ARG1~e.1 (c2/cat~e.2))` 72 | 73 | Just set the parameter `output_alignments` to `True`. 74 | 75 | ``` 76 | from amr_utils.amr_readers import AMR_Reader 77 | 78 | reader = AMR_Reader() 79 | amrs, alignments = reader.load(amr_file, remove_wiki=True, output_alignments=True) 80 | ``` 81 | 82 | By default, `AMR_Reader` uses the LDC/ISI style of node ids where 1.n is the nth child of the root with indices starting at 1. 83 | Any alignments are automatically converted to this format for data consistency. 84 | 85 | # AMR Alignments JSON Format 86 | The package includes tools for converting AMR alignments from and to JSON like the following. 87 | ``` 88 | [{'type':'isi', 'tokens':[0], 'nodes':['1.1'], 'edges':[]}, 89 | {'type':'isi', 'tokens':[1], 'nodes':['1'], 'edges':[['1',':ARG0','1.1'],['1',':ARG1','1.2']]}, 90 | {'type':'isi', 'tokens':[2], 'nodes':['1.2'], 'edges':[]}, 91 | ] 92 | ``` 93 | 94 | The advantages of using JSON are: 95 | - Easy to load and save (No need to write a special script for reading some esoteric format) 96 | - Can store additional information in a `type` to distinguish different types of alignments 97 | - Can easily store multiple sets of alignments separately, without needing to modify an AMR file. That makes it easy to compare different sets of alignments. 98 | 99 | To read alignments from a JSON file do: 100 | ``` 101 | reader = AMR_Reader() 102 | alignments = reader.load_alignments_from_json(alignments_file) 103 | ``` 104 | To save alignments to a JSON file do: 105 | ``` 106 | reader = AMR_Reader() 107 | reader.save_alignments_to_json(alignments_file, alignments) 108 | ``` 109 | # AMR Visualization 110 | AMR-utils includes tools for visualizing AMRs and AMR aligments. See the wiki for more detail. 111 | 112 | ## Latex 113 | AMR-utils allows you to read AMRs from a text file and output them as latex diagrams, such as the following. 114 | ![latex example](https://github.com/ablodge/amr-utils/blob/master/latex_ex.PNG) 115 | 116 | ### Colors 117 | The default coloring assigns blue to each node, but the parameter `assign_color` can be used to assign colors using a function. To change a color by hand, just rewrite `\node[red]` as `\node[purple]`, etc. 118 | 119 | ### Instructions 120 | Run as follows: 121 | 122 | `python style.py --latex -f [input file] [output file]` 123 | 124 | Add these lines to your latex file: 125 | 126 | ``` 127 | \usepackage{tikz} 128 | \usetikzlibrary{shapes} 129 | ``` 130 | 131 | 132 | ## HTML 133 | AMR-utils allows you to read AMRs from a text file and output them as html. You can look in `style.css` for an example of styling. 134 | ![html example](https://github.com/ablodge/amr-utils/blob/master/html_ex.PNG) 135 | ### Instructions 136 | Run as follows: 137 | 138 | `python style.py --html -f [input file] [output file]` 139 | 140 | 141 | ## AMR Diff 142 | 143 | AMR Diff is a tool for comparing two files of AMRs. The tool uses AMR-to-AMR alignment from [smatch](https://github.com/snowblink14/smatch) to find the differences between pairs of AMRs which contribute to a lower smatch score. AMR Diff is useful for detailed error analysis of AMR parsers. The display includes highlighted differences and mouse-over description text explanation of the error. 144 | 145 | ![amr diff example](https://github.com/ablodge/amr-utils/blob/master/amr_diff_ex.PNG) 146 | ### Instructions 147 | Run as follows: 148 | 149 | `python amr_diff.py [amr file1] [amr file2] [output file]` 150 | 151 | 152 | ## Display Alignments 153 | AMR-utils also includes a tool for displaying alignments in an easy-to-read format, with highlights and mouse-over description text of which tokens/nodes/edges are aligned. 154 | 155 | ![display alignments example](https://github.com/ablodge/amr-utils/blob/master/display_align_ex.PNG) 156 | ### Instructions 157 | Run as follows: 158 | 159 | `python display_alignments.py [amr file] [alignment file] [output file]` 160 | 161 | 162 | # Graph Operations 163 | You can import graph operations from `graph_utils.py`: 164 | ``` 165 | from amr_utils.graph_utils import get_subgraph, is_rooted_dag, breadth_first_nodes, \ 166 | breadth_first_edges, depth_first_nodes, depth_first_edges, \ 167 | get_shortest_path, get_connected_components 168 | ``` 169 | Functions in `graph_utils.py` allow you to 170 | - Traverse AMR nodes or edges in depth-first or breadth-first order 171 | - Retrieve an AMR subgraph 172 | - Test an AMR or sub-AMR for DAG structure 173 | - Get the shortest path between two nodes 174 | - Seperate a subset of AMR nodes into connected components 175 | -------------------------------------------------------------------------------- /amr_diff_ex.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ablodge/amr-utils/be5534db1312dc7c6ba25ee50eafeb0d0f5e3f69/amr_diff_ex.PNG -------------------------------------------------------------------------------- /amr_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ablodge/amr-utils/be5534db1312dc7c6ba25ee50eafeb0d0f5e3f69/amr_utils/__init__.py -------------------------------------------------------------------------------- /amr_utils/alignments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | class AMR_Alignment: 6 | 7 | def __init__(self, type=None, tokens:list=None, nodes:list=None, edges:list=None, amr=None): 8 | self.type = type if type else 'basic' 9 | self.tokens = tokens if tokens else [] 10 | self.nodes = nodes if nodes else [] 11 | self.edges = edges if edges else [] 12 | self.amr = None 13 | if amr is not None: 14 | self.amr = amr 15 | 16 | def __bool__(self): 17 | return bool(self.tokens) and (bool(self.nodes) or bool(self.edges)) 18 | 19 | def __str__(self): 20 | if self.amr is not None: 21 | return f': tokens {self.tokens} nodes {self.nodes} edges {self.edges} ({self.readable(self.amr)})' 22 | return f': tokens {self.tokens} nodes {self.nodes} edges {self.edges}' 23 | 24 | def copy(self): 25 | align = AMR_Alignment(type=self.type, tokens=self.tokens.copy(), nodes=self.nodes.copy(), edges=self.edges.copy()) 26 | align.amr = self.amr 27 | return align 28 | 29 | def to_json(self, amr=None): 30 | if amr is not None: 31 | return {'type': self.type, 'tokens': self.tokens.copy(), 'nodes': self.nodes.copy(), 'edges': self.edges.copy(), 'string':self.readable(amr)} 32 | if self.amr is not None: 33 | return {'type': self.type, 'tokens': self.tokens.copy(), 'nodes': self.nodes.copy(), 'edges': self.edges.copy(), 'string':self.readable(self.amr)} 34 | return {'type':self.type, 'tokens':self.tokens.copy(), 'nodes':self.nodes.copy(), 'edges':self.edges.copy()} 35 | 36 | def readable(self, amr): 37 | type = '' if self.type=='basic' else self.type 38 | nodes = '' if not self.nodes else ", ".join(amr.nodes[n] for n in self.nodes) 39 | edges = '' if not self.edges else ", ".join(str((amr.nodes[s],r,amr.nodes[t])) for s,r,t in self.edges) 40 | tokens = " ".join(amr.tokens[t] for t in self.tokens) 41 | if nodes and edges: 42 | edges = ', '+edges 43 | if type: 44 | type += ' : ' 45 | return f'{type}{tokens} => {nodes}{edges}' 46 | 47 | 48 | def load_from_json(json_file, amrs=None, unanonymize=False): 49 | if amrs: 50 | amrs = {amr.id:amr for amr in amrs} 51 | with open(json_file, 'r', encoding='utf8') as f: 52 | alignments = json.load(f) 53 | for k in alignments: 54 | if unanonymize: 55 | if unanonymize and not amrs: 56 | raise Exception('To un-anonymize alignments, the parameter "amrs" is required.') 57 | for a in alignments[k]: 58 | if 'nodes' not in a: 59 | a['nodes'] = [] 60 | if 'edges' not in a: 61 | a['edges'] = [] 62 | amr = amrs[k] 63 | for i,e in enumerate(a['edges']): 64 | s,r,t = e 65 | if r is None: 66 | new_e = [e2 for e2 in amr.edges if e2[0]==s and e2[2]==t] 67 | if not new_e: 68 | print('Failed to un-anonymize:', amr.id, e, file=sys.stderr) 69 | else: 70 | new_e = new_e[0] 71 | a['edges'][i] = [s, new_e[1], t] 72 | alignments[k] = [AMR_Alignment(a['type'], a['tokens'], a['nodes'], [tuple(e) for e in a['edges']]) for a in alignments[k]] 73 | if amrs: 74 | for k in alignments: 75 | for align in alignments[k]: 76 | if k in amrs: 77 | align.amr = amrs[k] 78 | return alignments 79 | 80 | 81 | def write_to_json(json_file, alignments, anonymize=False, amrs=None): 82 | new_alignments = {} 83 | for k in alignments: 84 | new_alignments[k] = [a.to_json() for a in alignments[k]] 85 | if anonymize: 86 | if anonymize and not amrs: 87 | raise Exception('To anonymize alignments, the parameter "amrs" is required.') 88 | for a in new_alignments[k]: 89 | amr = next(amr_ for amr_ in amrs if amr_.id==k) 90 | for i,e in enumerate(a['edges']): 91 | if len([e2 for e2 in amr.edges if e2[0]==e[0] and e2[2]==e[2]])==1: 92 | a['edges'][i] = [e[0],None,e[2]] 93 | if 'string' in a: 94 | del a['string'] 95 | if 'nodes' in a and not a['nodes']: 96 | del a['nodes'] 97 | if 'edges' in a and not a['edges']: 98 | del a['edges'] 99 | with open(json_file, 'w+', encoding='utf8') as f: 100 | json.dump(new_alignments, f) 101 | 102 | 103 | -------------------------------------------------------------------------------- /amr_utils/amr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from amr_utils.alignments import AMR_Alignment 4 | 5 | 6 | class AMR: 7 | 8 | def __init__(self, tokens:list=None, id=None, root=None, nodes:dict=None, edges:list=None, metadata:dict=None): 9 | 10 | if edges is None: edges = [] 11 | if nodes is None: nodes = {} 12 | if tokens is None: tokens = [] 13 | if metadata is None: metadata = {} 14 | 15 | self.tokens = tokens 16 | self.root = root 17 | self.nodes = nodes 18 | self.edges = edges 19 | self.id = 'None' if id is None else id 20 | self.metadata = metadata 21 | 22 | def copy(self): 23 | return AMR(self.tokens.copy(), self.id, self.root, self.nodes.copy(), self.edges.copy(), self.metadata.copy()) 24 | 25 | def __str__(self): 26 | return metadata_string(self) 27 | 28 | def graph_string(self): 29 | return graph_string(self) 30 | 31 | def amr_string(self): 32 | return metadata_string(self) + graph_string(self)+'\n\n' 33 | 34 | def get_alignment(self, alignments, token_id=None, node_id=None, edge=None): 35 | if not isinstance(alignments, dict): 36 | raise Exception('Alignments object must be a dict.') 37 | if self.id not in alignments: 38 | return AMR_Alignment() 39 | for align in alignments[self.id]: 40 | if token_id is not None and token_id in align.tokens: 41 | return align 42 | if node_id is not None and node_id in align.nodes: 43 | return align 44 | if edge is not None and edge in align.edges: 45 | return align 46 | return AMR_Alignment() 47 | 48 | def triples(self, normalize_inverse_edges=False): 49 | taken_nodes = {self.root} 50 | yield self.root, ':instance', self.nodes[self.root] 51 | for s,r,t in self.edges: 52 | if not self.nodes[t][0].isalpha() or self.nodes[t] in ['imperative', 'expressive', 'interrogative']: 53 | yield s, r, self.nodes[t] 54 | continue 55 | if normalize_inverse_edges and r.endswith('-of') and r not in [':consist-of', ':prep-out-of', ':prep-on-behalf-of']: 56 | yield t, r[:-len('-of')], s 57 | else: 58 | yield s, r, t 59 | if t not in taken_nodes: 60 | yield t, ':instance', self.nodes[t] 61 | taken_nodes.add(t) 62 | 63 | def _rename_node(self, a, b): 64 | if b in self.nodes: 65 | raise Exception('Rename Node: Tried to use existing node name:', b) 66 | self.nodes[b] = self.nodes[a] 67 | del self.nodes[a] 68 | if self.root == a: 69 | self.root = b 70 | for i, e in enumerate(self.edges): 71 | s,r,t = e 72 | if a in [s, t]: 73 | if s==a: s=b 74 | if t==a: t=b 75 | self.edges[i] = (s,r,t) 76 | 77 | 78 | 79 | def metadata_string(amr): 80 | ''' 81 | # ::id sentence id 82 | # ::tok tokens... 83 | # ::node node_id node alignments 84 | # ::root root_id root 85 | # ::edge src label trg src_id trg_id alignments 86 | ''' 87 | output = '' 88 | # id 89 | if amr.id: 90 | output += f'# ::id {amr.id}\n' 91 | # tokens 92 | output += '# ::tok ' + (' '.join(amr.tokens)) + '\n' 93 | # metadata 94 | for label in amr.metadata: 95 | if label not in ['tok','id','node','root','edge','alignments']: 96 | output += f'# ::{label} {str(amr.metadata[label])}\n' 97 | # nodes 98 | for n in amr.nodes: 99 | output += f'# ::node\t{n}\t{amr.nodes[n].replace(" ","_") if n in amr.nodes else "None"}\n' 100 | # root 101 | root = amr.root 102 | if amr.root: 103 | output += f'# ::root\t{root}\t{amr.nodes[root] if root in amr.nodes else "None"}\n' 104 | # edges 105 | for i, e in enumerate(amr.edges): 106 | s, r, t = e 107 | r = r.replace(':', '') 108 | output += f'# ::edge\t{amr.nodes[s] if s in amr.nodes else "None"}\t{r}\t{amr.nodes[t] if t in amr.nodes else "None"}\t{s}\t{t}\n' 109 | 110 | return output 111 | 112 | 113 | def graph_string(amr): 114 | amr_string = f'[[{amr.root}]]' 115 | new_ids = {} 116 | for n in amr.nodes: 117 | new_id = amr.nodes[n][0] if amr.nodes[n] else 'x' 118 | if new_id.isalpha() and new_id.islower(): 119 | if new_id in new_ids.values(): 120 | j = 2 121 | while f'{new_id}{j}' in new_ids.values(): 122 | j += 1 123 | new_id = f'{new_id}{j}' 124 | else: 125 | j = 0 126 | while f'x{j}' in new_ids.values(): 127 | j += 1 128 | new_id = f'x{j}' 129 | new_ids[n] = new_id 130 | depth = 1 131 | nodes = {amr.root} 132 | completed = set() 133 | while '[[' in amr_string: 134 | tab = '\t' * depth 135 | for n in nodes.copy(): 136 | id = new_ids[n] if n in new_ids else 'x91' 137 | concept = amr.nodes[n] if n in new_ids and amr.nodes[n] else 'None' 138 | edges = sorted([e for e in amr.edges if e[0] == n], key=lambda x: x[1]) 139 | targets = set(t for s, r, t in edges) 140 | edges = [f'{r} [[{t}]]' for s, r, t in edges] 141 | children = f'\n{tab}'.join(edges) 142 | if children: 143 | children = f'\n{tab}' + children 144 | if n not in completed: 145 | if (concept[0].isalpha() and concept not in ['imperative', 'expressive', 'interrogative']) or targets: 146 | amr_string = amr_string.replace(f'[[{n}]]', f'({id}/{concept}{children})', 1) 147 | else: 148 | amr_string = amr_string.replace(f'[[{n}]]', f'{concept}') 149 | completed.add(n) 150 | amr_string = amr_string.replace(f'[[{n}]]', f'{id}') 151 | nodes.remove(n) 152 | nodes.update(targets) 153 | depth += 1 154 | if len(completed) < len(amr.nodes): 155 | missing_nodes = [n for n in amr.nodes if n not in completed] 156 | missing_edges = [(s, r, t) for s, r, t in amr.edges if s in missing_nodes or t in missing_nodes] 157 | missing_nodes= ', '.join(f'{n}/{amr.nodes[n]}' for n in missing_nodes) 158 | missing_edges = ', '.join(f'{s}/{amr.nodes[s]} {r} {t}/{amr.nodes[t]}' for s,r,t in missing_edges) 159 | print('[amr]', 'Failed to print AMR, ' 160 | + str(len(completed)) + ' of ' + str(len(amr.nodes)) + ' nodes printed:\n ' 161 | + str(amr.id) +':\n' 162 | + amr_string + '\n' 163 | + 'Missing nodes: ' + missing_nodes +'\n' 164 | + 'Missing edges: ' + missing_edges +'\n', 165 | file=sys.stderr) 166 | if not amr_string.startswith('('): 167 | amr_string = '(' + amr_string + ')' 168 | if len(amr.nodes) == 0: 169 | amr_string = '(a/amr-empty)' 170 | 171 | return amr_string 172 | -------------------------------------------------------------------------------- /amr_utils/amr_diff.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from amr_readers import AMR_Reader 4 | from style import HTML_AMR 5 | 6 | 7 | from amr_utils.graph_utils import get_node_alignment 8 | 9 | phase = 1 10 | 11 | def style(amr_pairs, other_args, assign_node_color=None, assign_node_desc=None, assign_edge_color=None, assign_edge_desc=None, 12 | assign_token_color=None, assign_token_desc=None, limit=None): 13 | global phase 14 | output = '\n' 15 | output += '\n' 16 | output += '\n\n' 19 | output += '\n' 20 | i = 0 21 | for id in amr_pairs: 22 | amr1, amr2 = amr_pairs[id] 23 | prec, rec, f1 = other_args[id][-3:] 24 | output += f'AMR 1:\n' 25 | phase = 1 26 | output += HTML_AMR.html(amr1, 27 | assign_node_color, assign_node_desc, 28 | assign_edge_color, assign_edge_desc, 29 | assign_token_color, assign_token_desc, 30 | other_args) 31 | output += 'AMR 2:\n' 32 | phase = 2 33 | output += HTML_AMR.html(amr2, 34 | assign_node_color, assign_node_desc, 35 | assign_edge_color, assign_edge_desc, 36 | assign_token_color, assign_token_desc, 37 | other_args) 38 | output += f'SMATCH: precision {100*prec:.1f} recall {100*rec:.1f} f1 {100*f1:.1f}\n' 39 | output += '
\n' 40 | i+=1 41 | if limit and i>limit: 42 | break 43 | output += '\n' 44 | output += '\n' 45 | return output 46 | 47 | 48 | def is_correct_node(amr, n, other_args): 49 | amr1, amr2, map1, map2 = other_args[amr.id][:4] 50 | if phase==1: 51 | other_amr = amr2 52 | node_map = map1 53 | else: 54 | other_amr = amr1 55 | node_map = map2 56 | if amr.nodes[n] == other_amr.nodes[node_map[n]]: 57 | return '' 58 | return 'red' 59 | 60 | 61 | def is_correct_edge(amr, e, other_args=None): 62 | amr1, amr2, map1, map2 = other_args[amr.id][:4] 63 | s,r,t = e 64 | if phase == 1: 65 | other_amr = amr2 66 | node_map = map1 67 | else: 68 | other_amr = amr1 69 | node_map = map2 70 | if (node_map[s],r,node_map[t]) in other_amr.edges: 71 | return '' 72 | return 'red' 73 | 74 | 75 | def is_correct_node_desc(amr, n, other_args=None): 76 | amr1, amr2, map1, map2 = other_args[amr.id][:4] 77 | if phase == 1: 78 | other_amr = amr2 79 | node_map = map1 80 | else: 81 | other_amr = amr1 82 | node_map = map2 83 | if amr.nodes[n] == other_amr.nodes[node_map[n]]: 84 | return '' 85 | if not amr.nodes[n][0].isalpha() or amr.nodes[n] in ['imperative', 'expressive', 'interrogative']: 86 | s,r,t = [(s,r,t) for s,r,t in amr.edges if t==n][0] 87 | return f'No corresponding attribute {other_amr.nodes[node_map[s]]} {r} {amr.nodes[t]}' 88 | return f'{amr.nodes[n]} != {other_amr.nodes[node_map[n]]}' 89 | 90 | 91 | def is_correct_edge_desc(amr, e, other_args=None): 92 | amr1, amr2, map1, map2 = other_args[amr.id][:4] 93 | s, r, t = e 94 | if phase == 1: 95 | other_amr = amr2 96 | node_map = map1 97 | else: 98 | other_amr = amr1 99 | node_map = map2 100 | if (node_map[s], r, node_map[t]) in other_amr.edges: 101 | return '' 102 | # attribute 103 | if not amr.nodes[t][0].isalpha() or amr.nodes[t] in ['imperative', 'expressive', 'interrogative']: 104 | return f'No corresponding attribute {other_amr.nodes[node_map[s]]} {r} {amr.nodes[t]}' 105 | # relation 106 | return f'No corresponding relation {other_amr.nodes[node_map[s]]} {r} {other_amr.nodes[node_map[t]]}' 107 | 108 | def main(): 109 | global amr_pairs 110 | import argparse 111 | 112 | # parser = argparse.ArgumentParser(description='Visually compare two AMR files') 113 | # parser.add_argument('files', type=str, nargs=2, required=True, 114 | # help='input files (AMRs in JAMR format)') 115 | # parser.add_argument('output', type=str, required=True, 116 | # help='output file (html)') 117 | # args = parser.parse_args() 118 | 119 | file1 = sys.argv[1] 120 | file2 = sys.argv[2] 121 | outfile = sys.argv[3] 122 | 123 | reader = AMR_Reader() 124 | amrs1 = reader.load(file1, remove_wiki=True) 125 | amrs2 = reader.load(file2, remove_wiki=True) 126 | 127 | other_args = {} 128 | amr_pairs = {} 129 | for amr1, amr2 in zip(amrs1, amrs2): 130 | map1, prec, rec, f1 = get_node_alignment(amr1, amr2) 131 | map2, _, _, _ = get_node_alignment(amr2, amr1) 132 | amr2.id = amr1.id 133 | other_args[amr1.id] = (amr1, amr2, map1, map2, prec, rec, f1) 134 | amr_pairs[amr1.id] = (amr1, amr2) 135 | output = style(amr_pairs, 136 | other_args, 137 | assign_node_color=is_correct_node, 138 | assign_node_desc=is_correct_node_desc, 139 | assign_edge_color=is_correct_edge, 140 | assign_edge_desc=is_correct_edge_desc, 141 | limit=2000 142 | ) 143 | 144 | with open(outfile, 'w+', encoding='utf8') as f: 145 | f.write(output) 146 | 147 | 148 | if __name__=='__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /amr_utils/amr_readers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import csv 5 | 6 | import penman 7 | 8 | from amr_utils.alignments import AMR_Alignment, write_to_json, load_from_json 9 | from amr_utils.amr import AMR 10 | 11 | 12 | class Matedata_Parser: 13 | 14 | token_range_re = re.compile('^(\d-\d|\d(,\d)+)$') 15 | metadata_re = re.compile('(?<=[^#]) ::') 16 | 17 | def __init__(self): 18 | pass 19 | 20 | def get_token_range(self, string): 21 | if '-' in string: 22 | start = int(string.split('-')[0]) 23 | end = int(string.split('-')[-1]) 24 | return [i for i in range(start, end)] 25 | else: 26 | return [int(i) for i in string.split(',')] 27 | 28 | def readlines(self, lines): 29 | lines = self.metadata_re.sub('\n# ::', lines) 30 | metadata = {} 31 | graph_metadata = {} 32 | rows = [self.readline_(line) for line in lines.split('\n')] 33 | labels = {label for label,_ in rows} 34 | for label in labels: 35 | if label in ['root','node','edge']: 36 | graph_metadata[label] = [val for l,val in rows if label==l] 37 | else: 38 | metadata[label] = [val for l,val in rows if label==l][0] 39 | if 'snt' not in metadata and 'tok' not in metadata: 40 | metadata['snt'] = [''] 41 | return metadata, graph_metadata 42 | 43 | def readline_(self, line): 44 | if not line.startswith('#'): 45 | label = 'snt' 46 | metadata = line.strip() 47 | elif line.startswith('# ::id'): 48 | label = 'id' 49 | metadata = line[len('# ::id '):].strip() 50 | elif line.startswith("# ::tok"): 51 | label = 'tok' 52 | metadata = line[len('# ::tok '):].strip().split() 53 | elif line.startswith('# ::snt '): 54 | label = 'snt' 55 | metadata = line[len('# ::snt '):].strip() 56 | elif line.startswith('# ::alignments'): 57 | label = 'alignments' 58 | metadata = line[len('# ::alignments '):].strip() 59 | elif line.startswith('# ::node') or line.startswith('# ::root') or line.startswith('# ::edge'): 60 | label = line[len('# ::'):].split()[0] 61 | line = line[len(f'# ::{label} '):] 62 | rows = [row for row in csv.reader([line], delimiter='\t', quotechar='|')] 63 | metadata = rows[0] 64 | for i, s in enumerate(metadata): 65 | if self.token_range_re.match(s): 66 | metadata[i] = self.get_token_range(s) 67 | elif line.startswith('# ::'): 68 | label = line[len('# ::'):].split()[0] 69 | line = line[len(f'# ::{label} '):] 70 | metadata = line 71 | else: 72 | label = 'snt' 73 | metadata = line[len('# '):].strip() 74 | return label, metadata 75 | 76 | 77 | from penman.model import Model 78 | 79 | 80 | class TreePenmanModel(Model): 81 | def deinvert(self, triple): 82 | return triple 83 | 84 | def invert(self, triple): 85 | return triple 86 | 87 | 88 | class PENMAN_Wrapper: 89 | 90 | def __init__(self, style='isi'): 91 | self.style = style 92 | 93 | def parse_amr(self, tokens, amr_string): 94 | amr = AMR(tokens=tokens) 95 | g = penman.decode(amr_string, model=TreePenmanModel()) 96 | triples = g.triples() if callable(g.triples) else g.triples 97 | 98 | letter_labels = {} 99 | isi_labels = {g.top: '1'} 100 | isi_edge_labels = {} 101 | jamr_labels = {g.top: '0'} 102 | 103 | new_idx = 0 104 | 105 | isi_edge_idx = {g.top: 1} 106 | jamr_edge_idx = {g.top: 0} 107 | 108 | nodes = [] 109 | attributes = [] 110 | edges = [] 111 | reentrancies = [] 112 | 113 | for i,tr in enumerate(triples): 114 | s, r, t = tr 115 | # an amr node 116 | if r == ':instance': 117 | if reentrancies and edges[-1]==reentrancies[-1]: 118 | s2,r2,t2 = edges[-1] 119 | jamr_labels[t2] = jamr_labels[s2] + '.' + str(jamr_edge_idx[s2]) 120 | isi_labels[t2] = isi_labels[s2] + '.' + str(isi_edge_idx[s2]) 121 | new_s = s 122 | while new_s in letter_labels: 123 | new_idx += 1 124 | new_s = f'x{new_idx}' 125 | letter_labels[s] = new_s 126 | nodes.append(tr) 127 | # an amr edge 128 | elif t not in letter_labels: 129 | if len(t) > 5 or not t[0].isalpha(): 130 | if tr in letter_labels: 131 | isi_labels['ignore'] = isi_labels[s] + '.' + str(isi_edge_idx[s]) 132 | isi_edge_labels['ignore'] = isi_labels[s] + '.' + str(isi_edge_idx[s])+'.r' 133 | isi_edge_idx[s] += 1 134 | jamr_edge_idx[s] += 1 135 | continue 136 | # attribute 137 | new_s = s 138 | while new_s in letter_labels: 139 | new_idx += 1 140 | new_s = f'x{new_idx}' 141 | letter_labels[tr] = new_s 142 | jamr_labels[tr] = jamr_labels[s] + '.' + str(jamr_edge_idx[s]) 143 | isi_labels[tr] = isi_labels[s] + '.' + str(isi_edge_idx[s]) 144 | isi_edge_labels[tr] = isi_labels[s] + '.' + str(isi_edge_idx[s])+'.r' 145 | isi_edge_idx[s] += 1 146 | jamr_edge_idx[s] += 1 147 | attributes.append(tr) 148 | else: 149 | # edge 150 | jamr_edge_idx[t] = 0 151 | isi_edge_idx[t] = 1 152 | jamr_labels[t] = jamr_labels[s] + '.' + str(jamr_edge_idx[s]) 153 | if i+1 5 or not t[0].isalpha(): 198 | align = AMR_Alignment(type='isi', tokens=list(indices), nodes=[default_labels[tr]]) 199 | else: 200 | align = AMR_Alignment(type='isi', tokens=list(indices), edges=[edge_map[tr]]) 201 | aligns.append(align) 202 | 203 | letter_labels = {v: default_labels[k] for k,v in letter_labels.items()} 204 | jamr_labels = {v: default_labels[k] for k, v in jamr_labels.items()} 205 | isi_labels = {v: default_labels[k] if k!='ignore' else k for k, v in isi_labels.items()} 206 | isi_edge_labels = {v: edge_map[k] if k in edge_map else k for k, v in isi_edge_labels.items()} 207 | 208 | return amr, (letter_labels, jamr_labels, isi_labels, isi_edge_labels, aligns) 209 | 210 | 211 | class AMR_Reader: 212 | 213 | def __init__(self, style='isi'): 214 | self.style=style 215 | 216 | def load(self, amr_file_name, remove_wiki=False, output_alignments=False): 217 | print('[amr]', 'Loading AMRs from file:', amr_file_name) 218 | amrs = [] 219 | alignments = {} 220 | penman_wrapper = PENMAN_Wrapper(style=self.style) 221 | metadata_parser = Matedata_Parser() 222 | 223 | with open(amr_file_name, 'r', encoding='utf8') as f: 224 | sents = f.read().replace('\r', '').split('\n\n') 225 | amr_idx = 0 226 | no_tokens = False 227 | if all(sent.strip().startswith('(') for sent in sents): 228 | no_tokens = True 229 | 230 | for sent in sents: 231 | prefix_lines = [line for i,line in enumerate(sent.split('\n')) if line.strip().startswith('#') or (i==0 and not no_tokens)] 232 | prefix = '\n'.join(prefix_lines) 233 | amr_string_lines = [line for i, line in enumerate(sent.split('\n')) 234 | if not line.strip().startswith('#') and (i>0 or no_tokens)] 235 | amr_string = ''.join(amr_string_lines).strip() 236 | amr_string = re.sub(' +', ' ', amr_string) 237 | if not amr_string: continue 238 | if not amr_string.startswith('(') or not amr_string.endswith(')'): 239 | raise Exception('Could not parse AMR from: ', amr_string) 240 | metadata, graph_metadata = metadata_parser.readlines(prefix) 241 | tokens = metadata['tok'] if 'tok' in metadata else metadata['snt'].split() 242 | tokens = self._clean_tokens(tokens) 243 | if graph_metadata: 244 | amr, aligns = self._parse_amr_from_metadata(tokens, graph_metadata) 245 | amr.id = metadata['id'] 246 | if output_alignments: 247 | alignments[amr.id] = aligns 248 | else: 249 | amr, other_stuff = penman_wrapper.parse_amr(tokens, amr_string) 250 | if 'id' in metadata: 251 | amr.id = metadata['id'] 252 | else: 253 | amr.id = str(amr_idx) 254 | if output_alignments: 255 | alignments[amr.id] = [] 256 | if 'alignments' in metadata: 257 | aligns = metadata['alignments'].split() 258 | if any('|' in a for a in aligns): 259 | jamr_labels = other_stuff[1] 260 | alignments[amr.id] = self._parse_jamr_alignments(amr, amr_file_name, aligns, jamr_labels, metadata_parser) 261 | else: 262 | isi_labels, isi_edge_labels = other_stuff[2:4] 263 | alignments[amr.id] = self._parse_isi_alignments(amr, amr_file_name, aligns, isi_labels, isi_edge_labels) 264 | else: 265 | aligns = other_stuff[4] 266 | alignments[amr.id] = aligns 267 | amr.metadata = {k:v for k,v in metadata.items() if k not in ['tok','id']} 268 | amrs.append(amr) 269 | amr_idx += 1 270 | if remove_wiki: 271 | for amr in amrs: 272 | wiki_nodes = [] 273 | wiki_edges = [] 274 | for s, r, t in amr.edges.copy(): 275 | if r == ':wiki': 276 | amr.edges.remove((s, r, t)) 277 | del amr.nodes[t] 278 | wiki_nodes.append(t) 279 | wiki_edges.append((s,r,t)) 280 | if alignments and amr.id in alignments: 281 | for align in alignments[amr.id]: 282 | for n in wiki_nodes: 283 | if n in align.nodes: 284 | align.nodes.remove(n) 285 | for e in wiki_edges: 286 | if e in align.edges: 287 | align.edges.remove(e) 288 | if output_alignments: 289 | return amrs, alignments 290 | return amrs 291 | 292 | def load_from_dir(self, dir, remove_wiki=False, output_alignments=False): 293 | all_amrs = [] 294 | all_alignments = {} 295 | 296 | taken_ids = set() 297 | for filename in os.listdir(dir): 298 | if filename.endswith('.txt'): 299 | print(filename) 300 | file = os.path.join(dir, filename) 301 | amrs, aligns = self.load(file, output_alignments=True, remove_wiki=remove_wiki) 302 | for amr in amrs: 303 | if amr.id.isdigit(): 304 | old_id = amr.id 305 | amr.id = filename+':'+old_id 306 | aligns[amr.id] = aligns[old_id] 307 | del aligns[old_id] 308 | for amr in amrs: 309 | if amr.id in taken_ids: 310 | old_id = amr.id 311 | amr.id += '#2' 312 | if old_id in aligns: 313 | aligns[amr.id] = aligns[old_id] 314 | del aligns[old_id] 315 | taken_ids.add(amr.id) 316 | all_amrs.extend(amrs) 317 | all_alignments.update(aligns) 318 | if output_alignments: 319 | return all_amrs, all_alignments 320 | return all_amrs 321 | 322 | @staticmethod 323 | def write_to_file(output_file, amrs): 324 | with open(output_file, 'w+', encoding='utf8') as f: 325 | for amr in amrs: 326 | f.write(amr.amr_string()) 327 | 328 | @staticmethod 329 | def load_alignments_from_json(json_file, amrs=None): 330 | return load_from_json(json_file, amrs=amrs) 331 | 332 | @staticmethod 333 | def save_alignments_to_json(json_file, alignments): 334 | write_to_json(json_file, alignments) 335 | 336 | @staticmethod 337 | def _parse_jamr_alignments(amr, amr_file, aligns, jamr_labels, metadata_parser): 338 | aligns = [(metadata_parser.get_token_range(a.split('|')[0]), a.split('|')[-1].split('+')) for a in aligns if '|' in a] 339 | 340 | alignments = [] 341 | for toks, components in aligns: 342 | if not all(n in jamr_labels for n in components) or any(t>=len(amr.tokens) for t in toks): 343 | raise Exception('Could not parse alignment:', amr_file, amr.id, toks, components) 344 | nodes = [jamr_labels[n] for n in components] 345 | new_align = AMR_Alignment(type='jamr', tokens=toks, nodes=nodes) 346 | alignments.append(new_align) 347 | return alignments 348 | 349 | @staticmethod 350 | def _parse_isi_alignments(amr, amr_file, aligns, isi_labels, isi_edge_labels): 351 | aligns = [(int(a.split('-')[0]), a.split('-')[-1]) for a in aligns if '-' in a] 352 | 353 | alignments = [] 354 | xml_offset = 1 if amr.tokens[0].startswith('<') and amr.tokens[0].endswith('>') else 0 355 | if any(t + xml_offset >= len(amr.tokens) for t, n in aligns): 356 | xml_offset = 0 357 | 358 | for tok, component in aligns: 359 | tok += xml_offset 360 | nodes = [] 361 | edges = [] 362 | if component.replace('.r', '') in isi_labels: 363 | # node or attribute 364 | n = isi_labels[component.replace('.r', '')] 365 | if n=='ignore': continue 366 | nodes.append(n) 367 | if n not in amr.nodes: 368 | raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) 369 | elif not component.endswith('.r') and component not in isi_labels and component + '.r' in isi_edge_labels: 370 | # reentrancy 371 | e = isi_edge_labels[component + '.r'] 372 | edges.append(e) 373 | if e not in amr.edges: 374 | raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) 375 | elif component.endswith('.r'): 376 | # edge 377 | e = isi_edge_labels[component] 378 | if e == 'ignore': continue 379 | edges.append(e) 380 | if e not in amr.edges: 381 | raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) 382 | elif component == '0.r': 383 | nodes.append(amr.root) 384 | else: 385 | raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) 386 | if tok >= len(amr.tokens): 387 | raise Exception('Could not parse alignment:', amr_file, amr.id, tok, component) 388 | new_align = AMR_Alignment(type='isi', tokens=[tok], nodes=nodes, edges=edges) 389 | alignments.append(new_align) 390 | return alignments 391 | 392 | @staticmethod 393 | def _parse_amr_from_metadata(tokens, metadata): 394 | ''' 395 | Metadata format is ... 396 | # ::id sentence id 397 | # ::tok tokens... 398 | # ::node node_id node alignments 399 | # ::root root_id root 400 | # ::edge src label trg src_id trg_id alignments 401 | amr graph 402 | ''' 403 | amr = AMR(tokens=tokens) 404 | alignments = [] 405 | 406 | nodes = metadata['node'] 407 | edges = metadata['edge'] if 'edge'in metadata else [] 408 | root = metadata['root'][0] 409 | amr.root = root[0] 410 | for data in nodes: 411 | n, label = data[:2] 412 | if len(data)>2: 413 | toks = data[2] 414 | alignments.append(AMR_Alignment(type='jamr', nodes=[n], tokens=toks)) 415 | amr.nodes[n] = label 416 | for data in edges: 417 | _, r, _, s, t = data[:5] 418 | if len(data)>5: 419 | toks = data[5] 420 | alignments.append(AMR_Alignment(type='jamr', edges=[(s,r,t)], tokens=toks)) 421 | if not r.startswith(':'): r = ':'+r 422 | amr.edges.append((s,r,t)) 423 | return amr, alignments 424 | 425 | @staticmethod 426 | def _clean_tokens(tokens): 427 | line = ' '.join(tokens) 428 | if '<' in line and '>' in line: 429 | tokens_reformat = [] 430 | is_xml = False 431 | for i, tok in enumerate(tokens): 432 | if is_xml: 433 | tokens_reformat[-1] += '_' + tok 434 | if '>' in tok: 435 | is_xml = False 436 | else: 437 | tokens_reformat.append(tok) 438 | if tok.startswith('<') and not '>' in tok: 439 | if len(tok) > 1 and (tok[1].isalpha() or tok[1] == '/'): 440 | if i + 1 < len(tokens) and '=' in tokens[i + 1]: 441 | is_xml = True 442 | tokens = tokens_reformat 443 | return tokens 444 | 445 | 446 | 447 | def main(): 448 | dir = sys.argv[1] 449 | output_file = sys.argv[2] 450 | 451 | reader = AMR_Reader() 452 | amrs, alignments = reader.load_from_dir(dir, output_alignments=True) 453 | 454 | reader.write_to_file(output_file, amrs) 455 | reader.save_alignments_to_json(output_file.replace('.txt','.alignments.json'), alignments) 456 | 457 | 458 | if __name__ == '__main__': 459 | main() 460 | -------------------------------------------------------------------------------- /amr_utils/display_alignments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from amr_utils.amr_readers import AMR_Reader 4 | from amr_utils.graph_utils import is_rooted_dag, get_subgraph 5 | from amr_utils.style import HTML_AMR 6 | 7 | 8 | def is_aligned_node(amr, n, alignments): 9 | align = amr.get_alignment(alignments, node_id=n) 10 | if align: 11 | return 'green' 12 | return '' 13 | 14 | 15 | def is_aligned_edge(amr, e, alignments): 16 | align = amr.get_alignment(alignments, edge=e) 17 | return 'grey' if align else '' 18 | 19 | 20 | def is_aligned_token(amr, t, alignments): 21 | align = amr.get_alignment(alignments, token_id=t) 22 | return 'green' if align else '' 23 | 24 | 25 | def get_node_aligned_tokens(amr, n, alignments): 26 | align = amr.get_alignment(alignments, node_id=n) 27 | if align: 28 | return ' '.join(amr.tokens[t] for t in align.tokens) 29 | return '' 30 | 31 | 32 | def get_edge_aligned_tokens(amr, e, alignments): 33 | align = amr.get_alignment(alignments, edge=e) 34 | if align: 35 | return ' '.join(amr.tokens[t] for t in align.tokens) 36 | return '' 37 | 38 | 39 | def get_token_aligned_subgraph(amr, tok, alignments): 40 | align = amr.get_alignment(alignments, token_id=tok) 41 | if align: 42 | elems = [amr.nodes[n] for n in align.nodes] 43 | elems += [r for s,r,t in align.edges] 44 | # return ' '.join(elems) 45 | edges = [(s,r,t) for s,r,t in amr.edges if ((s,r,t) in align.edges or (s in align.nodes and t in align.nodes))] 46 | sg = get_subgraph(amr, align.nodes, edges) 47 | if is_rooted_dag(amr, align.nodes): 48 | out = sg.graph_string() 49 | else: 50 | out = ', '.join(elems) 51 | return out 52 | return '' 53 | 54 | 55 | def style(amrs, alignments, outfile): 56 | output = HTML_AMR.style(amrs[:5000], 57 | assign_node_color=is_aligned_node, 58 | assign_edge_color=is_aligned_edge, 59 | assign_token_color=is_aligned_token, 60 | assign_node_desc=get_node_aligned_tokens, 61 | assign_edge_desc=get_edge_aligned_tokens, 62 | assign_token_desc=get_token_aligned_subgraph, 63 | other_args=alignments) 64 | 65 | with open(outfile, 'w+', encoding='utf8') as f: 66 | f.write(output) 67 | 68 | 69 | def main(): 70 | file = sys.argv[1] 71 | align_file = sys.argv[2] 72 | outfile = sys.argv[3] 73 | 74 | reader = AMR_Reader() 75 | amrs = reader.load(file, remove_wiki=True) 76 | alignments = reader.load_alignments_from_json(align_file, amrs) 77 | style(amrs[:5000], alignments, outfile) 78 | 79 | 80 | if __name__=='__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /amr_utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from amr_utils.amr import AMR 3 | from amr_utils.smatch import get_best_match 4 | 5 | 6 | def get_subgraph(amr, nodes: list, edges: list): 7 | if not nodes: 8 | return AMR() 9 | potential_root = nodes.copy() 10 | for x, r, y in amr.edges: 11 | if x in nodes and y in nodes: 12 | if y in potential_root: 13 | potential_root.remove(y) 14 | root = potential_root[0] if len(potential_root) > 0 else nodes[0] 15 | sub = AMR(root=root, 16 | edges=edges, 17 | nodes={n: amr.nodes[n] for n in nodes}) 18 | for s,r,t in edges: 19 | if s not in nodes: 20 | sub.nodes[s] = '' 21 | if t not in nodes: 22 | sub.nodes[t] = '' 23 | return sub 24 | 25 | 26 | def is_rooted_dag(amr, nodes): 27 | if not nodes: 28 | return False 29 | roots = nodes.copy() 30 | edges = [(s,r,t) for s,r,t in amr.edges if s in nodes and t in nodes] 31 | for s,r,t in edges: 32 | if t in roots: 33 | roots.remove(t) 34 | if len(roots)==1: 35 | return True 36 | return False 37 | 38 | 39 | def get_connected_components(amr, nodes): 40 | 41 | if not nodes: 42 | return [] 43 | descendants = {n:{n} for n in nodes} 44 | roots = [n for n in nodes] 45 | taken = set() 46 | edges = [(s, r, t) for s, r, t in breadth_first_edges(amr, ignore_reentrancies=True) if s in nodes and t in nodes] 47 | for s, r, t in edges: 48 | if t in taken: continue 49 | taken.add(t) 50 | if t in roots: 51 | roots.remove(t) 52 | for d in descendants: 53 | if s in descendants[d]: 54 | descendants[d].update(descendants[t]) 55 | components = [] 56 | for root in roots: 57 | edges = [] 58 | for s,r,t in breadth_first_edges(amr, ignore_reentrancies=True): 59 | if s in descendants[root] and t in descendants[root]: 60 | edges.append((s,r,t)) 61 | sub = AMR(nodes={n:amr.nodes[n] for n in descendants[root]}, root=root, edges=edges) 62 | components.append(sub) 63 | components = sorted(components, key=lambda x:len(x.nodes), reverse=True) 64 | return list(components) 65 | 66 | 67 | def is_projective_node_(amr, n, descendants, positions, ignore=None): 68 | span = {positions[m] for m in descendants if m in positions} 69 | if not span: 70 | return True, [] 71 | max_token = max(span) 72 | min_token = min(span) 73 | if max_token - min_token <= 1: 74 | return True, [i for i in range(min_token,max_token+1)] 75 | for tok in range(min_token + 1, max_token): 76 | if ignore and tok in ignore: 77 | continue 78 | if tok in span: 79 | continue 80 | align = amr.get_alignment(token_id=tok) 81 | if align and align.tokens[0] not in span: 82 | return False, [i for i in range(min_token,max_token+1)] 83 | return True, [i for i in range(min_token,max_token+1)] 84 | 85 | 86 | def is_projective(amr): 87 | 88 | descendants = {n: {n} for n in amr.nodes.keys()} 89 | for s, r, t in breadth_first_edges(amr, ignore_reentrancies=True): 90 | for d in descendants: 91 | if s in descendants[d]: 92 | descendants[d].update(descendants[t]) 93 | positions = {} 94 | alignments = {} 95 | for n in amr.nodes: 96 | align = amr.get_alignment(node_id=n) 97 | alignments[n] = align 98 | if align: 99 | positions[n] = align.tokens[0] 100 | 101 | nonprojective = {} 102 | used = set() 103 | for n in breadth_first_nodes(amr): 104 | if n in used: 105 | continue 106 | test, span = is_projective_node_(amr, n, descendants[n], positions) 107 | used.update(alignments[n].nodes) 108 | if not test: 109 | nonprojective[n] = span 110 | if not nonprojective: 111 | return True, [] 112 | for n in list(nonprojective.keys()): 113 | for d in descendants[n]: 114 | if d!=n and d in nonprojective: 115 | del nonprojective[n] 116 | break 117 | 118 | used = set() 119 | culprits = [] 120 | for n in nonprojective: 121 | for tok in nonprojective[n]: 122 | align = amr.get_alignment(token_id=tok) 123 | if not align or align in used: 124 | continue 125 | test, _ = is_projective_node_(amr, n, descendants[n], positions, ignore=align.tokens) 126 | used.add(align) 127 | if test: 128 | culprits.append(align) 129 | return False, culprits 130 | 131 | 132 | def breadth_first_nodes(amr): 133 | if amr.root is None: 134 | return 135 | nodes = [amr.root] 136 | children = [(s,r,t) for s,r,t in amr.edges if s in nodes] 137 | children = sorted(children, key=lambda x: x[1].lower()) 138 | edges = [e for e in amr.edges] 139 | yield amr.root 140 | while True: 141 | for s,r,t in children: 142 | if t not in nodes: 143 | nodes.append(t) 144 | yield t 145 | edges.remove((s,r,t)) 146 | children = [(s, r, t) for s, r, t in edges if s in nodes and t not in nodes] 147 | children = list(sorted(children, key=lambda x: x[1].lower())) 148 | if not children: 149 | break 150 | 151 | 152 | def breadth_first_edges(amr, ignore_reentrancies=False): 153 | if amr.root is None: 154 | return 155 | nodes = [amr.root] 156 | children = [(s,r,t) for s,r,t in amr.edges if s in nodes] 157 | children = sorted(children, key=lambda x: x[1].lower()) 158 | edges = [e for e in amr.edges] 159 | while True: 160 | for s,r,t in children: 161 | edges.remove((s, r, t)) 162 | if ignore_reentrancies and t in nodes: 163 | continue 164 | if t not in nodes: 165 | nodes.append(t) 166 | yield (s,r,t) 167 | children = [(s, r, t) for s, r, t in edges if s in nodes] 168 | children = list(sorted(children, key=lambda x: x[1].lower())) 169 | if not children: 170 | break 171 | 172 | 173 | def depth_first_nodes(amr): 174 | visited, stack = {amr.root}, [] 175 | children = [(s, r, t) for s, r, t in amr.edges if s == amr.root and t not in visited] 176 | children = list(sorted(children, key=lambda x: x[1].lower(), reverse=True)) 177 | stack.extend(children) 178 | edges = [e for e in amr.edges] 179 | yield amr.root 180 | 181 | while stack: 182 | s, r, t = stack.pop() 183 | if t in visited: 184 | continue 185 | yield t 186 | edges.remove((s, r, t)) 187 | visited.add(t) 188 | children = [(s2, r2, t2) for s2, r2, t2 in edges if s2 == t] 189 | children = list(sorted(children, key=lambda x: x[1].lower(), reverse=True)) 190 | stack.extend(children) 191 | 192 | 193 | def depth_first_edges(amr, ignore_reentrancies=False): 194 | visited, stack = {amr.root}, [] 195 | children = [(s, r, t) for s, r, t in amr.edges if s == amr.root and t not in visited] 196 | children = list(sorted(children, key=lambda x: x[1].lower(), reverse=True)) 197 | stack.extend(children) 198 | edges = [e for e in amr.edges] 199 | 200 | while stack: 201 | s,r,t = stack.pop() 202 | if ignore_reentrancies and t in visited: 203 | continue 204 | yield (s,r,t) 205 | edges.remove((s,r,t)) 206 | visited.add(t) 207 | children = [(s2,r2,t2) for s2,r2,t2 in edges if s2==t] 208 | children = list(sorted(children, key=lambda x: x[1].lower(), reverse=True)) 209 | stack.extend(children) 210 | 211 | 212 | def get_shortest_path(amr, n1, n2, ignore_reentrancies=False): 213 | path = [n1] 214 | for s,r,t in depth_first_edges(amr, ignore_reentrancies): 215 | if s in path: 216 | while path[-1]!=s: 217 | path.pop() 218 | path.append(t) 219 | if t==n2: 220 | return path 221 | return None 222 | 223 | 224 | # def is_cycle(amr, nodes): 225 | # descendants = {n: {n} for n in nodes} 226 | # for s, r, t in amr.edges: 227 | # if s in nodes and t in nodes: 228 | # for d in descendants: 229 | # if s in descendants[d]: 230 | # descendants[d].update(descendants[t]) 231 | # for n in nodes: 232 | # for n2 in nodes: 233 | # if n==n2: 234 | # continue 235 | # if n in descendants[n2] and n2 in descendants[n]: 236 | # return True 237 | # return False 238 | 239 | 240 | def get_node_alignment(amr1:AMR, amr2:AMR): 241 | prefix1 = "a" 242 | prefix2 = "b" 243 | node_map1 = {} 244 | node_map2 = {} 245 | idx = 0 246 | for n in amr1.nodes.copy(): 247 | amr1._rename_node(n, prefix1+str(idx)) 248 | node_map1[prefix1+str(idx)] = n 249 | idx+=1 250 | idx = 0 251 | for n in amr2.nodes.copy(): 252 | amr2._rename_node(n, prefix2+str(idx)) 253 | node_map2[prefix2 + str(idx)] = n 254 | idx += 1 255 | instance1 = [] 256 | attributes1 = [] 257 | relation1 = [] 258 | for s,r,t in amr1.triples(normalize_inverse_edges=True): 259 | if r==':instance': 260 | instance1.append((r,s,t)) 261 | elif t not in amr1.nodes: 262 | attributes1.append((r,s,t)) 263 | else: 264 | relation1.append((r,s,t)) 265 | instance2 = [] 266 | attributes2 = [] 267 | relation2 = [] 268 | for s,r,t in amr2.triples(normalize_inverse_edges=True): 269 | if r==':instance': 270 | instance2.append((r,s,t)) 271 | elif t not in amr2.nodes: 272 | attributes2.append((r,s,t)) 273 | else: 274 | relation2.append((r,s,t)) 275 | # optionally turn off some of the node comparison 276 | doinstance = doattribute = dorelation = True 277 | (best_mapping, best_match_num) = get_best_match(instance1, attributes1, relation1, 278 | instance2, attributes2, relation2, 279 | prefix1, prefix2, doinstance=doinstance, 280 | doattribute=doattribute, dorelation=dorelation) 281 | test_triple_num = len(instance1) + len(attributes1) + len(relation1) 282 | gold_triple_num = len(instance2) + len(attributes2) + len(relation2) 283 | for n in amr1.nodes.copy(): 284 | amr1._rename_node(n, node_map1[n]) 285 | for n in amr2.nodes.copy(): 286 | amr2._rename_node(n, node_map2[n]) 287 | 288 | align_map = {} 289 | for i,j in enumerate(best_mapping): 290 | a = prefix1 + str(i) 291 | if j==-1: 292 | continue 293 | b = prefix2 + str(j) 294 | align_map[node_map1[a]] = node_map2[b] 295 | if amr1.root not in align_map: 296 | align_map[amr1.root] = amr2.root 297 | for s,r,t in breadth_first_edges(amr1, ignore_reentrancies=True): 298 | if t not in align_map: 299 | for s2,r2,t2 in amr2.edges: 300 | if align_map[s]==s2 and r==r2 and amr1.nodes[t]==amr2.nodes[t2]: 301 | align_map[t] = t2 302 | if t not in align_map: 303 | align_map[t] = align_map[s] 304 | 305 | if not all(n in align_map for n in amr1.nodes): 306 | raise Exception('Failed to build node alignment:', amr1.id, amr2.id) 307 | prec = best_match_num / test_triple_num if test_triple_num>0 else 0 308 | rec = best_match_num / gold_triple_num if gold_triple_num>0 else 0 309 | f1 = 2*(prec*rec)/(prec+rec) if (prec+rec)>0 else 0 310 | return align_map, prec, rec, f1 311 | -------------------------------------------------------------------------------- /amr_utils/smatch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | This code is taken from https://github.com/snowblink14/smatch 6 | """ 7 | 8 | 9 | """ 10 | This script computes smatch score between two AMRs. 11 | For detailed description of smatch, see http://www.isi.edu/natural-language/amr/smatch-13.pdf 12 | 13 | """ 14 | 15 | import random 16 | 17 | # import amr 18 | import sys 19 | 20 | # total number of iteration in smatch computation 21 | iteration_num = 5 22 | 23 | # verbose output switch. 24 | # Default false (no verbose output) 25 | verbose = False 26 | veryVerbose = False 27 | 28 | # single score output switch. 29 | # Default true (compute a single score for all AMRs in two files) 30 | single_score = True 31 | 32 | # precision and recall output switch. 33 | # Default false (do not output precision and recall, just output F score) 34 | pr_flag = False 35 | 36 | # Error log location 37 | ERROR_LOG = sys.stderr 38 | 39 | # Debug log location 40 | DEBUG_LOG = sys.stderr 41 | 42 | # dictionary to save pre-computed node mapping and its resulting triple match count 43 | # key: tuples of node mapping 44 | # value: the matching triple count 45 | match_triple_dict = {} 46 | 47 | 48 | def get_best_match(instance1, attribute1, relation1, 49 | instance2, attribute2, relation2, 50 | prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): 51 | """ 52 | Get the highest triple match number between two sets of triples via hill-climbing. 53 | Arguments: 54 | instance1: instance triples of AMR 1 ("instance", node name, node value) 55 | attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) 56 | relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) 57 | instance2: instance triples of AMR 2 ("instance", node name, node value) 58 | attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) 59 | relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name) 60 | prefix1: prefix label for AMR 1 61 | prefix2: prefix label for AMR 2 62 | Returns: 63 | best_match: the node mapping that results in the highest triple matching number 64 | best_match_num: the highest triple matching number 65 | 66 | """ 67 | # Compute candidate pool - all possible node match candidates. 68 | # In the hill-climbing, we only consider candidate in this pool to save computing time. 69 | # weight_dict is a dictionary that maps a pair of node 70 | (candidate_mappings, weight_dict) = compute_pool(instance1, attribute1, relation1, 71 | instance2, attribute2, relation2, 72 | prefix1, prefix2, doinstance=doinstance, doattribute=doattribute, 73 | dorelation=dorelation) 74 | if veryVerbose: 75 | print("Candidate mappings:", file=DEBUG_LOG) 76 | print(candidate_mappings, file=DEBUG_LOG) 77 | print("Weight dictionary", file=DEBUG_LOG) 78 | print(weight_dict, file=DEBUG_LOG) 79 | 80 | best_match_num = 0 81 | # initialize best match mapping 82 | # the ith entry is the node index in AMR 2 which maps to the ith node in AMR 1 83 | best_mapping = [-1] * len(instance1) 84 | for i in range(iteration_num): 85 | if veryVerbose: 86 | print("Iteration", i, file=DEBUG_LOG) 87 | if i == 0: 88 | # smart initialization used for the first round 89 | cur_mapping = smart_init_mapping(candidate_mappings, instance1, instance2) 90 | else: 91 | # random initialization for the other round 92 | cur_mapping = random_init_mapping(candidate_mappings) 93 | # compute current triple match number 94 | match_num = compute_match(cur_mapping, weight_dict) 95 | if veryVerbose: 96 | print("Node mapping at start", cur_mapping, file=DEBUG_LOG) 97 | print("Triple match number at start:", match_num, file=DEBUG_LOG) 98 | while True: 99 | # get best gain 100 | (gain, new_mapping) = get_best_gain(cur_mapping, candidate_mappings, weight_dict, 101 | len(instance2), match_num) 102 | if veryVerbose: 103 | print("Gain after the hill-climbing", gain, file=DEBUG_LOG) 104 | # hill-climbing until there will be no gain for new node mapping 105 | if gain <= 0: 106 | break 107 | # otherwise update match_num and mapping 108 | match_num += gain 109 | cur_mapping = new_mapping[:] 110 | if veryVerbose: 111 | print("Update triple match number to:", match_num, file=DEBUG_LOG) 112 | print("Current mapping:", cur_mapping, file=DEBUG_LOG) 113 | if match_num > best_match_num: 114 | best_mapping = cur_mapping[:] 115 | best_match_num = match_num 116 | return best_mapping, best_match_num 117 | 118 | 119 | def normalize(item): 120 | """ 121 | lowercase and remove quote signifiers from items that are about to be compared 122 | """ 123 | return item.lower().rstrip('_') 124 | 125 | 126 | def compute_pool(instance1, attribute1, relation1, 127 | instance2, attribute2, relation2, 128 | prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): 129 | """ 130 | compute all possible node mapping candidates and their weights (the triple matching number gain resulting from 131 | mapping one node in AMR 1 to another node in AMR2) 132 | 133 | Arguments: 134 | instance1: instance triples of AMR 1 135 | attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) 136 | relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) 137 | instance2: instance triples of AMR 2 138 | attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) 139 | relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name 140 | prefix1: prefix label for AMR 1 141 | prefix2: prefix label for AMR 2 142 | Returns: 143 | candidate_mapping: a list of candidate nodes. 144 | The ith element contains the node indices (in AMR 2) the ith node (in AMR 1) can map to. 145 | (resulting in non-zero triple match) 146 | weight_dict: a dictionary which contains the matching triple number for every pair of node mapping. The key 147 | is a node pair. The value is another dictionary. key {-1} is triple match resulting from this node 148 | pair alone (instance triples and attribute triples), and other keys are node pairs that can result 149 | in relation triple match together with the first node pair. 150 | 151 | 152 | """ 153 | candidate_mapping = [] 154 | weight_dict = {} 155 | for instance1_item in instance1: 156 | # each candidate mapping is a set of node indices 157 | candidate_mapping.append(set()) 158 | if doinstance: 159 | for instance2_item in instance2: 160 | # if both triples are instance triples and have the same value 161 | if normalize(instance1_item[0]) == normalize(instance2_item[0]) and \ 162 | normalize(instance1_item[2]) == normalize(instance2_item[2]): 163 | # get node index by stripping the prefix 164 | node1_index = int(instance1_item[1][len(prefix1):]) 165 | node2_index = int(instance2_item[1][len(prefix2):]) 166 | candidate_mapping[node1_index].add(node2_index) 167 | node_pair = (node1_index, node2_index) 168 | # use -1 as key in weight_dict for instance triples and attribute triples 169 | if node_pair in weight_dict: 170 | weight_dict[node_pair][-1] += 1 171 | else: 172 | weight_dict[node_pair] = {} 173 | weight_dict[node_pair][-1] = 1 174 | if doattribute: 175 | for attribute1_item in attribute1: 176 | for attribute2_item in attribute2: 177 | # if both attribute relation triple have the same relation name and value 178 | if normalize(attribute1_item[0]) == normalize(attribute2_item[0]) \ 179 | and normalize(attribute1_item[2]) == normalize(attribute2_item[2]): 180 | node1_index = int(attribute1_item[1][len(prefix1):]) 181 | node2_index = int(attribute2_item[1][len(prefix2):]) 182 | candidate_mapping[node1_index].add(node2_index) 183 | node_pair = (node1_index, node2_index) 184 | # use -1 as key in weight_dict for instance triples and attribute triples 185 | if node_pair in weight_dict: 186 | weight_dict[node_pair][-1] += 1 187 | else: 188 | weight_dict[node_pair] = {} 189 | weight_dict[node_pair][-1] = 1 190 | if dorelation: 191 | for relation1_item in relation1: 192 | for relation2_item in relation2: 193 | # if both relation share the same name 194 | if normalize(relation1_item[0]) == normalize(relation2_item[0]): 195 | node1_index_amr1 = int(relation1_item[1][len(prefix1):]) 196 | node1_index_amr2 = int(relation2_item[1][len(prefix2):]) 197 | node2_index_amr1 = int(relation1_item[2][len(prefix1):]) 198 | node2_index_amr2 = int(relation2_item[2][len(prefix2):]) 199 | # add mapping between two nodes 200 | candidate_mapping[node1_index_amr1].add(node1_index_amr2) 201 | candidate_mapping[node2_index_amr1].add(node2_index_amr2) 202 | node_pair1 = (node1_index_amr1, node1_index_amr2) 203 | node_pair2 = (node2_index_amr1, node2_index_amr2) 204 | if node_pair2 != node_pair1: 205 | # update weight_dict weight. Note that we need to update both entries for future search 206 | # i.e weight_dict[node_pair1][node_pair2] 207 | # weight_dict[node_pair2][node_pair1] 208 | if node1_index_amr1 > node2_index_amr1: 209 | # swap node_pair1 and node_pair2 210 | node_pair1 = (node2_index_amr1, node2_index_amr2) 211 | node_pair2 = (node1_index_amr1, node1_index_amr2) 212 | if node_pair1 in weight_dict: 213 | if node_pair2 in weight_dict[node_pair1]: 214 | weight_dict[node_pair1][node_pair2] += 1 215 | else: 216 | weight_dict[node_pair1][node_pair2] = 1 217 | else: 218 | weight_dict[node_pair1] = {-1: 0, node_pair2: 1} 219 | if node_pair2 in weight_dict: 220 | if node_pair1 in weight_dict[node_pair2]: 221 | weight_dict[node_pair2][node_pair1] += 1 222 | else: 223 | weight_dict[node_pair2][node_pair1] = 1 224 | else: 225 | weight_dict[node_pair2] = {-1: 0, node_pair1: 1} 226 | else: 227 | # two node pairs are the same. So we only update weight_dict once. 228 | # this generally should not happen. 229 | if node_pair1 in weight_dict: 230 | weight_dict[node_pair1][-1] += 1 231 | else: 232 | weight_dict[node_pair1] = {-1: 1} 233 | return candidate_mapping, weight_dict 234 | 235 | 236 | def smart_init_mapping(candidate_mapping, instance1, instance2): 237 | """ 238 | Initialize mapping based on the concept mapping (smart initialization) 239 | Arguments: 240 | candidate_mapping: candidate node match list 241 | instance1: instance triples of AMR 1 242 | instance2: instance triples of AMR 2 243 | Returns: 244 | initialized node mapping between two AMRs 245 | 246 | """ 247 | random.seed() 248 | matched_dict = {} 249 | result = [] 250 | # list to store node indices that have no concept match 251 | no_word_match = [] 252 | for i, candidates in enumerate(candidate_mapping): 253 | if not candidates: 254 | # no possible mapping 255 | result.append(-1) 256 | continue 257 | # node value in instance triples of AMR 1 258 | value1 = instance1[i][2] 259 | for node_index in candidates: 260 | value2 = instance2[node_index][2] 261 | # find the first instance triple match in the candidates 262 | # instance triple match is having the same concept value 263 | if value1 == value2: 264 | if node_index not in matched_dict: 265 | result.append(node_index) 266 | matched_dict[node_index] = 1 267 | break 268 | if len(result) == i: 269 | no_word_match.append(i) 270 | result.append(-1) 271 | # if no concept match, generate a random mapping 272 | for i in no_word_match: 273 | candidates = list(candidate_mapping[i]) 274 | while candidates: 275 | # get a random node index from candidates 276 | rid = random.randint(0, len(candidates) - 1) 277 | candidate = candidates[rid] 278 | if candidate in matched_dict: 279 | candidates.pop(rid) 280 | else: 281 | matched_dict[candidate] = 1 282 | result[i] = candidate 283 | break 284 | return result 285 | 286 | 287 | def random_init_mapping(candidate_mapping): 288 | """ 289 | Generate a random node mapping. 290 | Args: 291 | candidate_mapping: candidate_mapping: candidate node match list 292 | Returns: 293 | randomly-generated node mapping between two AMRs 294 | 295 | """ 296 | # if needed, a fixed seed could be passed here to generate same random (to help debugging) 297 | random.seed() 298 | matched_dict = {} 299 | result = [] 300 | for c in candidate_mapping: 301 | candidates = list(c) 302 | if not candidates: 303 | # -1 indicates no possible mapping 304 | result.append(-1) 305 | continue 306 | found = False 307 | while candidates: 308 | # randomly generate an index in [0, length of candidates) 309 | rid = random.randint(0, len(candidates) - 1) 310 | candidate = candidates[rid] 311 | # check if it has already been matched 312 | if candidate in matched_dict: 313 | candidates.pop(rid) 314 | else: 315 | matched_dict[candidate] = 1 316 | result.append(candidate) 317 | found = True 318 | break 319 | if not found: 320 | result.append(-1) 321 | return result 322 | 323 | 324 | def compute_match(mapping, weight_dict): 325 | """ 326 | Given a node mapping, compute match number based on weight_dict. 327 | Args: 328 | mappings: a list of node index in AMR 2. The ith element (value j) means node i in AMR 1 maps to node j in AMR 2. 329 | Returns: 330 | matching triple number 331 | Complexity: O(m*n) , m is the node number of AMR 1, n is the node number of AMR 2 332 | 333 | """ 334 | # If this mapping has been investigated before, retrieve the value instead of re-computing. 335 | if veryVerbose: 336 | print("Computing match for mapping", file=DEBUG_LOG) 337 | print(mapping, file=DEBUG_LOG) 338 | if tuple(mapping) in match_triple_dict: 339 | if veryVerbose: 340 | print("saved value", match_triple_dict[tuple(mapping)], file=DEBUG_LOG) 341 | return match_triple_dict[tuple(mapping)] 342 | match_num = 0 343 | # i is node index in AMR 1, m is node index in AMR 2 344 | for i, m in enumerate(mapping): 345 | if m == -1: 346 | # no node maps to this node 347 | continue 348 | # node i in AMR 1 maps to node m in AMR 2 349 | current_node_pair = (i, m) 350 | if current_node_pair not in weight_dict: 351 | continue 352 | if veryVerbose: 353 | print("node_pair", current_node_pair, file=DEBUG_LOG) 354 | for key in weight_dict[current_node_pair]: 355 | if key == -1: 356 | # matching triple resulting from instance/attribute triples 357 | match_num += weight_dict[current_node_pair][key] 358 | if veryVerbose: 359 | print("instance/attribute match", weight_dict[current_node_pair][key], file=DEBUG_LOG) 360 | # only consider node index larger than i to avoid duplicates 361 | # as we store both weight_dict[node_pair1][node_pair2] and 362 | # weight_dict[node_pair2][node_pair1] for a relation 363 | elif key[0] < i: 364 | continue 365 | elif mapping[key[0]] == key[1]: 366 | match_num += weight_dict[current_node_pair][key] 367 | if veryVerbose: 368 | print("relation match with", key, weight_dict[current_node_pair][key], file=DEBUG_LOG) 369 | if veryVerbose: 370 | print("match computing complete, result:", match_num, file=DEBUG_LOG) 371 | # update match_triple_dict 372 | match_triple_dict[tuple(mapping)] = match_num 373 | return match_num 374 | 375 | 376 | def move_gain(mapping, node_id, old_id, new_id, weight_dict, match_num): 377 | """ 378 | Compute the triple match number gain from the move operation 379 | Arguments: 380 | mapping: current node mapping 381 | node_id: remapped node in AMR 1 382 | old_id: original node id in AMR 2 to which node_id is mapped 383 | new_id: new node in to which node_id is mapped 384 | weight_dict: weight dictionary 385 | match_num: the original triple matching number 386 | Returns: 387 | the triple match gain number (might be negative) 388 | 389 | """ 390 | # new node mapping after moving 391 | new_mapping = (node_id, new_id) 392 | # node mapping before moving 393 | old_mapping = (node_id, old_id) 394 | # new nodes mapping list (all node pairs) 395 | new_mapping_list = mapping[:] 396 | new_mapping_list[node_id] = new_id 397 | # if this mapping is already been investigated, use saved one to avoid duplicate computing 398 | if tuple(new_mapping_list) in match_triple_dict: 399 | return match_triple_dict[tuple(new_mapping_list)] - match_num 400 | gain = 0 401 | # add the triple match incurred by new_mapping to gain 402 | if new_mapping in weight_dict: 403 | for key in weight_dict[new_mapping]: 404 | if key == -1: 405 | # instance/attribute triple match 406 | gain += weight_dict[new_mapping][-1] 407 | elif new_mapping_list[key[0]] == key[1]: 408 | # relation gain incurred by new_mapping and another node pair in new_mapping_list 409 | gain += weight_dict[new_mapping][key] 410 | # deduct the triple match incurred by old_mapping from gain 411 | if old_mapping in weight_dict: 412 | for k in weight_dict[old_mapping]: 413 | if k == -1: 414 | gain -= weight_dict[old_mapping][-1] 415 | elif mapping[k[0]] == k[1]: 416 | gain -= weight_dict[old_mapping][k] 417 | # update match number dictionary 418 | match_triple_dict[tuple(new_mapping_list)] = match_num + gain 419 | return gain 420 | 421 | 422 | def swap_gain(mapping, node_id1, mapping_id1, node_id2, mapping_id2, weight_dict, match_num): 423 | """ 424 | Compute the triple match number gain from the swapping 425 | Arguments: 426 | mapping: current node mapping list 427 | node_id1: node 1 index in AMR 1 428 | mapping_id1: the node index in AMR 2 node 1 maps to (in the current mapping) 429 | node_id2: node 2 index in AMR 1 430 | mapping_id2: the node index in AMR 2 node 2 maps to (in the current mapping) 431 | weight_dict: weight dictionary 432 | match_num: the original matching triple number 433 | Returns: 434 | the gain number (might be negative) 435 | 436 | """ 437 | new_mapping_list = mapping[:] 438 | # Before swapping, node_id1 maps to mapping_id1, and node_id2 maps to mapping_id2 439 | # After swapping, node_id1 maps to mapping_id2 and node_id2 maps to mapping_id1 440 | new_mapping_list[node_id1] = mapping_id2 441 | new_mapping_list[node_id2] = mapping_id1 442 | if tuple(new_mapping_list) in match_triple_dict: 443 | return match_triple_dict[tuple(new_mapping_list)] - match_num 444 | gain = 0 445 | new_mapping1 = (node_id1, mapping_id2) 446 | new_mapping2 = (node_id2, mapping_id1) 447 | old_mapping1 = (node_id1, mapping_id1) 448 | old_mapping2 = (node_id2, mapping_id2) 449 | if node_id1 > node_id2: 450 | new_mapping2 = (node_id1, mapping_id2) 451 | new_mapping1 = (node_id2, mapping_id1) 452 | old_mapping1 = (node_id2, mapping_id2) 453 | old_mapping2 = (node_id1, mapping_id1) 454 | if new_mapping1 in weight_dict: 455 | for key in weight_dict[new_mapping1]: 456 | if key == -1: 457 | gain += weight_dict[new_mapping1][-1] 458 | elif new_mapping_list[key[0]] == key[1]: 459 | gain += weight_dict[new_mapping1][key] 460 | if new_mapping2 in weight_dict: 461 | for key in weight_dict[new_mapping2]: 462 | if key == -1: 463 | gain += weight_dict[new_mapping2][-1] 464 | # to avoid duplicate 465 | elif key[0] == node_id1: 466 | continue 467 | elif new_mapping_list[key[0]] == key[1]: 468 | gain += weight_dict[new_mapping2][key] 469 | if old_mapping1 in weight_dict: 470 | for key in weight_dict[old_mapping1]: 471 | if key == -1: 472 | gain -= weight_dict[old_mapping1][-1] 473 | elif mapping[key[0]] == key[1]: 474 | gain -= weight_dict[old_mapping1][key] 475 | if old_mapping2 in weight_dict: 476 | for key in weight_dict[old_mapping2]: 477 | if key == -1: 478 | gain -= weight_dict[old_mapping2][-1] 479 | # to avoid duplicate 480 | elif key[0] == node_id1: 481 | continue 482 | elif mapping[key[0]] == key[1]: 483 | gain -= weight_dict[old_mapping2][key] 484 | match_triple_dict[tuple(new_mapping_list)] = match_num + gain 485 | return gain 486 | 487 | 488 | def get_best_gain(mapping, candidate_mappings, weight_dict, instance_len, cur_match_num): 489 | """ 490 | Hill-climbing method to return the best gain swap/move can get 491 | Arguments: 492 | mapping: current node mapping 493 | candidate_mappings: the candidates mapping list 494 | weight_dict: the weight dictionary 495 | instance_len: the number of the nodes in AMR 2 496 | cur_match_num: current triple match number 497 | Returns: 498 | the best gain we can get via swap/move operation 499 | 500 | """ 501 | largest_gain = 0 502 | # True: using swap; False: using move 503 | use_swap = True 504 | # the node to be moved/swapped 505 | node1 = None 506 | # store the other node affected. In swap, this other node is the node swapping with node1. In move, this other 507 | # node is the node node1 will move to. 508 | node2 = None 509 | # unmatched nodes in AMR 2 510 | unmatched = set(range(instance_len)) 511 | # exclude nodes in current mapping 512 | # get unmatched nodes 513 | for nid in mapping: 514 | if nid in unmatched: 515 | unmatched.remove(nid) 516 | for i, nid in enumerate(mapping): 517 | # current node i in AMR 1 maps to node nid in AMR 2 518 | for nm in unmatched: 519 | if nm in candidate_mappings[i]: 520 | # remap i to another unmatched node (move) 521 | # (i, m) -> (i, nm) 522 | if veryVerbose: 523 | print("Remap node", i, "from ", nid, "to", nm, file=DEBUG_LOG) 524 | mv_gain = move_gain(mapping, i, nid, nm, weight_dict, cur_match_num) 525 | if veryVerbose: 526 | print("Move gain:", mv_gain, file=DEBUG_LOG) 527 | new_mapping = mapping[:] 528 | new_mapping[i] = nm 529 | new_match_num = compute_match(new_mapping, weight_dict) 530 | if new_match_num != cur_match_num + mv_gain: 531 | print(mapping, new_mapping, file=ERROR_LOG) 532 | print("Inconsistency in computing: move gain", cur_match_num, mv_gain, new_match_num, 533 | file=ERROR_LOG) 534 | if mv_gain > largest_gain: 535 | largest_gain = mv_gain 536 | node1 = i 537 | node2 = nm 538 | use_swap = False 539 | # compute swap gain 540 | for i, m in enumerate(mapping): 541 | for j in range(i + 1, len(mapping)): 542 | m2 = mapping[j] 543 | # swap operation (i, m) (j, m2) -> (i, m2) (j, m) 544 | # j starts from i+1, to avoid duplicate swap 545 | if veryVerbose: 546 | print("Swap node", i, "and", j, file=DEBUG_LOG) 547 | print("Before swapping:", i, "-", m, ",", j, "-", m2, file=DEBUG_LOG) 548 | print(mapping, file=DEBUG_LOG) 549 | print("After swapping:", i, "-", m2, ",", j, "-", m, file=DEBUG_LOG) 550 | sw_gain = swap_gain(mapping, i, m, j, m2, weight_dict, cur_match_num) 551 | if veryVerbose: 552 | print("Swap gain:", sw_gain, file=DEBUG_LOG) 553 | new_mapping = mapping[:] 554 | new_mapping[i] = m2 555 | new_mapping[j] = m 556 | print(new_mapping, file=DEBUG_LOG) 557 | new_match_num = compute_match(new_mapping, weight_dict) 558 | if new_match_num != cur_match_num + sw_gain: 559 | print(mapping, new_mapping, file=ERROR_LOG) 560 | print("Inconsistency in computing: swap gain", cur_match_num, sw_gain, new_match_num, 561 | file=ERROR_LOG) 562 | if sw_gain > largest_gain: 563 | largest_gain = sw_gain 564 | node1 = i 565 | node2 = j 566 | use_swap = True 567 | # generate a new mapping based on swap/move 568 | cur_mapping = mapping[:] 569 | if node1 is not None: 570 | if use_swap: 571 | if veryVerbose: 572 | print("Use swap gain", file=DEBUG_LOG) 573 | temp = cur_mapping[node1] 574 | cur_mapping[node1] = cur_mapping[node2] 575 | cur_mapping[node2] = temp 576 | else: 577 | if veryVerbose: 578 | print("Use move gain", file=DEBUG_LOG) 579 | cur_mapping[node1] = node2 580 | else: 581 | if veryVerbose: 582 | print("no move/swap gain found", file=DEBUG_LOG) 583 | if veryVerbose: 584 | print("Original mapping", mapping, file=DEBUG_LOG) 585 | print("Current mapping", cur_mapping, file=DEBUG_LOG) 586 | return largest_gain, cur_mapping 587 | 588 | 589 | def print_alignment(mapping, instance1, instance2): 590 | """ 591 | print the alignment based on a node mapping 592 | Args: 593 | mapping: current node mapping list 594 | instance1: nodes of AMR 1 595 | instance2: nodes of AMR 2 596 | 597 | """ 598 | result = [] 599 | for instance1_item, m in zip(instance1, mapping): 600 | r = instance1_item[1] + "(" + instance1_item[2] + ")" 601 | if m == -1: 602 | r += "-Null" 603 | else: 604 | instance2_item = instance2[m] 605 | r += "-" + instance2_item[1] + "(" + instance2_item[2] + ")" 606 | result.append(r) 607 | return " ".join(result) 608 | 609 | 610 | # def compute_f(match_num, test_num, gold_num): 611 | # """ 612 | # Compute the f-score based on the matching triple number, 613 | # triple number of AMR set 1, 614 | # triple number of AMR set 2 615 | # Args: 616 | # match_num: matching triple number 617 | # test_num: triple number of AMR 1 (test file) 618 | # gold_num: triple number of AMR 2 (gold file) 619 | # Returns: 620 | # precision: match_num/test_num 621 | # recall: match_num/gold_num 622 | # f_score: 2*precision*recall/(precision+recall) 623 | # """ 624 | # if test_num == 0 or gold_num == 0: 625 | # return 0.00, 0.00, 0.00 626 | # precision = float(match_num) / float(test_num) 627 | # recall = float(match_num) / float(gold_num) 628 | # if (precision + recall) != 0: 629 | # f_score = 2 * precision * recall / (precision + recall) 630 | # if veryVerbose: 631 | # print("F-score:", f_score, file=DEBUG_LOG) 632 | # return precision, recall, f_score 633 | # else: 634 | # if veryVerbose: 635 | # print("F-score:", "0.0", file=DEBUG_LOG) 636 | # return precision, recall, 0.00 637 | # 638 | # 639 | # def generate_amr_lines(f1, f2): 640 | # """ 641 | # Read one AMR line at a time from each file handle 642 | # :param f1: file handle (or any iterable of strings) to read AMR 1 lines from 643 | # :param f2: file handle (or any iterable of strings) to read AMR 2 lines from 644 | # :return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings 645 | # """ 646 | # while True: 647 | # cur_amr1 = amr.AMR.get_amr_line(f1) 648 | # cur_amr2 = amr.AMR.get_amr_line(f2) 649 | # if not cur_amr1 and not cur_amr2: 650 | # pass 651 | # elif not cur_amr1: 652 | # print("Error: File 1 has less AMRs than file 2", file=ERROR_LOG) 653 | # print("Ignoring remaining AMRs", file=ERROR_LOG) 654 | # elif not cur_amr2: 655 | # print("Error: File 2 has less AMRs than file 1", file=ERROR_LOG) 656 | # print("Ignoring remaining AMRs", file=ERROR_LOG) 657 | # else: 658 | # yield cur_amr1, cur_amr2 659 | # continue 660 | # break 661 | # 662 | # 663 | # def get_amr_match(cur_amr1, cur_amr2, sent_num=1, justinstance=False, justattribute=False, justrelation=False): 664 | # amr_pair = [] 665 | # for i, cur_amr in (1, cur_amr1), (2, cur_amr2): 666 | # try: 667 | # amr_pair.append(amr.AMR.parse_AMR_line(cur_amr)) 668 | # except Exception as e: 669 | # print("Error in parsing amr %d: %s" % (i, cur_amr), file=ERROR_LOG) 670 | # print("Please check if the AMR is ill-formatted. Ignoring remaining AMRs", file=ERROR_LOG) 671 | # print("Error message: %s" % e, file=ERROR_LOG) 672 | # amr1, amr2 = amr_pair 673 | # prefix1 = "a" 674 | # prefix2 = "b" 675 | # # Rename node to "a1", "a2", .etc 676 | # amr1.rename_node(prefix1) 677 | # # Renaming node to "b1", "b2", .etc 678 | # amr2.rename_node(prefix2) 679 | # (instance1, attributes1, relation1) = amr1.get_triples() 680 | # (instance2, attributes2, relation2) = amr2.get_triples() 681 | # if verbose: 682 | # print("AMR pair", sent_num, file=DEBUG_LOG) 683 | # print("============================================", file=DEBUG_LOG) 684 | # print("AMR 1 (one-line):", cur_amr1, file=DEBUG_LOG) 685 | # print("AMR 2 (one-line):", cur_amr2, file=DEBUG_LOG) 686 | # print("Instance triples of AMR 1:", len(instance1), file=DEBUG_LOG) 687 | # print(instance1, file=DEBUG_LOG) 688 | # print("Attribute triples of AMR 1:", len(attributes1), file=DEBUG_LOG) 689 | # print(attributes1, file=DEBUG_LOG) 690 | # print("Relation triples of AMR 1:", len(relation1), file=DEBUG_LOG) 691 | # print(relation1, file=DEBUG_LOG) 692 | # print("Instance triples of AMR 2:", len(instance2), file=DEBUG_LOG) 693 | # print(instance2, file=DEBUG_LOG) 694 | # print("Attribute triples of AMR 2:", len(attributes2), file=DEBUG_LOG) 695 | # print(attributes2, file=DEBUG_LOG) 696 | # print("Relation triples of AMR 2:", len(relation2), file=DEBUG_LOG) 697 | # print(relation2, file=DEBUG_LOG) 698 | # # optionally turn off some of the node comparison 699 | # doinstance = doattribute = dorelation = True 700 | # if justinstance: 701 | # doattribute = dorelation = False 702 | # if justattribute: 703 | # doinstance = dorelation = False 704 | # if justrelation: 705 | # doinstance = doattribute = False 706 | # (best_mapping, best_match_num) = get_best_match(instance1, attributes1, relation1, 707 | # instance2, attributes2, relation2, 708 | # prefix1, prefix2, doinstance=doinstance, 709 | # doattribute=doattribute, dorelation=dorelation) 710 | # if verbose: 711 | # print("best match number", best_match_num, file=DEBUG_LOG) 712 | # print("best node mapping", best_mapping, file=DEBUG_LOG) 713 | # print("Best node mapping alignment:", print_alignment(best_mapping, instance1, instance2), file=DEBUG_LOG) 714 | # if justinstance: 715 | # test_triple_num = len(instance1) 716 | # gold_triple_num = len(instance2) 717 | # elif justattribute: 718 | # test_triple_num = len(attributes1) 719 | # gold_triple_num = len(attributes2) 720 | # elif justrelation: 721 | # test_triple_num = len(relation1) 722 | # gold_triple_num = len(relation2) 723 | # else: 724 | # test_triple_num = len(instance1) + len(attributes1) + len(relation1) 725 | # gold_triple_num = len(instance2) + len(attributes2) + len(relation2) 726 | # return best_match_num, test_triple_num, gold_triple_num 727 | # 728 | # 729 | # def score_amr_pairs(f1, f2, justinstance=False, justattribute=False, justrelation=False): 730 | # """ 731 | # Score one pair of AMR lines at a time from each file handle 732 | # :param f1: file handle (or any iterable of strings) to read AMR 1 lines from 733 | # :param f2: file handle (or any iterable of strings) to read AMR 2 lines from 734 | # :param justinstance: just pay attention to matching instances 735 | # :param justattribute: just pay attention to matching attributes 736 | # :param justrelation: just pay attention to matching relations 737 | # :return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings 738 | # """ 739 | # # matching triple number, triple number in test file, triple number in gold file 740 | # total_match_num = total_test_num = total_gold_num = 0 741 | # # Read amr pairs from two files 742 | # for sent_num, (cur_amr1, cur_amr2) in enumerate(generate_amr_lines(f1, f2), start=1): 743 | # best_match_num, test_triple_num, gold_triple_num = get_amr_match(cur_amr1, cur_amr2, 744 | # sent_num=sent_num, # sentence number 745 | # justinstance=justinstance, 746 | # justattribute=justattribute, 747 | # justrelation=justrelation) 748 | # total_match_num += best_match_num 749 | # total_test_num += test_triple_num 750 | # total_gold_num += gold_triple_num 751 | # # clear the matching triple dictionary for the next AMR pair 752 | # match_triple_dict.clear() 753 | # if not single_score: # if each AMR pair should have a score, compute and output it here 754 | # yield compute_f(best_match_num, test_triple_num, gold_triple_num) 755 | # if verbose: 756 | # print("Total match number, total triple number in AMR 1, and total triple number in AMR 2:", file=DEBUG_LOG) 757 | # print(total_match_num, total_test_num, total_gold_num, file=DEBUG_LOG) 758 | # print("---------------------------------------------------------------------------------", file=DEBUG_LOG) 759 | # if single_score: # output document-level smatch score (a single f-score for all AMR pairs in two files) 760 | # yield compute_f(total_match_num, total_test_num, total_gold_num) 761 | # 762 | # 763 | # def main(arguments): 764 | # """ 765 | # Main function of smatch score calculation 766 | # """ 767 | # global verbose 768 | # global veryVerbose 769 | # global iteration_num 770 | # global single_score 771 | # global pr_flag 772 | # global match_triple_dict 773 | # # set the iteration number 774 | # # total iteration number = restart number + 1 775 | # iteration_num = arguments.r + 1 776 | # if arguments.ms: 777 | # single_score = False 778 | # if arguments.v: 779 | # verbose = True 780 | # if arguments.vv: 781 | # veryVerbose = True 782 | # if arguments.pr: 783 | # pr_flag = True 784 | # # significant digits to print out 785 | # floatdisplay = "%%.%df" % arguments.significant 786 | # for (precision, recall, best_f_score) in score_amr_pairs(args.f[0], args.f[1], 787 | # justinstance=arguments.justinstance, 788 | # justattribute=arguments.justattribute, 789 | # justrelation=arguments.justrelation): 790 | # # print("Sentence", sent_num) 791 | # if pr_flag: 792 | # print("Precision: " + floatdisplay % precision) 793 | # print("Recall: " + floatdisplay % recall) 794 | # print("F-score: " + floatdisplay % best_f_score) 795 | # args.f[0].close() 796 | # args.f[1].close() 797 | # 798 | # 799 | # if __name__ == "__main__": 800 | # import argparse 801 | # 802 | # parser = argparse.ArgumentParser(description="Smatch calculator") 803 | # parser.add_argument( 804 | # '-f', 805 | # nargs=2, 806 | # required=True, 807 | # type=argparse.FileType('r'), 808 | # help=('Two files containing AMR pairs. ' 809 | # 'AMRs in each file are separated by a single blank line')) 810 | # parser.add_argument( 811 | # '-r', 812 | # type=int, 813 | # default=4, 814 | # help='Restart number (Default:4)') 815 | # parser.add_argument( 816 | # '--significant', 817 | # type=int, 818 | # default=2, 819 | # help='significant digits to output (default: 2)') 820 | # parser.add_argument( 821 | # '-v', 822 | # action='store_true', 823 | # help='Verbose output (Default:false)') 824 | # parser.add_argument( 825 | # '--vv', 826 | # action='store_true', 827 | # help='Very Verbose output (Default:false)') 828 | # parser.add_argument( 829 | # '--ms', 830 | # action='store_true', 831 | # default=False, 832 | # help=('Output multiple scores (one AMR pair a score) ' 833 | # 'instead of a single document-level smatch score ' 834 | # '(Default: false)')) 835 | # parser.add_argument( 836 | # '--pr', 837 | # action='store_true', 838 | # default=False, 839 | # help=('Output precision and recall as well as the f-score. ' 840 | # 'Default: false')) 841 | # parser.add_argument( 842 | # '--justinstance', 843 | # action='store_true', 844 | # default=False, 845 | # help="just pay attention to matching instances") 846 | # parser.add_argument( 847 | # '--justattribute', 848 | # action='store_true', 849 | # default=False, 850 | # help="just pay attention to matching attributes") 851 | # parser.add_argument( 852 | # '--justrelation', 853 | # action='store_true', 854 | # default=False, 855 | # help="just pay attention to matching relations") 856 | # 857 | # args = parser.parse_args() 858 | # main(args) -------------------------------------------------------------------------------- /amr_utils/style.py: -------------------------------------------------------------------------------- 1 | import html 2 | import sys 3 | 4 | 5 | 6 | class Latex_AMR: 7 | ''' 8 | \begin{tikzpicture}[ 9 | red/.style={rectangle, draw=red!60, fill=red!5, very thick, minimum size=7mm}, 10 | blue/.style={rectangle, draw=blue!60, fill=blue!5, very thick, minimum size=7mm}, 11 | ] 12 | %Nodes 13 | \node[red] (r) at (5,4) {read-01}; 14 | \node[purple](p) at (3.33,2) {person}; 15 | \node[green] (b) at (6.67,2) {book}; 16 | \node[blue] (j) at (5,0) {``John''}; 17 | 18 | %Edges 19 | \draw[->] (r.south) -- (p.north) node[midway, above, sloped] {:ARG0}; 20 | \draw[->] (r.south) -- (b.north) node[midway, above, sloped] {:ARG1}; 21 | \draw[->] (p.south) -- (j.north) node[midway, above, sloped] {:name}; 22 | \end{tikzpicture} 23 | ''' 24 | 25 | @staticmethod 26 | def prefix(): 27 | return '\\usepackage{tikz}\n\\usetikzlibrary{shapes}\n\n' 28 | 29 | 30 | @staticmethod 31 | def latex(amr, assign_color='blue'): 32 | 33 | colors = set() 34 | node_depth = {amr.root:0} 35 | nodes = [amr.root] 36 | done = {amr.root} 37 | depth = 1 38 | while True: 39 | new_nodes = set() 40 | for s,r,t in amr.edges: 41 | if s in done and t not in nodes: 42 | node_depth[t] = depth 43 | new_nodes.add(t) 44 | nodes.append(t) 45 | if not new_nodes: 46 | break 47 | depth += 1 48 | done.update(new_nodes) 49 | if len(nodes) < len(amr.nodes): 50 | print('[amr]', 'Failed to print AMR, ' 51 | + str(len(nodes)) + ' of ' + str(len(amr.nodes)) + ' nodes printed:\n ' 52 | + ' '.join(amr.tokens) + '\n' + str(amr), file=sys.stderr) 53 | 54 | max_depth = depth 55 | elems = ['\t% Nodes'] 56 | for n in nodes: 57 | depth = node_depth[n] 58 | row = [n for n in nodes if node_depth[n]==depth] 59 | pos = row.index(n) 60 | x = Latex_AMR._get_x(pos, len(row)) 61 | y = Latex_AMR._get_y(depth, max_depth) 62 | if callable(assign_color): 63 | color = assign_color(amr, n) 64 | else: 65 | color = assign_color 66 | colors.add(color) 67 | if not amr.nodes[n][0].isalpha() or amr.nodes[n] in ['imperative', 'expressive', 'interrogative']: 68 | concept = amr.nodes[n] 69 | else: 70 | concept = f'{n}/{amr.nodes[n]}' 71 | elems.append(f'\t\\node[{color}]({n}) at ({x},{y}) {{{concept}}};') 72 | elems.append('\t% Edges') 73 | for s,r,t in amr.edges: 74 | if node_depth[s] > node_depth[t]: 75 | dir1 = 'north' 76 | dir2 = 'south' 77 | elif node_depth[s] < node_depth[t]: 78 | dir1 = 'south' 79 | dir2 = 'north' 80 | elif node_depth[s] == node_depth[t] and nodes.index(s), thick] ({s}.{dir1}) -- ({t}.{dir2}) node[midway, above, sloped] {{{r}}};') 87 | latex = '\n\\begin{tikzpicture}[\n' 88 | for color in colors: 89 | latex += f'{color}/.style={{rectangle, draw={color}!60, fill={color}!5, very thick, minimum size=7mm}},\n' 90 | latex += ']\n' 91 | latex += '\n'.join(elems) 92 | latex += '\n\end{tikzpicture}\n' 93 | 94 | return latex 95 | 96 | @staticmethod 97 | def _get_x(i, num_nodes): 98 | return (i+1)*20.0/(num_nodes+1) 99 | 100 | @staticmethod 101 | def _get_y(depth, max_depth): 102 | return 2.5*(max_depth - depth) 103 | 104 | @staticmethod 105 | def style(amrs, assign_color='blue'): 106 | output = Latex_AMR.prefix() 107 | for amr in amrs: 108 | output += Latex_AMR.latex(amr, assign_color) 109 | return output 110 | 111 | 112 | class HTML_AMR: 113 | ''' 114 |
115 |
116 |     (:ARG0 (d / dog)
117 |         :ARG1 (c2 / cat))
118 |     
119 |
120 | ''' 121 | @staticmethod 122 | def _get_description(frame, propbank_frames_dictionary): 123 | if frame in propbank_frames_dictionary: 124 | return propbank_frames_dictionary[frame].replace('\t', '\n') 125 | return '' 126 | 127 | @staticmethod 128 | def span(text, type, id, desc=''): 129 | desc = html.escape(desc) 130 | text = html.escape(text) 131 | return f'{text}' 132 | 133 | @staticmethod 134 | def style_sheet(): 135 | return ''' 136 | div.amr-container { 137 | font-family: "Cambria Math", sans-serif; 138 | font-size: 14px; 139 | } 140 | 141 | .amr-node { 142 | color : black; 143 | } 144 | 145 | .amr-frame { 146 | color : purple; 147 | text-decoration: underline; 148 | } 149 | 150 | .amr-edge { 151 | color : grey; 152 | } 153 | 154 | .blue { 155 | background: deepskyblue; 156 | color : white; 157 | } 158 | 159 | .red { 160 | background: crimson; 161 | color : white; 162 | } 163 | 164 | .grey { 165 | background: gainsboro; 166 | color : black; 167 | } 168 | 169 | .green { 170 | background: yellowgreen; 171 | color : black; 172 | } 173 | ''' 174 | 175 | @staticmethod 176 | def html(amr, assign_node_color=None, assign_node_desc=None, assign_edge_color=None, assign_edge_desc=None, 177 | assign_token_color=None, assign_token_desc=None, other_args=None): 178 | from amr_utils.propbank_frames import propbank_frames_dictionary 179 | amr_string = f'[[{amr.root}]]' 180 | new_ids = {} 181 | for n in amr.nodes: 182 | new_id = amr.nodes[n][0] if amr.nodes[n] else 'x' 183 | if new_id.isalpha() and new_id.islower(): 184 | if new_id in new_ids.values(): 185 | j = 2 186 | while f'{new_id}{j}' in new_ids.values(): 187 | j += 1 188 | new_id = f'{new_id}{j}' 189 | else: 190 | j = 0 191 | while f'x{j}' in new_ids.values(): 192 | j += 1 193 | new_id = f'x{j}' 194 | new_ids[n] = new_id 195 | depth = 1 196 | nodes = {amr.root} 197 | completed = set() 198 | while '[[' in amr_string: 199 | tab = ' ' * depth 200 | for n in nodes.copy(): 201 | id = new_ids[n] if n in new_ids else 'x91' 202 | concept = amr.nodes[n] if n in new_ids and amr.nodes[n] else 'None' 203 | edges = sorted([e for e in amr.edges if e[0] == n], key=lambda x: x[1]) 204 | targets = set(t for s, r, t in edges) 205 | edge_spans = [] 206 | for s, r, t in edges: 207 | if assign_edge_color: 208 | color = assign_edge_color(amr, (s,r,t), other_args) 209 | else: 210 | color = False 211 | type = 'amr-edge' + (f' {color}' if color else '') 212 | desc = assign_edge_desc(amr, (s,r,t), other_args) if assign_edge_desc else '' 213 | edge_spans.append(f'{HTML_AMR.span(r, type, f"{s}-{t}", desc)} [[{t}]]') 214 | children = f'\n{tab}'.join(edge_spans) 215 | if children: 216 | children = f'\n{tab}' + children 217 | if assign_node_color: 218 | color = assign_node_color(amr, n, other_args) 219 | else: 220 | color = False 221 | 222 | if n not in completed: 223 | if (concept[0].isalpha() and concept not in ['imperative', 'expressive', 224 | 'interrogative']) or targets or depth==1: 225 | desc = HTML_AMR._get_description(concept, propbank_frames_dictionary) 226 | type = 'amr-frame' if desc else "amr-node" 227 | if assign_node_desc: 228 | desc = assign_node_desc(amr, n, other_args) 229 | if color: 230 | type += f' {color}' 231 | span = HTML_AMR.span(f'{id}/{concept}', type, id, desc) 232 | amr_string = amr_string.replace(f'[[{n}]]', f'({span}{children})', 1) 233 | else: 234 | type = 'amr-node' + (f' {color}' if color else '') 235 | desc = assign_node_desc(amr, n, other_args) if assign_node_desc else '' 236 | span = HTML_AMR.span(f'{concept}', type, id, desc) 237 | amr_string = amr_string.replace(f'[[{n}]]', f'{span}') 238 | completed.add(n) 239 | type = 'amr-node' + (f' {color}' if color else '') 240 | desc = assign_node_desc(amr, n, other_args) if assign_node_desc else '' 241 | span = HTML_AMR.span(f'{id}', type, id, desc) 242 | amr_string = amr_string.replace(f'[[{n}]]', f'{span}') 243 | nodes.remove(n) 244 | nodes.update(targets) 245 | depth += 1 246 | if len(completed) < len(amr.nodes): 247 | missing_nodes = [n for n in amr.nodes if n not in completed] 248 | missing_edges = [(s, r, t) for s, r, t in amr.edges if s in missing_nodes or t in missing_nodes] 249 | missing_nodes = ', '.join(f'{n}/{amr.nodes[n]}' for n in missing_nodes) 250 | missing_edges = ', '.join(f'{s}/{amr.nodes[s]} {r} {t}/{amr.nodes[t]}' for s, r, t in missing_edges) 251 | print('[amr]', 'Failed to print AMR, ' 252 | + str(len(completed)) + ' of ' + str(len(amr.nodes)) + ' nodes printed:\n ' 253 | + str(amr.id) + ':\n' 254 | + 'Missing nodes: ' + missing_nodes + '\n' 255 | + 'Missing edges: ' + missing_edges + '\n', 256 | file=sys.stderr) 257 | if not amr_string.startswith('('): 258 | amr_string = '(' + amr_string + ')' 259 | if len(amr.nodes) == 0: 260 | span = HTML_AMR.span('a/amr-empty', "amr-node", 'a') 261 | amr_string = f'({span})' 262 | toks = [t for t in amr.tokens] 263 | if assign_token_color or assign_token_desc: 264 | for i,t in enumerate(toks): 265 | color = assign_token_color(amr, i, other_args) if assign_token_color else '' 266 | desc = assign_token_desc(amr, i, other_args) if assign_token_desc else '' 267 | if color or desc: 268 | toks[i] = HTML_AMR.span(t, color, f'tok{i}', desc) 269 | output = f'
\n
\n{" ".join(toks)}\n\n{amr_string}
\n
\n\n' 270 | return output 271 | 272 | @staticmethod 273 | def style(amrs, assign_node_color=None, assign_node_desc=None, assign_edge_color=None, assign_edge_desc=None, 274 | assign_token_color=None, assign_token_desc=None, other_args=None): 275 | output = '\n' 276 | output += '\n' 277 | output += '\n\n' 280 | output += '\n' 281 | for amr in amrs: 282 | output += HTML_AMR.html(amr, 283 | assign_node_color, assign_node_desc, 284 | assign_edge_color, assign_edge_desc, 285 | assign_token_color, assign_token_desc, 286 | other_args) 287 | output += '
\n' 288 | output += '\n' 289 | output += '\n' 290 | return output 291 | 292 | def main(): 293 | import argparse 294 | from amr_utils.amr_readers import AMR_Reader 295 | 296 | parser = argparse.ArgumentParser(description='Style AMRs as HTML or Latex') 297 | parser.add_argument('-f', '--files', type=str, nargs=2, required=True, 298 | help='input and output files (AMRs in JAMR format)') 299 | parser.add_argument('--latex', action='store_true', help='style as latex') 300 | parser.add_argument('--html', action='store_true', help='style as html') 301 | 302 | args = parser.parse_args() 303 | file = args.files[0] 304 | outfile = args.files[1] 305 | 306 | cr = AMR_Reader(style='letters') 307 | amrs = cr.load(file, remove_wiki=True) 308 | 309 | if args.html: 310 | output = HTML_AMR.style(amrs) 311 | with open(outfile, 'w+', encoding='utf8') as f: 312 | f.write(output) 313 | else: 314 | output = Latex_AMR.style(amrs) 315 | with open(outfile, 'w+', encoding='utf8') as f: 316 | f.write(output) 317 | 318 | 319 | 320 | if __name__=='__main__': 321 | main() 322 | -------------------------------------------------------------------------------- /data/test_amrs.alignments.json: -------------------------------------------------------------------------------- 1 | {"1": [{"type": "basic", "tokens": [1], "nodes": ["1.1"], "edges": []}, 2 | {"type": "basic", "tokens": [2], "nodes": ["1"], "edges": []}, 3 | {"type": "basic", "tokens": [4], "nodes": ["1.2"], "edges": []}, 4 | {"type": "basic", "tokens": [6, 7], "nodes": ["1.2.2", "1.2.2.1", "1.2.2.1.1", "1.2.2.1.2", "1.2.2.1.3"], "edges": []} 5 | ], 6 | "2": [{"type": "basic", "tokens": [0], "nodes": ["1.1"], "edges": []}, 7 | {"type": "basic", "tokens": [1], "nodes": ["1"], "edges": []}, 8 | {"type": "basic", "tokens": [2], "nodes": ["1.2"], "edges": []} 9 | ], 10 | "3": [{"type": "basic", "tokens": [0], "nodes": ["1.1.2", "1.1.2.1"], "edges": []}, 11 | {"type": "basic", "tokens": [1], "nodes": ["1.1.1"], "edges": []}, 12 | {"type": "basic", "tokens": [2], "nodes": ["1.1"], "edges": []}, 13 | {"type": "basic", "tokens": [3], "nodes": ["1"], "edges": []}, 14 | {"type": "basic", "tokens": [4], "nodes": ["1.2"], "edges": []} 15 | ] 16 | } -------------------------------------------------------------------------------- /data/test_amrs.txt: -------------------------------------------------------------------------------- 1 | # ::id 1 2 | # ::tok The boy wants to go to New York 3 | (w/want-01 :ARG0 (b/boy) 4 | :ARG1 (g/go-02 :ARG0 b 5 | :ARG4 (c/city :name (n/name :op1 "New" 6 | :op2 "York" 7 | :op3 "City")))) 8 | 9 | # ::id 2 10 | # ::tok Dogs chase cats 11 | (c/chase-01 :ARG0 (d/dog) 12 | :ARG1 (c2/cat)) 13 | 14 | # ::id 3 15 | # ::tok Colorless green ideas sleep furiously . 16 | (s/sleep-01 17 | :ARG0 (i2/idea 18 | :ARG1-of (g/green-02) 19 | :ARG1-of (c/color-01 :polarity -)) 20 | :time (i3/infuriate-01 21 | :ARG1 i2)) -------------------------------------------------------------------------------- /data/test_amrs2.txt: -------------------------------------------------------------------------------- 1 | # ::id 1 2 | # ::tok The boy wants to go to New York 3 | (w/want-01 :ARG0 (b/boy) 4 | :ARG1 (g/go-01 :ARG0 b 5 | :ARG4 (s/state :name (n/name :op1 "New" 6 | :op2 "York")))) 7 | 8 | # ::id 2 9 | # ::tok Dogs chase cats 10 | (c/chase-02 :ARG0 (d/dog) 11 | :ARG1 (c2/cat)) 12 | 13 | # ::id 3 14 | # ::tok Colorless green ideas sleep furiously . 15 | (s/sleep-01 16 | :ARG0 (i2/idea 17 | :mod (g/green) 18 | :mod (c/colorless)) 19 | :time (f/furious)) 20 | 21 | -------------------------------------------------------------------------------- /display_align_ex.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ablodge/amr-utils/be5534db1312dc7c6ba25ee50eafeb0d0f5e3f69/display_align_ex.PNG -------------------------------------------------------------------------------- /html_ex.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ablodge/amr-utils/be5534db1312dc7c6ba25ee50eafeb0d0f5e3f69/html_ex.PNG -------------------------------------------------------------------------------- /latex_ex.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ablodge/amr-utils/be5534db1312dc7c6ba25ee50eafeb0d0f5e3f69/latex_ex.PNG -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | 7 | setuptools.setup( 8 | name='amr-utils', 9 | version='1.0', 10 | scripts=[] , 11 | author="Austin Blodgett", 12 | description="A toolkit of operations for AMRs", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/ablodge/amr-utils", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | ], 20 | 21 | ) 22 | --------------------------------------------------------------------------------