├── .gitattributes
├── .gitignore
├── .ptignore
├── EVALB
├── COLLINS.prm
├── LICENSE
├── Makefile
├── README
├── evalb
├── evalb.c
├── new.prm
├── sample
│ ├── sample.gld
│ ├── sample.prm
│ ├── sample.rsl
│ └── sample.tst
└── tgrep_proc.prl
├── LICENSE
├── Ordered_Memory_Slides.pdf
├── README.md
├── data
├── listops
│ ├── __init__.py
│ ├── base.py
│ ├── load_listops_data.py
│ ├── make_data.py
│ ├── test_d20s.tsv
│ └── train_d20s.tsv
└── propositionallogic
│ ├── __init__.py
│ ├── generate_neg_set_data.py
│ ├── test0
│ ├── test1
│ ├── test10
│ ├── test11
│ ├── test12
│ ├── test2
│ ├── test3
│ ├── test4
│ ├── test5
│ ├── test6
│ ├── test7
│ ├── test8
│ ├── test9
│ ├── train0
│ ├── train1
│ ├── train10
│ ├── train11
│ ├── train12
│ ├── train2
│ ├── train3
│ ├── train4
│ ├── train5
│ ├── train6
│ ├── train7
│ ├── train8
│ └── train9
├── listops.py
├── ordered_memory.py
├── proplog.py
├── requirements.txt
├── sentiment.py
└── utils
├── __init__.py
├── hinton.py
├── listops_data.py
├── locked_dropout.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | core.*
2 |
3 | *.hdf5
4 | *.pt
5 | .vector_cache
6 | .data
7 |
8 | *.xml
9 | *.iml
10 |
11 |
12 | # vim and gedit cache:
13 | *.swp
14 | *.swo
15 | *.swn
16 | *.swl
17 | *~
18 |
19 | # cluster logs
20 | SMART_DISPATCH_LOGS/*
21 |
22 | # model params
23 | model/*
24 | params/*
25 | tmptrees/*
26 |
27 | # logs
28 | tblogs/*
29 | logs/*
30 |
31 | # Byte-compiled / optimized / DLL files
32 | __pycache__/
33 | *.py[cod]
34 | *$py.class
35 | *.pyc
36 |
37 | *.log
38 |
39 |
--------------------------------------------------------------------------------
/.ptignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | Philly*
3 | exp_logic
4 | *.diff
5 | data/data_scan
6 | data/scan
7 | data/propositionallogic_exp1
8 | data/mnli
9 | data/SST
10 | data/penn
11 | data/listops
12 | data/treebank_proc.conllu
13 | data/test_proc.conllu
14 |
--------------------------------------------------------------------------------
/EVALB/COLLINS.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## 0: No debugging ##
4 | ## 1: print data for individual sentence ##
5 | ##------------------------------------------##
6 | DEBUG 0
7 |
8 | ##------------------------------------------##
9 | ## MAX error ##
10 | ## Number of error to stop the process. ##
11 | ## This is useful if there could be ##
12 | ## tokanization error. ##
13 | ## The process will stop when this number##
14 | ## of errors are accumulated. ##
15 | ##------------------------------------------##
16 | MAX_ERROR 10
17 |
18 | ##------------------------------------------##
19 | ## Cut-off length for statistics ##
20 | ## At the end of evaluation, the ##
21 | ## statistics for the senetnces of length##
22 | ## less than or equal to this number will##
23 | ## be shown, on top of the statistics ##
24 | ## for all the sentences ##
25 | ##------------------------------------------##
26 | CUTOFF_LEN 40
27 |
28 | ##------------------------------------------##
29 | ## unlabeled or labeled bracketing ##
30 | ## 0: unlabeled bracketing ##
31 | ## 1: labeled bracketing ##
32 | ##------------------------------------------##
33 | LABELED 0
34 |
35 | ##------------------------------------------##
36 | ## Delete labels ##
37 | ## list of labels to be ignored. ##
38 | ## If it is a pre-terminal label, delete ##
39 | ## the word along with the brackets. ##
40 | ## If it is a non-terminal label, just ##
41 | ## delete the brackets (don't delete ##
42 | ## deildrens). ##
43 | ##------------------------------------------##
44 | DELETE_LABEL ROOT
45 |
46 | ##------------------------------------------##
47 | ## Delete labels for length calculation ##
48 | ## list of labels to be ignored for ##
49 | ## length calculation purpose ##
50 | ##------------------------------------------##
51 | DELETE_LABEL_FOR_LENGTH -NONE-
52 |
53 | ##------------------------------------------##
54 | ## Equivalent labels, words ##
55 | ## the pairs are considered equivalent ##
56 | ## This is non-directional. ##
57 | ##------------------------------------------##
58 | EQ_LABEL ADVP PRT
59 |
60 | # EQ_WORD Example example
61 |
--------------------------------------------------------------------------------
/EVALB/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/EVALB/Makefile:
--------------------------------------------------------------------------------
1 | all: evalb
2 |
3 | evalb: evalb.c
4 | gcc -Wall -g -o evalb evalb.c
5 |
--------------------------------------------------------------------------------
/EVALB/README:
--------------------------------------------------------------------------------
1 | #################################################################
2 | # #
3 | # Bug fix and additional functionality for evalb #
4 | # #
5 | # This updated version of evalb fixes a bug in which sentences #
6 | # were incorrectly categorized as "length mismatch" when the #
7 | # the parse output had certain mislabeled parts-of-speech. #
8 | # #
9 | # The bug was the result of evalb treating one of the tags (in #
10 | # gold or test) as a label to be deleted (see sections [6],[7] #
11 | # for details), but not the corresponding tag in the other. #
12 | # This most often occurs with punctuation. See the subdir #
13 | # "bug" for an example gld and tst file demonstating the bug, #
14 | # as well as output of evalb with and without the bug fix. #
15 | # #
16 | # For the present version in case of length mismatch, the nodes #
17 | # causing the imbalance are reinserted to resolve the miscount. #
18 | # If the lengths of gold and test truly differ, the error is #
19 | # still reported. The parameter file "new.prm" (derived from #
20 | # COLLINS.prm) shows how to add new potential mislabelings for #
21 | # quotes (",``,',`). #
22 | # #
23 | # I have preserved DJB's revision for modern compilers except #
24 | # for the delcaration of "exit" which is provided by stdlib. #
25 | # #
26 | # Other changes: #
27 | # #
28 | # * output of F-Measure in addition to precision and recall #
29 | # (I did not update the documention in section [4] for this) #
30 | # #
31 | # * more comprehensive DEBUG output that includes bracketing #
32 | # information as evalb is processing each sentence #
33 | # (useful in working through this, and peraps other bugs). #
34 | # Use either the "-D" run-time switch or set DEBUG to 2 in #
35 | # the parameter file. #
36 | # #
37 | # * added DELETE_LABEL lines in new.prm for S1 nodes produced #
38 | # by the Charniak parser and "?", "!" punctuation produced by #
39 | # the Bikel parser. #
40 | # #
41 | # #
42 | # David Ellis (Brown) #
43 | # #
44 | # January.2006 #
45 | #################################################################
46 |
47 | #################################################################
48 | # #
49 | # Update of evalb for modern compilers #
50 | # #
51 | # This is an updated version of evalb, for use with modern C #
52 | # compilers. There are a few updates, each marked in the code: #
53 | # #
54 | # /* DJB: explanation of comment */ #
55 | # #
56 | # The updates are purely to help compilation with recent #
57 | # versions of GCC (and other C compilers). There are *NO* other #
58 | # changes to the algorithm itself. #
59 | # #
60 | # I have made these changes following recommendations from #
61 | # users of the Corpora Mailing List, especially Peet Morris and #
62 | # Ramon Ziai. #
63 | # #
64 | # David Brooks (Birmingham) #
65 | # #
66 | # September.2005 #
67 | #################################################################
68 |
69 | #################################################################
70 | # #
71 | # README file for evalb #
72 | # #
73 | # Satoshi Sekine (NYU) #
74 | # Mike Collins (UPenn) #
75 | # #
76 | # October.1997 #
77 | #################################################################
78 |
79 | Contents of this README:
80 |
81 | [0] COPYRIGHT
82 | [1] INTRODUCTION
83 | [2] INSTALLATION AND RUN
84 | [3] OPTIONS
85 | [4] OUTPUT FORMAT FROM THE SCORER
86 | [5] HOW TO CREATE A GOLDFILE FROM THE TREEBANK
87 | [6] THE PARAMETER FILE
88 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM
89 |
90 |
91 | [0] COPYRIGHT
92 |
93 | The authors abandon the copyright of this program. Everyone is
94 | permitted to copy and distribute the program or a portion of the program
95 | with no charge and no restrictions unless it is harmful to someone.
96 |
97 | However, the authors are delightful for the user's kindness of proper
98 | usage and letting the authors know bugs or problems.
99 |
100 | This software is provided "AS IS", and the authors make no warranties,
101 | express or implied.
102 |
103 | To legally enforce the abandonment of copyright, this package is released
104 | under the Unlicense (see LICENSE).
105 |
106 | [1] INTRODUCTION
107 |
108 | Evaluation of bracketing looks simple, but in fact, there are minor
109 | differences from system to system. This is a program to parametarize
110 | such minor differences and to give an informative result.
111 |
112 | "evalb" evaluates bracketing accuracy in a test-file against a gold-file.
113 | It returns recall, precision, tagging accuracy. It uses an identical
114 | algorithm to that used in (Collins ACL97).
115 |
116 |
117 | [2] Installation and Run
118 |
119 | To compile the scorer, type
120 |
121 | > make
122 |
123 |
124 | To run the scorer:
125 |
126 | > evalb -p Parameter_file Gold_file Test_file
127 |
128 |
129 | For example to use the sample files:
130 |
131 | > evalb -p sample.prm sample.gld sample.tst
132 |
133 |
134 |
135 | [3] OPTIONS
136 |
137 | You can specify system parameters in the command line options.
138 | Other options concerning to evaluation metrix should be specified
139 | in parameter file, described later.
140 |
141 | -p param_file parameter file
142 | -d debug mode
143 | -e n number of error to kill (default=10)
144 | -h help
145 |
146 |
147 |
148 | [4] OUTPUT FORMAT FROM THE SCORER
149 |
150 | The scorer gives individual scores for each sentence, for
151 | example:
152 |
153 | Sent. Matched Bracket Cross Correct Tag
154 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy
155 | ============================================================================
156 | 1 8 0 100.00 100.00 5 5 5 0 6 5 83.33
157 |
158 | At the end of the output the === Summary === section gives statistics
159 | for all sentences, and for sentences <=40 words in length. The summary
160 | contains the following information:
161 |
162 | i) Number of sentences -- total number of sentences.
163 |
164 | ii) Number of Error/Skip sentences -- should both be 0 if there is no
165 | problem with the parsed/gold files.
166 |
167 | iii) Number of valid sentences = Number of sentences - Number of Error/Skip
168 | sentences
169 |
170 | iv) Bracketing recall = (number of correct constituents)
171 | ----------------------------------------
172 | (number of constituents in the goldfile)
173 |
174 | v) Bracketing precision = (number of correct constituents)
175 | ----------------------------------------
176 | (number of constituents in the parsed file)
177 |
178 | vi) Complete match = percentaage of sentences where recall and precision are
179 | both 100%.
180 |
181 | vii) Average crossing = (number of constituents crossing a goldfile constituen
182 | ----------------------------------------------------
183 | (number of sentences)
184 |
185 | viii) No crossing = percentage of sentences which have 0 crossing brackets.
186 |
187 | ix) 2 or less crossing = percentage of sentences which have <=2 crossing brackets.
188 |
189 | x) Tagging accuracy = percentage of correct POS tags (but see [5].3 for exact
190 | details of what is counted).
191 |
192 |
193 |
194 | [5] HOW TO CREATE A GOLDFILE FROM THE PENN TREEBANK
195 |
196 |
197 | The gold and parsed files are in a format similar to this:
198 |
199 | (TOP (S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .)))
200 |
201 | To create a gold file from the treebank:
202 |
203 | tgrep -wn '/.*/' | tgrep_proc.prl
204 |
205 | will produce a goldfile in the required format. ("tgrep -wn '/.*/'" prints
206 | parse trees, "tgrep_process.prl" just skips blank lines).
207 |
208 | For example, to produce a goldfile for section 23 of the treebank:
209 |
210 | tgrep -wn '/.*/' | tail +90895 | tgrep_process.prl | sed 2416q > sec23.gold
211 |
212 |
213 |
214 | [6] THE PARAMETER (.prm) FILE
215 |
216 |
217 | The .prm file sets options regarding the scoring method. COLLINS.prm gives
218 | the same scoring behaviour as the scorer used in (Collins 97). The options
219 | chosen were:
220 |
221 | 1) LABELED 1
222 |
223 | to give labelled precision/recall figures, i.e. a constituent must have the
224 | same span *and* label as a constituent in the goldfile.
225 |
226 | 2) DELETE_LABEL TOP
227 |
228 | Don't count the "TOP" label (which is always given in the output of tgrep)
229 | when scoring.
230 |
231 | 3) DELETE_LABEL -NONE-
232 |
233 | Remove traces (and all constituents which dominate nothing but traces) when
234 | scoring. For example
235 |
236 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .)))
237 |
238 | would be processed to give
239 |
240 | .... (VP (VBD reported)) (. .)))
241 |
242 |
243 | 4)
244 | DELETE_LABEL , -- for the purposes of scoring remove punctuation
245 | DELETE_LABEL :
246 | DELETE_LABEL ``
247 | DELETE_LABEL ''
248 | DELETE_LABEL .
249 |
250 | 5) DELETE_LABEL_FOR_LENGTH -NONE- -- don't include traces when calculating
251 | the length of a sentence (important
252 | when classifying a sentence as <=40
253 | words or >40 words)
254 |
255 | 6) EQ_LABEL ADVP PRT
256 |
257 | Count ADVP and PRT as being the same label when scoring.
258 |
259 |
260 |
261 |
262 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM
263 |
264 |
265 | 1) The scorer initially processes the files to remove all nodes specified
266 | by DELETE_LABEL in the .prm file. It also recursively removes nodes which
267 | dominate nothing due to all their children being removed. For example, if
268 | -NONE- is specified as a label to be deleted,
269 |
270 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .)))
271 |
272 | would be processed to give
273 |
274 | .... (VP (VBD reported)) (. .)))
275 |
276 | 2) The scorer also removes all functional tags attached to non-terminals
277 | (functional tags are prefixed with "-" or "=" in the treebank). For example
278 | "NP-SBJ" is processed to give "NP", "NP=2" is changed to "NP".
279 |
280 |
281 | 3) Tagging accuracy counts tags for all words *except* any tags which are
282 | deleted by a DELETE_LABEL specification in the .prm file. (For example, for
283 | COLLINS.prm, punctuation tagged as "," ":" etc. would not be included).
284 |
285 | 4) When calculating the length of a sentence, all words with POS tags not
286 | included in the "DELETE_LABEL_FOR_LENGTH" list in the .prm file are
287 | counted. (For COLLINS.prm, only "-NONE-" is specified in this list, so
288 | traces are removed before calculating the length of the sentence).
289 |
290 | 5) There are some subtleties in scoring when either the goldfile or parsed
291 | file contains multiple constituents for the same span which have the same
292 | non-terminal label. e.g. (NP (NP the man)) If the goldfile contains n
293 | constituents for the same span, and the parsed file contains m constituents
294 | with that nonterminal, the scorer works as follows:
295 |
296 | i) If m>n, then the precision is n/m, recall is 100%
297 |
298 | ii) If n>m, then the precision is 100%, recall is m/n.
299 |
300 | iii) If n==m, recall and precision are both 100%.
301 |
--------------------------------------------------------------------------------
/EVALB/evalb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikangshen/Ordered-Memory/2a0d3a22fb70216993a7faf25de3515f42e10431/EVALB/evalb
--------------------------------------------------------------------------------
/EVALB/evalb.c:
--------------------------------------------------------------------------------
1 | /*****************************************************************/
2 | /* evalb [-p param_file] [-dh] [-e n] gold-file test-file */
3 | /* */
4 | /* Evaluate bracketing in test-file against gold-file. */
5 | /* Return recall, precision, tagging accuracy. */
6 | /* */
7 | /* */
8 | /* -p param_file parameter file */
9 | /* -d debug mode */
10 | /* -e n number of error to kill (default=10) */
11 | /* -h help */
12 | /* */
13 | /* Satoshi Sekine (NYU) */
14 | /* Mike Collins (UPenn) */
15 | /* */
16 | /* October.1997 */
17 | /* */
18 | /* Please refer README for the update information */
19 | /*****************************************************************/
20 |
21 | #include
22 | #include //### added for exit, atoi decls
23 | #include
24 | #include
25 | #include
26 |
27 |
28 | /* Internal Data format -------------------------------------------*/
29 | /* */
30 | /* (S (NP (NNX this)) (VP (VBX is) (NP (DT a) (NNX pen))) (SYM .)) */
31 | /* */
32 | /* wn=5 */
33 | /* word label */
34 | /* terminal[0] = this NNX */
35 | /* terminal[1] = is VBX */
36 | /* terminal[2] = a DT */
37 | /* terminal[3] = pen NNX */
38 | /* terminal[4] = . SYM */
39 | /* */
40 | /* bn=4 */
41 | /* start end label */
42 | /* bracket[0] = 0 5 S */
43 | /* bracket[1] = 0 0 NP */
44 | /* bracket[2] = 1 4 VP */
45 | /* bracket[3] = 2 4 NP */
46 | /* */
47 | /* matched bracketing */
48 | /* Recall = --------------------------- */
49 | /* # of bracket in ref-data */
50 | /* */
51 | /* matched bracketing */
52 | /* Recall = --------------------------- */
53 | /* # of bracket in test-data */
54 | /* */
55 | /*-----------------------------------------------------------------*/
56 |
57 | /******************/
58 | /* constant macro */
59 | /******************/
60 |
61 | #define MAX_SENT_LEN 5000
62 | #define MAX_WORD_IN_SENT 200
63 | #define MAX_BRACKET_IN_SENT 200
64 | #define MAX_WORD_LEN 100
65 | #define MAX_LABEL_LEN 30
66 | #define MAX_QUOTE_TERM 20
67 |
68 | #define MAX_DELETE_LABEL 100
69 | #define MAX_EQ_LABEL 100
70 | #define MAX_EQ_WORD 100
71 |
72 | #define MAX_LINE_LEN 500
73 |
74 | #define DEFAULT_MAX_ERROR 10
75 | #define DEFAULT_CUT_LEN 40
76 |
77 | /*************/
78 | /* structure */
79 | /*************/
80 |
81 | typedef struct ss_terminal {
82 | char word[MAX_WORD_LEN];
83 | char label[MAX_LABEL_LEN];
84 | int result; /* 0:unmatch, 1:match, 9:undef */
85 | } s_terminal;
86 |
87 | typedef struct ss_term_ind {
88 | s_terminal term;
89 | int index;
90 | int bracket;
91 | int endslen;
92 | int ends[MAX_BRACKET_IN_SENT];
93 | } s_term_ind;
94 |
95 | typedef struct ss_bracket {
96 | int start;
97 | int end;
98 | unsigned int buf_start;
99 | unsigned int buf_end;
100 | char label[MAX_LABEL_LEN];
101 | int result; /* 0: unmatch, 1:match, 5:delete 9:undef */
102 | } s_bracket;
103 |
104 |
105 | typedef struct ss_equiv {
106 | char *s1;
107 | char *s2;
108 | } s_equiv;
109 |
110 |
111 | /****************************/
112 | /* global variables */
113 | /* gold-data: suffix = 1 */
114 | /* test-data: suffix = 2 */
115 | /****************************/
116 |
117 | /*---------------*/
118 | /* Sentence data */
119 | /*---------------*/
120 | int wn1, wn2; /* number of words in sentence */
121 | int r_wn1; /* number of words in sentence */
122 | /* which only ignores labels in */
123 | /* DELETE_LABEL_FOR_LENGTH */
124 |
125 | s_terminal terminal1[MAX_WORD_IN_SENT]; /* terminal information */
126 | s_terminal terminal2[MAX_WORD_IN_SENT];
127 |
128 | s_term_ind quotterm1[MAX_QUOTE_TERM]; /* special terminals ("'","POS") */
129 | s_term_ind quotterm2[MAX_QUOTE_TERM];
130 |
131 | int bn1, bn2; /* number of brackets */
132 |
133 | int r_bn1, r_bn2; /* number of brackets */
134 | /* after deletion */
135 |
136 | s_bracket bracket1[MAX_BRACKET_IN_SENT]; /* bracket information */
137 | s_bracket bracket2[MAX_BRACKET_IN_SENT];
138 |
139 |
140 | /*------------*/
141 | /* Total data */
142 | /*------------*/
143 | int TOTAL_bn1, TOTAL_bn2, TOTAL_match; /* total number of brackets */
144 | int TOTAL_sent; /* No. of sentence */
145 | int TOTAL_error_sent; /* No. of error sentence */
146 | int TOTAL_skip_sent; /* No. of skip sentence */
147 | int TOTAL_comp_sent; /* No. of complete match sent */
148 | int TOTAL_word; /* total number of word */
149 | int TOTAL_crossing; /* total crossing */
150 | int TOTAL_no_crossing; /* no crossing sentence */
151 | int TOTAL_2L_crossing; /* 2 or less crossing sentence */
152 | int TOTAL_correct_tag; /* total correct tagging */
153 |
154 | int TOT_cut_len = DEFAULT_CUT_LEN; /* Cut-off length in statistics */
155 |
156 | /* data for sentences with len <= CUT_LEN */
157 | /* Historically it was 40. */
158 | int TOT40_bn1, TOT40_bn2, TOT40_match; /* total number of brackets */
159 | int TOT40_sent; /* No. of sentence */
160 | int TOT40_error_sent; /* No. of error sentence */
161 | int TOT40_skip_sent; /* No. of skip sentence */
162 | int TOT40_comp_sent; /* No. of complete match sent */
163 | int TOT40_word; /* total number of word */
164 | int TOT40_crossing; /* total crossing */
165 | int TOT40_no_crossing; /* no crossing sentence */
166 | int TOT40_2L_crossing; /* 2 or less crossing sentence */
167 | int TOT40_correct_tag; /* total correct tagging */
168 |
169 | /*------------*/
170 | /* miscallous */
171 | /*------------*/
172 | int Line; /* line number */
173 | int Error_count = 0; /* Error count */
174 | int Status; /* Result status for each sent */
175 | /* 0: OK, 1: skip, 2: error */
176 |
177 | /*-------------------*/
178 | /* stack manuplation */
179 | /*-------------------*/
180 | int stack_top;
181 | int stack[MAX_BRACKET_IN_SENT];
182 |
183 | /************************************************************/
184 | /* User parameters which can be specified in parameter file */
185 | /************************************************************/
186 |
187 | /*------------------------------------------*/
188 | /* Debug mode */
189 | /* print out data for individual sentence */
190 | /*------------------------------------------*/
191 | int DEBUG=0;
192 |
193 | /*------------------------------------------*/
194 | /* MAX error */
195 | /* Number of error to stop the process. */
196 | /* This is useful if there could be */
197 | /* tokanization error. */
198 | /* The process will stop when this number*/
199 | /* of errors are accumulated. */
200 | /*------------------------------------------*/
201 | int Max_error = DEFAULT_MAX_ERROR;
202 |
203 | /*------------------------------------------*/
204 | /* Cut-off length for statistics */
205 | /* int TOT_cut_len = DEFAULT_CUT_LEN; */
206 | /* (Defined above) */
207 | /*------------------------------------------*/
208 |
209 |
210 | /*------------------------------------------*/
211 | /* unlabeled or labeled bracketing */
212 | /* 0: unlabeled bracketing */
213 | /* 1: labeled bracketing */
214 | /*------------------------------------------*/
215 | int F_label = 1;
216 |
217 | /*------------------------------------------*/
218 | /* Delete labels */
219 | /* list of labels to be ignored. */
220 | /* If it is a pre-terminal label, delete */
221 | /* the word along with the brackets. */
222 | /* If it is a non-terminal label, just */
223 | /* delete the brackets (don't delete */
224 | /* childrens). */
225 | /*------------------------------------------*/
226 | char *Delete_label[MAX_DELETE_LABEL];
227 | int Delete_label_n = 0;
228 |
229 | /*------------------------------------------*/
230 | /* Delete labels for length calculation */
231 | /* list of labels to be ignored for */
232 | /* length calculation purpose */
233 | /*------------------------------------------*/
234 | char *Delete_label_for_length[MAX_DELETE_LABEL];
235 | int Delete_label_for_length_n = 0;
236 |
237 | /*------------------------------------------*/
238 | /* Labels to be considered for misquote */
239 | /* (could be possesive or quote) */
240 | /*------------------------------------------*/
241 | char *Quote_term[MAX_QUOTE_TERM];
242 | int Quote_term_n = 0;
243 |
244 | /*------------------------------------------*/
245 | /* Equivalent labels, words */
246 | /* the pairs are considered equivalent */
247 | /* This is non-directional. */
248 | /*------------------------------------------*/
249 | s_equiv EQ_label[MAX_EQ_LABEL];
250 | int EQ_label_n = 0;
251 |
252 | s_equiv EQ_word[MAX_EQ_WORD];
253 | int EQ_word_n = 0;
254 |
255 |
256 |
257 | /************************/
258 | /* Function return-type */
259 | /************************/
260 | int main();
261 | void init_global();
262 | void print_head();
263 | void init();
264 | void read_parameter_file();
265 | void set_param();
266 | int narg();
267 | int read_line();
268 |
269 | void pushb();
270 | int popb();
271 | int stackempty();
272 |
273 | void calc_result(unsigned char *buf1,unsigned char *buf);
274 | void fix_quote();
275 | void reinsert_term();
276 | void massage_data();
277 | void modify_label();
278 | void individual_result();
279 | void print_total();
280 | void dsp_info();
281 | int is_terminator();
282 | int is_deletelabel();
283 | int is_deletelabel_for_length();
284 | int is_quote_term();
285 | int word_comp();
286 | int label_comp();
287 |
288 | void Error();
289 | void Fatal();
290 | void Usage();
291 |
292 | /* ### provided by std headers
293 | int fprintf();
294 | int printf();
295 | int atoi();
296 | int fclose();
297 | int sscanf();
298 | */
299 |
300 | /***********/
301 | /* program */
302 | /***********/
303 | #define ARG_CHECK(st) if(!(*++(*argv) || (--argc && *++argv))){ \
304 | fprintf(stderr,"Missing argument: %s\n",st); \
305 | }
306 |
307 | int
308 | main(argc,argv)
309 | int argc;
310 | char *argv[];
311 | {
312 | char *filename1, *filename2;
313 | FILE *fd1, *fd2;
314 | unsigned char buff[5000];
315 | unsigned char buff1[5000];
316 |
317 | filename1=NULL;
318 | filename2=NULL;
319 |
320 | for(argc--,argv++;argc>0;argc--,argv++){
321 | if(**argv == '-'){
322 | while(*++(*argv)){
323 | switch(**argv){
324 |
325 | case 'h': /* help */
326 | Usage();
327 | exit(1);
328 |
329 | case 'd': /* debug mode */
330 | DEBUG = 1;
331 | goto nextarg;
332 |
333 | case 'D': /* debug mode */
334 | DEBUG = 2;
335 | goto nextarg;
336 |
337 | case 'c': /* cut-off length */
338 | ARG_CHECK("cut-off length for statistices");
339 | TOT_cut_len = atoi(*argv);
340 | goto nextarg;
341 |
342 | case 'e': /* max error */
343 | ARG_CHECK("number of error to kill");
344 | Max_error = atoi(*argv);
345 | goto nextarg;
346 |
347 | case 'p': /* parameter file */
348 | ARG_CHECK("parameter file");
349 | read_parameter_file(*argv);
350 | goto nextarg;
351 |
352 | default:
353 | Usage();
354 | exit(0);
355 | }
356 | }
357 | } else {
358 | if(filename1==NULL){
359 | filename1 = *argv;
360 | }else if(filename2==NULL){
361 | filename2 = *argv;
362 | }
363 | }
364 | nextarg: continue;
365 | }
366 |
367 | init_global();
368 |
369 |
370 | if((fd1 = fopen(filename1,"r"))==NULL){
371 | Fatal("Can't open gold file (%s)\n",filename1);
372 | }
373 | if((fd2 = fopen(filename2,"r"))==NULL){
374 | Fatal("Can't open test file (%s)\n",filename2);
375 | }
376 |
377 | print_head();
378 |
379 | for(Line=1;fgets(buff,5000,fd1)!=NULL;Line++){
380 |
381 | init();
382 |
383 | /* READ 1 */
384 | r_wn1 = read_line(buff,terminal1,quotterm1,&wn1,bracket1,&bn1);
385 |
386 | strcpy(buff1,buff);
387 |
388 | /* READ 2 */
389 | if(fgets(buff,5000,fd2)==NULL){
390 | Error("Number of lines unmatch (too many lines in gold file)\n");
391 | break;
392 | }
393 |
394 | read_line(buff,terminal2,quotterm2,&wn2,bracket2,&bn2);
395 |
396 | /* Calculate result and print it */
397 | calc_result(buff1,buff);
398 |
399 | if(DEBUG>=1){
400 | dsp_info();
401 | }
402 | }
403 |
404 | if(fgets(buff,5000,fd2)!=NULL){
405 | Error("Number of lines unmatch (too many lines in test file)\n");
406 | }
407 |
408 | print_total();
409 |
410 | return (0);
411 | }
412 |
413 |
414 | /*-----------------------------*/
415 | /* initialize global variables */
416 | /*-----------------------------*/
417 | void
418 | init_global()
419 | {
420 | TOTAL_bn1 = TOTAL_bn2 = TOTAL_match = 0;
421 | TOTAL_sent = TOTAL_error_sent = TOTAL_skip_sent = TOTAL_comp_sent = 0;
422 | TOTAL_word = TOTAL_correct_tag = 0;
423 | TOTAL_crossing = 0;
424 | TOTAL_no_crossing = TOTAL_2L_crossing = 0;
425 |
426 | TOT40_bn1 = TOT40_bn2 = TOT40_match = 0;
427 | TOT40_sent = TOT40_error_sent = TOT40_skip_sent = TOT40_comp_sent = 0;
428 | TOT40_word = TOT40_correct_tag = 0;
429 | TOT40_crossing = 0;
430 | TOT40_no_crossing = TOT40_2L_crossing = 0;
431 |
432 | }
433 |
434 |
435 | /*------------------*/
436 | /* print head title */
437 | /*------------------*/
438 | void
439 | print_head()
440 | {
441 | printf(" Sent. Matched Bracket Cross Correct Tag\n");
442 | printf(" ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy\n");
443 | printf("============================================================================\n");
444 | }
445 |
446 |
447 | /*-----------------------------------------------*/
448 | /* initialization at each individual computation */
449 | /*-----------------------------------------------*/
450 | void
451 | init()
452 | {
453 | int i;
454 |
455 | wn1 = 0;
456 | wn2 = 0;
457 | bn1 = 0;
458 | bn2 = 0;
459 | r_bn1 = 0;
460 | r_bn2 = 0;
461 |
462 | for(i=0;i0 && (isspace(buff[i]) || buff[i]=='\n');i--){
519 | buff[i]='\0';
520 | }
521 | if(buff[0]=='#' || /* comment-line */
522 | strlen(buff)<3){ /* too short, just ignore */
523 | continue;
524 | }
525 |
526 | /* place the parameter and value */
527 | /*-------------------------------*/
528 | for(i=0;!isspace(buff[i]);i++);
529 | for(;isspace(buff[i]) && buff[i]!='\0';i++);
530 | if(buff[i]=='\0'){
531 | fprintf(stderr,"Empty value in parameter file (%d)\n",line);
532 | }
533 |
534 | /* set parameter and value */
535 | /*-------------------------*/
536 | set_param(buff,buff+i);
537 | }
538 |
539 | fclose(fd);
540 | }
541 |
542 |
543 | #define STRNCMP(s) (strncmp(param,s,strlen(s))==0 && \
544 | (param[strlen(s)]=='\0' || isspace(param[strlen(s)])))
545 |
546 |
547 | void
548 | set_param(param,value)
549 | char *param, *value;
550 | {
551 | char l1[MAX_LABEL_LEN], l2[MAX_LABEL_LEN];
552 |
553 | if(STRNCMP("DEBUG")){
554 |
555 | DEBUG = atoi(value);
556 |
557 | }else if(STRNCMP("MAX_ERROR")){
558 |
559 | Max_error = atoi(value);
560 |
561 | }else if(STRNCMP("CUTOFF_LEN")){
562 |
563 | TOT_cut_len = atoi(value);
564 |
565 | }else if(STRNCMP("LABELED")){
566 |
567 | F_label = atoi(value);
568 |
569 | }else if(STRNCMP("DELETE_LABEL")){
570 |
571 | Delete_label[Delete_label_n] = (char *)malloc(strlen(value)+1);
572 | strcpy(Delete_label[Delete_label_n],value);
573 | Delete_label_n++;
574 |
575 | }else if(STRNCMP("DELETE_LABEL_FOR_LENGTH")){
576 |
577 | Delete_label_for_length[Delete_label_for_length_n] = (char *)malloc(strlen(value)+1);
578 | strcpy(Delete_label_for_length[Delete_label_for_length_n],value);
579 | Delete_label_for_length_n++;
580 |
581 | }else if(STRNCMP("QUOTE_LABEL")){
582 |
583 | Quote_term[Quote_term_n] = (char *)malloc(strlen(value)+1);
584 | strcpy(Quote_term[Quote_term_n],value);
585 | Quote_term_n++;
586 |
587 | }else if(STRNCMP("EQ_LABEL")){
588 |
589 | if(narg(value)!=2){
590 | fprintf(stderr,"EQ_LABEL requires two values\n");
591 | return;
592 | }
593 | sscanf(value,"%s %s",l1,l2);
594 | EQ_label[EQ_label_n].s1 = (char *)malloc(strlen(l1)+1);
595 | strcpy(EQ_label[EQ_label_n].s1,l1);
596 | EQ_label[EQ_label_n].s2 = (char *)malloc(strlen(l2)+1);
597 | strcpy(EQ_label[EQ_label_n].s2,l2);
598 | EQ_label_n++;
599 |
600 | }else if(STRNCMP("EQ_WORD")){
601 |
602 | if(narg(value)!=2){
603 | fprintf(stderr,"EQ_WORD requires two values\n");
604 | return;
605 | }
606 | sscanf(value,"%s %s",l1,l2);
607 | EQ_word[EQ_word_n].s1 = (char *)malloc(strlen(l1)+1);
608 | strcpy(EQ_word[EQ_word_n].s1,l1);
609 | EQ_word[EQ_word_n].s2 = (char *)malloc(strlen(l2)+1);
610 | strcpy(EQ_word[EQ_word_n].s2,l2);
611 | EQ_word_n++;
612 |
613 | }else{
614 |
615 | fprintf(stderr,"Unknown keyword (%s) in parameter file\n",param);
616 |
617 | }
618 | }
619 |
620 |
621 | int
622 | narg(s)
623 | char *s;
624 | {
625 | int n;
626 |
627 | for(n=0;*s!='\0';){
628 | for(;isspace(*s);s++);
629 | if(*s=='\0'){
630 | break;
631 | }
632 | n++;
633 | for(;!isspace(*s);s++){
634 | if(*s=='\0'){
635 | break;
636 | }
637 | }
638 | }
639 |
640 | return(n);
641 | }
642 |
643 | /*-----------------------------*/
644 | /* Read line and gather data. */
645 | /* Return langth of sentence. */
646 | /*-----------------------------*/
647 | int
648 | read_line(buff, terminal, quotterm, wn, bracket, bn)
649 | char *buff;
650 | s_terminal terminal[];
651 | s_term_ind quotterm[];
652 | int *wn;
653 | s_bracket bracket[];
654 | int *bn;
655 | {
656 | char *p, *q, label[MAX_LABEL_LEN], word[MAX_WORD_LEN];
657 | int qt; /* quote term counter */
658 | int wid, bid; /* word ID, bracket ID */
659 | int n; /* temporary remembering the position */
660 | int b; /* temporary remembering bid */
661 | int i;
662 | int len; /* length of the sentence */
663 |
664 | len = 0;
665 | stack_top=0;
666 |
667 | for(p=buff,qt=0,wid=0,bid=0;*p!='\0';){
668 |
669 | if(isspace(*p)){
670 | p++;
671 | continue;
672 |
673 | /* open bracket */
674 | /*--------------*/
675 | }else if(*p=='('){
676 |
677 | n=wid;
678 | for(p++,i=0;!is_terminator(*p);p++,i++){
679 | label[i]=*p;
680 | }
681 | label[i]='\0';
682 |
683 | /* Find terminals */
684 | q = p;
685 | if(isspace(*q)){
686 | for(q++;isspace(*q);q++);
687 | for(i=0;!is_terminator(*q);q++,i++){
688 | word[i]=*q;
689 | }
690 | word[i]='\0';
691 |
692 | /* compute length */
693 | if(*q==')' && !is_deletelabel_for_length(label)==1){
694 | len++;
695 | }
696 | if (DEBUG>1)
697 | printf("label=%s, word=%s, wid=%d\n",label,word,wid);
698 | /* quote terminal */
699 | if(*q==')' && is_quote_term(label,word)==1){
700 | strcpy(quotterm[qt].term.word,word);
701 | strcpy(quotterm[qt].term.label,label);
702 | quotterm[qt].index = wid;
703 | quotterm[qt].bracket = bid;
704 | quotterm[qt].endslen = stack_top;
705 | //quotterm[qt].ends = (int*)malloc(stack_top*sizeof(int));
706 | memcpy(quotterm[qt].ends,stack,stack_top*sizeof(int));
707 | qt++;
708 | }
709 |
710 | /* delete terminal */
711 | if(*q==')' && is_deletelabel(label)==1){
712 | p = q+1;
713 | continue;
714 |
715 | /* valid terminal */
716 | }else if(*q==')'){
717 | strcpy(terminal[wid].word,word);
718 | strcpy(terminal[wid].label,label);
719 | wid++;
720 | p = q+1;
721 | continue;
722 |
723 | /* error */
724 | }else if(*q!='('){
725 | Error("More than two elements in a bracket\n");
726 | }
727 | }
728 |
729 | /* otherwise non-terminal label */
730 | bracket[bid].start = wid;
731 | bracket[bid].buf_start = p-buff;
732 | strcpy(bracket[bid].label,label);
733 | pushb(bid);
734 | bid++;
735 |
736 | /* close bracket */
737 | /*---------------*/
738 | }else if(*p==')'){
739 |
740 | b = popb();
741 | bracket[b].end = wid;
742 | bracket[b].buf_end = p-buff;
743 | p++;
744 |
745 | /* error */
746 | /*-------*/
747 | }else{
748 |
749 | Error("Reading sentence\n");
750 | }
751 | }
752 |
753 | if(!stackempty()){
754 | Error("Bracketing is unbalanced (too many open bracket)\n");
755 | }
756 |
757 | *wn = wid;
758 | *bn = bid;
759 |
760 | return(len);
761 | }
762 |
763 |
764 | /*----------------------*/
765 | /* stack operation */
766 | /* for bracketing pairs */
767 | /*----------------------*/
768 | void
769 | pushb(item)
770 | int item;
771 | {
772 | stack[stack_top++]=item;
773 | }
774 |
775 | int
776 | popb()
777 | {
778 | int item;
779 |
780 | item = stack[stack_top-1];
781 |
782 | if(stack_top-- < 0){
783 | Error("Bracketing unbalance (too many close bracket)\n");
784 | }
785 | return(item);
786 | }
787 |
788 | int
789 | stackempty()
790 | {
791 | if(stack_top==0){
792 | return(1);
793 | }else{
794 | return(0);
795 | }
796 | }
797 |
798 |
799 | /*------------------*/
800 | /* calculate result */
801 | /*------------------*/
802 | void
803 | calc_result(unsigned char *buf1,unsigned char *buf)
804 | {
805 | int i, j, l;
806 | int match, crossing, correct_tag;
807 |
808 | int last_i = -1;
809 |
810 | char my_buf[1000];
811 | int match_found = 0;
812 |
813 | char match_j[200];
814 | for (j = 0; j < bn2; ++j) {
815 | match_j[j] = 0;
816 | }
817 |
818 | /* ML */
819 | if (DEBUG>1)
820 | printf("\n");
821 |
822 |
823 | /* Find skip and error */
824 | /*---------------------*/
825 | if(wn2==0){
826 | Status = 2;
827 | individual_result(0,0,0,0,0,0);
828 | return;
829 | }
830 |
831 | if(wn1 != wn2){
832 | //if (DEBUG>1)
833 | //Error("Length unmatch (%d|%d)\n",wn1,wn2);
834 | fix_quote();
835 | if(wn1 != wn2){
836 | Error("Length unmatch (%d|%d)\n",wn1,wn2);
837 | individual_result(0,0,0,0,0,0);
838 | return;
839 | }
840 | }
841 |
842 | for(i=0;i1)
862 | printf("1.res=%d, 2.res=%d, 1.start=%d, 2.start=%d, 1.end=%d, 2.end=%d\n",bracket1[i].result,bracket2[j].result,bracket1[i].start,bracket2[j].start,bracket1[i].end,bracket2[j].end);
863 |
864 | // does bracket match?
865 | if(bracket1[i].result != 5 &&
866 | bracket2[j].result == 0 &&
867 | bracket1[i].start == bracket2[j].start && bracket1[i].end == bracket2[j].end) {
868 |
869 | // (1) do we not care about the label or (2) does the label match?
870 | if (F_label==0 || label_comp(bracket1[i].label,bracket2[j].label)==1) {
871 | bracket1[i].result = bracket2[j].result = 1;
872 | match++;
873 | match_found = 1;
874 | break;
875 | } else {
876 | if (DEBUG>1) {
877 | printf(" LABEL[%d-%d]: ",bracket1[i].start,bracket1[i].end-1);
878 | l = bracket1[i].buf_end-bracket1[i].buf_start;
879 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
880 | my_buf[l] = '\0';
881 | printf("%s\n",my_buf);
882 | }
883 | match_found = 1;
884 | match_j[j] = 1;
885 | }
886 | }
887 | }
888 |
889 | if (!match_found && bracket1[i].result != 5 && DEBUG>1) {
890 | /* ### ML 09/28/03: gold bracket with no corresponding test bracket */
891 | printf(" BRACKET[%d-%d]: ",bracket1[i].start,bracket1[i].end-1);
892 | l = bracket1[i].buf_end-bracket1[i].buf_start;
893 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
894 | my_buf[l] = '\0';
895 | printf("%s\n",my_buf);
896 | }
897 | match_found = 0;
898 | }
899 |
900 | for(j=0;j1) {
902 | /* test bracket with no corresponding gold bracket */
903 | printf(" EXTRA[%d-%d]: ",bracket2[j].start,bracket2[j].end-1);
904 | l = bracket2[j].buf_end-bracket2[j].buf_start;
905 | strncpy(my_buf,buf+bracket2[j].buf_start,l);
906 | my_buf[l] = '\0';
907 | printf("%s\n",my_buf);
908 | }
909 | }
910 |
911 | /* crossing */
912 | /*----------*/
913 | crossing = 0;
914 |
915 | /* crossing is counted based on the brackets */
916 | /* in test rather than gold file (by Mike) */
917 | for(j=0;j bracket2[j].start &&
923 | bracket1[i].end < bracket2[j].end) ||
924 | (bracket1[i].start > bracket2[j].start &&
925 | bracket1[i].start < bracket2[j].end &&
926 | bracket1[i].end > bracket2[j].end))){
927 |
928 | /* ### ML 09/01/03: get details on cross-brackettings */
929 | if (i != last_i) {
930 | if (DEBUG>1) {
931 | printf(" CROSSING[%d-%d]: ",bracket1[i].start,bracket1[i].end-1);
932 | l = bracket1[i].buf_end-bracket1[i].buf_start;
933 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
934 | my_buf[l] = '\0';
935 | printf("%s\n",my_buf);
936 |
937 | /* ML
938 | printf("\n CROSSING at bracket %d:\n",i-1);
939 | printf(" GOLD (tokens %d-%d): ",bracket1[i].start,bracket1[i].end-1);
940 | l = bracket1[i].buf_end-bracket1[i].buf_start;
941 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
942 | my_buf[l] = '\0';
943 | printf("%s\n",my_buf);
944 | */
945 | }
946 | last_i = i;
947 | }
948 |
949 | /* ML
950 | printf(" TEST (tokens %d-%d): ",bracket2[j].start,bracket2[j].end-1);
951 | l = bracket2[j].buf_end-bracket2[j].buf_start;
952 | strncpy(my_buf,buf+bracket2[j].buf_start,l);
953 | my_buf[l] = '\0';
954 | printf("%s\n",my_buf);
955 | */
956 |
957 | crossing++;
958 | break;
959 | }
960 | }
961 | }
962 |
963 | /* Tagging accuracy */
964 | /*------------------*/
965 | correct_tag=0;
966 | for(i=0;i1) {
983 | for(i=0;iindex;
1026 | int bra = quot->bracket;
1027 | s_terminal* term = "->term;
1028 | int k;
1029 | memmove(&terminal[ind+1],
1030 | &terminal[ind],
1031 | sizeof(s_terminal)*(MAX_WORD_IN_SENT-ind-1));
1032 | strcpy(terminal[ind].label,term->label);
1033 | strcpy(terminal[ind].word,term->word);
1034 | (*wn)++;
1035 | if (DEBUG>1)
1036 | printf("bra=%d, ind=%d\n",bra,ind);
1037 | for(k=0;k1)
1041 | printf("bracket[%d]={%d,%d}\n",k,bracket[k].start,bracket[k].end);
1042 | if (k>=bra) {
1043 | bracket[k].start++;
1044 | bracket[k].end++;
1045 | }
1046 | //if (bracket[k].start<=ind && bracket[k].end>=ind)
1047 | //bracket[k].end++;
1048 | }
1049 | if (DEBUG>1)
1050 | printf("endslen=%d\n",quot->endslen);
1051 | for(k=0;kendslen;k++) {
1052 | //printf("ends[%d]=%d",k,quot->ends[k]);
1053 | bracket[quot->ends[k]].end++;
1054 | }
1055 | //free(quot->ends);
1056 | }
1057 | /*
1058 | void
1059 | adjust_end(ind,bra)
1060 | int ind;
1061 | int bra;
1062 | {
1063 | for(k=0;k=bra)
1068 | bracket[k].end++;
1069 | }
1070 | }
1071 | */
1072 | void
1073 | massage_data()
1074 | {
1075 | int i, j;
1076 |
1077 | /* for GOLD */
1078 | /*----------*/
1079 | for(i=0;i0 && TOTAL_bn2>0){
1246 | printf(" %6.2f %6.2f %6d %5d %5d %5d",
1247 | (TOTAL_bn1>0?100.0*TOTAL_match/TOTAL_bn1:0.0),
1248 | (TOTAL_bn2>0?100.0*TOTAL_match/TOTAL_bn2:0.0),
1249 | TOTAL_match,
1250 | TOTAL_bn1,
1251 | TOTAL_bn2,
1252 | TOTAL_crossing);
1253 | }
1254 |
1255 | printf(" %5d %5d %6.2f",
1256 | TOTAL_word,
1257 | TOTAL_correct_tag,
1258 | (TOTAL_word>0?100.0*TOTAL_correct_tag/TOTAL_word:0.0));
1259 |
1260 | printf("\n");
1261 | printf("=== Summary ===\n");
1262 |
1263 | sentn = TOTAL_sent - TOTAL_error_sent - TOTAL_skip_sent;
1264 |
1265 | printf("\n-- All --\n");
1266 | printf("Number of sentence = %6d\n",TOTAL_sent);
1267 | printf("Number of Error sentence = %6d\n",TOTAL_error_sent);
1268 | printf("Number of Skip sentence = %6d\n",TOTAL_skip_sent);
1269 | printf("Number of Valid sentence = %6d\n",sentn);
1270 |
1271 | r = TOTAL_bn1>0 ? 100.0*TOTAL_match/TOTAL_bn1 : 0.0;
1272 | printf("Bracketing Recall = %6.2f\n",r);
1273 |
1274 | p = TOTAL_bn2>0 ? 100.0*TOTAL_match/TOTAL_bn2 : 0.0;
1275 | printf("Bracketing Precision = %6.2f\n",p);
1276 |
1277 | f = 2*p*r/(p+r);
1278 | printf("Bracketing FMeasure = %6.2f\n",f);
1279 |
1280 | printf("Complete match = %6.2f\n",
1281 | (sentn>0?100.0*TOTAL_comp_sent/sentn:0.0));
1282 | printf("Average crossing = %6.2f\n",
1283 | (sentn>0?1.0*TOTAL_crossing/sentn:0.0));
1284 | printf("No crossing = %6.2f\n",
1285 | (sentn>0?100.0*TOTAL_no_crossing/sentn:0.0));
1286 | printf("2 or less crossing = %6.2f\n",
1287 | (sentn>0?100.0*TOTAL_2L_crossing/sentn:0.0));
1288 | printf("Tagging accuracy = %6.2f\n",
1289 | (TOTAL_word>0?100.0*TOTAL_correct_tag/TOTAL_word:0.0));
1290 |
1291 | sentn = TOT40_sent - TOT40_error_sent - TOT40_skip_sent;
1292 |
1293 | printf("\n-- len<=%d --\n",TOT_cut_len);
1294 | printf("Number of sentence = %6d\n",TOT40_sent);
1295 | printf("Number of Error sentence = %6d\n",TOT40_error_sent);
1296 | printf("Number of Skip sentence = %6d\n",TOT40_skip_sent);
1297 | printf("Number of Valid sentence = %6d\n",sentn);
1298 |
1299 |
1300 | r = TOT40_bn1>0 ? 100.0*TOT40_match/TOT40_bn1 : 0.0;
1301 | printf("Bracketing Recall = %6.2f\n",r);
1302 |
1303 | p = TOT40_bn2>0 ? 100.0*TOT40_match/TOT40_bn2 : 0.0;
1304 | printf("Bracketing Precision = %6.2f\n",p);
1305 |
1306 | f = 2*p*r/(p+r);
1307 | printf("Bracketing FMeasure = %6.2f\n",f);
1308 |
1309 | printf("Complete match = %6.2f\n",
1310 | (sentn>0?100.0*TOT40_comp_sent/sentn:0.0));
1311 | printf("Average crossing = %6.2f\n",
1312 | (sentn>0?1.0*TOT40_crossing/sentn:0.0));
1313 | printf("No crossing = %6.2f\n",
1314 | (sentn>0?100.0*TOT40_no_crossing/sentn:0.0));
1315 | printf("2 or less crossing = %6.2f\n",
1316 | (sentn>0?100.0*TOT40_2L_crossing/sentn:0.0));
1317 | printf("Tagging accuracy = %6.2f\n",
1318 | (TOT40_word>0?100.0*TOT40_correct_tag/TOT40_word:0.0));
1319 |
1320 | }
1321 |
1322 |
1323 | /*--------------------------------*/
1324 | /* display individual information */
1325 | /*--------------------------------*/
1326 | void
1327 | dsp_info()
1328 | {
1329 | int i, n;
1330 |
1331 | printf("-<1>---(wn1=%3d, bn1=%3d)- ",wn1,bn1);
1332 | printf("-<2>---(wn2=%3d, bn2=%3d)-\n",wn2,bn2);
1333 |
1334 | n = (wn1>wn2?wn1:wn2);
1335 |
1336 | for(i=0;ibn2?bn1:bn2);
1354 |
1355 | for(i=0;iMax_error){
1503 | exit(1);
1504 | }
1505 | }
1506 |
1507 |
1508 | /*---------------------*/
1509 | /* fatal error to exit */
1510 | /*---------------------*/
1511 | void
1512 | Fatal(s,arg1,arg2,arg3)
1513 | char *s, *arg1, *arg2, *arg3;
1514 | {
1515 | fprintf(stderr,s,arg1,arg2,arg3);
1516 | exit(1);
1517 | }
1518 |
1519 |
1520 | /*-------*/
1521 | /* Usage */
1522 | /*-------*/
1523 | void
1524 | Usage()
1525 | {
1526 | fprintf(stderr," evalb [-dDh][-c n][-e n][-p param_file] gold-file test-file \n");
1527 | fprintf(stderr," \n");
1528 | fprintf(stderr," Evaluate bracketing in test-file against gold-file. \n");
1529 | fprintf(stderr," Return recall, precision, F-Measure, tag accuracy. \n");
1530 | fprintf(stderr," \n");
1531 | fprintf(stderr," \n");
1532 | fprintf(stderr," -d debug mode \n");
1533 | fprintf(stderr," -D debug mode plus bracketing info \n");
1534 | fprintf(stderr," -c n cut-off length forstatistics (def.=40)\n");
1535 | fprintf(stderr," -e n number of error to kill (default=10) \n");
1536 | fprintf(stderr," -p param_file parameter file \n");
1537 | fprintf(stderr," -h help \n");
1538 | }
1539 |
--------------------------------------------------------------------------------
/EVALB/new.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## 0: No debugging ##
4 | ## 1: print data for individual sentence ##
5 | ## 2: print detailed bracketing info ##
6 | ##------------------------------------------##
7 | DEBUG 0
8 |
9 | ##------------------------------------------##
10 | ## MAX error ##
11 | ## Number of error to stop the process. ##
12 | ## This is useful if there could be ##
13 | ## tokanization error. ##
14 | ## The process will stop when this number##
15 | ## of errors are accumulated. ##
16 | ##------------------------------------------##
17 | MAX_ERROR 10
18 |
19 | ##------------------------------------------##
20 | ## Cut-off length for statistics ##
21 | ## At the end of evaluation, the ##
22 | ## statistics for the senetnces of length##
23 | ## less than or equal to this number will##
24 | ## be shown, on top of the statistics ##
25 | ## for all the sentences ##
26 | ##------------------------------------------##
27 | CUTOFF_LEN 40
28 |
29 | ##------------------------------------------##
30 | ## unlabeled or labeled bracketing ##
31 | ## 0: unlabeled bracketing ##
32 | ## 1: labeled bracketing ##
33 | ##------------------------------------------##
34 | LABELED 1
35 |
36 | ##------------------------------------------##
37 | ## Delete labels ##
38 | ## list of labels to be ignored. ##
39 | ## If it is a pre-terminal label, delete ##
40 | ## the word along with the brackets. ##
41 | ## If it is a non-terminal label, just ##
42 | ## delete the brackets (don't delete ##
43 | ## deildrens). ##
44 | ##------------------------------------------##
45 | DELETE_LABEL TOP
46 | DELETE_LABEL S1
47 | DELETE_LABEL -NONE-
48 | DELETE_LABEL ,
49 | DELETE_LABEL :
50 | DELETE_LABEL ``
51 | DELETE_LABEL ''
52 | DELETE_LABEL .
53 | DELETE_LABEL ?
54 | DELETE_LABEL !
55 |
56 | ##------------------------------------------##
57 | ## Delete labels for length calculation ##
58 | ## list of labels to be ignored for ##
59 | ## length calculation purpose ##
60 | ##------------------------------------------##
61 | DELETE_LABEL_FOR_LENGTH -NONE-
62 |
63 | ##------------------------------------------##
64 | ## Labels to be considered for misquote ##
65 | ## (could be possesive or quote) ##
66 | ##------------------------------------------##
67 | QUOTE_LABEL ``
68 | QUOTE_LABEL ''
69 | QUOTE_LABEL POS
70 |
71 | ##------------------------------------------##
72 | ## These ones are less common, but ##
73 | ## are on occasion output by parsers: ##
74 | ##------------------------------------------##
75 | QUOTE_LABEL NN
76 | QUOTE_LABEL CD
77 | QUOTE_LABEL VBZ
78 | QUOTE_LABEL :
79 |
80 | ##------------------------------------------##
81 | ## Equivalent labels, words ##
82 | ## the pairs are considered equivalent ##
83 | ## This is non-directional. ##
84 | ##------------------------------------------##
85 | EQ_LABEL ADVP PRT
86 |
87 | # EQ_WORD Example example
88 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.gld:
--------------------------------------------------------------------------------
1 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
2 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
3 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
4 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
5 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
6 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
7 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
8 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
9 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
10 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
11 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
12 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
13 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
14 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
15 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
16 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
17 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
18 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
19 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
20 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
21 | (S (A-SBJ-1 (P this)) (B-WHATEVER (Q is) (A (R a) (T test))))
22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))))
23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *))
24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *))
25 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## print out data for individual sentence ##
4 | ##------------------------------------------##
5 | DEBUG 0
6 |
7 | ##------------------------------------------##
8 | ## MAX error ##
9 | ## Number of error to stop the process. ##
10 | ## This is useful if there could be ##
11 | ## tokanization error. ##
12 | ## The process will stop when this number##
13 | ## of errors are accumulated. ##
14 | ##------------------------------------------##
15 | MAX_ERROR 10
16 |
17 | ##------------------------------------------##
18 | ## Cut-off length for statistics ##
19 | ## At the end of evaluation, the ##
20 | ## statistics for the senetnces of length##
21 | ## less than or equal to this number will##
22 | ## be shown, on top of the statistics ##
23 | ## for all the sentences ##
24 | ##------------------------------------------##
25 | CUTOFF_LEN 40
26 |
27 | ##------------------------------------------##
28 | ## unlabeled or labeled bracketing ##
29 | ## 0: unlabeled bracketing ##
30 | ## 1: labeled bracketing ##
31 | ##------------------------------------------##
32 | LABELED 1
33 |
34 | ##------------------------------------------##
35 | ## Delete labels ##
36 | ## list of labels to be ignored. ##
37 | ## If it is a pre-terminal label, delete ##
38 | ## the word along with the brackets. ##
39 | ## If it is a non-terminal label, just ##
40 | ## delete the brackets (don't delete ##
41 | ## deildrens). ##
42 | ##------------------------------------------##
43 | DELETE_LABEL TOP
44 | DELETE_LABEL -NONE-
45 | DELETE_LABEL ,
46 | DELETE_LABEL :
47 | DELETE_LABEL ``
48 | DELETE_LABEL ''
49 |
50 | ##------------------------------------------##
51 | ## Delete labels for length calculation ##
52 | ## list of labels to be ignored for ##
53 | ## length calculation purpose ##
54 | ##------------------------------------------##
55 | DELETE_LABEL_FOR_LENGTH -NONE-
56 |
57 |
58 | ##------------------------------------------##
59 | ## Equivalent labels, words ##
60 | ## the pairs are considered equivalent ##
61 | ## This is non-directional. ##
62 | ##------------------------------------------##
63 | EQ_LABEL T TT
64 |
65 | EQ_WORD This this
66 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.rsl:
--------------------------------------------------------------------------------
1 | Sent. Matched Bracket Cross Correct Tag
2 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy
3 | ============================================================================
4 | 1 4 0 100.00 100.00 4 4 4 0 4 4 100.00
5 | 2 4 0 75.00 75.00 3 4 4 0 4 4 100.00
6 | 3 4 0 100.00 100.00 4 4 4 0 4 3 75.00
7 | 4 4 0 75.00 75.00 3 4 4 0 4 3 75.00
8 | 5 4 0 75.00 75.00 3 4 4 0 4 4 100.00
9 | 6 4 0 50.00 66.67 2 4 3 1 4 4 100.00
10 | 7 4 0 25.00 100.00 1 4 1 0 4 4 100.00
11 | 8 4 0 0.00 0.00 0 4 0 0 4 4 100.00
12 | 9 4 0 100.00 80.00 4 4 5 0 4 4 100.00
13 | 10 4 0 100.00 50.00 4 4 8 0 4 4 100.00
14 | 11 4 2 0.00 0.00 0 0 0 0 4 0 0.00
15 | 12 4 1 0.00 0.00 0 0 0 0 4 0 0.00
16 | 13 4 1 0.00 0.00 0 0 0 0 4 0 0.00
17 | 14 4 2 0.00 0.00 0 0 0 0 4 0 0.00
18 | 15 4 0 100.00 100.00 4 4 4 0 4 4 100.00
19 | 16 4 1 0.00 0.00 0 0 0 0 4 0 0.00
20 | 17 4 1 0.00 0.00 0 0 0 0 4 0 0.00
21 | 18 4 0 100.00 100.00 4 4 4 0 4 4 100.00
22 | 19 4 0 100.00 100.00 4 4 4 0 4 4 100.00
23 | 20 4 1 0.00 0.00 0 0 0 0 4 0 0.00
24 | 21 4 0 100.00 100.00 4 4 4 0 4 4 100.00
25 | 22 44 0 100.00 100.00 34 34 34 0 44 44 100.00
26 | 23 4 0 100.00 100.00 4 4 4 0 4 4 100.00
27 | 24 5 0 100.00 100.00 4 4 4 0 4 4 100.00
28 | ============================================================================
29 | 87.76 90.53 86 98 95 16 108 106 98.15
30 | === Summary ===
31 |
32 | -- All --
33 | Number of sentence = 24
34 | Number of Error sentence = 5
35 | Number of Skip sentence = 2
36 | Number of Valid sentence = 17
37 | Bracketing Recall = 87.76
38 | Bracketing Precision = 90.53
39 | Complete match = 52.94
40 | Average crossing = 0.06
41 | No crossing = 94.12
42 | 2 or less crossing = 100.00
43 | Tagging accuracy = 98.15
44 |
45 | -- len<=40 --
46 | Number of sentence = 23
47 | Number of Error sentence = 5
48 | Number of Skip sentence = 2
49 | Number of Valid sentence = 16
50 | Bracketing Recall = 81.25
51 | Bracketing Precision = 85.25
52 | Complete match = 50.00
53 | Average crossing = 0.06
54 | No crossing = 93.75
55 | 2 or less crossing = 100.00
56 | Tagging accuracy = 96.88
57 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.tst:
--------------------------------------------------------------------------------
1 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
2 | (S (A (P this)) (B (Q is) (C (R a) (T test))))
3 | (S (A (P this)) (B (Q is) (A (R a) (U test))))
4 | (S (C (P this)) (B (Q is) (A (R a) (U test))))
5 | (S (A (P this)) (B (Q is) (R a) (A (T test))))
6 | (S (A (P this) (Q is)) (A (R a) (T test)))
7 | (S (P this) (Q is) (R a) (T test))
8 | (P this) (Q is) (R a) (T test)
9 | (S (A (P this)) (B (Q is) (A (A (R a) (T test)))))
10 | (S (A (P this)) (B (Q is) (A (A (A (A (A (R a) (T test))))))))
11 |
12 | (S (A (P this)) (B (Q was) (A (A (R a) (T test)))))
13 | (S (A (P this)) (B (Q is) (U not) (A (A (R a) (T test)))))
14 |
15 | (TOP (S (A (P this)) (B (Q is) (A (R a) (T test)))))
16 | (S (A (P this)) (NONE *) (B (Q is) (A (R a) (T test))))
17 | (S (A (P this)) (S (NONE abc) (A (NONE *))) (B (Q is) (A (R a) (T test))))
18 | (S (A (P this)) (B (Q is) (A (R a) (TT test))))
19 | (S (A (P This)) (B (Q is) (A (R a) (T test))))
20 | (S (A (P That)) (B (Q is) (A (R a) (T test))))
21 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))))
23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *))
24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *))
25 |
--------------------------------------------------------------------------------
/EVALB/tgrep_proc.prl:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/perl
2 |
3 | while(<>)
4 | {
5 | if(m/TOP/) #skip lines which are blank
6 | {
7 | print;
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yikang Shen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Ordered_Memory_Slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikangshen/Ordered-Memory/2a0d3a22fb70216993a7faf25de3515f42e10431/Ordered_Memory_Slides.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ordered Memory
2 |
3 | This repository contains the code used for [Ordered Memory](https://arxiv.org/abs/1910.13466).
4 |
5 | The code comes with instructions for experiments:
6 | + [propositional logic experiments](https://www.aclweb.org/anthology/W15-4002.pdf)
7 |
8 | + [ListOps](https://arxiv.org/pdf/1804.06028.pdf)
9 |
10 | + [SST](https://nlp.stanford.edu/sentiment/treebank.html)
11 |
12 | If you use this code or our results in your research, please cite as appropriate:
13 |
14 | ```
15 | @incollection{NIPS2019_8748,
16 | title = {Ordered Memory},
17 | author = {Shen, Yikang and Tan, Shawn and Hosseini, Arian and Lin, Zhouhan and Sordoni, Alessandro and Courville, Aaron C},
18 | booktitle = {Advances in Neural Information Processing Systems 32},
19 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
20 | pages = {5038--5049},
21 | year = {2019},
22 | publisher = {Curran Associates, Inc.},
23 | url = {http://papers.nips.cc/paper/8748-ordered-memory.pdf}
24 | }
25 |
26 | ```
27 |
28 | ## Software Requirements
29 |
30 | Python 3, PyTorch 1.2, and torchtext are required for the current codebase.
31 |
32 | ## Experiments
33 |
34 | ### Propositional Logic
35 |
36 | + `python -u proplog.py --cuda --save logic.pt`
37 |
38 | ### ListOps
39 |
40 | + `python -u listops.py --cuda --name listops.pt`
41 |
42 | ### SST
43 |
44 | + `python -u main.py --subtrees --cuda --name sentiment.pt --glove/--elmo (--fine-grained)`
45 |
46 |
--------------------------------------------------------------------------------
/data/listops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikangshen/Ordered-Memory/2a0d3a22fb70216993a7faf25de3515f42e10431/data/listops/__init__.py
--------------------------------------------------------------------------------
/data/listops/base.py:
--------------------------------------------------------------------------------
1 | from spinn import util
2 |
3 | NUMBERS = list(range(10))
4 |
5 | FIXED_VOCABULARY = {str(x): i + 1 for i, x in enumerate(NUMBERS)}
6 | FIXED_VOCABULARY.update({
7 | util.PADDING_TOKEN: 0,
8 | "[MIN": len(FIXED_VOCABULARY) + 1,
9 | "[MAX": len(FIXED_VOCABULARY) + 2,
10 | "[FIRST": len(FIXED_VOCABULARY) + 3,
11 | "[LAST": len(FIXED_VOCABULARY) + 4,
12 | "[MED": len(FIXED_VOCABULARY) + 5,
13 | "[SM": len(FIXED_VOCABULARY) + 6,
14 | "[PM": len(FIXED_VOCABULARY) + 7,
15 | "[FLSUM": len(FIXED_VOCABULARY) + 8,
16 | "]": len(FIXED_VOCABULARY) + 9
17 | })
18 | assert len(set(FIXED_VOCABULARY.values())) == len(list(FIXED_VOCABULARY.values()))
19 |
--------------------------------------------------------------------------------
/data/listops/load_listops_data.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from spinn import util
3 |
4 | from spinn.data.listops.base import FIXED_VOCABULARY
5 |
6 | SENTENCE_PAIR_DATA = False
7 | OUTPUTS = list(range(10))
8 | LABEL_MAP = {str(x): i for i, x in enumerate(OUTPUTS)}
9 |
10 | Node = namedtuple('Node', 'tag span')
11 |
12 |
13 | def spans(transitions, tokens=None):
14 | n = (len(transitions) + 1) // 2
15 | stack = []
16 | buf = [Node("leaf", (l, r)) for l, r in zip(list(range(n)), list(range(1, n + 1)))]
17 | buf = list(reversed(buf))
18 |
19 | nodes = []
20 | reduced = [False] * n
21 |
22 | def SHIFT(item):
23 | nodes.append(item)
24 | return item
25 |
26 | def REDUCE(l, r):
27 | tag = None
28 | i = r.span[1] - 1
29 | if tokens is not None and tokens[i] == ']' and not reduced[i]:
30 | reduced[i] = True
31 | tag = "struct"
32 | new_stack_item = Node(tag=tag, span=(l.span[0], r.span[1]))
33 | nodes.append(new_stack_item)
34 | return new_stack_item
35 |
36 | for t in transitions:
37 | if t == 0:
38 | stack.append(SHIFT(buf.pop()))
39 | elif t == 1:
40 | r, l = stack.pop(), stack.pop()
41 | stack.append(REDUCE(l, r))
42 |
43 | return nodes
44 |
45 |
46 | def load_data(path, lowercase=None, choose=lambda x: True, eval_mode=False):
47 | examples = []
48 | with open(path) as f:
49 | for example_id, line in enumerate(f):
50 | line = line.strip()
51 | label, seq = line.split('\t')
52 | if len(seq) <= 1:
53 | continue
54 |
55 | tokens, transitions = util.ConvertBinaryBracketedSeq(
56 | seq.split(' '))
57 |
58 | example = {}
59 | example["label"] = label
60 | example["sentence"] = seq
61 | example["tokens"] = tokens
62 | example["transitions"] = transitions
63 | example["example_id"] = str(example_id)
64 |
65 | examples.append(example)
66 | return examples
67 |
--------------------------------------------------------------------------------
/data/listops/make_data.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 |
4 | MIN = "[MIN"
5 | MAX = "[MAX"
6 | MED = "[MED"
7 | FIRST = "[FIRST"
8 | LAST = "[LAST"
9 | SUM_MOD = "[SM"
10 | END = "]"
11 |
12 | OPERATORS = [MIN, MAX, MED, SUM_MOD] # , FIRST, LAST]
13 | VALUES = range(10)
14 |
15 | VALUE_P = 0.25
16 | MAX_ARGS = 5
17 | MAX_DEPTH = 20
18 |
19 | DATA_POINTS = 1000000
20 |
21 |
22 | def generate_tree(depth):
23 | if depth < MAX_DEPTH:
24 | r = random.random()
25 | else:
26 | r = 1
27 |
28 | if r > VALUE_P:
29 | value = random.choice(VALUES)
30 | return value
31 | else:
32 | num_values = random.randint(2, MAX_ARGS)
33 | values = []
34 | for _ in range(num_values):
35 | values.append(generate_tree(depth + 1))
36 |
37 | op = random.choice(OPERATORS)
38 | t = (op, values[0])
39 | for value in values[1:]:
40 | t = (t, value)
41 | t = (t, END)
42 | return t
43 |
44 |
45 | def to_string(t, parens=True):
46 | if isinstance(t, str):
47 | return t
48 | elif isinstance(t, int):
49 | return str(t)
50 | else:
51 | if parens:
52 | return '( ' + to_string(t[0]) + ' ' + to_string(t[1]) + ' )'
53 |
54 |
55 | def to_value(t):
56 | if not isinstance(t, tuple):
57 | return t
58 | l = to_value(t[0])
59 | r = to_value(t[1])
60 | if l in OPERATORS: # Create an unsaturated function.
61 | return (l, [r])
62 | elif r == END: # l must be an unsaturated function.
63 | if l[0] == MIN:
64 | return min(l[1])
65 | elif l[0] == MAX:
66 | return max(l[1])
67 | elif l[0] == FIRST:
68 | return l[1][0]
69 | elif l[0] == LAST:
70 | return l[1][-1]
71 | elif l[0] == MED:
72 | return int(np.median(l[1]))
73 | elif l[0] == SUM_MOD:
74 | return (np.sum(l[1]) % 10)
75 | elif isinstance(l, tuple): # We've hit an unsaturated function and an argument.
76 | return (l[0], l[1] + [r])
77 |
78 |
79 | data = set()
80 | while len(data) < DATA_POINTS:
81 | data.add(generate_tree(1))
82 |
83 | for example in data:
84 | print(str(to_value(example)) + '\t' + to_string(example))
--------------------------------------------------------------------------------
/data/propositionallogic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikangshen/Ordered-Memory/2a0d3a22fb70216993a7faf25de3515f42e10431/data/propositionallogic/__init__.py
--------------------------------------------------------------------------------
/data/propositionallogic/generate_neg_set_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from itertools import *
3 | from collections import *
4 | import random
5 |
6 |
7 | def powerset(iterable):
8 | "From itertools: powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
9 | s = list(iterable)
10 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
11 |
12 |
13 | def get_candidate_worlds(num_vars):
14 | return powerset(set(range(num_vars)))
15 |
16 |
17 | def get_satisfying_worlds_for_tree(tree, candidate_worlds):
18 | if isinstance(tree, tuple):
19 | if tree[0] == 'not':
20 | child = get_satisfying_worlds_for_tree(tree[1], candidate_worlds)
21 | return candidate_worlds.difference(child)
22 | else:
23 | left = get_satisfying_worlds_for_tree(tree[0], candidate_worlds)
24 | right = get_satisfying_worlds_for_tree(tree[2], candidate_worlds)
25 | if tree[1] == "and":
26 | return left.intersection(right)
27 | elif tree[1] == "or":
28 | return left.union(right)
29 | else:
30 | print 'syntax error', tree
31 | else:
32 | result = []
33 | for world in candidate_worlds:
34 | if tree in world:
35 | result.append(world)
36 | return set(result)
37 |
38 |
39 | def compute_relation(left, right, universe):
40 | ne_intersection = left.intersection(right)
41 | ne_just_left = left.difference(right)
42 | ne_just_right = right.difference(left)
43 | ne_outside = universe.difference(left.union(right))
44 | if ne_intersection and not ne_just_right and not ne_just_left and ne_outside:
45 | return "="
46 | elif ne_intersection and ne_just_right and not ne_just_left and ne_outside:
47 | return "<"
48 | elif ne_intersection and not ne_just_right and ne_just_left and ne_outside:
49 | return ">"
50 | elif not ne_intersection and ne_just_right and ne_just_left and not ne_outside:
51 | return "^"
52 | elif not ne_intersection and ne_just_right and ne_just_left and ne_outside:
53 | return "|"
54 | elif ne_intersection and ne_just_right and ne_just_left and not ne_outside:
55 | return "v"
56 | else:
57 | return "#"
58 |
59 |
60 | def create_sub_statement(universe, maxlen):
61 | operator = random.choice(operators)
62 | temp = ()
63 | if operator == '0' or maxlen < 2:
64 | temp = random.choice(list(universe))
65 | else:
66 | lhs = create_sub_statement(universe, maxlen / 2)
67 | rhs = create_sub_statement(universe, maxlen / 2)
68 | temp = tuple([lhs, operator, rhs])
69 |
70 | neg_or_none = random.choice(neg_or_nones)
71 | if neg_or_none == '0':
72 | return temp
73 | else:
74 | return tuple([neg_or_none, temp])
75 |
76 |
77 | def uniq(seq, idfun=None):
78 | # order preserving
79 | if idfun is None:
80 | def idfun(x):
81 | return x
82 | seen = {}
83 | result = []
84 | for item in seq:
85 | marker = idfun(item)
86 | # in old Python versions:
87 | # if seen.has_key(marker)
88 | # but in new ones:
89 | if marker in seen:
90 | continue
91 | seen[marker] = 1
92 | result.append(item)
93 | return result
94 |
95 |
96 | def to_string(expr, individuals):
97 | if isinstance(expr, int):
98 | return individuals[expr]
99 | if isinstance(expr, str):
100 | return expr
101 | elif len(expr) == 3:
102 | return "( " + to_string(expr[0], individuals) + " ( " + to_string(expr[1], individuals) + " " + to_string(expr[2], individuals) + " ) )"
103 | else:
104 | return "( " + to_string(expr[0], individuals) + " " + to_string(expr[1], individuals) + " )"
105 |
106 |
107 | def get_len(tree):
108 | if isinstance(tree, tuple):
109 | accum = 0
110 | for entry in tree:
111 | accum += get_len(entry)
112 | return accum
113 | elif tree == 'and' or tree == 'or' or tree == 'not':
114 | return 1
115 | else:
116 | return 0
117 |
118 | individuals = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
119 |
120 | worlds = set(get_candidate_worlds(6))
121 | universe = set(range(6))
122 |
123 | neg_or_nones = ['not', '0', '0']
124 | operators = ['and', 'or', 'and', 'or', '0', '0', '0', '0', '0']
125 |
126 |
127 | stats = Counter()
128 | total = 0
129 | outputs = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [],
130 | 6: [], 7: [], 8: [], 9: [], 10: [], 11: [], 12: []}
131 | while total < 500000:
132 | subuniverse = random.sample(universe, 4)
133 | lhs = create_sub_statement(subuniverse, 12)
134 | rhs = create_sub_statement(subuniverse, 12)
135 | sat1 = get_satisfying_worlds_for_tree(lhs, worlds)
136 | sat2 = get_satisfying_worlds_for_tree(rhs, worlds)
137 | if sat1 == worlds or len(sat1) == 0:
138 | continue
139 | if sat2 == worlds or len(sat2) == 0:
140 | continue
141 | rel = compute_relation(sat1, sat2, worlds)
142 |
143 | if rel != "?":
144 | stats[rel] += 1
145 | total += 1
146 | max_len = min(max(get_len(rhs), get_len(lhs)), 12)
147 | outputs[max_len].append("" + rel + "\t" + to_string(
148 | lhs, individuals) + "\t" + to_string(rhs, individuals))
149 |
150 | TRAIN_PORTION = 0.85
151 |
152 | for length in outputs.keys():
153 | outputs[length] = uniq(outputs[length])
154 |
155 | filename = 'train' + str(length)
156 | f = open(filename, 'w')
157 | for i in range(int(TRAIN_PORTION * len(outputs[length]))):
158 | output = outputs[length][i]
159 | f.write(output + "\n")
160 | f.close()
161 |
162 | filename = 'test' + str(length)
163 | f = open(filename, 'w')
164 | for i in range(int(TRAIN_PORTION * len(outputs[length])), len(outputs[length])):
165 | output = outputs[length][i]
166 | f.write(output + "\n")
167 | f.close()
168 |
169 | print stats
170 |
--------------------------------------------------------------------------------
/data/propositionallogic/test0:
--------------------------------------------------------------------------------
1 | # a b
2 | # a e
3 | = b b
4 | # b d
5 | # f e
6 | # c e
7 |
--------------------------------------------------------------------------------
/data/propositionallogic/test1:
--------------------------------------------------------------------------------
1 | # ( d ( and e ) ) ( d ( and c ) )
2 | # ( e ( and d ) ) ( f ( or f ) )
3 | > ( b ( or c ) ) ( b ( and d ) )
4 | # ( e ( or e ) ) ( a ( and c ) )
5 | # ( not d ) ( c ( or b ) )
6 | # ( a ( and a ) ) ( d ( and f ) )
7 | # ( d ( or f ) ) ( b ( or a ) )
8 | # ( d ( and b ) ) ( a ( or a ) )
9 | > ( f ( or b ) ) ( f ( or f ) )
10 | # ( c ( or c ) ) ( b ( and b ) )
11 | # ( f ( and f ) ) ( a ( and a ) )
12 | < ( d ( and a ) ) ( a ( or c ) )
13 | < ( f ( and c ) ) ( f ( and f ) )
14 | # ( e ( and a ) ) ( f ( and a ) )
15 | # ( a ( and c ) ) ( not f )
16 | # ( f ( and f ) ) ( b ( or e ) )
17 | > ( b ( or f ) ) ( b ( and d ) )
18 | # ( d ( and f ) ) ( e ( or e ) )
19 | # ( f ( and e ) ) ( b ( and f ) )
20 | > ( a ( or a ) ) ( a ( and e ) )
21 | < ( b ( and e ) ) ( e ( or b ) )
22 | # ( d ( or e ) ) ( b ( and b ) )
23 | # ( a ( or d ) ) ( d ( or b ) )
24 | # ( b ( and b ) ) ( c ( and c ) )
25 | < ( a ( or a ) ) ( e ( or a ) )
26 | # ( c ( and f ) ) ( not b )
27 | # ( b ( or a ) ) ( not c )
28 | # ( b ( and a ) ) ( e ( and e ) )
29 | # ( not b ) ( c ( and e ) )
30 | > ( d ( or d ) ) ( e ( and d ) )
31 | > ( f ( or a ) ) ( c ( and f ) )
32 | # ( f ( or a ) ) ( b ( or f ) )
33 | # ( e ( and f ) ) ( a ( and a ) )
34 | # ( e ( or a ) ) ( b ( and b ) )
35 | # ( a ( or e ) ) ( e ( or c ) )
36 | > ( d ( or c ) ) ( d ( and b ) )
37 | < ( b ( and d ) ) ( d ( or c ) )
38 | > ( d ( or f ) ) ( a ( and f ) )
39 | < ( a ( and f ) ) ( e ( or a ) )
40 | = ( c ( and e ) ) ( e ( and c ) )
41 | > ( f ( or f ) ) ( f ( and b ) )
42 | # ( f ( and c ) ) ( d ( or d ) )
43 | # ( c ( and c ) ) ( d ( or d ) )
44 | # ( d ( and d ) ) ( b ( and c ) )
45 | > ( b ( or f ) ) ( b ( and b ) )
46 | < ( a ( and e ) ) ( e ( or a ) )
47 | = ( c ( and d ) ) ( d ( and c ) )
48 | # ( d ( and d ) ) ( e ( or e ) )
49 | # ( c ( or d ) ) ( d ( or a ) )
50 | # ( a ( or e ) ) ( c ( or c ) )
51 | > ( d ( or f ) ) ( d ( and e ) )
52 | # ( d ( or d ) ) ( e ( or b ) )
53 | = ( e ( or f ) ) ( f ( or e ) )
54 | # ( b ( and d ) ) ( c ( and d ) )
55 | > ( c ( or b ) ) ( b ( and b ) )
56 | # ( d ( or d ) ) ( a ( or e ) )
57 | > ( e ( and e ) ) ( e ( and b ) )
58 | # ( b ( or a ) ) ( f ( or c ) )
59 | # ( f ( or d ) ) ( d ( or c ) )
60 | > ( c ( or a ) ) ( a ( and a ) )
61 | = ( e ( or d ) ) ( e ( or d ) )
62 | = ( c ( or e ) ) ( c ( or e ) )
63 | # ( a ( and e ) ) ( f ( or d ) )
64 | # ( a ( or b ) ) ( c ( or d ) )
65 | > ( e ( or a ) ) ( a ( and a ) )
66 | # ( e ( or f ) ) ( d ( or f ) )
67 | # ( b ( or f ) ) ( e ( or c ) )
68 | # ( b ( or a ) ) ( e ( or f ) )
69 | # ( a ( or e ) ) ( f ( and b ) )
70 | < ( f ( and a ) ) ( b ( or f ) )
71 | # ( a ( or e ) ) ( a ( or b ) )
72 | < ( b ( or b ) ) ( b ( or c ) )
73 | # ( f ( or c ) ) ( a ( and d ) )
74 | # ( c ( and a ) ) ( e ( and d ) )
75 | < ( f ( and a ) ) ( a ( or f ) )
76 | # ( a ( and e ) ) ( a ( and d ) )
77 | # ( e ( and a ) ) ( e ( and d ) )
78 | # ( b ( or d ) ) ( f ( and f ) )
79 | < ( a ( and c ) ) ( c ( or d ) )
80 | < ( c ( or c ) ) ( c ( or f ) )
81 | > ( d ( or b ) ) ( b ( and a ) )
82 | # ( d ( and e ) ) ( b ( and c ) )
83 | # ( e ( or a ) ) ( e ( or f ) )
84 | # ( c ( and d ) ) ( c ( and e ) )
85 | # ( a ( or a ) ) ( c ( or f ) )
86 | < ( c ( or c ) ) ( c ( or e ) )
87 | # ( f ( or f ) ) ( c ( or a ) )
88 | # ( c ( and a ) ) ( e ( and a ) )
89 | > ( d ( or a ) ) ( c ( and d ) )
90 | # ( e ( or e ) ) ( d ( and c ) )
91 | > ( b ( or d ) ) ( b ( and a ) )
92 | < ( b ( and d ) ) ( b ( or d ) )
93 | # ( c ( and a ) ) ( f ( or f ) )
94 | # ( b ( or e ) ) ( d ( or b ) )
95 | # ( e ( and f ) ) ( e ( and a ) )
96 | > ( a ( or a ) ) ( a ( and d ) )
97 | > ( f ( and f ) ) ( f ( and b ) )
98 | > ( e ( or f ) ) ( c ( and e ) )
99 | > ( c ( or d ) ) ( d ( and c ) )
100 | > ( e ( or d ) ) ( e ( and e ) )
101 | # ( f ( or e ) ) ( e ( or a ) )
102 | > ( e ( or b ) ) ( e ( and a ) )
103 | # ( d ( or f ) ) ( b ( and a ) )
104 | < ( e ( and a ) ) ( e ( and e ) )
105 | < ( d ( and d ) ) ( a ( or d ) )
106 | < ( b ( and e ) ) ( c ( or e ) )
107 | > ( f ( or b ) ) ( f ( and f ) )
108 | # ( d ( and a ) ) ( c ( and b ) )
109 | = ( e ( and c ) ) ( c ( and e ) )
110 | # ( c ( and c ) ) ( f ( or d ) )
111 | < ( e ( and e ) ) ( c ( or e ) )
112 | # ( e ( or e ) ) ( d ( and d ) )
113 | # ( f ( and f ) ) ( b ( or a ) )
114 | # ( not c ) ( b ( or e ) )
115 | # ( a ( or c ) ) ( f ( or f ) )
116 | # ( d ( or e ) ) ( b ( or a ) )
117 | > ( a ( or f ) ) ( d ( and a ) )
118 | > ( a ( or c ) ) ( c ( and f ) )
119 | # ( not f ) ( a ( and d ) )
120 | > ( e ( or f ) ) ( e ( and f ) )
121 | < ( f ( and f ) ) ( e ( or f ) )
122 | # ( b ( and d ) ) ( b ( and e ) )
123 | # ( c ( and f ) ) ( c ( and a ) )
124 | # ( c ( and f ) ) ( a ( and c ) )
125 | # ( c ( or c ) ) ( f ( and e ) )
126 | # ( b ( or c ) ) ( c ( or a ) )
127 | # ( b ( and b ) ) ( e ( or a ) )
128 | # ( f ( or d ) ) ( b ( or e ) )
129 | # ( c ( or b ) ) ( e ( or a ) )
130 | # ( a ( and a ) ) ( d ( and c ) )
131 | # ( b ( and b ) ) ( c ( and f ) )
132 | # ( a ( and a ) ) ( c ( or e ) )
133 | # ( e ( or f ) ) ( c ( or c ) )
134 | # ( f ( and f ) ) ( a ( and c ) )
135 | # ( not a ) ( b ( and f ) )
136 | # ( a ( or e ) ) ( b ( or e ) )
137 | < ( a ( and e ) ) ( d ( or e ) )
138 | # ( a ( and a ) ) ( f ( or d ) )
139 | = ( f ( or e ) ) ( f ( or e ) )
140 | # ( f ( and f ) ) ( b ( or c ) )
141 | # ( c ( or f ) ) ( d ( or a ) )
142 | < ( f ( and a ) ) ( f ( or f ) )
143 | = ( d ( or b ) ) ( b ( or d ) )
144 | > ( a ( or b ) ) ( d ( and a ) )
145 | < ( d ( and e ) ) ( d ( or b ) )
146 | # ( b ( and b ) ) ( e ( and e ) )
147 | < ( c ( and f ) ) ( b ( or f ) )
148 | # ( f ( or e ) ) ( f ( or b ) )
149 | < ( e ( or e ) ) ( d ( or e ) )
150 | < ( a ( and e ) ) ( b ( or e ) )
151 | # ( a ( or f ) ) ( f ( or b ) )
152 | > ( f ( or a ) ) ( a ( and a ) )
153 | > ( d ( or b ) ) ( b ( and b ) )
154 | = ( a ( and d ) ) ( a ( and d ) )
155 | # ( f ( or f ) ) ( e ( and c ) )
156 | # ( e ( or c ) ) ( e ( or b ) )
157 | # ( f ( and e ) ) ( a ( and f ) )
158 | # ( a ( or c ) ) ( d ( or a ) )
159 | # ( d ( and d ) ) ( f ( and e ) )
160 | # ( e ( and d ) ) ( e ( and f ) )
161 | # ( c ( or a ) ) ( c ( or f ) )
162 | < ( f ( and f ) ) ( c ( or f ) )
163 | = ( d ( and d ) ) ( d ( and d ) )
164 | < ( d ( and e ) ) ( c ( or d ) )
165 | = ( b ( or b ) ) ( b ( or b ) )
166 | < ( b ( and a ) ) ( a ( and a ) )
167 | < ( a ( and c ) ) ( c ( or c ) )
168 | # ( c ( and c ) ) ( f ( and a ) )
169 | # ( d ( and b ) ) ( c ( and c ) )
170 | # ( c ( or d ) ) ( c ( or f ) )
171 | > ( e ( or a ) ) ( a ( and d ) )
172 | > ( c ( or d ) ) ( c ( and a ) )
173 | < ( e ( and a ) ) ( a ( or b ) )
174 | # ( c ( or e ) ) ( f ( or c ) )
175 | > ( c ( or a ) ) ( c ( and c ) )
176 | # ( b ( and f ) ) ( d ( and b ) )
177 | # ( d ( and a ) ) ( e ( and d ) )
178 | # ( f ( and b ) ) ( f ( and e ) )
179 | > ( e ( or b ) ) ( e ( and d ) )
180 | # ( c ( or a ) ) ( d ( or c ) )
181 | < ( c ( or c ) ) ( b ( or c ) )
182 | # ( e ( or d ) ) ( e ( or a ) )
183 | # ( c ( and b ) ) ( a ( and b ) )
184 | > ( c ( or b ) ) ( b ( or b ) )
185 | # ( d ( and f ) ) ( e ( and d ) )
186 | # ( f ( and a ) ) ( e ( and a ) )
187 | > ( a ( or b ) ) ( a ( or a ) )
188 | < ( d ( and a ) ) ( a ( and a ) )
189 | # ( c ( or a ) ) ( d ( or a ) )
190 | # ( a ( and a ) ) ( c ( or c ) )
191 | # ( a ( or d ) ) ( a ( or f ) )
192 | < ( e ( and f ) ) ( f ( or e ) )
193 | < ( c ( and f ) ) ( c ( or f ) )
194 | = ( c ( and b ) ) ( c ( and b ) )
195 | # ( f ( and b ) ) ( a ( and d ) )
196 | # ( a ( and b ) ) ( f ( or c ) )
197 | < ( f ( or f ) ) ( f ( or c ) )
198 | = ( f ( and d ) ) ( d ( and f ) )
199 | # ( d ( or f ) ) ( d ( or b ) )
200 | < ( c ( and a ) ) ( a ( or b ) )
201 | # ( d ( or b ) ) ( e ( or b ) )
202 | # ( c ( and c ) ) ( d ( and f ) )
203 | > ( d ( or a ) ) ( a ( or a ) )
204 | > ( c ( or a ) ) ( a ( and e ) )
205 | # ( not c ) ( a ( and f ) )
206 | < ( d ( and f ) ) ( a ( or f ) )
207 | # ( c ( or f ) ) ( b ( or c ) )
208 | # ( e ( and d ) ) ( e ( and c ) )
209 | # ( a ( or d ) ) ( e ( or c ) )
210 | # ( f ( and f ) ) ( c ( and d ) )
211 | < ( d ( and d ) ) ( d ( or b ) )
212 | < ( f ( and d ) ) ( d ( or a ) )
213 | # ( c ( or c ) ) ( d ( or d ) )
214 | # ( b ( and a ) ) ( b ( and c ) )
215 | # ( a ( and f ) ) ( a ( and c ) )
216 | # ( a ( or d ) ) ( b ( or d ) )
217 | # ( f ( and b ) ) ( d ( or c ) )
218 | # ( c ( or d ) ) ( e ( or c ) )
219 | # ( a ( and b ) ) ( f ( and c ) )
220 | # ( c ( or b ) ) ( f ( and e ) )
221 | # ( e ( or a ) ) ( d ( or f ) )
222 | # ( f ( or f ) ) ( e ( or c ) )
223 | # ( e ( and f ) ) ( a ( and d ) )
224 | # ( c ( and c ) ) ( a ( or e ) )
225 | < ( b ( and f ) ) ( f ( or c ) )
226 | # ( c ( or c ) ) ( f ( or d ) )
227 | > ( c ( or b ) ) ( b ( and a ) )
228 | < ( c ( and b ) ) ( e ( or b ) )
229 | # ( e ( and b ) ) ( c ( or c ) )
230 | # ( e ( and e ) ) ( d ( or c ) )
231 | # ( a ( or d ) ) ( a ( or e ) )
232 | # ( a ( or b ) ) ( e ( or b ) )
233 | < ( e ( and b ) ) ( b ( or f ) )
234 | < ( e ( and f ) ) ( f ( and f ) )
235 | # ( e ( and b ) ) ( a ( or a ) )
236 | < ( e ( and c ) ) ( e ( or f ) )
237 | # ( e ( or e ) ) ( f ( or b ) )
238 | # ( f ( or f ) ) ( c ( or e ) )
239 | # ( f ( and a ) ) ( f ( and e ) )
240 | > ( c ( or b ) ) ( c ( or c ) )
241 | > ( a ( or f ) ) ( e ( and f ) )
242 | # ( d ( and f ) ) ( c ( and c ) )
243 | # ( d ( or e ) ) ( e ( or c ) )
244 | # ( a ( and d ) ) ( a ( and c ) )
245 | > ( d ( or b ) ) ( d ( and d ) )
246 | < ( f ( and a ) ) ( a ( or e ) )
247 | < ( a ( and c ) ) ( a ( or c ) )
248 | > ( d ( or d ) ) ( d ( and a ) )
249 | < ( a ( and b ) ) ( b ( or a ) )
250 | = ( b ( and f ) ) ( f ( and b ) )
251 | # ( c ( or e ) ) ( b ( and d ) )
252 | # ( c ( or b ) ) ( e ( or c ) )
253 | > ( d ( or b ) ) ( b ( and e ) )
254 | < ( d ( and f ) ) ( d ( or f ) )
255 | # ( d ( and f ) ) ( c ( and d ) )
256 | # ( d ( and b ) ) ( e ( or a ) )
257 | < ( d ( and c ) ) ( d ( or d ) )
258 | # ( f ( or a ) ) ( f ( or e ) )
259 | = ( a ( and d ) ) ( d ( and a ) )
260 | # ( c ( and c ) ) ( e ( or b ) )
261 | < ( c ( and e ) ) ( c ( or a ) )
262 | = ( f ( or b ) ) ( b ( or f ) )
263 | # ( c ( and f ) ) ( e ( and f ) )
264 | > ( e ( or f ) ) ( e ( or e ) )
265 | # ( e ( or c ) ) ( a ( and a ) )
266 | # ( f ( and f ) ) ( c ( and b ) )
267 | # ( e ( and a ) ) ( c ( and f ) )
268 | # ( d ( and a ) ) ( e ( and f ) )
269 | # ( b ( or e ) ) ( d ( and d ) )
270 | > ( c ( or e ) ) ( a ( and e ) )
271 | > ( b ( or c ) ) ( b ( and b ) )
272 | # ( a ( or e ) ) ( d ( or d ) )
273 | # ( b ( or e ) ) ( d ( and a ) )
274 | # ( f ( and b ) ) ( e ( and f ) )
275 | # ( not e ) ( f ( or b ) )
276 | # ( c ( or a ) ) ( b ( and d ) )
277 | # ( e ( and f ) ) ( b ( and b ) )
278 | = ( b ( and e ) ) ( e ( and b ) )
279 | # ( a ( or a ) ) ( c ( or e ) )
280 | > ( f ( or c ) ) ( f ( and f ) )
281 | # ( d ( and f ) ) ( b ( or b ) )
282 | # ( b ( and b ) ) ( c ( or e ) )
283 | # ( b ( or b ) ) ( a ( or c ) )
284 | # ( c ( or c ) ) ( d ( and a ) )
285 | # ( a ( or e ) ) ( a ( or f ) )
286 | # ( b ( or e ) ) ( c ( or f ) )
287 | > ( d ( or e ) ) ( d ( or d ) )
288 | # ( f ( and c ) ) ( a ( and d ) )
289 | # ( d ( or e ) ) ( d ( or f ) )
290 | # ( a ( or a ) ) ( b ( or d ) )
291 | # ( f ( or e ) ) ( d ( or e ) )
292 | < ( e ( and f ) ) ( e ( or e ) )
293 | # ( d ( or a ) ) ( f ( or d ) )
294 | # ( e ( or e ) ) ( c ( or b ) )
295 | = ( f ( and f ) ) ( f ( or f ) )
296 | # ( e ( and e ) ) ( d ( and d ) )
297 | > ( f ( or e ) ) ( e ( and a ) )
298 | # ( d ( or a ) ) ( f ( and b ) )
299 | # ( d ( and b ) ) ( d ( and f ) )
300 | # ( f ( or c ) ) ( e ( or a ) )
301 | # ( b ( or d ) ) ( b ( or c ) )
302 | # ( c ( or b ) ) ( d ( or e ) )
303 | # ( d ( and a ) ) ( e ( and a ) )
304 | < ( a ( and c ) ) ( a ( or a ) )
305 | > ( f ( or b ) ) ( f ( and e ) )
306 | # ( f ( or c ) ) ( e ( or f ) )
307 | # ( d ( and d ) ) ( f ( or a ) )
308 | # ( d ( or c ) ) ( e ( and e ) )
309 | < ( a ( and f ) ) ( a ( or e ) )
310 | # ( a ( or d ) ) ( c ( and f ) )
311 | # ( d ( or d ) ) ( c ( and c ) )
312 | > ( e ( or c ) ) ( b ( and c ) )
313 | # ( a ( and d ) ) ( b ( and f ) )
314 | < ( b ( and a ) ) ( b ( and b ) )
315 | = ( f ( or a ) ) ( f ( or a ) )
316 | # ( c ( or c ) ) ( b ( and a ) )
317 | > ( c ( or c ) ) ( f ( and c ) )
318 | > ( d ( and d ) ) ( d ( and a ) )
319 | < ( b ( and c ) ) ( c ( or a ) )
320 | > ( e ( or d ) ) ( d ( and b ) )
321 | # ( f ( and a ) ) ( c ( or c ) )
322 | # ( c ( or f ) ) ( d ( or b ) )
323 | > ( f ( or a ) ) ( f ( and c ) )
324 | > ( f ( or f ) ) ( c ( and f ) )
325 | # ( e ( and f ) ) ( d ( and e ) )
326 | # ( d ( and a ) ) ( d ( and f ) )
327 | # ( f ( or a ) ) ( e ( and e ) )
328 | # ( d ( and c ) ) ( c ( and a ) )
329 | < ( e ( and b ) ) ( e ( or e ) )
330 | # ( c ( or d ) ) ( f ( or c ) )
331 | > ( d ( or c ) ) ( a ( and d ) )
332 | # ( d ( or e ) ) ( a ( or c ) )
333 | # ( e ( and d ) ) ( f ( and f ) )
334 | > ( d ( or c ) ) ( e ( and c ) )
335 | # ( c ( and a ) ) ( e ( and e ) )
336 | # ( b ( or a ) ) ( f ( and f ) )
337 | = ( e ( or d ) ) ( d ( or e ) )
338 | < ( b ( and d ) ) ( e ( or d ) )
339 | # ( f ( or d ) ) ( a ( or d ) )
340 | > ( c ( and c ) ) ( c ( and d ) )
341 | < ( c ( and e ) ) ( b ( or e ) )
342 | # ( a ( or d ) ) ( f ( or b ) )
343 | # ( c ( or c ) ) ( b ( or b ) )
344 | # ( b ( and a ) ) ( e ( or d ) )
345 | # ( a ( or e ) ) ( c ( or b ) )
346 | = ( e ( or c ) ) ( c ( or e ) )
347 | # ( c ( and a ) ) ( b ( or b ) )
348 | # ( a ( or b ) ) ( a ( or e ) )
349 | # ( b ( or d ) ) ( a ( or d ) )
350 | # ( a ( or b ) ) ( c ( and d ) )
351 | # ( d ( or e ) ) ( f ( or e ) )
352 | > ( b ( or f ) ) ( a ( and f ) )
353 | < ( a ( and c ) ) ( e ( or c ) )
354 | < ( b ( or b ) ) ( b ( or e ) )
355 | < ( a ( and f ) ) ( d ( or f ) )
356 | # ( f ( or a ) ) ( d ( and e ) )
357 | = ( a ( or a ) ) ( a ( and a ) )
358 | # ( b ( or b ) ) ( e ( or c ) )
359 | # ( e ( and b ) ) ( c ( and b ) )
360 | < ( f ( and f ) ) ( f ( or e ) )
361 | > ( f ( or d ) ) ( f ( and e ) )
362 | # ( e ( or b ) ) ( b ( or a ) )
363 | # ( a ( and c ) ) ( f ( and d ) )
364 | > ( f ( or b ) ) ( b ( and e ) )
365 | > ( a ( or e ) ) ( c ( and a ) )
366 | > ( d ( or c ) ) ( d ( or d ) )
367 | # ( b ( or d ) ) ( c ( and e ) )
368 | > ( c ( or d ) ) ( d ( and d ) )
369 | # ( d ( and d ) ) ( a ( or a ) )
370 | # ( f ( and a ) ) ( e ( and c ) )
371 | = ( c ( and c ) ) ( c ( and c ) )
372 | # ( f ( and f ) ) ( d ( or d ) )
373 | = ( c ( or c ) ) ( c ( or c ) )
374 | < ( b ( and b ) ) ( d ( or b ) )
375 | > ( b ( or e ) ) ( e ( and c ) )
376 | < ( d ( and b ) ) ( d ( or f ) )
377 | # ( f ( or d ) ) ( a ( and a ) )
378 | # ( f ( or f ) ) ( d ( or e ) )
379 | < ( b ( and c ) ) ( b ( or c ) )
380 | > ( e ( or c ) ) ( e ( and e ) )
381 | < ( b ( and b ) ) ( a ( or b ) )
382 | # ( d ( and d ) ) ( f ( or f ) )
383 | > ( f ( or b ) ) ( c ( and b ) )
384 | < ( a ( and c ) ) ( c ( and c ) )
385 | # ( b ( and d ) ) ( b ( and f ) )
386 | # ( f ( and c ) ) ( a ( or d ) )
387 | < ( d ( and e ) ) ( e ( or a ) )
388 | # ( c ( or b ) ) ( d ( or c ) )
389 | = ( f ( and a ) ) ( a ( and f ) )
390 | < ( b ( and d ) ) ( d ( and d ) )
391 | # ( a ( and f ) ) ( c ( and a ) )
392 | < ( e ( or e ) ) ( e ( or c ) )
393 | < ( c ( and d ) ) ( d ( or c ) )
394 | # ( d ( or d ) ) ( c ( or a ) )
395 | # ( e ( or f ) ) ( a ( or c ) )
396 | # ( c ( or e ) ) ( d ( and f ) )
397 | # ( a ( or b ) ) ( a ( or d ) )
398 | < ( b ( and e ) ) ( e ( and e ) )
399 | # ( e ( and e ) ) ( c ( and a ) )
400 | # ( a ( or a ) ) ( e ( or b ) )
401 | > ( b ( or d ) ) ( e ( and b ) )
402 | # ( a ( and b ) ) ( b ( and f ) )
403 | # ( a ( or f ) ) ( b ( and e ) )
404 | # ( c ( or e ) ) ( c ( or d ) )
405 | # ( a ( and a ) ) ( f ( and f ) )
406 | # ( a ( and d ) ) ( f ( and a ) )
407 | # ( c ( and c ) ) ( b ( or d ) )
408 | > ( c ( or b ) ) ( a ( and c ) )
409 | # ( b ( or d ) ) ( c ( or c ) )
410 | # ( a ( or e ) ) ( d ( and d ) )
411 |
--------------------------------------------------------------------------------
/data/propositionallogic/train0:
--------------------------------------------------------------------------------
1 | = a a
2 | # d f
3 | = f f
4 | # b e
5 | # d c
6 | # c f
7 | # e a
8 | # f c
9 | # e c
10 | # e f
11 | # e d
12 | # f d
13 | # a d
14 | # c a
15 | # c b
16 | # c d
17 | # f b
18 | # d e
19 | # d a
20 | # d b
21 | # a f
22 | = d d
23 | # f a
24 | # a c
25 | # b a
26 | # b c
27 | = e e
28 | # b f
29 | = c c
30 | # e b
31 |
--------------------------------------------------------------------------------
/listops.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim.lr_scheduler as lr_scheduler
9 |
10 | import ordered_memory
11 | from utils.hinton import plot
12 | from utils.listops_data import load_data_and_embeddings, LABEL_MAP, PADDING_TOKEN, get_batch
13 | from utils.utils import build_tree, char2tree, evalb
14 |
15 |
16 | class ListOpsModel(nn.Module):
17 | def __init__(self, args):
18 | super(ListOpsModel, self).__init__()
19 |
20 | self.args = args
21 | self.padding_idx = args.padding_idx
22 | self.embedding = nn.Embedding(args.ntoken, args.ninp,
23 | padding_idx=self.padding_idx)
24 |
25 | self.encoder = ordered_memory.OrderedMemory(args.ninp, args.nhid, args.nslot,
26 | dropout=args.dropout, dropoutm=args.dropoutm,
27 | bidirection=args.bidirection)
28 |
29 | self.mlp = nn.Sequential(
30 | nn.Dropout(args.dropouto),
31 | nn.Linear(args.nhid * 2 if args.bidirection else args.nhid, args.nout),
32 | )
33 |
34 | self.drop_input = nn.Dropout(args.dropouti)
35 | self.drop_output = nn.Dropout(args.dropouto)
36 | self.cost = nn.CrossEntropyLoss()
37 |
38 | def forward(self, input):
39 | mask = (input != self.padding_idx).bool()
40 |
41 | emb = self.embedding(input)
42 | emb.transpose_(0, 1)
43 |
44 | mask.transpose_(0, 1)
45 | emb = self.drop_input(emb)
46 | output = self.encoder(emb, mask, output_last=True)
47 | output = self.mlp(output)
48 | return output
49 |
50 | def set_pretrained_embeddings(self, ext_embeddings, ext_word_to_index, word_to_index, finetune=False):
51 | assert hasattr(self, 'embedding')
52 | embeddings = self.embedding.weight.data.cpu().numpy()
53 | for word, index in word_to_index.items():
54 | if word in ext_word_to_index:
55 | embeddings[index] = ext_embeddings[ext_word_to_index[word]]
56 | embeddings = torch.from_numpy(embeddings).to(self.embedding.weight.device)
57 | self.embedding.weight.data.set_(embeddings)
58 | self.embedding.weight.requires_grad = finetune
59 |
60 |
61 | def model_save(fn):
62 | if args.philly:
63 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
64 | with open(fn, 'wb') as f:
65 | # torch.save([model, optimizer], f)
66 | torch.save({
67 | 'epoch': epoch,
68 | 'model_state_dict': model.state_dict(),
69 | 'optimizer_state_dict': optimizer.state_dict(),
70 | 'loss': test_loss
71 | }, f)
72 |
73 |
74 | def model_load(fn):
75 | global model, optimizer
76 | if args.philly:
77 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
78 | with open(fn, 'rb') as f:
79 | checkpoint = torch.load(f)
80 | model.load_state_dict(checkpoint['model_state_dict'])
81 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
82 | epoch = checkpoint['epoch']
83 | test_loss = checkpoint['loss']
84 |
85 |
86 | ###############################################################################
87 | # Training code
88 | ###############################################################################
89 |
90 | @torch.no_grad()
91 | def evaluate(data_iter):
92 | # Turn on evaluation mode which disables dropout.
93 | model.eval()
94 |
95 | total_loss = 0
96 | total_datapoints = 0
97 | for batch, data in enumerate(data_iter):
98 | batch_data = get_batch(data)
99 | X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch_data
100 |
101 | X_batch = torch.from_numpy(X_batch).long().to('cuda' if args.cuda else 'cpu')
102 | y_batch = torch.from_numpy(y_batch).long().to('cuda' if args.cuda else 'cpu')
103 |
104 | lin_output = model(X_batch)
105 | count = y_batch.shape[0]
106 | total_loss += torch.sum(
107 | torch.argmax(lin_output, dim=1) == y_batch
108 | ).float().data
109 | total_datapoints += count
110 |
111 | return total_loss.item() / total_datapoints
112 |
113 |
114 | def train():
115 | # Turn on training mode which enables dropout.
116 | model.train()
117 |
118 | total_loss = 0
119 | total_acc = 0
120 | start_time = time.time()
121 | for batch, data in enumerate(training_data_iter):
122 | # print(data)
123 | # batch_data = get_batch(next(training_data_iter))
124 | data, n_batches = data
125 | batch_data = get_batch(data)
126 | X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch_data
127 |
128 | X_batch = torch.from_numpy(X_batch).long().to('cuda' if args.cuda else 'cpu')
129 | y_batch = torch.from_numpy(y_batch).long().to('cuda' if args.cuda else 'cpu')
130 |
131 | optimizer.zero_grad()
132 |
133 | lin_output = model(X_batch)
134 | loss = model.cost(lin_output, y_batch)
135 | acc = torch.mean(
136 | (torch.argmax(lin_output, dim=1) == y_batch).float())
137 | loss.backward()
138 |
139 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
140 | if args.clip:
141 | torch.nn.utils.clip_grad_norm_(params, args.clip)
142 | optimizer.step()
143 |
144 | total_loss += loss.detach().data
145 | total_acc += acc.detach().data
146 | if batch % args.log_interval == 0 and batch > 0:
147 | elapsed = time.time() - start_time
148 | print(
149 | '| epoch {:3d} '
150 | '| {:5d}/ {:5d} batches '
151 | '| lr {:05.5f} | ms/batch {:5.2f} '
152 | '| loss {:5.2f} | acc {:0.2f}'.format(
153 | epoch,
154 | batch,
155 | n_batches,
156 | optimizer.param_groups[0]['lr'],
157 | elapsed * 1000 / args.log_interval,
158 | total_loss.item() / args.log_interval,
159 | total_acc.item() / args.log_interval))
160 | total_loss = 0
161 | total_acc = 0
162 | start_time = time.time()
163 | ###
164 | batch += 1
165 | if batch >= n_batches:
166 | break
167 |
168 |
169 | @torch.no_grad()
170 | def generate_parse(data_iter):
171 | model.eval()
172 | np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format})
173 | pred_tree_list = []
174 | targ_tree_list = []
175 | crop_count = 0
176 | total_count = 0
177 | for batch, data in enumerate(data_iter):
178 | sents = data['tokens']
179 | X = np.array([vocabulary[t] for t in data['tokens']])
180 | # if len(sents) > 100: # In case Evalb fail to process very long sequences
181 | # continue
182 |
183 | X_batch = torch.from_numpy(X).long().to('cuda' if args.cuda else 'cpu')
184 |
185 | model(X_batch[None, :])
186 | probs = model.encoder.probs
187 | distance = torch.argmax(probs, dim=-1)
188 | distance[0] = args.nslot
189 |
190 | total_count += 1
191 | depth = distance[:, 0]
192 | probs_k = probs[:, 0, :].data.cpu().numpy()
193 |
194 | try:
195 | parse_tree = build_tree(depth, sents)
196 | sen_tree = char2tree(data['sentence'].split())
197 | except:
198 | crop_count += 1
199 | print('Unbalanced datapoint!')
200 | continue
201 |
202 | pred_tree_list.append(parse_tree)
203 | targ_tree_list.append(sen_tree)
204 |
205 | if batch % 100 > 0:
206 | continue
207 | print(batch)
208 | for i in range(len(sents)):
209 | if sents[i] == '':
210 | break
211 | print('%20s\t%2.2f\t%s' % (sents[i], depth[i], plot(probs_k[i], 1)))
212 | print(parse_tree)
213 | print(sen_tree)
214 | print()
215 |
216 | print('Cropped: %d, Total: %d' % (crop_count, total_count))
217 | evalb(pred_tree_list, targ_tree_list, evalb_path="../EVALB")
218 |
219 |
220 | if __name__ == "__main__":
221 | parser = argparse.ArgumentParser(description='')
222 |
223 | parser.add_argument('--data', type=str, default='./data/listops',
224 | help='location of the data corpus')
225 | parser.add_argument('--bidirection', action='store_true',
226 | help='use bidirection model')
227 | parser.add_argument('--seq_len', type=int, default=100,
228 | help='max sequence length')
229 | parser.add_argument('--seq_len_test', type=int, default=1000,
230 | help='max sequence length')
231 | parser.add_argument('--no-smart-batching', action='store_true', # reverse
232 | help='batch based on length')
233 | parser.add_argument('--no-use_peano', action='store_true',
234 | help='batch based on length')
235 | parser.add_argument('--emsize', type=int, default=128,
236 | help='size of word embeddings')
237 | parser.add_argument('--nhid', type=int, default=128,
238 | help='number of hidden units per layer')
239 | parser.add_argument('--nslot', type=int, default=21,
240 | help='number of memory slots')
241 | parser.add_argument('--lr', type=float, default=0.001,
242 | help='initial learning rate')
243 | parser.add_argument('--clip', type=float, default=1.,
244 | help='gradient clipping')
245 | parser.add_argument('--epochs', type=int, default=50,
246 | help='upper epoch limit')
247 | parser.add_argument('--batch_size', type=int, default=128, metavar='N',
248 | help='batch size')
249 | parser.add_argument('--batch_size_test', type=int, default=128, metavar='N',
250 | help='batch size')
251 | parser.add_argument('--dropout', type=float, default=0.1,
252 | help='dropout applied to layers (0 = no dropout)')
253 | parser.add_argument('--dropoutm', type=float, default=0.3,
254 | help='dropout applied to memory (0 = no dropout)')
255 | parser.add_argument('--dropouti', type=float, default=0.1,
256 | help='dropout for input embedding layers (0 = no dropout)')
257 | parser.add_argument('--dropouto', type=float, default=0.2,
258 | help='dropout applied to layers (0 = no dropout)')
259 | parser.add_argument('--seed', type=int, default=1111,
260 | help='random seed')
261 | parser.add_argument('--cuda', action='store_true',
262 | help='use CUDA')
263 | parser.add_argument('--log-interval', type=int, default=100, metavar='N',
264 | help='report interval')
265 | parser.add_argument('--test-only', action='store_true',
266 | help='Test only')
267 |
268 | randomhash = ''.join(str(time.time()).split('.'))
269 | parser.add_argument('--name', type=str, default=randomhash + '.pt',
270 | help='exp name')
271 | parser.add_argument('--wdecay', type=float, default=1.2e-6,
272 | help='weight decay applied to all weights')
273 | parser.add_argument('--std', action='store_true',
274 | help='use standard LSTM')
275 | parser.add_argument('--philly', action='store_true',
276 | help='Use philly cluster')
277 | args = parser.parse_args()
278 |
279 | args.smart_batching = not args.no_smart_batching
280 | args.use_peano = not args.no_use_peano
281 |
282 | # Set the random seed manually for reproducibility.
283 | np.random.seed(args.seed)
284 | torch.manual_seed(args.seed)
285 | if torch.cuda.is_available():
286 | if not args.cuda:
287 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
288 | else:
289 | torch.cuda.manual_seed(args.seed)
290 |
291 | ###############################################################################
292 | # Load data
293 | ###############################################################################
294 | train_data_path = os.path.join(args.data, 'train_d20s.tsv')
295 | test_data_path = os.path.join(args.data, 'test_d20s.tsv')
296 | vocabulary, initial_embeddings, training_data_iter, eval_iterator, training_data_length, raw_eval_data \
297 | = load_data_and_embeddings(args, train_data_path, test_data_path)
298 | dictionary = {}
299 | for k, v in vocabulary.items():
300 | dictionary[v] = k
301 | # make iterator for splits
302 | vocab_size = len(vocabulary)
303 | num_classes = len(set(LABEL_MAP.values()))
304 | args.__dict__.update({'ntoken': vocab_size,
305 | 'ninp': args.emsize,
306 | 'nout': num_classes,
307 | 'padding_idx': vocabulary[PADDING_TOKEN]})
308 |
309 | model = ListOpsModel(args)
310 |
311 | if args.cuda:
312 | model = model.cuda()
313 |
314 | params = list(model.parameters())
315 | total_params = sum(x.size()[0] * x.size()[1]
316 | if len(x.size()) > 1 else x.size()[0]
317 | for x in params if x.size())
318 | total_params_sanity = sum(np.prod(x.size()) for x in model.parameters())
319 | assert total_params == total_params_sanity
320 | print("TOTAL PARAMS: %d" % sum(np.prod(x.size()) for x in model.parameters()))
321 | print('Args:', args)
322 | print('Model total parameters:', total_params)
323 |
324 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
325 | optimizer = torch.optim.Adam(params,
326 | lr=args.lr,
327 | betas=(0, 0.999),
328 | eps=1e-9,
329 | weight_decay=args.wdecay)
330 |
331 | if not args.test_only:
332 | # Loop over epochs.
333 | lr = args.lr
334 | stored_loss = 0.
335 |
336 | # At any point you can hit Ctrl + C to break out of training early.
337 | try:
338 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=2, threshold=0)
339 | for epoch in range(1, args.epochs + 1):
340 | epoch_start_time = time.time()
341 | train()
342 | test_loss = evaluate(eval_iterator)
343 |
344 | print('-' * 89)
345 | print(
346 | '| end of epoch {:3d} '
347 | '| time: {:5.2f}s '
348 | '| test acc: {:.4f} '
349 | '|\n'.format(
350 | epoch,
351 | (time.time() - epoch_start_time),
352 | test_loss
353 | )
354 | )
355 |
356 | if test_loss > stored_loss:
357 | model_save(args.name)
358 | print('Saving model (new best validation)')
359 | stored_loss = test_loss
360 | print('-' * 89)
361 |
362 | scheduler.step(test_loss)
363 | except KeyboardInterrupt:
364 | print('-' * 89)
365 | print('Exiting from training early')
366 |
367 | model_load(args.name)
368 | generate_parse(raw_eval_data)
369 | test_loss = evaluate(eval_iterator)
370 | data = {'args': args.__dict__,
371 | 'parameters': total_params,
372 | 'test_acc': test_loss}
373 | print('-' * 89)
374 | print(
375 | '| test acc: {:.4f} '
376 | '|\n'.format(
377 | test_loss
378 | )
379 | )
380 |
--------------------------------------------------------------------------------
/ordered_memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | class Distribution(nn.Module):
8 | def __init__(self, nslot, hidden_size, dropout):
9 | super(Distribution, self).__init__()
10 |
11 | self.query = nn.Sequential(
12 | nn.Dropout(dropout),
13 | nn.Linear(hidden_size, hidden_size),
14 | nn.LayerNorm(hidden_size),
15 | )
16 |
17 | self.key = nn.Sequential(
18 | nn.Dropout(dropout),
19 | nn.Linear(hidden_size, hidden_size),
20 | nn.LayerNorm(hidden_size),
21 | )
22 |
23 | self.beta = nn.Sequential(
24 | nn.ReLU(),
25 | nn.Dropout(dropout),
26 | nn.Linear(hidden_size, 1),
27 | )
28 |
29 | self.hidden_size = hidden_size
30 |
31 | def init_p(self, bsz, nslot):
32 | return None
33 |
34 | @staticmethod
35 | def process_softmax(beta, prev_p):
36 | if prev_p is None:
37 | return torch.zeros_like(beta), torch.ones_like(beta), torch.zeros_like(beta)
38 |
39 | beta_normalized = beta - beta.max(dim=-1)[0][:, None]
40 | x = torch.exp(beta_normalized)
41 |
42 | prev_cp = torch.cumsum(prev_p, dim=1)
43 | mask = prev_cp[:, 1:]
44 | mask = mask.masked_fill(mask < 1e-5, 0.)
45 | mask = F.pad(mask, (0, 1), value=1)
46 |
47 | x_masked = x * mask
48 |
49 | p = F.normalize(x_masked, p=1)
50 | cp = torch.cumsum(p, dim=1)
51 | rcp = torch.cumsum(p.flip([1]), dim=1).flip([1])
52 | return cp, rcp, p
53 |
54 | def forward(self, in_val, prev_out_M, prev_p):
55 | query = self.query(in_val)
56 | key = self.key(prev_out_M)
57 | beta = self.beta(query[:, None, :] + key).squeeze(dim=2)
58 | beta = beta / math.sqrt(self.hidden_size)
59 | cp, rcp, p = self.process_softmax(beta, prev_p)
60 | return cp, rcp, p
61 |
62 |
63 | class Cell(nn.Module):
64 | def __init__(self, hidden_size, dropout, activation=None):
65 | super(Cell, self).__init__()
66 | self.hidden_size = hidden_size
67 | self.cell_hidden_size = 4 * hidden_size
68 |
69 | self.input_t = nn.Sequential(
70 | nn.Dropout(dropout),
71 | nn.Linear(hidden_size * 2, self.cell_hidden_size),
72 | nn.ReLU(),
73 | nn.Dropout(dropout),
74 | nn.Linear(self.cell_hidden_size, hidden_size * 4),
75 | )
76 |
77 | self.gates = nn.Sequential(
78 | nn.Sigmoid(),
79 | )
80 |
81 | assert activation is not None
82 | self.activation = activation
83 |
84 | self.drop = nn.Dropout(dropout)
85 |
86 | def forward(self, vi, hi):
87 | input = torch.cat([vi, hi], dim=-1)
88 |
89 | g_input, cell = self.input_t(input).split(
90 | (self.hidden_size * 3, self.hidden_size),
91 | dim=-1
92 | )
93 |
94 | gates = self.gates(g_input)
95 | vg, hg, cg = gates.chunk(3, dim=1)
96 | output = self.activation(vg * vi + hg * hi + cg * cell)
97 | return output
98 |
99 |
100 | class OrderedMemoryRecurrent(nn.Module):
101 | def __init__(self, input_size, slot_size, nslot,
102 | dropout=0.2, dropoutm=0.2):
103 | super(OrderedMemoryRecurrent, self).__init__()
104 |
105 | self.activation = nn.LayerNorm(slot_size)
106 | self.input_projection = nn.Sequential(
107 | nn.Linear(input_size, slot_size),
108 | self.activation
109 | )
110 |
111 | self.distribution = Distribution(nslot, slot_size, dropoutm)
112 |
113 | self.cell = Cell(slot_size, dropout, activation=self.activation)
114 |
115 | self.nslot = nslot
116 | self.slot_size = slot_size
117 | self.input_size = input_size
118 |
119 | def init_hidden(self, bsz):
120 | weight = next(self.parameters()).data
121 | zeros = weight.new(bsz, self.nslot, self.slot_size).zero_()
122 | p = self.distribution.init_p(bsz, self.nslot)
123 | return (zeros, zeros, p)
124 |
125 | def omr_step(self, in_val, prev_M, prev_out_M, prev_p):
126 | batch_size, nslot, slot_size = prev_M.size()
127 | _batch_size, slot_size = in_val.size()
128 |
129 | assert self.slot_size == slot_size
130 | assert self.nslot == nslot
131 | assert batch_size == _batch_size
132 |
133 | cp, rcp, p = self.distribution(in_val, prev_out_M, prev_p)
134 |
135 | curr_M = prev_M * (1 - rcp)[:, :, None] + prev_out_M * rcp[:, :, None]
136 |
137 | M_list = []
138 | h = in_val
139 | for i in range(nslot):
140 | if i == nslot - 1 or cp[:, i+1].max() > 0:
141 | h = self.cell(h, curr_M[:, i, :])
142 | h = in_val * (1 - cp)[:, i, None] + h * cp[:, i, None]
143 | M_list.append(h)
144 | out_M = torch.stack(M_list, dim=1)
145 |
146 | output = out_M[:, -1]
147 | return output, curr_M, out_M, p
148 |
149 | def forward(self, X, hidden, mask=None):
150 | prev_M, prev_memory_output, prev_p = hidden
151 | output_list = []
152 | p_list = []
153 | X_projected = self.input_projection(X)
154 | if mask is not None:
155 | padded = ~mask
156 | for t in range(X_projected.size(0)):
157 | output, prev_M, prev_memory_output, prev_p = self.omr_step(
158 | X_projected[t], prev_M, prev_memory_output, prev_p)
159 | if mask is not None:
160 | padded_1 = padded[t, :, None]
161 | padded_2 = padded[t, :, None, None]
162 | output = output.masked_fill(padded_1, 0.)
163 | prev_p = prev_p.masked_fill(padded_1, 0.)
164 | prev_M = prev_M.masked_fill(padded_2, 0.)
165 | prev_memory_output = prev_memory_output.masked_fill(padded_2, 0.)
166 | output_list.append(output)
167 | p_list.append(prev_p)
168 |
169 | output = torch.stack(output_list)
170 | probs = torch.stack(p_list)
171 |
172 | return (output,
173 | probs,
174 | (prev_M, prev_memory_output, prev_p))
175 |
176 |
177 | class OrderedMemory(nn.Module):
178 | def __init__(self, input_size, slot_size,
179 | nslot, dropout=0.2, dropoutm=0.1,
180 | bidirection=False):
181 | super(OrderedMemory, self).__init__()
182 |
183 | self.OM_forward = OrderedMemoryRecurrent(input_size, slot_size, nslot,
184 | dropout=dropout, dropoutm=dropoutm)
185 | if bidirection:
186 | self.OM_backward = OrderedMemoryRecurrent(input_size, slot_size, nslot,
187 | dropout=dropout, dropoutm=dropoutm)
188 |
189 | self.bidirection = bidirection
190 |
191 | def init_hidden(self, bsz):
192 | return self.OM_forward.init_hidden(bsz)
193 |
194 | def forward(self, X, mask, output_last=False):
195 | bsz = X.size(1)
196 | lengths = mask.sum(0)
197 | init_hidden = self.init_hidden(bsz)
198 |
199 | output_list = []
200 | prob_list = []
201 |
202 | om_output_forward, prob_forward, _ = self.OM_forward(X, init_hidden, mask)
203 | if output_last:
204 | output_list.append(om_output_forward[-1])
205 | else:
206 | output_list.append(om_output_forward[lengths - 1, torch.arange(bsz).long()])
207 | prob_list.append(prob_forward)
208 |
209 | if self.bidirection:
210 | om_output_backward, prob_backward, _ = self.OM_backward(X.flip([0]), init_hidden, mask.flip([0]))
211 | output_list.append(om_output_backward[-1])
212 | prob_list.append(prob_backward.flip([0]))
213 |
214 | output = torch.cat(output_list, dim=-1)
215 | self.probs = prob_list[0]
216 |
217 | return output
218 |
--------------------------------------------------------------------------------
/proplog.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import os
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim.lr_scheduler as lr_scheduler
11 |
12 | import ordered_memory
13 | from utils.utils import build_tree, evalb, remove_bracket, char2tree
14 | from utils.hinton import plot
15 |
16 | # from orion.client import report_results
17 |
18 |
19 | if __name__ == "__main__":
20 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
21 | parser.add_argument('--data', type=str, default='data/propositionallogic/',
22 | help='location of the data corpus')
23 | parser.add_argument('--max-op', type=int, default=6,
24 | help='maximum number of operator')
25 | parser.add_argument('--emsize', type=int, default=200,
26 | help='size of word embeddings')
27 | parser.add_argument('--nhid', type=int, default=200,
28 | help='number of hidden units per layer')
29 | parser.add_argument('--nslot', type=int, default=12,
30 | help='number of memory slots')
31 | parser.add_argument('--lr', type=float, default=0.001,
32 | help='initial learning rate')
33 | parser.add_argument('--clip', type=float, default=1.,
34 | help='gradient clipping')
35 | parser.add_argument('--epochs', type=int, default=50,
36 | help='upper epoch limit')
37 | parser.add_argument('--batch_size', type=int, default=128, metavar='N',
38 | help='batch size')
39 | parser.add_argument('--dropout', type=float, default=0.2,
40 | help='dropout applied to layers (0 = no dropout)')
41 | parser.add_argument('--dropouti', type=float, default=0.1,
42 | help='dropout applied to layers (0 = no dropout)')
43 | parser.add_argument('--dropouto', type=float, default=0.3,
44 | help='dropout applied to layers (0 = no dropout)')
45 | parser.add_argument('--dropoutm', type=float, default=0.2,
46 | help='dropout applied to layers (0 = no dropout)')
47 | parser.add_argument('--seed', type=int, default=1111,
48 | help='random seed')
49 | parser.add_argument('--cuda', action='store_true',
50 | help='use CUDA')
51 | parser.add_argument('--test-only', action='store_true',
52 | help='Test only')
53 |
54 | parser.add_argument('--log-interval', type=int, default=200, metavar='N',
55 | help='report interval')
56 | randomhash = ''.join(str(time.time()).split('.'))
57 | parser.add_argument('--save', type=str, default=randomhash + '.pt',
58 | help='path to save the final model')
59 | parser.add_argument('--wdecay', type=float, default=1.2e-6,
60 | help='weight decay applied to all weights')
61 | parser.add_argument('--std', action='store_true',
62 | help='use standard LSTM')
63 | parser.add_argument('--philly', action='store_true',
64 | help='Use philly cluster')
65 | args = parser.parse_args()
66 | args.tied = True
67 |
68 | # Set the random seed manually for reproducibility.
69 | np.random.seed(args.seed)
70 | torch.manual_seed(args.seed)
71 | if torch.cuda.is_available():
72 | if not args.cuda:
73 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
74 | else:
75 | torch.cuda.manual_seed(args.seed)
76 |
77 |
78 | ###############################################################################
79 | # Load data
80 | ###############################################################################
81 |
82 |
83 | def model_save(fn):
84 | if args.philly:
85 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
86 | with open(fn, 'wb') as f:
87 | # torch.save([model, optimizer], f)
88 | torch.save({
89 | 'epoch': epoch,
90 | 'model_state_dict': model.state_dict(),
91 | 'optimizer_state_dict': optimizer.state_dict(),
92 | 'loss': val_loss
93 | }, f)
94 |
95 |
96 | def model_load(fn):
97 | global model, optimizer
98 | if args.philly:
99 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
100 | with open(fn, 'rb') as f:
101 | checkpoint = torch.load(f)
102 | model.load_state_dict(checkpoint['model_state_dict'])
103 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
104 | epoch = checkpoint['epoch']
105 | val_loss = checkpoint['loss']
106 |
107 |
108 | class LogicInference(object):
109 | def __init__(self, datapath='data/propositionallogic/', maxn=12):
110 | """maxn=0 indicates variable expression length."""
111 | self.num2char = ['(', ')',
112 | 'a', 'b', 'c', 'd', 'e', 'f',
113 | 'or', 'and', 'not']
114 | self.char2num = {self.num2char[i]: i
115 | for i in range(len(self.num2char))}
116 |
117 | self.num2lbl = list('<>=^|#v')
118 | self.lbl2num = {self.num2lbl[i]: i
119 | for i in range(len(self.num2lbl))}
120 |
121 | self.train_set, self.valid_set, self.test_set = [], [], []
122 | counter = 0
123 | for i in range(maxn):
124 | itrainexample = self._readfile(os.path.join(datapath, "train" + str(i)))
125 | for e in itrainexample:
126 | counter += 1
127 | if counter % 10 == 0:
128 | self.valid_set.append(e)
129 | else:
130 | self.train_set.append(e)
131 | # self.train_set = self.train_set + itrainexample
132 |
133 | for i in range(13):
134 | itestexample = self._readfile(os.path.join(datapath, "test" + str(i)))
135 | self.test_set.append(itestexample)
136 |
137 | def _readfile(self, filepath):
138 | f = open(filepath, 'r')
139 | examples = []
140 | for line in f.readlines():
141 | relation, p1, p2 = line.strip().split('\t')
142 | p1 = p1.split()
143 | p2 = p2.split()
144 | examples.append((self.lbl2num[relation],
145 | [self.char2num[w] for w in p1],
146 | [self.char2num[w] for w in p2]))
147 | return examples
148 |
149 | def stream(self, dataset, batch_size, shuffle=False, pad=None):
150 | if pad is None:
151 | pad = len(self.num2char)
152 | import random
153 | import math
154 | batch_count = int(math.ceil(len(dataset) / float(batch_size)))
155 |
156 | def shuffle_stream():
157 | if shuffle:
158 | random.shuffle(dataset)
159 | for i in range(batch_count):
160 | yield dataset[i * batch_size: (i + 1) * batch_size]
161 |
162 | def arrayify(stream, pad):
163 | for batch in stream:
164 | batch_lbls = np.array([x[0] for x in batch], dtype=np.int64)
165 | batch_sent = [x[1] for x in batch] + [x[2] for x in batch]
166 | max_len = max(len(s) for s in batch_sent)
167 | batch_idxs = np.full((max_len, len(batch_sent)), pad,
168 | dtype=np.int64)
169 | for i in range(len(batch_sent)):
170 | sentence = batch_sent[i]
171 | batch_idxs[:len(sentence), i] = sentence
172 | yield batch_idxs, batch_lbls
173 |
174 | stream = shuffle_stream()
175 | stream = arrayify(stream, pad)
176 | return stream
177 |
178 |
179 | corpus = LogicInference(datapath=args.data, maxn=args.max_op + 1)
180 |
181 | ###############################################################################
182 | # Build the model
183 | ###############################################################################
184 | ###
185 | # if args.resume:
186 | # print('Resuming model ...')
187 | # model_load(args.resume)
188 | # optimizer.param_groups[0]['lr'] = args.lr
189 | # model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute
190 | # if args.wdrop:
191 | # for rnn in model.rnn.cells:
192 | # rnn.hh.dropout = args.wdrop
193 | ###
194 |
195 | ntokens = len(corpus.num2char) + 1
196 | nlbls = len(corpus.num2lbl)
197 |
198 |
199 | class Classifier(nn.Module):
200 | """Container module with an encoder, a recurrent module, and a decoder."""
201 |
202 | def __init__(self, ntoken, ninp, nhid, nout, nslot, dropout, dropouti, dropouto, dropoutm):
203 | super(Classifier, self).__init__()
204 |
205 | self.padding_idx = ntoken - 1
206 | self.embedding = nn.Embedding(ntoken, ninp,
207 | padding_idx=self.padding_idx)
208 |
209 | self.encoder = ordered_memory.OrderedMemory(ninp, nhid, nslot,
210 | dropout=dropout, dropoutm=dropoutm)
211 |
212 | self.mlp = nn.Sequential(
213 | nn.Dropout(dropouto),
214 | nn.Linear(4 * nhid, nhid),
215 | nn.ELU(),
216 | nn.Dropout(dropouto),
217 | nn.Linear(nhid, nout),
218 | )
219 |
220 | self.drop = nn.Dropout(dropouti)
221 |
222 | self.cost = nn.CrossEntropyLoss()
223 | self.init_weights()
224 |
225 | def init_weights(self):
226 | initrange = 0.1
227 | self.embedding.weight.data.uniform_(-initrange, initrange)
228 |
229 | def forward(self, input):
230 | batch_size = input.size(1)
231 | mask = (input != self.padding_idx)
232 | emb = self.drop(self.embedding(input))
233 | output = self.encoder(emb, mask)
234 | self.probs = self.encoder.probs
235 |
236 | clause_1 = output[:batch_size // 2]
237 | clause_2 = output[batch_size // 2:]
238 | output = self.mlp(torch.cat([clause_1, clause_2,
239 | clause_1 * clause_2,
240 | torch.abs(clause_1 - clause_2)], dim=1))
241 | return output
242 |
243 |
244 | if __name__ == "__main__":
245 | model = Classifier(
246 | ntoken=ntokens,
247 | ninp=args.emsize,
248 | nhid=args.nhid,
249 | nout=nlbls,
250 | nslot=args.nslot,
251 | dropout=args.dropout,
252 | dropouti=args.dropouti,
253 | dropouto=args.dropouto,
254 | dropoutm=args.dropoutm,
255 | )
256 |
257 | if args.cuda:
258 | model = model.cuda()
259 | # model = model.half()
260 |
261 | params = list(model.parameters())
262 | total_params = sum(x.size()[0] * x.size()[1]
263 | if len(x.size()) > 1 else x.size()[0]
264 | for x in params if x.size())
265 | print('Args:', args)
266 | print('Model total parameters:', total_params)
267 |
268 | optimizer = torch.optim.Adam(params,
269 | lr=args.lr,
270 | betas=(0, 0.999),
271 | eps=1e-9,
272 | weight_decay=args.wdecay)
273 |
274 |
275 | ###############################################################################
276 | # Training code
277 | ###############################################################################
278 |
279 | @torch.no_grad()
280 | def valid():
281 | # Turn on evaluation mode which disables dropout.
282 | model.eval()
283 | total_loss = 0
284 | total_datapoints = 0
285 | for sents, lbls in corpus.stream(corpus.valid_set, args.batch_size * 2):
286 | count = lbls.shape[0]
287 | sents = torch.from_numpy(sents)
288 | lbls = torch.from_numpy(lbls)
289 | if args.cuda:
290 | sents = sents.cuda()
291 | lbls = lbls.cuda()
292 | lin_output = model(sents)
293 | total_loss += torch.sum(
294 | torch.argmax(lin_output, dim=1) == lbls
295 | ).float().data
296 | total_datapoints += count
297 | accs = total_loss.item() / total_datapoints
298 | return accs
299 |
300 | @torch.no_grad()
301 | def evaluate():
302 | # Turn on evaluation mode which disables dropout.
303 | model.eval()
304 | model.encoder.OM_forward.nslot = args.nslot * 2
305 |
306 | accs = []
307 | global_loss = 0
308 | global_datapoints = 0
309 | for l in range(13):
310 | total_loss = 0
311 | total_datapoints = 0
312 | for sents, lbls in corpus.stream(corpus.test_set[l], args.batch_size * 2):
313 | count = lbls.shape[0]
314 | sents = torch.from_numpy(sents)
315 | lbls = torch.from_numpy(lbls)
316 | if args.cuda:
317 | sents = sents.cuda()
318 | lbls = lbls.cuda()
319 | lin_output = model(sents)
320 | total_loss += torch.sum(
321 | torch.argmax(lin_output, dim=1) == lbls
322 | ).float().data.item()
323 | total_datapoints += count
324 | accs.append(total_loss / total_datapoints if total_datapoints > 0 else -1)
325 | global_loss += total_loss
326 | global_datapoints += total_datapoints
327 |
328 | accs.append(global_loss / global_datapoints)
329 |
330 | model.encoder.OM_forward.nslot = args.nslot
331 | return accs
332 |
333 |
334 | def train():
335 | # Turn on training mode which enables dropout.
336 | total_loss = 0
337 | total_acc = 0
338 | start_time = time.time()
339 | batch = 0
340 | for sents, lbls in corpus.stream(corpus.train_set, args.batch_size,
341 | shuffle=True):
342 | sents = torch.from_numpy(sents)
343 | lbls = torch.from_numpy(lbls)
344 | if args.cuda:
345 | sents = sents.cuda()
346 | lbls = lbls.cuda()
347 |
348 | model.train()
349 | optimizer.zero_grad()
350 |
351 | lin_output = model(sents)
352 | loss = model.cost(lin_output, lbls)
353 | acc = torch.mean(
354 | (torch.argmax(lin_output, dim=1) == lbls).float())
355 | loss.backward()
356 |
357 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
358 | if args.clip:
359 | torch.nn.utils.clip_grad_norm_(params, args.clip)
360 | optimizer.step()
361 |
362 | total_loss += loss.detach().data
363 | total_acc += acc.detach().data
364 | if batch % args.log_interval == 0 and batch > 0:
365 | elapsed = time.time() - start_time
366 | print(
367 | '| epoch {:3d} '
368 | '| lr {:05.5f} | ms/batch {:5.2f} '
369 | '| loss {:5.2f} | acc {:0.2f}'.format(
370 | epoch,
371 | optimizer.param_groups[0]['lr'],
372 | elapsed * 1000 / args.log_interval,
373 | total_loss.item() / args.log_interval,
374 | total_acc.item() / args.log_interval))
375 | total_loss = 0
376 | total_acc = 0
377 | start_time = time.time()
378 | ###
379 | batch += 1
380 |
381 | @torch.no_grad()
382 | def genparse():
383 | model.eval()
384 | model.encoder.OM_forward.nslot = args.nslot * 2
385 |
386 | np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format})
387 | pred_tree_list = []
388 | targ_tree_list = []
389 | for l in range(13):
390 | for sents, lbls in corpus.stream(corpus.test_set[l], args.batch_size * 2):
391 | sents = torch.from_numpy(sents)
392 | if args.cuda:
393 | sents = sents.cuda()
394 |
395 | # hidden = model.encoder.init_hidden(sents.size(1))
396 | # emb = model.drop(model.embedding(sents))
397 | # raw_output, probs_batch, _ = model.encoder(emb, hidden)
398 |
399 | model(sents)
400 | probs_batch = model.probs
401 |
402 | for i in range(sents.size(1)):
403 | probs = probs_batch[:, i].view(-1, args.nslot * 2)
404 | # self.distance = (torch.cumsum(self.probs, dim=-1) < 0.5).sum(dim=-1)
405 |
406 | distance = torch.argmax(probs, dim=-1)
407 | distance[0] = args.nslot * 2
408 | sen = [corpus.num2char[x]
409 | for x in sents[:, i] if x < len(corpus.num2char)]
410 | if len(sen) < 2:
411 | continue
412 | depth = distance[:len(sen)]
413 | probs = probs.data.cpu().numpy()
414 |
415 | parse_tree = remove_bracket(build_tree(depth, sen))
416 | sen_tree = char2tree(sen)
417 |
418 | pred_tree_list.append(parse_tree)
419 | targ_tree_list.append(sen_tree)
420 |
421 | if np.random.randint(0, 100) > 0:
422 | continue
423 | print()
424 | for i in range(len(sen)):
425 | print('%5s\t%2.2f\t%s' % (sen[i], distance[i], plot(probs[i], 1)))
426 |
427 | print(' '.join(sen))
428 | # print(sen_tree)
429 | print(parse_tree)
430 | print('')
431 |
432 | evalb(pred_tree_list, targ_tree_list)
433 |
434 | model.encoder.OM_forward.nslot = args.nslot
435 |
436 |
437 | if __name__ == "__main__":
438 | # Loop over epochs.
439 | if not args.test_only:
440 | lr = args.lr
441 | stored_loss = 0.
442 |
443 | # At any point you can hit Ctrl + C to break out of training early.
444 | try:
445 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
446 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=2, threshold=0)
447 | for epoch in range(1, args.epochs + 1):
448 | epoch_start_time = time.time()
449 | train()
450 | val_loss = valid()
451 | test_loss = evaluate()
452 |
453 | print('-' * 89)
454 | print(
455 | '| end of epoch {:3d} '
456 | '| time: {:5.2f}s '
457 | '| valid acc: {:.2f} '
458 | '|\n'.format(
459 | epoch,
460 | (time.time() - epoch_start_time),
461 | val_loss
462 | ),
463 | ', '.join(str('{:0.2f}'.format(v)) for v in test_loss)
464 | )
465 |
466 | if val_loss > stored_loss:
467 | model_save(args.save)
468 | print('Saving model (new best validation)')
469 | stored_loss = val_loss
470 | print('-' * 89)
471 |
472 | scheduler.step(val_loss)
473 | except KeyboardInterrupt:
474 | print('-' * 89)
475 | print('Exiting from training early')
476 | # Load the best saved model.
477 | model_load(args.save)
478 |
479 | genparse()
480 |
481 | test_loss = evaluate()
482 | val_loss = valid()
483 | print('-' * 89)
484 | print(
485 | '| valid acc: {:.2f} '
486 | '|\n'.format(
487 | val_loss
488 | ),
489 | ', '.join(str('{:0.2f}'.format(v)) for v in test_loss)
490 | )
491 |
492 | # report_results([dict(
493 | # name='val_loss',
494 | # type='objective',
495 | # value=val_loss)])
496 |
497 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.13.3
2 | matplotlib>=3.0.3
3 | python_gflags
4 | nltk
5 | spacy
6 | torch
7 | torchtext
--------------------------------------------------------------------------------
/sentiment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import os
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim.lr_scheduler as lr_scheduler
11 |
12 | import ordered_memory
13 | from utils.hinton import plot
14 | from utils.locked_dropout import LockedDropout
15 | from utils.utils import build_tree
16 |
17 |
18 | class SSTClassifier(nn.Module):
19 | def __init__(self, args, elmo=None, glove=None):
20 | super(SSTClassifier, self).__init__()
21 |
22 | self.args = args
23 | self.padding_idx = args.padding_idx
24 |
25 | ninp = args.emsize
26 | if ninp > 0:
27 | self.embedding = nn.Embedding(
28 | args.ntoken, ninp,
29 | padding_idx=self.padding_idx,
30 | )
31 | else:
32 | self.embedding = None
33 |
34 | self.elmo = elmo
35 | if elmo is not None:
36 | ninp += 1024
37 |
38 | self.glove = glove
39 | if glove is not None:
40 | ninp += 300
41 |
42 | self.lockdrop = LockedDropout(dropout=args.dropouti)
43 |
44 | self.encoder = ordered_memory.OrderedMemory(ninp, args.nhid, args.nslot,
45 | dropout=args.dropout, dropoutm=args.dropoutm,
46 | bidirection=args.bidirection)
47 |
48 | self.mlp = nn.Sequential(
49 | nn.Dropout(args.dropouto),
50 | nn.Linear(args.nhid, args.nhid),
51 | nn.ReLU(),
52 | nn.Dropout(args.dropouto),
53 | nn.Linear(args.nhid, args.nout),
54 | )
55 |
56 | self.drop_input = nn.Dropout(args.dropouti)
57 | self.cost = nn.CrossEntropyLoss()
58 |
59 | def forward(self, input):
60 | if self.elmo is not None:
61 | input_elmo, input_torchtext = input
62 | else:
63 | input_torchtext = input
64 | mask = (input_torchtext != self.padding_idx)
65 |
66 | emb_list = []
67 | if self.embedding is not None:
68 | emb_torchtext = self.embedding(input_torchtext)
69 | emb_list.append(emb_torchtext)
70 | if self.glove is not None:
71 | emb_glove = self.glove(input_torchtext).detach()
72 | emb_list.append(emb_glove)
73 | if self.elmo is not None:
74 | emb_elmo = self.elmo(input_elmo)
75 | assert (mask.long() == emb_elmo['mask']).all()
76 | emb_elmo = emb_elmo['elmo_representations'][0]
77 | emb_list.append(emb_elmo)
78 | emb = torch.cat(emb_list, dim=-1)
79 |
80 | emb.transpose_(0, 1)
81 | mask.transpose_(0, 1)
82 | emb = self.lockdrop(emb)
83 |
84 | output = self.encoder(emb, mask)
85 |
86 | output = self.mlp(output)
87 |
88 | return output
89 |
90 | @staticmethod
91 | def load_model(input_path):
92 | state = torch.load(input_path)
93 | print('Loading model from %s' % input_path)
94 | model = SSTClassifier(state['args'])
95 | model.load_state_dict(state['state_dict'])
96 | return model
97 |
98 | def save(self, output_path):
99 | state = dict(args=self.args,
100 | state_dict=self.state_dict())
101 | torch.save(state, output_path)
102 |
103 | def set_pretrained_embeddings(self, ext_embeddings, ext_word_to_index, word_to_index, finetune=False):
104 | assert hasattr(self, 'embedding')
105 | embeddings = self.embedding.weight.data.cpu().numpy()
106 | for word, index in word_to_index.items():
107 | if word in ext_word_to_index:
108 | embeddings[index] = ext_embeddings[ext_word_to_index[word]]
109 | embeddings = torch.from_numpy(embeddings).to(self.embedding.weight.device)
110 | self.embedding.weight.data.set_(embeddings)
111 | self.embedding.weight.requires_grad = finetune
112 |
113 |
114 | def model_save(fn):
115 | if args.philly:
116 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
117 | with open(fn, 'wb') as f:
118 | # torch.save([model, optimizer], f)
119 | torch.save({
120 | 'epoch': epoch,
121 | 'model_state_dict': model.state_dict(),
122 | 'optimizer_state_dict': optimizer.state_dict(),
123 | 'loss': val_loss
124 | }, f)
125 |
126 |
127 | def model_load(fn):
128 | global model, optimizer
129 | if args.philly:
130 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
131 | with open(fn, 'rb') as f:
132 | checkpoint = torch.load(f)
133 | model.load_state_dict(checkpoint['model_state_dict'])
134 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
135 | epoch = checkpoint['epoch']
136 | val_loss = checkpoint['loss']
137 |
138 |
139 | ###############################################################################
140 | # Training code
141 | ###############################################################################
142 |
143 |
144 | def evaluate(data_iter):
145 | # Turn on evaluation mode which disables dropout.
146 | model.eval()
147 | total_loss = 0
148 | total_datapoints = 0
149 | for batch, data in enumerate(data_iter):
150 | sents = data.text
151 | lbls = data.label
152 | count = lbls.shape[0]
153 | lin_output = model(sents)
154 | total_loss += torch.sum(
155 | torch.argmax(lin_output, dim=1) == lbls
156 | ).float().data
157 | total_datapoints += count
158 |
159 | return total_loss.item() / total_datapoints
160 |
161 |
162 | def train():
163 | # Turn on training mode which enables dropout.
164 | total_loss = 0
165 | total_acc = 0
166 | start_time = time.time()
167 | for batch, data in enumerate(train_iter):
168 | sents = data.text
169 | lbls = data.label
170 |
171 | model.train()
172 | optimizer.zero_grad()
173 |
174 | lin_output = model(sents)
175 | loss = model.cost(lin_output, lbls)
176 | acc = torch.mean(
177 | (torch.argmax(lin_output, dim=1) == lbls).float())
178 | loss.backward()
179 |
180 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
181 | if args.clip:
182 | torch.nn.utils.clip_grad_norm_(params, args.clip)
183 | optimizer.step()
184 |
185 | total_loss += loss.detach().data
186 | total_acc += acc.detach().data
187 | if batch % args.log_interval == 0 and batch > 0:
188 | elapsed = time.time() - start_time
189 | print(
190 | '| epoch {:3d} '
191 | '| {:5d}/{:5d} batches '
192 | '| lr {:05.5f} | ms/batch {:5.2f} '
193 | '| loss {:5.2f} | acc {:0.2f}'.format(
194 | epoch,
195 | batch, len(train_iter),
196 | optimizer.param_groups[0]['lr'],
197 | elapsed * 1000 / args.log_interval,
198 | total_loss.item() / args.log_interval,
199 | total_acc.item() / args.log_interval))
200 | total_loss = 0
201 | total_acc = 0
202 | start_time = time.time()
203 | ###
204 | batch += 1
205 |
206 |
207 | def generate_parse():
208 | from nltk import Tree
209 | from utils import evalb
210 |
211 | batch = []
212 | pred_tree_list = []
213 | targ_tree_list = []
214 |
215 | def process_batch():
216 | nonlocal batch, pred_tree_list, targ_tree_list
217 | idx = TEXT.process([example['sents'] for example in batch], device=hidden[0].device)
218 |
219 | model(idx)
220 |
221 | probs = model.encoder.probs
222 | distance = torch.argmax(probs, dim=-1)
223 | distance[0] = args.nslot
224 | probs = probs.data.cpu().numpy()
225 |
226 | for i, example in enumerate(batch):
227 | sents = example['sents']
228 | sents_tree = example['sents_tree']
229 | depth = distance[:, i]
230 |
231 | parse_tree = build_tree(depth, sents)
232 |
233 | if len(sents) <= 100:
234 | pred_tree_list.append(parse_tree)
235 | targ_tree_list.append(sents_tree)
236 |
237 | if i == 0:
238 | for j in range(len(sents)):
239 | print('%20s\t%2.2f\t%s' % (sents[j], depth[j], plot(probs[j, i], 1.)))
240 | print(parse_tree)
241 | print(sents_tree)
242 | print('-' * 80)
243 |
244 | batch = []
245 |
246 | np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format})
247 |
248 | model.eval()
249 | hidden = model.encoder.init_hidden(1)
250 |
251 | fin = open('.data/sst/trees/dev.txt', 'r')
252 | for line in fin:
253 | line = line.lower()
254 | sents_tree = Tree.fromstring(line)
255 | sents = sents_tree.leaves()
256 | batch.append({'sents_tree': sents_tree, 'sents': sents})
257 |
258 | if len(batch) == 16:
259 | process_batch()
260 |
261 | if len(batch) > 0:
262 | process_batch()
263 |
264 | evalb(pred_tree_list, targ_tree_list, evalb_path='./EVALB')
265 |
266 |
267 | if __name__ == "__main__":
268 | parser = argparse.ArgumentParser(description='')
269 |
270 | parser.add_argument('--fine-grained', action='store_true',
271 | help='use fine grained label')
272 | parser.add_argument('--subtrees', action='store_true',
273 | help='use fine subtrees')
274 | parser.add_argument('--glove', action='store_true',
275 | help='use pretrained glove embedding')
276 | parser.add_argument('--elmo', action='store_true',
277 | help='use pretrained elmo')
278 | parser.add_argument('--bidirection', action='store_true',
279 | help='use bidirection model')
280 | parser.add_argument('--emsize', type=int, default=0,
281 | help='size of word embeddings')
282 | parser.add_argument('--nhid', type=int, default=300,
283 | help='number of hidden units per layer')
284 | parser.add_argument('--nslot', type=int, default=15,
285 | help='number of memory slots')
286 | parser.add_argument('--lr', type=float, default=0.001,
287 | help='initial learning rate')
288 | parser.add_argument('--clip', type=float, default=1.,
289 | help='gradient clipping')
290 | parser.add_argument('--epochs', type=int, default=50,
291 | help='upper epoch limit')
292 | parser.add_argument('--batch_size', type=int, default=128, metavar='N',
293 | help='batch size')
294 | parser.add_argument('--dropout', type=float, default=0.2,
295 | help='dropout applied to layers (0 = no dropout)')
296 | parser.add_argument('--dropouti', type=float, default=0.3,
297 | help='dropout for input embedding layers (0 = no dropout)')
298 | parser.add_argument('--dropouto', type=float, default=0.4,
299 | help='dropout applied to layers (0 = no dropout)')
300 | parser.add_argument('--dropoutm', type=float, default=0.2,
301 | help='dropout applied to memory (0 = no dropout)')
302 | parser.add_argument('--attention', type=str, default='softmax',
303 | help='attention method')
304 | parser.add_argument('--seed', type=int, default=1111,
305 | help='random seed')
306 | parser.add_argument('--cuda', action='store_true',
307 | help='use CUDA')
308 | parser.add_argument('--log-interval', type=int, default=100, metavar='N',
309 | help='report interval')
310 | parser.add_argument('--test-only', action='store_true',
311 | help='Test only')
312 |
313 | randomhash = ''.join(str(time.time()).split('.'))
314 | parser.add_argument('--name', type=str, default=randomhash + '.pt',
315 | help='exp name')
316 | parser.add_argument('--wdecay', type=float, default=1.2e-6,
317 | help='weight decay applied to all weights')
318 | parser.add_argument('--std', action='store_true',
319 | help='use standard LSTM')
320 | parser.add_argument('--philly', action='store_true',
321 | help='Use philly cluster')
322 | parser.add_argument('--resume', action='store_true',
323 | help='resume from checkpoint')
324 | args = parser.parse_args()
325 |
326 | # Set the random seed manually for reproducibility.
327 | np.random.seed(args.seed)
328 | torch.manual_seed(args.seed)
329 | if torch.cuda.is_available():
330 | if not args.cuda:
331 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
332 | else:
333 | torch.cuda.manual_seed(args.seed)
334 |
335 | ###############################################################################
336 | # Load data
337 | ###############################################################################
338 | from torchtext import data
339 | from torchtext import datasets
340 | from torchtext.vocab import GloVe
341 |
342 | # set up fields
343 | TEXT = data.Field(lower=True, include_lengths=False, batch_first=True)
344 | LABEL = data.Field(sequential=False, unk_token=None)
345 |
346 | # make splits for data
347 | filter_pred = None
348 | if not args.fine_grained:
349 | filter_pred = lambda ex: ex.label != 'neutral'
350 | train_set, dev_set, test_set = datasets.SST.splits(
351 | TEXT, LABEL,
352 | train_subtrees=args.subtrees,
353 | fine_grained=args.fine_grained,
354 | filter_pred=filter_pred
355 | )
356 |
357 | # build the vocabulary
358 | if args.glove:
359 | TEXT.build_vocab(train_set, dev_set, test_set, min_freq=1, vectors=GloVe(name='840B', dim=300))
360 | else:
361 | TEXT.build_vocab(train_set, min_freq=2)
362 | LABEL.build_vocab(train_set)
363 |
364 | # make iterator for splits
365 | train_iter, dev_iter, test_iter = data.BucketIterator.splits(
366 | (train_set, dev_set, test_set),
367 | batch_size=args.batch_size,
368 | device='cuda' if args.cuda else 'cpu'
369 | )
370 |
371 | args.__dict__.update({'ntoken': len(TEXT.vocab),
372 | 'nout': len(LABEL.vocab),
373 | 'padding_idx': TEXT.vocab.stoi['']})
374 |
375 | if args.elmo:
376 | from allennlp.modules.elmo import Elmo, batch_to_ids
377 |
378 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
379 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
380 |
381 | elmo = Elmo(options_file, weight_file, 1, requires_grad=False, dropout=0)
382 |
383 | torchtext_process = TEXT.process
384 |
385 |
386 | def elmo_process(batch, device):
387 | elmo_tensor = batch_to_ids(batch)
388 | elmo_tensor = elmo_tensor.to(device=device)
389 | torchtext_tensor = torchtext_process(batch, device)
390 | return (elmo_tensor, torchtext_tensor)
391 |
392 |
393 | TEXT.process = elmo_process
394 | else:
395 | elmo = None
396 |
397 | if args.glove:
398 | glove = torch.nn.Embedding(args.ntoken, 300, _weight=TEXT.vocab.vectors)
399 | else:
400 | glove = None
401 |
402 | model = SSTClassifier(args, elmo=elmo, glove=glove)
403 |
404 | if args.resume:
405 | model_load(args.name)
406 |
407 | if args.cuda:
408 | model = model.cuda()
409 |
410 | params = list(model.parameters())
411 | total_params = sum(x.size()[0] * x.size()[1]
412 | if len(x.size()) > 1 else x.size()[0]
413 | for x in params if x.size())
414 | print('Args:', args)
415 | print('Model total parameters:', total_params)
416 |
417 | optimizer = torch.optim.Adam(params,
418 | lr=args.lr,
419 | betas=(0, 0.999),
420 | eps=1e-9,
421 | weight_decay=args.wdecay)
422 |
423 | if not args.test_only:
424 | # Loop over epochs.
425 | lr = args.lr
426 | stored_loss = 0.
427 |
428 | # At any point you can hit Ctrl + C to break out of training early.
429 | try:
430 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=2, threshold=0)
431 | for epoch in range(1, args.epochs + 1):
432 | epoch_start_time = time.time()
433 | train()
434 | val_loss = evaluate(dev_iter)
435 | test_loss = evaluate(test_iter)
436 |
437 | print('-' * 89)
438 | print(
439 | '| end of epoch {:3d} '
440 | '| time: {:5.2f}s '
441 | '| valid acc: {:.4f} '
442 | '| test acc: {:.4f} '
443 | '|\n'.format(
444 | epoch,
445 | (time.time() - epoch_start_time),
446 | val_loss,
447 | test_loss
448 | )
449 | )
450 |
451 | if val_loss > stored_loss:
452 | model_save(args.name)
453 | print('Saving model (new best validation)')
454 | stored_loss = val_loss
455 | print('-' * 89)
456 |
457 | scheduler.step(val_loss)
458 |
459 | except KeyboardInterrupt:
460 | print('-' * 89)
461 | print('Exiting from training early')
462 |
463 | model_load(args.name)
464 | test_loss = evaluate(test_iter)
465 | val_loss = evaluate(dev_iter)
466 |
467 | try:
468 | generate_parse()
469 | except:
470 | print('Unable to parse')
471 |
472 | print('-' * 89)
473 | print(
474 | '| valid acc: {:.4f} '
475 | '| test acc: {:.4f} '
476 | '|\n'.format(
477 | val_loss,
478 | test_loss
479 | )
480 | )
481 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikangshen/Ordered-Memory/2a0d3a22fb70216993a7faf25de3515f42e10431/utils/__init__.py
--------------------------------------------------------------------------------
/utils/hinton.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import numpy as np
3 | chars = [" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"]
4 |
5 |
6 | class BarHack(str):
7 |
8 | def __str__(self):
9 | return self.internal
10 |
11 | def __len__(self):
12 | return 1
13 |
14 |
15 | def plot(arr, max_val=None):
16 | if max_val is None:
17 | max_arr = arr
18 | max_val = max(abs(np.max(max_arr)), abs(np.min(max_arr)))
19 |
20 | opts = np.get_printoptions()
21 | np.set_printoptions(edgeitems=500)
22 | fig = np.array2string(arr,
23 | formatter={
24 | 'float_kind': lambda x: visual(x, max_val),
25 | 'int_kind': lambda x: visual(x, max_val)},
26 | max_line_width=5000
27 | )
28 | np.set_printoptions(**opts)
29 |
30 | return fig
31 |
32 |
33 | def visual(val, max_val):
34 | val = np.clip(val, 0, max_val)
35 | if abs(val) == max_val:
36 | step = len(chars) - 1
37 | else:
38 | step = int(abs(float(val) / max_val) * len(chars))
39 | colourstart = ""
40 | colourend = ""
41 | if val < 0:
42 | colourstart, colourend = '\033[90m', '\033[0m'
43 | return colourstart + chars[step] + colourend
44 |
--------------------------------------------------------------------------------
/utils/listops_data.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import random
3 | import itertools
4 | import time
5 | import sys
6 |
7 | import numpy as np
8 |
9 | from utils.utils import ConvertBinaryBracketedSeq
10 | NUMBERS = list(range(10))
11 | PADDING_TOKEN = "_PAD"
12 | UNK_TOKEN = "_"
13 | SENTENCE_PADDING_SYMBOL = 0
14 |
15 | FIXED_VOCABULARY = {str(x): i + 1 for i, x in enumerate(NUMBERS)}
16 | FIXED_VOCABULARY.update({
17 | PADDING_TOKEN: 0,
18 | "[MIN": len(FIXED_VOCABULARY) + 1,
19 | "[MAX": len(FIXED_VOCABULARY) + 2,
20 | "[FIRST": len(FIXED_VOCABULARY) + 3,
21 | "[LAST": len(FIXED_VOCABULARY) + 4,
22 | "[MED": len(FIXED_VOCABULARY) + 5,
23 | "[SM": len(FIXED_VOCABULARY) + 6,
24 | "[PM": len(FIXED_VOCABULARY) + 7,
25 | "[FLSUM": len(FIXED_VOCABULARY) + 8,
26 | "]": len(FIXED_VOCABULARY) + 9
27 | })
28 | assert len(set(FIXED_VOCABULARY.values())) == len(list(FIXED_VOCABULARY.values()))
29 |
30 |
31 | SENTENCE_PAIR_DATA = False
32 | OUTPUTS = list(range(10))
33 | LABEL_MAP = {str(x): i for i, x in enumerate(OUTPUTS)}
34 |
35 | Node = namedtuple('Node', 'tag span')
36 |
37 |
38 | def spans(transitions, tokens=None):
39 | n = (len(transitions) + 1) // 2
40 | stack = []
41 | buf = [Node("leaf", (l, r)) for l, r in zip(list(range(n)), list(range(1, n + 1)))]
42 | buf = list(reversed(buf))
43 |
44 | nodes = []
45 | reduced = [False] * n
46 |
47 | def SHIFT(item):
48 | nodes.append(item)
49 | return item
50 |
51 | def REDUCE(l, r):
52 | tag = None
53 | i = r.span[1] - 1
54 | if tokens is not None and tokens[i] == ']' and not reduced[i]:
55 | reduced[i] = True
56 | tag = "struct"
57 | new_stack_item = Node(tag=tag, span=(l.span[0], r.span[1]))
58 | nodes.append(new_stack_item)
59 | return new_stack_item
60 |
61 | for t in transitions:
62 | if t == 0:
63 | stack.append(SHIFT(buf.pop()))
64 | elif t == 1:
65 | r, l = stack.pop(), stack.pop()
66 | stack.append(REDUCE(l, r))
67 |
68 | return nodes
69 |
70 | def PreprocessDataset(
71 | dataset,
72 | vocabulary,
73 | seq_length,
74 | eval_mode=False,
75 | sentence_pair_data=False,
76 | simple=True,
77 | allow_cropping=False,
78 | pad_from_left=True):
79 | dataset = TrimDataset(
80 | dataset,
81 | seq_length,
82 | eval_mode=eval_mode,
83 | sentence_pair_data=sentence_pair_data,
84 | logger=None,
85 | allow_cropping=allow_cropping)
86 | dataset = TokensToIDs(
87 | vocabulary,
88 | dataset,
89 | sentence_pair_data=sentence_pair_data)
90 |
91 | dataset = CropAndPadSimple(
92 | dataset,
93 | seq_length,
94 | logger=None,
95 | sentence_pair_data=sentence_pair_data,
96 | allow_cropping=allow_cropping,
97 | pad_from_left=pad_from_left)
98 |
99 | if sentence_pair_data:
100 | X = np.transpose(np.array([[example["premise_tokens"] for example in dataset],
101 | [example["hypothesis_tokens"] for example in dataset]],
102 | dtype=np.int32), (1, 2, 0))
103 | if simple:
104 | transitions = np.zeros((len(dataset), 2, 0))
105 | num_transitions = np.transpose(np.array(
106 | [[len(np.array(example["premise_tokens"]).nonzero()[0]) for example in dataset],
107 | [len(np.array(example["hypothesis_tokens"]).nonzero()[0]) for example in dataset]],
108 | dtype=np.int32), (1, 0))
109 | else:
110 | transitions = np.transpose(np.array([[example["premise_transitions"] for example in dataset],
111 | [example["hypothesis_transitions"] for example in dataset]],
112 | dtype=np.int32), (1, 2, 0))
113 | num_transitions = np.transpose(np.array(
114 | [[example["num_premise_transitions"] for example in dataset],
115 | [example["num_hypothesis_transitions"] for example in dataset]],
116 | dtype=np.int32), (1, 0))
117 | else:
118 | X = np.array([example["tokens"] for example in dataset],
119 | dtype=np.int32)
120 | if simple:
121 | transitions = np.zeros((len(dataset), 0))
122 | num_transitions = np.array(
123 | [len(np.array(example["tokens"]).nonzero()[0]) for example in dataset],
124 | dtype=np.int32)
125 | else:
126 | transitions = np.array([example["transitions"]
127 | for example in dataset], dtype=np.int32)
128 | num_transitions = np.array(
129 | [example["num_transitions"] for example in dataset],
130 | dtype=np.int32)
131 |
132 | y = np.array(
133 | [LABEL_MAP[example["label"]] for example in dataset],
134 | dtype=np.int32)
135 |
136 | # NP Array of Strings
137 | example_ids = np.array([example["example_id"] for example in dataset])
138 |
139 | return X, transitions, y, num_transitions, example_ids
140 |
141 |
142 | def load_data(path, lowercase=None, choose=lambda x: True, eval_mode=False):
143 | examples = []
144 | with open(path) as f:
145 | for example_id, line in enumerate(f):
146 | line = line.strip()
147 | label, seq = line.split('\t')
148 | if len(seq) <= 1:
149 | continue
150 |
151 | tokens, transitions = ConvertBinaryBracketedSeq(
152 | seq.split(' '))
153 |
154 | example = {}
155 | example["label"] = label
156 | example["sentence"] = seq
157 | example["tokens"] = tokens
158 | example["transitions"] = transitions
159 | example["example_id"] = str(example_id)
160 |
161 | examples.append(example)
162 | return examples
163 |
164 |
165 | def load_data_and_embeddings(args, training_data_path, eval_data_path):
166 | raw_training_data = load_data(training_data_path, None, eval_mode=False)
167 | raw_eval_data = load_data(eval_data_path, None, eval_mode=True)
168 | import copy
169 | raw_eval_data_copy = copy.deepcopy(raw_eval_data)
170 | # Prepare the vocabulary
171 | vocabulary = FIXED_VOCABULARY
172 | print("In fixed vocabulary mode. Training embeddings from scratch.")
173 | initial_embeddings = None
174 | # Trim dataset, convert token sequences to integer sequences, crop, and
175 | # pad.
176 | print("Preprocessing training data.")
177 | training_data = PreprocessDataset(
178 | raw_training_data,
179 | vocabulary,
180 | args.seq_len, #def to 100
181 | eval_mode=False,
182 | sentence_pair_data=SENTENCE_PAIR_DATA,
183 | simple=True,
184 | allow_cropping=False,
185 | pad_from_left=True)
186 | training_data_iter = MakeTrainingIterator(training_data, args.batch_size, args.smart_batching, args.use_peano, sentence_pair_data=SENTENCE_PAIR_DATA)
187 | training_data_length = len(training_data[0])
188 | # Preprocess eval sets.
189 | eval_data = PreprocessDataset(
190 | raw_eval_data,
191 | vocabulary,
192 | args.seq_len_test,
193 | eval_mode=True,
194 | sentence_pair_data=SENTENCE_PAIR_DATA,
195 | simple=True, #for RNNs and shit
196 | allow_cropping=True,
197 | pad_from_left=True)
198 | eval_it = MakeEvalIterator(eval_data, args.batch_size_test, None,
199 | bucket_eval=True,
200 | shuffle=False)
201 |
202 | return vocabulary, initial_embeddings, training_data_iter, eval_it, training_data_length, raw_eval_data_copy
203 |
204 |
205 | def MakeTrainingIterator(
206 | sources,
207 | batch_size,
208 | smart_batches=True,
209 | use_peano=True,
210 | sentence_pair_data=True,
211 | pad_from_left=True):
212 | # Make an iterator that exposes a dataset as random minibatches.
213 |
214 | def get_key(num_transitions):
215 | if use_peano and sentence_pair_data:
216 | prem_len, hyp_len = num_transitions
217 | key = Peano(prem_len, hyp_len)
218 | return key
219 | else:
220 | if not isinstance(num_transitions, list):
221 | num_transitions = [num_transitions]
222 | return max(num_transitions)
223 |
224 | def build_batches():
225 | dataset_size = len(sources[0])
226 | order = list(range(dataset_size))
227 | random.shuffle(order)
228 | order = np.array(order)
229 |
230 | num_splits = 10 # TODO: Should we be smarter about split size?
231 | order_limit = len(order) // num_splits * num_splits
232 | order = order[:order_limit]
233 | order_splits = np.split(order, num_splits)
234 | batches = []
235 |
236 | for split in order_splits:
237 | # Put indices into buckets based on example length.
238 | keys = []
239 | for i in split:
240 | num_transitions = sources[3][i]
241 | key = get_key(num_transitions)
242 | keys.append((i, key))
243 | keys = sorted(keys, key=lambda __key: __key[1])
244 |
245 | # Group indices from buckets into batches, so that
246 | # examples in each batch have similar length.
247 | batch = []
248 | for i, _ in keys:
249 | batch.append(i)
250 | if len(batch) == batch_size:
251 | batches.append(batch)
252 | batch = []
253 | return batches
254 |
255 | def batch_iter():
256 | batches = build_batches()
257 | num_batches = len(batches)
258 | idx = -1
259 | order = list(range(num_batches))
260 | random.shuffle(order)
261 |
262 | while True:
263 | idx += 1
264 | if idx >= num_batches:
265 | # Start another epoch.
266 | batches = build_batches()
267 | num_batches = len(batches)
268 | idx = 0
269 | order = list(range(num_batches))
270 | random.shuffle(order)
271 | batch_indices = batches[order[idx]]
272 | yield tuple(source[batch_indices] for source in sources), num_batches
273 |
274 | def data_iter():
275 | dataset_size = len(sources[0])
276 | start = -1 * batch_size
277 | order = list(range(dataset_size))
278 | random.shuffle(order)
279 |
280 | while True:
281 | start += batch_size
282 | if start > dataset_size - batch_size:
283 | # Start another epoch.
284 | start = 0
285 | random.shuffle(order)
286 | batch_indices = order[start:start + batch_size]
287 | yield tuple(source[batch_indices] for source in sources)
288 |
289 | train_iter = batch_iter if smart_batches else data_iter
290 |
291 | return train_iter()
292 |
293 | def MakeBucketEvalIterator(sources, batch_size):
294 | # Order in eval should not matter. Use batches sorted by length for speed
295 | # improvement.
296 |
297 | def single_sentence_key(num_transitions):
298 | return num_transitions
299 |
300 | def sentence_pair_key(num_transitions):
301 | sent1_len, sent2_len = num_transitions
302 | return Peano(sent1_len, sent2_len)
303 |
304 | dataset_size = len(sources[0])
305 |
306 | # Sort examples by length. From longest to shortest.
307 | num_transitions = sources[3]
308 | sort_key = sentence_pair_key if len(
309 | num_transitions.shape) == 2 else single_sentence_key
310 | order = sorted(zip(list(range(dataset_size)), num_transitions),
311 | key=lambda x: sort_key(x[1]))
312 | order = list(reversed(order))
313 | order = [x[0] for x in order]
314 |
315 | num_batches = dataset_size // batch_size
316 | batches = []
317 |
318 | # Roll examples into batches so they have similar length.
319 | for i in range(num_batches):
320 | batch_indices = order[i * batch_size:(i + 1) * batch_size]
321 | batch = tuple(source[batch_indices] for source in sources)
322 | batches.append(batch)
323 |
324 | examples_leftover = dataset_size - num_batches * batch_size
325 |
326 | # Create a short batch:
327 | if examples_leftover > 0:
328 | batch_indices = order[num_batches *
329 | batch_size:num_batches *
330 | batch_size +
331 | examples_leftover]
332 | batch = tuple(source[batch_indices] for source in sources)
333 | batches.append(batch)
334 |
335 | return batches
336 |
337 |
338 | def MakeEvalIterator(sources, batch_size, limit=None, shuffle=False, rseed=123, bucket_eval=False):
339 | return MakeBucketEvalIterator(sources, batch_size)[:limit]
340 |
341 | def TrimDataset(dataset, seq_length, eval_mode=False,
342 | sentence_pair_data=False, logger=None, allow_cropping=False):
343 | """Avoid using excessively long training examples."""
344 |
345 | if sentence_pair_data:
346 | trimmed_dataset = [
347 | example for example in dataset if len(
348 | example["premise_transitions"]) <= seq_length and len(
349 | example["hypothesis_transitions"]) <= seq_length]
350 | else:
351 | trimmed_dataset = [example for example in dataset if
352 | len(example["transitions"]) <= seq_length]
353 |
354 | diff = len(dataset) - len(trimmed_dataset)
355 | if eval_mode:
356 | assert allow_cropping or diff == 0, "allow_eval_cropping is false but there are over-length eval examples."
357 | if logger and diff > 0:
358 | logger.Log(
359 | "Warning: Cropping " +
360 | str(diff) +
361 | " over-length eval examples.")
362 | return dataset
363 | else:
364 | if allow_cropping:
365 | if logger and diff > 0:
366 | logger.Log(
367 | "Cropping " +
368 | str(diff) +
369 | " over-length training examples.")
370 | return dataset
371 | else:
372 | if logger and diff > 0:
373 | logger.Log(
374 | "Discarding " +
375 | str(diff) +
376 | " over-length training examples.")
377 | return trimmed_dataset
378 |
379 | def TokensToIDs(vocabulary, dataset, sentence_pair_data=False):
380 | """Replace strings in original boolean dataset with token IDs."""
381 | if sentence_pair_data:
382 | keys = ["premise_tokens", "hypothesis_tokens"]
383 | else:
384 | keys = ["tokens"]
385 |
386 | tokens = 0
387 | unks = 0
388 | lowers = 0
389 | raises = 0
390 |
391 | for key in keys:
392 | if UNK_TOKEN in vocabulary:
393 | unk_id = vocabulary[UNK_TOKEN]
394 | for example in dataset:
395 | for i, token in enumerate(example[key]):
396 | if token in vocabulary:
397 | example[key][i] = vocabulary[token]
398 | elif token.lower() in vocabulary:
399 | example[key][i] = vocabulary[token.lower()]
400 | lowers += 1
401 | elif token.upper() in vocabulary:
402 | example[key][i] = vocabulary[token.upper()]
403 | raises += 1
404 | else:
405 | example[key][i] = unk_id
406 | unks += 1
407 | tokens += 1
408 | print("Unk rate {:2.6f}%, downcase rate {:2.6f}%, upcase rate {:2.6f}%".format((unks * 100.0 / tokens), (lowers * 100.0 / tokens), (raises * 100.0 / tokens)))
409 | else:
410 | for example in dataset:
411 | example[key] = [vocabulary[token]
412 | for token in example[key]]
413 | return dataset
414 |
415 | def CropAndPadExample(
416 | example,
417 | padding_amount,
418 | target_length,
419 | key,
420 | symbol=0,
421 | logger=None,
422 | allow_cropping=False,
423 | pad_from_left=True):
424 | """
425 | Crop/pad a sequence value of the given dict `example`.
426 | """
427 | if padding_amount < 0:
428 | if not allow_cropping:
429 | raise NotImplementedError(
430 | "Cropping not allowed. "
431 | "Please set seq_length and eval_seq_length to some sufficiently large value or (for non-SPINN models) use --allow_cropping and --allow_eval_cropping..")
432 | # Crop, then pad normally.
433 | if pad_from_left:
434 | example[key] = example[key][-padding_amount:]
435 | else:
436 | example[key] = example[key][:padding_amount]
437 | padding_amount = 0
438 | alternate_side_padding = target_length - \
439 | (padding_amount + len(example[key]))
440 | if pad_from_left:
441 | example[key] = ([symbol] * padding_amount) + \
442 | example[key] + ([symbol] * alternate_side_padding)
443 | else:
444 | example[key] = ([symbol] * alternate_side_padding) + \
445 | example[key] + ([symbol] * padding_amount)
446 |
447 |
448 | def CropAndPadSimple(
449 | dataset,
450 | length,
451 | logger=None,
452 | sentence_pair_data=False,
453 | allow_cropping=True,
454 | pad_from_left=True):
455 | # NOTE: This can probably be done faster in NumPy if it winds up making a
456 | # difference.
457 | if sentence_pair_data:
458 | keys = ["premise_tokens",
459 | "hypothesis_tokens"]
460 | else:
461 | keys = ["tokens"]
462 |
463 | for example in dataset:
464 | for tokens_key in keys:
465 | num_tokens = len(example[tokens_key])
466 | tokens_padding_amount = length - num_tokens
467 | CropAndPadExample(
468 | example,
469 | tokens_padding_amount,
470 | length,
471 | tokens_key,
472 | symbol=SENTENCE_PADDING_SYMBOL,
473 | logger=logger,
474 | allow_cropping=allow_cropping,
475 | pad_from_left=pad_from_left)
476 | return dataset
477 |
478 | def truncate(data, seq_length, max_length, left_padded):
479 | if left_padded:
480 | data = data[:, seq_length - max_length:]
481 | else:
482 | data = data[:, :max_length]
483 | return data
484 |
485 |
486 | def get_batch(batch):
487 | X_batch, transitions_batch, y_batch, num_transitions_batch, example_ids = batch
488 |
489 | # Truncate each batch to max length within the batch.
490 | X_batch_is_left_padded = True
491 | transitions_batch_is_left_padded = True
492 | max_length = np.max(num_transitions_batch)
493 | seq_length = X_batch.shape[1]
494 |
495 | # Truncate batch.
496 | X_batch = truncate(X_batch, seq_length, max_length, X_batch_is_left_padded)
497 | transitions_batch = truncate(transitions_batch, seq_length,
498 | max_length, transitions_batch_is_left_padded)
499 |
500 | return X_batch, transitions_batch, y_batch, num_transitions_batch, example_ids
501 |
--------------------------------------------------------------------------------
/utils/locked_dropout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | class LockedDropout(nn.Module):
6 | def __init__(self, dropout=0.5, dim=0):
7 | super().__init__()
8 |
9 | assert dim in [0, 1]
10 | self.dim = dim
11 | self.dropout = dropout
12 |
13 | def forward(self, x):
14 | assert len(x.size()) == 3
15 | if not self.training or not self.dropout:
16 | return x
17 | if self.dim == 0:
18 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout)
19 | elif self.dim == 1:
20 | m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout)
21 | mask = Variable(m, requires_grad=False) / (1 - self.dropout)
22 | mask = mask.expand_as(x)
23 | return mask * x
24 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | from collections import deque, Counter
2 |
3 | import torch
4 |
5 |
6 | class Dictionary(object):
7 | def __init__(self):
8 | self.word2idx = {}
9 | self.idx2word = []
10 | self.counter = Counter()
11 | self.total = 0
12 |
13 | def add_word(self, word):
14 | if word not in self.word2idx:
15 | self.idx2word.append(word)
16 | self.word2idx[word] = len(self.idx2word) - 1
17 | token_id = self.word2idx[word]
18 | self.counter[token_id] += 1
19 | self.total += 1
20 | return self.word2idx[word]
21 |
22 | def __len__(self):
23 | return len(self.idx2word)
24 |
25 |
26 | def build_tree(depth, sen):
27 | depth = depth
28 | queue = deque(sen)
29 | stack = [queue.popleft()]
30 | head = depth[0] - 1
31 | for point in depth[1:]:
32 | d = point - head
33 | if d > 0:
34 | for _ in range(d):
35 | if len(stack) == 1:
36 | break
37 | x1 = stack.pop()
38 | x2 = stack.pop()
39 | stack.append([x2, x1])
40 | if len(queue) > 0:
41 | stack.append(queue.popleft())
42 | head = point - 1
43 | while len(stack) > 2 and isinstance(stack, list):
44 | x1 = stack.pop()
45 | x2 = stack.pop()
46 | stack.append([x2, x1])
47 | while len(stack) == 1 and isinstance(stack, list):
48 | stack = stack.pop()
49 | return stack
50 |
51 |
52 | def repackage_hidden(h):
53 | """Wraps hidden states in new Tensors,
54 | to detach them from their history."""
55 | if isinstance(h, torch.Tensor):
56 | return h.detach()
57 | elif h is None:
58 | return None
59 | else:
60 | return tuple(repackage_hidden(v) for v in h)
61 |
62 |
63 | def batchify(data, bsz, args):
64 | # Work out how cleanly we can divide the dataset into bsz parts.
65 | nbatch = data.size(0) // bsz
66 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
67 | data = data.narrow(0, 0, nbatch * bsz)
68 | # Evenly divide the data across the bsz batches.
69 | data = data.view(bsz, -1).t().contiguous()
70 | if args.cuda:
71 | data = data.cuda()
72 | return data
73 |
74 |
75 | def get_batch(source, i, args, seq_len=None, evaluation=False):
76 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
77 | data = source[i:i + seq_len]
78 | target = source[i + 1:i + 1 + seq_len].view(-1)
79 | return data, target
80 |
81 |
82 | def evalb(pred_tree_list, targ_tree_list, evalb_path="EVALB"):
83 | import os
84 | import subprocess
85 | import re
86 | import nltk
87 | import tempfile
88 |
89 | temp_path = tempfile.TemporaryDirectory(prefix="evalb-")
90 | # temp_path = './test/'
91 | temp_file_path = os.path.join(temp_path.name, "pred_trees.txt")
92 | temp_targ_path = os.path.join(temp_path.name, "true_trees.txt")
93 | temp_eval_path = os.path.join(temp_path.name, "evals.txt")
94 |
95 | print("Temp: {}, {}".format(temp_file_path, temp_targ_path))
96 | temp_tree_file = open(temp_file_path, "w")
97 | temp_targ_file = open(temp_targ_path, "w")
98 |
99 | for pred_tree, targ_tree in zip(pred_tree_list, targ_tree_list):
100 | def process_str_tree(str_tree):
101 | return re.sub('[ |\n]+', ' ', str_tree)
102 |
103 | def list2tree(node):
104 | if isinstance(node, nltk.Tree):
105 | return node
106 | if isinstance(node, list):
107 | tree = []
108 | for child in node:
109 | tree.append(list2tree(child))
110 | return nltk.Tree('', tree)
111 | elif isinstance(node, str):
112 | return nltk.Tree('', [node])
113 |
114 | if re.search(r'[RRB|rrb]- [0-9]', process_str_tree(str(list2tree(targ_tree)))) is not None:
115 | continue
116 | temp_tree_file.write(process_str_tree(str(list2tree(pred_tree))) + '\n')
117 | temp_targ_file.write(process_str_tree(str(list2tree(targ_tree))) + '\n')
118 |
119 | temp_tree_file.close()
120 | temp_targ_file.close()
121 |
122 | evalb_dir = os.path.join(os.getcwd(), evalb_path)
123 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
124 | evalb_program_path = os.path.join(evalb_dir, "evalb")
125 | command = "{} -p {} {} {} > {}".format(
126 | evalb_program_path,
127 | evalb_param_path,
128 | temp_targ_path,
129 | temp_file_path,
130 | temp_eval_path)
131 |
132 | subprocess.run(command, shell=True)
133 |
134 | with open(temp_eval_path) as infile:
135 | for line in infile:
136 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
137 | if match:
138 | evalb_recall = float(match.group(1))
139 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
140 | if match:
141 | evalb_precision = float(match.group(1))
142 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
143 | if match:
144 | evalb_fscore = float(match.group(1))
145 | break
146 |
147 | temp_path.cleanup()
148 |
149 | print('-' * 80)
150 | print('Evalb Prec:', evalb_precision,
151 | ', Evalb Reca:', evalb_recall,
152 | ', Evalb F1:', evalb_fscore)
153 |
154 | return evalb_fscore
155 |
156 |
157 | def remove_bracket(tree):
158 | if isinstance(tree, str):
159 | if tree in ['(', ')']:
160 | return None
161 | else:
162 | return tree
163 | elif isinstance(tree, list):
164 | new_tree = []
165 | for child in tree:
166 | new_child = remove_bracket(child)
167 | if new_child is not None:
168 | new_tree.append(new_child)
169 | if new_tree == []:
170 | return None
171 | else:
172 | while len(new_tree) == 1 and isinstance(new_tree, list):
173 | new_tree = new_tree[0]
174 | return new_tree
175 |
176 |
177 | def char2tree(s):
178 | stack = []
179 | for w in s:
180 | if w == '(':
181 | stack.append(w)
182 | elif w == ')':
183 | node = []
184 | e = stack.pop()
185 | while not e == '(':
186 | node.append(e)
187 | e = stack.pop()
188 | node = node[::-1]
189 | stack.append(node)
190 | else:
191 | stack.append(w)
192 | while len(stack) == 1 and isinstance(stack, list):
193 | stack = stack[0]
194 | return stack
195 |
196 |
197 |
198 | def makedirs(name):
199 | """helper function for python 2 and 3 to call os.makedirs()
200 | avoiding an error if the directory to be created already exists"""
201 |
202 | import os, errno
203 |
204 | try:
205 | os.makedirs(name)
206 | except OSError as ex:
207 | if ex.errno == errno.EEXIST and os.path.isdir(name):
208 | # ignore existing directory
209 | pass
210 | else:
211 | # a different error happened
212 | raise
213 |
214 |
215 | def ConvertBinaryBracketedSeq(seq):
216 | T_SHIFT = 0
217 | T_REDUCE = 1
218 | T_SKIP = 2
219 |
220 | tokens, transitions = [], []
221 | for item in seq:
222 | if item != "(":
223 | if item != ")":
224 | tokens.append(item)
225 | transitions.append(T_REDUCE if item == ")" else T_SHIFT)
226 | return tokens, transitions
227 |
--------------------------------------------------------------------------------