├── .gitattributes
├── .gitignore
├── EVALB
├── COLLINS.prm
├── LICENSE
├── Makefile
├── README
├── evalb
├── evalb.c
├── new.prm
├── sample
│ ├── sample.gld
│ ├── sample.prm
│ ├── sample.rsl
│ └── sample.tst
└── tgrep_proc.prl
├── LICENSE
├── ON_LSTM.py
├── README.md
├── data.py
├── data
└── penn
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── data_ptb.py
├── embed_regularize.py
├── locked_dropout.py
├── main.py
├── model.py
├── parse_comparison.py
├── requirements.txt
├── splitcross.py
├── test_phrase_grammar.py
├── utils.py
└── weight_drop.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | <<<<<<< HEAD
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | env/
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 | #pycharm
105 | .idea/
106 |
107 | #pytorch
108 | *.pt
109 | =======
110 | # Byte-compiled / optimized / DLL files
111 | __pycache__/
112 | *.py[cod]
113 | *$py.class
114 |
115 | # C extensions
116 | *.so
117 |
118 | # Distribution / packaging
119 | .Python
120 | env/
121 | build/
122 | develop-eggs/
123 | dist/
124 | downloads/
125 | eggs/
126 | .eggs/
127 | lib/
128 | lib64/
129 | parts/
130 | sdist/
131 | var/
132 | wheels/
133 | *.egg-info/
134 | .installed.cfg
135 | *.egg
136 |
137 | # PyInstaller
138 | # Usually these files are written by a python script from a template
139 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
140 | *.manifest
141 | *.spec
142 |
143 | # Installer logs
144 | pip-log.txt
145 | pip-delete-this-directory.txt
146 |
147 | # Unit test / coverage reports
148 | htmlcov/
149 | .tox/
150 | .coverage
151 | .coverage.*
152 | .cache
153 | nosetests.xml
154 | coverage.xml
155 | *.cover
156 | .hypothesis/
157 |
158 | # Translations
159 | *.mo
160 | *.pot
161 |
162 | # Django stuff:
163 | *.log
164 | local_settings.py
165 |
166 | # Flask stuff:
167 | instance/
168 | .webassets-cache
169 |
170 | # Scrapy stuff:
171 | .scrapy
172 |
173 | # Sphinx documentation
174 | docs/_build/
175 |
176 | # PyBuilder
177 | target/
178 |
179 | # Jupyter Notebook
180 | .ipynb_checkpoints
181 |
182 | # pyenv
183 | .python-version
184 |
185 | # celery beat schedule file
186 | celerybeat-schedule
187 |
188 | # SageMath parsed files
189 | *.sage.py
190 |
191 | # dotenv
192 | .env
193 |
194 | # virtualenv
195 | .venv
196 | venv/
197 | ENV/
198 |
199 | # Spyder project settings
200 | .spyderproject
201 | .spyproject
202 |
203 | # Rope project settings
204 | .ropeproject
205 |
206 | # mkdocs documentation
207 | /site
208 |
209 | # mypy
210 | .mypy_cache/
211 |
212 | #pycharm
213 | .idea/
214 |
215 | #pytorch
216 | *.pt
217 | >>>>>>> e6ee33014912d2cdc248ceaf7855ad53fd2edad5
218 |
--------------------------------------------------------------------------------
/EVALB/COLLINS.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## 0: No debugging ##
4 | ## 1: print data for individual sentence ##
5 | ##------------------------------------------##
6 | DEBUG 0
7 |
8 | ##------------------------------------------##
9 | ## MAX error ##
10 | ## Number of error to stop the process. ##
11 | ## This is useful if there could be ##
12 | ## tokanization error. ##
13 | ## The process will stop when this number##
14 | ## of errors are accumulated. ##
15 | ##------------------------------------------##
16 | MAX_ERROR 10
17 |
18 | ##------------------------------------------##
19 | ## Cut-off length for statistics ##
20 | ## At the end of evaluation, the ##
21 | ## statistics for the senetnces of length##
22 | ## less than or equal to this number will##
23 | ## be shown, on top of the statistics ##
24 | ## for all the sentences ##
25 | ##------------------------------------------##
26 | CUTOFF_LEN 40
27 |
28 | ##------------------------------------------##
29 | ## unlabeled or labeled bracketing ##
30 | ## 0: unlabeled bracketing ##
31 | ## 1: labeled bracketing ##
32 | ##------------------------------------------##
33 | LABELED 1
34 |
35 | ##------------------------------------------##
36 | ## Delete labels ##
37 | ## list of labels to be ignored. ##
38 | ## If it is a pre-terminal label, delete ##
39 | ## the word along with the brackets. ##
40 | ## If it is a non-terminal label, just ##
41 | ## delete the brackets (don't delete ##
42 | ## deildrens). ##
43 | ##------------------------------------------##
44 | DELETE_LABEL TOP
45 | DELETE_LABEL -NONE-
46 | DELETE_LABEL ,
47 | DELETE_LABEL :
48 | DELETE_LABEL ``
49 | DELETE_LABEL ''
50 | DELETE_LABEL .
51 |
52 | ##------------------------------------------##
53 | ## Delete labels for length calculation ##
54 | ## list of labels to be ignored for ##
55 | ## length calculation purpose ##
56 | ##------------------------------------------##
57 | DELETE_LABEL_FOR_LENGTH -NONE-
58 |
59 | ##------------------------------------------##
60 | ## Equivalent labels, words ##
61 | ## the pairs are considered equivalent ##
62 | ## This is non-directional. ##
63 | ##------------------------------------------##
64 | EQ_LABEL ADVP PRT
65 |
66 | # EQ_WORD Example example
67 |
--------------------------------------------------------------------------------
/EVALB/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/EVALB/Makefile:
--------------------------------------------------------------------------------
1 | all: evalb
2 |
3 | evalb: evalb.c
4 | gcc -Wall -g -o evalb evalb.c
5 |
--------------------------------------------------------------------------------
/EVALB/README:
--------------------------------------------------------------------------------
1 | #################################################################
2 | # #
3 | # Bug fix and additional functionality for evalb #
4 | # #
5 | # This updated version of evalb fixes a bug in which sentences #
6 | # were incorrectly categorized as "length mismatch" when the #
7 | # the parse output had certain mislabeled parts-of-speech. #
8 | # #
9 | # The bug was the result of evalb treating one of the tags (in #
10 | # gold or test) as a label to be deleted (see sections [6],[7] #
11 | # for details), but not the corresponding tag in the other. #
12 | # This most often occurs with punctuation. See the subdir #
13 | # "bug" for an example gld and tst file demonstating the bug, #
14 | # as well as output of evalb with and without the bug fix. #
15 | # #
16 | # For the present version in case of length mismatch, the nodes #
17 | # causing the imbalance are reinserted to resolve the miscount. #
18 | # If the lengths of gold and test truly differ, the error is #
19 | # still reported. The parameter file "new.prm" (derived from #
20 | # COLLINS.prm) shows how to add new potential mislabelings for #
21 | # quotes (",``,',`). #
22 | # #
23 | # I have preserved DJB's revision for modern compilers except #
24 | # for the delcaration of "exit" which is provided by stdlib. #
25 | # #
26 | # Other changes: #
27 | # #
28 | # * output of F-Measure in addition to precision and recall #
29 | # (I did not update the documention in section [4] for this) #
30 | # #
31 | # * more comprehensive DEBUG output that includes bracketing #
32 | # information as evalb is processing each sentence #
33 | # (useful in working through this, and peraps other bugs). #
34 | # Use either the "-D" run-time switch or set DEBUG to 2 in #
35 | # the parameter file. #
36 | # #
37 | # * added DELETE_LABEL lines in new.prm for S1 nodes produced #
38 | # by the Charniak parser and "?", "!" punctuation produced by #
39 | # the Bikel parser. #
40 | # #
41 | # #
42 | # David Ellis (Brown) #
43 | # #
44 | # January.2006 #
45 | #################################################################
46 |
47 | #################################################################
48 | # #
49 | # Update of evalb for modern compilers #
50 | # #
51 | # This is an updated version of evalb, for use with modern C #
52 | # compilers. There are a few updates, each marked in the code: #
53 | # #
54 | # /* DJB: explanation of comment */ #
55 | # #
56 | # The updates are purely to help compilation with recent #
57 | # versions of GCC (and other C compilers). There are *NO* other #
58 | # changes to the algorithm itself. #
59 | # #
60 | # I have made these changes following recommendations from #
61 | # users of the Corpora Mailing List, especially Peet Morris and #
62 | # Ramon Ziai. #
63 | # #
64 | # David Brooks (Birmingham) #
65 | # #
66 | # September.2005 #
67 | #################################################################
68 |
69 | #################################################################
70 | # #
71 | # README file for evalb #
72 | # #
73 | # Satoshi Sekine (NYU) #
74 | # Mike Collins (UPenn) #
75 | # #
76 | # October.1997 #
77 | #################################################################
78 |
79 | Contents of this README:
80 |
81 | [0] COPYRIGHT
82 | [1] INTRODUCTION
83 | [2] INSTALLATION AND RUN
84 | [3] OPTIONS
85 | [4] OUTPUT FORMAT FROM THE SCORER
86 | [5] HOW TO CREATE A GOLDFILE FROM THE TREEBANK
87 | [6] THE PARAMETER FILE
88 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM
89 |
90 |
91 | [0] COPYRIGHT
92 |
93 | The authors abandon the copyright of this program. Everyone is
94 | permitted to copy and distribute the program or a portion of the program
95 | with no charge and no restrictions unless it is harmful to someone.
96 |
97 | However, the authors are delightful for the user's kindness of proper
98 | usage and letting the authors know bugs or problems.
99 |
100 | This software is provided "AS IS", and the authors make no warranties,
101 | express or implied.
102 |
103 | To legally enforce the abandonment of copyright, this package is released
104 | under the Unlicense (see LICENSE).
105 |
106 | [1] INTRODUCTION
107 |
108 | Evaluation of bracketing looks simple, but in fact, there are minor
109 | differences from system to system. This is a program to parametarize
110 | such minor differences and to give an informative result.
111 |
112 | "evalb" evaluates bracketing accuracy in a test-file against a gold-file.
113 | It returns recall, precision, tagging accuracy. It uses an identical
114 | algorithm to that used in (Collins ACL97).
115 |
116 |
117 | [2] Installation and Run
118 |
119 | To compile the scorer, type
120 |
121 | > make
122 |
123 |
124 | To run the scorer:
125 |
126 | > evalb -p Parameter_file Gold_file Test_file
127 |
128 |
129 | For example to use the sample files:
130 |
131 | > evalb -p sample.prm sample.gld sample.tst
132 |
133 |
134 |
135 | [3] OPTIONS
136 |
137 | You can specify system parameters in the command line options.
138 | Other options concerning to evaluation metrix should be specified
139 | in parameter file, described later.
140 |
141 | -p param_file parameter file
142 | -d debug mode
143 | -e n number of error to kill (default=10)
144 | -h help
145 |
146 |
147 |
148 | [4] OUTPUT FORMAT FROM THE SCORER
149 |
150 | The scorer gives individual scores for each sentence, for
151 | example:
152 |
153 | Sent. Matched Bracket Cross Correct Tag
154 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy
155 | ============================================================================
156 | 1 8 0 100.00 100.00 5 5 5 0 6 5 83.33
157 |
158 | At the end of the output the === Summary === section gives statistics
159 | for all sentences, and for sentences <=40 words in length. The summary
160 | contains the following information:
161 |
162 | i) Number of sentences -- total number of sentences.
163 |
164 | ii) Number of Error/Skip sentences -- should both be 0 if there is no
165 | problem with the parsed/gold files.
166 |
167 | iii) Number of valid sentences = Number of sentences - Number of Error/Skip
168 | sentences
169 |
170 | iv) Bracketing recall = (number of correct constituents)
171 | ----------------------------------------
172 | (number of constituents in the goldfile)
173 |
174 | v) Bracketing precision = (number of correct constituents)
175 | ----------------------------------------
176 | (number of constituents in the parsed file)
177 |
178 | vi) Complete match = percentaage of sentences where recall and precision are
179 | both 100%.
180 |
181 | vii) Average crossing = (number of constituents crossing a goldfile constituen
182 | ----------------------------------------------------
183 | (number of sentences)
184 |
185 | viii) No crossing = percentage of sentences which have 0 crossing brackets.
186 |
187 | ix) 2 or less crossing = percentage of sentences which have <=2 crossing brackets.
188 |
189 | x) Tagging accuracy = percentage of correct POS tags (but see [5].3 for exact
190 | details of what is counted).
191 |
192 |
193 |
194 | [5] HOW TO CREATE A GOLDFILE FROM THE PENN TREEBANK
195 |
196 |
197 | The gold and parsed files are in a format similar to this:
198 |
199 | (TOP (S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .)))
200 |
201 | To create a gold file from the treebank:
202 |
203 | tgrep -wn '/.*/' | tgrep_proc.prl
204 |
205 | will produce a goldfile in the required format. ("tgrep -wn '/.*/'" prints
206 | parse trees, "tgrep_process.prl" just skips blank lines).
207 |
208 | For example, to produce a goldfile for section 23 of the treebank:
209 |
210 | tgrep -wn '/.*/' | tail +90895 | tgrep_process.prl | sed 2416q > sec23.gold
211 |
212 |
213 |
214 | [6] THE PARAMETER (.prm) FILE
215 |
216 |
217 | The .prm file sets options regarding the scoring method. COLLINS.prm gives
218 | the same scoring behaviour as the scorer used in (Collins 97). The options
219 | chosen were:
220 |
221 | 1) LABELED 1
222 |
223 | to give labelled precision/recall figures, i.e. a constituent must have the
224 | same span *and* label as a constituent in the goldfile.
225 |
226 | 2) DELETE_LABEL TOP
227 |
228 | Don't count the "TOP" label (which is always given in the output of tgrep)
229 | when scoring.
230 |
231 | 3) DELETE_LABEL -NONE-
232 |
233 | Remove traces (and all constituents which dominate nothing but traces) when
234 | scoring. For example
235 |
236 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .)))
237 |
238 | would be processed to give
239 |
240 | .... (VP (VBD reported)) (. .)))
241 |
242 |
243 | 4)
244 | DELETE_LABEL , -- for the purposes of scoring remove punctuation
245 | DELETE_LABEL :
246 | DELETE_LABEL ``
247 | DELETE_LABEL ''
248 | DELETE_LABEL .
249 |
250 | 5) DELETE_LABEL_FOR_LENGTH -NONE- -- don't include traces when calculating
251 | the length of a sentence (important
252 | when classifying a sentence as <=40
253 | words or >40 words)
254 |
255 | 6) EQ_LABEL ADVP PRT
256 |
257 | Count ADVP and PRT as being the same label when scoring.
258 |
259 |
260 |
261 |
262 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM
263 |
264 |
265 | 1) The scorer initially processes the files to remove all nodes specified
266 | by DELETE_LABEL in the .prm file. It also recursively removes nodes which
267 | dominate nothing due to all their children being removed. For example, if
268 | -NONE- is specified as a label to be deleted,
269 |
270 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .)))
271 |
272 | would be processed to give
273 |
274 | .... (VP (VBD reported)) (. .)))
275 |
276 | 2) The scorer also removes all functional tags attached to non-terminals
277 | (functional tags are prefixed with "-" or "=" in the treebank). For example
278 | "NP-SBJ" is processed to give "NP", "NP=2" is changed to "NP".
279 |
280 |
281 | 3) Tagging accuracy counts tags for all words *except* any tags which are
282 | deleted by a DELETE_LABEL specification in the .prm file. (For example, for
283 | COLLINS.prm, punctuation tagged as "," ":" etc. would not be included).
284 |
285 | 4) When calculating the length of a sentence, all words with POS tags not
286 | included in the "DELETE_LABEL_FOR_LENGTH" list in the .prm file are
287 | counted. (For COLLINS.prm, only "-NONE-" is specified in this list, so
288 | traces are removed before calculating the length of the sentence).
289 |
290 | 5) There are some subtleties in scoring when either the goldfile or parsed
291 | file contains multiple constituents for the same span which have the same
292 | non-terminal label. e.g. (NP (NP the man)) If the goldfile contains n
293 | constituents for the same span, and the parsed file contains m constituents
294 | with that nonterminal, the scorer works as follows:
295 |
296 | i) If m>n, then the precision is n/m, recall is 100%
297 |
298 | ii) If n>m, then the precision is 100%, recall is m/n.
299 |
300 | iii) If n==m, recall and precision are both 100%.
301 |
--------------------------------------------------------------------------------
/EVALB/evalb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikangshen/Ordered-Neurons/46d63cde024802eaf1eb7cc896431329014dd869/EVALB/evalb
--------------------------------------------------------------------------------
/EVALB/evalb.c:
--------------------------------------------------------------------------------
1 | /*****************************************************************/
2 | /* evalb [-p param_file] [-dh] [-e n] gold-file test-file */
3 | /* */
4 | /* Evaluate bracketing in test-file against gold-file. */
5 | /* Return recall, precision, tagging accuracy. */
6 | /* */
7 | /* */
8 | /* -p param_file parameter file */
9 | /* -d debug mode */
10 | /* -e n number of error to kill (default=10) */
11 | /* -h help */
12 | /* */
13 | /* Satoshi Sekine (NYU) */
14 | /* Mike Collins (UPenn) */
15 | /* */
16 | /* October.1997 */
17 | /* */
18 | /* Please refer README for the update information */
19 | /*****************************************************************/
20 |
21 | #include
22 | #include //### added for exit, atoi decls
23 | #include
24 | #include
25 | #include
26 |
27 |
28 | /* Internal Data format -------------------------------------------*/
29 | /* */
30 | /* (S (NP (NNX this)) (VP (VBX is) (NP (DT a) (NNX pen))) (SYM .)) */
31 | /* */
32 | /* wn=5 */
33 | /* word label */
34 | /* terminal[0] = this NNX */
35 | /* terminal[1] = is VBX */
36 | /* terminal[2] = a DT */
37 | /* terminal[3] = pen NNX */
38 | /* terminal[4] = . SYM */
39 | /* */
40 | /* bn=4 */
41 | /* start end label */
42 | /* bracket[0] = 0 5 S */
43 | /* bracket[1] = 0 0 NP */
44 | /* bracket[2] = 1 4 VP */
45 | /* bracket[3] = 2 4 NP */
46 | /* */
47 | /* matched bracketing */
48 | /* Recall = --------------------------- */
49 | /* # of bracket in ref-data */
50 | /* */
51 | /* matched bracketing */
52 | /* Recall = --------------------------- */
53 | /* # of bracket in test-data */
54 | /* */
55 | /*-----------------------------------------------------------------*/
56 |
57 | /******************/
58 | /* constant macro */
59 | /******************/
60 |
61 | #define MAX_SENT_LEN 5000
62 | #define MAX_WORD_IN_SENT 200
63 | #define MAX_BRACKET_IN_SENT 200
64 | #define MAX_WORD_LEN 100
65 | #define MAX_LABEL_LEN 30
66 | #define MAX_QUOTE_TERM 20
67 |
68 | #define MAX_DELETE_LABEL 100
69 | #define MAX_EQ_LABEL 100
70 | #define MAX_EQ_WORD 100
71 |
72 | #define MAX_LINE_LEN 500
73 |
74 | #define DEFAULT_MAX_ERROR 10
75 | #define DEFAULT_CUT_LEN 40
76 |
77 | /*************/
78 | /* structure */
79 | /*************/
80 |
81 | typedef struct ss_terminal {
82 | char word[MAX_WORD_LEN];
83 | char label[MAX_LABEL_LEN];
84 | int result; /* 0:unmatch, 1:match, 9:undef */
85 | } s_terminal;
86 |
87 | typedef struct ss_term_ind {
88 | s_terminal term;
89 | int index;
90 | int bracket;
91 | int endslen;
92 | int ends[MAX_BRACKET_IN_SENT];
93 | } s_term_ind;
94 |
95 | typedef struct ss_bracket {
96 | int start;
97 | int end;
98 | unsigned int buf_start;
99 | unsigned int buf_end;
100 | char label[MAX_LABEL_LEN];
101 | int result; /* 0: unmatch, 1:match, 5:delete 9:undef */
102 | } s_bracket;
103 |
104 |
105 | typedef struct ss_equiv {
106 | char *s1;
107 | char *s2;
108 | } s_equiv;
109 |
110 |
111 | /****************************/
112 | /* global variables */
113 | /* gold-data: suffix = 1 */
114 | /* test-data: suffix = 2 */
115 | /****************************/
116 |
117 | /*---------------*/
118 | /* Sentence data */
119 | /*---------------*/
120 | int wn1, wn2; /* number of words in sentence */
121 | int r_wn1; /* number of words in sentence */
122 | /* which only ignores labels in */
123 | /* DELETE_LABEL_FOR_LENGTH */
124 |
125 | s_terminal terminal1[MAX_WORD_IN_SENT]; /* terminal information */
126 | s_terminal terminal2[MAX_WORD_IN_SENT];
127 |
128 | s_term_ind quotterm1[MAX_QUOTE_TERM]; /* special terminals ("'","POS") */
129 | s_term_ind quotterm2[MAX_QUOTE_TERM];
130 |
131 | int bn1, bn2; /* number of brackets */
132 |
133 | int r_bn1, r_bn2; /* number of brackets */
134 | /* after deletion */
135 |
136 | s_bracket bracket1[MAX_BRACKET_IN_SENT]; /* bracket information */
137 | s_bracket bracket2[MAX_BRACKET_IN_SENT];
138 |
139 |
140 | /*------------*/
141 | /* Total data */
142 | /*------------*/
143 | int TOTAL_bn1, TOTAL_bn2, TOTAL_match; /* total number of brackets */
144 | int TOTAL_sent; /* No. of sentence */
145 | int TOTAL_error_sent; /* No. of error sentence */
146 | int TOTAL_skip_sent; /* No. of skip sentence */
147 | int TOTAL_comp_sent; /* No. of complete match sent */
148 | int TOTAL_word; /* total number of word */
149 | int TOTAL_crossing; /* total crossing */
150 | int TOTAL_no_crossing; /* no crossing sentence */
151 | int TOTAL_2L_crossing; /* 2 or less crossing sentence */
152 | int TOTAL_correct_tag; /* total correct tagging */
153 |
154 | int TOT_cut_len = DEFAULT_CUT_LEN; /* Cut-off length in statistics */
155 |
156 | /* data for sentences with len <= CUT_LEN */
157 | /* Historically it was 40. */
158 | int TOT40_bn1, TOT40_bn2, TOT40_match; /* total number of brackets */
159 | int TOT40_sent; /* No. of sentence */
160 | int TOT40_error_sent; /* No. of error sentence */
161 | int TOT40_skip_sent; /* No. of skip sentence */
162 | int TOT40_comp_sent; /* No. of complete match sent */
163 | int TOT40_word; /* total number of word */
164 | int TOT40_crossing; /* total crossing */
165 | int TOT40_no_crossing; /* no crossing sentence */
166 | int TOT40_2L_crossing; /* 2 or less crossing sentence */
167 | int TOT40_correct_tag; /* total correct tagging */
168 |
169 | /*------------*/
170 | /* miscallous */
171 | /*------------*/
172 | int Line; /* line number */
173 | int Error_count = 0; /* Error count */
174 | int Status; /* Result status for each sent */
175 | /* 0: OK, 1: skip, 2: error */
176 |
177 | /*-------------------*/
178 | /* stack manuplation */
179 | /*-------------------*/
180 | int stack_top;
181 | int stack[MAX_BRACKET_IN_SENT];
182 |
183 | /************************************************************/
184 | /* User parameters which can be specified in parameter file */
185 | /************************************************************/
186 |
187 | /*------------------------------------------*/
188 | /* Debug mode */
189 | /* print out data for individual sentence */
190 | /*------------------------------------------*/
191 | int DEBUG=0;
192 |
193 | /*------------------------------------------*/
194 | /* MAX error */
195 | /* Number of error to stop the process. */
196 | /* This is useful if there could be */
197 | /* tokanization error. */
198 | /* The process will stop when this number*/
199 | /* of errors are accumulated. */
200 | /*------------------------------------------*/
201 | int Max_error = DEFAULT_MAX_ERROR;
202 |
203 | /*------------------------------------------*/
204 | /* Cut-off length for statistics */
205 | /* int TOT_cut_len = DEFAULT_CUT_LEN; */
206 | /* (Defined above) */
207 | /*------------------------------------------*/
208 |
209 |
210 | /*------------------------------------------*/
211 | /* unlabeled or labeled bracketing */
212 | /* 0: unlabeled bracketing */
213 | /* 1: labeled bracketing */
214 | /*------------------------------------------*/
215 | int F_label = 1;
216 |
217 | /*------------------------------------------*/
218 | /* Delete labels */
219 | /* list of labels to be ignored. */
220 | /* If it is a pre-terminal label, delete */
221 | /* the word along with the brackets. */
222 | /* If it is a non-terminal label, just */
223 | /* delete the brackets (don't delete */
224 | /* childrens). */
225 | /*------------------------------------------*/
226 | char *Delete_label[MAX_DELETE_LABEL];
227 | int Delete_label_n = 0;
228 |
229 | /*------------------------------------------*/
230 | /* Delete labels for length calculation */
231 | /* list of labels to be ignored for */
232 | /* length calculation purpose */
233 | /*------------------------------------------*/
234 | char *Delete_label_for_length[MAX_DELETE_LABEL];
235 | int Delete_label_for_length_n = 0;
236 |
237 | /*------------------------------------------*/
238 | /* Labels to be considered for misquote */
239 | /* (could be possesive or quote) */
240 | /*------------------------------------------*/
241 | char *Quote_term[MAX_QUOTE_TERM];
242 | int Quote_term_n = 0;
243 |
244 | /*------------------------------------------*/
245 | /* Equivalent labels, words */
246 | /* the pairs are considered equivalent */
247 | /* This is non-directional. */
248 | /*------------------------------------------*/
249 | s_equiv EQ_label[MAX_EQ_LABEL];
250 | int EQ_label_n = 0;
251 |
252 | s_equiv EQ_word[MAX_EQ_WORD];
253 | int EQ_word_n = 0;
254 |
255 |
256 |
257 | /************************/
258 | /* Function return-type */
259 | /************************/
260 | int main();
261 | void init_global();
262 | void print_head();
263 | void init();
264 | void read_parameter_file();
265 | void set_param();
266 | int narg();
267 | int read_line();
268 |
269 | void pushb();
270 | int popb();
271 | int stackempty();
272 |
273 | void calc_result(unsigned char *buf1,unsigned char *buf);
274 | void fix_quote();
275 | void reinsert_term();
276 | void massage_data();
277 | void modify_label();
278 | void individual_result();
279 | void print_total();
280 | void dsp_info();
281 | int is_terminator();
282 | int is_deletelabel();
283 | int is_deletelabel_for_length();
284 | int is_quote_term();
285 | int word_comp();
286 | int label_comp();
287 |
288 | void Error();
289 | void Fatal();
290 | void Usage();
291 |
292 | /* ### provided by std headers
293 | int fprintf();
294 | int printf();
295 | int atoi();
296 | int fclose();
297 | int sscanf();
298 | */
299 |
300 | /***********/
301 | /* program */
302 | /***********/
303 | #define ARG_CHECK(st) if(!(*++(*argv) || (--argc && *++argv))){ \
304 | fprintf(stderr,"Missing argument: %s\n",st); \
305 | }
306 |
307 | int
308 | main(argc,argv)
309 | int argc;
310 | char *argv[];
311 | {
312 | char *filename1, *filename2;
313 | FILE *fd1, *fd2;
314 | unsigned char buff[5000];
315 | unsigned char buff1[5000];
316 |
317 | filename1=NULL;
318 | filename2=NULL;
319 |
320 | for(argc--,argv++;argc>0;argc--,argv++){
321 | if(**argv == '-'){
322 | while(*++(*argv)){
323 | switch(**argv){
324 |
325 | case 'h': /* help */
326 | Usage();
327 | exit(1);
328 |
329 | case 'd': /* debug mode */
330 | DEBUG = 1;
331 | goto nextarg;
332 |
333 | case 'D': /* debug mode */
334 | DEBUG = 2;
335 | goto nextarg;
336 |
337 | case 'c': /* cut-off length */
338 | ARG_CHECK("cut-off length for statistices");
339 | TOT_cut_len = atoi(*argv);
340 | goto nextarg;
341 |
342 | case 'e': /* max error */
343 | ARG_CHECK("number of error to kill");
344 | Max_error = atoi(*argv);
345 | goto nextarg;
346 |
347 | case 'p': /* parameter file */
348 | ARG_CHECK("parameter file");
349 | read_parameter_file(*argv);
350 | goto nextarg;
351 |
352 | default:
353 | Usage();
354 | exit(0);
355 | }
356 | }
357 | } else {
358 | if(filename1==NULL){
359 | filename1 = *argv;
360 | }else if(filename2==NULL){
361 | filename2 = *argv;
362 | }
363 | }
364 | nextarg: continue;
365 | }
366 |
367 | init_global();
368 |
369 |
370 | if((fd1 = fopen(filename1,"r"))==NULL){
371 | Fatal("Can't open gold file (%s)\n",filename1);
372 | }
373 | if((fd2 = fopen(filename2,"r"))==NULL){
374 | Fatal("Can't open test file (%s)\n",filename2);
375 | }
376 |
377 | print_head();
378 |
379 | for(Line=1;fgets(buff,5000,fd1)!=NULL;Line++){
380 |
381 | init();
382 |
383 | /* READ 1 */
384 | r_wn1 = read_line(buff,terminal1,quotterm1,&wn1,bracket1,&bn1);
385 |
386 | strcpy(buff1,buff);
387 |
388 | /* READ 2 */
389 | if(fgets(buff,5000,fd2)==NULL){
390 | Error("Number of lines unmatch (too many lines in gold file)\n");
391 | break;
392 | }
393 |
394 | read_line(buff,terminal2,quotterm2,&wn2,bracket2,&bn2);
395 |
396 | /* Calculate result and print it */
397 | calc_result(buff1,buff);
398 |
399 | if(DEBUG>=1){
400 | dsp_info();
401 | }
402 | }
403 |
404 | if(fgets(buff,5000,fd2)!=NULL){
405 | Error("Number of lines unmatch (too many lines in test file)\n");
406 | }
407 |
408 | print_total();
409 |
410 | return (0);
411 | }
412 |
413 |
414 | /*-----------------------------*/
415 | /* initialize global variables */
416 | /*-----------------------------*/
417 | void
418 | init_global()
419 | {
420 | TOTAL_bn1 = TOTAL_bn2 = TOTAL_match = 0;
421 | TOTAL_sent = TOTAL_error_sent = TOTAL_skip_sent = TOTAL_comp_sent = 0;
422 | TOTAL_word = TOTAL_correct_tag = 0;
423 | TOTAL_crossing = 0;
424 | TOTAL_no_crossing = TOTAL_2L_crossing = 0;
425 |
426 | TOT40_bn1 = TOT40_bn2 = TOT40_match = 0;
427 | TOT40_sent = TOT40_error_sent = TOT40_skip_sent = TOT40_comp_sent = 0;
428 | TOT40_word = TOT40_correct_tag = 0;
429 | TOT40_crossing = 0;
430 | TOT40_no_crossing = TOT40_2L_crossing = 0;
431 |
432 | }
433 |
434 |
435 | /*------------------*/
436 | /* print head title */
437 | /*------------------*/
438 | void
439 | print_head()
440 | {
441 | printf(" Sent. Matched Bracket Cross Correct Tag\n");
442 | printf(" ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy\n");
443 | printf("============================================================================\n");
444 | }
445 |
446 |
447 | /*-----------------------------------------------*/
448 | /* initialization at each individual computation */
449 | /*-----------------------------------------------*/
450 | void
451 | init()
452 | {
453 | int i;
454 |
455 | wn1 = 0;
456 | wn2 = 0;
457 | bn1 = 0;
458 | bn2 = 0;
459 | r_bn1 = 0;
460 | r_bn2 = 0;
461 |
462 | for(i=0;i0 && (isspace(buff[i]) || buff[i]=='\n');i--){
519 | buff[i]='\0';
520 | }
521 | if(buff[0]=='#' || /* comment-line */
522 | strlen(buff)<3){ /* too short, just ignore */
523 | continue;
524 | }
525 |
526 | /* place the parameter and value */
527 | /*-------------------------------*/
528 | for(i=0;!isspace(buff[i]);i++);
529 | for(;isspace(buff[i]) && buff[i]!='\0';i++);
530 | if(buff[i]=='\0'){
531 | fprintf(stderr,"Empty value in parameter file (%d)\n",line);
532 | }
533 |
534 | /* set parameter and value */
535 | /*-------------------------*/
536 | set_param(buff,buff+i);
537 | }
538 |
539 | fclose(fd);
540 | }
541 |
542 |
543 | #define STRNCMP(s) (strncmp(param,s,strlen(s))==0 && \
544 | (param[strlen(s)]=='\0' || isspace(param[strlen(s)])))
545 |
546 |
547 | void
548 | set_param(param,value)
549 | char *param, *value;
550 | {
551 | char l1[MAX_LABEL_LEN], l2[MAX_LABEL_LEN];
552 |
553 | if(STRNCMP("DEBUG")){
554 |
555 | DEBUG = atoi(value);
556 |
557 | }else if(STRNCMP("MAX_ERROR")){
558 |
559 | Max_error = atoi(value);
560 |
561 | }else if(STRNCMP("CUTOFF_LEN")){
562 |
563 | TOT_cut_len = atoi(value);
564 |
565 | }else if(STRNCMP("LABELED")){
566 |
567 | F_label = atoi(value);
568 |
569 | }else if(STRNCMP("DELETE_LABEL")){
570 |
571 | Delete_label[Delete_label_n] = (char *)malloc(strlen(value)+1);
572 | strcpy(Delete_label[Delete_label_n],value);
573 | Delete_label_n++;
574 |
575 | }else if(STRNCMP("DELETE_LABEL_FOR_LENGTH")){
576 |
577 | Delete_label_for_length[Delete_label_for_length_n] = (char *)malloc(strlen(value)+1);
578 | strcpy(Delete_label_for_length[Delete_label_for_length_n],value);
579 | Delete_label_for_length_n++;
580 |
581 | }else if(STRNCMP("QUOTE_LABEL")){
582 |
583 | Quote_term[Quote_term_n] = (char *)malloc(strlen(value)+1);
584 | strcpy(Quote_term[Quote_term_n],value);
585 | Quote_term_n++;
586 |
587 | }else if(STRNCMP("EQ_LABEL")){
588 |
589 | if(narg(value)!=2){
590 | fprintf(stderr,"EQ_LABEL requires two values\n");
591 | return;
592 | }
593 | sscanf(value,"%s %s",l1,l2);
594 | EQ_label[EQ_label_n].s1 = (char *)malloc(strlen(l1)+1);
595 | strcpy(EQ_label[EQ_label_n].s1,l1);
596 | EQ_label[EQ_label_n].s2 = (char *)malloc(strlen(l2)+1);
597 | strcpy(EQ_label[EQ_label_n].s2,l2);
598 | EQ_label_n++;
599 |
600 | }else if(STRNCMP("EQ_WORD")){
601 |
602 | if(narg(value)!=2){
603 | fprintf(stderr,"EQ_WORD requires two values\n");
604 | return;
605 | }
606 | sscanf(value,"%s %s",l1,l2);
607 | EQ_word[EQ_word_n].s1 = (char *)malloc(strlen(l1)+1);
608 | strcpy(EQ_word[EQ_word_n].s1,l1);
609 | EQ_word[EQ_word_n].s2 = (char *)malloc(strlen(l2)+1);
610 | strcpy(EQ_word[EQ_word_n].s2,l2);
611 | EQ_word_n++;
612 |
613 | }else{
614 |
615 | fprintf(stderr,"Unknown keyword (%s) in parameter file\n",param);
616 |
617 | }
618 | }
619 |
620 |
621 | int
622 | narg(s)
623 | char *s;
624 | {
625 | int n;
626 |
627 | for(n=0;*s!='\0';){
628 | for(;isspace(*s);s++);
629 | if(*s=='\0'){
630 | break;
631 | }
632 | n++;
633 | for(;!isspace(*s);s++){
634 | if(*s=='\0'){
635 | break;
636 | }
637 | }
638 | }
639 |
640 | return(n);
641 | }
642 |
643 | /*-----------------------------*/
644 | /* Read line and gather data. */
645 | /* Return langth of sentence. */
646 | /*-----------------------------*/
647 | int
648 | read_line(buff, terminal, quotterm, wn, bracket, bn)
649 | char *buff;
650 | s_terminal terminal[];
651 | s_term_ind quotterm[];
652 | int *wn;
653 | s_bracket bracket[];
654 | int *bn;
655 | {
656 | char *p, *q, label[MAX_LABEL_LEN], word[MAX_WORD_LEN];
657 | int qt; /* quote term counter */
658 | int wid, bid; /* word ID, bracket ID */
659 | int n; /* temporary remembering the position */
660 | int b; /* temporary remembering bid */
661 | int i;
662 | int len; /* length of the sentence */
663 |
664 | len = 0;
665 | stack_top=0;
666 |
667 | for(p=buff,qt=0,wid=0,bid=0;*p!='\0';){
668 |
669 | if(isspace(*p)){
670 | p++;
671 | continue;
672 |
673 | /* open bracket */
674 | /*--------------*/
675 | }else if(*p=='('){
676 |
677 | n=wid;
678 | for(p++,i=0;!is_terminator(*p);p++,i++){
679 | label[i]=*p;
680 | }
681 | label[i]='\0';
682 |
683 | /* Find terminals */
684 | q = p;
685 | if(isspace(*q)){
686 | for(q++;isspace(*q);q++);
687 | for(i=0;!is_terminator(*q);q++,i++){
688 | word[i]=*q;
689 | }
690 | word[i]='\0';
691 |
692 | /* compute length */
693 | if(*q==')' && !is_deletelabel_for_length(label)==1){
694 | len++;
695 | }
696 | if (DEBUG>1)
697 | printf("label=%s, word=%s, wid=%d\n",label,word,wid);
698 | /* quote terminal */
699 | if(*q==')' && is_quote_term(label,word)==1){
700 | strcpy(quotterm[qt].term.word,word);
701 | strcpy(quotterm[qt].term.label,label);
702 | quotterm[qt].index = wid;
703 | quotterm[qt].bracket = bid;
704 | quotterm[qt].endslen = stack_top;
705 | //quotterm[qt].ends = (int*)malloc(stack_top*sizeof(int));
706 | memcpy(quotterm[qt].ends,stack,stack_top*sizeof(int));
707 | qt++;
708 | }
709 |
710 | /* delete terminal */
711 | if(*q==')' && is_deletelabel(label)==1){
712 | p = q+1;
713 | continue;
714 |
715 | /* valid terminal */
716 | }else if(*q==')'){
717 | strcpy(terminal[wid].word,word);
718 | strcpy(terminal[wid].label,label);
719 | wid++;
720 | p = q+1;
721 | continue;
722 |
723 | /* error */
724 | }else if(*q!='('){
725 | Error("More than two elements in a bracket\n");
726 | }
727 | }
728 |
729 | /* otherwise non-terminal label */
730 | bracket[bid].start = wid;
731 | bracket[bid].buf_start = p-buff;
732 | strcpy(bracket[bid].label,label);
733 | pushb(bid);
734 | bid++;
735 |
736 | /* close bracket */
737 | /*---------------*/
738 | }else if(*p==')'){
739 |
740 | b = popb();
741 | bracket[b].end = wid;
742 | bracket[b].buf_end = p-buff;
743 | p++;
744 |
745 | /* error */
746 | /*-------*/
747 | }else{
748 |
749 | Error("Reading sentence\n");
750 | }
751 | }
752 |
753 | if(!stackempty()){
754 | Error("Bracketing is unbalanced (too many open bracket)\n");
755 | }
756 |
757 | *wn = wid;
758 | *bn = bid;
759 |
760 | return(len);
761 | }
762 |
763 |
764 | /*----------------------*/
765 | /* stack operation */
766 | /* for bracketing pairs */
767 | /*----------------------*/
768 | void
769 | pushb(item)
770 | int item;
771 | {
772 | stack[stack_top++]=item;
773 | }
774 |
775 | int
776 | popb()
777 | {
778 | int item;
779 |
780 | item = stack[stack_top-1];
781 |
782 | if(stack_top-- < 0){
783 | Error("Bracketing unbalance (too many close bracket)\n");
784 | }
785 | return(item);
786 | }
787 |
788 | int
789 | stackempty()
790 | {
791 | if(stack_top==0){
792 | return(1);
793 | }else{
794 | return(0);
795 | }
796 | }
797 |
798 |
799 | /*------------------*/
800 | /* calculate result */
801 | /*------------------*/
802 | void
803 | calc_result(unsigned char *buf1,unsigned char *buf)
804 | {
805 | int i, j, l;
806 | int match, crossing, correct_tag;
807 |
808 | int last_i = -1;
809 |
810 | char my_buf[1000];
811 | int match_found = 0;
812 |
813 | char match_j[200];
814 | for (j = 0; j < bn2; ++j) {
815 | match_j[j] = 0;
816 | }
817 |
818 | /* ML */
819 | if (DEBUG>1)
820 | printf("\n");
821 |
822 |
823 | /* Find skip and error */
824 | /*---------------------*/
825 | if(wn2==0){
826 | Status = 2;
827 | individual_result(0,0,0,0,0,0);
828 | return;
829 | }
830 |
831 | if(wn1 != wn2){
832 | //if (DEBUG>1)
833 | //Error("Length unmatch (%d|%d)\n",wn1,wn2);
834 | fix_quote();
835 | if(wn1 != wn2){
836 | Error("Length unmatch (%d|%d)\n",wn1,wn2);
837 | individual_result(0,0,0,0,0,0);
838 | return;
839 | }
840 | }
841 |
842 | for(i=0;i1)
862 | printf("1.res=%d, 2.res=%d, 1.start=%d, 2.start=%d, 1.end=%d, 2.end=%d\n",bracket1[i].result,bracket2[j].result,bracket1[i].start,bracket2[j].start,bracket1[i].end,bracket2[j].end);
863 |
864 | // does bracket match?
865 | if(bracket1[i].result != 5 &&
866 | bracket2[j].result == 0 &&
867 | bracket1[i].start == bracket2[j].start && bracket1[i].end == bracket2[j].end) {
868 |
869 | // (1) do we not care about the label or (2) does the label match?
870 | if (F_label==0 || label_comp(bracket1[i].label,bracket2[j].label)==1) {
871 | bracket1[i].result = bracket2[j].result = 1;
872 | match++;
873 | match_found = 1;
874 | break;
875 | } else {
876 | if (DEBUG>1) {
877 | printf(" LABEL[%d-%d]: ",bracket1[i].start,bracket1[i].end-1);
878 | l = bracket1[i].buf_end-bracket1[i].buf_start;
879 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
880 | my_buf[l] = '\0';
881 | printf("%s\n",my_buf);
882 | }
883 | match_found = 1;
884 | match_j[j] = 1;
885 | }
886 | }
887 | }
888 |
889 | if (!match_found && bracket1[i].result != 5 && DEBUG>1) {
890 | /* ### ML 09/28/03: gold bracket with no corresponding test bracket */
891 | printf(" BRACKET[%d-%d]: ",bracket1[i].start,bracket1[i].end-1);
892 | l = bracket1[i].buf_end-bracket1[i].buf_start;
893 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
894 | my_buf[l] = '\0';
895 | printf("%s\n",my_buf);
896 | }
897 | match_found = 0;
898 | }
899 |
900 | for(j=0;j1) {
902 | /* test bracket with no corresponding gold bracket */
903 | printf(" EXTRA[%d-%d]: ",bracket2[j].start,bracket2[j].end-1);
904 | l = bracket2[j].buf_end-bracket2[j].buf_start;
905 | strncpy(my_buf,buf+bracket2[j].buf_start,l);
906 | my_buf[l] = '\0';
907 | printf("%s\n",my_buf);
908 | }
909 | }
910 |
911 | /* crossing */
912 | /*----------*/
913 | crossing = 0;
914 |
915 | /* crossing is counted based on the brackets */
916 | /* in test rather than gold file (by Mike) */
917 | for(j=0;j bracket2[j].start &&
923 | bracket1[i].end < bracket2[j].end) ||
924 | (bracket1[i].start > bracket2[j].start &&
925 | bracket1[i].start < bracket2[j].end &&
926 | bracket1[i].end > bracket2[j].end))){
927 |
928 | /* ### ML 09/01/03: get details on cross-brackettings */
929 | if (i != last_i) {
930 | if (DEBUG>1) {
931 | printf(" CROSSING[%d-%d]: ",bracket1[i].start,bracket1[i].end-1);
932 | l = bracket1[i].buf_end-bracket1[i].buf_start;
933 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
934 | my_buf[l] = '\0';
935 | printf("%s\n",my_buf);
936 |
937 | /* ML
938 | printf("\n CROSSING at bracket %d:\n",i-1);
939 | printf(" GOLD (tokens %d-%d): ",bracket1[i].start,bracket1[i].end-1);
940 | l = bracket1[i].buf_end-bracket1[i].buf_start;
941 | strncpy(my_buf,buf1+bracket1[i].buf_start,l);
942 | my_buf[l] = '\0';
943 | printf("%s\n",my_buf);
944 | */
945 | }
946 | last_i = i;
947 | }
948 |
949 | /* ML
950 | printf(" TEST (tokens %d-%d): ",bracket2[j].start,bracket2[j].end-1);
951 | l = bracket2[j].buf_end-bracket2[j].buf_start;
952 | strncpy(my_buf,buf+bracket2[j].buf_start,l);
953 | my_buf[l] = '\0';
954 | printf("%s\n",my_buf);
955 | */
956 |
957 | crossing++;
958 | break;
959 | }
960 | }
961 | }
962 |
963 | /* Tagging accuracy */
964 | /*------------------*/
965 | correct_tag=0;
966 | for(i=0;i1) {
983 | for(i=0;iindex;
1026 | int bra = quot->bracket;
1027 | s_terminal* term = "->term;
1028 | int k;
1029 | memmove(&terminal[ind+1],
1030 | &terminal[ind],
1031 | sizeof(s_terminal)*(MAX_WORD_IN_SENT-ind-1));
1032 | strcpy(terminal[ind].label,term->label);
1033 | strcpy(terminal[ind].word,term->word);
1034 | (*wn)++;
1035 | if (DEBUG>1)
1036 | printf("bra=%d, ind=%d\n",bra,ind);
1037 | for(k=0;k1)
1041 | printf("bracket[%d]={%d,%d}\n",k,bracket[k].start,bracket[k].end);
1042 | if (k>=bra) {
1043 | bracket[k].start++;
1044 | bracket[k].end++;
1045 | }
1046 | //if (bracket[k].start<=ind && bracket[k].end>=ind)
1047 | //bracket[k].end++;
1048 | }
1049 | if (DEBUG>1)
1050 | printf("endslen=%d\n",quot->endslen);
1051 | for(k=0;kendslen;k++) {
1052 | //printf("ends[%d]=%d",k,quot->ends[k]);
1053 | bracket[quot->ends[k]].end++;
1054 | }
1055 | //free(quot->ends);
1056 | }
1057 | /*
1058 | void
1059 | adjust_end(ind,bra)
1060 | int ind;
1061 | int bra;
1062 | {
1063 | for(k=0;k=bra)
1068 | bracket[k].end++;
1069 | }
1070 | }
1071 | */
1072 | void
1073 | massage_data()
1074 | {
1075 | int i, j;
1076 |
1077 | /* for GOLD */
1078 | /*----------*/
1079 | for(i=0;i0 && TOTAL_bn2>0){
1246 | printf(" %6.2f %6.2f %6d %5d %5d %5d",
1247 | (TOTAL_bn1>0?100.0*TOTAL_match/TOTAL_bn1:0.0),
1248 | (TOTAL_bn2>0?100.0*TOTAL_match/TOTAL_bn2:0.0),
1249 | TOTAL_match,
1250 | TOTAL_bn1,
1251 | TOTAL_bn2,
1252 | TOTAL_crossing);
1253 | }
1254 |
1255 | printf(" %5d %5d %6.2f",
1256 | TOTAL_word,
1257 | TOTAL_correct_tag,
1258 | (TOTAL_word>0?100.0*TOTAL_correct_tag/TOTAL_word:0.0));
1259 |
1260 | printf("\n");
1261 | printf("=== Summary ===\n");
1262 |
1263 | sentn = TOTAL_sent - TOTAL_error_sent - TOTAL_skip_sent;
1264 |
1265 | printf("\n-- All --\n");
1266 | printf("Number of sentence = %6d\n",TOTAL_sent);
1267 | printf("Number of Error sentence = %6d\n",TOTAL_error_sent);
1268 | printf("Number of Skip sentence = %6d\n",TOTAL_skip_sent);
1269 | printf("Number of Valid sentence = %6d\n",sentn);
1270 |
1271 | r = TOTAL_bn1>0 ? 100.0*TOTAL_match/TOTAL_bn1 : 0.0;
1272 | printf("Bracketing Recall = %6.2f\n",r);
1273 |
1274 | p = TOTAL_bn2>0 ? 100.0*TOTAL_match/TOTAL_bn2 : 0.0;
1275 | printf("Bracketing Precision = %6.2f\n",p);
1276 |
1277 | f = 2*p*r/(p+r);
1278 | printf("Bracketing FMeasure = %6.2f\n",f);
1279 |
1280 | printf("Complete match = %6.2f\n",
1281 | (sentn>0?100.0*TOTAL_comp_sent/sentn:0.0));
1282 | printf("Average crossing = %6.2f\n",
1283 | (sentn>0?1.0*TOTAL_crossing/sentn:0.0));
1284 | printf("No crossing = %6.2f\n",
1285 | (sentn>0?100.0*TOTAL_no_crossing/sentn:0.0));
1286 | printf("2 or less crossing = %6.2f\n",
1287 | (sentn>0?100.0*TOTAL_2L_crossing/sentn:0.0));
1288 | printf("Tagging accuracy = %6.2f\n",
1289 | (TOTAL_word>0?100.0*TOTAL_correct_tag/TOTAL_word:0.0));
1290 |
1291 | sentn = TOT40_sent - TOT40_error_sent - TOT40_skip_sent;
1292 |
1293 | printf("\n-- len<=%d --\n",TOT_cut_len);
1294 | printf("Number of sentence = %6d\n",TOT40_sent);
1295 | printf("Number of Error sentence = %6d\n",TOT40_error_sent);
1296 | printf("Number of Skip sentence = %6d\n",TOT40_skip_sent);
1297 | printf("Number of Valid sentence = %6d\n",sentn);
1298 |
1299 |
1300 | r = TOT40_bn1>0 ? 100.0*TOT40_match/TOT40_bn1 : 0.0;
1301 | printf("Bracketing Recall = %6.2f\n",r);
1302 |
1303 | p = TOT40_bn2>0 ? 100.0*TOT40_match/TOT40_bn2 : 0.0;
1304 | printf("Bracketing Precision = %6.2f\n",p);
1305 |
1306 | f = 2*p*r/(p+r);
1307 | printf("Bracketing FMeasure = %6.2f\n",f);
1308 |
1309 | printf("Complete match = %6.2f\n",
1310 | (sentn>0?100.0*TOT40_comp_sent/sentn:0.0));
1311 | printf("Average crossing = %6.2f\n",
1312 | (sentn>0?1.0*TOT40_crossing/sentn:0.0));
1313 | printf("No crossing = %6.2f\n",
1314 | (sentn>0?100.0*TOT40_no_crossing/sentn:0.0));
1315 | printf("2 or less crossing = %6.2f\n",
1316 | (sentn>0?100.0*TOT40_2L_crossing/sentn:0.0));
1317 | printf("Tagging accuracy = %6.2f\n",
1318 | (TOT40_word>0?100.0*TOT40_correct_tag/TOT40_word:0.0));
1319 |
1320 | }
1321 |
1322 |
1323 | /*--------------------------------*/
1324 | /* display individual information */
1325 | /*--------------------------------*/
1326 | void
1327 | dsp_info()
1328 | {
1329 | int i, n;
1330 |
1331 | printf("-<1>---(wn1=%3d, bn1=%3d)- ",wn1,bn1);
1332 | printf("-<2>---(wn2=%3d, bn2=%3d)-\n",wn2,bn2);
1333 |
1334 | n = (wn1>wn2?wn1:wn2);
1335 |
1336 | for(i=0;ibn2?bn1:bn2);
1354 |
1355 | for(i=0;iMax_error){
1503 | exit(1);
1504 | }
1505 | }
1506 |
1507 |
1508 | /*---------------------*/
1509 | /* fatal error to exit */
1510 | /*---------------------*/
1511 | void
1512 | Fatal(s,arg1,arg2,arg3)
1513 | char *s, *arg1, *arg2, *arg3;
1514 | {
1515 | fprintf(stderr,s,arg1,arg2,arg3);
1516 | exit(1);
1517 | }
1518 |
1519 |
1520 | /*-------*/
1521 | /* Usage */
1522 | /*-------*/
1523 | void
1524 | Usage()
1525 | {
1526 | fprintf(stderr," evalb [-dDh][-c n][-e n][-p param_file] gold-file test-file \n");
1527 | fprintf(stderr," \n");
1528 | fprintf(stderr," Evaluate bracketing in test-file against gold-file. \n");
1529 | fprintf(stderr," Return recall, precision, F-Measure, tag accuracy. \n");
1530 | fprintf(stderr," \n");
1531 | fprintf(stderr," \n");
1532 | fprintf(stderr," -d debug mode \n");
1533 | fprintf(stderr," -D debug mode plus bracketing info \n");
1534 | fprintf(stderr," -c n cut-off length forstatistics (def.=40)\n");
1535 | fprintf(stderr," -e n number of error to kill (default=10) \n");
1536 | fprintf(stderr," -p param_file parameter file \n");
1537 | fprintf(stderr," -h help \n");
1538 | }
1539 |
--------------------------------------------------------------------------------
/EVALB/new.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## 0: No debugging ##
4 | ## 1: print data for individual sentence ##
5 | ## 2: print detailed bracketing info ##
6 | ##------------------------------------------##
7 | DEBUG 0
8 |
9 | ##------------------------------------------##
10 | ## MAX error ##
11 | ## Number of error to stop the process. ##
12 | ## This is useful if there could be ##
13 | ## tokanization error. ##
14 | ## The process will stop when this number##
15 | ## of errors are accumulated. ##
16 | ##------------------------------------------##
17 | MAX_ERROR 10
18 |
19 | ##------------------------------------------##
20 | ## Cut-off length for statistics ##
21 | ## At the end of evaluation, the ##
22 | ## statistics for the senetnces of length##
23 | ## less than or equal to this number will##
24 | ## be shown, on top of the statistics ##
25 | ## for all the sentences ##
26 | ##------------------------------------------##
27 | CUTOFF_LEN 40
28 |
29 | ##------------------------------------------##
30 | ## unlabeled or labeled bracketing ##
31 | ## 0: unlabeled bracketing ##
32 | ## 1: labeled bracketing ##
33 | ##------------------------------------------##
34 | LABELED 1
35 |
36 | ##------------------------------------------##
37 | ## Delete labels ##
38 | ## list of labels to be ignored. ##
39 | ## If it is a pre-terminal label, delete ##
40 | ## the word along with the brackets. ##
41 | ## If it is a non-terminal label, just ##
42 | ## delete the brackets (don't delete ##
43 | ## deildrens). ##
44 | ##------------------------------------------##
45 | DELETE_LABEL TOP
46 | DELETE_LABEL S1
47 | DELETE_LABEL -NONE-
48 | DELETE_LABEL ,
49 | DELETE_LABEL :
50 | DELETE_LABEL ``
51 | DELETE_LABEL ''
52 | DELETE_LABEL .
53 | DELETE_LABEL ?
54 | DELETE_LABEL !
55 |
56 | ##------------------------------------------##
57 | ## Delete labels for length calculation ##
58 | ## list of labels to be ignored for ##
59 | ## length calculation purpose ##
60 | ##------------------------------------------##
61 | DELETE_LABEL_FOR_LENGTH -NONE-
62 |
63 | ##------------------------------------------##
64 | ## Labels to be considered for misquote ##
65 | ## (could be possesive or quote) ##
66 | ##------------------------------------------##
67 | QUOTE_LABEL ``
68 | QUOTE_LABEL ''
69 | QUOTE_LABEL POS
70 |
71 | ##------------------------------------------##
72 | ## These ones are less common, but ##
73 | ## are on occasion output by parsers: ##
74 | ##------------------------------------------##
75 | QUOTE_LABEL NN
76 | QUOTE_LABEL CD
77 | QUOTE_LABEL VBZ
78 | QUOTE_LABEL :
79 |
80 | ##------------------------------------------##
81 | ## Equivalent labels, words ##
82 | ## the pairs are considered equivalent ##
83 | ## This is non-directional. ##
84 | ##------------------------------------------##
85 | EQ_LABEL ADVP PRT
86 |
87 | # EQ_WORD Example example
88 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.gld:
--------------------------------------------------------------------------------
1 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
2 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
3 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
4 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
5 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
6 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
7 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
8 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
9 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
10 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
11 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
12 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
13 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
14 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
15 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
16 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
17 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
18 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
19 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
20 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
21 | (S (A-SBJ-1 (P this)) (B-WHATEVER (Q is) (A (R a) (T test))))
22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))))
23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *))
24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *))
25 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.prm:
--------------------------------------------------------------------------------
1 | ##------------------------------------------##
2 | ## Debug mode ##
3 | ## print out data for individual sentence ##
4 | ##------------------------------------------##
5 | DEBUG 0
6 |
7 | ##------------------------------------------##
8 | ## MAX error ##
9 | ## Number of error to stop the process. ##
10 | ## This is useful if there could be ##
11 | ## tokanization error. ##
12 | ## The process will stop when this number##
13 | ## of errors are accumulated. ##
14 | ##------------------------------------------##
15 | MAX_ERROR 10
16 |
17 | ##------------------------------------------##
18 | ## Cut-off length for statistics ##
19 | ## At the end of evaluation, the ##
20 | ## statistics for the senetnces of length##
21 | ## less than or equal to this number will##
22 | ## be shown, on top of the statistics ##
23 | ## for all the sentences ##
24 | ##------------------------------------------##
25 | CUTOFF_LEN 40
26 |
27 | ##------------------------------------------##
28 | ## unlabeled or labeled bracketing ##
29 | ## 0: unlabeled bracketing ##
30 | ## 1: labeled bracketing ##
31 | ##------------------------------------------##
32 | LABELED 1
33 |
34 | ##------------------------------------------##
35 | ## Delete labels ##
36 | ## list of labels to be ignored. ##
37 | ## If it is a pre-terminal label, delete ##
38 | ## the word along with the brackets. ##
39 | ## If it is a non-terminal label, just ##
40 | ## delete the brackets (don't delete ##
41 | ## deildrens). ##
42 | ##------------------------------------------##
43 | DELETE_LABEL TOP
44 | DELETE_LABEL -NONE-
45 | DELETE_LABEL ,
46 | DELETE_LABEL :
47 | DELETE_LABEL ``
48 | DELETE_LABEL ''
49 |
50 | ##------------------------------------------##
51 | ## Delete labels for length calculation ##
52 | ## list of labels to be ignored for ##
53 | ## length calculation purpose ##
54 | ##------------------------------------------##
55 | DELETE_LABEL_FOR_LENGTH -NONE-
56 |
57 |
58 | ##------------------------------------------##
59 | ## Equivalent labels, words ##
60 | ## the pairs are considered equivalent ##
61 | ## This is non-directional. ##
62 | ##------------------------------------------##
63 | EQ_LABEL T TT
64 |
65 | EQ_WORD This this
66 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.rsl:
--------------------------------------------------------------------------------
1 | Sent. Matched Bracket Cross Correct Tag
2 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy
3 | ============================================================================
4 | 1 4 0 100.00 100.00 4 4 4 0 4 4 100.00
5 | 2 4 0 75.00 75.00 3 4 4 0 4 4 100.00
6 | 3 4 0 100.00 100.00 4 4 4 0 4 3 75.00
7 | 4 4 0 75.00 75.00 3 4 4 0 4 3 75.00
8 | 5 4 0 75.00 75.00 3 4 4 0 4 4 100.00
9 | 6 4 0 50.00 66.67 2 4 3 1 4 4 100.00
10 | 7 4 0 25.00 100.00 1 4 1 0 4 4 100.00
11 | 8 4 0 0.00 0.00 0 4 0 0 4 4 100.00
12 | 9 4 0 100.00 80.00 4 4 5 0 4 4 100.00
13 | 10 4 0 100.00 50.00 4 4 8 0 4 4 100.00
14 | 11 4 2 0.00 0.00 0 0 0 0 4 0 0.00
15 | 12 4 1 0.00 0.00 0 0 0 0 4 0 0.00
16 | 13 4 1 0.00 0.00 0 0 0 0 4 0 0.00
17 | 14 4 2 0.00 0.00 0 0 0 0 4 0 0.00
18 | 15 4 0 100.00 100.00 4 4 4 0 4 4 100.00
19 | 16 4 1 0.00 0.00 0 0 0 0 4 0 0.00
20 | 17 4 1 0.00 0.00 0 0 0 0 4 0 0.00
21 | 18 4 0 100.00 100.00 4 4 4 0 4 4 100.00
22 | 19 4 0 100.00 100.00 4 4 4 0 4 4 100.00
23 | 20 4 1 0.00 0.00 0 0 0 0 4 0 0.00
24 | 21 4 0 100.00 100.00 4 4 4 0 4 4 100.00
25 | 22 44 0 100.00 100.00 34 34 34 0 44 44 100.00
26 | 23 4 0 100.00 100.00 4 4 4 0 4 4 100.00
27 | 24 5 0 100.00 100.00 4 4 4 0 4 4 100.00
28 | ============================================================================
29 | 87.76 90.53 86 98 95 16 108 106 98.15
30 | === Summary ===
31 |
32 | -- All --
33 | Number of sentence = 24
34 | Number of Error sentence = 5
35 | Number of Skip sentence = 2
36 | Number of Valid sentence = 17
37 | Bracketing Recall = 87.76
38 | Bracketing Precision = 90.53
39 | Complete match = 52.94
40 | Average crossing = 0.06
41 | No crossing = 94.12
42 | 2 or less crossing = 100.00
43 | Tagging accuracy = 98.15
44 |
45 | -- len<=40 --
46 | Number of sentence = 23
47 | Number of Error sentence = 5
48 | Number of Skip sentence = 2
49 | Number of Valid sentence = 16
50 | Bracketing Recall = 81.25
51 | Bracketing Precision = 85.25
52 | Complete match = 50.00
53 | Average crossing = 0.06
54 | No crossing = 93.75
55 | 2 or less crossing = 100.00
56 | Tagging accuracy = 96.88
57 |
--------------------------------------------------------------------------------
/EVALB/sample/sample.tst:
--------------------------------------------------------------------------------
1 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
2 | (S (A (P this)) (B (Q is) (C (R a) (T test))))
3 | (S (A (P this)) (B (Q is) (A (R a) (U test))))
4 | (S (C (P this)) (B (Q is) (A (R a) (U test))))
5 | (S (A (P this)) (B (Q is) (R a) (A (T test))))
6 | (S (A (P this) (Q is)) (A (R a) (T test)))
7 | (S (P this) (Q is) (R a) (T test))
8 | (P this) (Q is) (R a) (T test)
9 | (S (A (P this)) (B (Q is) (A (A (R a) (T test)))))
10 | (S (A (P this)) (B (Q is) (A (A (A (A (A (R a) (T test))))))))
11 |
12 | (S (A (P this)) (B (Q was) (A (A (R a) (T test)))))
13 | (S (A (P this)) (B (Q is) (U not) (A (A (R a) (T test)))))
14 |
15 | (TOP (S (A (P this)) (B (Q is) (A (R a) (T test)))))
16 | (S (A (P this)) (NONE *) (B (Q is) (A (R a) (T test))))
17 | (S (A (P this)) (S (NONE abc) (A (NONE *))) (B (Q is) (A (R a) (T test))))
18 | (S (A (P this)) (B (Q is) (A (R a) (TT test))))
19 | (S (A (P This)) (B (Q is) (A (R a) (T test))))
20 | (S (A (P That)) (B (Q is) (A (R a) (T test))))
21 | (S (A (P this)) (B (Q is) (A (R a) (T test))))
22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))))
23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *))
24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *))
25 |
--------------------------------------------------------------------------------
/EVALB/tgrep_proc.prl:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/perl
2 |
3 | while(<>)
4 | {
5 | if(m/TOP/) #skip lines which are blank
6 | {
7 | print;
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2017,
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/ON_LSTM.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 |
5 | from locked_dropout import LockedDropout
6 |
7 |
8 | class LayerNorm(nn.Module):
9 |
10 | def __init__(self, features, eps=1e-6):
11 | super(LayerNorm, self).__init__()
12 | self.gamma = nn.Parameter(torch.ones(features))
13 | self.beta = nn.Parameter(torch.zeros(features))
14 | self.eps = eps
15 |
16 | def forward(self, x):
17 | mean = x.mean(-1, keepdim=True)
18 | std = x.std(-1, keepdim=True)
19 | return self.gamma * (x - mean) / (std + self.eps) + self.beta
20 |
21 |
22 | class LinearDropConnect(nn.Linear):
23 | def __init__(self, in_features, out_features, bias=True, dropout=0.):
24 | super(LinearDropConnect, self).__init__(
25 | in_features=in_features,
26 | out_features=out_features,
27 | bias=bias
28 | )
29 | self.dropout = dropout
30 |
31 | def sample_mask(self):
32 | if self.dropout == 0.:
33 | self._weight = self.weight
34 | else:
35 | mask = self.weight.new_empty(
36 | self.weight.size(),
37 | dtype=torch.uint8
38 | )
39 | mask.bernoulli_(self.dropout)
40 | self._weight = self.weight.masked_fill(mask, 0.)
41 |
42 | def forward(self, input, sample_mask=False):
43 | if self.training:
44 | if sample_mask:
45 | self.sample_mask()
46 | return F.linear(input, self._weight, self.bias)
47 | else:
48 | return F.linear(input, self.weight * (1 - self.dropout),
49 | self.bias)
50 |
51 |
52 | def cumsoftmax(x, dim=-1):
53 | return torch.cumsum(F.softmax(x, dim=dim), dim=dim)
54 |
55 |
56 | class ONLSTMCell(nn.Module):
57 |
58 | def __init__(self, input_size, hidden_size, chunk_size, dropconnect=0.):
59 | super(ONLSTMCell, self).__init__()
60 | self.input_size = input_size
61 | self.hidden_size = hidden_size
62 | self.chunk_size = chunk_size
63 | self.n_chunk = int(hidden_size / chunk_size)
64 |
65 | self.ih = nn.Sequential(
66 | nn.Linear(input_size, 4 * hidden_size + self.n_chunk * 2, bias=True),
67 | # LayerNorm(3 * hidden_size)
68 | )
69 | self.hh = LinearDropConnect(hidden_size, hidden_size*4+self.n_chunk*2, bias=True, dropout=dropconnect)
70 |
71 | # self.c_norm = LayerNorm(hidden_size)
72 |
73 | self.drop_weight_modules = [self.hh]
74 |
75 | def forward(self, input, hidden,
76 | transformed_input=None):
77 | hx, cx = hidden
78 |
79 | if transformed_input is None:
80 | transformed_input = self.ih(input)
81 |
82 | gates = transformed_input + self.hh(hx)
83 | cingate, cforgetgate = gates[:, :self.n_chunk*2].chunk(2, 1)
84 | outgate, cell, ingate, forgetgate = gates[:,self.n_chunk*2:].view(-1, self.n_chunk*4, self.chunk_size).chunk(4,1)
85 |
86 | cingate = 1. - cumsoftmax(cingate)
87 | cforgetgate = cumsoftmax(cforgetgate)
88 |
89 | distance_cforget = 1. - cforgetgate.sum(dim=-1) / self.n_chunk
90 | distance_cin = cingate.sum(dim=-1) / self.n_chunk
91 |
92 | cingate = cingate[:, :, None]
93 | cforgetgate = cforgetgate[:, :, None]
94 |
95 | ingate = F.sigmoid(ingate)
96 | forgetgate = F.sigmoid(forgetgate)
97 | cell = F.tanh(cell)
98 | outgate = F.sigmoid(outgate)
99 |
100 | # cy = cforgetgate * forgetgate * cx + cingate * ingate * cell
101 |
102 | overlap = cforgetgate * cingate
103 | forgetgate = forgetgate * overlap + (cforgetgate - overlap)
104 | ingate = ingate * overlap + (cingate - overlap)
105 | cy = forgetgate * cx + ingate * cell
106 |
107 | # hy = outgate * F.tanh(self.c_norm(cy))
108 | hy = outgate * F.tanh(cy)
109 | return hy.view(-1, self.hidden_size), cy, (distance_cforget, distance_cin)
110 |
111 | def init_hidden(self, bsz):
112 | weight = next(self.parameters()).data
113 | return (weight.new(bsz, self.hidden_size).zero_(),
114 | weight.new(bsz, self.n_chunk, self.chunk_size).zero_())
115 |
116 | def sample_masks(self):
117 | for m in self.drop_weight_modules:
118 | m.sample_mask()
119 |
120 |
121 | class ONLSTMStack(nn.Module):
122 | def __init__(self, layer_sizes, chunk_size, dropout=0., dropconnect=0.):
123 | super(ONLSTMStack, self).__init__()
124 | self.cells = nn.ModuleList([ONLSTMCell(layer_sizes[i],
125 | layer_sizes[i+1],
126 | chunk_size,
127 | dropconnect=dropconnect)
128 | for i in range(len(layer_sizes) - 1)])
129 | self.lockdrop = LockedDropout()
130 | self.dropout = dropout
131 | self.sizes = layer_sizes
132 |
133 | def init_hidden(self, bsz):
134 | return [c.init_hidden(bsz) for c in self.cells]
135 |
136 | def forward(self, input, hidden):
137 | length, batch_size, _ = input.size()
138 |
139 | if self.training:
140 | for c in self.cells:
141 | c.sample_masks()
142 |
143 | prev_state = list(hidden)
144 | prev_layer = input
145 |
146 | raw_outputs = []
147 | outputs = []
148 | distances_forget = []
149 | distances_in = []
150 | for l in range(len(self.cells)):
151 | curr_layer = [None] * length
152 | dist = [None] * length
153 | t_input = self.cells[l].ih(prev_layer)
154 |
155 | for t in range(length):
156 | hidden, cell, d = self.cells[l](
157 | None, prev_state[l],
158 | transformed_input=t_input[t]
159 | )
160 | prev_state[l] = hidden, cell # overwritten every timestep
161 | curr_layer[t] = hidden
162 | dist[t] = d
163 |
164 | prev_layer = torch.stack(curr_layer)
165 | dist_cforget, dist_cin = zip(*dist)
166 | dist_layer_cforget = torch.stack(dist_cforget)
167 | dist_layer_cin = torch.stack(dist_cin)
168 | raw_outputs.append(prev_layer)
169 | if l < len(self.cells) - 1:
170 | prev_layer = self.lockdrop(prev_layer, self.dropout)
171 | outputs.append(prev_layer)
172 | distances_forget.append(dist_layer_cforget)
173 | distances_in.append(dist_layer_cin)
174 | output = prev_layer
175 |
176 | return output, prev_state, raw_outputs, outputs, (torch.stack(distances_forget), torch.stack(distances_in))
177 |
178 |
179 | if __name__ == "__main__":
180 | x = torch.Tensor(10, 10, 10)
181 | x.data.normal_()
182 | lstm = ONLSTMStack([10, 10, 10], chunk_size=10)
183 | print(lstm(x, lstm.init_hidden(10))[1])
184 |
185 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ON-LSTM
2 |
3 | This repository contains the code used for word-level language model and unsupervised parsing experiments in
4 | [Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks](https://arxiv.org/abs/1810.09536) paper,
5 | originally forked from the
6 | [LSTM and QRNN Language Model Toolkit for PyTorch](https://github.com/salesforce/awd-lstm-lm).
7 | If you use this code or our results in your research, we'd appreciate if you cite our paper as following:
8 |
9 | ```
10 | @article{shen2018ordered,
11 | title={Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks},
12 | author={Shen, Yikang and Tan, Shawn and Sordoni, Alessandro and Courville, Aaron},
13 | journal={arXiv preprint arXiv:1810.09536},
14 | year={2018}
15 | }
16 | ```
17 |
18 | ## Software Requirements
19 | Python 3.6, NLTK and PyTorch 0.4 are required for the current codebase.
20 |
21 | ## Steps
22 |
23 | 1. Install PyTorch 0.4 and NLTK
24 |
25 | 2. Download PTB data. Note that the two tasks, i.e., language modeling and unsupervised parsing share the same model
26 | strucutre but require different formats of the PTB data. For language modeling we need the standard 10,000 word
27 | [Penn Treebank corpus](https://github.com/pytorch/examples/tree/75e435f98ab7aaa7f82632d4e633e8e03070e8ac/word_language_model/data/penn) data,
28 | and for parsing we need [Penn Treebank Parsed](https://catalog.ldc.upenn.edu/LDC99T42) data.
29 |
30 | 3. Scripts and commands
31 |
32 | + Train Language Modeling
33 | ```python main.py --batch_size 20 --dropout 0.45 --dropouth 0.3 --dropouti 0.5 --wdrop 0.45 --chunk_size 10 --seed 141 --epoch 1000 --data /path/to/your/data```
34 |
35 | + Test Unsupervised Parsing
36 | ```python test_phrase_grammar.py --cuda```
37 |
38 | The default setting in `main.py` achieves a perplexity of approximately `56.17` on PTB test set
39 | and unlabeled F1 of approximately `47.7` on WSJ test set.
40 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from collections import Counter
5 |
6 |
7 | class Dictionary(object):
8 | def __init__(self):
9 | self.word2idx = {}
10 | self.idx2word = []
11 | self.counter = Counter()
12 | self.total = 0
13 |
14 | def add_word(self, word):
15 | if word not in self.word2idx:
16 | self.idx2word.append(word)
17 | self.word2idx[word] = len(self.idx2word) - 1
18 | token_id = self.word2idx[word]
19 | self.counter[token_id] += 1
20 | self.total += 1
21 | return self.word2idx[word]
22 |
23 | def __len__(self):
24 | return len(self.idx2word)
25 |
26 |
27 | class Corpus(object):
28 | def __init__(self, path):
29 | self.dictionary = Dictionary()
30 | self.train = self.tokenize(os.path.join(path, 'train.txt'))
31 | self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
32 | self.test = self.tokenize(os.path.join(path, 'test.txt'))
33 |
34 | def tokenize(self, path):
35 | """Tokenizes a text file."""
36 | assert os.path.exists(path)
37 | # Add words to the dictionary
38 | with open(path, 'r') as f:
39 | tokens = 0
40 | for line in f:
41 | words = line.split() + ['']
42 | tokens += len(words)
43 | for word in words:
44 | self.dictionary.add_word(word)
45 |
46 | # Tokenize file content
47 | with open(path, 'r') as f:
48 | ids = torch.LongTensor(tokens)
49 | token = 0
50 | for line in f:
51 | words = line.split() + ['']
52 | for word in words:
53 | ids[token] = self.dictionary.word2idx[word]
54 | token += 1
55 |
56 | return ids
57 |
--------------------------------------------------------------------------------
/data_ptb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import pickle
4 | import copy
5 |
6 | import numpy
7 | import torch
8 | import nltk
9 | from nltk.corpus import ptb
10 |
11 | word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT',
12 | 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ',
13 | 'WDT', 'WP', 'WP$', 'WRB']
14 | currency_tags_words = ['#', '$', 'C$', 'A$']
15 | ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*', '*PPA*', '*NOT*']
16 | punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``']
17 | punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``', '--', ';', '-', '?', '!', '...', '-LCB-', '-RCB-']
18 |
19 | file_ids = ptb.fileids()
20 | train_file_ids = []
21 | valid_file_ids = []
22 | test_file_ids = []
23 | rest_file_ids = []
24 | for id in file_ids:
25 | if 'WSJ/00/WSJ_0000.MRG' <= id <= 'WSJ/24/WSJ_2499.MRG':
26 | train_file_ids.append(id)
27 | if 'WSJ/22/WSJ_2200.MRG' <= id <= 'WSJ/22/WSJ_2299.MRG':
28 | valid_file_ids.append(id)
29 | if 'WSJ/23/WSJ_2300.MRG' <= id <= 'WSJ/23/WSJ_2399.MRG':
30 | test_file_ids.append(id)
31 | # elif 'WSJ/00/WSJ_0000.MRG' <= id <= 'WSJ/01/WSJ_0199.MRG' or 'WSJ/24/WSJ_2400.MRG' <= id <= 'WSJ/24/WSJ_2499.MRG':
32 | # rest_file_ids.append(id)
33 |
34 |
35 | class Dictionary(object):
36 | def __init__(self):
37 | self.word2idx = {'': 0}
38 | self.idx2word = ['']
39 | self.word2frq = {}
40 |
41 | def add_word(self, word):
42 | if word not in self.word2idx:
43 | self.idx2word.append(word)
44 | self.word2idx[word] = len(self.idx2word) - 1
45 | if word not in self.word2frq:
46 | self.word2frq[word] = 1
47 | else:
48 | self.word2frq[word] += 1
49 | return self.word2idx[word]
50 |
51 | def __len__(self):
52 | return len(self.idx2word)
53 |
54 | def __getitem__(self, item):
55 | if item in self.word2idx:
56 | return self.word2idx[item]
57 | else:
58 | return self.word2idx['']
59 |
60 | def rebuild_by_freq(self, thd=3):
61 | self.word2idx = {'': 0}
62 | self.idx2word = ['']
63 |
64 | for k, v in self.word2frq.items():
65 | if v >= thd and (not k in self.idx2word):
66 | self.idx2word.append(k)
67 | self.word2idx[k] = len(self.idx2word) - 1
68 |
69 | print('Number of words:', len(self.idx2word))
70 | return len(self.idx2word)
71 |
72 |
73 | class Corpus(object):
74 | def __init__(self, path):
75 | dict_file_name = os.path.join(path, 'dict.pkl')
76 | if os.path.exists(dict_file_name):
77 | self.dictionary = pickle.load(open(dict_file_name, 'rb'))
78 | else:
79 | self.dictionary = Dictionary()
80 | self.add_words(train_file_ids)
81 | # self.add_words(valid_file_ids)
82 | # self.add_words(test_file_ids)
83 | self.dictionary.rebuild_by_freq()
84 | pickle.dump(self.dictionary, open(dict_file_name, 'wb'))
85 |
86 | self.train, self.train_sens, self.train_trees, self.train_nltktrees = self.tokenize(train_file_ids)
87 | self.valid, self.valid_sens, self.valid_trees, self.valid_nltktress = self.tokenize(valid_file_ids)
88 | self.test, self.test_sens, self.test_trees, self.test_nltktrees = self.tokenize(test_file_ids)
89 | self.rest, self.rest_sens, self.rest_trees, self.rest_nltktrees = self.tokenize(rest_file_ids)
90 |
91 | def filter_words(self, tree):
92 | words = []
93 | for w, tag in tree.pos():
94 | if tag in word_tags:
95 | w = w.lower()
96 | w = re.sub('[0-9]+', 'N', w)
97 | # if tag == 'CD':
98 | # w = 'N'
99 | words.append(w)
100 | return words
101 |
102 | def add_words(self, file_ids):
103 | # Add words to the dictionary
104 | for id in file_ids:
105 | sentences = ptb.parsed_sents(id)
106 | for sen_tree in sentences:
107 | words = self.filter_words(sen_tree)
108 | words = [''] + words + ['']
109 | for word in words:
110 | self.dictionary.add_word(word)
111 |
112 | def tokenize(self, file_ids):
113 |
114 | def tree2list(tree):
115 | if isinstance(tree, nltk.Tree):
116 | if tree.label() in word_tags:
117 | w = tree.leaves()[0].lower()
118 | w = re.sub('[0-9]+', 'N', w)
119 | return w
120 | else:
121 | root = []
122 | for child in tree:
123 | c = tree2list(child)
124 | if c != []:
125 | root.append(c)
126 | if len(root) > 1:
127 | return root
128 | elif len(root) == 1:
129 | return root[0]
130 | return []
131 |
132 | sens_idx = []
133 | sens = []
134 | trees = []
135 | nltk_trees = []
136 | for id in file_ids:
137 | sentences = ptb.parsed_sents(id)
138 | for sen_tree in sentences:
139 | words = self.filter_words(sen_tree)
140 | words = [''] + words + ['']
141 | # if len(words) > 50:
142 | # continue
143 | sens.append(words)
144 | idx = []
145 | for word in words:
146 | idx.append(self.dictionary[word])
147 | sens_idx.append(torch.LongTensor(idx))
148 | trees.append(tree2list(sen_tree))
149 | nltk_trees.append(sen_tree)
150 |
151 | return sens_idx, sens, trees, nltk_trees
--------------------------------------------------------------------------------
/embed_regularize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | def embedded_dropout(embed, words, dropout=0.1, scale=None):
6 | if dropout:
7 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
8 | masked_embed_weight = mask * embed.weight
9 | else:
10 | masked_embed_weight = embed.weight
11 | if scale:
12 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
13 |
14 | padding_idx = embed.padding_idx
15 | if padding_idx is None:
16 | padding_idx = -1
17 |
18 | X = torch.nn.functional.embedding(words, masked_embed_weight,
19 | padding_idx, embed.max_norm, embed.norm_type,
20 | embed.scale_grad_by_freq, embed.sparse
21 | )
22 | return X
23 |
24 | if __name__ == '__main__':
25 | V = 50
26 | h = 4
27 | bptt = 10
28 | batch_size = 2
29 |
30 | embed = torch.nn.Embedding(V, h)
31 |
32 | words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt))
33 | words = torch.LongTensor(words)
34 |
35 | origX = embed(words)
36 | X = embedded_dropout(embed, words)
37 |
38 | print(origX)
39 | print(X)
40 |
--------------------------------------------------------------------------------
/locked_dropout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | class LockedDropout(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def forward(self, x, dropout=0.5):
10 | if not self.training or not dropout:
11 | return x
12 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
13 | mask = Variable(m, requires_grad=False) / (1 - dropout)
14 | mask = mask.expand_as(x)
15 | return mask * x
16 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim.lr_scheduler as lr_scheduler
8 |
9 | import data
10 | import model
11 |
12 | from utils import batchify, get_batch, repackage_hidden
13 |
14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
15 | parser.add_argument('--data', type=str, default='data/penn/',
16 | help='location of the data corpus')
17 | parser.add_argument('--model', type=str, default='LSTM',
18 | help='type of recurrent net (LSTM, QRNN, GRU)')
19 | parser.add_argument('--emsize', type=int, default=400,
20 | help='size of word embeddings')
21 | parser.add_argument('--nhid', type=int, default=1150,
22 | help='number of hidden units per layer')
23 | parser.add_argument('--chunk_size', type=int, default=10,
24 | help='number of units per chunk')
25 | parser.add_argument('--nlayers', type=int, default=3,
26 | help='number of layers')
27 | parser.add_argument('--lr', type=float, default=30,
28 | help='initial learning rate')
29 | parser.add_argument('--clip', type=float, default=0.25,
30 | help='gradient clipping')
31 | parser.add_argument('--epochs', type=int, default=8000,
32 | help='upper epoch limit')
33 | parser.add_argument('--batch_size', type=int, default=80, metavar='N',
34 | help='batch size')
35 | parser.add_argument('--bptt', type=int, default=70,
36 | help='sequence length')
37 | parser.add_argument('--dropout', type=float, default=0.4,
38 | help='dropout applied to layers (0 = no dropout)')
39 | parser.add_argument('--dropouth', type=float, default=0.25,
40 | help='dropout for rnn layers (0 = no dropout)')
41 | parser.add_argument('--dropouti', type=float, default=0.5,
42 | help='dropout for input embedding layers (0 = no dropout)')
43 | parser.add_argument('--dropoute', type=float, default=0.1,
44 | help='dropout to remove words from embedding layer (0 = no dropout)')
45 | parser.add_argument('--wdrop', type=float, default=0.4,
46 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix')
47 | parser.add_argument('--seed', type=int, default=1111,
48 | help='random seed')
49 | parser.add_argument('--nonmono', type=int, default=5,
50 | help='random seed')
51 | parser.add_argument('--cuda', action='store_false',
52 | help='use CUDA')
53 | parser.add_argument('--log-interval', type=int, default=200, metavar='N',
54 | help='report interval')
55 | randomhash = ''.join(str(time.time()).split('.'))
56 | parser.add_argument('--save', type=str, default=randomhash + '.pt',
57 | help='path to save the final model')
58 | parser.add_argument('--alpha', type=float, default=2,
59 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)')
60 | parser.add_argument('--beta', type=float, default=1,
61 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)')
62 | parser.add_argument('--wdecay', type=float, default=1.2e-6,
63 | help='weight decay applied to all weights')
64 | parser.add_argument('--resume', type=str, default='',
65 | help='path of model to resume')
66 | parser.add_argument('--optimizer', type=str, default='sgd',
67 | help='optimizer to use (sgd, adam)')
68 | parser.add_argument('--when', nargs="+", type=int, default=[-1],
69 | help='When (which epochs) to divide the learning rate by 10 - accepts multiple')
70 | parser.add_argument('--finetuning', type=int, default=500,
71 | help='When (which epochs) to switch to finetuning')
72 | parser.add_argument('--philly', action='store_true',
73 | help='Use philly cluster')
74 | args = parser.parse_args()
75 | args.tied = True
76 |
77 | # Set the random seed manually for reproducibility.
78 | np.random.seed(args.seed)
79 | torch.manual_seed(args.seed)
80 | if torch.cuda.is_available():
81 | if not args.cuda:
82 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
83 | else:
84 | torch.cuda.manual_seed(args.seed)
85 |
86 |
87 | ###############################################################################
88 | # Load data
89 | ###############################################################################
90 |
91 | def model_save(fn):
92 | if args.philly:
93 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
94 | with open(fn, 'wb') as f:
95 | torch.save([model, criterion, optimizer], f)
96 |
97 |
98 | def model_load(fn):
99 | global model, criterion, optimizer
100 | if args.philly:
101 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
102 | with open(fn, 'rb') as f:
103 | model, criterion, optimizer = torch.load(f)
104 |
105 |
106 | import os
107 | import hashlib
108 |
109 | fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest())
110 | if args.philly:
111 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
112 | if os.path.exists(fn):
113 | print('Loading cached dataset...')
114 | corpus = torch.load(fn)
115 | else:
116 | print('Producing dataset...')
117 | corpus = data.Corpus(args.data)
118 | torch.save(corpus, fn)
119 |
120 | eval_batch_size = 10
121 | test_batch_size = 1
122 | train_data = batchify(corpus.train, args.batch_size, args)
123 | val_data = batchify(corpus.valid, eval_batch_size, args)
124 | test_data = batchify(corpus.test, test_batch_size, args)
125 |
126 | ###############################################################################
127 | # Build the model
128 | ###############################################################################
129 |
130 | from splitcross import SplitCrossEntropyLoss
131 |
132 | criterion = None
133 |
134 | ntokens = len(corpus.dictionary)
135 | model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.chunk_size, args.nlayers,
136 | args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.tied)
137 | ###
138 | if args.resume:
139 | print('Resuming model ...')
140 | model_load(args.resume)
141 | optimizer.param_groups[0]['lr'] = args.lr
142 | model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute
143 | if args.wdrop:
144 | for rnn in model.rnn.cells:
145 | rnn.hh.dropout = args.wdrop
146 | ###
147 | if not criterion:
148 | splits = []
149 | if ntokens > 500000:
150 | # One Billion
151 | # This produces fairly even matrix mults for the buckets:
152 | # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422
153 | splits = [4200, 35000, 180000]
154 | elif ntokens > 75000:
155 | # WikiText-103
156 | splits = [2800, 20000, 76000]
157 | print('Using', splits)
158 | criterion = SplitCrossEntropyLoss(args.emsize, splits=splits, verbose=False)
159 | ###
160 | if args.cuda:
161 | model = model.cuda()
162 | criterion = criterion.cuda()
163 | ###
164 | params = list(model.parameters()) + list(criterion.parameters())
165 | total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
166 | print('Args:', args)
167 | print('Model total parameters:', total_params)
168 |
169 |
170 | ###############################################################################
171 | # Training code
172 | ###############################################################################
173 |
174 | def evaluate(data_source, batch_size=10):
175 | # Turn on evaluation mode which disables dropout.
176 | model.eval()
177 | if args.model == 'QRNN': model.reset()
178 | total_loss = 0
179 | ntokens = len(corpus.dictionary)
180 | hidden = model.init_hidden(batch_size)
181 | for i in range(0, data_source.size(0) - 1, args.bptt):
182 | data, targets = get_batch(data_source, i, args, evaluation=True)
183 | output, hidden = model(data, hidden)
184 | total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data
185 | hidden = repackage_hidden(hidden)
186 | return total_loss.item() / len(data_source)
187 |
188 |
189 | def train():
190 | # Turn on training mode which enables dropout.
191 | if args.model == 'QRNN': model.reset()
192 | total_loss = 0
193 | start_time = time.time()
194 | ntokens = len(corpus.dictionary)
195 | hidden = model.init_hidden(args.batch_size)
196 | batch, i = 0, 0
197 | while i < train_data.size(0) - 1 - 1:
198 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
199 | # Prevent excessively small or negative sequence lengths
200 | seq_len = max(5, int(np.random.normal(bptt, 5)))
201 | # There's a very small chance that it could select a very long sequence length resulting in OOM
202 | # seq_len = min(seq_len, args.bptt + 10)
203 |
204 | lr2 = optimizer.param_groups[0]['lr']
205 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
206 | model.train()
207 | data, targets = get_batch(train_data, i, args, seq_len=seq_len)
208 |
209 | # Starting each batch, we detach the hidden state from how it was previously produced.
210 | # If we didn't, the model would try backpropagating all the way to start of the dataset.
211 | hidden = repackage_hidden(hidden)
212 | optimizer.zero_grad()
213 |
214 | output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
215 | # output, hidden = model(data, hidden, return_h=False)
216 | raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets)
217 |
218 | loss = raw_loss
219 | # Activiation Regularization
220 | if args.alpha:
221 | loss = loss + sum(
222 | args.alpha * dropped_rnn_h.pow(2).mean()
223 | for dropped_rnn_h in dropped_rnn_hs[-1:]
224 | )
225 | # Temporal Activation Regularization (slowness)
226 | if args.beta:
227 | loss = loss + sum(
228 | args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean()
229 | for rnn_h in rnn_hs[-1:]
230 | )
231 | loss.backward()
232 |
233 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
234 | if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip)
235 | optimizer.step()
236 |
237 | total_loss += raw_loss.data
238 | optimizer.param_groups[0]['lr'] = lr2
239 | if batch % args.log_interval == 0 and batch > 0:
240 | cur_loss = total_loss.item() / args.log_interval
241 | elapsed = time.time() - start_time
242 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
243 | 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
244 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'],
245 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2)))
246 | total_loss = 0
247 | start_time = time.time()
248 | ###
249 | batch += 1
250 | i += seq_len
251 |
252 |
253 | # Loop over epochs.
254 | lr = args.lr
255 | best_val_loss = []
256 | stored_loss = 100000000
257 |
258 | # At any point you can hit Ctrl + C to break out of training early.
259 | try:
260 | optimizer = None
261 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
262 | if args.optimizer == 'sgd':
263 | optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
264 | if args.optimizer == 'adam':
265 | optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0, 0.999), eps=1e-9, weight_decay=args.wdecay)
266 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 0.5, patience=2, threshold=0)
267 | for epoch in range(1, args.epochs + 1):
268 | epoch_start_time = time.time()
269 | train()
270 | if 't0' in optimizer.param_groups[0]:
271 | tmp = {}
272 | for prm in model.parameters():
273 | tmp[prm] = prm.data.clone()
274 | prm.data = optimizer.state[prm]['ax'].clone()
275 |
276 | val_loss2 = evaluate(val_data, eval_batch_size)
277 | print('-' * 89)
278 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
279 | 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
280 | epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2)))
281 | print('-' * 89)
282 |
283 | if val_loss2 < stored_loss:
284 | model_save(args.save)
285 | print('Saving Averaged!')
286 | stored_loss = val_loss2
287 |
288 | for prm in model.parameters():
289 | prm.data = tmp[prm].clone()
290 |
291 | if epoch == args.finetuning:
292 | print('Switching to finetuning')
293 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay)
294 | best_val_loss = []
295 |
296 | if epoch > args.finetuning and len(best_val_loss) > args.nonmono and val_loss2 > min(
297 | best_val_loss[:-args.nonmono]):
298 | print('Done!')
299 | import sys
300 |
301 | sys.exit(1)
302 |
303 | else:
304 | val_loss = evaluate(val_data, eval_batch_size)
305 | print('-' * 89)
306 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
307 | 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
308 | epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2)))
309 | print('-' * 89)
310 |
311 | if val_loss < stored_loss:
312 | model_save(args.save)
313 | print('Saving model (new best validation)')
314 | stored_loss = val_loss
315 |
316 | if args.optimizer == 'adam':
317 | scheduler.step(val_loss)
318 |
319 | if args.optimizer == 'sgd' and 't0' not in optimizer.param_groups[0] and (
320 | len(best_val_loss) > args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])):
321 | print('Switching to ASGD')
322 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay)
323 |
324 | if epoch in args.when:
325 | print('Saving model before learning rate decreased')
326 | model_save('{}.e{}'.format(args.save, epoch))
327 | print('Dividing learning rate by 10')
328 | optimizer.param_groups[0]['lr'] /= 10.
329 |
330 | best_val_loss.append(val_loss)
331 |
332 | print("PROGRESS: {}%".format((epoch / args.epochs) * 100))
333 |
334 | except KeyboardInterrupt:
335 | print('-' * 89)
336 | print('Exiting from training early')
337 |
338 | # Load the best saved model.
339 | model_load(args.save)
340 |
341 | # Run on test data.
342 | test_loss = evaluate(test_data, test_batch_size)
343 | print('=' * 89)
344 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format(
345 | test_loss, math.exp(test_loss), test_loss / math.log(2)))
346 | print('=' * 89)
347 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from embed_regularize import embedded_dropout
5 | from locked_dropout import LockedDropout
6 | from weight_drop import WeightDrop
7 | from ON_LSTM import ONLSTMStack
8 |
9 | class RNNModel(nn.Module):
10 | """Container module with an encoder, a recurrent module, and a decoder."""
11 |
12 | def __init__(self, rnn_type, ntoken, ninp, nhid, chunk_size, nlayers, dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, tie_weights=False):
13 | super(RNNModel, self).__init__()
14 | self.lockdrop = LockedDropout()
15 | self.idrop = nn.Dropout(dropouti)
16 | self.hdrop = nn.Dropout(dropouth)
17 | self.drop = nn.Dropout(dropout)
18 | self.encoder = nn.Embedding(ntoken, ninp)
19 | assert rnn_type in ['LSTM'], 'RNN type is not supported'
20 | self.rnn = ONLSTMStack(
21 | [ninp] + [nhid] * (nlayers - 1) + [ninp],
22 | chunk_size=chunk_size,
23 | dropconnect=wdrop,
24 | dropout=dropouth
25 | )
26 | self.decoder = nn.Linear(ninp, ntoken)
27 |
28 | # Optionally tie weights as in:
29 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
30 | # https://arxiv.org/abs/1608.05859
31 | # and
32 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
33 | # https://arxiv.org/abs/1611.01462
34 | if tie_weights:
35 | #if nhid != ninp:
36 | # raise ValueError('When using the tied flag, nhid must be equal to emsize')
37 | self.decoder.weight = self.encoder.weight
38 |
39 | self.init_weights()
40 |
41 | self.rnn_type = rnn_type
42 | self.ninp = ninp
43 | self.nhid = nhid
44 | self.nlayers = nlayers
45 | self.dropout = dropout
46 | self.dropouti = dropouti
47 | self.dropouth = dropouth
48 | self.dropoute = dropoute
49 | self.tie_weights = tie_weights
50 |
51 | def reset(self):
52 | if self.rnn_type == 'QRNN': [r.reset() for r in self.rnns]
53 |
54 | def init_weights(self):
55 | initrange = 0.1
56 | self.encoder.weight.data.uniform_(-initrange, initrange)
57 | self.decoder.bias.data.fill_(0)
58 | self.decoder.weight.data.uniform_(-initrange, initrange)
59 |
60 | def forward(self, input, hidden, return_h=False):
61 | emb = embedded_dropout(
62 | self.encoder, input,
63 | dropout=self.dropoute if self.training else 0
64 | )
65 |
66 | emb = self.lockdrop(emb, self.dropouti)
67 |
68 | raw_output, hidden, raw_outputs, outputs, distances = self.rnn(emb, hidden)
69 | self.distance = distances
70 |
71 | output = self.lockdrop(raw_output, self.dropout)
72 |
73 | result = output.view(output.size(0)*output.size(1), output.size(2))
74 | if return_h:
75 | return result, hidden, raw_outputs, outputs
76 | else:
77 | return result, hidden
78 |
79 | def init_hidden(self, bsz):
80 | return self.rnn.init_hidden(bsz)
81 |
--------------------------------------------------------------------------------
/parse_comparison.py:
--------------------------------------------------------------------------------
1 | """
2 | Reads a parsed corpus (data_path) and a model report (report_path) from a model
3 | that produces latent tree structures and computes the unlabeled F1 score between
4 | the model's latent trees and:
5 | - The ground-truth trees in the parsed corpus
6 | - Strictly left-branching trees for the sentences in the parsed corpus
7 | - Strictly right-branching trees for the sentences in the parsed corpus
8 |
9 | Note that for binary-branching trees like these, precision, recall, and F1 are
10 | equal by definition, so only one number is shown.
11 |
12 | Usage:
13 | $ python scripts/parse_comparison.py \
14 | --data_path ./snli_1.0/snli_1.0_dev.jsonl \
15 | --report_path ./logs/example-nli.report \
16 | """
17 |
18 | import gflags
19 | import sys
20 | import codecs
21 | import json
22 | import random
23 | import re
24 | import glob
25 | import math
26 | from collections import Counter
27 |
28 | LABEL_MAP = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
29 |
30 | FLAGS = gflags.FLAGS
31 |
32 | mathops = ["[MAX", "[MIN", "[MED", "[SM"]
33 |
34 |
35 | def spaceify(parse):
36 | return parse # .replace("(", "( ").replace(")", " )")
37 |
38 |
39 | def balance(parse, lowercase=False):
40 | # Modified to provided a "half-full" binary tree without padding.
41 | # Difference between the other method is the right subtrees are
42 | # the half full ones.
43 | tokens = tokenize_parse(parse)
44 | if len(tokens) > 1:
45 | transitions = full_transitions(len(tokens), right_full=True)
46 | stack = []
47 | for transition in transitions:
48 | if transition == 0:
49 | stack.append(tokens.pop(0))
50 | elif transition == 1:
51 | right = stack.pop()
52 | left = stack.pop()
53 | stack.append("( " + left + " " + right + " )")
54 | assert len(stack) == 1
55 | else:
56 | stack = tokens
57 | return stack[0]
58 |
59 |
60 | def roundup2(N):
61 | """ Round up using factors of 2. """
62 | return int(2 ** math.ceil(math.log(N, 2)))
63 |
64 |
65 | def full_transitions(N, left_full=False, right_full=False):
66 | """
67 | Recursively creates a full binary tree of with N
68 | leaves using shift reduce transitions.
69 | """
70 |
71 | if N == 1:
72 | return [0]
73 |
74 | if N == 2:
75 | return [0, 0, 1]
76 |
77 | assert not (left_full and right_full), "Please only choose one."
78 |
79 | if not left_full and not right_full:
80 | N = float(N)
81 |
82 | # Constrain to full binary trees.
83 | assert math.log(N, 2) % 1 == 0, \
84 | "Bad value. N={}".format(N)
85 |
86 | left_N = N / 2
87 | right_N = N - left_N
88 |
89 | if left_full:
90 | left_N = roundup2(N) / 2
91 | right_N = N - left_N
92 |
93 | if right_full:
94 | right_N = roundup2(N) / 2
95 | left_N = N - right_N
96 |
97 | return full_transitions(left_N, left_full=left_full, right_full=right_full) + \
98 | full_transitions(right_N, left_full=left_full, right_full=right_full) + \
99 | [1]
100 |
101 |
102 | def tokenize_parse(parse):
103 | parse = spaceify(parse)
104 | return [token for token in parse.split() if token not in ['(', ')']]
105 |
106 |
107 | def to_string(parse):
108 | if type(parse) is not list:
109 | return parse
110 | if len(parse) == 1:
111 | return parse[0]
112 | else:
113 | return '( ' + to_string(parse[0]) + ' ' + to_string(parse[1]) + ' )'
114 |
115 |
116 | def tokens_to_rb(tree):
117 | if type(tree) is not list:
118 | return tree
119 | if len(tree) == 1:
120 | return tree[0]
121 | else:
122 | return [tree[0], tokens_to_rb(tree[1:])]
123 |
124 |
125 | def to_rb(gt_table):
126 | new_data = {}
127 | for key in gt_table:
128 | parse = gt_table[key]
129 | tokens = tokenize_parse(parse)
130 | new_data[key] = to_string(tokens_to_rb(tokens))
131 | return new_data
132 |
133 |
134 | def tokens_to_lb(tree):
135 | if type(tree) is not list:
136 | return tree
137 | if len(tree) == 1:
138 | return tree[0]
139 | else:
140 | return [tokens_to_lb(tree[:-1]), tree[-1]]
141 |
142 |
143 | def to_lb(gt_table):
144 | new_data = {}
145 | for key in gt_table:
146 | parse = gt_table[key]
147 | tokens = tokenize_parse(parse)
148 | new_data[key] = to_string(tokens_to_lb(tokens))
149 | return new_data
150 |
151 |
152 | def average_depth(parse):
153 | depths = []
154 | current_depth = 0
155 | for token in parse.split():
156 | if token == '(':
157 | current_depth += 1
158 | elif token == ')':
159 | current_depth -= 1
160 | else:
161 | depths.append(current_depth)
162 | if len(depths) == 0:
163 | pass
164 | else:
165 | return float(sum(depths)) / len(depths)
166 |
167 |
168 | def corpus_average_depth(corpus):
169 | local_averages = []
170 | for key in corpus:
171 | s = corpus[key]
172 | if average_depth(s) is not None:
173 | local_averages.append(average_depth(s))
174 | else:
175 | pass
176 | return float(sum(local_averages)) / len(local_averages)
177 |
178 |
179 | def average_length(parse):
180 | parse = spaceify(parse)
181 | return len(parse.split())
182 |
183 |
184 | def corpus_average_length(corpus):
185 | local_averages = []
186 | for key in corpus:
187 | if average_length(s) is not None:
188 | local_averages.append(average_length(s))
189 | else:
190 | pass
191 | return float(sum(local_averages)) / len(local_averages)
192 |
193 |
194 | def corpus_stats(corpus_1, corpus_2, first_two=False, neg_pair=False, const_parse=False):
195 | """
196 | Note: If a few examples in one dataset are missing from the other (i.e., some examples from the source corpus were not included
197 | in a model corpus), the shorter dataset must be supplied as corpus_1.
198 |
199 | corpus_1 is the report being evaluated (important for counting complete constituents)
200 | """
201 |
202 | f1_accum = 0.0
203 | count = 0.0
204 | first_two_count = 0.0
205 | last_two_count = 0.0
206 | three_count = 0.0
207 | neg_pair_count = 0.0
208 | neg_count = 0.0
209 | const_parsed_1 = 0
210 | if const_parse:
211 | const_parsed_2 = 0
212 | else:
213 | const_parsed_2 = 1
214 | for key in corpus_2:
215 | c1, cp1 = to_indexed_contituents(corpus_1[key], const_parse)
216 | c2, cp2 = to_indexed_contituents(corpus_2[key], const_parse)
217 | f1_accum += example_f1(c1, c2)
218 | count += 1
219 | const_parsed_1 += cp1
220 | const_parsed_2 += cp2
221 |
222 | if first_two and len(c1) > 1:
223 | if (0, 2) in c1:
224 | first_two_count += 1
225 | num_words = len(c1) + 1
226 | if (num_words - 2, num_words) in c1:
227 | last_two_count += 1
228 | three_count += 1
229 | if neg_pair:
230 | word_index = 0
231 | s = spaceify(corpus_1[key])
232 | tokens = s.split()
233 | for token_index, token in enumerate(tokens):
234 | if token in ['(', ')']:
235 | continue
236 | if token in ["n't", "not", "never", "no", "none", "Not", "Never", "No", "None"]:
237 | if tokens[token_index + 1] not in ['(', ')']:
238 | neg_pair_count += 1
239 | neg_count += 1
240 | word_index += 1
241 | stats = f1_accum / count
242 | if first_two:
243 | stats = str(stats) + '\t' + str(first_two_count / three_count) + '\t' + str(last_two_count / three_count)
244 | if neg_pair:
245 | stats = str(stats) + '\t' + str(neg_pair_count / neg_count)
246 | return stats, const_parsed_1 / const_parsed_2
247 |
248 |
249 | def corpus_stats_labeled(corpus_unlabeled, corpus_labeled):
250 | """
251 | Note: If a few examples in one dataset are missing from the other (i.e., some examples from the source corpus were not included
252 | in a model corpus), the shorter dataset must be supplied as corpus_1.
253 | """
254 |
255 | correct = Counter()
256 | total = Counter()
257 |
258 | for key in corpus_labeled:
259 | c1, _, nwords1 = to_indexed_contituents(corpus_unlabeled[key], False)
260 | c2, nwords2 = to_indexed_contituents_labeled(corpus_labeled[key])
261 | assert nwords1 == nwords2
262 | if len(c2) == 0:
263 | continue
264 |
265 | ex_correct, ex_total = example_labeled_acc(c1, c2)
266 | correct.update(ex_correct)
267 | total.update(ex_total)
268 | return correct, total
269 |
270 |
271 | def count_parse(parse, index, const_parsed=[]):
272 | """
273 | Compute Constituents Parsed metric for ListOps style examples.
274 | """
275 | mathops = ["[MAX", "[MIN", "[MED", "[SM"]
276 | if "]" in parse:
277 | after = parse[index:]
278 | before = parse[:index]
279 | between = after[: after.index("]")]
280 |
281 | nest_check = [m in between[1:] for m in mathops]
282 | if True in nest_check:
283 | op_i = nest_check.index(True)
284 | nested_i = after[1:].index(mathops[op_i]) + 1
285 | nested = after[nested_i:]
286 | c = count_parse(parse, index + nested_i, const_parsed)
287 | cc = count_parse(parse, index, const_parsed)
288 | else:
289 | o_b = between.count("(") # open, between
290 | c_b = between.count(")") # close, between
291 |
292 | end = after.index("]")
293 | cafter = after[end + 1:]
294 | stop = None
295 | stop_list = []
296 | for item in cafter:
297 | stop_list.append(")" == item)
298 | if stop_list[-1] == False:
299 | break
300 |
301 | if False in stop_list:
302 | stop = stop_list.index(False)
303 | else:
304 | stop = None
305 | cafter = cafter[: stop]
306 | c_a = cafter.count(")")
307 |
308 | stop = None
309 | stop_list = []
310 | for item in before[::-1]:
311 | stop_list.append("(" == item)
312 | if stop_list[-1] == False:
313 | break
314 |
315 | if False in stop_list:
316 | stop = len(before) - stop_list.index(False) - 1
317 | else:
318 | stop = None
319 | cbefore = before[stop:]
320 | o_a = cbefore.count("(")
321 |
322 | ints = sum(c.isdigit() for c in between) + between.count("-")
323 | op = o_a + o_b
324 | cl = c_a + c_b
325 |
326 | if op >= ints and cl >= ints:
327 | if op == ints + 1 or cl == ints + 1:
328 | const_parsed.append(1)
329 | parse[index - o_a: index + len(between) + 1 + c_a] = '-'
330 | return sum(const_parsed)
331 |
332 |
333 | def to_indexed_contituents(parse, const_parse):
334 | if parse.count("(") != parse.count(")"):
335 | print(parse)
336 | parse = spaceify(parse)
337 | sp = parse.split()
338 | if len(sp) == 1:
339 | return set([(0, 1)]), 0, 1
340 |
341 | backpointers = []
342 | indexed_constituents = set()
343 | word_index = 0
344 | first_op = -1
345 | for index, token in enumerate(sp):
346 | if token == '(':
347 | backpointers.append(word_index)
348 | elif token == ')':
349 | # if len(backpointers) == 0:
350 | # pass
351 | # else:
352 | start = backpointers.pop()
353 | end = word_index
354 | constituent = (start, end)
355 | indexed_constituents.add(constituent)
356 | elif "[" in token:
357 | if first_op == -1:
358 | first_op = index
359 | else:
360 | pass
361 | else:
362 | word_index += 1
363 |
364 | const_parsed = []
365 | cp = 0
366 | if const_parse:
367 | cp = count_parse(sp, first_op, const_parsed)
368 | max_count = parse.count("]")
369 | return indexed_constituents, cp, word_index
370 |
371 |
372 | def to_indexed_contituents_labeled(parse):
373 | # sp = re.findall(r'\([^ ]+| [^\(\) ]+|\)', parse)
374 | sp = parse.split()
375 | if len(sp) == 1:
376 | return set([(0, 1)])
377 |
378 | backpointers = []
379 | indexed_constituents = set()
380 | word_index = 0
381 | for index, token in enumerate(sp):
382 | if token[0] == '(':
383 | backpointers.append((word_index, token[1:]))
384 | elif token == ')':
385 | start, typ = backpointers.pop()
386 | end = word_index
387 | constituent = (start, end, typ)
388 | if end - start > 1:
389 | indexed_constituents.add(constituent)
390 | else:
391 | word_index += 1
392 | return indexed_constituents, word_index
393 |
394 |
395 | def example_f1(c1, c2):
396 | prec = float(len(c1.intersection(c2))) / len(c2)
397 | return prec # For strictly binary trees, P = R = F1
398 |
399 |
400 | def example_labeled_acc(c1, c2):
401 | '''Compute the number of non-unary constituents of each type in the labeled (non-binirized) parse appear in the model output.'''
402 | correct = Counter()
403 | total = Counter()
404 | for constituent in c2:
405 | if (constituent[0], constituent[1]) in c1:
406 | correct[constituent[2]] += 1
407 | total[constituent[2]] += 1
408 | return correct, total
409 |
410 |
411 | def randomize(parse):
412 | tokens = tokenize_parse(parse)
413 | while len(tokens) > 1:
414 | merge = random.choice(list(range(len(tokens) - 1)))
415 | tokens[merge] = "( " + tokens[merge] + " " + tokens[merge + 1] + " )"
416 | del tokens[merge + 1]
417 | return tokens[0]
418 |
419 |
420 | def to_latex(parse):
421 | return ("\\Tree " + parse).replace('(', '[').replace(')', ']').replace(' . ', ' $.$ ')
422 |
423 |
424 | def read_nli_report(path):
425 | report = {}
426 | with codecs.open(path, encoding='utf-8') as f:
427 | for line in f:
428 | loaded_example = json.loads(line)
429 | report[loaded_example['example_id'] + "_1"] = unpad(loaded_example['sent1_tree'])
430 | report[loaded_example['example_id'] + "_2"] = unpad(loaded_example['sent2_tree'])
431 | return report
432 |
433 |
434 | def read_sst_report(path):
435 | report = {}
436 | with codecs.open(path, encoding='utf-8') as f:
437 | for line in f:
438 | loaded_example = json.loads(line)
439 | report[loaded_example['example_id'] + "_1"] = unpad(loaded_example['sent1_tree'])
440 | return report
441 |
442 |
443 | def read_listops_report(path):
444 | report = {}
445 | correct = 0
446 | num = 0
447 | with codecs.open(path, encoding='utf-8') as f:
448 | for line in f:
449 | loaded_example = json.loads(line)
450 | report[loaded_example['example_id']] = unpad(loaded_example['sent1_tree'])
451 | num += 1
452 | if loaded_example['truth'] == loaded_example['prediction']:
453 | correct += 1
454 | print("Accuracy = ", correct / num)
455 | return report
456 |
457 |
458 | def read_nli_report_padded(path):
459 | report = {}
460 | with codecs.open(path, encoding='utf-8') as f:
461 | for line in f:
462 | try:
463 | line = line.encode('UTF-8')
464 | except UnicodeError as e:
465 | print("ENCODING ERROR:", line, e)
466 | line = "{}"
467 | loaded_example = json.loads(line)
468 | report[loaded_example['example_id'] + "_1"] = loaded_example['sent1_tree']
469 | report[loaded_example['example_id'] + "_2"] = loaded_example['sent2_tree']
470 | return report
471 |
472 |
473 | def read_ptb_report(path):
474 | report = {}
475 | with codecs.open(path, encoding='utf-8') as f:
476 | for line in f:
477 | loaded_example = json.loads(line)
478 | report[loaded_example['example_id']] = unpad(loaded_example['sent1_tree'])
479 | return report
480 |
481 |
482 | def unpad(parse):
483 | ok = ["(", ")", "_PAD"]
484 | unpadded = []
485 | tokens = parse.split()
486 | cur = [i for i in range(len(tokens)) if tokens[i] == "_PAD"]
487 |
488 | if len(cur) != 0:
489 | if tokens[cur[0] - 1] in ok:
490 | unpad = tokens[:cur[0] - 1]
491 | else:
492 | unpad = tokens[:cur[0]]
493 | else:
494 | unpad = tokens
495 |
496 | sent = " ".join(unpad)
497 | while sent.count("(") != sent.count(")"):
498 | sent += " )"
499 |
500 | return sent
501 |
502 |
503 | def ConvertBinaryBracketedSeq(seq):
504 | T_SHIFT = 0
505 | T_REDUCE = 1
506 |
507 | tokens, transitions = [], []
508 | for item in seq:
509 | if item != "(":
510 | if item != ")":
511 | tokens.append(item)
512 | transitions.append(T_REDUCE if item == ")" else T_SHIFT)
513 | return tokens, transitions
514 |
515 |
516 | def run():
517 | gt = {}
518 | # gt_labeled = {}
519 | with codecs.open(FLAGS.main_data_path, encoding='utf-8') as f:
520 | for example_id, line in enumerate(f):
521 | if FLAGS.data_type == "nli":
522 | loaded_example = json.loads(line)
523 | if loaded_example["gold_label"] not in LABEL_MAP:
524 | continue
525 | if '512-4841' in loaded_example['sentence1_binary_parse'] \
526 | or '512-8581' in loaded_example['sentence1_binary_parse'] \
527 | or '412-4841' in loaded_example['sentence1_binary_parse'] \
528 | or '512-4841' in loaded_example['sentence2_binary_parse'] \
529 | or '512-8581' in loaded_example['sentence2_binary_parse'] \
530 | or '412-4841' in loaded_example['sentence2_binary_parse']:
531 | continue # Stanford parser tree binarizer doesn't handle phone numbers properly.
532 | gt[loaded_example['pairID'] + "_1"] = loaded_example['sentence1_binary_parse']
533 | gt[loaded_example['pairID'] + "_2"] = loaded_example['sentence2_binary_parse']
534 | # gt_labeled[loaded_example['pairID'] + "_1"] = loaded_example['sentence1_parse']
535 | # gt_labeled[loaded_example['pairID'] + "_2"] = loaded_example['sentence2_parse']
536 |
537 | gt_labeled[loaded_example['pairID'] + "_1"] = loaded_example['sentence1_parse']
538 | gt_labeled[loaded_example['pairID'] + "_2"] = loaded_example['sentence2_parse']
539 |
540 | elif FLAGS.data_type == "sst":
541 | line = line.strip()
542 | stack = []
543 | words = line.replace(')', ' )')
544 | words = words.split(' ')
545 | for index, word in enumerate(words):
546 | if word[0] != "(":
547 | if word == ")":
548 | # Ignore unary merges
549 | if words[index - 1] == ")":
550 | newg = "( " + stack.pop() + " " + stack.pop() + " )"
551 | stack.append(newg)
552 | else:
553 | stack.append(word)
554 | gt[str(example_id) + "_1"] = stack[0]
555 |
556 | elif FLAGS.data_type == "listops":
557 | line = line.strip()
558 | label, seq = line.split('\t')
559 | if len(seq) <= 1:
560 | continue
561 |
562 | tokens, transitions = ConvertBinaryBracketedSeq(
563 | seq.split(' '))
564 |
565 | example = {}
566 | example["label"] = label
567 | example["sentence"] = seq
568 | example["tokens"] = tokens
569 | example["transitions"] = transitions
570 | example["example_id"] = str(example_id)
571 | gt[example["example_id"]] = example["sentence"]
572 |
573 | lb = to_lb(gt)
574 | rb = to_rb(gt)
575 | print("GT average depth", corpus_average_depth(gt))
576 |
577 | ptb = {}
578 | ptb_labeled = {}
579 | if FLAGS.ptb_data_path != "_":
580 | with codecs.open(FLAGS.ptb_data_path, encoding='utf-8') as f:
581 | for line in f:
582 | loaded_example = json.loads(line)
583 | if loaded_example["gold_label"] not in LABEL_MAP:
584 | continue
585 | ptb[loaded_example['pairID']] = loaded_example['sentence1_binary_parse']
586 | ptb_labeled[loaded_example['pairID']] = loaded_example['sentence1_parse']
587 |
588 | reports = []
589 | ptb_reports = []
590 | if FLAGS.use_random_parses:
591 | print("Creating five sets of random parses for the main data.")
592 | report_paths = list(range(5))
593 | for _ in report_paths:
594 | report = {}
595 | for sentence in gt:
596 | report[sentence] = randomize(gt[sentence])
597 | reports.append(report)
598 |
599 | print("Creating five sets of random parses for the PTB data.")
600 | ptb_report_paths = list(range(5))
601 | for _ in report_paths:
602 | report = {}
603 | for sentence in ptb:
604 | report[sentence] = randomize(ptb[sentence])
605 | ptb_reports.append(report)
606 | if FLAGS.use_balanced_parses:
607 | print("Creating five sets of balanced binary parses for the main data.")
608 | report_paths = list(range(5))
609 | for _ in report_paths:
610 | report = {}
611 | for sentence in gt:
612 | report[sentence] = balance(gt[sentence])
613 | reports.append(report)
614 |
615 | print("Creating five sets of balanced binary parses for the PTB data.")
616 | ptb_report_paths = list(range(5))
617 | for _ in report_paths:
618 | report = {}
619 | for sentence in ptb:
620 | report[sentence] = balance(ptb[sentence])
621 | ptb_reports.append(report)
622 | else:
623 | report_paths = glob.glob(FLAGS.main_report_path_template)
624 | for path in report_paths:
625 | print("Loading", path)
626 | if FLAGS.data_type == "nli":
627 | reports.append(read_nli_report(path))
628 | elif FLAGS.data_type == "sst":
629 | reports.append(read_sst_report(path))
630 | elif FLAGS.data_type == "listops":
631 | reports.append(read_listops_report(path))
632 | if FLAGS.main_report_path_template != "_":
633 | ptb_report_paths = glob.glob(FLAGS.ptb_report_path_template)
634 | for path in ptb_report_paths:
635 | print("Loading", path)
636 | ptb_reports.append(read_ptb_report(path))
637 |
638 | if len(reports) > 1 and FLAGS.compute_self_f1:
639 | f1s = []
640 | for i in range(len(report_paths) - 1):
641 | for j in range(i + 1, len(report_paths)):
642 | path_1 = report_paths[i]
643 | path_2 = report_paths[j]
644 | f1 = corpus_stats(reports[i], reports[j])
645 | f1s.append(f1)
646 | print("Mean Self F1:\t" + str(sum(f1s) / len(f1s)))
647 |
648 | correct = Counter()
649 | total = Counter()
650 | for i, report in enumerate(reports):
651 | print(report_paths[i])
652 | if FLAGS.print_latex > 0:
653 | for index, sentence in enumerate(gt):
654 | if index == FLAGS.print_latex:
655 | break
656 | print(to_latex(gt[sentence]))
657 | print(to_latex(report[sentence]))
658 | print()
659 |
660 | if FLAGS.data_type == "listops":
661 | gtf1, gtcp = corpus_stats(report, gt, first_two=FLAGS.first_two, neg_pair=FLAGS.neg_pair, const_parse=True)
662 | else:
663 | gtf1, gtcp = corpus_stats(report, gt, first_two=FLAGS.first_two, neg_pair=FLAGS.neg_pair, const_parse=False)
664 | print("Left:", str(corpus_stats(report, lb)[0]) + '\t' + "Right:",
665 | str(corpus_stats(report, rb)[0]) + '\t' + "Groud-truth", str(gtf1) + '\t' + "Tree depth:",
666 | str(corpus_average_depth(report)), '\t', "Constituent parsed:", str(gtcp))
667 |
668 | correct = Counter()
669 | total = Counter()
670 | for i, report in enumerate(ptb_reports):
671 | print(ptb_report_paths[i])
672 | if FLAGS.print_latex > 0:
673 | for index, sentence in enumerate(ptb):
674 | if index == FLAGS.print_latex:
675 | break
676 | print(to_latex(ptb[sentence]))
677 | print(to_latex(report[sentence]))
678 | print()
679 | print(str(corpus_stats(report, ptb)) + '\t' + str(corpus_average_depth(report)))
680 | set_correct, set_total = corpus_stats_labeled(report, ptb_labeled)
681 | correct.update(set_correct)
682 | total.update(set_total)
683 |
684 | for key in sorted(total):
685 | print(key + '\t' + str(correct[key] * 1. / total[key]))
686 |
687 |
688 | if __name__ == '__main__':
689 | gflags.DEFINE_string("main_report_path_template", "./checkpoints/*.report",
690 | "A template (with wildcards input as \*) for the paths to the main reports.")
691 | gflags.DEFINE_string("main_data_path", "./snli_1.0/snli_1.0_dev.jsonl",
692 | "A template (with wildcards input as \*) for the paths to the main reports.")
693 | gflags.DEFINE_string("ptb_report_path_template", "_",
694 | "A template (with wildcards input as \*) for the paths to the PTB reports, or '_' if not available.")
695 | gflags.DEFINE_string("ptb_data_path", "_", "The path to the PTB data in SNLI format, or '_' if not available.")
696 | gflags.DEFINE_boolean("compute_self_f1", True,
697 | "Compute self F1 over all reports matching main_report_path_template.")
698 | gflags.DEFINE_boolean("use_random_parses", False,
699 | "Replace all report trees with randomly generated trees. Report path template flags are not used when this is set.")
700 | gflags.DEFINE_boolean("use_balanced_parses", False,
701 | "Replace all report trees with roughly-balanced binary trees. Report path template flags are not used when this is set.")
702 | gflags.DEFINE_boolean("first_two", False, "Show 'first two' and 'last two' metrics.")
703 | gflags.DEFINE_boolean("neg_pair", False, "Show 'neg_pair' metric.")
704 | gflags.DEFINE_enum("data_type", "nli", ["nli", "sst", "listops"], "Data Type")
705 | gflags.DEFINE_integer("print_latex", 0, "Print this many trees in LaTeX format for each report.")
706 |
707 | FLAGS(sys.argv)
708 |
709 | run()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.13.3
2 | matplotlib>=3.0.3
3 | python_gflags
4 | nltk
5 | torch
6 |
--------------------------------------------------------------------------------
/splitcross.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import numpy as np
7 |
8 |
9 | class SplitCrossEntropyLoss(nn.Module):
10 | r'''SplitCrossEntropyLoss calculates an approximate softmax'''
11 | def __init__(self, hidden_size, splits, verbose=False):
12 | # We assume splits is [0, split1, split2, N] where N >= |V|
13 | # For example, a vocab of 1000 words may have splits [0] + [100, 500] + [inf]
14 | super(SplitCrossEntropyLoss, self).__init__()
15 | self.hidden_size = hidden_size
16 | self.splits = [0] + splits + [100 * 1000000]
17 | self.nsplits = len(self.splits) - 1
18 | self.stats = defaultdict(list)
19 | self.verbose = verbose
20 | # Each of the splits that aren't in the head require a pretend token, we'll call them tombstones
21 | # The probability given to this tombstone is the probability of selecting an item from the represented split
22 | if self.nsplits > 1:
23 | self.tail_vectors = nn.Parameter(torch.zeros(self.nsplits - 1, hidden_size))
24 | self.tail_bias = nn.Parameter(torch.zeros(self.nsplits - 1))
25 |
26 | def logprob(self, weight, bias, hiddens, splits=None, softmaxed_head_res=None, verbose=False):
27 | # First we perform the first softmax on the head vocabulary and the tombstones
28 | if softmaxed_head_res is None:
29 | start, end = self.splits[0], self.splits[1]
30 | head_weight = None if end - start == 0 else weight[start:end]
31 | head_bias = None if end - start == 0 else bias[start:end]
32 | # We only add the tombstones if we have more than one split
33 | if self.nsplits > 1:
34 | head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors])
35 | head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias])
36 |
37 | # Perform the softmax calculation for the word vectors in the head for all splits
38 | # We need to guard against empty splits as torch.cat does not like random lists
39 | head_res = torch.nn.functional.linear(hiddens, head_weight, bias=head_bias)
40 | softmaxed_head_res = torch.nn.functional.log_softmax(head_res, dim=-1)
41 |
42 | if splits is None:
43 | splits = list(range(self.nsplits))
44 |
45 | results = []
46 | running_offset = 0
47 | for idx in splits:
48 |
49 | # For those targets in the head (idx == 0) we only need to return their loss
50 | if idx == 0:
51 | results.append(softmaxed_head_res[:, :-(self.nsplits - 1)])
52 |
53 | # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone)
54 | else:
55 | start, end = self.splits[idx], self.splits[idx + 1]
56 | tail_weight = weight[start:end]
57 | tail_bias = bias[start:end]
58 |
59 | # Calculate the softmax for the words in the tombstone
60 | tail_res = torch.nn.functional.linear(hiddens, tail_weight, bias=tail_bias)
61 |
62 | # Then we calculate p(tombstone) * p(word in tombstone)
63 | # Adding is equivalent to multiplication in log space
64 | head_entropy = (softmaxed_head_res[:, -idx]).contiguous()
65 | tail_entropy = torch.nn.functional.log_softmax(tail_res, dim=-1)
66 | results.append(head_entropy.view(-1, 1) + tail_entropy)
67 |
68 | if len(results) > 1:
69 | return torch.cat(results, dim=1)
70 | return results[0]
71 |
72 | def split_on_targets(self, hiddens, targets):
73 | # Split the targets into those in the head and in the tail
74 | split_targets = []
75 | split_hiddens = []
76 |
77 | # Determine to which split each element belongs (for each start split value, add 1 if equal or greater)
78 | # This method appears slower at least for WT-103 values for approx softmax
79 | #masks = [(targets >= self.splits[idx]).view(1, -1) for idx in range(1, self.nsplits)]
80 | #mask = torch.sum(torch.cat(masks, dim=0), dim=0)
81 | ###
82 | # This is equally fast for smaller splits as method below but scales linearly
83 | mask = None
84 | for idx in range(1, self.nsplits):
85 | partial_mask = targets >= self.splits[idx]
86 | mask = mask + partial_mask if mask is not None else partial_mask
87 | ###
88 | #masks = torch.stack([targets] * (self.nsplits - 1))
89 | #mask = torch.sum(masks >= self.split_starts, dim=0)
90 | for idx in range(self.nsplits):
91 | # If there are no splits, avoid costly masked select
92 | if self.nsplits == 1:
93 | split_targets, split_hiddens = [targets], [hiddens]
94 | continue
95 | # If all the words are covered by earlier targets, we have empties so later stages don't freak out
96 | if sum(len(t) for t in split_targets) == len(targets):
97 | split_targets.append([])
98 | split_hiddens.append([])
99 | continue
100 | # Are you in our split?
101 | tmp_mask = mask == idx
102 | split_targets.append(torch.masked_select(targets, tmp_mask))
103 | split_hiddens.append(hiddens.masked_select(tmp_mask.unsqueeze(1).expand_as(hiddens)).view(-1, hiddens.size(1)))
104 | return split_targets, split_hiddens
105 |
106 | def forward(self, weight, bias, hiddens, targets, verbose=False):
107 | if self.verbose or verbose:
108 | for idx in sorted(self.stats):
109 | print('{}: {}'.format(idx, int(np.mean(self.stats[idx]))), end=', ')
110 | print()
111 |
112 | total_loss = None
113 | if len(hiddens.size()) > 2: hiddens = hiddens.view(-1, hiddens.size(2))
114 |
115 | split_targets, split_hiddens = self.split_on_targets(hiddens, targets)
116 |
117 | # First we perform the first softmax on the head vocabulary and the tombstones
118 | start, end = self.splits[0], self.splits[1]
119 | head_weight = None if end - start == 0 else weight[start:end]
120 | head_bias = None if end - start == 0 else bias[start:end]
121 |
122 | # We only add the tombstones if we have more than one split
123 | if self.nsplits > 1:
124 | head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors])
125 | head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias])
126 |
127 | # Perform the softmax calculation for the word vectors in the head for all splits
128 | # We need to guard against empty splits as torch.cat does not like random lists
129 | combo = torch.cat([split_hiddens[i] for i in range(self.nsplits) if len(split_hiddens[i])])
130 | ###
131 | all_head_res = torch.nn.functional.linear(combo, head_weight, bias=head_bias)
132 | softmaxed_all_head_res = torch.nn.functional.log_softmax(all_head_res, dim=-1)
133 | if self.verbose or verbose:
134 | self.stats[0].append(combo.size()[0] * head_weight.size()[0])
135 |
136 | running_offset = 0
137 | for idx in range(self.nsplits):
138 | # If there are no targets for this split, continue
139 | if len(split_targets[idx]) == 0: continue
140 |
141 | # For those targets in the head (idx == 0) we only need to return their loss
142 | if idx == 0:
143 | softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])]
144 | entropy = -torch.gather(softmaxed_head_res, dim=1, index=split_targets[idx].view(-1, 1))
145 | # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone)
146 | else:
147 | softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])]
148 |
149 | if self.verbose or verbose:
150 | start, end = self.splits[idx], self.splits[idx + 1]
151 | tail_weight = weight[start:end]
152 | self.stats[idx].append(split_hiddens[idx].size()[0] * tail_weight.size()[0])
153 |
154 | # Calculate the softmax for the words in the tombstone
155 | tail_res = self.logprob(weight, bias, split_hiddens[idx], splits=[idx], softmaxed_head_res=softmaxed_head_res)
156 |
157 | # Then we calculate p(tombstone) * p(word in tombstone)
158 | # Adding is equivalent to multiplication in log space
159 | head_entropy = softmaxed_head_res[:, -idx]
160 | # All indices are shifted - if the first split handles [0,...,499] then the 500th in the second split will be 0 indexed
161 | indices = (split_targets[idx] - self.splits[idx]).view(-1, 1)
162 | # Warning: if you don't squeeze, you get an N x 1 return, which acts oddly with broadcasting
163 | tail_entropy = torch.gather(torch.nn.functional.log_softmax(tail_res, dim=-1), dim=1, index=indices).squeeze()
164 | entropy = -(head_entropy + tail_entropy)
165 | ###
166 | running_offset += len(split_hiddens[idx])
167 | total_loss = entropy.float().sum() if total_loss is None else total_loss + entropy.float().sum()
168 |
169 | return (total_loss / len(targets)).type_as(weight)
170 |
171 |
172 | if __name__ == '__main__':
173 | np.random.seed(42)
174 | torch.manual_seed(42)
175 | if torch.cuda.is_available():
176 | torch.cuda.manual_seed(42)
177 |
178 | V = 8
179 | H = 10
180 | N = 100
181 | E = 10
182 |
183 | embed = torch.nn.Embedding(V, H)
184 | crit = SplitCrossEntropyLoss(hidden_size=H, splits=[V // 2])
185 | bias = torch.nn.Parameter(torch.ones(V))
186 | optimizer = torch.optim.SGD(list(embed.parameters()) + list(crit.parameters()), lr=1)
187 |
188 | for _ in range(E):
189 | prev = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long())
190 | x = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long())
191 | y = embed(prev).squeeze()
192 | c = crit(embed.weight, bias, y, x.view(N))
193 | print('Crit', c.exp().data[0])
194 |
195 | logprobs = crit.logprob(embed.weight, bias, y[:2]).exp()
196 | print(logprobs)
197 | print(logprobs.sum(dim=1))
198 |
199 | optimizer.zero_grad()
200 | c.backward()
201 | optimizer.step()
202 |
--------------------------------------------------------------------------------
/test_phrase_grammar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 |
4 | import matplotlib.pyplot as plt
5 | import nltk
6 | import numpy
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import Variable
10 |
11 | import data
12 | import data_ptb
13 | from utils import batchify, get_batch, repackage_hidden, evalb
14 |
15 | from parse_comparison import corpus_stats_labeled, corpus_average_depth
16 | from data_ptb import word_tags
17 |
18 |
19 | criterion = nn.CrossEntropyLoss()
20 | def evaluate(data_source, batch_size=1):
21 | # Turn on evaluation mode which disables dropout.
22 | model.eval()
23 | total_loss = 0
24 | ntokens = len(corpus.dictionary)
25 | hidden = model.init_hidden(batch_size)
26 | for i in range(0, data_source.size(0) - 1, args.bptt):
27 | data, targets = get_batch(data_source, i, args, evaluation=True)
28 | output, hidden = model(data, hidden)
29 | output = model.decoder(output)
30 | output_flat = output.view(-1, ntokens)
31 | total_loss += len(data) * criterion(output_flat, targets).data
32 | hidden = repackage_hidden(hidden)
33 | return total_loss / len(data_source)
34 |
35 | def corpus2idx(sentence):
36 | arr = np.array([data.dictionary.word2idx[c] for c in sentence.split()], dtype=np.int32)
37 | return torch.from_numpy(arr[:, None]).long()
38 |
39 |
40 | # Test model
41 | def build_tree(depth, sen):
42 | assert len(depth) == len(sen)
43 |
44 | if len(depth) == 1:
45 | parse_tree = sen[0]
46 | else:
47 | idx_max = numpy.argmax(depth)
48 | parse_tree = []
49 | if len(sen[:idx_max]) > 0:
50 | tree0 = build_tree(depth[:idx_max], sen[:idx_max])
51 | parse_tree.append(tree0)
52 | tree1 = sen[idx_max]
53 | if len(sen[idx_max + 1:]) > 0:
54 | tree2 = build_tree(depth[idx_max + 1:], sen[idx_max + 1:])
55 | tree1 = [tree1, tree2]
56 | if parse_tree == []:
57 | parse_tree = tree1
58 | else:
59 | parse_tree.append(tree1)
60 | return parse_tree
61 |
62 |
63 | # def build_tree(depth, sen):
64 | # assert len(depth) == len(sen)
65 | # assert len(depth) >= 0
66 | #
67 | # if len(depth) == 1:
68 | # parse_tree = sen[0]
69 | # else:
70 | # idx_max = numpy.argmax(depth[1:]) + 1
71 | # parse_tree = []
72 | # if len(sen[:idx_max]) > 0:
73 | # tree0 = build_tree(depth[:idx_max], sen[:idx_max])
74 | # parse_tree.append(tree0)
75 | # if len(sen[idx_max:]) > 0:
76 | # tree1 = build_tree(depth[idx_max:], sen[idx_max:])
77 | # parse_tree.append(tree1)
78 | # return parse_tree
79 |
80 |
81 | def get_brackets(tree, idx=0):
82 | brackets = set()
83 | if isinstance(tree, list) or isinstance(tree, nltk.Tree):
84 | for node in tree:
85 | node_brac, next_idx = get_brackets(node, idx)
86 | if next_idx - idx > 1:
87 | brackets.add((idx, next_idx))
88 | brackets.update(node_brac)
89 | idx = next_idx
90 | return brackets, idx
91 | else:
92 | return brackets, idx + 1
93 |
94 | def MRG(tr):
95 | if isinstance(tr, str):
96 | #return '(' + tr + ')'
97 | return tr + ' '
98 | else:
99 | s = '( '
100 | for subtr in tr:
101 | s += MRG(subtr)
102 | s += ') '
103 | return s
104 |
105 | def MRG_labeled(tr):
106 | if isinstance(tr, nltk.Tree):
107 | if tr.label() in word_tags:
108 | return tr.leaves()[0] + ' '
109 | else:
110 | s = '(%s ' % (re.split(r'[-=]', tr.label())[0])
111 | for subtr in tr:
112 | s += MRG_labeled(subtr)
113 | s += ') '
114 | return s
115 | else:
116 | return ''
117 |
118 | def mean(x):
119 | return sum(x) / len(x)
120 |
121 |
122 | def test(model, corpus, cuda, prt=False):
123 | model.eval()
124 |
125 | prec_list = []
126 | reca_list = []
127 | f1_list = []
128 |
129 | pred_tree_list = []
130 | targ_tree_list = []
131 |
132 | nsens = 0
133 | word2idx = corpus.dictionary.word2idx
134 | if args.wsj10:
135 | dataset = zip(corpus.train_sens, corpus.train_trees, corpus.train_nltktrees)
136 | else:
137 | dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees)
138 |
139 | corpus_sys = {}
140 | corpus_ref = {}
141 | for sen, sen_tree, sen_nltktree in dataset:
142 | if args.wsj10 and len(sen) > 12:
143 | continue
144 | x = numpy.array([word2idx[w] if w in word2idx else word2idx[''] for w in sen])
145 | input = Variable(torch.LongTensor(x[:, None]))
146 | if cuda:
147 | input = input.cuda()
148 |
149 | hidden = model.init_hidden(1)
150 | _, hidden = model(input, hidden)
151 |
152 | distance = model.distance[0].squeeze().data.cpu().numpy()
153 | distance_in = model.distance[1].squeeze().data.cpu().numpy()
154 |
155 | nsens += 1
156 | if prt and nsens % 100 == 0:
157 | for i in range(len(sen)):
158 | print('%15s\t%s\t%s' % (sen[i], str(distance[:, i]), str(distance_in[:, i])))
159 | print('Standard output:', sen_tree)
160 |
161 | sen_cut = sen[1:-1]
162 | # gates = distance.mean(axis=0)
163 | for gates in [
164 | # distance[0],
165 | distance[1],
166 | # distance[2],
167 | # distance.mean(axis=0)
168 | ]:
169 | depth = gates[1:-1]
170 | parse_tree = build_tree(depth, sen_cut)
171 |
172 | corpus_sys[nsens] = MRG(parse_tree)
173 | corpus_ref[nsens] = MRG_labeled(sen_nltktree)
174 |
175 | pred_tree_list.append(parse_tree)
176 | targ_tree_list.append(sen_tree)
177 |
178 | model_out, _ = get_brackets(parse_tree)
179 | std_out, _ = get_brackets(sen_tree)
180 | overlap = model_out.intersection(std_out)
181 |
182 | prec = float(len(overlap)) / (len(model_out) + 1e-8)
183 | reca = float(len(overlap)) / (len(std_out) + 1e-8)
184 | if len(std_out) == 0:
185 | reca = 1.
186 | if len(model_out) == 0:
187 | prec = 1.
188 | f1 = 2 * prec * reca / (prec + reca + 1e-8)
189 | prec_list.append(prec)
190 | reca_list.append(reca)
191 | f1_list.append(f1)
192 |
193 | if prt and nsens % 100 == 0:
194 | print('Model output:', parse_tree)
195 | print('Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1))
196 |
197 | if prt and nsens % 100 == 0:
198 | print('-' * 80)
199 |
200 | f, axarr = plt.subplots(3, sharex=True, figsize=(distance.shape[1] // 2, 6))
201 | axarr[0].bar(numpy.arange(distance.shape[1])-0.2, distance[0], width=0.4)
202 | axarr[0].bar(numpy.arange(distance_in.shape[1])+0.2, distance_in[0], width=0.4)
203 | axarr[0].set_ylim([0., 1.])
204 | axarr[0].set_ylabel('1st layer')
205 | axarr[1].bar(numpy.arange(distance.shape[1]) - 0.2, distance[1], width=0.4)
206 | axarr[1].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[1], width=0.4)
207 | axarr[1].set_ylim([0., 1.])
208 | axarr[1].set_ylabel('2nd layer')
209 | axarr[2].bar(numpy.arange(distance.shape[1]) - 0.2, distance[2], width=0.4)
210 | axarr[2].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[2], width=0.4)
211 | axarr[2].set_ylim([0., 1.])
212 | axarr[2].set_ylabel('3rd layer')
213 | plt.sca(axarr[2])
214 | plt.xlim(xmin=-0.5, xmax=distance.shape[1] - 0.5)
215 | plt.xticks(numpy.arange(distance.shape[1]), sen, fontsize=10, rotation=45)
216 |
217 | plt.savefig('figure/%d.png' % (nsens))
218 | plt.close()
219 |
220 | prec_list, reca_list, f1_list \
221 | = numpy.array(prec_list).reshape((-1,1)), numpy.array(reca_list).reshape((-1,1)), numpy.array(f1_list).reshape((-1,1))
222 | if prt:
223 | print('-' * 80)
224 | numpy.set_printoptions(precision=4)
225 | print('Mean Prec:', prec_list.mean(axis=0),
226 | ', Mean Reca:', reca_list.mean(axis=0),
227 | ', Mean F1:', f1_list.mean(axis=0))
228 | print('Number of sentence: %i' % nsens)
229 |
230 | correct, total = corpus_stats_labeled(corpus_sys, corpus_ref)
231 | print(correct)
232 | print(total)
233 | print('ADJP:', correct['ADJP'], total['ADJP'])
234 | print('NP:', correct['NP'], total['NP'])
235 | print('PP:', correct['PP'], total['PP'])
236 | print('INTJ:', correct['INTJ'], total['INTJ'])
237 | print(corpus_average_depth(corpus_sys))
238 |
239 | evalb(pred_tree_list, targ_tree_list)
240 |
241 | return f1_list.mean(axis=0)
242 |
243 |
244 | if __name__ == '__main__':
245 | marks = [' ', '-', '=']
246 |
247 | numpy.set_printoptions(precision=2, suppress=True, linewidth=5000)
248 |
249 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')
250 |
251 | # Model parameters.
252 | parser.add_argument('--data', type=str, default='data/ptb',
253 | help='location of the data corpus')
254 | parser.add_argument('--checkpoint', type=str, default='PTB.pt',
255 | help='model checkpoint to use')
256 | parser.add_argument('--seed', type=int, default=1111,
257 | help='random seed')
258 | parser.add_argument('--cuda', action='store_true',
259 | help='use CUDA')
260 | parser.add_argument('--wsj10', action='store_true',
261 | help='use WSJ10')
262 | args = parser.parse_args()
263 | args.bptt = 70
264 |
265 | # Set the random seed manually for reproducibility.
266 | torch.manual_seed(args.seed)
267 |
268 | # Load model
269 | with open(args.checkpoint, 'rb') as f:
270 | model, _, _ = torch.load(f)
271 | torch.cuda.manual_seed(args.seed)
272 | model.cpu()
273 | if args.cuda:
274 | model.cuda()
275 |
276 | # Load data
277 | import hashlib
278 |
279 | fn = 'corpus.{}.data'.format(hashlib.md5('data/penn'.encode()).hexdigest())
280 | print('Loading cached dataset...')
281 | corpus = torch.load(fn)
282 | dictionary = corpus.dictionary
283 |
284 | # test_batch_size = 1
285 | # test_data = batchify(corpus.test, test_batch_size, args)
286 | # test_loss = evaluate(test_data, test_batch_size)
287 | # print('=' * 89)
288 | # print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format(
289 | # test_loss, math.exp(test_loss), test_loss / math.log(2)))
290 | # print('=' * 89)
291 |
292 | print('Loading PTB dataset...')
293 | corpus = data_ptb.Corpus(args.data)
294 | corpus.dictionary = dictionary
295 |
296 | test(model, corpus, args.cuda, prt=True)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def repackage_hidden(h):
5 | """Wraps hidden states in new Tensors,
6 | to detach them from their history."""
7 | if isinstance(h, torch.Tensor):
8 | return h.detach()
9 | else:
10 | return tuple(repackage_hidden(v) for v in h)
11 |
12 |
13 | def batchify(data, bsz, args):
14 | # Work out how cleanly we can divide the dataset into bsz parts.
15 | nbatch = data.size(0) // bsz
16 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
17 | data = data.narrow(0, 0, nbatch * bsz)
18 | # Evenly divide the data across the bsz batches.
19 | data = data.view(bsz, -1).t().contiguous()
20 | if args.cuda:
21 | data = data.cuda()
22 | return data
23 |
24 |
25 | def get_batch(source, i, args, seq_len=None, evaluation=False):
26 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
27 | data = source[i:i+seq_len]
28 | target = source[i+1:i+1+seq_len].view(-1)
29 | return data, target
30 |
31 |
32 | def load_embeddings_txt(path):
33 | words = pd.read_csv(path, sep=" ", index_col=0,
34 | na_values=None, keep_default_na=False, header=None,
35 | quoting=csv.QUOTE_NONE)
36 | matrix = words.values
37 | index_to_word = list(words.index)
38 | word_to_index = {
39 | word: ind for ind, word in enumerate(index_to_word)
40 | }
41 | return matrix, word_to_index, index_to_word
42 |
43 | def evalb(pred_tree_list, targ_tree_list):
44 | import os
45 | import subprocess
46 | import tempfile
47 | import re
48 | import nltk
49 |
50 | temp_path = tempfile.TemporaryDirectory(prefix="evalb-")
51 | temp_file_path = os.path.join(temp_path.name, "pred_trees.txt")
52 | temp_targ_path = os.path.join(temp_path.name, "true_trees.txt")
53 | temp_eval_path = os.path.join(temp_path.name, "evals.txt")
54 |
55 | print("Temp: {}, {}".format(temp_file_path, temp_targ_path))
56 | temp_tree_file = open(temp_file_path, "w")
57 | temp_targ_file = open(temp_targ_path, "w")
58 |
59 | for pred_tree, targ_tree in zip(pred_tree_list, targ_tree_list):
60 | def process_str_tree(str_tree):
61 | return re.sub('[ |\n]+', ' ', str_tree)
62 |
63 | def list2tree(node):
64 | if isinstance(node, list):
65 | tree = []
66 | for child in node:
67 | tree.append(list2tree(child))
68 | return nltk.Tree('', tree)
69 | elif isinstance(node, str):
70 | return nltk.Tree('', [node])
71 |
72 | temp_tree_file.write(process_str_tree(str(list2tree(pred_tree)).lower()) + '\n')
73 | temp_targ_file.write(process_str_tree(str(list2tree(targ_tree)).lower()) + '\n')
74 |
75 | temp_tree_file.close()
76 | temp_targ_file.close()
77 |
78 | evalb_dir = os.path.join(os.getcwd(), "EVALB")
79 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
80 | evalb_program_path = os.path.join(evalb_dir, "evalb")
81 | command = "{} -p {} {} {} > {}".format(
82 | evalb_program_path,
83 | evalb_param_path,
84 | temp_targ_path,
85 | temp_file_path,
86 | temp_eval_path)
87 |
88 | subprocess.run(command, shell=True)
89 |
90 | with open(temp_eval_path) as infile:
91 | for line in infile:
92 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
93 | if match:
94 | evalb_recall = float(match.group(1))
95 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
96 | if match:
97 | evalb_precision = float(match.group(1))
98 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
99 | if match:
100 | evalb_fscore = float(match.group(1))
101 | break
102 |
103 | temp_path.cleanup()
104 |
105 | print('-' * 80)
106 | print('Evalb Prec:', evalb_precision,
107 | ', Evalb Reca:', evalb_recall,
108 | ', Evalb F1:', evalb_fscore)
109 |
110 | return evalb_fscore
111 |
--------------------------------------------------------------------------------
/weight_drop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from functools import wraps
4 |
5 | class WeightDrop(torch.nn.Module):
6 | def __init__(self, module, weights, dropout=0, variational=False):
7 | super(WeightDrop, self).__init__()
8 | self.module = module
9 | self.weights = weights
10 | self.dropout = dropout
11 | self.variational = variational
12 | self._setup()
13 |
14 | def widget_demagnetizer_y2k_edition(*args, **kwargs):
15 | # We need to replace flatten_parameters with a nothing function
16 | # It must be a function rather than a lambda as otherwise pickling explodes
17 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
18 | # (╯°□°)╯︵ ┻━┻
19 | return
20 |
21 | def _setup(self):
22 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
23 | if issubclass(type(self.module), torch.nn.RNNBase):
24 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
25 |
26 | for name_w in self.weights:
27 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
28 | w = getattr(self.module, name_w)
29 | del self.module._parameters[name_w]
30 | self.module.register_parameter(name_w + '_raw', Parameter(w.data))
31 |
32 | def _setweights(self):
33 | for name_w in self.weights:
34 | raw_w = getattr(self.module, name_w + '_raw')
35 | w = None
36 | if self.variational:
37 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
38 | if raw_w.is_cuda: mask = mask.cuda()
39 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
40 | w = mask.expand_as(raw_w) * raw_w
41 | else:
42 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
43 | setattr(self.module, name_w, w)
44 |
45 | def forward(self, *args):
46 | self._setweights()
47 | return self.module.forward(*args)
48 |
49 | if __name__ == '__main__':
50 | import torch
51 | from weight_drop import WeightDrop
52 |
53 | # Input is (seq, batch, input)
54 | x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda()
55 | h0 = None
56 |
57 | ###
58 |
59 | print('Testing WeightDrop')
60 | print('=-=-=-=-=-=-=-=-=-=')
61 |
62 | ###
63 |
64 | print('Testing WeightDrop with Linear')
65 |
66 | lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9)
67 | lin.cuda()
68 | run1 = [x.sum() for x in lin(x).data]
69 | run2 = [x.sum() for x in lin(x).data]
70 |
71 | print('All items should be different')
72 | print('Run 1:', run1)
73 | print('Run 2:', run2)
74 |
75 | assert run1[0] != run2[0]
76 | assert run1[1] != run2[1]
77 |
78 | print('---')
79 |
80 | ###
81 |
82 | print('Testing WeightDrop with LSTM')
83 |
84 | wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9)
85 | wdrnn.cuda()
86 |
87 | run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
88 | run2 = [x.sum() for x in wdrnn(x, h0)[0].data]
89 |
90 | print('First timesteps should be equal, all others should differ')
91 | print('Run 1:', run1)
92 | print('Run 2:', run2)
93 |
94 | # First time step, not influenced by hidden to hidden weights, should be equal
95 | assert run1[0] == run2[0]
96 | # Second step should not
97 | assert run1[1] != run2[1]
98 |
99 | print('---')
100 |
--------------------------------------------------------------------------------