├── README.md ├── compare_smatch ├── __init__.py ├── amr_alignment.py ├── amr_metadata.py └── smatch_graph.py ├── demo ├── AMRICA-demo.html ├── AMRICA-demo.ipynb ├── DF-183-195681-794_9333.5_annoted_UCO-AMR-06_UCO-AMR-05.png ├── wb.eng_0003.13.png ├── wb.eng_0003.14.png ├── wb.eng_0003.14_0.png ├── wb.eng_0003.6.png └── wb.eng_0003.6_0.png ├── disagree.py ├── scripts └── smatch_stats.py └── smatch ├── README.txt ├── __init__.py ├── amr.py ├── smatch-table.py └── smatch.py /README.md: -------------------------------------------------------------------------------- 1 | # Using AMRICA 2 | 3 | AMRICA (AMR Inspector for Cross-language Alignments) is a simple tool for aligning and visually representing AMRs [(Banarescu, 2013)](http://www.isi.edu/natural-language/amr/a.pdf), both for bilingual contexts and for monolingual inter-annotator agreement. It is based on and extends the Smatch system [(Cai, 2012)](http://www.newdesign.aclweb.org/anthology-new/P/P13/P13-2131.pdf) for identifying AMR interannotator agreement. 4 | 5 | It is also possible to use AMRICA to visualize manual alignments you have edited or compiled yourself (see [Common Flags](#common-flags)). 6 | 7 | ## Getting started 8 | 9 | Download the python source from [github](https://github.com/nsaphra/AMRICA). 10 | 11 | ### Dependencies 12 | 13 | We assume you have `pip`. To install the dependencies (assuming you already have graphviz dependencies mentioned below), just run: 14 | 15 | ``` 16 | pip install argparse_config networkx==1.8 pygraphviz pynlpl 17 | ``` 18 | 19 | `pygraphviz` requires [graphviz](http://www.graphviz.org/) to work. On Linux, you may have to install `graphviz libgraphviz-dev pkg-config`. Additionally, to prepare bilingual alignment data you will need [GIZA++](https://code.google.com/p/giza-pp/) and possibly [JAMR](https://github.com/jflanigan/jamr/). 20 | 21 | ### Single View Quick Start 22 | 23 | ``` 24 | ./disagree.py -i sample.amr -o sample_out_dir/ 25 | ``` 26 | 27 | This command will read the AMRs in `sample.amr` (separated by empty lines) and put their graphviz visualizations in .png files located in `sample_out_dir/`. 28 | 29 | ### Data Preparation 30 | 31 | #### Monolingual AMRICA 32 | 33 | To generate visualizations of Smatch alignments, we need an AMR input file with each 34 | `::tok` or `::snt` fields containing tokenized sentences, `::id` fields with a sentence ID, and `::annotator` or `::anno` fields with an annotator ID. The annotations for a particular sentence are listed sequentially, and the first annotation is considered the gold standard for visualization purposes. 35 | 36 | ##### Single AMR View 37 | 38 | If you only want to visualize the single annotation per sentence without interannotator agreement, you can use an AMR file with only a single annotator. In this case, annotator and sentence ID fields are optional. The resulting graph will be all black. 39 | 40 | ### Bilingual AMRICA 41 | 42 | For bilingual alignments, we start with two AMR files, one containing the target annotations and one with the source annotations in the same order, with `::tok` and `::id` fields for each annotation. If we want JAMR alignments for either side, we include those in a `::alignments` field. 43 | 44 | The sentence alignments should be in the form of two GIZA++ alignment .NBEST files, one source-target and one target-source. To generate these, use the --nbestalignments flag in your GIZA++ config file set to your preferred nbest count. 45 | 46 | ## Configuration 47 | 48 | Flags can be set either at the command line or in a config file. The location of a config file can be set with `-c CONF_FILE` at the command line. 49 | 50 | ### Common flags 51 | 52 | In addition to `--conf_file`, there are several other flags that apply to both monolingual and bilingual text. `--outdir DIR` is the only required one, and specifies the directory to which we will write the image files. 53 | 54 | The optional shared flags are: 55 | * `--verbose` to print out sentences as we align them. 56 | * `--no-verbose` to override a verbose default setting. 57 | * `--json FILE.json` to write the alignment graphs to a .json file. 58 | * `--num_restarts N` to specify the number of random restarts Smatch should execute. 59 | * `--align_out FILE.csv` to write the alignments to file. 60 | * `--align_in FILE.csv` to read the alignments from disk instead of running Smatch. 61 | * `--layout` to modify the layout parameter to graphviz. 62 | 63 | The alignment .csv files are in a format where each graph matching set is separated by an empty line, and each line within a set contains either a comment or a line indicating an alignment. For example: 64 | 65 | ``` 66 | 3 它 - 1 it 67 | 2 多长 - -1 68 | -1 - 2 take 69 | ``` 70 | 71 | The tab-separated fields are the test node index (as processed by Smatch), the test node label, the gold node index, and the gold node label. 72 | 73 | ### Monolingual 74 | 75 | Monolingual alignment requires one additional flag, `--infile FILE.amr`, with `FILE.amr` set to the location of the AMR file. 76 | 77 | Following is an example config file: 78 | 79 | ``` 80 | [default] 81 | infile: data/events_amr.txt 82 | outdir: data/events_png/ 83 | json: data/events.json 84 | verbose 85 | ``` 86 | 87 | ### Bilingual 88 | 89 | In bilingual alignment, there are more required flags. 90 | 91 | * `--src_amr FILE` for the source annotation AMR file. 92 | * `--tgt_amr FILE` for the target annotation AMR file. 93 | * `--align_tgt2src FILE.A3.NBEST` for the GIZA++ .NBEST file aligning target-to-source (with target as vcb1), generated with `--nbestalignments N` 94 | * `--align_src2tgt FILE.A3.NBEST` for the GIZA++ .NBEST file aligning source-to-target (with source as vcb1), generated with `--nbestalignments N` 95 | 96 | Now if `--nbestalignments N` was set to be >1, we should specify it with `--num_aligned_in_file`. If we want to count only the top $k$ of those alignments, we set `--num_align_read` as well. 97 | 98 | ## Endnotes 99 | 100 | `--nbestalignments` is a tricky flag to use, because it will only generate on a final alignment run. I could only get it to work with the default GIZA++ settings, myself. 101 | 102 | ## How It Works 103 | 104 | ### Smatch Classic 105 | 106 | Since AMRICA is a variation on Smatch, one should begin by understanding Smatch. Smatch attempts to identfy a matching between the variable nodes of two AMR representations of the same sentence in order to measure inter-annotator agreement. The matching should be selected to maximize the Smatch score, which assigns a point for each edge appearing in both graphs, falling into three categories. Each category is illustrated in the following annotation of "It didn't take long." 107 | 108 | ``` 109 | (t / take-10 110 | :ARG0 (i / it) 111 | :ARG1 (l2 / long 112 | :polarity -)) 113 | ``` 114 | 115 | * Instance labels, such as `(instance, t, take-10)` 116 | * Variable-variable edges, such as `(ARG0, t, i)` 117 | * Variable-const edges, such as `(polarity, l2, -)` 118 | 119 | Because the problem of finding the matching maximizing the Smatch score is NP-complete, Smatch uses a hill-climbing algorithm to approximate the best solution. It seeds by matching each node to a node sharing its label if possible and matching the remaining nodes in the smaller graph (hereafter the target) randomly. Smatch then performs a step by finding the action that will increase the score the most by either switching two target nodes' matchings or moving a matching from its source node to an unmatched source node. It repeats this step until no step can immediately increase the Smatch score. 120 | 121 | To avoid local optima, Smatch generally restarts 5 times. 122 | 123 | ### AMRICA 124 | 125 | For technical details about AMRICA's inner workings, it may be more useful to read our [NAACL demo paper](http://speak.clsp.jhu.edu/uploads/publications/papers/1053_pdf.pdf). 126 | 127 | AMRICA begins by replacing all constant nodes with variable nodes that are instances of the constant's label. This is necessary so we can align the constant nodes as well as the variables. So the only points added to AMRICA score will come from matching variable-variable edges and instance labels. 128 | 129 | While Smatch tries to match every node in the smaller graph to some node in the larger graph, AMRICA removes matchings that do not increase the modified Smatch score, or AMRICA score. 130 | 131 | AMRICA then generates image files from graphviz graphs of the alignments. If a node or edge appears only in the gold data, it is red. If that node or edge appears only in the test data, it is blue. If the node or edge has a matching in our final alignment, it is black. 132 | 133 | ![](demo/DF-183-195681-794_9333.5_annoted_UCO-AMR-06_UCO-AMR-05.png?raw=true) 134 | 135 | #### Bitextual Variant 136 | 137 | In AMRICA, instead of adding one point for each perfectly matching instance label, we add a point based on a likelihood score on those labels aligning. The likelihood score ℓ(aLt,Ls[i]|Lt,Wt,Ls,Ws) with target label set Lt, source labels set Ls, target sentence Wt, source sentence Ws, and alignment aLt,Ls[i] mapping Lt[i] onto some label Ls[aLt,Ls[i]], is computed from a likelihood that is defined by the following rules: 138 | 139 | * If the labels for Ls[aLt,Ls[i]] and Lt[i] match, add 1 to the likelihood. 140 | * Add to the likelihood: 141 | ∑j=1|Wt|∑k=1|Ws|ℓ(aLt,Wt[i]=j)ℓ(aWt,Ws[j]=k)ℓ(aWs,Ls[k]=aLt,Ls[i]) 142 | * Compute ℓ(aLt,Wt[i]=j) by one of two methods. 143 | * If there are JAMR alignments available, for each JAMR alignment containing this node, 1 point is partitioned among the tokens in the range aligned to the label. If there are no such tokens in the range, the 1 point is partitioned among all tokens in the range. 144 | * If no JAMR alignment contains the ith node, treat it as though the token ranges with no JAMR aligned nodes were aligned to the ith node. 145 | * If there are no JAMR alignments available, then 1 point is partitioned among all tokens string-matching label i. 146 | * Compute ℓ(aWs,Ls[k]=aLt,Ls[i]) by the same method. 147 | * Compute ℓ(aWt,Ws[j]=k) from a posterior word alignment score extracted from the source-target and target-source nbest GIZA alignment files, normalized to 1. 148 | 149 | In general, bilingual AMRICA appears to require more random restarts than monolingual AMRICA to perform well. This restart count can be modified with the flag `--num_restarts`. 150 | 151 | ![](demo/wb.eng_0003.13.png?raw=true) 152 | 153 | ##### Comparison: Smart Initialization vs. Approximation 154 | 155 | We can observe the degree to which using Smatch-like approximations (here, with 20 random initializations) improves accuracy over selecting likely matches from raw alignment data (smart initialization). For a pairing declared structurally compatible by [(Xue 2014)](http://www.lrec-conf.org/proceedings/lrec2014/pdf/384_Paper.pdf). 156 | 157 | * After initialization: 158 | 159 | ![](demo/wb.eng_0003.14_0.png?raw=true) 160 | 161 | * After bilingual smatch, with errors circled: 162 | 163 | ![](demo/wb.eng_0003.14.png?raw=true) 164 | 165 | For a pairing considered incompatible: 166 | 167 | * After initialization: 168 | 169 | ![](demo/wb.eng_0003.6_0.png?raw=true) 170 | 171 | * After bilingual smatch, with errors circled: 172 | 173 | ![](demo/wb.eng_0003.6.png?raw=true) 174 | 175 | 176 | 177 | *This software was developed partly with the support of the National Science Foundation (USA) under awards 1349902 and 0530118. 178 | The University of Edinburgh is a charitable body, registered in 179 | Scotland, with registration number SC005336.* 180 | -------------------------------------------------------------------------------- /compare_smatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/compare_smatch/__init__.py -------------------------------------------------------------------------------- /compare_smatch/amr_alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | amr_alignment.py 3 | 4 | Author: Naomi Saphra (nsaphra@jhu.edu) 5 | Copyright(c) 2014 6 | 7 | Builds a weighted mapping of tokens between parallel sentences for use in 8 | weighted cross-language Smatch alignment. 9 | Takes in an output file from GIZA++ (specified in construction functions). 10 | """ 11 | 12 | from collections import defaultdict 13 | from pynlpl.formats.giza import GizaSentenceAlignment 14 | import re 15 | 16 | class Amr2AmrAligner(object): 17 | def __init__(self, num_best=5, num_best_in_file=-1, src2tgt_fh=None, tgt2src_fh=None): 18 | if src2tgt_fh == None or tgt2src_fh == None: 19 | self.is_default = True 20 | self.node_weight_fn = self.dflt_node_weight_fn 21 | self.edge_weight_fn = self.dflt_edge_weight_fn 22 | else: 23 | self.is_default = False 24 | self.node_weight_fn = None 25 | self.edge_weight_fn = self.xlang_edge_weight_fn 26 | self.src2tgt_fh = src2tgt_fh 27 | self.tgt2src_fh = tgt2src_fh 28 | self.amr2amr = {} 29 | self.num_best = num_best 30 | self.num_best_in_file = num_best_in_file 31 | self.last_nbest_line = {self.src2tgt_fh:None, self.tgt2src_fh:None} 32 | if num_best_in_file < 0: 33 | self.num_best_in_file = num_best 34 | assert self.num_best_in_file >= self.num_best 35 | 36 | 37 | def set_amrs(self, tgt_amr, src_amr): 38 | if self.is_default: 39 | return 40 | 41 | self.tgt_toks = tgt_amr.metadata['tok'].strip().split() 42 | self.src_toks = src_amr.metadata['tok'].strip().split() 43 | 44 | sent2sent_union = align_sent2sent_union(self.tgt_toks, self.src_toks, 45 | self.get_nbest_alignments(self.src2tgt_fh), self.get_nbest_alignments(self.tgt2src_fh)) 46 | 47 | if 'alignments' in tgt_amr.metadata: 48 | amr2sent_tgt = align_amr2sent_jamr(tgt_amr, self.tgt_toks, tgt_amr.metadata['alignments'].strip().split()) 49 | else: 50 | amr2sent_tgt = align_amr2sent_dflt(tgt_amr, self.tgt_toks) 51 | if 'alignments' in src_amr.metadata: 52 | amr2sent_src = align_amr2sent_jamr(src_amr, self.src_toks, src_amr.metadata['alignments'].strip().split()) 53 | else: 54 | amr2sent_src = align_amr2sent_dflt(src_amr, self.src_toks) 55 | 56 | self.amr2amr = defaultdict(float) 57 | for (tgt_lbl, tgt_scores) in amr2sent_tgt.items(): 58 | for (src_lbl, src_scores) in amr2sent_src.items(): 59 | if src_lbl.lower() == tgt_lbl.lower(): 60 | self.amr2amr[(tgt_lbl, src_lbl)] += 1.0 61 | continue 62 | for (t, t_score) in enumerate(tgt_scores): 63 | for (s, s_score) in enumerate(src_scores): 64 | score = t_score * s_score * sent2sent_union[t][s] 65 | if score > 0: 66 | self.amr2amr[(tgt_lbl, src_lbl)] += score 67 | 68 | self.node_weight_fn = lambda t,s : self.amr2amr[(t, s)] 69 | 70 | 71 | def const_map_fn(self, const): 72 | """ Get all const strings from source amr that could map to target const """ 73 | const_matches = [const] 74 | for (t,s) in filter(lambda (t,s): t == const, self.amr2amr): 75 | if self.node_weight_fn(t,s) > 0: # weight > 0 76 | const_matches.append(s) 77 | return sorted(const_matches, key=lambda x: self.node_weight_fn(const, x), reverse=True) 78 | 79 | 80 | @staticmethod 81 | def dflt_node_weight_fn(tgt_label, src_label): 82 | return 1.0 if tgt_label.lower() == src_label.lower() else 0.0 83 | 84 | 85 | @staticmethod 86 | def dflt_edge_weight_fn(tgt_label, src_label): 87 | return 1.0 if tgt_label.lower() == src_label.lower() else 0.0 88 | 89 | 90 | def xlang_edge_weight_fn(self, tgt_label, src_label): 91 | tgt = tgt_label.lower() 92 | src = src_label.lower() 93 | if tgt == src: 94 | # operand edges are all equivalent 95 | #TODO make the string match a regex instead? 96 | return 1.0 97 | if tgt.startswith("op") and src.startswith("op"): 98 | return 0.9 # TODO this is a frumious hack to favor similar op edges 99 | return 0.0 100 | 101 | 102 | def get_nbest_alignments(self, fh): 103 | """ Read an entry from the giza alignment .A3 NBEST file. """ 104 | aligns = [] 105 | curr_sent = -1 106 | start_ind = 0 107 | if self.last_nbest_line[fh]: 108 | if self.num_best > 0: 109 | aligns.append(self.last_nbest_line[fh]) 110 | start_ind = 1 111 | curr_sent = self.last_nbest_line[fh][0].index 112 | self.last_nbest_line[fh] = None 113 | 114 | for ind in range(start_ind, self.num_best_in_file): 115 | meta_line = fh.readline() 116 | if meta_line == "": 117 | if len(aligns) == 0: 118 | return None 119 | else: 120 | break 121 | 122 | meta = re.match("# Sentence pair \((\d+)\) "+ 123 | "source length (\d+) target length (\d+) "+ 124 | "alignment score : (.+)", meta_line) 125 | if not meta: 126 | raise Exception 127 | sent = int(meta.group(1)) 128 | if curr_sent < 0: 129 | curr_sent = sent 130 | score = float(meta.group(4)) 131 | 132 | tgt_line = fh.readline() 133 | src_line = fh.readline() 134 | if sent != curr_sent: 135 | self.last_nbest_line[fh] = (GizaSentenceAlignment(src_line, tgt_line, sent), score) 136 | break 137 | if ind < self.num_best: 138 | aligns.append((GizaSentenceAlignment(src_line, tgt_line, sent), score)) 139 | return aligns 140 | 141 | default_aligner = Amr2AmrAligner() 142 | 143 | def get_all_labels(amr): 144 | ret = [v for v in amr.var_values] 145 | for l in amr.const_links: 146 | ret += [v for (k,v) in l.items()] 147 | return ret 148 | 149 | 150 | def align_amr2sent_dflt(amr, sent): 151 | labels = get_all_labels(amr) 152 | align = {l:[0.0 for tok in sent] for l in labels} 153 | for label in labels: 154 | lbl = label.lower() 155 | # checking for multiwords / bad segmentation 156 | # ('_' replaces ' ' in multiword quotes) 157 | # TODO just fix AMR format parser to deal with spaces in quotes 158 | possible_toks = lbl.split('_') 159 | possible_toks.append(lbl) 160 | 161 | matches = [t_ind for (t_ind, t) in enumerate(sent) if t.lower() in possible_toks] 162 | for t_ind in matches: 163 | align[label][t_ind] = 1.0 / len(matches) 164 | return align 165 | 166 | 167 | def parse_jamr_alignment(chunk): 168 | (tok_range, nodes_str) = chunk.split('|') 169 | (start_tok, end_tok) = tok_range.split('-') 170 | node_list = nodes_str.split('+') 171 | return (int(start_tok), int(end_tok), node_list) 172 | 173 | 174 | def align_label2toks_en(label, sent, weights, toks_to_align): 175 | """ 176 | label: node label to map 177 | sent: token list to map label to 178 | weights: list to be modified with new weights 179 | default_full: set True to have the default distribution sum to 1 instead of 0 180 | return list mapping token index to match weight 181 | """ 182 | 183 | # TODO frumious hack. should set up actual stemmer sometime. 184 | lbl = label.lower() 185 | stem = lbl 186 | wordnet = re.match("(.+)-\d\d", lbl) 187 | if wordnet: 188 | stem = wordnet.group(1) 189 | if len(stem) > 4: # arbitrary 190 | if len(stem) > 5: 191 | stem = stem[:-2] 192 | else: 193 | stem = stem[:-1] 194 | 195 | def is_match(tok): 196 | return tok == lbl or \ 197 | (len(tok) >= len(stem) and tok[:len(stem)] == stem) 198 | 199 | matches = [t_ind for t_ind in toks_to_align if is_match(sent[t_ind].lower())] 200 | if len(matches) == 0: 201 | matches = toks_to_align 202 | for t_ind in matches: 203 | weights[t_ind] += 1.0 / len(matches) 204 | return weights 205 | 206 | 207 | def align_amr2sent_jamr(amr, sent, jamr_line): 208 | """ 209 | amr: an amr to map nodes to sentence toks 210 | sent: sentence array of toks 211 | jamr_line: metadata field 'alignments', aligned with jamr 212 | return dict mapping amr node labels to match weights for each tok in sent 213 | """ 214 | labels = get_all_labels(amr) 215 | labels_remain = {label:labels.count(label) for label in labels} 216 | tokens_remain = set(range(len(sent))) 217 | align = {l:[0.0 for tok in sent] for l in labels} 218 | 219 | for chunk in jamr_line: 220 | (start_tok, end_tok, node_list) = parse_jamr_alignment(chunk) 221 | for node_path in node_list: 222 | label = amr.path2label[node_path] 223 | toks_to_align = range(start_tok, end_tok) 224 | align[label] = align_label2toks_en(label, sent, align[label], toks_to_align) 225 | labels_remain[label] -= 1 226 | for t in toks_to_align: 227 | tokens_remain.discard(t) 228 | 229 | for label in labels_remain: 230 | if labels_remain[label] > 0: 231 | align[label] = align_label2toks_en(label, sent, align[label], tokens_remain) 232 | for label in align: 233 | z = sum(align[label]) 234 | if z == 0: 235 | continue 236 | align[label] = [w/z for w in align[label]] 237 | return align 238 | 239 | 240 | def align_sent2sent(tgt_toks, src_toks, alignment_scores): 241 | """ 242 | return list array where entry (i,j) is the likelihood weight of target token i 243 | aligning to source token j. 244 | """ 245 | z = sum([s for (a,s) in alignment_scores]) 246 | tok_align = [[0.0 for s in src_toks] for t in tgt_toks] 247 | for (align, score) in alignment_scores: 248 | for srcind, tgtind in align.alignment: 249 | if tgtind >= 0 and srcind >= 0: 250 | tok_align[tgtind][srcind] += score 251 | 252 | for targetind, targettok in enumerate(tgt_toks): 253 | for sourceind, sourcetok in enumerate(src_toks): 254 | tok_align[targetind][sourceind] /= z 255 | return tok_align 256 | 257 | 258 | def align_sent2sent_union(tgt_toks, src_toks, src2tgt, tgt2src): 259 | """ 260 | return list array where entry (i,j) is the average likelihood of aligning in each 261 | direction 262 | """ 263 | src2tgt_align = align_sent2sent(tgt_toks, src_toks, src2tgt) 264 | tgt2src_align = align_sent2sent(src_toks, tgt_toks, tgt2src) 265 | 266 | tok_align = [[0.0 for s in src_toks] for t in tgt_toks] 267 | for tgtind, tgttok in enumerate(tgt_toks): 268 | for srcind, srctok in enumerate(src_toks): 269 | tok_align[tgtind][srcind] = \ 270 | (src2tgt_align[tgtind][srcind] + tgt2src_align[srcind][tgtind]) / 2.0 271 | return tok_align 272 | -------------------------------------------------------------------------------- /compare_smatch/amr_metadata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | amr_metadata.py 4 | 5 | Author: Naomi Saphra (nsaphra@jhu.edu) 6 | Copyright(c) 2014 7 | 8 | Read AMR file in while also processing metadata in comments 9 | """ 10 | 11 | import re 12 | 13 | from smatch.amr import AMR 14 | 15 | class AmrMeta(AMR): 16 | def __init__(self, var_list=None, var_value_list=None, 17 | link_list=None, const_link_list=None, path2label=None, 18 | base_amr=None, metadata={}): 19 | if base_amr is None: 20 | super(AmrMeta, self).__init__(var_list, var_value_list, 21 | link_list, const_link_list, path2label) 22 | else: 23 | self.nodes = base_amr.nodes 24 | self.root = base_amr.root 25 | self.var_values = base_amr.var_values 26 | self.links = base_amr.links 27 | self.const_links = base_amr.const_links 28 | self.path2label = base_amr.path2label 29 | 30 | self.metadata = metadata 31 | 32 | 33 | @classmethod 34 | def from_parse(cls, annotation_line, comment_lines, consts_to_vars=False): 35 | metadata = {} 36 | for l in comment_lines: 37 | matches = re.findall(r'::(\S+)\s(([^:]|:(?!:))+)', l) 38 | for m in matches: 39 | metadata[m[0]] = m[1].strip() 40 | 41 | base_amr = AMR.parse_AMR_line(annotation_line, consts_to_vars=consts_to_vars) 42 | return cls(base_amr=base_amr, metadata=metadata) 43 | 44 | 45 | def get_amr_line(infile): 46 | """ Read an entry from the input file. AMRs are separated by blank lines. """ 47 | cur_comments = [] 48 | cur_amr = [] 49 | has_content = False 50 | for line in infile: 51 | if line[0] == "(" and len(cur_amr) != 0: 52 | cur_amr = [] 53 | if line.strip() == "": 54 | if not has_content: 55 | continue 56 | else: 57 | break 58 | elif line.strip().startswith("#"): 59 | cur_comments.append(line.strip()) 60 | else: 61 | has_content = True 62 | cur_amr.append(line.strip()) 63 | return ("".join(cur_amr), cur_comments) 64 | -------------------------------------------------------------------------------- /compare_smatch/smatch_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | smatch_graph.py 4 | 5 | Author: Naomi Saphra (nsaphra@jhu.edu) 6 | Copyright(c) 2014 7 | 8 | Describes a class for building graphs of AMRs with disagreements hilighted. 9 | """ 10 | 11 | import copy 12 | import networkx as nx 13 | import pygraphviz as pgz 14 | from pynlpl.formats.giza import GizaSentenceAlignment 15 | 16 | from amr_alignment import Amr2AmrAligner 17 | from amr_alignment import default_aligner 18 | import amr_metadata 19 | from smatch import smatch 20 | 21 | GOLD_COLOR = 'blue' 22 | TEST_COLOR = 'red' 23 | DFLT_COLOR = 'black' 24 | 25 | class SmatchGraph: 26 | def __init__(self, inst, rel1, rel2, \ 27 | gold_inst_t, gold_rel1_t, gold_rel2_t, \ 28 | match, const_map_fn=default_aligner.const_map_fn): 29 | """ 30 | Input: 31 | (inst, rel1, rel2) from test amr.get_triples2() 32 | (gold_inst_t, gold_rel1_t, gold_rel2_t) from gold amr2dict() 33 | match from smatch 34 | const_map_fn returns a sorted list of gold label matches for a test label input 35 | """ 36 | (self.inst, self.rel1, self.rel2) = (inst, rel1, rel2) 37 | (self.gold_inst_t, self.gold_rel1_t, self.gold_rel2_t) = \ 38 | (gold_inst_t, gold_rel1_t, gold_rel2_t) 39 | self.match = match # test var index -> gold var index 40 | self.map_fn = const_map_fn 41 | 42 | (self.unmatched_inst, self.unmatched_rel1, self.unmatched_rel2) = \ 43 | [copy.deepcopy(x) for x in (self.gold_inst_t, self.gold_rel1_t, self.gold_rel2_t)] 44 | self.gold_ind = {} # test variable hash -> gold variable index 45 | self.G = nx.MultiDiGraph() 46 | 47 | 48 | def smatch2graph(self, node_weight_fn=None, edge_weight_fn=None): 49 | """ 50 | Returns graph of test AMR / gold AMR union, with hilighted disagreements for 51 | different labels on edges and nodes, unmatched nodes and edges. 52 | """ 53 | 54 | for (ind, (i, v, instof)) in enumerate(self.inst): 55 | self.add_inst(ind, v, instof) 56 | 57 | for (reln, v, const) in self.rel1: 58 | self.add_rel1(reln, v, const) 59 | 60 | for (reln, v1, v2) in self.rel2: 61 | self.add_rel2(reln, v1, v2) 62 | 63 | if node_weight_fn and edge_weight_fn: 64 | self.unmatch_dead_nodes(node_weight_fn, edge_weight_fn) 65 | 66 | # Add gold standard elements not in test 67 | test_ind = {v:k for (k,v) in self.gold_ind.items()} # reverse lookup from gold ind 68 | 69 | for (ind, instof) in self.unmatched_inst.items(): 70 | test_ind[ind] = u'GOLD %s' % ind 71 | self.add_node(test_ind[ind], '', instof, test_ind=-1, gold_ind=ind) 72 | 73 | for ((ind, const), relns) in self.unmatched_rel1.items(): 74 | for reln in relns: 75 | const_hash = test_ind[ind] + ' ' + const 76 | if const_hash not in test_ind: 77 | test_ind[const_hash] = const_hash 78 | self.add_node(const_hash, '', const) 79 | self.add_edge(test_ind[ind], test_ind[const_hash], '', reln) 80 | 81 | for ((ind1, ind2), relns) in self.unmatched_rel2.items(): 82 | for reln in relns: 83 | self.add_edge(test_ind[ind1], test_ind[ind2], '', reln) 84 | 85 | return self.G 86 | 87 | 88 | def get_text_alignments(self): 89 | """ Return an array of variable ID mappings, including labels, that are human-readable. 90 | Call only after smatch2graph(). """ 91 | align = [] 92 | for (v, attr) in self.G.nodes(data=True): 93 | if attr['test_ind'] < 0 and attr['gold_ind'] < 0: 94 | continue 95 | align.append("%s\t%s\t-\t%s\t%s" % (attr['test_ind'], attr['test_label'], attr['gold_ind'], attr['gold_label'])) 96 | return align 97 | 98 | 99 | def add_edge(self, v1, v2, test_lbl, gold_lbl): 100 | assert(gold_lbl == '' or test_lbl == '' or gold_lbl == test_lbl) 101 | if gold_lbl == '': 102 | self.G.add_edge(v1, v2, label=test_lbl, test_label=test_lbl, gold_label=gold_lbl, color=TEST_COLOR) 103 | elif test_lbl == '': 104 | self.G.add_edge(v1, v2, label=gold_lbl, test_label=test_lbl, gold_label=gold_lbl, color=GOLD_COLOR) 105 | elif test_lbl == gold_lbl: 106 | self.G.add_edge(v1, v2, label=test_lbl, test_label=test_lbl, gold_label=gold_lbl, color=DFLT_COLOR) 107 | 108 | 109 | def add_node(self, v, test_lbl, gold_lbl, test_ind=-1, gold_ind=-1): 110 | assert(gold_lbl or test_lbl) 111 | if gold_lbl == '': 112 | self.G.add_node(v, label=u'%s / *' % test_lbl, test_label=test_lbl, gold_label=gold_lbl, \ 113 | test_ind=test_ind, gold_ind=gold_ind, color=TEST_COLOR) 114 | elif test_lbl == '': 115 | self.G.add_node(v, label=u'* / %s' % gold_lbl, test_label=test_lbl, gold_label=gold_lbl, \ 116 | test_ind=test_ind, gold_ind=gold_ind, color=GOLD_COLOR) 117 | elif test_lbl == gold_lbl: 118 | self.G.add_node(v, label=test_lbl, test_label=test_lbl, gold_label=gold_lbl, \ 119 | test_ind=test_ind, gold_ind=gold_ind, color=DFLT_COLOR) 120 | else: 121 | self.G.add_node(v, label=u'%s / %s' % (test_lbl, gold_lbl), test_label=test_lbl, gold_label=gold_lbl, \ 122 | test_ind=test_ind, gold_ind=gold_ind, color=DFLT_COLOR) 123 | 124 | 125 | def add_inst(self, ind, var, instof): 126 | self.gold_ind[var] = self.match[ind] 127 | gold_lbl = '' 128 | gold_ind = self.match[ind] 129 | if gold_ind >= 0: # there's a gold match 130 | gold_lbl = self.gold_inst_t[gold_ind] 131 | if self.match[ind] in self.unmatched_inst: 132 | del self.unmatched_inst[gold_ind] 133 | self.add_node(var, instof, gold_lbl, test_ind=ind, gold_ind=gold_ind) 134 | 135 | 136 | def add_rel1(self, reln, var, const): 137 | const_matches = self.map_fn(const) 138 | gold_edge_lbl = '' 139 | 140 | # we match const to the highest-ranked match label from the var 141 | gold_node_lbl = '' 142 | node_hash = var+' '+const 143 | for const_match in const_matches: 144 | if (self.gold_ind[var], const_match) in self.gold_rel1_t: 145 | gold_node_lbl = const_match 146 | #TODO put the metatable editing in the helper fcns? 147 | if reln not in self.gold_rel1_t[(self.gold_ind[var], const_match)]: 148 | # relns between existing nodes should be in unmatched rel2 149 | self.gold_ind[node_hash] = const_match 150 | self.unmatched_rel2[(self.gold_ind[var], const_match)] = self.unmatched_rel1[(self.gold_ind[var], const_match)] 151 | del self.unmatched_rel1[(self.gold_ind[var], const_match)] 152 | else: 153 | gold_edge_lbl = reln 154 | self.unmatched_rel1[(self.gold_ind[var], const_match)].remove(reln) 155 | break 156 | 157 | self.add_node(node_hash, const, gold_node_lbl) 158 | self.add_edge(var, node_hash, reln, gold_edge_lbl) 159 | 160 | 161 | def add_rel2(self, reln, v1, v2): 162 | gold_lbl = '' 163 | if (self.gold_ind[v1], self.gold_ind[v2]) in self.gold_rel2_t: 164 | if reln in self.gold_rel2_t[(self.gold_ind[v1], self.gold_ind[v2])]: 165 | gold_lbl = reln 166 | self.unmatched_rel2[(self.gold_ind[v1], self.gold_ind[v2])].remove(reln) 167 | self.add_edge(v1, v2, reln, gold_lbl) 168 | 169 | 170 | def unmatch_dead_nodes(self, node_weight_fn, edge_weight_fn): 171 | """ Unmap node mappings that don't increase smatch score. """ 172 | node_is_live = {v:(gold == -1) for (v, gold) in self.gold_ind.items()} 173 | for (v, attr) in self.G.nodes(data=True): 174 | if node_weight_fn(attr['test_label'], attr['gold_label']) > 0: 175 | node_is_live[v] = True 176 | for (v1, links) in self.G.adjacency_iter(): 177 | for (v2, edges) in links.items(): 178 | if len(edges) > 1: 179 | node_is_live[v2] = True 180 | node_is_live[v1] = True 181 | break 182 | for (ind, attr) in edges.items(): 183 | if attr['test_label'] == attr['gold_label']: 184 | node_is_live[v2] = True 185 | node_is_live[v1] = True 186 | break 187 | 188 | for v in node_is_live.keys(): 189 | if not node_is_live[v]: 190 | self.unmatched_inst[self.gold_ind[v]] = self.G.node[v]['gold_label'] 191 | self.G.node[v]['gold_label'] = '' 192 | self.G.node[v]['label'] = u'%s / *' % self.G.node[v]['test_label'] 193 | self.G.node[v]['color'] = TEST_COLOR 194 | del self.gold_ind[v] 195 | 196 | 197 | def amr2dict(inst, rel1, rel2): 198 | """ Get tables of AMR data indexed by variable number """ 199 | node_inds = {} 200 | inst_t = {} 201 | for (ind, (i, v, label)) in enumerate(inst): 202 | node_inds[v] = ind 203 | inst_t[ind] = label 204 | 205 | rel1_t = {} 206 | for (label, v1, const) in rel1: 207 | if (node_inds[v1], const) not in rel1_t: 208 | rel1_t[(node_inds[v1], const)] = set() 209 | rel1_t[(node_inds[v1], const)].add(label) 210 | 211 | rel2_t = {} 212 | for (label, v1, v2) in rel2: 213 | if (node_inds[v1], node_inds[v2]) not in rel2_t: 214 | rel2_t[(node_inds[v1], node_inds[v2])] = set() 215 | rel2_t[(node_inds[v1], node_inds[v2])].add(label) 216 | 217 | return (inst_t, rel1_t, rel2_t) 218 | -------------------------------------------------------------------------------- /demo/AMRICA-demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "celltoolbar": "Raw Cell Format", 4 | "name": "", 5 | "signature": "sha256:d1f62dd575ae10212f0278a04166bda353ce6bcd131f449e65789c99b0ca364b" 6 | }, 7 | "nbformat": 3, 8 | "nbformat_minor": 0, 9 | "worksheets": [ 10 | { 11 | "cells": [ 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# Using AMRICA\n", 17 | "\n", 18 | "AMRICA (AMR Inspector for Cross-language Alignments) is a simple tool for aligning and visually representing AMRs [(Banarescu, 2013)](http://www.isi.edu/natural-language/amr/a.pdf), both for bilingual contexts and for monolingual inter-annotator agreement. It is based on and extends the Smatch system [(Cai, 2012)](http://www.newdesign.aclweb.org/anthology-new/P/P13/P13-2131.pdf) for identifying AMR interannotator agreement.\n", 19 | "\n", 20 | "## How It Works\n", 21 | "\n", 22 | "### Smatch Classic\n", 23 | "\n", 24 | "Since AMRICA is a variation on Smatch, one should begin by understanding Smatch. Smatch attempts to identfy a matching between the variable nodes of two AMR representations of the same sentence in order to measure inter-annotator agreement. The matching should be selected to maximize the Smatch score, which assigns a point for each edge appearing in both graphs, falling into three categories. Each category is illustrated in the following annotation of \"It didn't take long.\"\n", 25 | "\n", 26 | "```\n", 27 | "(t / take-10\n", 28 | " :ARG0 (i / it)\n", 29 | " :ARG1 (l2 / long\n", 30 | " :polarity -))\n", 31 | "```\n", 32 | "\n", 33 | "* Instance labels, such as `(instance, t, take-10)`\n", 34 | "* Variable-variable edges, such as `(ARG0, t, i)`\n", 35 | "* Variable-const edges, such as `(polarity, l2, -)`\n", 36 | "\n", 37 | "Because the problem of finding the matching maximizing the Smatch score is NP-complete, Smatch uses a hill-climbing algorithm to approximate the best solution. It seeds by matching each node to a node sharing its label if possible and matching the remaining nodes in the smaller graph (hereafter the target) randomly. Smatch then performs a step by finding the action that will increase the score the most by either switching two target nodes' matchings or moving a matching from its source node to an unmatched source node. It repeats this step until no step can immediately increase the Smatch score.\n", 38 | "\n", 39 | "To avoid local optima, Smatch generally restarts 5 times.\n", 40 | "\n", 41 | "### AMRICA\n", 42 | "\n", 43 | "AMRICA begins by replacing all constant nodes with variable nodes that are instances of the constant's label. This is necessary so we can align the constant nodes as well as the variables. So the only points added to AMRICA score will come from matching variable-variable edges and instance labels.\n", 44 | "\n", 45 | "While Smatch tries to match every node in the smaller graph to some node in the larger graph, AMRICA removes matchings that do not increase the modified Smatch score, or AMRICA score.\n", 46 | "\n", 47 | "AMRICA then generates image files from graphviz graphs of the alignments. If a node or edge appears only in the gold data, it is red. If that node or edge appears only in the test data, it is blue. If the node or edge has a matching in our final alignment, it is black.\n", 48 | "\n", 49 | "\n", 50 | "\n", 51 | "#### Bitextual Variant\n", 52 | "\n", 53 | "In AMRICA, instead of adding one point for each perfectly matching instance label, we add a point based on a likelihood score on those labels aligning. The likelihood score $\\ell(a_{L_t,L_s}[i] | L_t, W_t, L_s, W_s)$ with target label set $L_t$, source labels set $L_s$, target sentence $W_t$, source sentence $W_s$, and alignment $a_{L_t,L_s}[i]$ mapping $L_{t}[i]$ onto some label $L_{s}[a_{L_t,L_s}[i]]$, is computed from a likelihood that is defined by the following rules:\n", 54 | "\n", 55 | "* If the labels for $L_s[a_{L_t,L_s}[i]]$ and $L_t[i]$ match, add 1 to the likelihood.\n", 56 | "* Add to the likelihood\n", 57 | "$$\\sum_{j = 1}^{|W_t|} \\sum_{k = 1}^{|W_s|} \\ell(a_{L_t,W_t}[i] = j) \\ell(a_{W_t,W_s}[j] = k) \\ell(a_{W_s,L_s}[k] = a_{L_t, L_s}[i])$$\n", 58 | " * Compute $\\ell(a_{L_t,W_t}[i] = j)$ by one of two methods.\n", 59 | " * If there are JAMR alignments available, for each JAMR alignment containing this node, 1 point is partitioned among the tokens in the range aligned to the label. If there are no such tokens in the range, the 1 point is partitioned among all tokens in the range.\n", 60 | " * If no JAMR alignment contains the $i^{\\textit{th}}$ node, treat it as though the token ranges with no JAMR aligned nodes were aligned to the $i^{\\textit{th}}$ node.\n", 61 | " * If there are no JAMR alignments available, then 1 point is partitioned among all tokens string-matching label $i$.\n", 62 | " * Compute $\\ell(a_{W_s,L_s}[k] = a_{L_t, L_s}[i])$ by the same method.\n", 63 | " * Compute $\\ell(a_{W_t,W_s}[j] = k)$ from a posterior word alignment score extracted from the source-target and target-source nbest GIZA alignment files, normalized to 1.\n", 64 | "\n", 65 | "In general, bilingual AMRICA appears to require more random restarts than monolingual AMRICA to perform well. This restart count can be modified with the flag `--num_restarts`.\n", 66 | "\n", 67 | "\n", 68 | "\n", 69 | "##### Comparison: Smart Initialization vs. Approximation\n", 70 | "\n", 71 | "We can observe the degree to which using Smatch-like approximations (here, with 20 random initializations) improves accuracy over selecting likely matches from raw alignment data (smart initialization). For a pairing declared structurally compatible by [(Xue 2014)](http://www.lrec-conf.org/proceedings/lrec2014/pdf/384_Paper.pdf).\n", 72 | "\n", 73 | "* After initialization:\n", 74 | "\n", 75 | "\n", 76 | "\n", 77 | "* After bilingual smatch, with errors circled:\n", 78 | "\n", 79 | "\n", 80 | "\n", 81 | "For a pairing considered incompatible:\n", 82 | "\n", 83 | "* After initialization:\n", 84 | "\n", 85 | "\n", 86 | "\n", 87 | "* After bilingual smatch, with errors circled:\n", 88 | "\n", 89 | "\n", 90 | "\n", 91 | "\n", 92 | "\n", 93 | "## Getting started\n", 94 | "\n", 95 | "Download the python source from [github](https://github.com/nsaphra/AMRICA).\n", 96 | "\n", 97 | "### Dependencies\n", 98 | "\n", 99 | "The following python packages are required to run AMRICA and can be installed with pip: `networkx`, `argparse`, `argparse_config`, `pygraphviz`.\n", 100 | "\n", 101 | "Additionally, to prepare bilingual alignment data you will need [GIZA++](https://code.google.com/p/giza-pp/) and possibly [JAMR](https://github.com/jflanigan/jamr/).\n", 102 | "\n", 103 | "### Data Preparation\n", 104 | "\n", 105 | "#### Monolingual AMRICA\n", 106 | "\n", 107 | "To generate visualizations of Smatch alignments, we need an AMR input file with each \n", 108 | "`::tok` fields containing tokenized sentences, `::id` fields with a sentence ID, and `::anno` fields with an annotator ID. The annotations for a particular sentence are listed sequentially, and the first annotation is considered the gold standard for visualization purposes.\n", 109 | "\n", 110 | "If you only want to visualize the single annotation per sentence, you can use an AMR file with only a single annotator.\n", 111 | "\n", 112 | "### Bilingual AMRICA\n", 113 | "\n", 114 | "For bilingual alignments, we start with two AMR files, one containing the target annotations and one with the source annotations in the same order, with `::tok` and `::id` fields for each annotation. If we want JAMR alignments for either side, we include those in a `::alignments` field.\n", 115 | "\n", 116 | "The sentence alignments should be in the form of two GIZA++ alignment .NBEST files, one source-target and one target-source. To generate these, use the --nbestalignments flag in your GIZA++ config file set to your preferred nbest count.\n", 117 | "\n", 118 | "## Configuration\n", 119 | "\n", 120 | "Flags can be set either at the command line or in a config file. The location of a config file can be set with `-c CONF_FILE` at the command line.\n", 121 | "\n", 122 | "### Common flags\n", 123 | "\n", 124 | "In addition to `--conf_file`, there are several other flags that apply to both monolingual and bilingual text. `--outdir DIR` is the only required one, and specifies the directory to which we will write the image files.\n", 125 | "\n", 126 | "The optional shared flags are:\n", 127 | "* `--verbose` to print out sentences as we align them.\n", 128 | "* `--no-verbose` to override a verbose default setting.\n", 129 | "* `--json FILE.json` to write the alignment graphs to a .json file.\n", 130 | "* `--num_restarts N` to specify the number of random restarts Smatch should execute.\n", 131 | "* `--align_out FILE.csv` to write the alignments to file.\n", 132 | "* `--align_in FILE.csv` to read the alignments from disk instead of running Smatch.\n", 133 | "\n", 134 | "The alignment .csv files are in a format where each graph matching set is separated by an empty line, and each line within a set contains either a comment or a line indicating an alignment. For example:\n", 135 | "\n", 136 | "```\n", 137 | "3\t\u5b83\t-\t1\tit\n", 138 | "2\t\u591a\u957f\t-\t-1\n", 139 | "-1\t\t-\t 2 take\n", 140 | "```\n", 141 | "\n", 142 | "The tab-separated fields are the test node index (as processed by Smatch), the test node label, the gold node index, and the gold node label.\n", 143 | "\n", 144 | "### Monolingual\n", 145 | "\n", 146 | "Monolingual alignment requires one additional flag, `--infile FILE.amr`, with `FILE.amr` set to the location of the AMR file.\n", 147 | "\n", 148 | "Following is an example config file:\n", 149 | "\n", 150 | "```\n", 151 | "[default]\n", 152 | "infile: data/events_amr.txt\n", 153 | "outdir: data/events_png/\n", 154 | "json: data/events.json\n", 155 | "verbose\n", 156 | "```\n", 157 | "\n", 158 | "### Bilingual\n", 159 | "\n", 160 | "In bilingual alignment, there are more required flags.\n", 161 | "\n", 162 | "* `--src_amr FILE` for the source annotation AMR file.\n", 163 | "* `--tgt_amr FILE` for the target annotation AMR file.\n", 164 | "* `--align_tgt2src FILE.A3.NBEST` for the GIZA++ .NBEST file aligning target-to-source (with target as vcb1), generated with `--nbestalignments N`\n", 165 | "* `--align_src2tgt FILE.A3.NBEST` for the GIZA++ .NBEST file aligning source-to-target (with source as vcb1), generated with `--nbestalignments N`\n", 166 | "\n", 167 | "Now if `--nbestalignments N` was set to be >1, we should specify it with `--num_aligned_in_file`. If we want to count only the top $k$ of those alignments, we set `--num_align_read` as well.\n", 168 | "\n", 169 | "## Endnotes\n", 170 | "\n", 171 | "`--nbestalignments` is a tricky flag to use, because it will only generate on a final alignment run. I could only get it to work with the default GIZA++ settings, myself." 172 | ] 173 | } 174 | ], 175 | "metadata": {} 176 | } 177 | ] 178 | } -------------------------------------------------------------------------------- /demo/DF-183-195681-794_9333.5_annoted_UCO-AMR-06_UCO-AMR-05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/demo/DF-183-195681-794_9333.5_annoted_UCO-AMR-06_UCO-AMR-05.png -------------------------------------------------------------------------------- /demo/wb.eng_0003.13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/demo/wb.eng_0003.13.png -------------------------------------------------------------------------------- /demo/wb.eng_0003.14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/demo/wb.eng_0003.14.png -------------------------------------------------------------------------------- /demo/wb.eng_0003.14_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/demo/wb.eng_0003.14_0.png -------------------------------------------------------------------------------- /demo/wb.eng_0003.6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/demo/wb.eng_0003.6.png -------------------------------------------------------------------------------- /demo/wb.eng_0003.6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/demo/wb.eng_0003.6_0.png -------------------------------------------------------------------------------- /disagree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | disagree.py 4 | 5 | Author: Naomi Saphra (nsaphra@jhu.edu) 6 | Copyright(c) 2014. All rights reserved. 7 | 8 | A tool for inspecting AMR data to id patterns of inter-annotator disagreement 9 | or semantic inequivalence. 10 | 11 | AMR input file expected in format where comments above each annotation indicate 12 | the sentence like so: 13 | 14 | # ::id DF-170-181103-888_2097.1 ::date 2013-09-16T07:15:31 ::annotator ANON-01 ::preferred 15 | # ::tok This is a sentence . 16 | (this / file 17 | :is (an / AMR)) 18 | 19 | For monolingual disagreement, all annotations of some sentence should occur 20 | consecutively in the monolingual annotation file. For bilingual, annotations 21 | should be in the same order of sentences between the two files. 22 | 23 | For bilingual disagreement, you can include a ::alignments field from jamr to help with 24 | AMR-sentence alignment. 25 | """ 26 | 27 | import argparse 28 | import argparse_config 29 | import codecs 30 | import networkx as nx 31 | from networkx.readwrite import json_graph 32 | import os 33 | import pygraphviz as pgz 34 | 35 | # internal libraries 36 | from compare_smatch import amr_metadata 37 | from compare_smatch import smatch_graph 38 | from compare_smatch.amr_alignment import Amr2AmrAligner 39 | from compare_smatch.amr_alignment import default_aligner 40 | from compare_smatch.smatch_graph import SmatchGraph 41 | from smatch import smatch 42 | 43 | cur_sent_id = 0 44 | 45 | def hilight_disagreement(test_amrs, gold_amr, iter_num, aligner=default_aligner, gold_aligned_fh=None): 46 | """ 47 | Input: 48 | test_amrs: list of AMRs to compare to 49 | gold_amr: gold AMR object 50 | iter_num: Number of random restarts to use in smatch algorithm. 51 | Returns list of disagreement graphs for each gold-test AMR pair. 52 | """ 53 | 54 | smatchgraphs = [] 55 | gold_label=u'b' 56 | gold_amr.rename_node(gold_label) 57 | (gold_inst, gold_rel1, gold_rel2) = gold_amr.get_triples2() 58 | (gold_inst_t, gold_rel1_t, gold_rel2_t) = smatch_graph.amr2dict(gold_inst, gold_rel1, gold_rel2) 59 | # TODO Also compute the weight score if we read gold alignments in from file 60 | # TODO This would require me to handle constants when we read from file 61 | 62 | for a in test_amrs: 63 | aligner.set_amrs(a, gold_amr) 64 | test_label=u'a' 65 | a.rename_node(test_label) 66 | (test_inst, test_rel1, test_rel2) = a.get_triples2() 67 | if gold_aligned_fh: 68 | best_match = get_next_gold_alignments(gold_aligned_fh) 69 | best_match_num = -1.0 70 | else: 71 | (best_match, best_match_num) = smatch.get_fh(test_inst, test_rel1, test_rel2, 72 | gold_inst, gold_rel1, gold_rel2, 73 | test_label, gold_label, 74 | node_weight_fn=aligner.node_weight_fn, edge_weight_fn=aligner.edge_weight_fn, 75 | iter_num=iter_num) 76 | 77 | disagreement = SmatchGraph(test_inst, test_rel1, test_rel2, \ 78 | gold_inst_t, gold_rel1_t, gold_rel2_t, \ 79 | best_match, const_map_fn=aligner.const_map_fn) 80 | smatchgraphs.append((disagreement, best_match_num)) 81 | return smatchgraphs 82 | 83 | 84 | def get_disagreement_graphs(smatchgraphs, aligner=default_aligner, 85 | unmatch_dead_nodes=True): 86 | if unmatch_dead_nodes: 87 | return [(g.smatch2graph(node_weight_fn=aligner.node_weight_fn, 88 | edge_weight_fn=aligner.edge_weight_fn), 89 | score) \ 90 | for (g, score) in smatchgraphs] 91 | else: 92 | return [(g.smatch2graph(), score) for (g, score) in smatchgraphs] 93 | 94 | 95 | def open_output_files(args): 96 | json_fh = None 97 | if args.json_out: 98 | json_fh = codecs.open(args.json_out, 'w', encoding='utf8') 99 | align_fh = None 100 | if args.align_out: 101 | align_fh = codecs.open(args.align_out, 'w', encoding='utf8') 102 | return (json_fh, align_fh) 103 | 104 | 105 | def close_output_files(json_fh, align_fh): 106 | json_fh and json_fh.close() 107 | align_fh and align_fh.close() 108 | 109 | 110 | def get_next_gold_alignments(gold_aligned_fh): 111 | match_hash = {} 112 | line = gold_aligned_fh.readline().strip() 113 | while (line): 114 | if line.startswith('#'): # comment line 115 | line = gold_aligned_fh.readline().strip() 116 | continue 117 | align = line.split('\t') 118 | test_ind = int(align[0]) 119 | gold_ind = int(align[3]) 120 | if test_ind >= 0: 121 | match_hash[test_ind] = gold_ind 122 | line = gold_aligned_fh.readline().strip() 123 | 124 | return [v for (k, v) in sorted(match_hash.items())] 125 | 126 | 127 | def get_sent_info(metadata, dflt_id=None): 128 | """ Return ID, sentence if available, and change metadata to reflect """ 129 | (sent_id, sent) = (None, None) 130 | global cur_sent_id 131 | if 'tok' in metadata: 132 | sent = metadata['tok'] 133 | else: 134 | sent = metadata['snt'] 135 | 136 | if 'id' in metadata: 137 | sent_id = metadata['id'] 138 | elif dflt_id is not None: 139 | sent_id = dflt_id 140 | else: 141 | sent_id = "%d" % cur_sent_id 142 | cur_sent_id += 1 143 | 144 | (metadata['id'], metadata['tok']) = \ 145 | (sent_id, sent) 146 | 147 | return (sent_id, sent) 148 | 149 | 150 | def monolingual_main(args): 151 | """ Disagreement graphs for different annotations of a single sentence. """ 152 | infile = codecs.open(args.infile, encoding='utf8') 153 | gold_aligned_fh = None 154 | if args.align_in: 155 | gold_aligned_fh = codecs.open(args.align_in, encoding='utf8') 156 | (json_fh, align_fh) = open_output_files(args) 157 | 158 | amrs_same_sent = [] 159 | cur_id = "" 160 | while True: 161 | (amr_line, comments) = amr_metadata.get_amr_line(infile) 162 | cur_amr = None 163 | if amr_line: 164 | cur_amr = amr_metadata.AmrMeta.from_parse(amr_line, comments, 165 | consts_to_vars=(gold_aligned_fh != None or align_fh != None)) 166 | get_sent_info(cur_amr.metadata) 167 | if 'annotator' not in cur_amr.metadata: 168 | cur_amr.metadata['annotator'] = '' 169 | if not cur_id: 170 | cur_id = cur_amr.metadata['id'] 171 | 172 | if cur_amr is None or cur_id != cur_amr.metadata['id'] or (args.singleview and len(amrs_same_sent)): 173 | gold_amr = amrs_same_sent[0] 174 | test_amrs = amrs_same_sent[1:] 175 | if len(test_amrs) == 0: 176 | test_amrs = [gold_amr] # single AMR view case 177 | args.num_restarts = 1 # TODO make single AMR view more efficient 178 | smatchgraphs = hilight_disagreement(test_amrs, gold_amr, 179 | args.num_restarts, gold_aligned_fh=gold_aligned_fh) 180 | amr_graphs = get_disagreement_graphs(smatchgraphs, unmatch_dead_nodes=(gold_aligned_fh == None)) 181 | gold_anno = gold_amr.metadata['annotator'] 182 | sent = gold_amr.metadata['tok'] 183 | 184 | if (args.verbose): 185 | print("ID: %s\n Sentence: %s\n gold anno: %s" % (cur_id, sent, gold_anno)) 186 | 187 | for (ind, a) in enumerate(test_amrs): 188 | (g, score) = amr_graphs[ind] 189 | test_anno = a.metadata['annotator'] 190 | if json_fh: 191 | json_fh.write(json_graph.dumps(g) + '\n') 192 | if align_fh: 193 | sg = smatchgraphs[ind][0] 194 | align_fh.write("""# ::id %s\n# ::tok %s\n# ::gold_anno %s\n# ::test_anno %s\n""" % \ 195 | (cur_id, sent, gold_anno, test_anno)) 196 | align_fh.write('\n'.join(sg.get_text_alignments()) + '\n\n') 197 | if (args.verbose): 198 | print(" annotator %s score: %d" % (test_anno, score)) 199 | 200 | ag = nx.drawing.nx_agraph.to_agraph(g) 201 | ag.graph_attr['label'] = sent 202 | ag.layout(prog=args.layout) 203 | ag.draw('%s/%s_annotated_%s_%s.png' % (args.outdir, cur_id, gold_anno, test_anno)) 204 | 205 | amrs_same_sent = [] 206 | if cur_amr is not None: 207 | cur_id = cur_amr.metadata['id'] 208 | else: 209 | break 210 | 211 | amrs_same_sent.append(cur_amr) 212 | 213 | infile.close() 214 | gold_aligned_fh and gold_aligned_fh.close() 215 | close_output_files(json_fh, align_fh) 216 | 217 | 218 | def xlang_main(args): 219 | """ Disagreement graphs for aligned cross-language language. """ 220 | src_amr_fh = codecs.open(args.src_amr, encoding='utf8') 221 | tgt_amr_fh = codecs.open(args.tgt_amr, encoding='utf8') 222 | src2tgt_fh = codecs.open(args.align_src2tgt, encoding='utf8') 223 | tgt2src_fh = codecs.open(args.align_tgt2src, encoding='utf8') 224 | gold_aligned_fh = None 225 | if args.align_in: 226 | gold_aligned_fh = codecs.open(args.align_in, encoding='utf8') 227 | (json_fh, align_fh) = open_output_files(args) 228 | 229 | amrs_same_sent = [] 230 | aligner = Amr2AmrAligner(num_best=args.num_align_read, num_best_in_file=args.num_aligned_in_file, src2tgt_fh=src2tgt_fh, tgt2src_fh=tgt2src_fh) 231 | while True: 232 | (src_amr_line, src_comments) = amr_metadata.get_amr_line(src_amr_fh) 233 | if src_amr_line == "": 234 | break 235 | (tgt_amr_line, tgt_comments) = amr_metadata.get_amr_line(tgt_amr_fh) 236 | src_amr = amr_metadata.AmrMeta.from_parse(src_amr_line, src_comments, consts_to_vars=True) 237 | tgt_amr = amr_metadata.AmrMeta.from_parse(tgt_amr_line, tgt_comments, consts_to_vars=True) 238 | (cur_id, src_sent) = get_sent_info(src_amr.metadata) 239 | (tgt_id, tgt_sent) = get_sent_info(tgt_amr.metadata, dflt_id=cur_id) 240 | assert cur_id == tgt_id 241 | 242 | smatchgraphs = hilight_disagreement([tgt_amr], src_amr, args.num_restarts, aligner=aligner, gold_aligned_fh=gold_aligned_fh) 243 | amr_graphs = get_disagreement_graphs(smatchgraphs, aligner=aligner, 244 | unmatch_dead_nodes=(gold_aligned_fh == None)) 245 | 246 | if json_fh: 247 | json_fh.write(json_graph.dumps(amr_graphs[0]) + '\n') 248 | if align_fh: 249 | align_fh.write("""# ::id %s\n# ::src_snt %s\n# ::tgt_snt %s\n""" % (cur_id, src_sent, tgt_sent)) 250 | align_fh.write('\n'.join(smatchgraphs[0].get_text_alignments()) + '\n\n') 251 | if (args.verbose): 252 | print("ID: %s\n Sentence: %s\n Sentence: %s\n Score: %f" % (cur_id, src_sent, tgt_sent, amr_graphs[0][1])) 253 | 254 | ag = nx.drawing.nx_agraph.to_agraph(amr_graphs[0][0]) 255 | ag.graph_attr['label'] = "%s\n%s" % (src_sent, tgt_sent) 256 | ag.layout(prog=args.layout) 257 | ag.draw('%s/%s.png' % (args.outdir, cur_id)) 258 | 259 | src_amr_fh.close() 260 | tgt_amr_fh.close() 261 | src2tgt_fh.close() 262 | tgt2src_fh.close() 263 | gold_aligned_fh and gold_aligned_fh.close() 264 | close_output_files(json_fh, align_fh) 265 | 266 | 267 | if __name__ == '__main__': 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument("-c", "--conf_file", help="Specify config file") 270 | parser.add_argument('-i', '--infile', help='amr input file') 271 | parser.add_argument('-o', '--outdir', help='Visualization output directory') 272 | parser.add_argument('-v', '--verbose', action='store_true') 273 | parser.add_argument('--no-verbose', action='store_true') 274 | parser.add_argument('-b', '--bitext', action='store_true', 275 | help='Input source and target language bitext AMRs.') 276 | parser.add_argument('-s', '--src_amr', 277 | help='In bitext mode, source language AMR file.') 278 | parser.add_argument('-t', '--tgt_amr', 279 | help='In bitext mode, target language AMR file.') 280 | parser.add_argument('--align_src2tgt', 281 | help='In bitext mode, GIZA alignment .NBEST file (see GIZA++ -nbestalignments opt) with source as vcb1.') 282 | parser.add_argument('--align_tgt2src', 283 | help='In bitext mode, GIZA alignment .NBEST file (see GIZA++ -nbestalignments opt) with target as vcb1.') 284 | parser.add_argument('--num_align_read', type=int, 285 | help='N to read from GIZA NBEST file.') 286 | parser.add_argument('--num_aligned_in_file', type=int, default=1, 287 | help='N printed to GIZA NBEST file.') 288 | parser.add_argument('-j', '--json_out', 289 | help='File to dump json graphs to.') 290 | parser.add_argument('--num_restarts', type=int, default=5, 291 | help='Number of random restarts to execute during hill-climbing algorithm.') 292 | parser.add_argument('--align_out', 293 | help="Human-readable alignments output file - WARNING, will force conversion of const nodes to var nodes for alignment") 294 | parser.add_argument('--align_in', 295 | help="Alignments from human-editable text file, as from align_out") 296 | parser.add_argument('--layout', default='dot', 297 | help='Graphviz output layout') 298 | parser.add_argument('--singleview', action='store_true', 299 | help='If set, display each AMR in the file individually without alignments') 300 | # TODO make interactive option and option to process a specific range 301 | 302 | args_conf = parser.parse_args() 303 | if args_conf.conf_file: 304 | argparse_config.read_config_file(parser, args_conf.conf_file) 305 | 306 | args = parser.parse_args() 307 | if args.no_verbose: 308 | args.verbose = False 309 | if not args.num_align_read: 310 | args.num_align_read = args.num_aligned_in_file 311 | 312 | if not os.path.exists(args.outdir): 313 | os.makedirs(args.outdir) 314 | 315 | if (args.bitext): 316 | xlang_main(args) 317 | else: 318 | if args.infile == None or args.outdir == None: 319 | raise parser.error("Both --infile and --outdir are required flags.") 320 | monolingual_main(args) 321 | -------------------------------------------------------------------------------- /scripts/smatch_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from collections import defaultdict 5 | from networkx.readwrite import json_graph 6 | import sys 7 | 8 | from compare_smatch.smatch_graph import SmatchGraph 9 | 10 | sys.path.append("..") 11 | 12 | counts = defaultdict(int) 13 | 14 | def add_counts(g, v1, v2, attr, prefix): 15 | def pfx(lbl): 16 | return '%s_%s' % (prefix, lbl) 17 | def opp(lbl): 18 | if prefix == 'gold': 19 | return 'test_%s' % lbl 20 | if prefix == 'test': 21 | return 'gold_%s' % lbl 22 | else: 23 | raise Exception #TODO think of a good exception 24 | def incr(lbl): 25 | counts['total_%s' % lbl] += 1 26 | counts[pfx(lbl)] += 1 27 | def is_dflt(curr_attr): 28 | return curr_attr['gold_label'] == curr_attr['test_label'] 29 | def is_opp(curr_attr): 30 | return curr_attr['color'] != attr['color'] and not is_dflt(curr_attr) 31 | def is_same(curr_attr): 32 | return curr_attr['color'] == attr['color'] 33 | 34 | incr('edges') 35 | if attr['label'] == 'polarity': 36 | incr('polarity') 37 | if attr['label'].startswith('op') and len(attr['label']) == 3: 38 | incr('name_opt') 39 | return 40 | if g.node[v2]['color'] == attr['color']: 41 | incr('same_color_head') 42 | if g.node[v1]['color'] == attr['color']: 43 | incr('same_color_tail') 44 | if prefix != 'dflt': 45 | for head, edges in g.edge[v1].items(): 46 | for (ind, curr) in edges.items(): 47 | if is_same(curr): 48 | continue 49 | if head == v2: 50 | if is_opp(curr): 51 | incr('cfg1') 52 | else: 53 | for head2, edges2 in g.edge[v1].items(): 54 | for (ind2, curr2) in edges2.items(): 55 | if head2 == v2 and is_opp(curr2): 56 | if is_dflt(curr): 57 | incr('cfg2') 58 | else: # is_opp 59 | incr('cfg3') 60 | 61 | for head, edges in g.edge[v2].items(): 62 | for (ind, attrs) in edges.items(): 63 | if head == v1: 64 | if is_opp(curr): 65 | incr('cfg1_reverse') 66 | 67 | 68 | def analyze(g): 69 | for (v1, links) in g.adjacency_iter(): 70 | for (v2, edges) in links.items(): 71 | for (ind, attr) in edges.items(): 72 | if not attr['gold_label']: 73 | add_counts(g, v1, v2, attr, 'test') 74 | elif not attr['test_label']: 75 | add_counts(g, v1, v2, attr, 'gold') 76 | else: 77 | add_counts(g, v1, v2, attr, 'dflt') 78 | 79 | 80 | def print_proportions(): 81 | for k,v in sorted(counts.items()): 82 | if k != 'total_edges': 83 | print '%s: %f' % (k,v/float(counts['total_edges'])) 84 | 85 | def main(args): 86 | input_fh = open(args.input) 87 | while True: 88 | line = input_fh.readline().strip() 89 | if not line: 90 | break 91 | g = json_graph.loads(line) 92 | analyze(g) 93 | input_fh.close() 94 | for k,v in sorted(counts.items()): 95 | print '%s: %d' % (k,v) 96 | print '=======' 97 | print_proportions() 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser( 102 | description='Statistical analysis of smatch disagreement graphs.\n' 103 | 'Usage: ./smatch_stats.py -i graphs.json' 104 | ) 105 | parser.add_argument('-i', '--input', 106 | help='Specify .json amr disagreement file') 107 | args = parser.parse_args() 108 | 109 | main(args) -------------------------------------------------------------------------------- /smatch/README.txt: -------------------------------------------------------------------------------- 1 | There is also a pdf version of this documentation: smatch_guide.pdf (with the same content but in the same directory. 2 | 3 | Smatch Tool Guideline 4 | 5 | Shu Cai 03/20/2013 6 | 7 | Smatch is a tool to evaluate the semantic overlap between semantic feature structures. It can be used to compute the inter agreements of AMRs, and the agreement between an automatic-generated AMR and a gold AMR. For multiple AMR pairs, the smatch tool can provide a weighted, overall score for all the AMR pairs. 8 | 9 | I. Content and web demo pages 10 | 11 | The directory contains the Smatch code (mostly Python and some Perl) as well as a guide for Smatch. 12 | 13 | Smatch Webpages 14 | 15 | Smatch tool webpage: http://amr.isi.edu/eval/smatch/compare.html (A quick tutorial can be found on the page) 16 | - input: two AMRs. 17 | - output: the smatch score and the matching/unmatching triples. 18 | 19 | Smatch table tool webpage: http://amr.isi.edu/eval/smatch/table.html 20 | - input: AMR IDs and users. 21 | - output: a table which consists of the smatch scores of every pair of users. 22 | 23 | II. Installation 24 | 25 | Python (version 2.5 or later) is required to run smatch tool. Python 2.7 is recommended. No compilation is necessary. 26 | 27 | If a user wants to run smatch tool outside the current locations, they can just copy the whole directory. Running the latest smatch tools requires the following files: amr.py (a library called by smatch.py), smatch.py, smatch-table.py. Running the old versions of smatch requires Perl installed, and 28 | esem-format-check.pl,smatch-v0.x.py (x<5), smatch-table-v0.x.py (x<3). 29 | 30 | III. Usage 31 | 32 | Smatch tool consists of two program written in python. 33 | 34 | 1. smatch.py: for computing the smatch score(s) for multiple AMRs created by two different groups. 35 | 36 | Input: two files which contain AMRs. Each file may contain multiple AMRs, and every two AMRs are separated by a blank line. AMRs can be one-per-line or have multiple lines, as long as there is no blank line in one AMR. 37 | 38 | Input file format: see test_input1.txt, test_input2.txt in the smatch tool folder. AMRs are separated by one or more blank lines, so no blank lines are allowed inside an AMR. Lines starting with a hash (#) will be ignored. 39 | 40 | Output: Smatch score(s) computed 41 | 42 | Usage: python smatch.py [-h] -f F F [-r R] [-v] [-ms] 43 | 44 | arguments: 45 | 46 | -h: help 47 | 48 | -f: two files which contain multiple AMRs. A blank line is used to separate two AMRs. Required arguments. 49 | 50 | -r: restart numer of the heuristic search during computation, optional. Default value: 4. This argument must be a positive integer. Large restart number will reduce the chance of search error, but also increase the running time. Small restart number will reduce the running time as well as increase the change of search error. The default value is by far the best trade-off. User can set a large number if the AMR length is long (search space is large) and user does not need very high speed. 51 | 52 | -v: verbose output, optional. Default value: false. The verbose information includes the triples of each AMR, the matching triple number found for each iterations, and the best matching triple number. It is useful when you try to understand how the program works. User will not need this option most of the time. 53 | 54 | --ms: multiple score, optional. Adding this option will result in a single smatch score for each AMR pair. Otherwise it will output one single weighted score based on all pairs of AMRs. AMRs are weighted according to their number of triples. 55 | Default value: false 56 | 57 | --pr: Output precision and recall as well as the f-score. Default:false 58 | 59 | A typical (and most common) example of running smatch.py: 60 | 61 | python smatch.py -f test_input1.txt test_input2.txt 62 | 63 | The release includes sample files test_input1.txt and test_input2.txt, so you should be able to run the above command as is. The above command should about the following line: 64 | Document F-score: 0.81 65 | 66 | 2. smatch-table.py: it calls the smatch library to compute the smatch scores for a group of users and multiple AMR IDs, and output a table to show the AMR score between each pair of users. 67 | 68 | Input: AMR ID list and User list. AMR ID list can be stored in a file (-fl file) or given by the command line (-f AMR_ID1, AMR_ID2,...). User list are given by the command line (-p user1,user2,..). If no users are given, the program searches for all the users who annotates all AMRs we require. The user number should be at least 2. 69 | 70 | Input file format: AMR ID list (see sample_file_list the smatch tool folder) 71 | 72 | Output: A table which shows the overall AMR score between every pair of users. 73 | 74 | Usage: python smatch-table.py [-h] [--fl FL] [-f F [F ...]] [-p [P [P ...]]] 75 | [--fd FD] [-r R] [-v] 76 | 77 | optional arguments: 78 | 79 | -h, --help show this help message and exit 80 | 81 | --fl FL AMR ID list file (a file which contains one line of AMR IDs, separated by blank space) 82 | 83 | -f F [F ...] AMR IDs (at least one). If we already have valid AMR ID list file, this option will be ignored. 84 | 85 | -p [P [P ...]] User list (It can be unspecified. When the list is none, the program searches for all the users who annotates all AMRs we require) It is meaningless to give only one user since smatch-table computes agreement between each pair of users. So the number of P is at least 2. 86 | 87 | --fd FD AMR File directory. Default=location on isi file system 88 | 89 | -r R Restart number (Default:4), same as the -r option in smatch.py 90 | 91 | -v Verbose output (Default:False), same as the -v option in smatch.py 92 | 93 | 94 | A typical example of running smatch-table.py: 95 | 96 | python smatch-table.py --fd $amr_root_dir --fl sample_file_list -p ulf knight 97 | 98 | which will compare files 99 | $amr_root_dir/ulf/nw_wsj_0001_1.txt $amr_root_dir/knight/nw_wsj_0001_1.txt 100 | $amr_root_dir/ulf/nw_wsj_0001_2.txt $amr_root_dir/knight/nw_wsj_0001_2.txt 101 | etc. 102 | 103 | Note: smatch-table.py computes smatch scores for every pair of users, so its speed can be slow when the number of user is large or when -P option is not set (in this case we compute smatch scores for all users who annotates the AMRs we require). 104 | -------------------------------------------------------------------------------- /smatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsaphra/AMRICA/a3b85a9635b78e3c7f47d548b0b9b3ddbdffe237/smatch/__init__.py -------------------------------------------------------------------------------- /smatch/amr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | amr.py 5 | 6 | This file is from the code for smatch, available at: 7 | 8 | http://amr.isi.edu/download/smatch-v1.0.tar.gz 9 | http://amr.isi.edu/smatch-13.pdf 10 | """ 11 | 12 | import sys 13 | from collections import defaultdict 14 | 15 | 16 | class AMR(object): 17 | 18 | def __init__( 19 | self, 20 | var_list=None, 21 | var_value_list=None, 22 | link_list=None, 23 | const_link_list=None, 24 | path2label=None): 25 | """ 26 | path2label: maps 0.1.0 to the label (inst or const) of the 0-indexed child 27 | of the 1-indexed child of the 0th node (head) 28 | """ 29 | if var_list is None: 30 | self.nodes = [] # AMR variables 31 | self.root = None 32 | else: 33 | self.nodes = var_list[:] 34 | if len(var_list) != 0: 35 | self.root = var_list[0] 36 | else: 37 | self.root = None 38 | if var_value_list is None: 39 | self.var_values = [] 40 | else: 41 | self.var_values = var_value_list[:] 42 | if link_list is None: 43 | # connections between instances #adjacent list representation 44 | self.links = [] 45 | else: 46 | self.links = link_list[:] 47 | if const_link_list is None: 48 | self.const_links = [] 49 | else: 50 | self.const_links = const_link_list[:] 51 | if path2label is None: 52 | self.path2label = {} 53 | else: 54 | self.path2label = path2label 55 | 56 | def add_node(node_value): 57 | self.nodes.append(node_value) 58 | 59 | def rename_node(self, prefix): 60 | var_map_dict = {} 61 | for i in range(0, len(self.nodes)): 62 | var_map_dict[self.nodes[i]] = prefix + str(i) 63 | for i, v in enumerate(self.nodes): 64 | self.nodes[i] = var_map_dict[v] 65 | for i, d in enumerate(self.links): 66 | new_dict = {} 67 | for k, v in d.items(): 68 | new_dict[var_map_dict[k]] = v 69 | self.links[i] = new_dict 70 | 71 | def get_triples(self): 72 | """Get the triples in two list: instance_triple, relation_triple""" 73 | instance_triple = [] 74 | relation_triple = [] 75 | for i in range(len(self.nodes)): 76 | instance_triple.append(("instance", self.nodes[i], self.var_values[i])) 77 | for k, v in self.links[i].items(): 78 | relation_triple.append((v, self.nodes[i], k)) 79 | for k2, v2 in self.const_links[i].items(): 80 | relation_triple.append((k2, self.nodes[i], v2)) 81 | return (instance_triple, relation_triple) 82 | 83 | def get_triples2(self): 84 | """Get the triples in three lists: instance_triple, relation (two variables) triple, and relation (one variable) triple""" 85 | instance_triple = [] 86 | relation_triple1 = [] 87 | relation_triple2 = [] 88 | for i in range(len(self.nodes)): 89 | instance_triple.append(("instance", self.nodes[i], self.var_values[i])) 90 | for k, v in self.links[i].items(): 91 | relation_triple2.append((v, self.nodes[i], k)) 92 | for k2, v2 in self.const_links[i].items(): 93 | relation_triple1.append((k2, self.nodes[i], v2)) 94 | return (instance_triple, relation_triple1, relation_triple2) 95 | 96 | def __str__(self): 97 | """Output AMR string""" 98 | for i in range(len(self.nodes)): 99 | print "Variable", i, self.nodes[i] 100 | print "Dependencies:" 101 | for k, v in self.links[i].items(): 102 | print "Variable", k, " via ", v 103 | for k2, v2 in self.const_links[i].items(): 104 | print "Attribute:", k2, "value", v2 105 | 106 | def __repr__(self): 107 | return self.__str__() 108 | 109 | def out_amr(self): 110 | self.__str__() 111 | 112 | @staticmethod 113 | def parse_AMR_line(line, consts_to_vars=False): 114 | # set consts_to_vars True if you want consts represented as variable nodes with 115 | # instance labels 116 | # significant symbol just encountered: 1 for (, 2 for :, 3 for / 117 | state = -1 118 | stack = [] # variable stack 119 | cur_charseq = [] # current processing char sequence 120 | var_dict = {} # key: var name value: var value 121 | var_list = [] # variable name list (order: occurence of the variable 122 | # key: var name: value: list of (attribute name, other variable) 123 | var_attr_dict1 = defaultdict(list) 124 | # key:var name, value: list of (attribute name, const value) 125 | var_attr_dict2 = defaultdict(list) 126 | cur_attr_name = "" # current attribute name 127 | attr_list = [] # each entry is an attr dict 128 | in_quote = False 129 | curr_path = ['0'] 130 | path2label = {} 131 | path_lookup = {} # (var, reln, const) to path key 132 | 133 | def remove_from_paths(path): 134 | """ Adjust all paths in path2label by removing the node at path 135 | (and any descdendants) """ 136 | node_ind = int(path[-1]) 137 | depth = len(path) - 1 138 | prefix = '.'.join(path[:-1]) + '.' 139 | # remove node from path2label keys 140 | new_path2label = {} 141 | for (k, v) in path2label.items(): 142 | if k.startswith(prefix): 143 | k_arr = k.split('.') 144 | curr_ind = int(k_arr[depth]) 145 | if curr_ind == node_ind: 146 | continue # deleting node 147 | elif curr_ind > node_ind: 148 | # node index moves down by 1 since middle node removed 149 | k_arr[depth] = str(curr_ind - 1) 150 | new_path2label['.'.join(k_arr)] = v 151 | continue 152 | new_path2label[k] = v 153 | return new_path2label 154 | 155 | # remove node from path_lookup vals 156 | for (k, v) in path_lookup.items(): 157 | if v[:depth] == path[:depth]: 158 | curr_ind = int(v[depth]) 159 | if curr_ind == node_ind: 160 | del path_lookup[k] 161 | if curr_ind > node_ind: 162 | v[depth] = str(curr_ind - 1) 163 | 164 | for i, c in enumerate(line.strip()): 165 | if c == " ": 166 | if in_quote: 167 | cur_charseq.append('_') 168 | continue 169 | if state == 2: 170 | cur_charseq.append(c) 171 | continue 172 | elif c == "\"": 173 | if in_quote: 174 | in_quote = False 175 | else: 176 | in_quote = True 177 | elif c == "(": 178 | if in_quote: 179 | continue 180 | if state == 2: 181 | if cur_attr_name != "": 182 | print >> sys.stderr, "Format error when processing ", line[0:i + 1] 183 | return None 184 | cur_attr_name = "".join(cur_charseq).strip() 185 | cur_charseq[:] = [] 186 | state = 1 187 | elif c == ":": 188 | if in_quote: 189 | continue 190 | if state == 3: # (...: 191 | var_value = "".join(cur_charseq) 192 | cur_charseq[:] = [] 193 | cur_var_name = stack[-1] 194 | var_dict[cur_var_name] = var_value 195 | path2label['.'.join(curr_path)] = var_value 196 | curr_path.append('0') 197 | elif state == 2: # : ...: 198 | temp_attr_value = "".join(cur_charseq) 199 | cur_charseq[:] = [] 200 | parts = temp_attr_value.split() 201 | if len(parts) < 2: 202 | print >> sys.stderr, "Error in processing", line[0:i + 1] 203 | return None 204 | attr_name = parts[0].strip() 205 | attr_value = parts[1].strip() 206 | if len(stack) == 0: 207 | print >> sys.stderr, "Error in processing", line[ 208 | :i], attr_name, attr_value 209 | return None 210 | # TODO should all labels in quotes be consts? 211 | if attr_value not in var_dict: 212 | var_attr_dict2[stack[-1]].append((attr_name, attr_value)) 213 | path2label['.'.join(curr_path)] = attr_value 214 | path_lookup[ 215 | (stack[-1], attr_name, attr_value)] = [i for i in curr_path] 216 | curr_path[-1] = str(int(curr_path[-1]) + 1) 217 | else: 218 | var_attr_dict1[stack[-1]].append((attr_name, attr_value)) 219 | else: 220 | curr_path[-1] = str(int(curr_path[-1]) + 1) 221 | state = 2 222 | elif c == "/": 223 | if in_quote: 224 | continue 225 | if state == 1: 226 | variable_name = "".join(cur_charseq) 227 | cur_charseq[:] = [] 228 | if variable_name in var_dict: 229 | print >> sys.stderr, "Duplicate variable ", variable_name, " in parsing AMR" 230 | return None 231 | stack.append(variable_name) 232 | var_list.append(variable_name) 233 | if cur_attr_name != "": 234 | if not cur_attr_name.endswith("-of"): 235 | var_attr_dict1[stack[-2]].append((cur_attr_name, variable_name)) 236 | else: 237 | var_attr_dict1[variable_name].append( 238 | (cur_attr_name[:-3], stack[-2])) 239 | cur_attr_name = "" 240 | else: 241 | print >> sys.stderr, "Error in parsing AMR", line[0:i + 1] 242 | return None 243 | state = 3 244 | elif c == ")": 245 | if in_quote: 246 | continue 247 | if len(stack) == 0: 248 | print >> sys.stderr, "Unmatched parathesis at position", i, "in processing", line[ 249 | 0:i + 1] 250 | return None 251 | if state == 2: 252 | temp_attr_value = "".join(cur_charseq) 253 | cur_charseq[:] = [] 254 | parts = temp_attr_value.split() 255 | if len(parts) < 2: 256 | print >> sys.stderr, "Error processing", line[ 257 | :i + 1], temp_attr_value 258 | return None 259 | attr_name = parts[0].strip() 260 | attr_value = parts[1].strip() 261 | if cur_attr_name.endswith("-of"): 262 | var_attr_dict1[variable_name].append( 263 | (cur_attr_name[:-3], stack[-2])) 264 | elif attr_value not in var_dict: 265 | var_attr_dict2[stack[-1]].append((attr_name, attr_value)) 266 | else: 267 | var_attr_dict1[stack[-1]].append((attr_name, attr_value)) 268 | path2label['.'.join(curr_path)] = attr_value 269 | path_lookup[ 270 | (stack[-1], attr_name, attr_value)] = [i for i in curr_path] 271 | curr_path.pop() 272 | elif state == 3: 273 | var_value = "".join(cur_charseq) 274 | cur_charseq[:] = [] 275 | cur_var_name = stack[-1] 276 | var_dict[cur_var_name] = var_value 277 | path2label['.'.join(curr_path)] = var_value 278 | else: 279 | curr_path.pop() 280 | stack.pop() 281 | cur_attr_name = "" 282 | state = 4 283 | else: 284 | cur_charseq.append(c) 285 | # create var_list, link_list, attribute 286 | # keep original variable name. 287 | var_value_list = [] 288 | link_list = [] 289 | const_attr_list = [] # for monolingual mode 290 | 291 | # consts_to_vars mode variables 292 | const_cnt = 0 293 | const_var_list = [] 294 | const_var_value_list = [] 295 | const_link_list = [] 296 | 297 | for v in var_list: 298 | if v not in var_dict: 299 | print >> sys.stderr, "Error: variable value not found", v 300 | return None 301 | else: 302 | var_value_list.append(var_dict[v]) 303 | link_dict = {} 304 | const_dict = {} 305 | if v in var_attr_dict1: 306 | for v1 in var_attr_dict1[v]: 307 | link_dict[v1[1]] = v1[0] 308 | if v in var_attr_dict2: 309 | for v2 in var_attr_dict2[v]: 310 | const_lbl = v2[1] 311 | if v2[1][0] == "\"" and v2[1][-1] == "\"": 312 | const_lbl = v2[1][1:-1] 313 | elif v2[1] in var_dict: 314 | # not the first occurrence of this child var 315 | link_dict[v2[1]] = v2[0] 316 | path2label = remove_from_paths(path_lookup[(v, v2[0], v2[1])]) 317 | continue 318 | 319 | if consts_to_vars: 320 | const_var = '_CONST_%d' % const_cnt 321 | const_cnt += 1 322 | var_dict[const_var] = const_lbl 323 | const_var_list.append(const_var) 324 | const_var_value_list.append(const_lbl) 325 | const_link_list.append({}) 326 | link_dict[const_var] = v2[0] 327 | else: 328 | const_dict[v2[0]] = const_lbl 329 | 330 | link_list.append(link_dict) 331 | if not consts_to_vars: 332 | const_attr_list.append(const_dict) 333 | link_list[0][var_list[0]] = "TOP" 334 | if consts_to_vars: 335 | var_list += const_var_list 336 | var_value_list += const_var_value_list 337 | link_list += const_link_list 338 | const_attr_list = [{} for v in var_list] 339 | result_amr = AMR( 340 | var_list, 341 | var_value_list, 342 | link_list, 343 | const_attr_list, 344 | path2label) 345 | return result_amr 346 | -------------------------------------------------------------------------------- /smatch/smatch-table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | smatch-table.py 5 | 6 | This file is from the code for smatch, available at: 7 | 8 | http://amr.isi.edu/download/smatch-v1.0.tar.gz 9 | http://amr.isi.edu/smatch-13.pdf 10 | """ 11 | 12 | import amr 13 | import sys 14 | import subprocess 15 | import smatch 16 | import os 17 | import random 18 | import time 19 | #import optparse 20 | # import argparse #argparse only works for python 2.7. If you are using older versin of Python, you can use optparse instead. 21 | #import locale 22 | 23 | ERROR_LOG = sys.stderr 24 | 25 | verbose = False 26 | 27 | isi_dir_pre = "/nfs/web/isi.edu/cgi-bin/div3/mt/save-amr" 28 | 29 | """ 30 | Get the annotator name list based on a list of files 31 | Args: 32 | file_dir: AMR file folder 33 | files: a list of AMR names, e.g. nw_wsj_0001_1 34 | Return: 35 | a list of user names who annotate all the files 36 | """ 37 | 38 | 39 | def get_names(file_dir, files): 40 | # for each user, check if they have files available 41 | # return user name list 42 | total_list = [] 43 | name_list = [] 44 | get_sub = False 45 | for path, subdir, dir_files in os.walk(file_dir): 46 | # print path 47 | if not get_sub: 48 | total_list = subdir[:] 49 | get_sub = True 50 | else: 51 | break 52 | for user in total_list: 53 | # print user 54 | has_file = True 55 | for file in files: 56 | # print file 57 | file_path = file_dir + user + "/" + file + ".txt" 58 | # print file_path 59 | if not os.path.exists(file_path): 60 | has_file = False 61 | break 62 | if has_file: 63 | name_list.append(user) 64 | # print name_list 65 | if len(name_list) == 0: 66 | print >> ERROR_LOG, "********Error: Cannot find any user who completes the files*************" 67 | return name_list 68 | """ 69 | Compute the smatch scores for a file list between two users 70 | Args: 71 | user1: user 1 name 72 | user2: user 2 name 73 | file_list: file list 74 | dir_pre: the file location prefix 75 | start_num: the number of restarts in smatch 76 | Returns: 77 | smatch f score. 78 | """ 79 | 80 | 81 | def compute_files(user1, user2, file_list, dir_pre, start_num): 82 | # print file_list 83 | # print user1, user2 84 | match_total = 0 85 | test_total = 0 86 | gold_total = 0 87 | for fi in file_list: 88 | file1 = dir_pre + user1 + "/" + fi + ".txt" 89 | file2 = dir_pre + user2 + "/" + fi + ".txt" 90 | # print file1,file2 91 | if not os.path.exists(file1): 92 | print >> ERROR_LOG, "*********Error: ", file1, "does not exist*********" 93 | return -1.00 94 | if not os.path.exists(file2): 95 | print >> ERROR_LOG, "*********Error: ", file2, "does not exist*********" 96 | return -1.00 97 | try: 98 | file1_h = open(file1, "r") 99 | file2_h = open(file2, "r") 100 | except: 101 | print >> ERROR_LOG, "Cannot open the files", file1, file2 102 | cur_amr1 = smatch.get_amr_line(file1_h) 103 | cur_amr2 = smatch.get_amr_line(file2_h) 104 | if(cur_amr1 == ""): 105 | print >> ERROR_LOG, "AMR 1 is empty" 106 | continue 107 | if(cur_amr2 == ""): 108 | print >> ERROR_LOG, "AMR 2 is empty" 109 | continue 110 | amr1 = amr.AMR.parse_AMR_line(cur_amr1) 111 | amr2 = amr.AMR.parse_AMR_line(cur_amr2) 112 | test_label = "a" 113 | gold_label = "b" 114 | amr1.rename_node(test_label) 115 | amr2.rename_node(gold_label) 116 | (test_inst, test_rel1, test_rel2) = amr1.get_triples2() 117 | (gold_inst, gold_rel1, gold_rel2) = amr2.get_triples2() 118 | if verbose: 119 | print >> ERROR_LOG, "Instance triples of file 1:", len(test_inst) 120 | print >> ERROR_LOG, test_inst 121 | print >> sys.stderr, "Relation triples of file 1:", len( 122 | test_rel1) + len(test_rel2) 123 | print >>sys.stderr, test_rel1 124 | print >> sys.stderr, test_rel2 125 | print >> ERROR_LOG, "Instance triples of file 2:", len(gold_inst) 126 | print >> ERROR_LOG, gold_inst 127 | print >> sys.stderr, "Relation triples of file 2:", len( 128 | gold_rel1) + len(gold_rel2) 129 | print >> sys.stderr, gold_rel1 130 | print >> sys.stderr, gold_rel2 131 | if len(test_inst) < len(gold_inst): 132 | (best_match, 133 | best_match_num) = smatch.get_fh(test_inst, 134 | test_rel1, 135 | test_rel2, 136 | gold_inst, 137 | gold_rel1, 138 | gold_rel2, 139 | test_label, 140 | gold_label) 141 | if verbose: 142 | print >> ERROR_LOG, "best match number", best_match_num 143 | print >>ERROR_LOG, "Best Match:", smatch.print_alignment( 144 | best_match, test_inst, gold_inst) 145 | else: 146 | (best_match, 147 | best_match_num) = smatch.get_fh(gold_inst, 148 | gold_rel1, 149 | gold_rel2, 150 | test_inst, 151 | test_rel1, 152 | test_rel2, 153 | gold_label, 154 | test_label) 155 | if verbose: 156 | print >> ERROR_LOG, "best match number", best_match_num 157 | print >>ERROR_LOG, "Best Match:", smatch.print_alignment( 158 | best_match, gold_inst, test_inst, True) 159 | #(match_num,test_num,gold_num)=smatch.get_match(tmp_filename1,tmp_filename2,start_num) 160 | # print match_num,test_num,gold_num 161 | # print best_match_num 162 | # print len(test_inst)+len(test_rel1)+len(test_rel2) 163 | # print len(gold_inst)+len(gold_rel1)+len(gold_rel2) 164 | match_total += best_match_num 165 | test_total += len(test_inst) + len(test_rel1) + len(test_rel2) 166 | gold_total += len(gold_inst) + len(gold_rel1) + len(gold_rel2) 167 | smatch.match_num_dict.clear() 168 | (precision, recall, f_score) = smatch.compute_f( 169 | match_total, test_total, gold_total) 170 | return "%.2f" % f_score 171 | 172 | 173 | def get_max_width(table, index): 174 | return max([len(str(row[index])) for row in table]) 175 | """ 176 | Print a table 177 | """ 178 | 179 | 180 | def pprint_table(table): 181 | col_paddings = [] 182 | for i in range(len(table[0])): 183 | col_paddings.append(get_max_width(table, i)) 184 | for row in table: 185 | print row[0].ljust(col_paddings[0] + 1), 186 | for i in range(1, len(row)): 187 | col = str(row[i]).rjust(col_paddings[i] + 2) 188 | print col, 189 | print "\n" 190 | 191 | 192 | def print_help(): 193 | print "Smatch Calculator Program Help" 194 | print "This program prints the smatch score of the two files" 195 | print "Command line arguments:" 196 | print "-h: Show help (Other options won't work if you use -h)" 197 | print "smatch-table.py -h" 198 | print "Usage: smatch-table.py file_list (-f list_file) [ -p user_list ] [-r number of starts]" 199 | print "File list is AMR file ids separated by a blank space" 200 | print "Example: smatch-table.py nw_wsj_0001_1 nw_wsj_0001_2" 201 | print "Or use -f list_file to indicate a file which contains one line of file names, separated by a blank space" 202 | print "Example: smatch.py -f file" 203 | print "-p: (Optional) user list to list the user name in the command line, after the file list. Otherwise the program automatically searches for the users who completes all AMRs you want." 204 | print "Example: smatch.py -f file -p user1 user2" 205 | print "Example: smatch.py nw_wsj_0001_1 nw_wsj_0001_2 -p user1 user2" 206 | print "-r: (Optional) the number of random starts(higher number may results in higher accuracy and slower speed (default number of starts: 10)" 207 | print "Example: smatch.py -f file -p user1 user2 -r 20" 208 | # print "-d: detailed output, including alignment and triples of the two files" 209 | # print "Example (if you want to use all options): smatch.py file1 file2 210 | # -d -r 20" 211 | print "Contact shucai@isi.edu for additional help" 212 | 213 | 214 | def build_arg_parser(): 215 | """Build an argument parser using argparse""" 216 | parser = argparse.ArgumentParser( 217 | description="Smatch table calculator -- arguments") 218 | parser.add_argument( 219 | "--fl", 220 | type=argparse.FileType('r'), 221 | help='AMR ID list file') 222 | parser.add_argument('-f', nargs='+', help='AMR IDs (at least one)') 223 | parser.add_argument("-p", nargs='*', help="User list (can be none)") 224 | parser.add_argument( 225 | "--fd", 226 | default=isi_dir_pre, 227 | help="AMR File directory. Default=location on isi machine") 228 | #parser.add_argument("--cd",default=os.getcwd(),help="(Dependent) code directory. Default: current directory") 229 | parser.add_argument( 230 | '-r', 231 | type=int, 232 | default=4, 233 | help='Restart number (Default:4)') 234 | parser.add_argument( 235 | '-v', 236 | action='store_true', 237 | help='Verbose output (Default:False)') 238 | return parser 239 | """ 240 | Callback function to handle variable number of arguments in optparse 241 | """ 242 | 243 | 244 | def cb(option, opt_str, value, parser): 245 | args = [] 246 | args.append(value) 247 | for arg in parser.rargs: 248 | if arg[0] != "-": 249 | args.append(arg) 250 | else: 251 | del parser.rargs[:len(args)] 252 | break 253 | if getattr(parser.values, option.dest): 254 | args.extend(getattr(parser.values, option.dest)) 255 | setattr(parser.values, option.dest, args) 256 | 257 | 258 | def build_arg_parser2(): 259 | """Build an argument parser using optparse""" 260 | usage_str = "Smatch table calculator -- arguments" 261 | parser = optparse.OptionParser(usage=usage_str) 262 | parser.add_option("--fl", dest="fl", type="string", help='AMR ID list file') 263 | parser.add_option( 264 | "-f", 265 | dest="f", 266 | type="string", 267 | action="callback", 268 | callback=cb, 269 | help="AMR IDs (at least one)") 270 | parser.add_option( 271 | "-p", 272 | dest="p", 273 | type="string", 274 | action="callback", 275 | callback=cb, 276 | help="User list") 277 | parser.add_option("--fd", dest="fd", type="string", help="file directory") 278 | #parser.add_option("--cd",dest="cd",type="string",help="code directory") 279 | parser.add_option( 280 | "-r", 281 | "--restart", 282 | dest="r", 283 | type="int", 284 | help='Restart number (Default: 4)') 285 | parser.add_option( 286 | "-v", 287 | "--verbose", 288 | action='store_true', 289 | dest="v", 290 | help='Verbose output (Default:False)') 291 | parser.set_defaults(r=4, v=False, ms=False, fd=isi_dir_pre) 292 | return parser 293 | 294 | 295 | def check_args(args): 296 | """Check if the arguments are valid""" 297 | if not os.path.exists(args.fd): 298 | print >> ERROR_LOG, "Not a valid path", args.fd 299 | return ([], [], False) 300 | # if not os.path.exists(args.cd): 301 | # print >> ERROR_LOG,"Not a valid path", args.cd 302 | # return ([],[],False) 303 | amr_ids = [] 304 | if args.fl is not None: 305 | # we already ensure the file can be opened and opened the file 306 | file_line = args.fl.readline() 307 | amr_ids = file_line.strip().split() 308 | elif args.f is None: 309 | print >> ERROR_LOG, "No AMR ID was given" 310 | return ([], [], False) 311 | else: 312 | amr_ids = args.f 313 | names = [] 314 | check_name = True 315 | if args.p is None: 316 | names = get_names(args.fd, amr_ids) 317 | check_name = False # no need to check names 318 | if len(names) == 0: 319 | print >> ERROR_LOG, "Cannot find any user who tagged these AMR" 320 | return ([], [], False) 321 | else: 322 | names = args.p 323 | if names == []: 324 | print >> ERROR_LOG, "No user was given" 325 | return ([], [], False) 326 | if len(names) == 1: 327 | print >> ERROR_LOG, "Only one user is given. Smatch calculation requires at least two users." 328 | return ([], [], False) 329 | if "consensus" in names: 330 | con_index = names.index("consensus") 331 | names.pop(con_index) 332 | names.append("consensus") 333 | # check if all the AMR_id and user combinations are valid 334 | if check_name: 335 | pop_name = [] 336 | for i, name in enumerate(names): 337 | for amr in amr_ids: 338 | amr_path = args.fd + name + "/" + amr + ".txt" 339 | if not os.path.exists(amr_path): 340 | print >> ERROR_LOG, "User", name, "fails to tag AMR", amr 341 | pop_name.append(i) 342 | break 343 | if len(pop_name) != 0: 344 | pop_num = 0 345 | for p in pop_name: 346 | print >> ERROR_LOG, "Deleting user", names[ 347 | p - pop_num], "from the name list" 348 | names.pop(p - pop_num) 349 | pop_num += 1 350 | if len(names) < 2: 351 | print >> ERROR_LOG, "Not enough users to evaluate. Smatch requires >2 users who tag all the AMRs" 352 | return ("", "", False) 353 | return (amr_ids, names, True) 354 | 355 | 356 | def main(args): 357 | """Main Function""" 358 | (ids, names, result) = check_args(args) 359 | if args.v: 360 | verbose = True 361 | if not result: 362 | return 0 363 | acc_time = 0 364 | len_name = len(names) 365 | table = [] 366 | for i in range(0, len_name + 1): 367 | table.append([]) 368 | table[0].append("") 369 | for i in range(0, len_name): 370 | table[0].append(names[i]) 371 | for i in range(0, len_name): 372 | table[i + 1].append(names[i]) 373 | for j in range(0, len_name): 374 | if i != j: 375 | start = time.clock() 376 | table[ 377 | i + 378 | 1].append( 379 | compute_files( 380 | names[i], 381 | names[j], 382 | ids, 383 | args.fd, 384 | args.r)) 385 | end = time.clock() 386 | if table[i + 1][-1] != -1.0: 387 | acc_time += end - start 388 | # if table[i+1][-1]==-1.0: 389 | # sys.exit(1) 390 | else: 391 | table[i + 1].append("") 392 | # check table 393 | for i in range(0, len_name + 1): 394 | for j in range(0, len_name + 1): 395 | if i != j: 396 | if table[i][j] != table[j][i]: 397 | if table[i][j] > table[j][i]: 398 | table[j][i] = table[i][j] 399 | else: 400 | table[i][j] = table[j][i] 401 | pprint_table(table) 402 | return acc_time 403 | 404 | 405 | if __name__ == "__main__": 406 | # acc_time=0 #accumulated time 407 | whole_start = time.clock() 408 | parser = None 409 | args = None 410 | if sys.version_info[:2] != (2, 7): 411 | # requires version >=2.3! 412 | if sys.version_info[0] != 2 or sys.version_info[1] < 5: 413 | print >> ERROR_LOG, "This prgram requires python 2.5 or later to run. " 414 | exit(1) 415 | import optparse 416 | parser = build_arg_parser2() 417 | (args, opts) = parser.parse_args() 418 | file_handle = None 419 | if args.fl is not None: 420 | try: 421 | file_handle = open(args.fl, "r") 422 | args.fl = file_handle 423 | except: 424 | print >> ERROR_LOG, "The ID list file", args.fl, "does not exist" 425 | args.fl = None 426 | # print args 427 | else: # version 2.7 428 | import argparse 429 | parser = build_arg_parser() 430 | args = parser.parse_args() 431 | # Regularize fd and cd representation 432 | if args.fd[-1] != "/": 433 | args.fd = args.fd + "/" 434 | # if args.cd[-1]!="/": 435 | # args.cd=args.cd+"/" 436 | acc_time = main(args) 437 | whole_end = time.clock() 438 | whole_time = whole_end - whole_start 439 | # print >> ERROR_LOG, "Accumulated time", acc_time 440 | # print >> ERROR_LOG, "whole time", whole_time 441 | # print >> ERROR_LOG, "Percentage", float(acc_time)/float(whole_time) 442 | -------------------------------------------------------------------------------- /smatch/smatch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | smatch.py 5 | 6 | Author: Shu Cai 7 | Copyright(c) 2012. All rights reserved. 8 | 9 | This file is from the code for smatch, available at: 10 | 11 | http://amr.isi.edu/download/smatch-v1.0.tar.gz 12 | http://amr.isi.edu/smatch-13.pdf 13 | """ 14 | import codecs 15 | import sys 16 | import os 17 | import time 18 | import random 19 | import amr 20 | #import optparse 21 | # import argparse #argparse only works for python 2.7. If you are using 22 | # older versin of Python, you can use optparse instead. 23 | 24 | verbose = False # global variable, verbose output control 25 | 26 | single_score = True # global variable, single score output control 27 | 28 | pr_flag = False # global variable, output precision and recall 29 | 30 | ERROR_LOG = sys.stderr 31 | 32 | match_num_dict = {} # key: match number tuples value: the matching number 33 | 34 | 35 | def get_amr_line(input_f): 36 | """Read the amr file. AMRs are separated by a blank line.""" 37 | cur_amr = [] 38 | has_content = False 39 | for line in input_f: 40 | if line[0] == "(" and len(cur_amr) != 0: 41 | cur_amr = [] 42 | if line.strip() == "": 43 | if not has_content: 44 | continue 45 | else: 46 | break 47 | elif line.strip().startswith("#"): 48 | # omit the comment in the AMR file 49 | continue 50 | else: 51 | has_content = True 52 | cur_amr.append(line.strip()) 53 | return "".join(cur_amr) 54 | 55 | 56 | def build_arg_parser(): 57 | """Build an argument parser using argparse""" 58 | parser = argparse.ArgumentParser( 59 | description="Smatch calculator -- arguments") 60 | parser.add_argument( 61 | '-f', 62 | nargs=2, 63 | required=True, 64 | type=argparse.FileType('r'), 65 | help='Two files containing AMR pairs. AMRs in each file are separated by a single blank line') 66 | parser.add_argument( 67 | '-r', 68 | type=int, 69 | default=4, 70 | help='Restart number (Default:4)') 71 | parser.add_argument( 72 | '-v', 73 | action='store_true', 74 | help='Verbose output (Default:False)') 75 | parser.add_argument( 76 | '--ms', 77 | action='store_true', 78 | default=False, 79 | help='Output multiple scores (one AMR pair a score) instead of a single document-level smatch score (Default: False)') 80 | parser.add_argument( 81 | '--pr', 82 | action='store_true', 83 | default=False, 84 | help="Output precision and recall as well as the f-score. Default: false") 85 | return parser 86 | 87 | 88 | def build_arg_parser2(): 89 | """Build an argument parser using optparse""" 90 | usage_str = "Smatch calculator -- arguments" 91 | parser = optparse.OptionParser(usage=usage_str) 92 | #parser.add_option("-h","--help",action="help",help="Smatch calculator -- arguments") 93 | parser.add_option( 94 | "-f", 95 | "--files", 96 | nargs=2, 97 | dest="f", 98 | type="string", 99 | help='Two files containing AMR pairs. AMRs in each file are separated by a single blank line. This option is required.') 100 | parser.add_option( 101 | "-r", 102 | "--restart", 103 | dest="r", 104 | type="int", 105 | help='Restart number (Default: 4)') 106 | parser.add_option( 107 | "-v", 108 | "--verbose", 109 | action='store_true', 110 | dest="v", 111 | help='Verbose output (Default:False)') 112 | parser.add_option( 113 | "--ms", 114 | "--multiple_score", 115 | action='store_true', 116 | dest="ms", 117 | help='Output multiple scores (one AMR pair a score) instead of a single document-level smatch score (Default: False)') 118 | parser.add_option( 119 | '--pr', 120 | "--precision_recall", 121 | action='store_true', 122 | dest="pr", 123 | help="Output precision and recall as well as the f-score. Default: false") 124 | parser.set_defaults(r=4, v=False, ms=False, pr=False) 125 | return parser 126 | 127 | 128 | def dflt_label_weighter(test_label, gold_label): 129 | """ 130 | Return score corresponding to the weight to add for matching 131 | test_label and gold_label in the default Smatch setting. 132 | """ 133 | if test_label.lower() == gold_label.lower(): 134 | return 1.0 135 | else: 136 | return 0.0 137 | 138 | 139 | def compute_pool(test_instance, test_relation1, test_relation2, 140 | gold_instance, gold_relation1, gold_relation2, 141 | test_label, gold_label, 142 | node_weight_fn, edge_weight_fn): 143 | """ 144 | compute the possible variable matching candidate (the match which may result in 1) 145 | Args: 146 | test_instance: instance triples in AMR 1 147 | test_relation1: relation triples which contain only one variable in AMR 1 148 | test_relation2: relation triples which contain two variables in AMR 1 149 | gold_instance: instance triples in AMR 2 150 | gold_relation1: relation triples which contain only one variable in AMR 2 151 | gold_relations: relation triples which contain two variables in AMR 2 152 | test_label: the prefix of the variable in AMR 1, e.g. a (variable a1, a2, a3...) 153 | gold_label: the prefix of the variable in AMR 2, e.g. b (variable b1, b2, b3...) 154 | Returns: 155 | candidate_match: a list of candidate mapping variables. Each entry contains a set of the variables the variable can map to. 156 | weight_dict: a dictionary which contains the matching triple number of every pair of variable mapping. """ 157 | len_test_inst = len(test_instance) 158 | len_gold_inst = len(gold_instance) 159 | len_test_rel1 = len(test_relation1) 160 | len_gold_rel1 = len(gold_relation1) 161 | len_test_rel2 = len(test_relation2) 162 | len_gold_rel2 = len(gold_relation2) 163 | candidate_match = [] 164 | weight_dict = {} 165 | for i in range(0, len_test_inst): 166 | candidate_match.append(set()) 167 | for i in range(0, len_test_inst): 168 | for j in range(0, len_gold_inst): 169 | if test_instance[i][0].lower() == gold_instance[j][0].lower(): 170 | w = node_weight_fn(test_instance[i][2], gold_instance[j][2]) 171 | var1_num = int(test_instance[i][1][len(test_label):]) 172 | var2_num = int(gold_instance[j][1][len(gold_label):]) 173 | candidate_match[var1_num].add(var2_num) 174 | cur_k = (var1_num, var2_num) 175 | if cur_k in weight_dict: 176 | weight_dict[cur_k][-1] += w 177 | else: 178 | weight_dict[cur_k] = {} 179 | weight_dict[cur_k][-1] = w 180 | for i in range(0, len_test_rel1): 181 | for j in range(0, len_gold_rel1): 182 | if test_relation1[i][0].lower() == gold_relation1[j][0].lower(): 183 | w = node_weight_fn(test_relation1[i][2], gold_relation1[j][2]) 184 | var1_num = int(test_relation1[i][1][len(test_label):]) 185 | var2_num = int(gold_relation1[j][1][len(gold_label):]) 186 | candidate_match[var1_num].add(var2_num) 187 | cur_k = (var1_num, var2_num) 188 | if cur_k in weight_dict: 189 | weight_dict[cur_k][-1] += w 190 | else: 191 | weight_dict[cur_k] = {} 192 | weight_dict[cur_k][-1] = w 193 | 194 | for i in range(0, len_test_rel2): 195 | for j in range(0, len_gold_rel2): 196 | w = edge_weight_fn(test_relation2[i][0], gold_relation2[j][0]) 197 | if w > 0: 198 | var1_num_test = int(test_relation2[i][1][len(test_label):]) 199 | var1_num_gold = int(gold_relation2[j][1][len(gold_label):]) 200 | var2_num_test = int(test_relation2[i][2][len(test_label):]) 201 | var2_num_gold = int(gold_relation2[j][2][len(gold_label):]) 202 | candidate_match[var1_num_test].add(var1_num_gold) 203 | candidate_match[var2_num_test].add(var2_num_gold) 204 | cur_k1 = (var1_num_test, var1_num_gold) 205 | cur_k2 = (var2_num_test, var2_num_gold) 206 | if cur_k2 != cur_k1: 207 | if cur_k1 in weight_dict: 208 | if cur_k2 in weight_dict[cur_k1]: 209 | weight_dict[cur_k1][cur_k2] += w 210 | else: 211 | weight_dict[cur_k1][cur_k2] = w 212 | else: 213 | weight_dict[cur_k1] = {} 214 | weight_dict[cur_k1][-1] = 0 215 | weight_dict[cur_k1][cur_k2] = w 216 | if cur_k2 in weight_dict: 217 | if cur_k1 in weight_dict[cur_k2]: 218 | weight_dict[cur_k2][cur_k1] += w 219 | else: 220 | weight_dict[cur_k2][cur_k1] = w 221 | else: 222 | weight_dict[cur_k2] = {} 223 | weight_dict[cur_k2][-1] = 0 224 | weight_dict[cur_k2][cur_k1] = w 225 | else: 226 | # cycle 227 | if cur_k1 in weight_dict: 228 | weight_dict[cur_k1][-1] += w 229 | else: 230 | weight_dict[cur_k1] = {} 231 | weight_dict[cur_k1][-1] = w 232 | return (candidate_match, weight_dict) 233 | 234 | 235 | def init_match(candidate_match, test_instance, gold_instance, node_weight_fn): 236 | """Initialize match based on the word match 237 | Args: 238 | candidate_match: candidate variable match list 239 | test_instance: test instance 240 | gold_instance: gold instance 241 | Returns: 242 | intialized match result""" 243 | random.seed() 244 | matched_dict = {} 245 | 246 | num_test_matched = 0 247 | matches_by_weight = [] 248 | result = [-1 for c in candidate_match] 249 | for i, c in enumerate(candidate_match): 250 | c2 = list(c) 251 | if len(c2) == 0: 252 | num_test_matched += 1 253 | continue 254 | # word in the test instance 255 | test_word = test_instance[i][2] 256 | for j, m_id in enumerate(c2): 257 | gold_word = gold_instance[int(m_id)][2] 258 | curr_score = node_weight_fn(gold_word, test_word) 259 | matches_by_weight.append((int(m_id), i, curr_score)) 260 | 261 | matches_by_weight = sorted(matches_by_weight, key=lambda (x1,x2,x3) : x3, reverse=True) 262 | for (gold, test, score) in matches_by_weight: 263 | if len(matched_dict) == len(gold_instance) \ 264 | or num_test_matched == len(test_instance): 265 | break 266 | if result[test] != -1 or gold in matched_dict: 267 | continue 268 | result[test] = gold 269 | matched_dict[gold] = 1 270 | num_test_matched += 1 271 | 272 | for (i, m) in enumerate(result): 273 | if m != -1: 274 | continue 275 | c2 = list(candidate_match[i]) 276 | found = False 277 | while len(c2) != 1: 278 | rid = random.randint(0, len(c2) - 1) 279 | if c2[rid] in matched_dict: 280 | c2.pop(rid) 281 | else: 282 | matched_dict[c2[rid]] = 1 283 | result[i] = c2[rid] 284 | found = True 285 | break 286 | if not found: 287 | if c2[0] not in matched_dict: 288 | result[i] = c2[0] 289 | matched_dict[c2[0]] = 1 290 | return result 291 | 292 | 293 | def get_random_sol(candidate): 294 | """ 295 | Generate a random variable mapping. 296 | Args: 297 | candidate:a list of set and each set contains the candidate match of a test instance 298 | """ 299 | random.seed() 300 | matched_dict = {} 301 | result = [] 302 | for c in candidate: 303 | c2 = list(c) 304 | found = False 305 | if len(c2) == 0: 306 | result.append(-1) 307 | continue 308 | while len(c2) != 1: 309 | rid = random.randint(0, len(c2) - 1) 310 | if c2[rid] in matched_dict: 311 | c2.pop(rid) 312 | else: 313 | matched_dict[c2[rid]] = 1 314 | result.append(c2[rid]) 315 | found = True 316 | break 317 | if not found: 318 | if c2[0] not in matched_dict: 319 | result.append(c2[0]) 320 | matched_dict[c2[0]] = 1 321 | else: 322 | result.append(-1) 323 | return result 324 | 325 | 326 | def compute_match(match, weight_dict): 327 | """Given a variable match, compute match number based on weight_dict. 328 | Args: 329 | match: a list of number in gold set, len(match)= number of test instance 330 | Returns: 331 | matching triple number 332 | Complexity: O(m*n) , m is the length of test instance, n is the length of gold instance""" 333 | # remember matching number of the previous matching we investigated 334 | if tuple(match) in match_num_dict: 335 | return match_num_dict[tuple(match)] 336 | match_num = 0 337 | for i, m in enumerate(match): 338 | if m == -1: 339 | continue 340 | cur_m = (i, m) 341 | if cur_m not in weight_dict: 342 | continue 343 | match_num += weight_dict[cur_m][-1] 344 | for k in weight_dict[cur_m]: 345 | if k == -1: 346 | continue 347 | if k[0] < i: 348 | continue 349 | elif match[k[0]] == k[1]: 350 | match_num += weight_dict[cur_m][k] 351 | match_num_dict[tuple(match)] = match_num 352 | return match_num 353 | 354 | 355 | def move_gain(match, i, m, nm, weight_dict, match_num): 356 | """Compute the triple match number gain by the move operation 357 | Args: 358 | match: current match list 359 | i: the remapped source variable 360 | m: the original id 361 | nm: new mapped id 362 | weight_dict: weight dictionary 363 | match_num: the original matching number 364 | Returns: 365 | the gain number (might be negative)""" 366 | cur_m = (i, nm) 367 | old_m = (i, m) 368 | new_match = match[:] 369 | new_match[i] = nm 370 | if tuple(new_match) in match_num_dict: 371 | return match_num_dict[tuple(new_match)] - match_num 372 | gain = 0 373 | if cur_m in weight_dict: 374 | gain += weight_dict[cur_m][-1] 375 | for k in weight_dict[cur_m]: 376 | if k == -1: 377 | continue 378 | elif match[k[0]] == k[1]: 379 | gain += weight_dict[cur_m][k] 380 | if old_m in weight_dict: 381 | gain -= weight_dict[old_m][-1] 382 | for k in weight_dict[old_m]: 383 | if k == -1: 384 | continue 385 | elif match[k[0]] == k[1]: 386 | gain -= weight_dict[old_m][k] 387 | match_num_dict[tuple(new_match)] = match_num + gain 388 | return gain 389 | 390 | 391 | def swap_gain(match, i, m, j, m2, weight_dict, match_num): 392 | """Compute the triple match number gain by the swap operation 393 | Args: 394 | match: current match list 395 | i: the position 1 396 | m: the original mapped variable of i 397 | j: the position 2 398 | m2: the original mapped variable of j 399 | weight_dict: weight dictionary 400 | match_num: the original matching number 401 | Returns: 402 | the gain number (might be negative)""" 403 | new_match = match[:] 404 | new_match[i] = m2 405 | new_match[j] = m 406 | gain = 0 407 | cur_m = (i, m2) 408 | cur_m2 = (j, m) 409 | old_m = (i, m) 410 | old_m2 = (j, m2) 411 | if cur_m in weight_dict: 412 | gain += weight_dict[cur_m][-1] 413 | if cur_m2 in weight_dict[cur_m]: 414 | gain += weight_dict[cur_m][cur_m2] 415 | for k in weight_dict[cur_m]: 416 | if k == -1: 417 | continue 418 | elif k[0] == j: 419 | continue 420 | elif match[k[0]] == k[1]: 421 | gain += weight_dict[cur_m][k] 422 | if cur_m2 in weight_dict: 423 | gain += weight_dict[cur_m2][-1] 424 | for k in weight_dict[cur_m2]: 425 | if k == -1: 426 | continue 427 | elif k[0] == i: 428 | continue 429 | elif match[k[0]] == k[1]: 430 | gain += weight_dict[cur_m2][k] 431 | if old_m in weight_dict: 432 | gain -= weight_dict[old_m][-1] 433 | if old_m2 in weight_dict[old_m]: 434 | gain -= weight_dict[old_m][old_m2] 435 | for k in weight_dict[old_m]: 436 | if k == -1: 437 | continue 438 | elif k[0] == j: 439 | continue 440 | elif match[k[0]] == k[1]: 441 | gain -= weight_dict[old_m][k] 442 | if old_m2 in weight_dict: 443 | gain -= weight_dict[old_m2][-1] 444 | for k in weight_dict[old_m2]: 445 | if k == -1: 446 | continue 447 | elif k[0] == i: 448 | continue 449 | elif match[k[0]] == k[1]: 450 | gain -= weight_dict[old_m2][k] 451 | match_num_dict[tuple(new_match)] = match_num + gain 452 | return gain 453 | 454 | 455 | def get_best_gain( 456 | match, 457 | candidate_match, 458 | weight_dict, 459 | gold_len, 460 | start_match_num): 461 | """ hill-climbing method to return the best gain swap/move can get 462 | Args: 463 | match: the initial variable mapping 464 | candidate_match: the match candidates list 465 | weight_dict: the weight dictionary 466 | gold_len: the number of the variables in file 2 467 | start_match_num: the initial match number 468 | Returns: 469 | the best gain we can get via swap/move operation""" 470 | largest_gain = 0 471 | largest_match_num = 0 472 | swap = True # True: using swap False: using move 473 | change_list = [] 474 | # unmatched gold number 475 | unmatched_gold = set(range(0, gold_len)) 476 | # O(gold_len) 477 | for m in match: 478 | if m in unmatched_gold: 479 | unmatched_gold.remove(m) 480 | unmatch_list = list(unmatched_gold) 481 | for i, m in enumerate(match): 482 | # remap i 483 | for nm in unmatch_list: 484 | if nm in candidate_match[i]: 485 | #(i,m) -> (i,nm) 486 | gain = move_gain(match, i, m, nm, weight_dict, start_match_num) 487 | if verbose: 488 | new_match = match[:] 489 | new_match[i] = nm 490 | new_m_num = compute_match(new_match, weight_dict) 491 | if new_m_num != start_match_num + gain: 492 | print >> sys.stderr, match, new_match 493 | print >> sys.stderr, "Inconsistency in computing: move gain", start_match_num, gain, new_m_num 494 | if gain > largest_gain: 495 | largest_gain = gain 496 | change_list = [i, nm] 497 | swap = False 498 | largest_match_num = start_match_num + gain 499 | for i, m in enumerate(match): 500 | for j, m2 in enumerate(match): 501 | # swap i 502 | if i == j: 503 | continue 504 | new_match = match[:] 505 | new_match[i] = m2 506 | new_match[j] = m 507 | sw_gain = swap_gain(match, i, m, j, m2, weight_dict, start_match_num) 508 | if verbose: 509 | new_match = match[:] 510 | new_match[i] = m2 511 | new_match[j] = m 512 | new_m_num = compute_match(new_match, weight_dict) 513 | if new_m_num != start_match_num + sw_gain: 514 | print >> sys.stderr, match, new_match 515 | print >> sys.stderr, "Inconsistency in computing: swap gain", start_match_num, sw_gain, new_m_num 516 | if sw_gain > largest_gain: 517 | largest_gain = sw_gain 518 | change_list = [i, j] 519 | swap = True 520 | cur_match = match[:] 521 | largest_match_num = start_match_num + largest_gain 522 | if change_list != []: 523 | if swap: 524 | temp = cur_match[change_list[0]] 525 | cur_match[change_list[0]] = cur_match[change_list[1]] 526 | cur_match[change_list[1]] = temp 527 | # print >> sys.stderr,"swap gain" 528 | else: 529 | cur_match[change_list[0]] = change_list[1] 530 | # print >> sys.stderr,"move gain" 531 | return (largest_match_num, cur_match) 532 | 533 | 534 | def get_fh(test_instance, test_relation1, test_relation2, 535 | gold_instance, gold_relation1, gold_relation2, 536 | test_label, gold_label, 537 | node_weight_fn=dflt_label_weighter, edge_weight_fn=dflt_label_weighter, 538 | iter_num=5): 539 | """Get the f-score given two sets of triples 540 | Args: 541 | iter_num: iteration number of heuristic search 542 | test_instance: instance triples of AMR 1 543 | test_relation1: relation triples of AMR 1 (one-variable) 544 | test_relation2: relation triples of AMR 2 (two-variable) 545 | gold_instance: instance triples of AMR 2 546 | gold_relation1: relation triples of AMR 2 (one-variable) 547 | gold_relation2: relation triples of AMR 2 (two-variable) 548 | test_label: prefix label for AMRe 1 549 | gold_label: prefix label for AMR 2 550 | Returns: 551 | best_match: the variable mapping which results in the best matching triple number 552 | best_match_num: the highest matching number 553 | """ 554 | # compute candidate pool 555 | (candidate_match, 556 | weight_dict) = compute_pool(test_instance, test_relation1, test_relation2, 557 | gold_instance, gold_relation1, gold_relation2, 558 | test_label, gold_label, 559 | node_weight_fn, edge_weight_fn) 560 | best_match_num = 0 561 | best_match = [-1] * len(test_instance) 562 | 563 | # best lexical match 564 | if iter_num == 0: 565 | start_match = init_match( 566 | candidate_match, 567 | test_instance, 568 | gold_instance, 569 | node_weight_fn) 570 | return(start_match, compute_match(start_match, weight_dict)) 571 | 572 | for i in range(0, iter_num): 573 | if verbose: 574 | print >> sys.stderr, "Iteration", i 575 | if i == 0: 576 | # smart initialization 577 | start_match = init_match( 578 | candidate_match, 579 | test_instance, 580 | gold_instance, 581 | node_weight_fn) 582 | else: 583 | # random initialization 584 | start_match = get_random_sol(candidate_match) 585 | # first match_num, and store the match in memory 586 | match_num = compute_match(start_match, weight_dict) 587 | # match_num_dict[tuple(start_match)]=match_num 588 | if verbose: 589 | print >> sys.stderr, "starting point match num:", match_num 590 | print >> sys.stderr, "start match", start_match 591 | # hill-climbing 592 | (largest_match_num, 593 | cur_match) = get_best_gain(start_match, 594 | candidate_match, 595 | weight_dict, 596 | len(gold_instance), 597 | match_num) 598 | if verbose: 599 | print >> sys.stderr, "Largest match number after the hill-climbing", largest_match_num 600 | # match_num=largest_match_num 601 | # hill-climbing until there will be no gain if we generate a new variable 602 | # mapping 603 | while largest_match_num > match_num: 604 | match_num = largest_match_num 605 | (largest_match_num, 606 | cur_match) = get_best_gain(cur_match, 607 | candidate_match, 608 | weight_dict, 609 | len(gold_instance), 610 | match_num) 611 | if verbose: 612 | print >> sys.stderr, "Largest match number after the hill-climbing", largest_match_num 613 | if match_num > best_match_num: 614 | best_match = cur_match[:] 615 | best_match_num = match_num 616 | return (best_match, best_match_num) 617 | 618 | # help of inst_list: record a0 location in the test_instance ... 619 | 620 | 621 | def print_alignment(match, test_instance, gold_instance, flip=False): 622 | """ print the alignment based on a match 623 | Args: 624 | match: current match, denoted by a list 625 | test_instance: instances of AMR 1 626 | gold_instance: instances of AMR 2 627 | filp: filp the test/gold or not""" 628 | result = [] 629 | for i, m in enumerate(match): 630 | if m == -1: 631 | if not flip: 632 | result.append( 633 | test_instance[i][1] + 634 | "(" + 635 | test_instance[i][2] + 636 | ")" + 637 | "-Null") 638 | else: 639 | result.append( 640 | "Null-" + 641 | test_instance[i][1] + 642 | "(" + 643 | test_instance[i][2] + 644 | ")") 645 | else: 646 | if not flip: 647 | result.append( 648 | test_instance[i][1] + 649 | "(" + 650 | test_instance[i][2] + 651 | ")" + 652 | "-" + 653 | gold_instance[m][1] + 654 | "(" + 655 | gold_instance[m][2] + 656 | ")") 657 | else: 658 | result.append( 659 | gold_instance[m][1] + 660 | "(" + 661 | gold_instance[m][2] + 662 | ")" + 663 | "-" + 664 | test_instance[i][1] + 665 | "(" + 666 | test_instance[i][2] + 667 | ")") 668 | return " ".join(result) 669 | 670 | 671 | def compute_f(match_num, test_num, gold_num): 672 | """ Compute the f-score based on the matching triple number, triple number of the AMR set 1, triple number of AMR set 2 673 | Args: 674 | match_num: matching triple number 675 | test_num: triple number of AMR 1 676 | gold_num: triple number of AMR 2 677 | Returns: 678 | precision: match_num/test_num 679 | recall: match_num/gold_num 680 | f_score: 2*precision*recall/(precision+recall)""" 681 | if test_num == 0 or gold_num == 0: 682 | return (0.00, 0.00, 0.00) 683 | precision = float(match_num) / float(test_num) 684 | recall = float(match_num) / float(gold_num) 685 | if (precision + recall) != 0: 686 | f_score = 2 * precision * recall / (precision + recall) 687 | if verbose: 688 | print >> sys.stderr, "F-score:", f_score 689 | return (precision, recall, f_score) 690 | else: 691 | if verbose: 692 | print >> sys.stderr, "F-score:", "0.0" 693 | return (precision, recall, 0.00) 694 | 695 | 696 | def main(args): 697 | """Main function of the smatch calculation program""" 698 | global verbose 699 | global iter_num 700 | global single_score 701 | global pr_flag 702 | global match_num_dict 703 | # set the restart number 704 | iter_num = args.r + 1 705 | verbose = False 706 | if args.ms: 707 | single_score = False 708 | if args.v: 709 | verbose = True 710 | if args.pr: 711 | pr_flag = True 712 | total_match_num = 0 713 | total_test_num = 0 714 | total_gold_num = 0 715 | sent_num = 1 716 | while True: 717 | cur_amr1 = get_amr_line(args.f[0]) 718 | cur_amr2 = get_amr_line(args.f[1]) 719 | if cur_amr1 == "" and cur_amr2 == "": 720 | break 721 | if(cur_amr1 == ""): 722 | print >> sys.stderr, "Error: File 1 has less AMRs than file 2" 723 | print >> sys.stderr, "Ignoring remaining AMRs" 724 | break 725 | # print >> sys.stderr, "AMR 1 is empty" 726 | # continue 727 | if(cur_amr2 == ""): 728 | print >> sys.stderr, "Error: File 2 has less AMRs than file 1" 729 | print >> sys.stderr, "Ignoring remaining AMRs" 730 | break 731 | # print >> sys.stderr, "AMR 2 is empty" 732 | # continue 733 | amr1 = amr.AMR.parse_AMR_line(cur_amr1) 734 | amr2 = amr.AMR.parse_AMR_line(cur_amr2) 735 | test_label = "a" 736 | gold_label = "b" 737 | amr1.rename_node(test_label) 738 | amr2.rename_node(gold_label) 739 | (test_inst, test_rel1, test_rel2) = amr1.get_triples2() 740 | (gold_inst, gold_rel1, gold_rel2) = amr2.get_triples2() 741 | if verbose: 742 | print "AMR pair", sent_num 743 | print >> sys.stderr, "Instance triples of AMR 1:", len(test_inst) 744 | print >> sys.stderr, test_inst 745 | # print >> sys.stderr,"Relation triples of AMR 1:",len(test_rel) 746 | print >> sys.stderr, "Relation triples of AMR 1:", len( 747 | test_rel1) + len(test_rel2) 748 | print >>sys.stderr, test_rel1 749 | print >> sys.stderr, test_rel2 750 | # print >> sys.stderr, test_rel 751 | print >> sys.stderr, "Instance triples of AMR 2:", len(gold_inst) 752 | print >> sys.stderr, gold_inst 753 | # print >> sys.stderr,"Relation triples of file 2:",len(gold_rel) 754 | print >> sys.stderr, "Relation triples of AMR 2:", len( 755 | gold_rel1) + len(gold_rel2) 756 | #print >> sys.stderr,"Relation triples of file 2:",len(gold_rel1)+len(gold_rel2) 757 | print >> sys.stderr, gold_rel1 758 | print >> sys.stderr, gold_rel2 759 | # print >> sys.stderr, gold_rel 760 | if len(test_inst) < len(gold_inst): 761 | (best_match, 762 | best_match_num) = get_fh(test_inst, 763 | test_rel1, 764 | test_rel2, 765 | gold_inst, 766 | gold_rel1, 767 | gold_rel2, 768 | test_label, 769 | gold_label) 770 | if verbose: 771 | print >> sys.stderr, "AMR pair ", sent_num 772 | print >> sys.stderr, "best match number", best_match_num 773 | print >> sys.stderr, "best match", best_match 774 | print >>sys.stderr, "Best Match:", print_alignment( 775 | best_match, test_inst, gold_inst) 776 | else: 777 | (best_match, 778 | best_match_num) = get_fh(gold_inst, 779 | gold_rel1, 780 | gold_rel2, 781 | test_inst, 782 | test_rel1, 783 | test_rel2, 784 | gold_label, 785 | test_label) 786 | if verbose: 787 | print >> sys.stderr, "Sent ", sent_num 788 | print >> sys.stderr, "best match number", best_match_num 789 | print >> sys.stderr, "best match", best_match 790 | print >>sys.stderr, "Best Match:", print_alignment( 791 | best_match, gold_inst, test_inst, True) 792 | if not single_score: 793 | (precision, 794 | recall, 795 | best_f_score) = compute_f(best_match_num, 796 | len(test_rel1) + len(test_inst) + len(test_rel2), 797 | len(gold_rel1) + len(gold_inst) + len(gold_rel2)) 798 | print "Sentence", sent_num 799 | if pr_flag: 800 | print "Precision: %.2f" % precision 801 | print "Recall: %.2f" % recall 802 | print "Smatch score: %.2f" % best_f_score 803 | total_match_num += best_match_num 804 | total_test_num += len(test_rel1) + len(test_rel2) + len(test_inst) 805 | total_gold_num += len(gold_rel1) + len(gold_rel2) + len(gold_inst) 806 | match_num_dict.clear() 807 | sent_num += 1 # print "F-score:",best_f_score 808 | if verbose: 809 | print >> sys.stderr, "Total match num" 810 | print >> sys.stderr, total_match_num, total_test_num, total_gold_num 811 | if single_score: 812 | (precision, recall, best_f_score) = compute_f( 813 | total_match_num, total_test_num, total_gold_num) 814 | if pr_flag: 815 | print "Precision: %.2f" % precision 816 | print "Recall: %.2f" % recall 817 | print "Document F-score: %.2f" % best_f_score 818 | args.f[0].close() 819 | args.f[1].close() 820 | 821 | if __name__ == "__main__": 822 | parser = None 823 | args = None 824 | if sys.version_info[:2] != (2, 7): 825 | if sys.version_info[0] != 2 or sys.version_info[1] < 5: 826 | print >> ERROR_LOG, "Smatch only supports python 2.5 or later" 827 | exit(1) 828 | import optparse 829 | if len(sys.argv) == 1: 830 | print >> ERROR_LOG, "No argument given. Please run smatch.py -h to see the argument descriptions." 831 | exit(1) 832 | # requires version >=2.3! 833 | parser = build_arg_parser2() 834 | (args, opts) = parser.parse_args() 835 | # handling file errors 836 | # if not len(args.f)<2: 837 | # print >> ERROR_LOG,"File number given is less than 2" 838 | # exit(1) 839 | file_handle = [] 840 | if args.f is None: 841 | print >> ERROR_LOG, "smatch.py requires -f option to indicate two files containing AMR as input. Please run smatch.py -h to see the argument descriptions." 842 | exit(1) 843 | if not os.path.exists(args.f[0]): 844 | print >> ERROR_LOG, "Given file", args.f[0], "does not exist" 845 | exit(1) 846 | else: 847 | file_handle.append(codecs.open(args.f[0], encoding='utf8')) 848 | if not os.path.exists(args.f[1]): 849 | print >> ERROR_LOG, "Given file", args.f[1], "does not exist" 850 | exit(1) 851 | else: 852 | file_handle.append(codecs.open(args.f[1], encoding='utf8')) 853 | args.f = tuple(file_handle) 854 | else: # version 2.7 855 | import argparse 856 | parser = build_arg_parser() 857 | args = parser.parse_args() 858 | main(args) 859 | --------------------------------------------------------------------------------