├── LICENSE ├── README.md ├── code ├── checkpoints │ └── StReg │ │ └── pretrained.tar ├── data.py ├── datasets │ └── StReg │ │ ├── exs-teste.txt │ │ ├── exs-testi.txt │ │ ├── exs-train.txt │ │ ├── exs-val.txt │ │ ├── map-teste.txt │ │ ├── map-testi.txt │ │ ├── map-train.txt │ │ ├── map-val.txt │ │ ├── rec-teste.pkl │ │ ├── rec-testi.pkl │ │ ├── rec-train.pkl │ │ ├── rec-val.pkl │ │ ├── src-indexer.pkl │ │ ├── src-teste.txt │ │ ├── src-testi.txt │ │ ├── src-train.txt │ │ ├── src-val.txt │ │ ├── targ-indexer.pkl │ │ ├── targ-teste.txt │ │ ├── targ-testi.txt │ │ ├── targ-train.txt │ │ └── targ-val.txt ├── decode.py ├── decodes │ └── StReg │ │ ├── teste-pretrained.txt │ │ └── testi-pretrained.txt ├── eval.py ├── external │ ├── datagen.jar │ └── lib │ │ └── antlr-4.7.1-complete.jar ├── gadget.py ├── models.py ├── train.py └── utils.py ├── data ├── const_anonymized │ ├── dev.tsv │ ├── teste.tsv │ ├── testi.tsv │ └── train.tsv ├── raw │ ├── dev.tsv │ ├── teste.tsv │ ├── testi.tsv │ └── train.tsv └── tokenized │ ├── dev.tsv │ ├── teste.tsv │ ├── testi.tsv │ └── train.tsv ├── easy_eval ├── eval.py ├── external │ ├── datagen.jar │ └── lib │ │ └── antlr-4.7.1-complete.jar ├── streg_utils.py └── usage_example.py ├── quick_eval ├── external │ ├── backend.jar │ └── lib │ │ └── antlr-4.7.1-complete.jar └── regex_backend.py └── toolkit ├── README.md ├── base.py ├── constraints.py ├── external ├── jars │ └── datagen.jar └── lib │ ├── antlr-4.7.1-complete.jar │ └── sempre-core.jar ├── filters.py ├── gen_regex_data.py ├── postprocess.py ├── prepare_regex_data.py ├── regex_io.py ├── template.py └── usage_example.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Xi Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StructuredRegex 2 | Data and Code for paper **Benchmarking Multimodal Regex Synthesis with Complex Structures** . 3 | 4 | ## Data 5 | We provide *raw* data, *tokenized* data, and data with *anonymized* const strings. 6 | 7 | Natural Language Descriptions of *raw* version contain the original raw annotations from Turkers. In *tokenized version*, we preprocessed and tokenized the descriptions. In *anonymized* version, we further replaced the contants mentioned in the descriptions with anonymous symbols. E.g., given the NL-regex pair [*it must contain the string "ABC".* --> contain(\)], we replace *"ABC"* with symbol *const0* in both NL and regex, and hence the const-anonymized NL-regex pair should be [*it must contain the string const0.* --> contain(const0)]. 8 | 9 | All data is presented in **TSV** format, with fields including: 10 | 11 | * problem_id -- unique ID of the target regex. 12 | * description -- Turker annotated description. 13 | * regex -- target regex. 14 | * pos_examples -- positive examples. 15 | * neg_examples -- negative examples. 16 | * const_values -- mapping from symbols to the real string values, only existing in anonymized version. 17 | 18 | ## Code 19 | #### Requirements 20 | * pytorch > 1.0.0 21 | 22 | We've attached pretrained checkpoints in `code/checkpoints/pretrained.tar`, which is ready to use. You can also reproduce the experimental results following the steps below (Please execute the commands in `code` directory) 23 | 24 | **Train** 25 | 26 | `python train.py StReg --model_id `. The models will be stored in `checkpoints/StReg` directory with names following model_id*.tar. 27 | 28 | **Decode** 29 | 30 | `python decode.py StReg --split test*`. The derivations will be generated using the `checkpoints/StReg/.tar` and be outputed to `decodes/StReg/` directory. 31 | 32 | **Evaluate** 33 | 34 | `python eval.py StReg --split test*`. Note that we report DFA accuracy (refer to the paper for more details). 35 | 36 | 37 | ## Sampling Regexes and I/O Examples 38 | see README and `usage_example.py` in `toolkit`. 39 | 40 | ## Easy API for Checking Equivalence and I/O Consistency 41 | see 'easy_eval/usage_example.py` 42 | 43 | It also contains code for parsing the specification into AST that is easy to operate, and some code skeletons that can be completed to convert the specification in our DSL into standard regex. -------------------------------------------------------------------------------- /code/checkpoints/StReg/pretrained.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/checkpoints/StReg/pretrained.tar -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from utils import * 3 | from os.path import join 4 | 5 | PAD_SYMBOL = "" 6 | UNK_SYMBOL = "" 7 | SOS_SYMBOL = "" 8 | EOS_SYMBOL = "" 9 | 10 | # Wrapper class for an example. 11 | # x = the natural language as one string 12 | # x_tok = tokenized NL, a list of strings 13 | # x_indexed = indexed tokens, a list of ints 14 | # y = the logical form 15 | # y_tok = tokenized logical form, a list of strings 16 | # y_indexed = indexed logical form 17 | class Example(object): 18 | def __init__(self, x, x_tok, x_indexed, y, y_tok, y_indexed): 19 | self.x = x 20 | self.x_tok = x_tok 21 | self.x_indexed = x_indexed 22 | self.y = y 23 | self.y_tok = y_tok 24 | self.y_indexed = y_indexed 25 | self.id = 0 26 | 27 | def __repr__(self): 28 | return " ".join(self.x_tok) + " => " + " ".join(self.y_tok) + "\n indexed as: " + repr(self.x_indexed) + " => " + repr(self.y_indexed) 29 | 30 | def __str__(self): 31 | return self.__repr__() 32 | 33 | # Wrapper for a Derivation consisting of an Example object, a score/probability associated with that example, 34 | # and the tokenized prediction. 35 | class Derivation(object): 36 | def __init__(self, example, p, y_toks): 37 | self.example = example 38 | self.p = p 39 | self.y_toks = y_toks 40 | 41 | def __str__(self): 42 | return "%s (%s)" % (self.y_toks, self.p) 43 | 44 | def __repr__(self): 45 | return self.__str__() 46 | 47 | def read_lines(filename): 48 | with open(filename) as f: 49 | lines = f.readlines() 50 | lines = [x.rstrip('\n') for x in lines] 51 | 52 | return lines 53 | 54 | def get_model_file(dataset, model_id): 55 | return join('./checkpoints', dataset, model_id + '.tar') 56 | 57 | def get_decode_file(dataset, split, model_id): 58 | return join('./decodes', dataset, '{}-{}.txt'.format(split, model_id)) 59 | 60 | # Reads the training, dev, and test data from the corresponding files. 61 | def load_datasets(dataset): 62 | output_path = join('./datasets', dataset) 63 | 64 | train_raw = load_dataset(output_path, 'train') 65 | dev_raw = load_dataset(output_path, 'val') 66 | 67 | input_indexer = Indexer() 68 | output_indexer = Indexer() 69 | 70 | input_indexer.load_from_file(join(output_path, 'src-indexer.pkl')) 71 | output_indexer.load_from_file(join(output_path, 'targ-indexer.pkl')) 72 | return train_raw, dev_raw, input_indexer, output_indexer 73 | 74 | def load_test_dataset(dataset, split): 75 | output_path = join('./datasets', dataset) 76 | 77 | test_raw = load_dataset(output_path, split) 78 | 79 | input_indexer = Indexer() 80 | output_indexer = Indexer() 81 | 82 | input_indexer.load_from_file(join(output_path, 'src-indexer.pkl')) 83 | output_indexer.load_from_file(join(output_path, 'targ-indexer.pkl')) 84 | return test_raw, input_indexer, output_indexer 85 | 86 | 87 | # Reads a dataset in from the given file 88 | def load_dataset(output_path, split): 89 | 90 | src_lines = read_lines(join(output_path, 'src-%s.txt' % (split))) 91 | targ_lines = read_lines(join(output_path, 'targ-%s.txt' % (split))) 92 | 93 | data_raw = list(zip(src_lines, targ_lines)) 94 | 95 | return data_raw 96 | 97 | def read_map_file(filename): 98 | with open(filename) as f: 99 | lines = f.readlines() 100 | lines = [x.rstrip() for x in lines] 101 | maps = [] 102 | for l in lines: 103 | fields = l.split(" ") 104 | num = int(fields[0]) 105 | fields = fields[1:] 106 | if num == 0: 107 | maps.append([]) 108 | continue 109 | m = [] 110 | for f in fields: 111 | pair = f.split(",", 1) 112 | m.append((pair[0], pair[1])) 113 | maps.append(m) 114 | return maps 115 | 116 | def load_const_maps(dataset, split): 117 | filename = join("./datasets", dataset, 'map-%s.txt' % (split)) 118 | return read_map_file(filename) 119 | 120 | def load_exs(dataset, split): 121 | filename = join("./datasets", dataset, 'exs-%s.txt' % (split)) 122 | lines = read_lines(filename) 123 | lines = [x.split(" ") for x in lines] 124 | lines = [[(y.split(",", 1)[0], y.split(",", 1)[1]) 125 | for y in x] for x in lines] 126 | return lines 127 | 128 | def load_rec(dataset, split): 129 | filename = join("./datasets", dataset, 'rec-%s.pkl' % (split)) 130 | with open(filename, "rb") as f: 131 | rec = pickle.load(f) 132 | return rec 133 | 134 | # Whitespace tokenization 135 | def tokenize(x): 136 | return x.split() 137 | 138 | 139 | def index(x_tok, indexer): 140 | return [indexer.index_of(xi) if indexer.index_of(xi) >= 0 else indexer.index_of(UNK_SYMBOL) for xi in x_tok] 141 | 142 | 143 | def index_data(data, input_indexer, output_indexer, example_len_limit): 144 | data_indexed = [] 145 | for (x, y) in data: 146 | x_tok = tokenize(x) 147 | y_tok = tokenize(y)[0:example_len_limit] 148 | data_indexed.append(Example(x, x_tok, index(x_tok, input_indexer), y, y_tok, 149 | index(y_tok, output_indexer) + [output_indexer.get_index(EOS_SYMBOL)])) 150 | for i, exs in enumerate(data_indexed): 151 | exs.id = i + 1 152 | return data_indexed 153 | 154 | def filter_data(data_indexed): 155 | return [exs for exs in data_indexed if exs.y != "null"] 156 | 157 | # Indexes train and test datasets where all words occurring less than or equal to unk_threshold times are 158 | # replaced by UNK tokens. 159 | def index_datasets(train_data, dev_data, input_indexer, output_indexer, example_len_limit): 160 | # Index things 161 | train_data_indexed = index_data(train_data, input_indexer, output_indexer, example_len_limit) 162 | dev_data_indexed = index_data(dev_data, input_indexer, output_indexer, example_len_limit) 163 | # test_data_indexed = index_data(test_data, input_indexer, output_indexer, example_len_limit) 164 | return train_data_indexed, dev_data_indexed 165 | -------------------------------------------------------------------------------- /code/datasets/StReg/map-teste.txt: -------------------------------------------------------------------------------- 1 | 2 const0,O const1,H 2 | 2 const0,h const1,O 3 | 1 const0,N 4 | 0 5 | 1 const0,NU 6 | 1 const0,5 7 | 3 const0,K const1,4 const2,7 8 | 0 9 | 4 const0,! const1,6 const2,8 const3,$ 10 | 4 const0,s const1,Z const2,*## const3,_;, 11 | 1 const0,8 12 | 1 const0,U 13 | 0 14 | 0 15 | 0 16 | 1 const0,X 17 | 5 const0,^ const1,* const2,B const3,a const4,Y 18 | 3 const0,* const1,B const2,Y 19 | 1 const0,P 20 | 1 const0,# 21 | 0 22 | 0 23 | 0 24 | 1 const0,N 25 | 1 const0,0 26 | 2 const0,T const1,B 27 | 0 28 | 0 29 | 2 const0,I const1,F 30 | 2 const0,I const1,f 31 | 0 32 | 2 const0,Q const1,L 33 | 0 34 | 0 35 | 3 const0,gd const1,scy const2,| 36 | 0 37 | 1 const0,5 38 | 3 const0,AK const1,RHK const2,4 39 | 0 40 | 3 const0,v const1,B const2,T 41 | 1 const0,e 42 | 0 43 | 4 const0,. const1,K const2,C const3,RC 44 | 1 const0,C 45 | 1 const0,G 46 | 2 const0,ST const1,xa 47 | 3 const0,70 const1,604 const2,1 48 | 2 const0,K const1,EC 49 | 3 const0,wI const1,kO const2,G 50 | 3 const0,7 const1,4 const2,t 51 | 2 const0,c const1,I 52 | 0 53 | 2 const0,z const1,C 54 | 0 55 | 1 const0,-Y- 56 | 0 57 | 1 const0,S 58 | 2 const0,579 const1,719 59 | 2 const0,579 const1,719 60 | 0 61 | 0 62 | 1 const0,h 63 | 2 const0,rp const1,I 64 | 0 65 | 2 const0,% const1,42 66 | 1 const0,MPK 67 | 0 68 | 0 69 | 1 const0,r 70 | 0 71 | 1 const0,og 72 | 0 73 | 0 74 | 4 const0,YT const1,A const2,S const3,LL 75 | 2 const0,TCA const1,& 76 | 0 77 | 3 const0,9 const1,21 const2,37 78 | 1 const0,5 79 | 0 80 | 0 81 | 2 const0,S const1,v 82 | 0 83 | 0 84 | 0 85 | 2 const0,. const1,0 86 | 0 87 | 1 const0,hv 88 | 1 const0,hv 89 | 2 const0,uy const1,CN 90 | 1 const0,P 91 | 1 const0,T 92 | 1 const0,C 93 | 1 const0,GUC 94 | 0 95 | 1 const0,kD 96 | 2 const0,:; const1,#;; 97 | 1 const0,U 98 | 3 const0,s const1,p const2,O 99 | 0 100 | 0 101 | 2 const0,Vi const1,Kjh 102 | 1 const0,t 103 | 0 104 | 2 const0,HAD const1,IIV 105 | 2 const0,s const1,B 106 | 2 const0,s const1,B 107 | 0 108 | 0 109 | 2 const0,.,+ const1,,_$ 110 | 4 const0,T const1,K const2,i const3,6 111 | 1 const0,4 112 | 4 const0,r const1,q const2,257 const3,786 113 | 4 const0,t const1,W const2,A const3,6 114 | 2 const0,UJQ const1,IW 115 | 3 const0,k const1,;$ const2,_!: 116 | 1 const0,P 117 | 3 const0,@+ const1,_* const2,B 118 | 1 const0,I 119 | 2 const0,Y const1,y 120 | 1 const0,1 121 | 1 const0,I 122 | 0 123 | 4 const0,S const1,xa const2,O const3,fit 124 | 0 125 | 1 const0,L 126 | 1 const0,1 127 | 1 const0,1 128 | 0 129 | 2 const0,5 const1,I 130 | 0 131 | 0 132 | 3 const0,U const1,glc const2,tf 133 | 3 const0,0 const1,+ const2,# 134 | 0 135 | 0 136 | 0 137 | 0 138 | 1 const0,w 139 | 0 140 | 1 const0,D 141 | 2 const0,N const1,D 142 | 0 143 | 1 const0,K 144 | 1 const0,A 145 | 2 const0,o const1,3 146 | 4 const0,q const1,z const2,OC const3,LL 147 | 2 const0,Y const1,% 148 | 1 const0,g 149 | 1 const0,6 150 | 0 151 | 0 152 | 2 const0,IHC const1,N 153 | 0 154 | 3 const0,- const1,i const2,q 155 | 0 156 | 0 157 | 0 158 | 1 const0,0 159 | 0 160 | 0 161 | 0 162 | 4 const0,jow const1,go const2,K const3,T 163 | 0 164 | 2 const0,d const1,1 165 | 0 166 | 1 const0,O 167 | 1 const0,7 168 | 2 const0,w const1,# 169 | 2 const0,3 const1,KN 170 | 0 171 | 0 172 | 1 const0,VCR 173 | 0 174 | 2 const0,r const1,B 175 | 2 const0,r const1,B 176 | 2 const0,ja const1,ys 177 | 1 const0,t 178 | 2 const0,5 const1,20 179 | 0 180 | 1 const0,1 181 | 2 const0,x const1,VRq 182 | 2 const0,JN const1,S 183 | 0 184 | 0 185 | 0 186 | 2 const0,! const1,n 187 | 2 const0,& const1,09 188 | 0 189 | 0 190 | 2 const0,y const1,b 191 | 0 192 | 1 const0,9 193 | 1 const0,TMA 194 | 1 const0,ac 195 | 0 196 | 4 const0,LKM const1,HVA const2,F const3,J 197 | 1 const0,Q 198 | 1 const0,e 199 | 0 200 | 2 const0,D const1,R 201 | 1 const0,A 202 | 0 203 | 0 204 | 2 const0,a const1,E 205 | 0 206 | 1 const0,n 207 | 2 const0,YWA const1,QEL 208 | 1 const0,3 209 | 0 210 | 1 const0,7 211 | 2 const0,3 const1,E 212 | 1 const0,0 213 | 0 214 | 3 const0,xs const1,FY const2,2 215 | 3 const0,xs const1,FY const2,2 216 | 3 const0,r const1,injg const2,hog 217 | 0 218 | 0 219 | 1 const0,n 220 | 2 const0,G const1,V 221 | 1 const0,8 222 | 1 const0,kk 223 | 1 const0,2 224 | 1 const0,P 225 | 0 226 | 2 const0,F const1,Q 227 | 0 228 | 1 const0,A 229 | 2 const0,A const1,AAA 230 | 3 const0,S const1,q const2,t 231 | 1 const0,o 232 | 0 233 | 6 const0,yp const1,tp const2,H const3,o const4,e const5,8 234 | 0 235 | 0 236 | 6 const0,Ya const1,yN const2,s const3,v const4,35 const5,612 237 | 2 const0,LT const1,CN 238 | 2 const0,s const1,s" 239 | 5 const0,O const1,C const2,F const3,N const4,W 240 | 5 const0,O const1,C const2,F const3,N const4,W 241 | 1 const0,o 242 | 1 const0,H 243 | 2 const0,O const1,K 244 | 1 const0,x 245 | 0 246 | 1 const0,9 247 | 1 const0,9 248 | 1 const0,3 249 | 2 const0,tx const1,rx 250 | 0 251 | 0 252 | 0 253 | 1 const0,p 254 | 1 const0,m 255 | 0 256 | 0 257 | 0 258 | 1 const0,UCY 259 | 2 const0,F const1,L 260 | 0 261 | 0 262 | 1 const0,56 263 | 3 const0,p const1,1 const2,7 264 | 0 265 | 3 const0,. const1,U const2,0 266 | 2 const0,. const1,H 267 | 0 268 | 1 const0,9 269 | 2 const0,tJa const1,ELb 270 | 0 271 | 0 272 | 0 273 | 2 const0,91 const1,5 274 | 0 275 | 1 const0,569 276 | 0 277 | 1 const0,d 278 | 1 const0,d 279 | 1 const0,b 280 | 1 const0,bbb 281 | 1 const0,449 282 | 2 const0,n const1,m 283 | 1 const0,6 284 | 2 const0,666 const1,6666 285 | 0 286 | 0 287 | 0 288 | 2 const0,P const1,v 289 | 0 290 | 0 291 | 1 const0,684 292 | 2 const0,NO const1,sFH 293 | 1 const0,l 294 | 1 const0,I 295 | 0 296 | 0 297 | 0 298 | 0 299 | 2 const0,6 const1,3 300 | 1 const0,b 301 | 0 302 | 4 const0,. const1,J const2,I const3,2 303 | 2 const0,c const1,F 304 | 0 305 | 1 const0,O 306 | 2 const0,f const1,l 307 | 2 const0,f const1,l 308 | 0 309 | 0 310 | 1 const0,z 311 | 1 const0,. 312 | 2 const0,V const1,D 313 | 2 const0,VE const1,JJV 314 | 2 const0,i const1,3 315 | 1 const0,o 316 | 4 const0,66 const1,855 const2,t const3,x 317 | 2 const0,_ const1,5 318 | 1 const0,0 319 | 0 320 | 1 const0,d 321 | 1 const0,n 322 | 0 323 | 6 const0,N const1,1 const2,2 const3,tnl const4,R const5,q 324 | 0 325 | 0 326 | 1 const0,Y 327 | 2 const0,W const1,76 328 | 2 const0,q const1,N 329 | 0 330 | 2 const0,drr const1,aiK 331 | 2 const0,P const1,% 332 | 0 333 | 4 const0,EMC const1,X const2,U const3,Z 334 | 4 const0,EMC const1,X const2,U const3,Z 335 | 1 const0,88 336 | 3 const0,y const1,NPZ const2,EX 337 | 0 338 | 2 const0,07 const1,2 339 | 3 const0,eg const1,28 const2,buD 340 | 0 341 | 1 const0,% 342 | 0 343 | 0 344 | 0 345 | 1 const0,00 346 | 2 const0,Z const1,O 347 | 0 348 | 2 const0,^ const1,@ 349 | 1 const0,pkj 350 | 4 const0,= const1,% const2,ryh const3,kdq 351 | 1 const0,g 352 | 0 353 | 0 354 | 0 355 | 1 const0,= 356 | 0 357 | 0 358 | 1 const0,2 359 | 1 const0,2 360 | 4 const0,0 const1,6 const2,7 const3,712 361 | 4 const0,0 const1,6 const2,7 const3,712 362 | 2 const0,d const1,% 363 | 1 const0,p 364 | 5 const0,UWD const1,S const2,GKW const3,XC const4,Y 365 | 1 const0,8 366 | 1 const0,8 367 | 1 const0,O -------------------------------------------------------------------------------- /code/datasets/StReg/map-testi.txt: -------------------------------------------------------------------------------- 1 | 2 const0,h const1,O 2 | 1 const0,n 3 | 1 const0,n 4 | 0 5 | 0 6 | 2 const0,NU const1,DG 7 | 2 const0,NU const1,DG 8 | 1 const0,5 9 | 1 const0,5 10 | 3 const0,K const1,4 const2,7 11 | 3 const0,K const1,4 const2,7 12 | 0 13 | 0 14 | 5 const0,6 const1,2 const2,8 const3,$ const4,^ 15 | 0 16 | 4 const0,s const1,Z const2,*## const3,_;, 17 | 4 const0,s const1,Z const2,_;, const3,## 18 | 1 const0,8 19 | 1 const0,8 20 | 1 const0,u 21 | 0 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 1 const0,X 29 | 1 const0,X 30 | 5 const0,^ const1,* const2,B const3,a const4,Y 31 | 1 const0,P 32 | 1 const0,P 33 | 0 34 | 2 const0,- const1,# 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 2 const0,N const1,g 42 | 2 const0,N const1,g 43 | 1 const0,0 44 | 1 const0,0 45 | 2 const0,T const1,b 46 | 2 const0,T const1,b 47 | 0 48 | 0 49 | 0 50 | 0 51 | 2 const0,I const1,f 52 | 0 53 | 0 54 | 2 const0,Q const1,L 55 | 2 const0,Q const1,L 56 | 0 57 | 0 58 | 0 59 | 0 60 | 2 const0,gd const1,scy 61 | 3 const0,gd const1,scy const2,l 62 | 0 63 | 0 64 | 2 const0,5 const1,4 65 | 2 const0,5 const1,4 66 | 3 const0,AK const1,RHK const2,4 67 | 3 const0,AK const1,RHK const2,4 68 | 0 69 | 0 70 | 3 const0,v const1,B const2,T 71 | 3 const0,BB const1,TT const2,v".s 72 | 1 const0,e 73 | 1 const0,e 74 | 0 75 | 1 const0,3 76 | 3 const0,K const1,C const2,RC 77 | 1 const0,C 78 | 1 const0,C 79 | 1 const0,G 80 | 2 const0,; const1,G 81 | 2 const0,ST const1,xa 82 | 2 const0,ST const1,xa 83 | 2 const0,70 const1,604 84 | 3 const0,70 const1,604 const2,1 85 | 2 const0,K const1,EC 86 | 2 const0,K const1,EC 87 | 2 const0,kO const1,G 88 | 2 const0,wlG const1,kOG 89 | 3 const0,7 const1,4 const2,t 90 | 3 const0,7 const1,4 const2,t 91 | 1 const0,cllll 92 | 2 const0,c const1,I 93 | 0 94 | 0 95 | 1 const0,z 96 | 2 const0,z const1,C 97 | 0 98 | 1 const0,; 99 | 2 const0,6 const1,-Y- 100 | 2 const0,Y const1,6 101 | 0 102 | 0 103 | 1 const0,S 104 | 1 const0,S 105 | 2 const0,579 const1,719 106 | 0 107 | 0 108 | 0 109 | 0 110 | 1 const0,h 111 | 1 const0,h 112 | 2 const0,l const1,rp 113 | 2 const0,rp const1,I 114 | 0 115 | 0 116 | 2 const0,% const1,42 117 | 2 const0,% const1,42 118 | 1 const0,MPK 119 | 1 const0,MPK 120 | 0 121 | 0 122 | 0 123 | 0 124 | 2 const0,i const1,r 125 | 2 const0,i const1,r 126 | 0 127 | 0 128 | 1 const0,og 129 | 1 const0,og 130 | 0 131 | 4 const0,YT const1,A const2,S const3,LL 132 | 4 const0,YT const1,A const2,S const3,LL 133 | 2 const0,TCA const1,& 134 | 2 const0,TCA const1,& 135 | 0 136 | 0 137 | 3 const0,9 const1,21 const2,37 138 | 3 const0,9 const1,21 const2,37 139 | 1 const0,555 140 | 1 const0,5 141 | 0 142 | 0 143 | 0 144 | 0 145 | 2 const0,S const1,v. 146 | 2 const0,S const1,V 147 | 0 148 | 0 149 | 0 150 | 2 const0,. const1,.0 151 | 2 const0,. const1,0 152 | 0 153 | 0 154 | 1 const0,hv 155 | 0 156 | 2 const0,uy const1,CN 157 | 1 const0,P 158 | 1 const0,P 159 | 1 const0,T 160 | 1 const0,T 161 | 2 const0,, const1,C 162 | 1 const0,C 163 | 1 const0,GUC 164 | 1 const0,GUC 165 | 0 166 | 0 167 | 1 const0,kD 168 | 1 const0,kD 169 | 2 const0,:; const1,#;; 170 | 2 const0,:; const1,#;; 171 | 1 const0,U 172 | 1 const0,U 173 | 3 const0,s const1,p const2,O 174 | 3 const0,s const1,p const2,O 175 | 0 176 | 0 177 | 0 178 | 0 179 | 2 const0,Vi const1,Kjh 180 | 2 const0,Vi const1,Kjh 181 | 1 const0,t 182 | 1 const0,t 183 | 0 184 | 0 185 | 2 const0,HAD const1,IIV 186 | 2 const0,HAD const1,IIV 187 | 2 const0,s const1,B 188 | 0 189 | 0 190 | 0 191 | 0 192 | 2 const0,.,+ const1,,_$ 193 | 2 const0,.,+ const1,,_$ 194 | 4 const0,T const1,K const2,i const3,6 195 | 4 const0,T const1,K const2,i const3,6 196 | 1 const0,4 197 | 1 const0,4 198 | 4 const0,r const1,q const2,257 const3,786 199 | 3 const0,q const1,257 const2,786 200 | 4 const0,t const1,W const2,A const3,6 201 | 4 const0,t const1,W const2,A const3,6 202 | 2 const0,UJQ const1,IW 203 | 2 const0,UJQ const1,IW 204 | 2 const0,k const1,;$ 205 | 3 const0,k const1,;$ const2,_!: 206 | 1 const0,P 207 | 1 const0,P 208 | 3 const0,@+ const1,_* const2,B 209 | 3 const0,@+ const1,_* const2,B 210 | 1 const0,I 211 | 0 212 | 2 const0,Y const1,y 213 | 2 const0,Y const1,y 214 | 1 const0,1 215 | 1 const0,1 216 | 1 const0,l 217 | 1 const0,I 218 | 0 219 | 0 220 | 4 const0,S const1,flt const2,xa const3,O 221 | 4 const0,flt const1,xa const2,O const3,s 222 | 0 223 | 0 224 | 1 const0,L 225 | 1 const0,L 226 | 1 const0,1 227 | 1 const0,1 228 | 0 229 | 0 230 | 0 231 | 0 232 | 1 const0,5 233 | 2 const0,5 const1,l 234 | 0 235 | 0 236 | 0 237 | 0 238 | 3 const0,U const1,glc const2,tf 239 | 3 const0,U const1,glc const2,tf 240 | 0 241 | 3 const0,0 const1,+ const2,# 242 | 0 243 | 0 244 | 0 245 | 0 246 | 0 247 | 0 248 | 0 249 | 0 250 | 1 const0,w 251 | 1 const0,w 252 | 0 253 | 0 254 | 1 const0,D 255 | 1 const0,D 256 | 2 const0,N const1,D 257 | 2 const0,N const1,D 258 | 0 259 | 0 260 | 1 const0,K 261 | 2 const0,; const1,K 262 | 1 const0,A 263 | 1 const0,A 264 | 2 const0,o const1,3 265 | 2 const0,o const1,3 266 | 4 const0,q const1,z const2,OC const3,LL 267 | 4 const0,q const1,z const2,OC const3,LL 268 | 2 const0,Y const1,% 269 | 2 const0,Y const1,% 270 | 1 const0,g 271 | 1 const0,g 272 | 1 const0,6 273 | 1 const0,6 274 | 0 275 | 0 276 | 0 277 | 0 278 | 2 const0,IHC const1,N 279 | 3 const0,- const1,IHC const2,N 280 | 0 281 | 0 282 | 2 const0,i const1,q 283 | 2 const0,i const1,q 284 | 1 const0,- 285 | 0 286 | 0 287 | 2 const0,; const1,0 288 | 1 const0,0 289 | 0 290 | 1 const0,; 291 | 0 292 | 1 const0,, 293 | 0 294 | 0 295 | 4 const0,jow const1,go const2,K const3,T 296 | 4 const0,jow const1,go const2,K const3,T 297 | 0 298 | 0 299 | 3 const0,- const1,d const2,1 300 | 2 const0,d const1,1 301 | 0 302 | 0 303 | 1 const0,O 304 | 1 const0,O 305 | 1 const0,7 306 | 1 const0,7 307 | 2 const0,w const1,# 308 | 2 const0,w const1,# 309 | 2 const0,3 const1,KN 310 | 1 const0,KN 311 | 0 312 | 1 const0,VCR 313 | 1 const0,VCR 314 | 0 315 | 0 316 | 2 const0,r const1,B 317 | 2 const0,ja const1,ys 318 | 2 const0,ja const1,ys 319 | 1 const0,t 320 | 1 const0,t 321 | 2 const0,5 const1,20 322 | 1 const0,20 323 | 0 324 | 0 325 | 1 const0,1 326 | 1 const0,1 327 | 2 const0,x const1,VRq 328 | 2 const0,x const1,VRq 329 | 2 const0,JN const1,S 330 | 0 331 | 0 332 | 0 333 | 0 334 | 0 335 | 0 336 | 2 const0,! const1,n 337 | 1 const0,n 338 | 2 const0,& const1,09 339 | 2 const0,& const1,09 340 | 0 341 | 0 342 | 0 343 | 0 344 | 2 const0,y const1,b 345 | 2 const0,y const1,b 346 | 0 347 | 0 348 | 1 const0,9 349 | 1 const0,9 350 | 1 const0,TMA 351 | 1 const0,TMA 352 | 1 const0,ac 353 | 1 const0,ac 354 | 0 355 | 0 356 | 4 const0,LKM const1,HVA const2,F const3,J 357 | 4 const0,LKM const1,HVA const2,F const3,J 358 | 1 const0,Q 359 | 1 const0,Q 360 | 1 const0,e 361 | 1 const0,e 362 | 0 363 | 0 364 | 2 const0,D const1,R 365 | 2 const0,D const1,R 366 | 1 const0,A 367 | 1 const0,A 368 | 0 369 | 0 370 | 0 371 | 0 372 | 1 const0,E 373 | 2 const0,a const1,E 374 | 0 375 | 0 376 | 1 const0,n 377 | 1 const0,n 378 | 2 const0,YWA const1,QEL 379 | 2 const0,YWA const1,QEL 380 | 1 const0,3 381 | 1 const0,3 382 | 0 383 | 0 384 | 1 const0,7 385 | 1 const0,7 386 | 2 const0,3 const1,E 387 | 2 const0,3 const1,E 388 | 1 const0,0 389 | 1 const0,0 390 | 0 391 | 0 392 | 3 const0,xs const1,FY const2,2 393 | 4 const0,r const1,inj const2,ho const3,g 394 | 4 const0,r const1,inj const2,ho const3,g 395 | 0 396 | 0 397 | 0 398 | 0 399 | 1 const0,n 400 | 1 const0,n 401 | 2 const0,G const1,V 402 | 2 const0,G const1,V 403 | 1 const0,8 404 | 1 const0,8 405 | 1 const0,k 406 | 1 const0,k 407 | 1 const0,2 408 | 1 const0,2 409 | 1 const0,P 410 | 1 const0,P 411 | 0 412 | 0 413 | 2 const0,F const1,Q 414 | 3 const0,F const1,Q const2,1 415 | 0 416 | 0 417 | 1 const0,A 418 | 3 const0,S const1,q const2,t 419 | 3 const0,S const1,q const2,t 420 | 1 const0,o 421 | 1 const0,o 422 | 0 423 | 0 424 | 6 const0,yp const1,tp const2,H const3,o const4,e const5,8 425 | 6 const0,yp const1,tp const2,H const3,o const4,e const5,8 426 | 0 427 | 0 428 | 0 429 | 0 430 | 6 const0,Ya const1,yN const2,s const3,v const4,35 const5,612 431 | 6 const0,Ya const1,yN const2,s const3,v const4,35 const5,612 432 | 2 const0,LT const1,CN 433 | 2 const0,LT const1,CN 434 | 1 const0,s 435 | 1 const0,s 436 | 5 const0,O const1,C const2,F const3,N const4,W 437 | 1 const0,o 438 | 1 const0,o 439 | 1 const0,H 440 | 1 const0,H 441 | 2 const0,O const1,K 442 | 1 const0,x 443 | 0 444 | 0 445 | 1 const0,9 446 | 1 const0,3 447 | 1 const0,3 448 | 2 const0,tx const1,rx 449 | 3 const0,; const1,tx const2,rx 450 | 0 451 | 0 452 | 0 453 | 0 454 | 0 455 | 1 const0,p 456 | 1 const0,p 457 | 1 const0,m 458 | 1 const0,m 459 | 0 460 | 0 461 | 2 const0,; const1,UCY 462 | 1 const0,UCY 463 | 2 const0,F const1,L 464 | 2 const0,F const1,L 465 | 0 466 | 0 467 | 0 468 | 0 469 | 2 const0,; const1,56 470 | 4 const0,; const1,p const2,1 const3,7 471 | 3 const0,p const1,1 const2,7 472 | 0 473 | 0 474 | 2 const0,U const1,0 475 | 2 const0,U const1,0 476 | 1 const0,H 477 | 1 const0,H 478 | 1 const0,; 479 | 1 const0,; 480 | 1 const0,9999 481 | 1 const0,9999 482 | 3 const0,; const1,tJa const2,ELb 483 | 0 484 | 0 485 | 0 486 | 0 487 | 1 const0,- 488 | 0 489 | 2 const0,91 const1,5 490 | 2 const0,91 const1,5 491 | 0 492 | 0 493 | 1 const0,569 494 | 1 const0,- 495 | 0 496 | 1 const0,d 497 | 1 const0,b 498 | 1 const0,449 499 | 1 const0,449 500 | 2 const0,n const1,m 501 | 2 const0,n const1,m 502 | 0 503 | 0 504 | 1 const0,, 505 | 0 506 | 0 507 | 2 const0,P const1,v 508 | 2 const0,P const1,v 509 | 0 510 | 0 511 | 0 512 | 2 const0,NO const1,sFH 513 | 2 const0,NO const1,sFH 514 | 1 const0,I 515 | 0 516 | 0 517 | 0 518 | 0 519 | 0 520 | 2 const0,6 const1,3 521 | 2 const0,6 const1,3 522 | 2 const0,- const1,b 523 | 0 524 | 4 const0,. const1,J const2,I const3,2 525 | 3 const0,J const1,I const2,2 526 | 2 const0,c const1,F 527 | 2 const0,c const1,F 528 | 0 529 | 0 530 | 1 const0,O 531 | 1 const0,O 532 | 0 533 | 0 534 | 1 const0,; 535 | 1 const0,,z, 536 | 1 const0,z, 537 | 0 538 | 1 const0,. 539 | 2 const0,V const1,D 540 | 2 const0,V const1,D 541 | 2 const0,VE const1,JJV 542 | 2 const0,i const1,3 543 | 2 const0,i const1,3 544 | 2 const0,; const1,o 545 | 1 const0,o 546 | 4 const0,66 const1,855 const2,t const3,x 547 | 5 const0,- const1,66 const2,855 const3,t const4,x 548 | 2 const0,_ const1,5 549 | 1 const0,5 550 | 1 const0,0 551 | 1 const0,0 552 | 0 553 | 0 554 | 1 const0,d 555 | 1 const0,d 556 | 1 const0,n 557 | 1 const0,n 558 | 0 559 | 0 560 | 6 const0,N const1,1 const2,2 const3,tnl const4,R const5,q 561 | 0 562 | 0 563 | 0 564 | 0 565 | 1 const0,Y 566 | 1 const0,Y 567 | 2 const0,W const1,76 568 | 2 const0,W const1,76 569 | 2 const0,q const1,N 570 | 0 571 | 0 572 | 2 const0,drr const1,aiK 573 | 2 const0,drr const1,aiK 574 | 2 const0,P const1,% 575 | 2 const0,P const1,% 576 | 0 577 | 0 578 | 4 const0,EMC const1,X const2,U const3,Z 579 | 1 const0,8 580 | 1 const0,88 581 | 3 const0,y const1,NPZ const2,EX 582 | 3 const0,y const1,NPZ const2,EX 583 | 0 584 | 0 585 | 2 const0,07 const1,2 586 | 2 const0,07 const1,2 587 | 3 const0,eg const1,28 const2,buD 588 | 3 const0,eg const1,28 const2,buD 589 | 0 590 | 0 591 | 1 const0,% 592 | 1 const0,% 593 | 0 594 | 0 595 | 0 596 | 0 597 | 0 598 | 0 599 | 1 const0,00 600 | 1 const0,00 601 | 2 const0,Z const1,O 602 | 2 const0,Z const1,O 603 | 1 const0,_ 604 | 1 const0,_ 605 | 2 const0,^ const1,@ 606 | 2 const0,^ const1,@ 607 | 1 const0,pjk. 608 | 1 const0,pkj 609 | 4 const0,= const1,% const2,ryh const3,kdq 610 | 4 const0,= const1,% const2,ryh const3,kdq 611 | 1 const0,g 612 | 0 613 | 0 614 | 1 const0,= 615 | 0 616 | 0 617 | 0 618 | 0 619 | 0 620 | 4 const0,0 const1,6 const2,7 const3,712 621 | 2 const0,d const1,% 622 | 2 const0,d const1,% 623 | 1 const0,p 624 | 1 const0,p 625 | 5 const0,UWD const1,S const2,GKW const3,XC const4,Y 626 | 5 const0,UWD const1,S const2,GKW const3,XC const4,Y 627 | 1 const0,8 628 | 2 const0,O const1,w 629 | 2 const0,O const1,w -------------------------------------------------------------------------------- /code/datasets/StReg/map-val.txt: -------------------------------------------------------------------------------- 1 | 4 const0,LEj const1,uWt const2,* const3,& 2 | 4 const0,LEj const1,uWt const2,* const3,& 3 | 2 const0,LEj const1,uWt 4 | 1 const0,4 5 | 0 6 | 0 7 | 0 8 | 0 9 | 0 10 | 4 const0,6 const1,Y const2,984 const3,42 11 | 4 const0,6 const1,Y const2,42 const3,982 12 | 4 const0,6 const1,Y const2,984 const3,42 13 | 1 const0,J 14 | 1 const0,J 15 | 1 const0,J 16 | 1 const0,x 17 | 1 const0,x 18 | 1 const0,x 19 | 1 const0,H 20 | 1 const0,H 21 | 1 const0,H 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 0 29 | 0 30 | 0 31 | 1 const0,r 32 | 1 const0,r 33 | 1 const0,r 34 | 1 const0,XU 35 | 1 const0,XU 36 | 1 const0,XU 37 | 1 const0,u 38 | 1 const0,u 39 | 1 const0,u 40 | 1 const0,4 41 | 1 const0,4 42 | 1 const0,4 43 | 0 44 | 3 const0,R const1,J const2,j 45 | 3 const0,R const1,J const2,j 46 | 2 const0,h const1,v 47 | 2 const0,h const1,v 48 | 2 const0,h const1,v 49 | 0 50 | 0 51 | 1 const0,- 52 | 1 const0,0 53 | 1 const0,0 54 | 1 const0,0 55 | 1 const0,f 56 | 1 const0,f 57 | 1 const0,f 58 | 2 const0,- const1,; 59 | 0 60 | 0 61 | 3 const0,i const1,X const2,C 62 | 2 const0,X const1,C 63 | 2 const0,i const1,C 64 | 1 const0,3 65 | 1 const0,3 66 | 1 const0,3 67 | 1 const0,ecz 68 | 1 const0,ecz 69 | 1 const0,ecz 70 | 0 71 | 0 72 | 0 73 | 1 const0,_ 74 | 1 const0,_ 75 | 1 const0,_ 76 | 1 const0,. 77 | 0 78 | 0 79 | 5 const0,#_ const1,__; const2,F const3,z const4,y 80 | 5 const0,#_ const1,F const2,z const3,y const4,__ 81 | 5 const0,#_ const1,__; const2,F const3,z const4,y 82 | 2 const0,BKZ const1,uO 83 | 2 const0,BKZ const1,uO 84 | 2 const0,BKZ const1,uO 85 | 1 const0,b 86 | 1 const0,b 87 | 1 const0,b 88 | 2 const0,# const1,LLLL 89 | 2 const0,L const1,# 90 | 2 const0,L const1,# 91 | 4 const0,K const1,HNR const2,XCJ const3,W 92 | 4 const0,K const1,HNR const2,XCJ const3,W 93 | 4 const0,K const1,HNR const2,XCJ const3,W 94 | 0 95 | 0 96 | 0 97 | 1 const0,o 98 | 1 const0,o 99 | 1 const0,o 100 | 1 const0,2 101 | 1 const0,2 102 | 1 const0,2 103 | 2 const0,L const1,F 104 | 2 const0,L const1,F 105 | 2 const0,3 const1,0 106 | 2 const0,3 const1,0 107 | 1 const0,0 108 | 2 const0,N const1,D 109 | 2 const0,N const1,D 110 | 2 const0,N const1,D 111 | 0 112 | 0 113 | 0 114 | 3 const0,G const1,F const2,T 115 | 3 const0,G const1,F const2,T 116 | 3 const0,G const1,F const2,T 117 | 0 118 | 0 119 | 0 120 | 1 const0,d 121 | 1 const0,d 122 | 1 const0,d 123 | 2 const0,h const1,8 124 | 3 const0,- const1,h const2,8 125 | 3 const0,- const1,h const2,8 126 | 2 const0,rzq const1,cz 127 | 2 const0,cf const1,rzq 128 | 2 const0,rzq const1,,":"cf 129 | 1 const0,j 130 | 1 const0,j 131 | 1 const0,j 132 | 0 133 | 0 134 | 2 const0,qt const1,CC 135 | 2 const0,qt const1,C 136 | 1 const0,t 137 | 1 const0,t 138 | 1 const0,t 139 | 0 140 | 0 141 | 0 142 | 1 const0,m 143 | 2 const0,m const1,m- 144 | 1 const0,m 145 | 1 const0,K 146 | 2 const0,K const1,, 147 | 2 const0,; const1,K 148 | 2 const0,- const1,Q 149 | 1 const0,Q 150 | 2 const0,- const1,Q- 151 | 2 const0,n const1,JyE 152 | 2 const0,n const1,JyE 153 | 2 const0,n const1,JyE 154 | 0 155 | 0 156 | 0 157 | 2 const0,0 const1,2 158 | 2 const0,0 const1,2 159 | 2 const0,0 const1,2 160 | 3 const0,x const1,j const2,5 161 | 4 const0,, const1,x const2,j const3,5 162 | 4 const0,x const1,j const2,5 const3,4 163 | 1 const0,K 164 | 1 const0,K 165 | 1 const0,K 166 | 1 const0,c. 167 | 1 const0,c 168 | 1 const0,c 169 | 1 const0,1 170 | 2 const0,; const1,1 171 | 0 172 | 0 173 | 0 174 | 0 175 | 1 const0,b 176 | 1 const0,b. 177 | 1 const0,b 178 | 1 const0,v 179 | 1 const0,v 180 | 1 const0,v 181 | 2 const0,P const1,Dtl 182 | 3 const0,- const1,P const2,Dtl 183 | 2 const0,P const1,Dtl 184 | 0 185 | 0 186 | 0 187 | 3 const0,= const1,e const2,O 188 | 3 const0,= const1,e const2,O 189 | 3 const0,= const1,e const2,O 190 | 0 191 | 0 192 | 0 193 | 2 const0,3 const1,4 194 | 2 const0,3 const1,4 195 | 2 const0,3 const1,4 196 | 1 const0,S 197 | 1 const0,SSS 198 | 1 const0,S 199 | 2 const0,N const1,F 200 | 1 const0,N 201 | 2 const0,N const1,F 202 | 2 const0,mxJ const1,rt 203 | 2 const0,mxJ const1,rt 204 | 2 const0,mxJ const1,rt 205 | 1 const0,J 206 | 1 const0,J 207 | 1 const0,J 208 | 1 const0,A 209 | 1 const0,A 210 | 1 const0,A 211 | 1 const0,ksh 212 | 1 const0,ksh 213 | 1 const0,ksh 214 | 1 const0,5 215 | 1 const0,5 216 | 1 const0,5 217 | 2 const0,. const1,u 218 | 2 const0,. const1,u 219 | 1 const0,u 220 | 0 221 | 0 222 | 0 223 | 1 const0,* 224 | 1 const0,* 225 | 0 226 | 0 227 | 0 228 | 0 229 | 0 230 | 0 231 | 0 232 | 3 const0,P const1,I const2,^! 233 | 3 const0,P const1,I const2,^! 234 | 3 const0,P const1,I const2,!^ 235 | 1 const0,A 236 | 1 const0,A 237 | 1 const0,A 238 | 1 const0,u 239 | 1 const0,u 240 | 1 const0,u 241 | 2 const0,2 const1,m7 242 | 2 const0,2 const1,m7 243 | 2 const0,2 const1,m7 244 | 1 const0,P 245 | 1 const0,P 246 | 2 const0,H const1,0 247 | 3 const0,H const1,3 const2,0 248 | 6 const0,X const1,aUI const2,CW const3,V const4,e const5,v 249 | 7 const0,X const1,x const2,aUI const3,CW const4,V const5,e const6,v 250 | 1 const0,K 251 | 1 const0,K 252 | 1 const0,K 253 | 0 254 | 0 255 | 0 256 | 2 const0,x const1,M 257 | 2 const0,x const1,M 258 | 2 const0,x const1,M 259 | 1 const0,j7 260 | 2 const0,j const1,7 261 | 2 const0,j const1,7 262 | 0 263 | 0 264 | 0 265 | 0 266 | 0 267 | 0 268 | 2 const0,1 const1,5 269 | 2 const0,1 const1,5 270 | 2 const0,1 const1,5 271 | 3 const0,71 const1,023 const2,1 272 | 3 const0,71 const1,023 const2,1 273 | 3 const0,71 const1,023 const2,1 274 | 1 const0,X 275 | 1 const0,X 276 | 1 const0,X 277 | 2 const0,q const1,e 278 | 2 const0,q const1,e 279 | 2 const0,q const1,e 280 | 2 const0,s const1,d 281 | 2 const0,s const1,d 282 | 2 const0,s const1,d 283 | 2 const0,Z const1,g 284 | 2 const0,Z const1,g 285 | 2 const0,Z const1,g 286 | 2 const0,v const1,o 287 | 2 const0,v const1,o 288 | 2 const0,v const1,o 289 | 1 const0,O 290 | 1 const0,O 291 | 4 const0,0 const1,8 const2,E const3,s 292 | 4 const0,0 const1,8 const2,E const3,s 293 | 3 const0,0 const1,8 const2,E 294 | 1 const0,e 295 | 2 const0,e const1,-. 296 | 1 const0,e 297 | 0 298 | 0 299 | 0 300 | 0 301 | 0 302 | 1 const0,W 303 | 1 const0,W 304 | 1 const0,W 305 | 1 const0,U 306 | 1 const0,U 307 | 1 const0,U 308 | 0 309 | 0 310 | 0 311 | 1 const0,a 312 | 0 313 | 1 const0,a 314 | 0 315 | 0 316 | 0 317 | 1 const0,id 318 | 1 const0,id 319 | 1 const0,id 320 | 3 const0,VO const1,ZB const2,MF 321 | 3 const0,VO const1,ZB const2,MF 322 | 3 const0,VO const1,ZB const2,MF 323 | 2 const0,k const1,u 324 | 2 const0,k const1,u 325 | 2 const0,k const1,u 326 | 1 const0,10 327 | 1 const0,10 328 | 1 const0,10 329 | 1 const0,0 330 | 1 const0,0 331 | 1 const0,0 332 | 2 const0,$= const1,:&_ 333 | 2 const0,$= const1,:&_ 334 | 1 const0,2 335 | 1 const0,2 336 | 1 const0,2 337 | 0 338 | 0 339 | 0 340 | 1 const0,l 341 | 4 const0,dh const1,l const2,mq const3,nu 342 | 3 const0,dh const1,mq const2,nu 343 | 1 const0,sxg 344 | 1 const0,sxg 345 | 1 const0,sxg 346 | 0 347 | 0 348 | 0 349 | 0 350 | 1 const0,= 351 | 1 const0,= -------------------------------------------------------------------------------- /code/datasets/StReg/rec-teste.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/datasets/StReg/rec-teste.pkl -------------------------------------------------------------------------------- /code/datasets/StReg/rec-testi.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/datasets/StReg/rec-testi.pkl -------------------------------------------------------------------------------- /code/datasets/StReg/rec-train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/datasets/StReg/rec-train.pkl -------------------------------------------------------------------------------- /code/datasets/StReg/rec-val.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/datasets/StReg/rec-val.pkl -------------------------------------------------------------------------------- /code/datasets/StReg/src-indexer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/datasets/StReg/src-indexer.pkl -------------------------------------------------------------------------------- /code/datasets/StReg/targ-indexer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/datasets/StReg/targ-indexer.pkl -------------------------------------------------------------------------------- /code/decode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | import time 5 | import torch 6 | from torch import optim 7 | # from lf_evaluator import * 8 | from models import * 9 | from data import * 10 | from utils import * 11 | import math 12 | from os.path import join 13 | from gadget import * 14 | import os 15 | import shutil 16 | 17 | def _parse_args(): 18 | parser = argparse.ArgumentParser(description='main.py') 19 | 20 | parser.add_argument('dataset', help='specified dataset') 21 | parser.add_argument('model_id', help='specified model id') 22 | 23 | parser.add_argument('--split', type=str, default='test', help='test split') 24 | parser.add_argument('--do_eval', dest='do_eval', default=False, action='store_true', help='only output') 25 | # parser.add_argument('--outfile', dest='outfile', default='beam_output.txt', help='output file of beam') 26 | # Some common arguments for your convenience 27 | 28 | parser.add_argument('--gpu', type=str, default=None, help='gpu id') 29 | parser.add_argument('--seed', type=int, default=0, help='RNG seed (default = 0)') 30 | parser.add_argument('--beam_size', type=int, default=20, help='beam size') 31 | 32 | # 65 is all you need for GeoQuery 33 | parser.add_argument('--decoder_len_limit', type=int, default=170, help='output length limit of the decoder') 34 | parser.add_argument('--input_dim', type=int, default=100, help='input vector dimensionality') 35 | parser.add_argument('--output_dim', type=int, default=100, help='output vector dimensionality') 36 | parser.add_argument('--hidden_size', type=int, default=200, help='hidden state dimensionality') 37 | 38 | # Hyperparameters for the encoder -- feel free to play around with these! 39 | parser.add_argument('--no_bidirectional', dest='bidirectional', default=True, action='store_false', help='bidirectional LSTM') 40 | parser.add_argument('--reverse_input', dest='reverse_input', default=False, action='store_true') 41 | parser.add_argument('--emb_dropout', type=float, default=0.2, help='input dropout rate') 42 | parser.add_argument('--rnn_dropout', type=float, default=0.2, help='dropout rate internal to encoder RNN') 43 | args = parser.parse_args() 44 | return args 45 | 46 | def make_input_tensor(exs, reverse_input): 47 | x = np.array(exs.x_indexed) 48 | len_x = len(exs.x_indexed) 49 | if reverse_input: 50 | x = np.array(x[::-1]) 51 | # add batch dim 52 | x = x[np.newaxis, :] 53 | len_x = np.array([len_x]) 54 | x = torch.from_numpy(x).long() 55 | len_x = torch.from_numpy(len_x) 56 | return x, len_x 57 | 58 | def decode(model_path, test_data, input_indexer, output_indexer, args): 59 | device = config.device 60 | if 'cpu' in str(device): 61 | checkpoint = torch.load(model_path, map_location=device) 62 | else: 63 | checkpoint = torch.load(model_path) 64 | 65 | # Create model 66 | model_input_emb = EmbeddingLayer(args.input_dim, len(input_indexer), args.emb_dropout) 67 | model_enc = RNNEncoder(args.input_dim, args.hidden_size, args.rnn_dropout, args.bidirectional) 68 | model_output_emb = EmbeddingLayer(args.output_dim, len(output_indexer), args.emb_dropout) 69 | model_dec = AttnRNNDecoder(args.input_dim, args.hidden_size, 2 * args.hidden_size if args.bidirectional else args.hidden_size,len(output_indexer), args.rnn_dropout) 70 | 71 | # load dict 72 | model_input_emb.load_state_dict(checkpoint['input_emb']) 73 | model_enc.load_state_dict(checkpoint['enc']) 74 | model_output_emb.load_state_dict(checkpoint['output_emb']) 75 | model_dec.load_state_dict(checkpoint['dec']) 76 | 77 | # map device 78 | model_input_emb.to(device) 79 | model_enc.to(device) 80 | model_output_emb.to(device) 81 | model_dec.to(device) 82 | 83 | # switch to eval 84 | model_input_emb.eval() 85 | model_enc.eval() 86 | model_output_emb.eval() 87 | model_dec.eval() 88 | 89 | pred_derivations = [] 90 | with torch.no_grad(): 91 | for i, ex in enumerate(test_data): 92 | if i % 50 == 0: 93 | print("Done", i) 94 | x, len_x = make_input_tensor(ex, args.reverse_input) 95 | x, len_x = x.to(device), len_x.to(device) 96 | 97 | enc_out_each_word, enc_context_mask, enc_final_states = \ 98 | encode_input_for_decoder(x, len_x, model_input_emb, model_enc) 99 | 100 | pred_derivations.append(beam_decoder(enc_out_each_word, enc_context_mask, enc_final_states, 101 | output_indexer, model_output_emb, model_dec, args.decoder_len_limit, args.beam_size)) 102 | 103 | 104 | output_derivations(test_data, pred_derivations, args) 105 | 106 | def beam_decoder(enc_out_each_word, enc_context_mask, enc_final_states, output_indexer, 107 | model_output_emb, model_dec, decoder_len_limit, beam_size): 108 | ders, scores = batched_beam_sampling(enc_out_each_word, enc_context_mask, enc_final_states, output_indexer, 109 | model_output_emb, model_dec, decoder_len_limit, beam_size) 110 | pred_tokens = [[output_indexer.get_object(t) for t in y] for y in ders] 111 | return pred_tokens 112 | 113 | def output_derivations(test_data, pred_derivations, args): 114 | outfile = get_decode_file(args.dataset, args.split, args.model_id) 115 | with open(outfile, "w") as out: 116 | for i, pred_ders in enumerate(pred_derivations): 117 | out.write(" ".join(["".join(x[1]) for x in enumerate(pred_ders)]) + "\n") 118 | 119 | if __name__ == '__main__': 120 | args = _parse_args() 121 | print(args) 122 | # global device 123 | set_global_device(args.gpu) 124 | 125 | print("Pytroch using device ", config.device) 126 | random.seed(args.seed) 127 | np.random.seed(args.seed) 128 | # Load the training and test data 129 | test, input_indexer, output_indexer = load_test_dataset(args.dataset, args.split) 130 | test_data_indexed = index_data(test, input_indexer, output_indexer, args.decoder_len_limit) 131 | # test_data_indexed = tricky_filter_data(test_data_indexed) 132 | 133 | model_path = get_model_file(args.dataset, args.model_id) 134 | decode(model_path, test_data_indexed, input_indexer, output_indexer, args) -------------------------------------------------------------------------------- /code/eval.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from os.path import join 3 | import numpy as np 4 | import argparse 5 | from tqdm import tqdm 6 | import sys 7 | import subprocess 8 | import os 9 | 10 | def _parse_args(): 11 | parser = argparse.ArgumentParser(description='main.py') 12 | 13 | parser.add_argument('dataset', help='specified dataset') 14 | parser.add_argument('decodes_file', help='specified decodes file') 15 | parser.add_argument('--split', type=str, default='test', help='test split') 16 | parser.add_argument('--decoder_len_limit', type=int, default=170, help='output length limit of the decoder') 17 | 18 | args = parser.parse_args() 19 | return args 20 | 21 | def read_derivations(decode_file): 22 | with open(decode_file) as f: 23 | lines = f.readlines() 24 | lines = [x.rstrip() for x in lines] 25 | lines = [x.split(" ") for x in lines] 26 | 27 | return lines 28 | 29 | def inverse_regex_with_map(r, maps): 30 | for m in maps: 31 | src = m[0] 32 | if len(m[1]) == 1: 33 | dst = "<{}>".format(m[1]) 34 | else: 35 | dst = "const(<{}>)".format(m[1]) 36 | r = r.replace(src, dst) 37 | return r 38 | 39 | def external_evaluation(gt_spec, preds, exs, flag_force=False): 40 | pred_line = " ".join(preds) 41 | exs_line = " ".join(["{},{}".format(x[0], x[1]) for x in exs]) 42 | flag_str = "true" if flag_force else "false" 43 | 44 | flag_use_file = len(pred_line) > 200 45 | if flag_use_file: 46 | filename = join("./external/", "tmp.in") 47 | with open(filename, "w") as f: 48 | f.write(pred_line + "\n") 49 | f.write(exs_line + "\n") 50 | f.write(gt_spec) 51 | out = subprocess.check_output( 52 | ['java', '-cp', './external/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'evaluate_single_file', 53 | filename, flag_str], stderr=subprocess.DEVNULL) 54 | os.remove(filename) 55 | else: 56 | out = subprocess.check_output( 57 | ['java', '-cp', './external/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'evaluate_single', 58 | pred_line, exs_line, gt_spec, flag_str], stderr=subprocess.DEVNULL) 59 | 60 | out = out.decode("utf-8") 61 | out = out.rstrip() 62 | vals = out.split(" ") 63 | return vals[0], vals[1:] 64 | 65 | def filtering_test(gt, preds, m, exs, flag_force=False): 66 | gt = gt.replace(" ", "") 67 | gt = inverse_regex_with_map(gt, m) 68 | preds = [inverse_regex_with_map(x, m) for x in preds] 69 | global_res, pred_res = external_evaluation(gt, preds, exs, flag_force) 70 | if global_res in ["exact", "equiv"]: 71 | return True, global_res, pred_res 72 | else: 73 | return False, global_res, pred_res 74 | 75 | if __name__ == "__main__": 76 | args = _parse_args() 77 | print(args) 78 | test, input_indexer, output_indexer = load_test_dataset(args.dataset, args.split) 79 | 80 | const_maps = load_const_maps(args.dataset, args.split) 81 | exs_lists = load_exs(args.dataset, args.split) 82 | 83 | decode_file = get_decode_file(args.dataset, args.split, args.decodes_file) 84 | pred_derivations = read_derivations(decode_file) 85 | 86 | cnt = 0 87 | results = [] 88 | for (_,gt), p, m, exs in tqdm(zip(test, pred_derivations, const_maps, exs_lists), desc='eval', file=sys.stdout, total=len(test)): 89 | # print(gt) 90 | match_result = filtering_test(gt, p, m, exs, flag_force=True) 91 | # print(match_result[0]) 92 | results.append(match_result) 93 | 94 | # Top 0 DFA Acc 95 | num_top0_correct = sum([x[2][0] in ["exact", "equiv"] for x in results]) 96 | print('Top-1 Derivation DFA ACC', num_top0_correct * 1. / len(results)) 97 | # Filter Acc 98 | num_correct_after_filtering = sum([x[0] for x in results]) 99 | print('ACC with Filtering', num_correct_after_filtering * 1. / len(results)) 100 | -------------------------------------------------------------------------------- /code/external/datagen.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/external/datagen.jar -------------------------------------------------------------------------------- /code/external/lib/antlr-4.7.1-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/code/external/lib/antlr-4.7.1-complete.jar -------------------------------------------------------------------------------- /code/gadget.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import sys 4 | from torch import optim 5 | # from lf_evaluator import * 6 | from models import * 7 | from data import * 8 | from utils import * 9 | 10 | class config(): 11 | device = None 12 | 13 | def set_global_device(gpu): 14 | if gpu is not None: 15 | config.device = torch.device(('cuda:' + gpu) if torch.cuda.is_available() else 'cpu') 16 | else: 17 | config.device = 'cpu' 18 | 19 | # Analogous to make_padded_input_tensor, but without the option to reverse input 20 | def make_padded_output_tensor(exs, output_indexer, max_len): 21 | return np.array([[ex.y_indexed[i] if i < len(ex.y_indexed) else output_indexer.index_of(PAD_SYMBOL) for i in range(0, max_len)] for ex in exs]) 22 | 23 | # Takes the given Examples and their input indexer and turns them into a numpy array by padding them out to max_len. 24 | # Optionally reverses them. 25 | def make_padded_input_tensor(exs, input_indexer, max_len, reverse_input): 26 | if reverse_input: 27 | return np.array( 28 | [[ex.x_indexed[len(ex.x_indexed) - 1 - i] if i < len(ex.x_indexed) else input_indexer.index_of(PAD_SYMBOL) 29 | for i in range(0, max_len)] 30 | for ex in exs]) 31 | else: 32 | return np.array([[ex.x_indexed[i] if i < len(ex.x_indexed) else input_indexer.index_of(PAD_SYMBOL) 33 | for i in range(0, max_len)] 34 | for ex in exs]) 35 | 36 | def masked_cross_entropy(voc_scores, gt, mask): 37 | corss_entropy = -torch.log(torch.gather(voc_scores, 1, gt.view(-1, 1))) 38 | loss = corss_entropy.squeeze(1).masked_select(mask).sum() 39 | return loss 40 | 41 | def encode_input_for_decoder(x_tensor, inp_lens_tensor, model_input_emb, model_enc): 42 | input_emb = model_input_emb.forward(x_tensor) 43 | (enc_output_each_word, enc_context_mask, enc_final_states) = model_enc.forward(input_emb, inp_lens_tensor) 44 | # print(enc_output_each_word.size(), enc_context_mask.size()) 45 | enc_final_states_reshaped = (enc_final_states[0].unsqueeze(0), enc_final_states[1].unsqueeze(0)) 46 | return (enc_output_each_word, enc_context_mask, enc_final_states_reshaped) 47 | 48 | def batched_beam_sampling(enc_out_each_word, enc_context_mask, enc_final_states, output_indexer, 49 | model_output_emb, model_dec, decoder_len_limit, beam_size): 50 | device = config.device 51 | EOS = output_indexer.get_index(EOS_SYMBOL) 52 | context_inf_mask = get_inf_mask(enc_context_mask) 53 | 54 | completed = [] 55 | cur_beam = [([], .0, .0)] 56 | # 0 toks, 1 score 57 | input_words = torch.LongTensor([[output_indexer.index_of(SOS_SYMBOL)]]).to(device) 58 | input_states = enc_final_states 59 | for _ in range(decoder_len_limit): 60 | 61 | # input_words = torch.LongTensor([[x[1] for x in cur_beam]]).to(device) 62 | input_embeded_words = model_output_emb.forward(input_words) 63 | 64 | batch_voc_scores, batch_next_states = model_dec(input_embeded_words, input_states, enc_out_each_word, context_inf_mask) 65 | batch_voc_scores = torch.log(batch_voc_scores) 66 | batch_voc_scores_cpu = batch_voc_scores.tolist() 67 | 68 | next_beam = [] 69 | action_pool = [] 70 | for b_id, voc_scores in enumerate(batch_voc_scores_cpu): 71 | base_score = cur_beam[b_id][1] 72 | for voc_id, score_cpu in enumerate(voc_scores): 73 | # next_beam.append() 74 | action_pool.append((b_id, voc_id, base_score + score_cpu, True)) 75 | 76 | for b_id, (_, score, _) in enumerate(completed): 77 | action_pool.append((b_id, 0, score, False )) 78 | 79 | action_pool.sort(key=lambda x: x[2], reverse=True) 80 | kept_b_id = [] 81 | next_input_words = [] 82 | next_completed = [] 83 | for b_id, voc_id, new_score, is_gen in action_pool[:beam_size]: 84 | if is_gen: 85 | if voc_id == EOS: 86 | next_completed.append((cur_beam[b_id][0], new_score, cur_beam[b_id][2] + batch_voc_scores[b_id][voc_id])) 87 | else: 88 | next_beam.append((cur_beam[b_id][0] + [voc_id], new_score, cur_beam[b_id][2] + batch_voc_scores[b_id][voc_id])) 89 | next_input_words.append(voc_id) 90 | kept_b_id.append(b_id) 91 | else: 92 | next_completed.append(completed[b_id]) 93 | completed = next_completed 94 | if not next_beam: 95 | break 96 | kept_b_id = torch.LongTensor(kept_b_id).to(device) 97 | input_words = torch.LongTensor([next_input_words]).to(device) 98 | cur_beam = next_beam 99 | input_states = batch_next_states[0].index_select(1, kept_b_id), batch_next_states[1].index_select(1, kept_b_id) 100 | 101 | completed.sort(key=lambda x: x[1], reverse=True) 102 | ders = [x[0] for x in completed] 103 | sum_probs = [x[2] for x in completed] 104 | 105 | # print('---------------') 106 | # pred_tokens = [[output_indexer.get_object(t) for t in y] for y in ders] 107 | # [print(''.join(p)) for p in pred_tokens] 108 | return ders, sum_probs 109 | -------------------------------------------------------------------------------- /code/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | 6 | import numpy as np 7 | import math 8 | 9 | class BatchDataLoader(object): 10 | 11 | def __init__(self, data, all_in, all_out, batch_size=1, shuffle=False, drop_last=False, return_id=False): 12 | self.data = data 13 | self.all_in = torch.from_numpy(all_in).long() 14 | # self.all_in_lens = torch.from_numpy(np.asarray([min(100,len(ex.x_indexed)) for ex in data])) 15 | self.all_in_lens = torch.from_numpy(np.asarray([len(ex.x_indexed) for ex in data])) 16 | self.all_out = torch.from_numpy(all_out).long() 17 | self.all_out_lens = torch.from_numpy(np.asarray([len(ex.y_indexed) for ex in data])) 18 | self.all_id = torch.from_numpy(np.asarray([ex.id for ex in data])) 19 | self.batch_size = batch_size 20 | self.shuffle = shuffle 21 | self.drop_last = drop_last 22 | self.batch_cnt = 0 23 | self.max_batch = 0 24 | self.sampler = None 25 | self.return_id = return_id 26 | 27 | def __len__(self): 28 | return int(math.ceil(len(self.data) * 1.0 / self.batch_size)) 29 | 30 | def __next__(self): 31 | if self.batch_cnt == self.max_batch: 32 | # self.__iter__() 33 | raise StopIteration 34 | 35 | b_start = self.batch_cnt * self.batch_size 36 | b_end = min(b_start + self.batch_size, len(self.data)) 37 | b_idx = self.sampler[b_start:b_end] 38 | b_idx = np.sort(b_idx) 39 | 40 | b_in = self.all_in[b_idx] 41 | b_in_lens = self.all_in_lens[b_idx] 42 | # print(b_in) 43 | b_out = self.all_out[b_idx] 44 | b_out_lens = self.all_out_lens[b_idx] 45 | self.batch_cnt += 1 46 | if self.return_id: 47 | b_ids = self.all_id[b_idx] 48 | return b_in, b_in_lens, b_out, b_out_lens, b_ids 49 | else: 50 | return b_in, b_in_lens, b_out, b_out_lens 51 | 52 | def __iter__(self): 53 | self.batch_cnt = 0 54 | self.max_batch = self.__len__() 55 | if self.shuffle: 56 | self.sampler = np.random.choice(len(self.data), len(self.data), replace=False) 57 | # print(self.sampler) 58 | else: 59 | self.sampler = np.arange(len(self.data)) 60 | return self 61 | 62 | def sent_lens_to_mask(lens, max_length): 63 | mask = torch.from_numpy(np.asarray([[1 if j < lens.data[i].item() else 0 64 | for j in range(0, max_length)] for i in range(0, lens.shape[0])])).bool() 65 | # match device of input 66 | return mask.to(lens.device) 67 | 68 | def get_inf_mask(mask): 69 | inf_mask = torch.zeros_like(mask, dtype=torch.float32) 70 | inf_mask.masked_fill_(~mask, float("-inf")) 71 | return inf_mask 72 | 73 | # Embedding layer that has a lookup table of symbols that is [full_dict_size x input_dim]. Includes dropout. 74 | # Works for both non-batched and batched inputs 75 | class EmbeddingLayer(nn.Module): 76 | # Parameters: dimension of the word embeddings, number of words, and the dropout rate to apply 77 | # (0.2 is often a reasonable value) 78 | def __init__(self, input_dim, full_dict_size, embedding_dropout_rate): 79 | super(EmbeddingLayer, self).__init__() 80 | self.dropout = nn.Dropout(embedding_dropout_rate) 81 | self.word_embedding = nn.Embedding(full_dict_size, input_dim) 82 | 83 | # Takes either a non-batched input [sent len x input_dim] or a batched input 84 | # [batch size x sent len x input dim] 85 | def forward(self, input): 86 | embedded_words = self.word_embedding(input) 87 | final_embeddings = self.dropout(embedded_words) 88 | return final_embeddings 89 | 90 | 91 | # One-layer RNN encoder for batched inputs -- handles multiple sentences at once. You're free to call it with a 92 | # leading dimension of 1 (batch size 1) but it does expect this dimension. 93 | class RNNEncoder(nn.Module): 94 | # Parameters: input size (should match embedding layer), hidden size for the LSTM, dropout rate for the RNN, 95 | # and a boolean flag for whether or not we're using a bidirectional encoder 96 | def __init__(self, input_size, hidden_size, dropout, bidirect): 97 | super(RNNEncoder, self).__init__() 98 | self.bidirect = bidirect 99 | self.input_size = input_size 100 | self.hidden_size = hidden_size 101 | self.reduce_h_W = nn.Linear(hidden_size * 2, hidden_size, bias=True) 102 | self.reduce_c_W = nn.Linear(hidden_size * 2, hidden_size, bias=True) 103 | self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, 104 | dropout=dropout, bidirectional=self.bidirect) 105 | self.init_weight() 106 | 107 | # Initializes weight matrices using Xavier initialization 108 | def init_weight(self): 109 | nn.init.xavier_uniform_(self.rnn.weight_hh_l0, gain=1) 110 | nn.init.xavier_uniform_(self.rnn.weight_ih_l0, gain=1) 111 | if self.bidirect: 112 | nn.init.xavier_uniform_(self.rnn.weight_hh_l0_reverse, gain=1) 113 | nn.init.xavier_uniform_(self.rnn.weight_ih_l0_reverse, gain=1) 114 | nn.init.constant_(self.rnn.bias_hh_l0, 0) 115 | nn.init.constant_(self.rnn.bias_ih_l0, 0) 116 | if self.bidirect: 117 | nn.init.constant_(self.rnn.bias_hh_l0_reverse, 0) 118 | nn.init.constant_(self.rnn.bias_ih_l0_reverse, 0) 119 | 120 | def get_output_size(self): 121 | return self.hidden_size * 2 if self.bidirect else self.hidden_size 122 | 123 | # embedded_words should be a [batch size x sent len x input dim] tensor 124 | # input_lens is a tensor containing the length of each input sentence 125 | # Returns output (each word's representation), context_mask (a mask of 0s and 1s 126 | # reflecting where the model's output should be considered), and h_t, a *tuple* containing 127 | # the final states h and c from the encoder for each sentence. 128 | def forward(self, embedded_words, input_lens): 129 | # Takes the embedded sentences, "packs" them into an efficient Pytorch-internal representation 130 | packed_embedding = nn.utils.rnn.pack_padded_sequence(embedded_words, input_lens, batch_first=True) 131 | # Runs the RNN over each sequence. Returns output at each position as well as the last vectors of the RNN 132 | # state for each sentence (first/last vectors for bidirectional) 133 | output, hn = self.rnn(packed_embedding) 134 | # Unpacks the Pytorch representation into normal tensors 135 | output, _ = nn.utils.rnn.pad_packed_sequence(output) 136 | max_length = input_lens.data[0].item() 137 | context_mask = sent_lens_to_mask(input_lens, max_length) 138 | 139 | # Grabs the encoded representations out of hn, which is a weird tuple thing. 140 | # Note: if you want multiple LSTM layers, you'll need to change this to consult the penultimate layer 141 | # or gather representations from all layers. 142 | if self.bidirect: 143 | h, c = hn[0], hn[1] 144 | # Grab the representations from forward and backward LSTMs 145 | h_, c_ = torch.cat((h[0], h[1]), dim=1), torch.cat((c[0], c[1]), dim=1) 146 | # Reduce them by multiplying by a weight matrix so that the hidden size sent to the decoder is the same 147 | # as the hidden size in the encoder 148 | new_h = self.reduce_h_W(h_) 149 | new_c = self.reduce_c_W(c_) 150 | h_t = (new_h, new_c) 151 | else: 152 | h, c = hn[0][0], hn[1][0] 153 | h_t = (h, c) 154 | # print(max_length, output.size(), h_t[0].size(), h_t[1].size()) 155 | 156 | return (output, context_mask, h_t) 157 | 158 | 159 | # decoder 160 | # in: word_vec (input), hidden_layer 161 | # leading dimension of 1 (batch size 1) but it does expect this dimension. 162 | class RNNDecoder(nn.Module): 163 | # Parameters: input size (should match embedding layer), hidden size for the LSTM, dropout rate for the RNN, 164 | # and a boolean flag for whether or not we're using a bidirectional encoder 165 | def __init__(self, input_size, hidden_size, voc_size, dropout): 166 | super(RNNDecoder, self).__init__() 167 | self.input_size = input_size 168 | self.hidden_size = hidden_size 169 | self.voc_size = voc_size 170 | 171 | self.reduce_h_v = nn.Linear(hidden_size, voc_size, bias=True) 172 | self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, dropout=dropout) 173 | self.init_weight() 174 | 175 | # Initializes weight matrices using Xavier initialization 176 | def init_weight(self): 177 | nn.init.xavier_uniform_(self.rnn.weight_hh_l0, gain=1) 178 | nn.init.xavier_uniform_(self.rnn.weight_ih_l0, gain=1) 179 | nn.init.constant_(self.rnn.bias_hh_l0, 0) 180 | nn.init.constant_(self.rnn.bias_ih_l0, 0) 181 | 182 | def forward(self, embedded_words, hidden_states): 183 | # Takes the embedded sentences, "packs" them into an efficient Pytorch-internal representation 184 | outputs, hn = self.rnn(embedded_words, hidden_states) 185 | voc_scores = self.reduce_h_v(outputs) 186 | voc_scores = voc_scores.reshape((-1, self.voc_size)) 187 | voc_scores = F.softmax(voc_scores, 1) 188 | return voc_scores, hn 189 | 190 | class LuongAttention(nn.Module): 191 | 192 | def __init__(self, hidden_size, context_size=None): 193 | super(LuongAttention, self).__init__() 194 | self.hidden_size = hidden_size 195 | self.context_size = hidden_size if context_size is None else context_size 196 | self.attn = torch.nn.Linear(self.context_size, self.hidden_size) 197 | 198 | self.init_weight() 199 | 200 | def init_weight(self): 201 | nn.init.xavier_uniform_(self.attn.weight, gain=1) 202 | nn.init.constant_(self.attn.bias, 0) 203 | 204 | # input query: batch * q * hidden, contexts: batch * c * hidden 205 | # output: batch * len * q * c 206 | def forward(self, query, context, inf_mask=None, requires_weight=False): 207 | # Calculate the attention weights (energies) based on the given method 208 | query = query.transpose(0, 1) 209 | context = context.transpose(0, 1) 210 | 211 | e = self.attn(context) 212 | # e: B * Q * C 213 | e = torch.matmul(query, e.transpose(1, 2)) 214 | if inf_mask is not None: 215 | e = e + inf_mask.unsqueeze(1) 216 | 217 | # dim w: B * Q * C, context: B * C * H, wanted B * Q * H 218 | w = F.softmax(e, dim=2) 219 | c = torch.matmul(w, context) 220 | # # Return the softmax normalized probability scores (with added dimension 221 | if requires_weight: 222 | return c.transpose(0, 1), w 223 | return c.transpose(0, 1) 224 | 225 | # decoder 226 | # in: word_vec (input), hidden_layer 227 | # leading dimension of 1 (batch size 1) but it does expect this dimension. 228 | class AttnRNNDecoder(nn.Module): 229 | # Parameters: input size (should match embedding layer), hidden size for the LSTM, dropout rate for the RNN, 230 | # and a boolean flag for whether or not we're using a bidirectional encoder 231 | def __init__(self, input_size, hidden_size, context_hidden_size, voc_size, dropout): 232 | super(AttnRNNDecoder, self).__init__() 233 | self.input_size = input_size 234 | self.hidden_size = hidden_size 235 | self.voc_size = voc_size 236 | self.context_hidden_size = context_hidden_size 237 | self.attn = LuongAttention(hidden_size, context_hidden_size) 238 | 239 | self.reduce_h_v = nn.Linear(hidden_size + context_hidden_size, voc_size, bias=True) 240 | self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, dropout=dropout) 241 | self.init_weight() 242 | 243 | # Initializes weight matrices using Xavier initialization 244 | def init_weight(self): 245 | nn.init.xavier_uniform_(self.rnn.weight_hh_l0, gain=1) 246 | nn.init.xavier_uniform_(self.rnn.weight_ih_l0, gain=1) 247 | nn.init.constant_(self.rnn.bias_hh_l0, 0) 248 | nn.init.constant_(self.rnn.bias_ih_l0, 0) 249 | 250 | def forward(self, embedded_words, hidden_states, context_states, context_inf_mask): 251 | 252 | outputs, hn = self.rnn(embedded_words, hidden_states) 253 | 254 | # attn_weights: batch * len 255 | # context_states batch * len * contedxt_hidden_size 256 | # contexts = torch.bmm(attn_weights.unsqueeze(1), context_states.transpose(0, 1)) 257 | # output_contexts = contexts.view((1, -1, self.context_hidden_size)) 258 | output_contexts = self.attn(hn[0], context_states, inf_mask=context_inf_mask) 259 | concated_outpts = torch.cat((outputs, output_contexts), 2) 260 | # concated_outpts = outputs 261 | # concated_outpts = F.relu(concated_outpts) 262 | voc_scores = self.reduce_h_v(concated_outpts) 263 | 264 | voc_scores = voc_scores.reshape((-1, self.voc_size)) 265 | voc_scores = F.softmax(voc_scores, 1) 266 | return voc_scores, hn 267 | 268 | def forward_and_extract_attn(self, embedded_words, hidden_states, context_states, context_inf_mask): 269 | outputs, hn = self.rnn(embedded_words, hidden_states) 270 | 271 | # attn_weights: batch * len 272 | # context_states batch * len * contedxt_hidden_size 273 | 274 | output_contexts, attn_weights = self.attn(hn[0], context_states, inf_mask=context_inf_mask, requires_weight=True) 275 | concated_outpts = torch.cat((outputs, output_contexts), 2) 276 | # concated_outpts = outputs 277 | # concated_outpts = F.relu(concated_outpts) 278 | voc_scores = self.reduce_h_v(concated_outpts) 279 | 280 | voc_scores = voc_scores.reshape((-1, self.voc_size)) 281 | voc_scores = F.softmax(voc_scores, 1) 282 | return voc_scores, hn, attn_weights.squeeze() 283 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import sys 4 | import numpy as np 5 | import time 6 | import torch 7 | from torch import optim 8 | from gadget import * 9 | from models import * 10 | from data import * 11 | from utils import * 12 | import math 13 | 14 | def _parse_args(): 15 | parser = argparse.ArgumentParser(description='main.py') 16 | 17 | parser.add_argument('dataset', help='specified dataset') 18 | # General system running and configuration options 19 | 20 | # Some common arguments for your convenience 21 | parser.add_argument("--gpu", type=str, default="0", help="gpu id") 22 | parser.add_argument('--seed', type=int, default=0, help='RNG seed (default = 0)') 23 | parser.add_argument('--epochs', type=int, default=100, help='num epochs to train for') 24 | parser.add_argument('--lr', type=float, default=.001) 25 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 26 | parser.add_argument('--clip_grad', type=float, default=10.0) 27 | 28 | # regarding model saving 29 | parser.add_argument('--model_id', type=str, default=None, help='model identifier') 30 | parser.add_argument('--saving_from', type=int, default=50, help='saving from - epoch') 31 | parser.add_argument('--saving_interval', type=int, default=10, help='saving iterval') 32 | 33 | # 65 is all you need for GeoQuery 34 | parser.add_argument('--decoder_len_limit', type=int, default=170, help='output length limit of the decoder') 35 | parser.add_argument('--input_dim', type=int, default=100, help='input vector dimensionality') 36 | parser.add_argument('--output_dim', type=int, default=100, help='output vector dimensionality') 37 | parser.add_argument('--hidden_size', type=int, default=200, help='hidden state dimensionality') 38 | 39 | # Hyperparameters for the encoder -- feel free to play around with these! 40 | parser.add_argument('--no_bidirectional', dest='bidirectional', default=True, action='store_false', help='bidirectional LSTM') 41 | parser.add_argument('--reverse_input', dest='reverse_input', default=False, action='store_true') 42 | parser.add_argument('--emb_dropout', type=float, default=0.2, help='input dropout rate') 43 | parser.add_argument('--rnn_dropout', type=float, default=0.2, help='dropout rate internal to encoder RNN') 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def train_decode_with_output_of_encoder(enc_out_each_word, enc_context_mask, 50 | enc_final_states, output_indexer, gt_out, gt_out_lens, 51 | model_output_emb, model_dec, decoder_len_limit, p_forcing): 52 | batch_size = enc_context_mask.size(0) 53 | context_inf_mask = get_inf_mask(enc_context_mask) 54 | input_words = torch.from_numpy(np.asarray([output_indexer.index_of(SOS_SYMBOL) for _ in range(batch_size)])) 55 | input_words = input_words.to(config.device) 56 | input_words = input_words.unsqueeze(1) 57 | dec_hidden_states = enc_final_states 58 | 59 | gt_out_mask = sent_lens_to_mask(gt_out_lens, gt_out.size(1)) 60 | output_max_len = torch.max(gt_out_lens).item() 61 | 62 | using_teacher_forcing = np.random.uniform() < p_forcing 63 | loss = 0 64 | 65 | if using_teacher_forcing: 66 | for i in range(output_max_len): 67 | input_embeded_words = model_output_emb.forward(input_words) 68 | input_embeded_words = input_embeded_words.reshape((1, batch_size, -1)) 69 | voc_scores, dec_hidden_states = model_dec(input_embeded_words, dec_hidden_states, enc_out_each_word, context_inf_mask) 70 | input_words = gt_out[:, i].view((-1, 1)) 71 | 72 | loss += masked_cross_entropy(voc_scores, gt_out[:, i], gt_out_mask[:, i]) 73 | 74 | else: 75 | for i in range(output_max_len): 76 | input_embeded_words = model_output_emb.forward(input_words) 77 | input_embeded_words = input_embeded_words.reshape((1, batch_size, -1)) 78 | voc_scores, dec_hidden_states = model_dec(input_embeded_words, dec_hidden_states, enc_out_each_word, context_inf_mask) 79 | output_words = voc_scores.argmax(dim=1, keepdim=True) 80 | input_words = output_words.detach() 81 | loss += masked_cross_entropy(voc_scores, gt_out[:, i], gt_out_mask[:, i]) 82 | 83 | num_entry = gt_out_lens.sum().float().item() 84 | loss = loss / num_entry 85 | return loss, num_entry 86 | 87 | def model_perplexity(test_loader, 88 | model_input_emb, model_enc, model_output_emb, model_dec, 89 | input_indexer, output_indexer, args): 90 | device = config.device 91 | model_input_emb.eval() 92 | model_enc.eval() 93 | model_output_emb.eval() 94 | model_dec.eval() 95 | 96 | test_iter = iter(test_loader) 97 | epoch_loss = 0.0 98 | epoch_num_entry = 0.0 99 | 100 | with torch.no_grad(): 101 | for _, batch_data in enumerate(test_iter): 102 | batch_in, batch_in_lens, batch_out, batch_out_lens = batch_data 103 | batch_in, batch_in_lens, batch_out, batch_out_lens = \ 104 | batch_in.to(device), batch_in_lens.to(device), batch_out.to(device), batch_out_lens.to(device) 105 | 106 | enc_out_each_word, enc_context_mask, enc_final_states = \ 107 | encode_input_for_decoder(batch_in, batch_in_lens, model_input_emb, model_enc) 108 | 109 | loss, num_entry = \ 110 | train_decode_with_output_of_encoder(enc_out_each_word, enc_context_mask, enc_final_states, output_indexer, 111 | batch_out, batch_out_lens, model_output_emb, model_dec, args.decoder_len_limit, 1) 112 | epoch_loss += (loss.item() * num_entry) 113 | epoch_num_entry += num_entry 114 | perperlexity = epoch_loss / epoch_num_entry 115 | return perperlexity 116 | 117 | def train_model_encdec_ml(train_data, test_data, input_indexer, output_indexer, args): 118 | device = config.device 119 | # Sort in descending order by x_indexed, essential for pack_padded_sequence 120 | train_data.sort(key=lambda ex: len(ex.x_indexed), reverse=True) 121 | test_data.sort(key=lambda ex: len(ex.x_indexed), reverse=True) 122 | 123 | # Create indexed input 124 | train_input_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in train_data])) 125 | test_input_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in test_data])) 126 | input_max_len = max(train_input_max_len, test_input_max_len) 127 | # input_max_len = 100 128 | 129 | all_train_input_data = make_padded_input_tensor(train_data, input_indexer, input_max_len, args.reverse_input) 130 | all_test_input_data = make_padded_input_tensor(test_data, input_indexer, input_max_len, args.reverse_input) 131 | 132 | train_output_max_len = np.max(np.asarray([len(ex.y_indexed) for ex in train_data])) 133 | test_output_max_len = np.max(np.asarray([len(ex.y_indexed) for ex in test_data])) 134 | output_max_len = max(train_output_max_len, test_output_max_len) 135 | all_train_output_data = make_padded_output_tensor(train_data, output_indexer, output_max_len) 136 | all_test_output_data = make_padded_output_tensor(test_data, output_indexer, np.max(np.asarray([len(ex.y_indexed) for ex in test_data])) ) 137 | all_test_output_data = np.maximum(all_test_output_data, 0) 138 | 139 | print("Train length: %i" % input_max_len) 140 | print("Train output length: %i" % np.max(np.asarray([len(ex.y_indexed) for ex in train_data]))) 141 | print("Train matrix: %s; shape = %s" % (all_train_input_data, all_train_input_data.shape)) 142 | 143 | # Create model 144 | model_input_emb = EmbeddingLayer(args.input_dim, len(input_indexer), args.emb_dropout) 145 | model_enc = RNNEncoder(args.input_dim, args.hidden_size, args.rnn_dropout, args.bidirectional) 146 | model_output_emb = EmbeddingLayer(args.output_dim, len(output_indexer), args.emb_dropout) 147 | model_dec = AttnRNNDecoder(args.input_dim, args.hidden_size, 2 * args.hidden_size if args.bidirectional else args.hidden_size,len(output_indexer), args.rnn_dropout) 148 | 149 | model_input_emb.to(device) 150 | model_enc.to(device) 151 | model_output_emb.to(device) 152 | model_dec.to(device) 153 | 154 | # Loop over epochs, loop over examples, given some indexed words, call encode_input_for_decoder, then call your 155 | # decoder, accumulate losses, update parameters 156 | 157 | # optimizer = None 158 | train_loader = BatchDataLoader(train_data, all_train_input_data, all_train_output_data, batch_size=args.batch_size, shuffle=True) 159 | test_loader = BatchDataLoader(test_data, all_test_input_data, all_test_output_data, batch_size=args.batch_size, shuffle=False) 160 | 161 | train_iter = iter(train_loader) 162 | 163 | optimizer = optim.Adam([ 164 | {'params': model_input_emb.parameters()}, 165 | {'params': model_enc.parameters()}, 166 | {'params': model_output_emb.parameters()}, 167 | {'params': model_dec.parameters()}], lr=0.001) 168 | 169 | get_teaching_forcing_ratio = lambda x: 1.0 170 | clip = args.clip_grad 171 | 172 | best_dev_perplexity = np.inf 173 | for epoch in range(1, args.epochs + 1): 174 | 175 | model_input_emb.train() 176 | model_enc.train() 177 | model_output_emb.train() 178 | model_dec.train() 179 | 180 | print('epoch {}'.format(epoch)) 181 | epoch_loss = 0.0 182 | epoch_num_entry = 0.0 183 | for batch_idx, batch_data in enumerate(train_iter): 184 | 185 | optimizer.zero_grad() 186 | 187 | batch_in, batch_in_lens, batch_out, batch_out_lens = batch_data 188 | batch_in, batch_in_lens, batch_out, batch_out_lens = \ 189 | batch_in.to(device), batch_in_lens.to(device), batch_out.to(device), batch_out_lens.to(device) 190 | 191 | enc_out_each_word, enc_context_mask, enc_final_states = \ 192 | encode_input_for_decoder(batch_in, batch_in_lens, model_input_emb, model_enc) 193 | 194 | tf_ratio = get_teaching_forcing_ratio(epoch) 195 | loss, num_entry = \ 196 | train_decode_with_output_of_encoder(enc_out_each_word, enc_context_mask, enc_final_states, output_indexer, 197 | batch_out, batch_out_lens, model_output_emb, model_dec, args.decoder_len_limit, tf_ratio) 198 | 199 | loss.backward() 200 | epoch_loss += (loss.item() * num_entry) 201 | epoch_num_entry += num_entry 202 | # print('epoch loss', epoch_loss, 'epoch entry', epoch_num_entry) 203 | _ = torch.nn.utils.clip_grad_norm_(model_input_emb.parameters(), clip) 204 | _ = torch.nn.utils.clip_grad_norm_(model_enc.parameters(), clip) 205 | _ = torch.nn.utils.clip_grad_norm_(model_output_emb.parameters(), clip) 206 | _ = torch.nn.utils.clip_grad_norm_(model_dec.parameters(), clip) 207 | optimizer.step() 208 | 209 | print('epoch {} tf: {} train loss: {}'.format(epoch, tf_ratio, epoch_loss / epoch_num_entry)) 210 | 211 | if (epoch < args.saving_from) or (args.model_id is None): 212 | continue 213 | 214 | # start saving 215 | dev_perplexity = model_perplexity(test_loader, model_input_emb, model_enc, model_output_emb, model_dec, input_indexer, output_indexer, args) 216 | print('epoch {} tf: {} dev loss: {}'.format(epoch, tf_ratio, dev_perplexity)) 217 | 218 | if dev_perplexity < best_dev_perplexity: 219 | parameters = {'input_emb': model_input_emb.state_dict(), 'enc': model_enc.state_dict(), 220 | 'output_emb': model_output_emb.state_dict(), 'dec': model_dec.state_dict()} 221 | best_dev_perplexity = dev_perplexity 222 | torch.save(parameters, get_model_file(args.dataset, args.model_id + "-best")) 223 | 224 | if (epoch - args.saving_from) % args.saving_interval == 0: 225 | parameters = {'input_emb': model_input_emb.state_dict(), 'enc': model_enc.state_dict(), 226 | 'output_emb': model_output_emb.state_dict(), 'dec': model_dec.state_dict()} 227 | torch.save(parameters, get_model_file(args.dataset, args.model_id + "-" + str(epoch))) 228 | 229 | if __name__ == '__main__': 230 | args = _parse_args() 231 | print(args) 232 | # global device 233 | set_global_device(args.gpu) 234 | 235 | print("Pytroch using device ", config.device) 236 | random.seed(args.seed) 237 | np.random.seed(args.seed) 238 | # Load the training and test data 239 | train, dev, input_indexer, output_indexer = load_datasets(args.dataset) 240 | train_data_indexed, dev_data_indexed = index_datasets(train, dev, input_indexer, output_indexer, args.decoder_len_limit) 241 | 242 | train_model_encdec_ml(train_data_indexed, dev_data_indexed, input_indexer, output_indexer, args) 243 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | import time 3 | import pickle 4 | 5 | def render_ratio(numer, denom): 6 | return "%i / %i = %.3f" % (numer, denom, float(numer)/denom) 7 | 8 | class TimeLogger(object): 9 | 10 | def __init__(self): 11 | self.time_start = time.time() 12 | 13 | def restart(self): 14 | self.time_start = time.time() 15 | 16 | def log(self, msg): 17 | print("{} cost {}s".format(msg, time.time() - self.time_start)) 18 | 19 | # Bijection between objects and integers starting at 0. Useful for mapping 20 | # labels, features, etc. into coordinates of a vector space. 21 | class Indexer(object): 22 | def __init__(self): 23 | self.objs_to_ints = {} 24 | self.ints_to_objs = {} 25 | 26 | def __repr__(self): 27 | return str([str(self.get_object(i)) for i in range(0, len(self))]) 28 | 29 | def __str__(self): 30 | return self.__repr__() 31 | 32 | def __len__(self): 33 | return len(self.objs_to_ints) 34 | 35 | def get_object(self, index): 36 | if (index not in self.ints_to_objs): 37 | return None 38 | else: 39 | return self.ints_to_objs[index] 40 | 41 | def contains(self, object): 42 | return self.index_of(object) != -1 43 | 44 | # Returns -1 if the object isn't present, index otherwise 45 | def index_of(self, object): 46 | if (object not in self.objs_to_ints): 47 | return -1 48 | else: 49 | return self.objs_to_ints[object] 50 | 51 | # Adds the object to the index if it isn't present, always returns a nonnegative index 52 | def get_index(self, object, add=True): 53 | if not add: 54 | return self.index_of(object) 55 | if (object not in self.objs_to_ints): 56 | new_idx = len(self.objs_to_ints) 57 | self.objs_to_ints[object] = new_idx 58 | self.ints_to_objs[new_idx] = object 59 | return self.objs_to_ints[object] 60 | 61 | def save_to_file(self, filename): 62 | obj = { 63 | 'objs_to_ints': self.objs_to_ints, 64 | 'ints_to_objs': self.ints_to_objs 65 | } 66 | with open(filename, 'wb') as f: 67 | pickle.dump(obj, f) 68 | 69 | def load_from_file(self, filename): 70 | with open(filename, 'rb') as f: 71 | obj = pickle.load(f) 72 | self.objs_to_ints = obj['objs_to_ints'] 73 | self.ints_to_objs = obj['ints_to_objs'] 74 | 75 | 76 | # Map from objects to doubles that has a default value of 0 for all elements 77 | # Relatively inefficient (dictionary-backed); useful for sparse encoding of things like gradients, but shouldn't be 78 | # used for dense things like weight vectors (instead use an Indexer over the objects and use a numpy array to store the 79 | # values) 80 | class Counter(object): 81 | def __init__(self): 82 | self.counter = {} 83 | 84 | def __repr__(self): 85 | return str([str(key) + ": " + str(self.get_count(key)) for key in self.counter.keys()]) 86 | 87 | def __str__(self): 88 | return self.__repr__() 89 | 90 | def __len__(self): 91 | return len(self.counter) 92 | 93 | def keys(self): 94 | return self.counter.keys() 95 | 96 | def get_count(self, key): 97 | if key in self.counter: 98 | return self.counter[key] 99 | else: 100 | return 0 101 | 102 | def increment_count(self, obj, count): 103 | if obj in self.counter: 104 | self.counter[obj] = self.counter[obj] + count 105 | else: 106 | self.counter[obj] = count 107 | 108 | def increment_all(self, objs_list, count): 109 | for obj in objs_list: 110 | self.increment_count(obj, count) 111 | 112 | def set_count(self, obj, count): 113 | self.counter[obj] = count 114 | 115 | def add(self, otherCounter): 116 | for key in otherCounter.counter.keys(): 117 | self.increment_count(key, otherCounter.counter[key]) 118 | 119 | # Bad O(n) implementation right now 120 | def argmax(self): 121 | best_key = None 122 | for key in self.counter.keys(): 123 | if best_key is None or self.get_count(key) > self.get_count(best_key): 124 | best_key = key 125 | return best_key 126 | 127 | 128 | # Beam data structure. Maintains a list of scored elements like a Counter, but only keeps the top n 129 | # elements after every insertion operation. Insertion is O(n) (list is maintained in 130 | # sorted order), access is O(1). Still fast enough for practical purposes for small beams. 131 | class Beam(object): 132 | def __init__(self, size): 133 | self.size = size 134 | self.elts = [] 135 | self.scores = [] 136 | 137 | def __repr__(self): 138 | return "Beam(" + repr(list(self.get_elts_and_scores())) + ")" 139 | 140 | def __str__(self): 141 | return self.__repr__() 142 | 143 | def __len__(self): 144 | return len(self.elts) 145 | 146 | # Adds the element to the beam with the given score if the beam has room or if the score 147 | # is better than the score of the worst element currently on the beam 148 | def add(self, elt, score): 149 | if len(self.elts) == self.size and score < self.scores[-1]: 150 | # Do nothing because this element is the worst 151 | return 152 | # If the list contains the item with a lower score, remove it 153 | # i = 0 154 | # while i < len(self.elts): 155 | # if self.elts[i] == elt and score > self.scores[i]: 156 | # del self.elts[i] 157 | # del self.scores[i] 158 | # i += 1 159 | # If the list is empty, just insert the item 160 | if len(self.elts) == 0: 161 | self.elts.insert(0, elt) 162 | self.scores.insert(0, score) 163 | # Find the insertion point with binary search 164 | else: 165 | lb = 0 166 | ub = len(self.scores) - 1 167 | # We're searching for the index of the first element with score less than score 168 | while lb < ub: 169 | m = (lb + ub) // 2 170 | # Check > because the list is sorted in descending order 171 | if self.scores[m] > score: 172 | # Put the lower bound ahead of m because all elements before this are greater 173 | lb = m + 1 174 | else: 175 | # m could still be the insertion point 176 | ub = m 177 | # lb and ub should be equal and indicate the index of the first element with score less than score. 178 | # Might be necessary to insert at the end of the list. 179 | if self.scores[lb] > score: 180 | self.elts.insert(lb + 1, elt) 181 | self.scores.insert(lb + 1, score) 182 | else: 183 | self.elts.insert(lb, elt) 184 | self.scores.insert(lb, score) 185 | # Drop and item from the beam if necessary 186 | if len(self.scores) > self.size: 187 | self.elts.pop() 188 | self.scores.pop() 189 | 190 | def get_elts(self): 191 | return self.elts 192 | 193 | def get_elts_and_scores(self): 194 | return zip(self.elts, self.scores) 195 | 196 | def head(self): 197 | return self.elts[0] 198 | 199 | 200 | # Indexes a string feat using feature_indexer and adds it to feats. 201 | # If add_to_indexer is true, that feature is indexed and added even if it is new 202 | # If add_to_indexer is false, unseen features will be discarded 203 | def maybe_add_feature(feats, feature_indexer, add_to_indexer, feat): 204 | if add_to_indexer: 205 | feats.append(feature_indexer.get_index(feat)) 206 | else: 207 | feat_idx = feature_indexer.index_of(feat) 208 | if feat_idx != -1: 209 | feats.append(feat_idx) 210 | 211 | 212 | # Computes the dot product over a list of features (i.e., a sparse feature vector) 213 | # and a weight vector (numpy array) 214 | def score_indexed_features(feats, weights): 215 | score = 0.0 216 | for feat in feats: 217 | score += weights[feat] 218 | return score 219 | 220 | 221 | ################## 222 | # Tests 223 | def test_counter(): 224 | print("TESTING COUNTER") 225 | ctr = Counter() 226 | ctr.increment_count("a", 5) 227 | ctr.increment_count("b", 3) 228 | print(repr(ctr.get_count("a")) + " should be 5") 229 | ctr.increment_count("a", 5) 230 | print(repr(ctr.get_count("a")) + " should be 10") 231 | print(str(ctr.counter)) 232 | for key in ctr.counter.keys(): 233 | print(key) 234 | ctr2 = Counter() 235 | ctr2.increment_count("a", 3) 236 | ctr2.increment_count("c", 4) 237 | ctr.add(ctr2) 238 | print("%s should be ['a: 13', 'c: 4', 'b: 3']" % ctr) 239 | 240 | 241 | def test_beam(): 242 | print("TESTING BEAM") 243 | beam = Beam(3) 244 | beam.add("a", 5) 245 | beam.add("b", 7) 246 | beam.add("c", 6) 247 | beam.add("d", 4) 248 | print("Should contain b, c, a: %s" % beam) 249 | beam.add("e", 8) 250 | beam.add("f", 6.5) 251 | print("Should contain e, b, f: %s" % beam) 252 | beam.add("f", 9.5) 253 | print("Should contain f, e, b: %s" % beam) 254 | 255 | beam = Beam(5) 256 | beam.add("a", 5) 257 | beam.add("b", 7) 258 | beam.add("c", 6) 259 | beam.add("d", 4) 260 | print("Should contain b, c, a, d: %s" % beam) 261 | beam.add("e", 8) 262 | beam.add("f", 6.5) 263 | print("Should contain e, b, f, c, a: %s" % beam) 264 | 265 | if __name__ == '__main__': 266 | test_counter() 267 | test_beam() 268 | -------------------------------------------------------------------------------- /easy_eval/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import subprocess 4 | 5 | def check_equiv(spec0, spec1): 6 | if spec0 == spec1: 7 | # print("exact", spec0, spec1) 8 | return True 9 | # try: 10 | out = subprocess.check_output( 11 | ['java', '-cp', './external/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'equiv', 12 | spec0, spec1], stderr=subprocess.DEVNULL) 13 | out = out.decode("utf-8") 14 | out = out.rstrip() 15 | # if out == "true": 16 | # print("true", spec0, spec1) 17 | 18 | return out == "true" 19 | 20 | # examples: (xxx,'+), (xxx, '-') 21 | def check_io_consistency(spec, examples): 22 | # pred_line = " ".join(preds) 23 | pred_line = "{} {}".format(spec, spec) 24 | exs_line = " ".join(["{},{}".format(x[1], x[0]) for x in examples]) 25 | 26 | try: 27 | out = subprocess.check_output( 28 | ['java', '-cp', './external/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'preverify', 29 | pred_line, exs_line], stderr=subprocess.DEVNULL, timeout=5) 30 | except subprocess.TimeoutExpired as e: 31 | return False 32 | 33 | # stderr=subprocess.DEVNULL 34 | out = out.decode("utf-8") 35 | out = out.rstrip() 36 | # print(streg_ast.debug_form()) 37 | return out == "true" 38 | -------------------------------------------------------------------------------- /easy_eval/external/datagen.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/easy_eval/external/datagen.jar -------------------------------------------------------------------------------- /easy_eval/external/lib/antlr-4.7.1-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/easy_eval/external/lib/antlr-4.7.1-complete.jar -------------------------------------------------------------------------------- /easy_eval/streg_utils.py: -------------------------------------------------------------------------------- 1 | def tokenize_specification(x): 2 | y = [] 3 | while len(x) > 0: 4 | head = x[0] 5 | if head in ["(", ")", ","]: 6 | y.append(head) 7 | x = x[1:] 8 | elif head == "<": 9 | end = x.index(">") + 1 10 | y.append(x[:end]) 11 | x = x[end:] 12 | else: 13 | leftover = [(i in ["(", ")", "<", ">", ","]) for i in x] 14 | end = leftover.index(True) 15 | y.append(x[:end]) 16 | x = x[end:] 17 | return y 18 | 19 | def _consume_an_ast_node(tokens, cursor): 20 | pass 21 | 22 | # _NODE_CLASS_TO_RULE = { 23 | # "not": "Not", 24 | # "notcc": "NotCC", 25 | # "star": "Star", 26 | # "optional": "Optional", 27 | # "startwith": "StartWith", 28 | # "endwith": "EndWith", 29 | # "contain": "Contain", 30 | # "concat": "Concat", 31 | # "and": "And", 32 | # "or": "Or", 33 | # "repeat": "Repeat", 34 | # "repeatatleast": "RepeatAtleast", 35 | # "repeatrange": "RepeatRange", 36 | # "const": "String" 37 | # } 38 | 39 | def parse_spec_toks_to_ast(tokens): 40 | ast, final_cursor = _parse_spec_toks_to_ast(tokens, 0) 41 | assert final_cursor == len(tokens) 42 | return ast 43 | 44 | def _parse_spec_toks_to_ast(tokens, cursor): 45 | cur_tok = tokens[cursor] 46 | # unary operator 47 | if cur_tok in ['not', 'notcc', 'star', 'optional', 'startwith', 'endwith', 'contain']: 48 | assert tokens[cursor + 1] == '(' 49 | child, cursor = _parse_spec_toks_to_ast(tokens, cursor + 2) 50 | assert tokens[cursor] == ')' 51 | cursor += 1 52 | node = StRegNode(cur_tok, [child]) 53 | elif cur_tok in ['and', 'or', 'concat']: 54 | assert tokens[cursor + 1] == '(' 55 | left_child, cursor = _parse_spec_toks_to_ast(tokens, cursor + 2) 56 | assert tokens[cursor] == ',' 57 | right_child, cursor = _parse_spec_toks_to_ast(tokens, cursor + 1) 58 | assert tokens[cursor] == ')' 59 | cursor += 1 60 | node = StRegNode(cur_tok, [left_child, right_child]) 61 | elif cur_tok in ['repeat', 'repeatatleast']: 62 | assert tokens[cursor + 1] == '(' 63 | child, cursor = _parse_spec_toks_to_ast(tokens, cursor + 2) 64 | assert tokens[cursor] == ',' 65 | assert tokens[cursor + 1].isdigit() 66 | int_val = int(tokens[cursor + 1]) 67 | assert tokens[cursor + 2] == ')' 68 | cursor = cursor + 3 69 | node = StRegNode(cur_tok, [child], [int_val]) 70 | elif cur_tok in ['repeatrange']: 71 | assert tokens[cursor + 1] == '(' 72 | child, cursor = _parse_spec_toks_to_ast(tokens, cursor + 2) 73 | assert tokens[cursor] == ',' 74 | assert tokens[cursor + 1].isdigit() 75 | int_val1 = int(tokens[cursor + 1]) 76 | assert tokens[cursor + 2] == ',' 77 | assert tokens[cursor + 3].isdigit() 78 | int_val2 = int(tokens[cursor + 3]) 79 | cursor = cursor + 4 80 | node = StRegNode(cur_tok, [child], [int_val1, int_val2]) 81 | elif cur_tok.startswith('<') and cur_tok.endswith('>'): 82 | cursor += 1 83 | node = StRegNode(cur_tok) 84 | # not really a valid ast, need to be replaced with real const 85 | elif cur_tok.startswith('const'): 86 | cursor += 1 87 | node = StRegNode(cur_tok) 88 | else: 89 | raise RuntimeError('Not parsable', cur_tok) 90 | return node, cursor 91 | 92 | # pasre a specification to AST 93 | def parse_spec_to_ast(x): 94 | toks = tokenize_specification(x) 95 | ast = parse_spec_toks_to_ast(toks) 96 | assert x == ast.logical_form() 97 | return ast 98 | 99 | 100 | # ASTNoode 101 | # node_class: the name of nonterminal or terminal 102 | # children: list of children nodes 103 | # params: intergers for repeat/repeatatleast/repeatrange 104 | class StRegNode: 105 | def __init__(self, node_class, children=[], params=[]): 106 | self.node_class = node_class 107 | self.children = children 108 | self.params = params 109 | 110 | def logical_form(self): 111 | if len(self.children) + len(self.params) > 0: 112 | return self.node_class + "(" + ",".join([x.logical_form() for x in self.children] + [str(x) for x in self.params]) + ")" 113 | else: 114 | return self.node_class 115 | 116 | def debug_form(self): 117 | if len(self.children) + len(self.params) > 0: 118 | return str(self.node_class) + "(" + ",".join([x.debug_form() if x is not None else str(x) for x in self.children] + [str(x) for x in self.params]) + ")" 119 | else: 120 | return str(self.node_class) 121 | 122 | def short_debug_form(self): 123 | x = self.debug_form() 124 | tunct_pair = [('None', '?'), ('concat', 'cat'), ('repeatatleast', 'rp+'), ('repeatrange', 'rprng'), ('repeat', 'rp'), ('optional', 'optn')] 125 | # x = x.replace('concat', 'cat') 126 | for a, b in tunct_pair: 127 | x = x.replace(a, b) 128 | return x 129 | 130 | def tokenized_logical_form(self): 131 | if len(self.children) + len(self.params) > 0: 132 | toks = [self.node_class] + ["("] 133 | toks.extend(self.children[0].tokenized_logical_form()) 134 | for c in self.children[1:]: 135 | toks.append(",") 136 | toks.extend(c.tokenized_logical_form()) 137 | for p in [str(x) for x in self.params]: 138 | toks.append(",") 139 | toks.append(p) 140 | toks.append(")") 141 | return toks 142 | else: 143 | return [self.node_class] 144 | 145 | # some operators can't be converted: not, and 146 | def standard_regex(self): 147 | if self.node_class == '': 148 | return '[A-Za-z]' 149 | elif self.node_class == '': 150 | return '[0-9]' 151 | elif self.node_class == 'concat': 152 | return '(%s)(%s)' % (self.children[0].standard_regex(), self.children[1].standard_regex()) 153 | elif self.node_class == 'contain': 154 | return '.*(%s).*' % (self.children[0].standard_regex()) 155 | elif self.node_class == 'repeatatleast': 156 | return '(%s){%d,}' % (self.children[0].standard_regex(), self.params[0]) 157 | else: 158 | # add code for parsing other terminals and operators 159 | raise NotImplementedError('Please fill in') 160 | -------------------------------------------------------------------------------- /easy_eval/usage_example.py: -------------------------------------------------------------------------------- 1 | from eval import check_equiv, check_io_consistency 2 | from streg_utils import parse_spec_to_ast 3 | 4 | # check equivalance 5 | print('EXPECTED TRUE', check_equiv('or(,)', 'or(,)')) 6 | print('EXPECTED TRUE',check_equiv('or(,)', '')) 7 | print('EXPECTED FALSE',check_equiv('concat(,)', 'concat(,)')) 8 | 9 | # check example consistency 10 | 11 | spec = 'and(repeatatleast(or(,or(,<^>)),1),and(not(startwith()),and(not(startwith(<^>)),not(contain(concat(notcc(),<^>))))))' 12 | good_examples = [('ItrdY', '+'), ('JIQD', '+'), ('GAFXvIc^j^l^o^op', '+'), ('WZpg^y^eMrXSfXTqHw^', '+'), ('Y', '+'), ('Jw', '+'), ('cvZpBMcQKAqAXj', '-'), ('X^^mwwSbU^Wk^', '-'), ('ZHQgmLzM^', '-'), ('.-;-g', '-'), (':;A:', '-'), ('Ew^^B^Kcc^zR', '-')] 13 | bad_examples1 = [('123ItrdY', '+'), ('ItrdY', '+'), ('JIQD', '+'), ('GAFXvIc^j^l^o^op', '+'), ('WZpg^y^eMrXSfXTqHw^', '+'), ('Y', '+'), ('Jw', '+'), ('cvZpBMcQKAqAXj', '-'), ('X^^mwwSbU^Wk^', '-'), ('ZHQgmLzM^', '-'), ('.-;-g', '-'), (':;A:', '-'), ('Ew^^B^Kcc^zR', '-')] 14 | bad_examples2 = [('ItrdY', '-'), ('JIQD', '+'), ('GAFXvIc^j^l^o^op', '+'), ('WZpg^y^eMrXSfXTqHw^', '+'), ('Y', '+'), ('Jw', '+'), ('cvZpBMcQKAqAXj', '-'), ('X^^mwwSbU^Wk^', '-'), ('ZHQgmLzM^', '-'), ('.-;-g', '-'), (':;A:', '-'), ('Ew^^B^Kcc^zR', '-')] 15 | 16 | print('EXPECTED TRUE',check_io_consistency(spec, good_examples)) 17 | print('EXPECTED FALSE',check_io_consistency(spec, bad_examples1)) 18 | print('EXPECTED FALSE',check_io_consistency(spec, bad_examples2)) 19 | 20 | 21 | # skeleton for converting to standard regex 22 | print('concat(,)', parse_spec_to_ast('concat(,)').standard_regex()) 23 | print('contain()', parse_spec_to_ast('contain()').standard_regex()) 24 | print('repeatatleast(concat(,),3)', parse_spec_to_ast('repeatatleast(concat(,),3)').standard_regex()) 25 | 26 | 27 | ast = parse_spec_to_ast(spec) 28 | print(ast.logical_form()) 29 | std_regex = ast.standard_regex() 30 | -------------------------------------------------------------------------------- /quick_eval/external/backend.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/quick_eval/external/backend.jar -------------------------------------------------------------------------------- /quick_eval/external/lib/antlr-4.7.1-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/quick_eval/external/lib/antlr-4.7.1-complete.jar -------------------------------------------------------------------------------- /quick_eval/regex_backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from os.path import join 4 | import subprocess 5 | from select import select 6 | 7 | class DatagenBackend: 8 | proc = None 9 | stdin = None 10 | stdout = None 11 | stderr = None 12 | 13 | @classmethod 14 | def start_backend(cls): 15 | DatagenBackend.proc = subprocess.Popen( 16 | ['java', '-cp', './external/backend.jar:./external/lib/*', '-ea', 'datagen.Main', 'preverify_server'], 17 | stdin=subprocess.PIPE, 18 | stdout=subprocess.PIPE, 19 | stderr=subprocess.PIPE, 20 | ) 21 | 22 | DatagenBackend.stdin = io.TextIOWrapper(DatagenBackend.proc.stdin, line_buffering=True) 23 | DatagenBackend.stdout = io.TextIOWrapper(DatagenBackend.proc.stdout) 24 | DatagenBackend.stderr = io.TextIOWrapper(DatagenBackend.proc.stderr) 25 | 26 | @classmethod 27 | def restart_backend(cls): 28 | if DatagenBackend.proc: 29 | if DatagenBackend.proc.poll() is None: 30 | DatagenBackend.proc.kill() 31 | DatagenBackend.start_backend() 32 | 33 | @classmethod 34 | def exit_backend(cls): 35 | if DatagenBackend.proc: 36 | if DatagenBackend.proc.poll() is None: 37 | DatagenBackend.proc.kill() 38 | # DatagenBackend.proc.kill() 39 | 40 | @classmethod 41 | def exec(cls, command, timeout=1, is_retry=False): 42 | try: 43 | # out,err = c.proc.communicate("{}\t{}".format(pred_line, exs_line).encode("utf-8"), timeout=2) 44 | DatagenBackend.stdin.write(command) 45 | # print(command) 46 | r_list, _, _ = select([DatagenBackend.stdout], [], [], 1) 47 | if r_list: 48 | out = DatagenBackend.stdout.readline() 49 | else: 50 | DatagenBackend.restart_backend() 51 | raise subprocess.TimeoutExpired(command, timeout) 52 | except BrokenPipeError as e: 53 | if is_retry: 54 | raise e 55 | else: 56 | DatagenBackend.restart_backend() 57 | return cls.exec(command, timeout, is_retry=True) 58 | except Exception as e: 59 | DatagenBackend.restart_backend() 60 | raise e 61 | return out 62 | 63 | def check_regex_equiv(spec0, spec1): 64 | if spec0 == spec1: 65 | # print("exact", spec0, spec1) 66 | return True 67 | cmd_line = "CHECKEQUIV\t{}\t{}\n".format(spec0, spec1) 68 | try: 69 | out = DatagenBackend.exec(cmd_line, timeout=1) 70 | except subprocess.TimeoutExpired as e: 71 | return False 72 | except BrokenPipeError as e: 73 | return False 74 | out = out.rstrip() 75 | return out == "true" 76 | 77 | 78 | 79 | def check_io_consistency(spec, examples): 80 | # pred_line = " ".join(preds) 81 | 82 | pred_line = "{} {}".format(spec, spec) 83 | exs_line = " ".join(["{},{}".format(x[1], x[0]) for x in examples]) 84 | cmd_line = "{}\t{}\n".format(pred_line, exs_line) 85 | try: 86 | out = DatagenBackend.exec(cmd_line, timeout=1) 87 | except Exception as e: 88 | return False 89 | 90 | out = out.rstrip() 91 | 92 | return out == "true" 93 | 94 | if __name__=="__main__": 95 | DatagenBackend.start_backend() 96 | print('EXPECTED TRUE', check_regex_equiv('or(,)', 'or(,)')) 97 | print('EXPECTED TRUE',check_regex_equiv('or(,)', '')) 98 | print('EXPECTED FALSE',check_regex_equiv('concat(,)', 'concat(,)')) 99 | 100 | spec = 'and(repeatatleast(or(,or(,<^>)),1),and(not(startwith()),and(not(startwith(<^>)),not(contain(concat(notcc(),<^>))))))' 101 | good_examples = [('ItrdY', '+'), ('JIQD', '+'), ('GAFXvIc^j^l^o^op', '+'), ('WZpg^y^eMrXSfXTqHw^', '+'), ('Y', '+'), ('Jw', '+'), ('cvZpBMcQKAqAXj', '-'), ('X^^mwwSbU^Wk^', '-'), ('ZHQgmLzM^', '-'), ('.-;-g', '-'), (':;A:', '-'), ('Ew^^B^Kcc^zR', '-')] 102 | bad_examples1 = [('123ItrdY', '+'), ('ItrdY', '+'), ('JIQD', '+'), ('GAFXvIc^j^l^o^op', '+'), ('WZpg^y^eMrXSfXTqHw^', '+'), ('Y', '+'), ('Jw', '+'), ('cvZpBMcQKAqAXj', '-'), ('X^^mwwSbU^Wk^', '-'), ('ZHQgmLzM^', '-'), ('.-;-g', '-'), (':;A:', '-'), ('Ew^^B^Kcc^zR', '-')] 103 | bad_examples2 = [('ItrdY', '-'), ('JIQD', '+'), ('GAFXvIc^j^l^o^op', '+'), ('WZpg^y^eMrXSfXTqHw^', '+'), ('Y', '+'), ('Jw', '+'), ('cvZpBMcQKAqAXj', '-'), ('X^^mwwSbU^Wk^', '-'), ('ZHQgmLzM^', '-'), ('.-;-g', '-'), (':;A:', '-'), ('Ew^^B^Kcc^zR', '-')] 104 | print('EXPECTED TRUE',check_io_consistency(spec, good_examples)) 105 | print('EXPECTED FALSE',check_io_consistency(spec, bad_examples1)) 106 | print('EXPECTED FALSE',check_io_consistency(spec, bad_examples2)) 107 | DatagenBackend.exit_backend() 108 | -------------------------------------------------------------------------------- /toolkit/README.md: -------------------------------------------------------------------------------- 1 | # Sampling Regexes and Examples 2 | See `usage_example.py` -------------------------------------------------------------------------------- /toolkit/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | LOG_FLAG = False 4 | def ctrl_logger(*args): 5 | if LOG_FLAG: 6 | print(args) 7 | 8 | def random_decision(p): 9 | r = random.random() 10 | return r < p 11 | 12 | def weighted_random_decision(choices, p): 13 | return random.choices(choices, weights=p)[0] 14 | 15 | class Function(): 16 | def __init__(self, *args): 17 | self.parent = None 18 | self.children = [] 19 | self.params = [] 20 | self.lineage = [] 21 | for value in args: 22 | if issubclass(value.__class__, Function): 23 | self.children.append(value) 24 | value.parent = self 25 | else: 26 | self.params.append(value) 27 | 28 | def logical_form(self): 29 | raise Exception("Not implemented for {}".format( 30 | self.__class__.__name__)) 31 | 32 | def description(self): 33 | raise Exception("Not implemented for {}".format( 34 | self.__class__.__name__)) 35 | 36 | def specification(self): 37 | raise Exception("Not implemented for {}".format( 38 | self.__class__.__name__)) 39 | # specification 40 | def ground_truth(self): 41 | return self.specification() 42 | 43 | # string of the DSL used for generating regexes. This is NOT specification. It's used for reconstruct the FUNCTION class 44 | def to_string(self): 45 | return "{}({})".format(self.__class__.__name__, ",".join([x.to_string() for x in self.children] + [str(x) for x in self.params])) 46 | 47 | def get_all_functions_flat_list(self): 48 | cur_list = [] 49 | for child in self.children: 50 | if child: 51 | cur_list = cur_list + [child] 52 | cid = child.get_all_functions_flat_list() 53 | if cid: 54 | cur_list = cur_list + cid 55 | return cur_list 56 | 57 | def set_root(self, root): 58 | self.root = root 59 | 60 | def sample_negative(self): 61 | raise Exception("Not implemented for {}".format( 62 | self.__class__.__name__)) 63 | 64 | def negative_candidates(self): 65 | raise Exception("Not implemented for {}".format( 66 | self.__class__.__name__)) 67 | 68 | class NoneToken(Function): 69 | def logical_form(self): 70 | return "NONE" 71 | 72 | def description(self): 73 | return "NONE" 74 | 75 | class Token(Function): 76 | cnt = -1 77 | 78 | def specification(self): 79 | return self.logical_form() 80 | 81 | def sample_negative(self): 82 | return NotCCCons(self) 83 | 84 | class NumToken(Token): 85 | def logical_form(self): 86 | return "" 87 | 88 | def description(self): 89 | return "a number" 90 | 91 | @classmethod 92 | def random_tok(cls): 93 | return str(random.randint(0, 9)) 94 | 95 | @classmethod 96 | def nice_tok(cls): 97 | cls.cnt += 1 98 | # return ["0", "1", "2", "3"][cls.cnt % 4] 99 | return random.choice(["0", "1", "2", "3"]) 100 | 101 | @classmethod 102 | def nice_string(cls): 103 | cls.cnt += 1 104 | # return ["00", "012", "999", "99"][cls.cnt % 4] 105 | return random.choice(["00", "012", "01", "999", "99"]) 106 | 107 | # Tokens 108 | class CapitalToken(Token): 109 | def logical_form(self): 110 | return "" 111 | 112 | def description(self): 113 | return "a capital letter" 114 | 115 | @classmethod 116 | def random_tok(cls): 117 | return chr(random.randint(1, 26) + 64) 118 | 119 | @classmethod 120 | def nice_tok(cls): 121 | cls.cnt += 1 122 | # return ["A", "B", "C", "D"][cls.cnt % 4] 123 | return random.choice(["A", "B", "C", "D"]) 124 | 125 | @classmethod 126 | def nice_string(cls): 127 | cls.cnt += 1 128 | # return ["AA", "ABC", "XX", "AA"][cls.cnt % 4] 129 | return random.choice(["AA", "ABC", "XX", "AA", "XYZ"]) 130 | 131 | class LowerToken(Token): 132 | def logical_form(self): 133 | return "" 134 | 135 | def description(self): 136 | return "a lower-case letter" 137 | 138 | @classmethod 139 | def random_tok(cls): 140 | return chr(random.randint(1, 26) + 96) 141 | 142 | @classmethod 143 | def nice_tok(cls): 144 | cls.cnt += 1 145 | # return ["a", "b", "c", "d"][cls.cnt % 4] 146 | return random.choice(["a", "b", "c", "d"]) 147 | 148 | 149 | @classmethod 150 | def nice_string(cls): 151 | cls.cnt += 1 152 | # return ["aa", "abc", "xx", "aa"][cls.cnt % 4] 153 | return random.choice(["aa", "abc", "xx", "aa", "xyz"]) 154 | 155 | 156 | class LetterToken(Token): 157 | def logical_form(self): 158 | return "" 159 | 160 | def description(self): 161 | return "a letter" 162 | 163 | @classmethod 164 | def random_tok(cls): 165 | x = random.random() 166 | if x > 0.5: 167 | return LowerToken.random_tok() 168 | else: 169 | return CapitalToken.random_tok() 170 | 171 | @classmethod 172 | def nice_tok(cls): 173 | cls.cnt += 1 174 | # return ["a", "A", "b", "B"][cls.cnt % 4] 175 | return random.choice(["a", "A", "b", "B"]) 176 | 177 | @classmethod 178 | def nice_string(cls): 179 | cls.cnt += 1 180 | # return ["aA", "Aa", "aAa", "AaA"][cls.cnt % 4] 181 | return random.choice(["aA", "Aa", "aB", "aAa", "AaA"]) 182 | 183 | # allowed space, dash, comma, colon, dot, plus, underscore, 184 | class SpecialToken(Token): 185 | spec_toks_ = [ "-", ",", ";", ".", "_", "+", ":", "!", "@", "#", "$", "%", "&", "^", "*", "="] 186 | nice_toks_ = ["-", ",", ";", "."] 187 | div_toks_ = ["-", ",", ";", "."] 188 | spec_toks = [ "-", ",", ";", ".", "_", "+", ":", "!", "@", "#", "$", "%", "&", "^", "*", "="] 189 | nice_toks = ["-", ",", ";", "."] 190 | def logical_form(self): 191 | return "" 192 | 193 | def description(self): 194 | return "a special character" 195 | 196 | @classmethod 197 | def random_tok(cls): 198 | return random.choice(cls.spec_toks) 199 | 200 | @classmethod 201 | def nice_tok(cls): 202 | cls.cnt += 1 203 | # return cls.nice_toks[cls.cnt % len(cls.nice_toks)] 204 | return random.choice(cls.nice_toks) 205 | 206 | @classmethod 207 | def gen_div_tok(cls): 208 | tok = random.choice(cls.div_toks_) 209 | return SingleToken(SpecialToken, tok) 210 | 211 | @classmethod 212 | def nice_string(cls): 213 | cls.cnt += 1 214 | # return 2*cls.nice_toks[cls.cnt % len(cls.nice_toks)] 215 | return 2 * random.choice(cls.nice_toks) 216 | 217 | @classmethod 218 | def screen_tok(cls, tok): 219 | cls.spec_toks = [x for x in cls.spec_toks_ if x != tok] 220 | cls.nice_toks = [x for x in cls.nice_toks_ if x != tok] 221 | 222 | @classmethod 223 | def restore(cls): 224 | cls.spec_toks = cls.spec_toks_[:] 225 | cls.nice_toks = cls.nice_toks_[:] 226 | 227 | class CharacterToken(Token): 228 | def logical_form(self): 229 | return "" 230 | 231 | def description(self): 232 | return "a character" 233 | 234 | 235 | # allowed space, dash, comma, colon, dot, plus, underscore, 236 | # input a list of CC or SingleToken 237 | 238 | class SingleToken(Token): 239 | def __init__(self, cc_type, tok): 240 | super().__init__() 241 | self.parent = None 242 | self.cc_type = cc_type 243 | self.tok = tok 244 | 245 | @classmethod 246 | def generate(cls, choices, is_random=True): 247 | base = random.choice(choices) 248 | if isinstance(base, SingleToken): 249 | cc_type = base.cc_type 250 | tok = base.tok 251 | return cls(cc_type, tok) 252 | else: 253 | cc_type = base 254 | if is_random: 255 | tok = cc_type.random_tok() 256 | else: 257 | # tok = cc_type.nice_tok() 258 | tok = cc_type.random_tok() 259 | return cls(cc_type, tok) 260 | 261 | def logical_form(self): 262 | return '<{}>'.format(self.tok) 263 | 264 | def description(self): 265 | return '"{}"'.format(self.tok) 266 | 267 | def to_string(self): 268 | return "SingleToken(<{}>,<{}>)".format(self.cc_type.__name__, self.tok) 269 | 270 | def sample_negative(self): 271 | return SingleToken.generate([self.cc_type]) 272 | 273 | class StringToken(Token): 274 | def __init__(self, cc_type, tok): 275 | super().__init__() 276 | self.parent = None 277 | self.cc_type = cc_type 278 | self.tok = tok 279 | 280 | @classmethod 281 | def generate(cls, choices, is_random=True, length=-1): 282 | cc_s = [x for x in choices if not isinstance(x, SingleToken)] 283 | if cc_s: 284 | base = random.choice(cc_s) 285 | if length > 0: 286 | tok = "".join([base.random_tok() for _ in range(length)]) 287 | else: 288 | if is_random: 289 | tok = "".join([base.random_tok() for _ in range(random.randint(2,3))]) 290 | else: 291 | # tok = base.nice_string() 292 | tok = "".join([base.random_tok() for _ in range(random.randint(2,3))]) 293 | return cls(base, tok) 294 | else: 295 | return NoneToken() 296 | 297 | def logical_form(self): 298 | return '<{}>'.format(self.tok) 299 | 300 | def description(self): 301 | return '"the string {}"'.format(self.tok) 302 | 303 | def specification(self): 304 | return "const(<{}>)".format(self.tok) 305 | 306 | def ground_truth(self): 307 | return 'const(<{}>))'.format(self.tok) 308 | 309 | def to_string(self): 310 | return "StringToken(<{}>,<{}>)".format(self.cc_type.__name__, self.tok) 311 | 312 | def sample_negative(self): 313 | return StringToken.generate([self.cc_type]) 314 | 315 | class SingleOrSingleToken(Token): 316 | pass 317 | 318 | class StringOrStringToken(Token): 319 | pass 320 | 321 | class Composition(Function): 322 | pass 323 | 324 | class OrComp(Composition): 325 | def logical_form(self): 326 | return "Or({},{})".format(self.children[0].logical_form(), self.children[1].logical_form()) 327 | 328 | def description(self): 329 | return "{} or {}".format(self.children[0].description(),self.children[1].description()) 330 | 331 | def specification(self): 332 | return "or({},{})".format(self.children[0].specification(), self.children[1].specification()) 333 | 334 | def sample_negative(self): 335 | return random.choice(self.children).sample_negative() 336 | 337 | class ConcatComp(Composition): 338 | def logical_form(self): 339 | if len(self.children) == 1: 340 | return self.children[0].logical_form() 341 | else: 342 | return "Concat({})".format(",".join([x.logical_form() for x in self.children])) 343 | 344 | def description(self): 345 | if len(self.children) == 1: 346 | return self.children[0].logical_form() 347 | else: 348 | return " followed by ".join([x.logical_form() for x in self.children]) 349 | 350 | def specification(self): 351 | return self.concat_type_specification(self.children) 352 | 353 | def sample_negative(self): 354 | i = random.choice(range(len(self.children))) 355 | new_children = self.children[:i] + [self.children[i].sample_negative()] + self.children[i+1:] 356 | if len(new_children) == 0: 357 | return NoneToken() 358 | if len(new_children) == 1: 359 | return new_children[0] 360 | return ConcatComp(*new_children) 361 | 362 | @staticmethod 363 | def concat_type_specification(children): 364 | if len(children) == 1: 365 | return "{}".format(children[0].specification()) 366 | else: 367 | r = [x.specification() for x in children] 368 | r.reverse() 369 | y = r[0] 370 | for c in r[1:]: 371 | y = 'concat({},{})'.format(c,y) 372 | return y 373 | 374 | class AndComp(Composition): 375 | def logical_form(self): 376 | return "And({},{})".format(self.children[0].logical_form(), self.children[1].logical_form()) 377 | 378 | def description(self): 379 | return "{} and {}".format(self.children[0].description(),self.children[1].description()) 380 | 381 | def specification(self): 382 | return "and({},{})".format(self.children[0].specification(), self.children[1].specification()) 383 | 384 | def sample_negative(self): 385 | i = random.choice(range(len(self.children))) 386 | new_children = self.children[:i] + [self.children[i].sample_negative()] + self.children[i+1:] 387 | return AndComp(*new_children) 388 | 389 | @staticmethod 390 | def and_type_specification(children): 391 | if len(children) == 1: 392 | return "{}".format(children[0].specification()) 393 | else: 394 | r = [x.specification() for x in children] 395 | r.reverse() 396 | y = r[0] 397 | for c in r[1:]: 398 | y = 'and({},{})'.format(c,y) 399 | return y 400 | 401 | class Constraint(Function): 402 | def sample_negative(self): 403 | return NotCons(self) 404 | 405 | class NotCons(Constraint): 406 | def logical_form(self): 407 | return "Not({})".format(self.children[0].logical_form()) 408 | 409 | def description(self): 410 | return "Not {} ".format(self.children[0].description()) 411 | 412 | def specification(self): 413 | return "not({})".format(self.children[0].specification()) 414 | 415 | def sample_negative(self): 416 | return self.children[0] 417 | 418 | class NotCCCons(Constraint): 419 | def logical_form(self): 420 | return "NotCC({})".format(self.children[0].logical_form()) 421 | 422 | def description(self): 423 | return "NotCC {} ".format(self.children[0].description()) 424 | 425 | def specification(self): 426 | return "notcc({})".format(self.children[0].specification()) 427 | 428 | def sample_negative(self): 429 | return self.children[0] 430 | 431 | class CCToken(Token): 432 | pass 433 | 434 | class ConstSet2Token(Token): 435 | pass 436 | 437 | class SimpleConcat2Token(Token): 438 | pass 439 | 440 | class SimpleOr2Token(Token): 441 | pass 442 | 443 | class Modifier(Function): 444 | pass 445 | 446 | def soft_remove(l, x): 447 | if x in l: 448 | l.remove(x) 449 | 450 | def is_valid(x): 451 | if not isinstance(x, NoneToken): 452 | return all([is_valid(y) for y in x.children]) 453 | else: 454 | return False 455 | 456 | def tok(x): 457 | y = [] 458 | while len(x) > 0: 459 | head = x[0] 460 | if head in ["?", "(", ")", "{", "}", ","]: 461 | y.append(head) 462 | x = x[1:] 463 | elif head == "<": 464 | end = x.index(">") + 1 465 | y.append(x[:end]) 466 | x = x[end:] 467 | else: 468 | leftover = [(i in ["?", "(", ")", "{", "}", "<", ">", ","]) for i in x] 469 | end = leftover.index(True) 470 | y.append(x[:end]) 471 | x = x[end:] 472 | return "|".join(y) 473 | -------------------------------------------------------------------------------- /toolkit/constraints.py: -------------------------------------------------------------------------------- 1 | from base import * 2 | 3 | CC = [NumToken, LetterToken, CapitalToken, LowerToken, SpecialToken] 4 | CC_NO_SPEC = [NumToken, LetterToken, CapitalToken, LowerToken] 5 | SINGLE_TOKEN_CHOICE = [NumToken, LetterToken, CapitalToken, LowerToken, SpecialToken] 6 | 7 | class ComposedByCons(Constraint): 8 | def logical_form(self): 9 | return "Star(Or({}))".format(",".join([x.logical_form() for x in self.children])) 10 | 11 | def description(self): 12 | return "Composed by {}".format(", " .join([x.logical_form() for x in self.children])) 13 | 14 | def specification(self): 15 | if len(self.children) == 1: 16 | return "repeatatleast({},1)".format(",".join([x.specification() for x in self.children])) 17 | else: 18 | r = [x.specification() for x in self.children] 19 | r.reverse() 20 | y = r[0] 21 | for c in r[1:]: 22 | y = 'or({},{})'.format(c,y) 23 | return "repeatatleast({},1)".format(y) 24 | 25 | @classmethod 26 | def generate(cls, max_element=4): 27 | cls.config_children_attr() 28 | things = [] 29 | allowed = [] 30 | num_elements = random.randint(1, max_element) 31 | 32 | element_choices = CC[:] 33 | single_choices = CC[:] 34 | if num_elements == 1: 35 | picked_type = random.choice(CC_NO_SPEC) 36 | picked = picked_type() 37 | things.append(picked) 38 | allowed.append(picked_type) 39 | else: 40 | for _ in range(num_elements): 41 | if random_decision(0.15 if (len(element_choices) > 0) else 1.0): 42 | picked_type = SingleToken 43 | else: 44 | picked_type = random.choice(element_choices) 45 | if picked_type == SingleToken: 46 | picked = SingleToken.generate(single_choices, False) 47 | allowed.append(picked) 48 | else: 49 | picked = picked_type() 50 | allowed.append(picked_type) 51 | things.append(picked) 52 | cls.resolve_avaliable_choices( 53 | element_choices, single_choices, picked_type, picked) 54 | if len(element_choices) == 0 and len(single_choices) == 0: 55 | break 56 | if things: 57 | ComposedByCons.sort_tokens(things) 58 | return cls(*things), allowed 59 | return NoneToken(), None 60 | 61 | @staticmethod 62 | def sort_tokens(tokens): 63 | tokens.sort(key=lambda x: 1 if isinstance(x, SingleToken) else 0) 64 | 65 | @classmethod 66 | def resolve_avaliable_choices(self, element_choices, single_choices, picked_type, picked): 67 | # deal with element choice 68 | # deal with single choice 69 | if picked_type == SingleToken: 70 | soft_remove(element_choices, picked.cc_type) 71 | # soft_remove(element_choices, SingleToken) 72 | # soft_remove(element_choices, SingleOrSingleToken) 73 | if picked.cc_type == LetterToken: 74 | soft_remove(element_choices, LowerToken) 75 | soft_remove(element_choices, CapitalToken) 76 | soft_remove(single_choices, LowerToken) 77 | soft_remove(single_choices, CapitalToken) 78 | if picked.cc_type == CapitalToken or picked.cc_type == LowerToken: 79 | soft_remove(element_choices, LetterToken) 80 | soft_remove(single_choices, LetterToken) 81 | return 82 | 83 | soft_remove(element_choices, picked_type) 84 | soft_remove(single_choices, picked_type) 85 | 86 | if picked_type == LetterToken: 87 | soft_remove(single_choices, LowerToken) 88 | soft_remove(single_choices, CapitalToken) 89 | soft_remove(element_choices, LowerToken) 90 | soft_remove(element_choices, CapitalToken) 91 | if picked_type == CapitalToken or picked_type == LowerToken: 92 | soft_remove(element_choices, LetterToken) 93 | soft_remove(single_choices, LetterToken) 94 | 95 | @classmethod 96 | def config_children_attr(cls): 97 | for c in CC: 98 | c.cnt = -1 99 | 100 | def sample_negative(self): 101 | # avaliable violation 102 | options = [] 103 | CC_LEFT = CC[:] 104 | for c in self.children: 105 | if isinstance(c, SingleToken): 106 | options.append(c.cc_type) 107 | elif isinstance(c, LetterToken): 108 | soft_remove(CC_LEFT, LetterToken) 109 | soft_remove(CC_LEFT, CapitalToken) 110 | soft_remove(CC_LEFT, LowerToken) 111 | elif isinstance(c, CapitalToken): 112 | soft_remove(CC_LEFT, CapitalToken) 113 | soft_remove(CC_LEFT, LetterToken) 114 | elif isinstance(c, LowerToken): 115 | soft_remove(CC_LEFT, LowerToken) 116 | soft_remove(CC_LEFT, LetterToken) 117 | elif isinstance(c, NumToken): 118 | soft_remove(CC_LEFT, NumToken) 119 | elif isinstance(c, SpecialToken): 120 | soft_remove(CC_LEFT, SpecialToken) 121 | # SingleToken -> Full 122 | # Digit Letter 123 | if len(options): 124 | new_base = random.choice(options) 125 | new_children = self.children + [new_base()] 126 | return ComposedByCons(*new_children) 127 | options.extend(CC_LEFT) 128 | if len(options): 129 | new_base = random.choice(options) 130 | new_children = self.children + [new_base()] 131 | return ComposedByCons(*new_children) 132 | else: 133 | return NoneToken() 134 | 135 | class NotOnlyComposedByCons(Constraint): 136 | pass 137 | 138 | class ContainCons(Constraint): 139 | def logical_form(self): 140 | return "Contain({})".format(self.children[0].logical_form()) 141 | 142 | def description(self): 143 | return "Contain {}".format(self.children[0].description()) 144 | 145 | def specification(self): 146 | return "contain({})".format(self.children[0].specification()) 147 | 148 | @classmethod 149 | def generate(cls, allowed=None, ban_orcons=False): 150 | things = [] 151 | 152 | containable_helper = ContainableFiledsHelper(allowed) 153 | things = [containable_helper.generate(ban_orcons=ban_orcons)] 154 | if things: 155 | return cls(*things) 156 | return NoneToken() 157 | 158 | class NotContainCons(Constraint): 159 | @classmethod 160 | def generate(cls, allowed=None): 161 | contain_cons = ContainCons.generate(allowed=allowed, ban_orcons=True) 162 | 163 | if contain_cons is not None: 164 | return NotCons(*[contain_cons]) 165 | else: 166 | return NoneToken() 167 | 168 | 169 | # (can) contain x, but must be followed by y 170 | # (can) contain x, but must be preceded by y 171 | # y must be a CC 172 | 173 | class ConditionalContainCons(Constraint): 174 | @classmethod 175 | def generate(cls, allowed=None): 176 | if allowed: 177 | cc_s = [y for y in allowed if y in CC_NO_SPEC] 178 | else: 179 | cc_s = CC_NO_SPEC[:] 180 | if not cc_s: 181 | return NoneToken() 182 | y = random.choice(cc_s) 183 | x_s = allowed[:] if allowed else CC[:] 184 | soft_remove(x_s, y) 185 | if not x_s: 186 | return NoneToken() 187 | x = random.choice(x_s) 188 | if x in CC: 189 | x = x() 190 | if y in CC: 191 | y = y() 192 | condition = NotCCCons(y) 193 | if random_decision(0.5): 194 | order = [x, condition] 195 | cons = NotCons(ContainCons(ConcatComp(*order))) 196 | return AndComp(NotCons(EndwithCons(x)), cons) 197 | else: 198 | order = [condition, x] 199 | cons = NotCons(ContainCons(ConcatComp(*order))) 200 | return AndComp(NotCons(StartwithCons(x)), cons) 201 | 202 | class StartwithCons(Constraint): 203 | def logical_form(self): 204 | return "StartWith({})".format(self.children[0].logical_form()) 205 | 206 | def description(self): 207 | return "StartWith {}".format(self.children[0].description()) 208 | 209 | def specification(self): 210 | return "startwith({})".format(self.children[0].specification()) 211 | 212 | @classmethod 213 | def generate(cls, allowed=None, ban_orcons=False): 214 | things = [] 215 | helper = SEWithFiledHelper(allowed) 216 | things = [helper.generate(ban_orcons=ban_orcons)] 217 | if things: 218 | return cls(*things) 219 | return NoneToken() 220 | 221 | class EndwithCons(Constraint): 222 | def logical_form(self): 223 | return "EndWith({})".format(self.children[0].logical_form()) 224 | 225 | def description(self): 226 | return "End with {}".format(self.children[0].description()) 227 | 228 | def specification(self): 229 | return "endwith({})".format(self.children[0].specification()) 230 | 231 | @classmethod 232 | def generate(cls, allowed=None, ban_orcons=False): 233 | things = [] 234 | 235 | helper = SEWithFiledHelper(allowed) 236 | things = [helper.generate(ban_orcons=ban_orcons)] 237 | if things: 238 | return cls(*things) 239 | return NoneToken() 240 | 241 | class NotStartwithCons(Constraint): 242 | @classmethod 243 | def generate(cls, allowed=None): 244 | contain_cons = StartwithCons.generate(allowed=allowed, ban_orcons=True) 245 | 246 | if contain_cons is not None: 247 | return NotCons(*[contain_cons]) 248 | else: 249 | return NoneToken() 250 | 251 | class NotEndwithCons(Constraint): 252 | @classmethod 253 | def generate(cls, allowed=None): 254 | contain_cons = EndwithCons.generate(allowed=allowed, ban_orcons=True) 255 | 256 | if is_valid(contain_cons): 257 | return NotCons(*[contain_cons]) 258 | else: 259 | return NoneToken() 260 | 261 | class ConditionalStartwithCons(Constraint): 262 | @classmethod 263 | def generate(cls, allowed=None): 264 | things = [] 265 | helper = SEWithFiledHelper(allowed) 266 | sup_set, child_set = helper.generate_inclusive_pair() 267 | if is_valid(sup_set) and is_valid(child_set): 268 | return AndComp(StartwithCons(sup_set),NotCons(StartwithCons(child_set))) 269 | return NoneToken() 270 | 271 | class ConditionalEndwithCons(Constraint): 272 | @classmethod 273 | def generate(cls, allowed=None): 274 | things = [] 275 | helper = SEWithFiledHelper(allowed) 276 | sup_set, child_set = helper.generate_inclusive_pair() 277 | if is_valid(sup_set) and is_valid(child_set): 278 | return AndComp(EndwithCons(sup_set),NotCons(EndwithCons(child_set))) 279 | return NoneToken() 280 | 281 | class LengthCons(Constraint): 282 | pass 283 | 284 | class LengthOfCons(LengthCons): 285 | @classmethod 286 | def generate(cls, rng): 287 | return RepeatMod.generate(CharacterToken(), rng) 288 | 289 | class LengthLessThanCons(LengthCons): 290 | @classmethod 291 | def generate(cls, rng): 292 | return RepeatRangeMod(CharacterToken(), 0, random.randint(rng[0], rng[1])) 293 | 294 | class LengthMoreThanCons(LengthCons): 295 | @classmethod 296 | def generate(cls, rng): 297 | return RepeatAtLeastMod.generate(CharacterToken(),rng) 298 | 299 | class LengthBetweenCons(LengthCons): 300 | @classmethod 301 | def generate(cls, min_range, max_range): 302 | return RepeatRangeMod.generate(CharacterToken(), min_range, max_range) 303 | 304 | class RepeatMod(Modifier): 305 | def __init__(self, child, x): 306 | super().__init__(*[child]) 307 | self.params.append(x) 308 | 309 | def logical_form(self): 310 | if self.params[0] == 1: 311 | return self.children[0].logical_form() 312 | else: 313 | return "Repeat({},{})".format(self.children[0].logical_form(), self.params[0]) 314 | 315 | def description(self): 316 | return "Repeat {}, {} times".format(self.children[0].description(),self.params[0]) 317 | 318 | def specification(self): 319 | if self.params[0] == 1: 320 | return self.children[0].specification() 321 | else: 322 | return "repeat({},{})".format(self.children[0].specification(), self.params[0]) 323 | 324 | @classmethod 325 | def generate(cls, child, rng): 326 | x = random.randint(rng[0], rng[1]) 327 | return cls(child, x) 328 | 329 | def sample_negative(self): 330 | x = self.params[0] 331 | _x = random.choice([x -1, x + 1]) 332 | return RepeatMod(self.children[0], _x) 333 | 334 | class RepeatAtLeastMod(Modifier): 335 | def __init__(self, child, x): 336 | super().__init__(*[child]) 337 | self.params.append(x) 338 | 339 | def logical_form(self): 340 | return "RepeatAtLeast({},{})".format(self.children[0].logical_form(), self.params[0]) 341 | 342 | def description(self): 343 | return "Repeat {}, {} times".format(self.children[0].description(),self.params[0]) 344 | 345 | def specification(self): 346 | return "repeatatleast({},{})".format(self.children[0].specification(), self.params[0]) 347 | 348 | @classmethod 349 | def generate(cls, child, rng): 350 | x = random.randint(rng[0], rng[1]) 351 | return cls(child, x) 352 | 353 | def sample_negative(self): 354 | x = self.params[0] 355 | _x = x - 1 356 | return RepeatMod(self.children[0], _x) 357 | 358 | class RepeatRangeMod(Modifier): 359 | def __init__(self, child, lb, hb): 360 | super().__init__(*[child]) 361 | self.params.append(lb) 362 | self.params.append(hb) 363 | 364 | def logical_form(self): 365 | return "RepeatRange({},{},{})".format(self.children[0].logical_form(), self.params[0], self.params[1]) 366 | 367 | def description(self): 368 | return "Repeat range {}, {} to {} times".format(self.children[0].description(), self.params[0], self.params[1]) 369 | 370 | def specification(self): 371 | return "repeatrange({},{},{})".format(self.children[0].specification(), self.params[0], self.params[1]) 372 | 373 | @classmethod 374 | def generate(cls, child, min_range, max_range): 375 | lb = random.randint(min_range[0], min_range[1]) 376 | hb = random.randint(max(max_range[0], lb + 1), max_range[1]) 377 | 378 | return cls(child, lb, hb) 379 | 380 | def sample_negative(self): 381 | lb = self.params[0] 382 | hb = self.params[1] 383 | if lb == 0: 384 | _x = hb + 1 385 | else: 386 | _x = random.choice([lb -1, hb + 1]) 387 | return RepeatMod(self.children[0], _x) 388 | 389 | class FieldHelper: 390 | def __init__(self, allowed=None): 391 | self.allowed = allowed 392 | 393 | def sample_single(self): 394 | if self.allowed: 395 | return SingleToken.generate(self.allowed, False) 396 | else: 397 | return SingleToken.generate(CC, False) 398 | 399 | def sample_string(self): 400 | if self.allowed: 401 | return StringToken.generate(self.allowed, False) 402 | else: 403 | return StringToken.generate(CC_NO_SPEC, False) 404 | 405 | def sample_single_or_cc(self): 406 | if self.allowed: 407 | choices = self.allowed 408 | picked_type = random.choice(choices) 409 | if picked_type in CC: 410 | return picked_type() 411 | else: 412 | return SingleToken.generate([picked_type]) 413 | else: 414 | choices = CC + [SingleToken] 415 | picked_type = random.choice(choices) 416 | if picked_type in CC: 417 | return picked_type() 418 | else: 419 | return SingleToken.generate(CC, False) 420 | 421 | def sample_simple_or_fields(self): 422 | choices = [StringToken, RepeatMod] 423 | weights = [0.3, 1.0] 424 | picked_type = weighted_random_decision(choices, weights) 425 | if picked_type == StringToken: 426 | if self.allowed: 427 | cc_candidates = [x for x in self.allowed if x in CC] 428 | if not cc_candidates: 429 | return NoneToken() 430 | else: 431 | cc_candidates = CC 432 | picked_cc = random.choice(cc_candidates) 433 | return OrComp(StringToken.generate([picked_cc]), StringToken.generate([picked_cc])) 434 | 435 | # in repeat type 436 | else: 437 | # choose between single token or cc 438 | tok_type = weighted_random_decision([SingleToken, CCToken], [1.0, 3.5]) 439 | if tok_type == SingleToken: 440 | if self.allowed: 441 | cc_candidates = [x for x in self.allowed if x in CC] 442 | if not cc_candidates: 443 | return NoneToken() 444 | else: 445 | cc_candidates = CC 446 | picked_cc = random.choice(cc_candidates) 447 | tok1 = SingleToken.generate([picked_cc], False) 448 | tok2 = SingleToken.generate([picked_cc], False) 449 | else: 450 | tok1, tok2 = self.sample_cc_pair() 451 | if isinstance(tok1, NoneToken): 452 | return NoneToken() 453 | 454 | repeating_time = weighted_random_decision([1, 2, 3], [0.55, 0.25, 0.25]) 455 | # return 456 | if repeating_time == 1: 457 | return OrComp(tok1, tok2) 458 | else: 459 | return OrComp(RepeatMod(tok1, repeating_time), RepeatMod(tok2, repeating_time)) 460 | 461 | def sample_cc_pair(self): 462 | if self.allowed: 463 | cc_candidates = [x for x in self.allowed if x in CC] 464 | else: 465 | cc_candidates = CC[:] 466 | 467 | toks = [] 468 | if len(cc_candidates) == 0: 469 | return NoneToken(), NoneToken() 470 | 471 | if len(cc_candidates) == 1: 472 | if cc_candidates[0] == LetterToken: 473 | toks = [LowerToken, CapitalToken] 474 | random.shuffle(toks) 475 | return toks[0](), toks[1]() 476 | else: 477 | return NoneToken(), NoneToken() 478 | else: 479 | if self.allowed and (LetterToken in cc_candidates): 480 | cc_candidates.append(CapitalToken) 481 | cc_candidates.append(LowerToken) 482 | print(cc_candidates) 483 | tok1 = random.choice(cc_candidates) 484 | print(tok1) 485 | soft_remove(cc_candidates, tok1) 486 | if tok1 == LetterToken: 487 | soft_remove(cc_candidates, CapitalToken) 488 | soft_remove(cc_candidates, LowerToken) 489 | if tok1 == CapitalToken or tok1 == LowerToken: 490 | soft_remove(cc_candidates, LetterToken) 491 | print(cc_candidates) 492 | tok2 = random.choice(cc_candidates) 493 | 494 | return tok1(), tok2() 495 | 496 | def sample_cc(self): 497 | if self.allowed: 498 | cc_list = [x for x in self.allowed if x in CC] 499 | if not cc_list: 500 | return NoneToken() 501 | picked_type = random.choice(cc_list) 502 | return picked_type() 503 | else: 504 | picked_type = random.choice(CC) 505 | return picked_type() 506 | 507 | def instantiate_cat_candidate(self, picked_type): 508 | if picked_type == SingleToken: 509 | return self.sample_single() 510 | elif picked_type in CC: 511 | return picked_type() 512 | elif picked_type == RepeatMod: 513 | return RepeatMod.generate(self.sample_single_or_cc(), (2,3)) 514 | 515 | def sample_concat(self): 516 | if self.allowed: 517 | candidates = [SingleToken] + [x for x in self.allowed if x in CC] + [RepeatMod] 518 | else: 519 | candidates = [SingleToken] + CC + [RepeatMod] 520 | 521 | type_c1 = random.choice(candidates) 522 | type_c2 = random.choice(candidates) 523 | 524 | c1 = self.instantiate_cat_candidate(type_c1) 525 | c2 = self.instantiate_cat_candidate(type_c2) 526 | return ConcatComp(c1,c2) 527 | 528 | 529 | # only help constuct 530 | # containabble helper , has to be true subset 531 | class ContainableFiledsHelper(FieldHelper): 532 | def generate(self, ban_orcons=True): 533 | OR_WEIGHT = 1.5 534 | if not self.allowed: 535 | feasible_list = CC + [SingleToken, StringToken, RepeatMod, SimpleConcat2Token] 536 | if self.allowed: 537 | if len(self.allowed) == 1: 538 | feasible_list = [SingleToken, StringToken] 539 | else: 540 | feasible_list = [x for x in self.allowed if x in CC] + [SingleToken, StringToken, RepeatMod] 541 | # if two distinct Token 542 | if len(self.allowed) > 2: 543 | feasible_list.append(SimpleConcat2Token) 544 | feasible_weights = [1.0 for _ in feasible_list] 545 | if not ban_orcons: 546 | feasible_list.append(SimpleOr2Token) 547 | feasible_weights.append(OR_WEIGHT) 548 | 549 | picked_type = weighted_random_decision(feasible_list, feasible_weights) 550 | if picked_type == SingleToken: 551 | return self.sample_single() 552 | elif picked_type in CC: 553 | return picked_type() 554 | elif picked_type == RepeatMod: 555 | return RepeatMod.generate(self.sample_single_or_cc(), (2,3)) 556 | elif picked_type == SimpleConcat2Token: 557 | return self.sample_concat() 558 | elif picked_type == StringToken: 559 | return self.sample_string() 560 | elif picked_type == SimpleOr2Token: 561 | return self.sample_simple_or_fields() 562 | 563 | class OptionalCons(Constraint): 564 | def logical_form(self): 565 | return "Optional({})".format(self.children[0].logical_form()) 566 | 567 | def description(self): 568 | return "Optional {}".format(self.children[0].description()) 569 | 570 | def specification(self): 571 | return "optional({})".format(self.children[0].specification()) 572 | 573 | def sample_negative(self): 574 | return self.children[0].sample_negative() 575 | 576 | # only help constuct 577 | class SEWithFiledHelper(FieldHelper): 578 | def generate(self, ban_orcons=True): 579 | OR_WEIGHT = 3.0 580 | if not self.allowed: 581 | if random_decision(0.1): 582 | picked_type = random.choice([SingleToken, StringToken]) 583 | else: 584 | feasible_list = CC + [RepeatMod, SimpleConcat2Token] 585 | feasible_weights = [1.0 for _ in feasible_list] 586 | if not ban_orcons: 587 | feasible_list.append(SimpleOr2Token) 588 | feasible_weights.append(OR_WEIGHT) 589 | picked_type = weighted_random_decision(feasible_list, feasible_weights) 590 | if self.allowed: 591 | if len(self.allowed) == 1: 592 | picked_type = random.choice([SingleToken, StringToken]) 593 | else: 594 | if random_decision(0.1): 595 | picked_type = random.choice([SingleToken, StringToken]) 596 | feasible_list = [x for x in self.allowed if x in CC] + [RepeatMod] 597 | # if two distinct Token 598 | if len(self.allowed) > 2: 599 | feasible_list.append(SimpleConcat2Token) 600 | feasible_weights = [1.0 for _ in feasible_list] 601 | if not ban_orcons: 602 | feasible_list.append(SimpleOr2Token) 603 | feasible_weights.append(OR_WEIGHT) 604 | picked_type = weighted_random_decision(feasible_list, feasible_weights) 605 | 606 | if picked_type == SingleToken: 607 | return self.sample_single() 608 | elif picked_type in CC: 609 | return picked_type() 610 | elif picked_type == RepeatMod: 611 | return RepeatMod.generate(self.sample_single_or_cc(), (2,3)) 612 | elif picked_type == SimpleConcat2Token: 613 | return self.sample_concat() 614 | elif picked_type == StringToken: 615 | return self.sample_string() 616 | elif picked_type == SimpleOr2Token: 617 | return self.sample_simple_or_fields() 618 | 619 | def generate_inclusive_pair(self, exclude=None): 620 | if not self.allowed: 621 | # sup_set_type = random.choice(CC + [RepeatMod]) 622 | if random_decision(0.55): 623 | sup_set_type = random.choice(CC_NO_SPEC) 624 | else: 625 | sup_set_type = RepeatMod 626 | if self.allowed: 627 | if len(self.allowed) == 1: 628 | return NoneToken(), NoneToken() 629 | feasible_list = [x for x in self.allowed if x in CC] 630 | if not feasible_list: 631 | return NoneToken(), NoneToken() 632 | if random_decision(0.55): 633 | sup_set_type = random.choice(CC_NO_SPEC) 634 | else: 635 | sup_set_type = RepeatMod 636 | 637 | if sup_set_type in CC: 638 | sup_set = sup_set_type() 639 | child_set_type = weighted_random_decision([SingleToken, StringToken], [0.6, 0.45]) 640 | if child_set_type == SingleToken: 641 | child_set = SingleToken.generate([sup_set_type], False) 642 | elif child_set_type == StringToken: 643 | child_set = StringToken.generate([sup_set_type], False) 644 | 645 | elif sup_set_type == RepeatMod: 646 | if not self.allowed: 647 | sup_set_child_type = random.choice(CC) 648 | if self.allowed: 649 | feasible_list = [x for x in self.allowed if x in CC] 650 | if not feasible_list: 651 | return NoneToken(), NoneToken() 652 | sup_set_child_type = random.choice(feasible_list) 653 | sup_set_child = sup_set_child_type() 654 | sup_set = RepeatMod.generate(sup_set_child, (2,3)) 655 | if random_decision(0.4): 656 | sup_length = sup_set.params[0] 657 | child_set = StringToken.generate([sup_set_child_type], length=sup_length) 658 | else: 659 | child_set_child = SingleToken.generate([sup_set_child_type], False) 660 | child_set = RepeatMod(child_set_child, sup_set.params[0]) 661 | return sup_set, child_set 662 | 663 | class NotOnlyCons(Constraint): 664 | pass 665 | -------------------------------------------------------------------------------- /toolkit/external/jars/datagen.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/toolkit/external/jars/datagen.jar -------------------------------------------------------------------------------- /toolkit/external/lib/antlr-4.7.1-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/toolkit/external/lib/antlr-4.7.1-complete.jar -------------------------------------------------------------------------------- /toolkit/external/lib/sempre-core.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiye17/StructuredRegex/7d49aa2a176677fb5a553b05a6ad7d8fc24bc37d/toolkit/external/lib/sempre-core.jar -------------------------------------------------------------------------------- /toolkit/filters.py: -------------------------------------------------------------------------------- 1 | from base import * 2 | from template import * 3 | from functools import reduce 4 | import subprocess 5 | from prepare_regex_data import gen_pos_examples 6 | 7 | 8 | def filter_regexes(regexes): 9 | regexes = [(x[1],print("Is Valiid", x[0],len(regexes)))[0] for x in enumerate(regexes) if is_valid(x[1])] 10 | regexes = [(x[1],print("Is Good", x[0],len(regexes)))[0] for x in enumerate(regexes) if is_good(x[1])] 11 | regexes = [(x[1],print("Is Diverse", x[0],len(regexes)))[0] for x in enumerate(regexes) if is_diverse(x[1])] 12 | return regexes 13 | 14 | def is_diverse(node): 15 | pos_exs = gen_pos_examples(node, 100) 16 | num_pos = len(set(pos_exs)) 17 | return num_pos > 10 18 | 19 | def is_good(node): 20 | if not all([is_good(x) for x in node.children]): 21 | return False 22 | 23 | # if isinstance(node, RepeatMod): 24 | # child = node.children[0] 25 | # times = node.params[0] 26 | # if isinstance(child, SingleToken) and times == 3: 27 | # return False 28 | 29 | if isinstance(node, OrComp): 30 | child_0 = node.children[0] 31 | child_1 = node.children[1] 32 | if child_0.logical_form() == child_1.logical_form(): 33 | return False 34 | 35 | if isinstance(node, ConcatenationField) or isinstance(node, ConcatComp): 36 | if not check_cat_type(node): 37 | return False 38 | 39 | if isinstance(node, AndComp): 40 | if not check_and_type(node): 41 | return False 42 | 43 | if isinstance(node, UnstructuredField): 44 | if not check_uns_type(node): 45 | return False 46 | 47 | if isinstance(node, ComposedByCons): 48 | if not check_or_type(node): 49 | return False 50 | 51 | return True 52 | 53 | def extract_terminal(node): 54 | if isinstance(node, Token): 55 | return (node.logical_form(),) 56 | elif isinstance(node, NotCCCons): 57 | return ("not" + extract_terminal(node.children[0])[0],) 58 | else: 59 | return reduce(lambda x,y: x + y, [extract_terminal(x) for x in node.children]) 60 | 61 | def check_cat_type(node): 62 | if all([isinstance(c, OptionalCons) for c in node.children]): 63 | return False 64 | 65 | flat_children = [] 66 | for c in node.children: 67 | if isinstance(c, OptionalCons): 68 | if isinstance(c.children[0], ConcatComp): 69 | flat_children.extend(c.children[0].children) 70 | else: 71 | flat_children.append(c.children[0]) 72 | else: 73 | flat_children.append(c) 74 | terminals = [extract_terminal(x) for x in flat_children] 75 | for i in range(len(terminals) - 1): 76 | if terminals[i] == terminals[i + 1]: 77 | return False 78 | return True 79 | 80 | def check_and_type(node): 81 | children_logical_forms = [x.logical_form() for x in node.children] 82 | if len(set(children_logical_forms)) < len(children_logical_forms): 83 | return False 84 | return True 85 | 86 | def check_or_type(node): 87 | children_logical_forms = [x.logical_form() for x in node.children] 88 | if len(set(children_logical_forms)) < len(children_logical_forms): 89 | return False 90 | return True 91 | 92 | def check_uns_type(node): 93 | if not check_and_type(node): 94 | return False 95 | 96 | complexity = 0 97 | # complexity check 98 | for child in node.children: 99 | if isinstance(child, AndComp): 100 | complexity += 2 101 | elif isinstance(child, StartwithCons) or isinstance(child, EndwithCons) or isinstance(child, ContainCons): 102 | if "Or(" in child.logical_form(): 103 | complexity += 1 104 | else: 105 | complexity += 1 106 | 107 | max_complexity = 3 if isinstance(node, SimpleUnstructuredField) else 6 108 | if complexity > max_complexity: 109 | print("Complexity ({}|{}) Filter".format(complexity, max_complexity)) 110 | print(node.description()) 111 | return False 112 | 113 | # check compabality 114 | 115 | 116 | cons = [] 117 | not_contain_cons = [] 118 | composed_by_cons = None 119 | for child in node.children: 120 | if isinstance(child, NotCons): 121 | construct = child.children[0] 122 | else: 123 | construct = child 124 | if isinstance(construct, ComposedByCons): 125 | cons.append(child) 126 | composed_by_cons = child 127 | if isinstance(construct, ContainCons): 128 | cons.append(child) 129 | if isinstance(child, NotCons): 130 | not_contain_cons.append(construct) 131 | if isinstance(construct, StartwithCons): 132 | cons.append(child) 133 | if isinstance(construct, EndwithCons): 134 | cons.append(child) 135 | 136 | if (composed_by_cons is not None) and not_contain_cons: 137 | for nc_cons in not_contain_cons: 138 | banned_tok = nc_cons.children[0].logical_form() 139 | for tok in composed_by_cons.children: 140 | if banned_tok == tok.logical_form(): 141 | print("Semantic Filter") 142 | print(node.description()) 143 | return False 144 | 145 | if len(cons) >= 2: 146 | origin_spec = AndComp.and_type_specification(cons) 147 | for i in range(len(cons)): 148 | reduced_spec = AndComp.and_type_specification(cons[:i] + cons[i + 1:]) 149 | if check_equiv(origin_spec, reduced_spec) == "true": 150 | print("Redundancy Filter") 151 | print(node.description()) 152 | return False 153 | 154 | return True 155 | 156 | 157 | def check_equiv(spec0, spec1): 158 | # try: 159 | out = subprocess.check_output( 160 | ['java', '-cp', './external/jars/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'equiv', 161 | spec0, spec1], stderr=subprocess.DEVNULL) 162 | out = out.decode("utf-8") 163 | out = out.rstrip() 164 | return out 165 | 166 | 167 | -------------------------------------------------------------------------------- /toolkit/gen_regex_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | from collections import Counter 4 | import subprocess 5 | import re 6 | from base import * 7 | from template import * 8 | from constraints import ComposedByCons 9 | from filters import * 10 | from os.path import join 11 | from regex_io import read_tsv_file, build_func_from_str 12 | 13 | def get_posted_regexes(): 14 | posted_batches = [1, 2] 15 | posted_regexes = [] 16 | for b in posted_batches: 17 | fname = join("./HitArxiv", "batch" + str(b), "batch-{}-record.txt".format(b)) 18 | regexes = [x[1] for x in read_tsv_file(fname, delimiter=" ")] 19 | regexes = [build_func_from_str(x) for x in regexes] 20 | posted_regexes.extend(regexes) 21 | return posted_regexes 22 | 23 | def gen_pilot_template(): 24 | # uns_regexes =[UnstructuredField.generate(5) for _ in range(300)] 25 | # cat_regexes = [ConcatenationField.generate(6) for _ in range(350)] 26 | sep_regexes = [SeperatedField.generate() for _ in range(350)] 27 | # print("uns", len(uns_regexes)) 28 | # print("cat", len(cat_regexes)) 29 | print("sep", len(sep_regexes)) 30 | 31 | # do filtering 32 | # uns_regexes = filter_regexes(uns_regexes) 33 | # cat_regexes = filter_regexes(cat_regexes) 34 | sep_regexes = filter_regexes(sep_regexes) 35 | 36 | # print("uns", len(uns_regexes)) 37 | # print("cat", len(cat_regexes)) 38 | print("sep", len(sep_regexes)) 39 | 40 | posted_regexes = get_posted_regexes() 41 | posted_forms = [x.logical_form() for x in posted_regexes] 42 | # print(len(uns_regexes), len(cat_regexes), len(sep_regexes)) 43 | # uns_regexes = [x for x in uns_regexes if x.logical_form() not in posted_forms] 44 | # cat_regexes = [x for x in cat_regexes if x.logical_form() not in posted_forms] 45 | sep_regexes = [x for x in sep_regexes if x.logical_form() not in posted_forms] 46 | # print(len(uns_regexes), len(cat_regexes), len(sep_regexes)) 47 | 48 | prefix = "regexes-raw" 49 | # random.shuffle(uns_regexes) 50 | # with open(join(prefix, "batch3_uns.txt"), "w") as f: 51 | # [f.write("{}\n".format(r.to_string())) for r in uns_regexes] 52 | # with open(join(prefix, "uns-plot.txt"), "w") as f: 53 | # [f.write("{}\n".format(tok(r.logical_form()))) for r in uns_regexes] 54 | 55 | # random.shuffle(cat_regexes) 56 | # with open(join(prefix, "batch3_cat.txt"), "w") as f: 57 | # [f.write("{}\n".format(r.to_string())) for r in cat_regexes] 58 | # with open(join(prefix, "cat-plot.txt"), "w") as f: 59 | # [f.write("id {}\n".format(tok(r.logical_form()))) for r in cat_regexes] 60 | 61 | random.shuffle(sep_regexes) 62 | with open(join(prefix, "batch3_sep.txt"), "w") as f: 63 | [f.write("{}\n".format(r.to_string())) for r in sep_regexes] 64 | # with open(join(prefix, "sep-plot.txt"), "w") as f: 65 | # [f.write("{}\n".format(tok(r.logical_form()))) for r in sep_regexes] 66 | 67 | def main(): 68 | gen_pilot_template() 69 | 70 | if __name__ == "__main__": 71 | random.seed(34341) 72 | main() 73 | -------------------------------------------------------------------------------- /toolkit/postprocess.py: -------------------------------------------------------------------------------- 1 | from regex_io import * 2 | from os.path import join 3 | from functools import reduce 4 | from base import * 5 | from template import * 6 | import random 7 | import re 8 | import nltk 9 | from nltk.translate import IBMModel1 10 | from nltk.translate import Alignment, AlignedSent 11 | from nltk.translate.ibm3 import IBMModel3 12 | from nltk.translate.ibm2 import IBMModel2 13 | from prepare_regex_data import gen_bad_examples 14 | import numpy as np 15 | import spacy 16 | import pickle 17 | 18 | # postprocess responses 19 | def filter_responces(resps): 20 | resps = [x for x in resps if len(x[4]["description"].split()) > 4] 21 | return resps 22 | 23 | def mannually_filter_responces(resps): 24 | # filter by worker id 25 | bad_workers = ["A3HM4LUQL1ASJ1", "A2C33799VENPIP", "A2V0BI2FTB17HN"] 26 | print("Before M", len(resps)) 27 | resps = [x for x in resps if x[4]["worker_id"] not in bad_workers] 28 | print("After M", len(resps)) 29 | return resps 30 | 31 | def build_reg_libs(): 32 | regexes_b1 = read_tsv_file(join("./results", "batch1_record.txt"), delimiter=" ") 33 | regexes_b2 = read_tsv_file(join("./results", "batch2_record.txt"), delimiter=" ") 34 | regexes_b3 = read_tsv_file(join("./results", "batch3_record.txt"), delimiter=" ") 35 | regexes = regexes_b1 + regexes_b2 + regexes_b3 36 | regexes = dict([(x[0],build_func_from_str(x[1])) for x in regexes]) 37 | return regexes 38 | 39 | def extract_const_comps(node): 40 | if isinstance(node, SingleToken): 41 | return [("tok", node.tok)] 42 | elif isinstance(node, StringToken): 43 | return [("str", node.tok)] 44 | else: 45 | return reduce(lambda x,y: x + y, [extract_const_comps(x) for x in node.children], []) 46 | 47 | def extract_ints(node): 48 | toks = node.ground_truth() 49 | toks = tokenize(toks) 50 | toks = [x for x in toks if x.isdigit()] 51 | return toks 52 | 53 | # "a" or "." or "," -> unique mode 54 | # digit -> unique 55 | def map_tokens_with_consts(tokens, consts, ints): 56 | maps = [] 57 | tokens_bak = tokens[:] 58 | quoated_tokens = [] 59 | symbol_tokens = [] 60 | for const in consts: 61 | if const[0] == "tok": 62 | if const[1] == "a" or const[1] == "." or const[1] == ",": 63 | continue 64 | if const[1].isdigit(): 65 | # overlapping 66 | if const[1] in ints: 67 | continue 68 | 69 | target = const[1] 70 | quoated_tokens = [] 71 | for i, tok in enumerate(tokens): 72 | if tok == target: 73 | quoated_tokens.append('"{}"'.format(target)) 74 | else: 75 | quoated_tokens.append(tok) 76 | tokens = quoated_tokens 77 | if not quoated_tokens: 78 | quoated_tokens = tokens[:] 79 | 80 | c_id = 0 81 | for const in consts: 82 | target = const[1] 83 | symbol = "const{}".format(c_id) 84 | symbol_tokens = [] 85 | flag_mapped = False 86 | for i, tok in enumerate(quoated_tokens): 87 | if tok == '"{}"'.format(target): 88 | symbol_tokens.append(symbol) 89 | flag_mapped = True 90 | else: 91 | symbol_tokens.append(tok) 92 | if flag_mapped: 93 | maps.append((symbol, target)) 94 | c_id += 1 95 | quoated_tokens = symbol_tokens 96 | 97 | if not symbol_tokens: 98 | symbol_tokens = quoated_tokens[:] 99 | # leftout free consts 100 | free_consts = [] 101 | for tok in symbol_tokens: 102 | if tok[0] == '"' and tok[-1] == '"' and len(tok) > 2: 103 | if tok[1:-1] not in free_consts: 104 | free_consts.append(tok[1:-1]) 105 | map_before = maps[:] 106 | symbol_descp = " ".join(symbol_tokens) 107 | for fcons in free_consts: 108 | symbol = "const{}".format(c_id) 109 | symbol_descp = symbol_descp.replace('"{}"'.format(fcons), symbol) 110 | maps.append((symbol, fcons)) 111 | c_id += 1 112 | symbol_tokens = symbol_descp.split(" ") 113 | return symbol_tokens, maps 114 | 115 | def remove_duplicate_const(consts): 116 | y = [] 117 | for c in consts: 118 | if c not in y: 119 | y.append(c) 120 | 121 | return y 122 | 123 | def special_processing(descp): 124 | new_tokens = [] 125 | tokens = descp.split(" ") 126 | a_to_b_re = re.compile("\d+-\d+") 127 | a_or_more_re = re.compile("\d+\+") 128 | for tok in tokens: 129 | if a_to_b_re.match(tok): 130 | nums = tok.split("-") 131 | new_tokens.append(nums[0]) 132 | new_tokens.append("-") 133 | new_tokens.append(nums[1]) 134 | continue 135 | if a_or_more_re.match(tok): 136 | nums = tok[:-1] 137 | new_tokens.append(nums) 138 | new_tokens.append("+") 139 | continue 140 | new_tokens.append(tok) 141 | 142 | # clear . again 143 | new_descp = " ".join(new_tokens) 144 | new_descp = new_descp.replace(".", " . ") 145 | new_tokens = new_descp.split(" ") 146 | new_tokens = [x for x in new_tokens if len(x)] 147 | 148 | return " ".join(new_tokens) 149 | 150 | def replace_regex_with_map(r, maps): 151 | for m in maps: 152 | if len(m[1]) == 1: 153 | src = "<{}>".format(m[1]) 154 | else: 155 | src = "const(<{}>)".format(m[1]) 156 | dst = m[0] 157 | r = r.replace(src, dst) 158 | return r 159 | 160 | _CONTRACTIONS2 = [ 161 | r"(?i)\b(can)(?#X)(not)\b", 162 | r"(?i)\b(gon)(?#X)(na)\b", 163 | r"(?i)\b(got)(?#X)(ta)\b", 164 | r"(?i)\b(wan)(?#X)(na)b", 165 | r"(?i)\b(there)(?#X)('s)b", 166 | r"(?i)\b(it)(?#X)('s)b", 167 | ] 168 | _CONTRACTIONS3 = [r"(?i)\b(\w+)(?#X)(n't)\b"] 169 | CONTRACTIONS2 = list(map(re.compile, _CONTRACTIONS2)) 170 | CONTRACTIONS3 = list(map(re.compile, _CONTRACTIONS3)) 171 | 172 | def remove_contractions(x): 173 | for regexp in CONTRACTIONS2: 174 | x = regexp.sub(r'\1 \2', x) 175 | x = x.replace("can't", "can not") 176 | for regexp in CONTRACTIONS3: 177 | x = regexp.sub(r'\1 not', x) 178 | # x = x.replace("can't", "can n‘t") 179 | # for regexp in CONTRACTIONS3: 180 | # x = regexp.sub(r'\1 \2', x) 181 | return x 182 | 183 | 184 | # like 3 a's or 3 "a"s 185 | def clear_prural(tokens): 186 | new_toks = [] 187 | for t in tokens: 188 | if t.endswith("'s") and t != "'s": 189 | new_toks.append(t[:-2]) 190 | elif t.endswith('"s'): 191 | new_toks.append(t[:-1]) 192 | else: 193 | new_toks.append(t) 194 | return new_toks 195 | 196 | # fix unpaired token 197 | def fix_unpaired_quotas(tokens): 198 | s = " ".join(tokens) 199 | double_quot_regex = re.compile("\"([^'\s]+)\"") 200 | s = re.sub(double_quot_regex, r' "\1" ', s) 201 | 202 | new_toks = [] 203 | for t in s.split(" "): 204 | if "'" in t or '"' in t: 205 | if t == "'s" or t == "n't": 206 | new_toks.append(t) 207 | continue 208 | if t[0] == '"' and t[-1] == '"': 209 | new_toks.append(t) 210 | continue 211 | 212 | t = t.replace("'", " ") 213 | t = t.replace('"', " ") 214 | new_toks.extend(t.split(" ")) 215 | else: 216 | new_toks.append(t) 217 | new_toks = [x for x in new_toks if len(x)] 218 | return new_toks 219 | 220 | def spacy_processing(descp): 221 | doc = spacy_tokenizer(descp) 222 | return " ".join([x.text for x in doc]) 223 | 224 | def record_to_metadata(r, reg_libs): 225 | # print(r["problem_id"]) 226 | descp = r["description"] 227 | id = r["problem_id"] 228 | regex = reg_libs[id] 229 | consts = extract_const_comps(regex) 230 | consts = remove_duplicate_const(consts) 231 | ints = extract_ints(regex) 232 | # deal with , . 233 | descp_clear_punc = descp + " " 234 | descp_clear_punc = descp_clear_punc.replace(". ", " . ") 235 | descp_clear_punc = descp_clear_punc.replace(", ", " , ") 236 | descp_clear_punc = descp_clear_punc.replace("(", " ( ") 237 | descp_clear_punc = descp_clear_punc.replace(")", " ) ") 238 | descp_clear_punc = descp_clear_punc.rstrip() 239 | 240 | descp_clear_punc = remove_contractions(descp_clear_punc) 241 | # fix minor issue 242 | descp_clear_punc = descp_clear_punc.replace("’", "'") 243 | single_quot_regex = re.compile("'([^'\s]+)'") 244 | descp_clear_single = re.sub(single_quot_regex, r'"\1"', descp_clear_punc) 245 | 246 | # clear prural 247 | descp_tokens = descp_clear_single.split(" ") 248 | descp_tokens = [x for x in descp_tokens if len(x)] 249 | descp_tokens = clear_prural(descp_tokens) 250 | 251 | # fix unpaired quota 252 | descp_tokens = fix_unpaired_quotas(descp_tokens) 253 | 254 | # map const, if not quoated, quote them 255 | symbol_tokens, token_const_maps = map_tokens_with_consts(descp_tokens, consts, ints) 256 | 257 | symbol_descp = " ".join(symbol_tokens) 258 | symbol_descp = symbol_descp.lower() 259 | 260 | num_pairs = [(' one ', ' 1 '), (' two ', ' 2 '), (' three ', ' 3 '), (' four ', ' 4 '), 261 | (' five ', ' 5 '), (' six ', ' 6 '), (' seven ', ' 7 '), (' eight ', ' 8 '), (' nine ', ' 9 '), (' ten ', ' 10 ')] 262 | for pair in num_pairs: 263 | symbol_descp = symbol_descp.replace(pair[0], pair[1]) 264 | symbol_descp = special_processing(symbol_descp) 265 | symbol_descp = spacy_processing(symbol_descp) 266 | # dealwith groudtruth 267 | ground_truth = regex.ground_truth() 268 | ground_truth = replace_regex_with_map(ground_truth, token_const_maps) 269 | ground_truth = " ".join(tokenize(ground_truth)) 270 | 271 | pos_exs = r["pos_examples"] 272 | neg_exs = r["neg_examples"] 273 | exs = [] 274 | for pex in pos_exs.split("\n"): 275 | exs.append("+,"+pex) 276 | for nex in neg_exs.split("\n"): 277 | exs.append("-,"+nex) 278 | exs = " ".join(exs) 279 | rec = r 280 | return symbol_descp, ground_truth, token_const_maps, exs, rec, regex 281 | 282 | def write_dataset(records, split): 283 | print(len(records)) 284 | print("Split", split) 285 | print(len(records)) 286 | descps = [x[0] for x in records] 287 | gts = [x[1] for x in records] 288 | maps = [x[2] for x in records] 289 | exs = [x[3] for x in records] 290 | rec = [x[4] for x in records] 291 | regex = [x[5] for x in records] 292 | 293 | prefix = "./ARealBase" 294 | with open(join(prefix, "src-{}.txt".format(split)), "w") as f: 295 | descps_lines = "\n".join(descps) 296 | f.write(descps_lines) 297 | 298 | with open(join(prefix, "targ-{}.txt".format(split)), "w") as f: 299 | gts_lines = "\n".join(gts) 300 | f.write(gts_lines) 301 | 302 | with open(join(prefix, "map-{}.txt".format(split)), "w") as f: 303 | maps_lines = [ " ".join([str(len(m))] + ["{},{}".format(pair[0], pair[1]) for pair in m]) for m in maps] 304 | maps_lines = "\n".join(maps_lines) 305 | f.write(maps_lines) 306 | 307 | with open(join(prefix, "exs-{}.txt".format(split)), "w") as f: 308 | exs_lines = "\n".join(exs) 309 | f.write(exs_lines) 310 | 311 | with open(join(prefix, "rec-{}.pkl".format(split)), "wb") as f: 312 | rec_lines = [{"id":r["problem_id"], "worker_id": r["worker_id"]} for r in rec] 313 | pickle.dump(rec_lines, f) 314 | 315 | def make_random_example(): 316 | reg_libs = build_reg_libs() 317 | prefix = "./dataset" 318 | with open(join(prefix, "id-val.txt")) as f: 319 | lines = f.readlines() 320 | lines = [x.rstrip() for x in lines] 321 | id_lines = lines 322 | id_set = list(set(id_lines)) 323 | 324 | bad_exs = [] 325 | for id in id_set: 326 | print(id) 327 | r = reg_libs[id] 328 | spec = r.specification() 329 | exs = gen_bad_examples(spec) 330 | bad_exs.append(exs) 331 | id_exs_dict = dict(zip(id_set, bad_exs)) 332 | 333 | with open(join(prefix, "bad_exs-val.txt"), "w") as f: 334 | exs_lines = [id_exs_dict[x] for x in id_lines] 335 | exs_lines = [" ".join(["{},{}".format(x[1],x[0]) for x in exs]) for exs in exs_lines] 336 | f.write("\n".join(exs_lines)) 337 | 338 | def make_dr_data(): 339 | reg_libs = build_reg_libs() 340 | records_b1 = read_result("results/batch1_res.csv") 341 | records_b2 = read_result("results/batch2_res.csv") 342 | records_b3 = read_result("results/batch3_res.csv") 343 | records = records_b1 + records_b2 + records_b3 344 | records = group_by_filed(records, "problem_id") 345 | print(len(records)) 346 | records = [records[k] for k in records] 347 | random.shuffle(records) 348 | split_point = 900 349 | train_records = records[:split_point] 350 | test_records = records[split_point:] 351 | 352 | train_records = list(reduce(lambda x,y: x + y, map(lambda z: [record_to_metadata(r, reg_libs) for r in z], train_records))) 353 | train_records = filter_responces(train_records) 354 | train_records = mannually_filter_responces(train_records) 355 | test_records = list(reduce(lambda x,y: x + y, map(lambda z: [record_to_metadata(r, reg_libs) for r in z], test_records))) 356 | test_records = filter_responces(test_records) 357 | test_records = mannually_filter_responces(test_records) 358 | write_dataset(train_records, "train") 359 | write_dataset(test_records, "val") 360 | 361 | def post_process_data(): 362 | reg_libs = build_reg_libs() 363 | records_b1 = read_result("results/batch1_res.csv") 364 | records_b2 = read_result("results/batch2_res.csv") 365 | records_b3 = read_result("results/batch3_res.csv") 366 | records = records_b1 + records_b2 + records_b3 367 | records = group_by_filed(records, "problem_id") 368 | print(len(records)) 369 | records = [records[k] for k in records] 370 | train_records = records 371 | 372 | train_records = list(reduce(lambda x,y: x + y, map(lambda z: [record_to_metadata(r, reg_libs) for r in z], train_records))) 373 | train_records = mannually_filter_responces(train_records) 374 | train_records = filter_responces(train_records) 375 | write_dataset(train_records, "col") 376 | 377 | def read_file_lines(filename): 378 | with open(filename) as f: 379 | lines = f.readlines() 380 | lines = [x.rstrip() for x in lines] 381 | return lines 382 | 383 | def write_file_lines(filename, lines): 384 | with open(filename, "w") as f: 385 | f.write("\n".join(lines)) 386 | return lines 387 | 388 | def load_dataset(split): 389 | prefix = "./FinalSplit0" 390 | descps = read_file_lines(join(prefix, "src-{}.txt".format(split))) 391 | gts = read_file_lines(join(prefix, "targ-{}.txt".format(split))) 392 | exs = read_file_lines(join(prefix, "exs-{}.txt".format(split))) 393 | maps = read_file_lines(join(prefix, "map-{}.txt".format(split))) 394 | 395 | with open(join(prefix, "rec-{}.pkl".format(split)), "rb") as f: 396 | records = pickle.load(f) 397 | return list(zip(descps, gts, maps, exs, records)) 398 | 399 | def num_of_three_types(rec): 400 | return [len([y for y in rec if t in y[4]["id"]]) for t in ["uns", "cat", "sep"]] 401 | 402 | def mannually_decide_testex(records_by_worker, full_records): 403 | random.shuffle(records_by_worker) 404 | min_allowed = 350 405 | found = [] 406 | for r in records_by_worker: 407 | print(len(r), end=" ") 408 | found.extend(r) 409 | if len(found) > min_allowed: 410 | print() 411 | break 412 | num_type = [len([y for y in found if t in y["id"]]) for t in ["uns", "cat", "sep"]] 413 | if any([x >= 130 or x <= 110 for x in num_type]): 414 | return False 415 | print([len([y for y in found if t in y["id"]]) for t in ["uns", "cat", "sep"]]) 416 | workers = list(set([x["worker_id"] for x in found])) 417 | ids = list(set([x["id"] for x in found])) 418 | print(workers) 419 | print(len(workers), len(found)) 420 | print(len(ids), len(found), 3 * len(ids) - len(found)) 421 | left_out = [x for x in full_records if x["id"] in ids and x["worker_id"] not in workers] 422 | print([len([y for y in left_out if t in y["id"]]) for t in ["uns", "cat", "sep"]], len(left_out)) 423 | return True 424 | 425 | 426 | TO_VOID_IN_TESTE = ['A2RMJNF6IPI42F', 'A2ECRNQ3X5LEXD', 'A2PFLDMSADON5K'] 427 | # split1 428 | # TEST_EX_WORKER = ['AE861G0AY5RGT', 'A6PRQVQM8YZ4W', 'A2AKGKD22DWZHI', 'A2F5AZQ55LXHKT', 'A20OJ1Q95TMP8B'] 429 | 430 | # split2 431 | TEST_EX_WORKER = ['A1U5D0C8S15TIK', 'A2F5AZQ55LXHKT', 'A2AKGKD22DWZHI', 'A6PRQVQM8YZ4W', 'A38Z99XF4NDNH0', 'A3RMDIRX16L60E', 'A3DCO9GJ4XDVE2', 'A2DDPSXH2X96RF', 'A3UVLUYTHE86UA'] 432 | 433 | #split3 434 | # TEST_EX_WORKER = ['A1E0WK5W1BFPWR', 'A2F5AZQ55LXHKT', 'A2NYUS12FHF2Y', 'AE861G0AY5RGT', 'A28QUR0QYD2WI7', 'A1MKCTVKE7J0ZP', 'A2AKGKD22DWZHI', 'A1U5D0C8S15TIK', 'A3UVLUYTHE86UA', 'A38Z99XF4NDNH0'] 435 | 436 | #split4 437 | # TEST_EX_WORKER = ['A1IATW3PMVL6J3', 'A28QUR0QYD2WI7', 'A2UFGZT4QUY5ON', 'A3DCO9GJ4XDVE2', 'A2BDIIXOFUX18', 'A2RLSRUHS830A7', 'A38Z99XF4NDNH0', 'A2F5AZQ55LXHKT'] 438 | 439 | # split5 440 | # TEST_EX_WORKER = ['AE861G0AY5RGT', 'A3UVLUYTHE86UA', 'A2UFGZT4QUY5ON', 'A3K0GYICW2CXM3', 'A1IATW3PMVL6J3', 'A3DCO9GJ4XDVE2'] 441 | 442 | # split6 443 | # TEST_EX_WORKER = ['A3RMDIRX16L60E', 'A12LP4V8NTTEUE', 'A28QUR0QYD2WI7', 'A1MKCTVKE7J0ZP', 'A38Z99XF4NDNH0', 'A279JAYOXWD7PO', 'A002160837SWJFPIAI7L7', 'A1U5D0C8S15TIK', 'A2NYUS12FHF2Y', 'AE861G0AY5RGT'] 444 | 445 | def dump_dataset(records, split): 446 | descps = [x[0] for x in records] 447 | gts = [x[1] for x in records] 448 | maps = [x[2] for x in records] 449 | exs = [x[3] for x in records] 450 | rec = [x[4] for x in records] 451 | prefix = "./FinalSplit" 452 | write_file_lines(join(prefix, "src-{}.txt".format(split)), descps) 453 | write_file_lines(join(prefix, "targ-{}.txt".format(split)), gts) 454 | write_file_lines(join(prefix, "exs-{}.txt".format(split)), exs) 455 | write_file_lines(join(prefix, "map-{}.txt".format(split)), maps) 456 | 457 | with open(join(prefix, "rec-{}.pkl".format(split)), "wb") as f: 458 | pickle.dump(rec, f) 459 | 460 | 461 | def decide_ex_workers(): 462 | records = load_dataset("col") 463 | 464 | records = [r[4] for r in records] 465 | # print(records) 466 | records_by_worker = group_by_filed(records, "worker_id") 467 | records_by_worker = [records_by_worker[k] for k in records_by_worker if k not in TO_VOID_IN_TESTE] 468 | 469 | 470 | records_by_worker = [x for x in records_by_worker if (len(x) < 110 and len(x) > 20)] 471 | while(not mannually_decide_testex(records_by_worker, records)): 472 | pass 473 | # mannually_decide_testex(records_by_worker) 474 | 475 | def make_testex_split(): 476 | records = load_dataset("col") 477 | 478 | records = [x for x in records if x[4]["worker_id"] in TEST_EX_WORKER] 479 | dump_dataset(records, "teste") 480 | 481 | def make_train_dev_testi_split(): 482 | records = load_dataset("col") 483 | 484 | ex_records = [x for x in records if x[4]["worker_id"] in TEST_EX_WORKER] 485 | ex_ids = list(set([x[4]["id"] for x in ex_records])) 486 | 487 | records = [x for x in records if x[4]["worker_id"] not in TEST_EX_WORKER] 488 | print(len(records)) 489 | must_in_test_records = [x for x in records if x[4]["id"] in ex_ids] 490 | print(len(must_in_test_records)) 491 | 492 | rest_records = [x for x in records if x[4]["id"] not in ex_ids] 493 | print(len(rest_records)) 494 | rest_ids = list(set([x[4]["id"] for x in rest_records])) 495 | # random.seed(235) 496 | random.shuffle(rest_ids) 497 | num_each_type = 40 498 | rest_ids_by_types = [[x for x in rest_ids if t in x] for t in ["uns", "cat", "sep"]] 499 | rem_ids = sum([x[:num_each_type] for x in rest_ids_by_types], []) 500 | print(len(rem_ids)) 501 | 502 | rem_records = [x for x in rest_records if x[4]["id"] in rem_ids] 503 | train_records = [x for x in rest_records if x[4]["id"] not in rem_ids] 504 | 505 | dev_records = rem_records 506 | testi_records = must_in_test_records 507 | print(num_of_three_types(train_records)) 508 | print(num_of_three_types(dev_records)) 509 | print(num_of_three_types(testi_records)) 510 | 511 | dump_dataset(train_records, "train") 512 | dump_dataset(dev_records, "val") 513 | dump_dataset(testi_records, "testi") 514 | simple_stats(dev_records) 515 | 516 | def simple_stats(records): 517 | 518 | 519 | def calc_ast_depth(x): 520 | return 1 + max([0] + [calc_ast_depth(c) for c in x.children]) 521 | 522 | def calc_ast_size(x): 523 | return 1 + sum([calc_ast_size(c) for c in x.children]) 524 | 525 | descps = [x[0] for x in records] 526 | gts = [x[1] for x in records] 527 | maps = [x[2] for x in records] 528 | 529 | descp_lens = np.array([len(x.split(" ")) for x in descps]) 530 | print(descp_lens.mean()) 531 | print(np.quantile(descp_lens, [0.0, 0.25, 0.5, 0.75, 1.0])) 532 | 533 | gt_toks = [x.replace(" ", "") for x in gts] 534 | targ_toks = [tokenize(x) for x in gt_toks] 535 | 536 | targ_asts = [build_dataset_ast_from_toks(x, 0)[0] for x in targ_toks] 537 | ast_sizes = np.array([calc_ast_size(x) for x in targ_asts]) 538 | ast_depths = np.array([calc_ast_depth(x) for x in targ_asts]) 539 | 540 | print("Avg Size", ast_sizes.mean()) 541 | print("Qutiles Size", np.quantile(ast_sizes, [0,0.25,0.5,0.75,1.0])) 542 | # view_special_ones(targ_lines, ast_sizes) 543 | print("Avg Depth", ast_depths.mean()) 544 | print("Qutiles Depth", np.quantile(ast_depths, [0,0.25,0.5,0.75,1.0])) 545 | 546 | # def make_train_dev_testi_split(): 547 | # records = load_dataset("col") 548 | 549 | # ex_records = [x for x in records if x[4]["worker_id"] in TEST_EX_WORKER] 550 | # ex_ids = list(set([x[4]["id"] for x in ex_records])) 551 | 552 | # records = [x for x in records if x[4]["worker_id"] not in TEST_EX_WORKER] 553 | # print(len(records)) 554 | # must_in_test_records = [x for x in records if x[4]["id"] in ex_ids] 555 | # print(len(must_in_test_records)) 556 | 557 | # rest_records = [x for x in records if x[4]["id"] not in ex_ids] 558 | # print(len(rest_records)) 559 | # rest_ids = list(set([x[4]["id"] for x in rest_records])) 560 | # random.seed(666) 561 | # random.shuffle(rest_ids) 562 | # target_num = 710 - len(must_in_test_records) 563 | 564 | # rem_ids = [] 565 | # cnt = 0 566 | # for id in rest_ids: 567 | # cnt += sum([x[4]["id"] == id for x in rest_records]) 568 | # rem_ids.append(id) 569 | # if cnt > target_num: 570 | # break 571 | 572 | # print(len(rem_ids), cnt) 573 | 574 | # rem_records = [x for x in rest_records if x[4]["id"] in rem_ids] 575 | # train_records = [x for x in rest_records if x[4]["id"] not in rem_ids] 576 | # all_test_records = rem_records + must_in_test_records 577 | # half_num = len(all_test_records) // 2 578 | # print(len(train_records), len(all_test_records), half_num) 579 | # random.shuffle(all_test_records) 580 | 581 | # dev_records = all_test_records[:half_num] 582 | # testi_records = all_test_records[half_num:] 583 | # dump_dataset(train_records, "train") 584 | # dump_dataset(dev_records, "dev") 585 | # dump_dataset(testi_records, "testi") 586 | 587 | def verify_partion(): 588 | train = load_dataset("train") 589 | dev = load_dataset("dev") 590 | testi = load_dataset("testi") 591 | teste = load_dataset("teste") 592 | 593 | train_descps = [x[0] for x in train] 594 | print(sum([x[0] in train_descps for x in dev])) 595 | print(sum([x[0] in train_descps for x in testi])) 596 | print(sum([x[0] in train_descps for x in teste])) 597 | 598 | testi_descps = [x[0] for x in testi] 599 | dev_descps = [x[0] for x in dev] 600 | print(sum([x[0] in dev_descps for x in teste])) 601 | print(sum([x[0] in testi_descps for x in teste])) 602 | # print(sum([x in train_descps for x in dev])) 603 | 604 | train_regexes = [x[1] for x in train] 605 | print(sum([x[1] in train_regexes for x in dev])) 606 | print(sum([x[1] in train_regexes for x in testi])) 607 | print(sum([x[1] in train_regexes for x in teste])) 608 | 609 | testi_regexes = [x[1] for x in testi] 610 | dev_regexes = [x[1] for x in dev] 611 | print(sum([x[1] in dev_regexes for x in teste])) 612 | print(sum([x[1] in testi_regexes for x in teste])) 613 | 614 | def stats_partion(): 615 | train = load_dataset("train") 616 | dev = load_dataset("val") 617 | testi = load_dataset("testi") 618 | teste = load_dataset("teste") 619 | 620 | def targ_length(x): 621 | targs = [y[1] for y in x] 622 | return max([len(y.split()) for y in targs]) 623 | print(targ_length(train)) 624 | print(targ_length(dev)) 625 | print(targ_length(testi)) 626 | print(targ_length(teste)) 627 | simple_stats(dev) 628 | 629 | def mannually_assesment(): 630 | train = load_dataset("train") 631 | dev = load_dataset("val") 632 | 633 | gts = [x[0] for x in train] + [x[0] for x in dev] 634 | targs = [x[1] for x in train] + [x[1] for x in dev] 635 | recs = [x[4] for x in train] + [x[4] for x in dev] 636 | 637 | types = ["uns", "cat", "sep"] 638 | infos = list(zip(gts, targs, recs)) 639 | random.shuffle(infos) 640 | def sample_n_things(t): 641 | t_samples = [] 642 | id_sets = set() 643 | for gt, targ, r in infos: 644 | id = r["id"] 645 | if len(id_sets) == 50: 646 | break 647 | if t not in id: 648 | continue 649 | if id in id_sets: 650 | continue 651 | id_sets.add(id) 652 | t_samples.append((gt, targ)) 653 | 654 | return t_samples 655 | 656 | samples = [sample_n_things(t) for t in types] 657 | samples = sum(samples, []) 658 | sample_length = np.array([len(x[0].split()) for x in samples]) 659 | print(sample_length.mean()) 660 | with open("random_sample.txt", "w") as f: 661 | for gt, targ in samples: 662 | f.write('"{}","{}"\n'.format(gt, targ)) 663 | 664 | 665 | 666 | 667 | 668 | def get_spacy_tokenizer(): 669 | disable = ["vectors", "textcat", "tagger", "parser", "ner"] 670 | spacy_model = spacy.load("en_core_web_sm", disable=disable) 671 | return spacy_model 672 | 673 | # spacy.load("en_core_web_sm") 674 | # spacy_tokenizer = get_spacy_tokenizer() 675 | if __name__ == "__main__": 676 | # random.seed(2333) 677 | # post_process_data() 678 | # decide_ex_workers() 679 | # make_testex_split() 680 | # make_train_dev_testi_split() 681 | # verify_partion() 682 | stats_partion() 683 | # mannually_assesment() -------------------------------------------------------------------------------- /toolkit/prepare_regex_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from base import tok, is_valid 3 | from os.path import join 4 | from regex_io import build_func_from_str, read_tsv_file 5 | from template import InfSeperatedField 6 | import subprocess 7 | import random 8 | # REGEX_TYPES = ["uns", "cat", "sep"] 9 | REGEX_TYPES = ["uns", "cat", "sep"] 10 | RAW_DIR = "regexes-raw" 11 | NUM_TRYIED = 100 12 | NUM_KEEP = 6 13 | NUM_EXS = 10 14 | 15 | def prepare_pos_examples(): 16 | # regexes_by_type = [read_regex_file(join(RAW_DIR, t + ".txt")) for t in REGEX_TYPES] 17 | for t in REGEX_TYPES: 18 | regexes = read_tsv_file(join(RAW_DIR, t + ".txt")) 19 | # with open(join(pr)) 20 | examples = [] 21 | for i, x in enumerate(regexes): 22 | print(t, i) 23 | examples.append(gen_examples(x[1])) 24 | # examples = [gen_examples(x) for x in regexes] 25 | examples_lines = ["\t".join(x) for x in examples] 26 | with open(join(RAW_DIR, t + "-pos.txt"), "w") as f: 27 | f.write("\n".join(examples_lines)) 28 | 29 | def gen_random_examples(regex, num_keep=NUM_KEEP, num_gen=NUM_TRYIED): 30 | out = subprocess.check_output( 31 | ['java', '-cp', './external/jars/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'example', 32 | str(num_gen), regex]) 33 | out = out.decode("utf-8") 34 | lines = out.split("\n") 35 | lines = lines[1:] 36 | lines = [x for x in lines if len(x)] 37 | fields = [(x[1:-3],x[-1]) for x in lines] 38 | pos_exs = [x for x in fields if x[1] == "+"] 39 | neg_exs = [x for x in fields if x[1] == "-"] 40 | random.shuffle(pos_exs) 41 | random.shuffle(neg_exs) 42 | pos_exs = pos_exs[:num_keep] 43 | neg_exs = neg_exs[:num_keep] 44 | exs = pos_exs + neg_exs 45 | return exs 46 | 47 | def gen_pos_examples(regex, num_gen=NUM_TRYIED, is_spec=False): 48 | # try: 49 | if not is_spec: 50 | regex = regex.specification() 51 | print("Gen", regex) 52 | out = subprocess.check_output( 53 | ['java', '-cp', './external/jars/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'example', 54 | str(num_gen), regex]) 55 | out = out.decode("utf-8") 56 | return parse_examples(out) 57 | 58 | def match_spec_example(regex, example): 59 | # try: 60 | print("Match", regex, example) 61 | out = subprocess.check_output( 62 | ['java', '-cp', './external/jars/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'evaluate', 63 | regex, example]) 64 | out = out.decode("utf-8") 65 | out = out.rstrip() 66 | return out 67 | 68 | def gen_neg_examples(regex, num_keep=NUM_KEEP): 69 | # for a large number, random sample a path, gen a positive example 70 | exs = [] 71 | reg_candidates = regex.negative_candidates() 72 | for neg_regex in reg_candidates: 73 | if not is_valid(neg_regex): 74 | continue 75 | over_exs = gen_pos_examples(neg_regex, num_gen=50) 76 | random.shuffle(over_exs) 77 | exs.extend(over_exs[:num_keep]) 78 | return exs 79 | 80 | def parse_examples(exs_out): 81 | lines = exs_out.split("\n") 82 | lines = lines[1:] 83 | lines = [x for x in lines if len(x)] 84 | fields = [(x[1:-3],x[-1]) for x in lines] 85 | fields = [x[0] for x in fields if x[1] == "+"] 86 | return fields 87 | 88 | def gen_examples_file(filename, regex): 89 | out = subprocess.check_output( 90 | ['java', '-cp', './external/jars/datagen.jar:./external/lib/*', '-ea', 'datagen.Main', 'example', 91 | str(NUM_TRYIED), regex]) 92 | exs_out = out.decode("utf-8") 93 | 94 | lines = exs_out.split("\n") 95 | lines = lines[1:] 96 | lines = [x for x in lines if len(x)] 97 | random.shuffle(lines) 98 | lines = lines[:20] 99 | lines = ["// examples"] + lines + ["", "// gt", regex] 100 | 101 | with open(filename, "w") as f: 102 | f.write("\n".join(lines)) 103 | 104 | def make_examples_file(filename, regex, pos_exs, neg_exs): 105 | pos_lines = pos_exs 106 | random.shuffle(pos_lines) 107 | pos_lines = pos_lines[:NUM_EXS] 108 | pos_lines = ['"{}",+'.format(x) for x in pos_lines] 109 | 110 | neg_lines = neg_exs 111 | random.shuffle(neg_lines) 112 | neg_lines = neg_lines[:NUM_EXS] 113 | 114 | spec = regex.specification() 115 | gt = regex.ground_truth() 116 | match_results = [match_spec_example(spec, x) for x in neg_lines] 117 | neg_lines = [x[0] for x in zip(neg_lines, match_results) if x[1] == "false"] 118 | print("before:", len(neg_exs), "after:", len(neg_lines)) 119 | neg_lines = ['"{}",-'.format(x) for x in neg_lines] 120 | 121 | lines = ["// examples"] + pos_lines + neg_lines + ["", "// gt", gt] 122 | with open(filename, "w") as f: 123 | f.write("\n".join(lines)) 124 | 125 | # def prepare_data(): 126 | # regexes_by_type = [read_tsv_file(join(RAW_DIR, t + ".txt")) for t in REGEX_TYPES] 127 | # regexes_by_type = [[build_func_from_str(x[0]) for x in regexes] for regexes in regexes_by_type] 128 | 129 | # pos_examples_by_type = [read_tsv_file(join(RAW_DIR, t + "-pos.txt")) for t in REGEX_TYPES] 130 | # neg_examples_by_type = [read_tsv_file(join(RAW_DIR, t + "-neg.txt")) for t in REGEX_TYPES] 131 | # # tgt_num_by_type = [15, 15, 30] 132 | # tgt_num_by_type = [50, 50, 50] 133 | # data_regexes = [] 134 | # for regexes, pos_exs, neg_exs, num_tgt in zip(regexes_by_type, pos_examples_by_type, neg_examples_by_type, tgt_num_by_type): 135 | # reg_ex_pairs = list(zip(regexes, pos_exs, neg_exs)) 136 | # print("Len before:",len(reg_ex_pairs)) 137 | # reg_ex_pairs = [x for x in reg_ex_pairs if len(x[1]) >= NUM_KEEP] 138 | # print("Len after:",len(reg_ex_pairs)) 139 | # data_regexes.extend(reg_ex_pairs[:num_tgt]) 140 | 141 | # with open("pilot.txt", "w") as f: 142 | # for p in data_regexes: 143 | # f.write(tok(p[0].logical_form()) + "\n") 144 | 145 | # with open("pilot-spec.txt", "w") as f: 146 | # for p in data_regexes: 147 | # f.write(p[0].specification() + "\n") 148 | 149 | # with open("pilot.csv", "w") as f: 150 | # lines = [] 151 | # lines.append("image_url,str_examples") 152 | # for i, p in enumerate(data_regexes): 153 | # img_url = '"http://taur.cs.utexas.edu/hidden/p/{}.png"'.format(i) 154 | # random.shuffle(p[1]) 155 | # exs_str = '"
    {}
"'.format("".join(["
  • {}
  • ".format(x) for x in p[1][:NUM_KEEP]])) 156 | # lines.append("{},{}".format(img_url, exs_str)) 157 | # f.write("\n".join(lines)) 158 | 159 | def prepare_data(): 160 | regexes_by_type = [read_tsv_file(join(RAW_DIR, t + ".txt")) for t in REGEX_TYPES] 161 | regexes_by_type = [[build_func_from_str(x[0]) for x in regexes] for regexes in regexes_by_type] 162 | 163 | pos_examples_by_type = [read_tsv_file(join(RAW_DIR, t + "-pos.txt")) for t in REGEX_TYPES] 164 | # neg_examples_by_type = [read_tsv_file(join(RAW_DIR, t + "-neg.txt")) for t in REGEX_TYPES] 165 | # tgt_num_by_type = [15, 15, 30] 166 | tgt_num_by_type = [50, 50, 50] 167 | data_regexes = [] 168 | for regexes, pos_exs, num_tgt in zip(regexes_by_type, pos_examples_by_type, tgt_num_by_type): 169 | reg_ex_pairs = list(zip(regexes, pos_exs)) 170 | print("Len before:",len(reg_ex_pairs)) 171 | reg_ex_pairs = [x for x in reg_ex_pairs if len(x[1]) >= NUM_KEEP] 172 | print("Len after:",len(reg_ex_pairs)) 173 | data_regexes.extend(reg_ex_pairs[:num_tgt]) 174 | 175 | with open("pilot.txt", "w") as f: 176 | for p in data_regexes: 177 | f.write(tok(p[0].logical_form()) + "\n") 178 | 179 | with open("pilot.csv", "w") as f: 180 | lines = [] 181 | lines.append("image_url,str_examples") 182 | for i, p in enumerate(data_regexes): 183 | img_url = '"http://taur.cs.utexas.edu/hidden/p/{}.png"'.format(i) 184 | random.shuffle(p[1]) 185 | exs_str = '"
      {}
    "'.format("".join(["
  • {}
  • ".format(x) for x in p[1][:NUM_KEEP]])) 186 | lines.append("{},{}".format(img_url, exs_str)) 187 | f.write("\n".join(lines)) 188 | 189 | def io_test(): 190 | for t in REGEX_TYPES: 191 | regexes = read_tsv_file(join(RAW_DIR, t + ".txt")) 192 | 193 | examples = [build_func_from_str(x[0]).to_string() for x in regexes] 194 | with open(join(RAW_DIR, t + "-re.txt"), "w") as f: 195 | [f.write(r + "\n") for r in examples] 196 | 197 | def prepare_examples(): 198 | for t in REGEX_TYPES: 199 | regexes = read_tsv_file(join(RAW_DIR, t + ".txt")) 200 | regexes = [build_func_from_str(x[0]) for x in regexes] 201 | 202 | # neg_examples = [gen_neg_examples(x) for x in regexes] 203 | # neg_examples_lines = ["\t".join(x) for x in neg_examples] 204 | # with open(join(RAW_DIR, t + "-neg.txt"), "w") as f: 205 | # f.write("\n".join(neg_examples_lines)) 206 | 207 | pos_examples = [gen_pos_examples(x) for x in regexes] 208 | pos_examples_lines = ["\t".join(x) for x in pos_examples] 209 | with open(join(RAW_DIR, t + "-pos.txt"), "w") as f: 210 | f.write("\n".join(pos_examples_lines)) 211 | 212 | def read_example_file(filename): 213 | with open(filename) as f: 214 | lines = f.readlines() 215 | lines = [x.rstrip() for x in lines] 216 | lines = lines[1:-2] 217 | lines = [x for x in lines if len(x)] 218 | fields = [(x[1:-3],x[-1]) for x in lines] 219 | pos_exs = [x[0] for x in fields if x[1] == "+"] 220 | neg_exs = [x[0] for x in fields if x[1] == "-"] 221 | return pos_exs, neg_exs 222 | 223 | def compare_neg_examples(): 224 | 225 | root = HtmlElement("html") 226 | for id in range(60): 227 | id = str(id) 228 | new_fname = join("./benchmark", id) 229 | old_fname = join("./benchmark-old", id) 230 | _, new_neg_exs = read_example_file(new_fname) 231 | _, old_neg_exs = read_example_file(old_fname) 232 | 233 | div = HtmlElement("div") 234 | img_url = '"http://taur.cs.utexas.edu/hidden/p/{}.png"'.format(id) 235 | root.add_children(HtmlElement("p", [ImgElement(img_url)])) 236 | root.add_children(HtmlElement("p", [PlanText("ID:" + id)])) 237 | 238 | with open("compare_neg.html", "w") as f: 239 | f.write(root.html()) 240 | 241 | def gen_hit_pos_exs(regex): 242 | hit_pos_exs = [] 243 | if isinstance(regex, InfSeperatedField): 244 | spec_regexes = regex.required_special_examples() 245 | for spec_r in spec_regexes: 246 | spec_pos = gen_pos_examples(spec_r, is_spec=True) 247 | if spec_pos: 248 | hit_pos_exs.append(random.choice(spec_pos)) 249 | else: 250 | print(spec_r, "Not Good") 251 | pos_exs = gen_pos_examples(regex) 252 | random.shuffle(pos_exs) 253 | hit_pos_exs.extend(pos_exs[:(NUM_KEEP - len(hit_pos_exs))]) 254 | return hit_pos_exs 255 | 256 | def gen_hit_neg_exs(regex): 257 | hit_neg_exs = gen_neg_examples(regex) 258 | random.shuffle(hit_neg_exs) 259 | hit_neg_exs = hit_neg_exs[:(2*NUM_KEEP)] 260 | spec = regex.specification() 261 | match_results = [match_spec_example(spec, x) for x in hit_neg_exs] 262 | hit_neg_exs = [x[0] for x in zip(hit_neg_exs, match_results) if x[1] == "false"] 263 | return hit_neg_exs[:NUM_KEEP] 264 | 265 | def prepare_hit_fields(name, regex): 266 | # id, img_url, pos_exs, neg_exs 267 | id = name 268 | img_url = "http://taur.cs.utexas.edu/hidden/p/{}.png".format(id) 269 | pos_exs = gen_hit_pos_exs(regex) 270 | pos_exs = '
      {}
    '.format("".join(["
  • {}
  • ".format(x) for x in pos_exs])) 271 | neg_exs = gen_hit_neg_exs(regex) 272 | neg_exs = '
      {}
    '.format("".join(["
  • {}
  • ".format(x) for x in neg_exs])) 273 | return (id, img_url, pos_exs, neg_exs) 274 | 275 | def prepare_hits(): 276 | batch_id = 3 277 | # prepare batch.txt 278 | # list of id regex 279 | named_regexes = [] 280 | tgt_num_by_type = [150, 150, 150] 281 | for num_tgt, t in zip(tgt_num_by_type, REGEX_TYPES): 282 | regexes = read_tsv_file(join(RAW_DIR, "batch{}_{}.txt".format(batch_id, t))) 283 | regexes = [build_func_from_str(x[0]) for x in regexes] 284 | regexes = regexes[:num_tgt] 285 | regexes = [("b-{}_t-{}_id-{}".format(batch_id, t, i), x) for i, x in enumerate(regexes)] 286 | named_regexes.extend(regexes) 287 | 288 | with open("batch-{}-record.txt".format(batch_id), "w") as f: 289 | for name, r in named_regexes: 290 | f.write("{} {} {}\n".format(name, r.to_string(), tok(r.logical_form()))) 291 | 292 | with open("batch-{}.txt".format(batch_id), "w") as f: 293 | for name, r in named_regexes: 294 | f.write("{} {}\n".format(name, tok(r.logical_form()))) 295 | 296 | # prepare batch.csv 297 | # id, img_url, pos_exs, neg_exs 298 | csv_fields = [] 299 | for name, r in named_regexes: 300 | print(name) 301 | fields = prepare_hit_fields(name, r) 302 | csv_fields.append(fields) 303 | csv_lines = [] 304 | csv_lines.append("id,img_url,pos_exs,neg_exs\n") 305 | for id, img_url, pos_exs, neg_exs in csv_fields: 306 | csv_lines.append('"{}","{}","{}","{}"\n'.format(id,img_url,pos_exs,neg_exs)) 307 | with open("batch-{}.csv".format(batch_id), "w") as f: 308 | f.writelines(csv_lines) 309 | 310 | # preview 311 | root = HtmlElement("html") 312 | for id, img_url, pos_exs, neg_exs in csv_fields: 313 | div = HtmlElement("div") 314 | img_url = '"http://0.0.0.0:8000/preview/{}.png"'.format(id) 315 | root.add_children(HtmlElement("p", [ImgElement(img_url)])) 316 | root.add_children(HtmlElement("p", [PlanText("ID:" + id)])) 317 | root.add_children(HtmlElement("p", [PlanText("POS:")])) 318 | root.add_children(HtmlElement("p", [PlanText(pos_exs)])) 319 | root.add_children(HtmlElement("p", [PlanText("NEG:")])) 320 | root.add_children(HtmlElement("p", [PlanText(neg_exs)])) 321 | 322 | with open("preview_dataset.html", "w") as f: 323 | f.write(root.html()) 324 | 325 | 326 | def main(): 327 | # prepare_examples() 328 | # prepare_data() 329 | # compare_neg_examples() 330 | prepare_hits() 331 | 332 | if __name__ == "__main__": 333 | random.seed(123) 334 | main() -------------------------------------------------------------------------------- /toolkit/regex_io.py: -------------------------------------------------------------------------------- 1 | from base import * 2 | from template import * 3 | from constraints import * 4 | import csv 5 | import sys 6 | 7 | def str_to_class(str): 8 | return getattr(sys.modules[__name__], str) 9 | 10 | def tokenize(x): 11 | y = [] 12 | while len(x) > 0: 13 | head = x[0] 14 | if head in ["?", "(", ")", "{", "}", ","]: 15 | y.append(head) 16 | x = x[1:] 17 | elif head == "<": 18 | end = x.index(">") + 1 19 | y.append(x[:end]) 20 | x = x[end:] 21 | else: 22 | leftover = [(i in ["?", "(", ")", "{", "}", "<", ">", ","]) for i in x] 23 | end = leftover.index(True) 24 | y.append(x[:end]) 25 | x = x[end:] 26 | return y 27 | 28 | class ASTNode: 29 | def __init__(self, node_class, children=[], params=[]): 30 | self.node_class = node_class 31 | self.children = children 32 | self.params = params 33 | 34 | def logical_form(self): 35 | if len(self.children) + len(self.params) > 0: 36 | return self.node_class + "(" + ",".join([x.logical_form() for x in self.children] + [str(x) for x in self.params]) + ")" 37 | else: 38 | return self.node_class 39 | 40 | def tokenized_logical_form(self): 41 | if len(self.children) + len(self.params) > 0: 42 | toks = [self.node_class] + ["("] 43 | toks.extend(self.children[0].tokenized_logical_form()) 44 | for c in self.children[1:]: 45 | toks.append(",") 46 | toks.extend(c.tokenized_logical_form()) 47 | for p in [str(x) for x in self.params]: 48 | toks.append(",") 49 | toks.append(p) 50 | toks.append(")") 51 | return toks 52 | else: 53 | return [self.node_class] 54 | 55 | def build_func_from_str(s): 56 | toks = tokenize(s) 57 | ast, _ = build_ast_from_toks(toks, 0) 58 | func = build_func_from_ast(ast) 59 | return func 60 | 61 | 62 | def build_func_from_ast(ast): 63 | node_class = ast.node_class 64 | if node_class == "SeperatedField": 65 | fields = [build_func_from_ast(x) for x in ast.children[0].children] 66 | delimeter = build_func_from_ast(ast.children[1].children[0]) 67 | return SeperatedField(delimeter, fields) 68 | elif node_class == "InfSeperatedField": 69 | field = build_func_from_ast(ast.children[0].children[0]) 70 | delimeter = build_func_from_ast(ast.children[1].children[0]) 71 | return InfSeperatedField(delimeter, field) 72 | elif node_class == "SingleToken": 73 | cc_type = str_to_class(ast.children[0].node_class[1:-1]) 74 | tok = ast.children[1].node_class[1:-1] 75 | return SingleToken(cc_type, tok) 76 | elif node_class == "StringToken": 77 | cc_type = str_to_class(ast.children[0].node_class[1:-1]) 78 | tok = ast.children[1].node_class[1:-1] 79 | return StringToken(cc_type, tok) 80 | elif node_class == "RepeatMod": 81 | child = build_func_from_ast(ast.children[0]) 82 | return RepeatMod(child, ast.params[0]) 83 | elif node_class == "RepeatRangeMod": 84 | child = build_func_from_ast(ast.children[0]) 85 | return RepeatRangeMod(child, ast.params[0], ast.params[1]) 86 | elif node_class == "RepeatAtLeastMod": 87 | child = build_func_from_ast(ast.children[0]) 88 | return RepeatAtLeastMod(child, ast.params[0]) 89 | 90 | 91 | children = [build_func_from_ast(x) for x in ast.children] 92 | 93 | cls_type = str_to_class(node_class) 94 | return cls_type(*children) 95 | 96 | def build_ast_from_toks(toks, cur): 97 | node_class = None 98 | children = [] 99 | params = [] 100 | 101 | while cur < len(toks): 102 | head = toks[cur] 103 | if head.startswith("<") and head.endswith(">"): 104 | return ASTNode(head), cur + 1 105 | elif head == ")": 106 | return ASTNode(node_class, children, params), cur + 1 107 | elif head == "(" or head == ",": 108 | next_tok = toks[cur + 1] 109 | if next_tok.isdigit(): 110 | params.append(int(next_tok)) 111 | cur = cur + 2 112 | elif head == "(" and next_tok == ")": 113 | return ASTNode(node_class), cur + 2 114 | else: 115 | ret_vals = build_ast_from_toks(toks, cur + 1) 116 | children.append(ret_vals[0]) 117 | cur = ret_vals[1] 118 | else: 119 | node_class = head 120 | cur = cur + 1 121 | print(cur, node_class, children, params) 122 | 123 | def build_dataset_ast_from_toks(toks, cur): 124 | node_class = None 125 | children = [] 126 | params = [] 127 | 128 | while cur < len(toks): 129 | head = toks[cur] 130 | if head.startswith("<") and head.endswith(">"): 131 | return ASTNode(head), cur + 1 132 | elif head.startswith("const") and head[5:].isdigit(): 133 | return ASTNode(head), cur + 1 134 | elif head == ")": 135 | return ASTNode(node_class, children, params), cur + 1 136 | elif head == "(" or head == ",": 137 | next_tok = toks[cur + 1] 138 | if next_tok.isdigit(): 139 | params.append(int(next_tok)) 140 | cur = cur + 2 141 | elif head == "(" and next_tok == ")": 142 | return ASTNode(node_class), cur + 2 143 | else: 144 | ret_vals = build_dataset_ast_from_toks(toks, cur + 1) 145 | children.append(ret_vals[0]) 146 | cur = ret_vals[1] 147 | else: 148 | node_class = head 149 | cur = cur + 1 150 | print(cur, node_class, children, params) 151 | 152 | def read_tsv_file(filename, delimiter="\t"): 153 | with open(filename) as f: 154 | lines = f.readlines() 155 | lines = [x.rstrip() for x in lines] 156 | lines = [x.split(delimiter) for x in lines] 157 | return lines 158 | 159 | def row_to_record(row, header): 160 | record = {} 161 | record["hit_id"] = row[header.index("HITId")] 162 | record["worker_id"] = row[header.index("WorkerId")] 163 | record["work_time"] = row[header.index("WorkTimeInSeconds")] 164 | pos_exs = row[header.index("Input.pos_exs")] 165 | neg_exs = row[header.index("Input.neg_exs")] 166 | #
    • x
    • x
    167 | pos_exs = pos_exs[8:-10].split("
  • ") 168 | neg_exs = neg_exs[8:-10].split("
  • ") 169 | record["pos_examples"] = "\n".join(pos_exs) 170 | record["neg_examples"] = "\n".join(neg_exs) 171 | record["imgurl"] = row[header.index("Input.img_url")] 172 | record["problem_id"] = row[header.index("Input.id")] 173 | record["description"] = row[header.index("Answer.description")] 174 | record["pos_exs"] = row[header.index("Answer.pos_example")] 175 | if len(row) < len(header): 176 | row.append("") 177 | row.append("") 178 | return record 179 | 180 | def read_result(filename): 181 | with open(filename) as csv_file: 182 | csv_reader = csv.reader(csv_file, delimiter=',') 183 | header = next(csv_reader) 184 | print(header) 185 | # exit() 186 | return [row_to_record(x, header) for x in csv_reader] 187 | 188 | def group_by_filed(records, key): 189 | key_set = list(set([x[key] for x in records])) 190 | key_set.sort() 191 | 192 | grouped_records = dict(zip(key_set, [[x for x in records if x[key] == y] for y in key_set])) 193 | return grouped_records 194 | -------------------------------------------------------------------------------- /toolkit/usage_example.py: -------------------------------------------------------------------------------- 1 | from doctest import Example 2 | from random import random 3 | 4 | 5 | import random 6 | import sys 7 | import random 8 | from collections import Counter 9 | import subprocess 10 | import re 11 | from base import * 12 | from template import * 13 | from constraints import ComposedByCons 14 | from filters import * 15 | from os.path import join 16 | from regex_io import read_tsv_file, build_func_from_str 17 | from prepare_regex_data import gen_hit_pos_exs, gen_hit_neg_exs, gen_random_examples 18 | 19 | # RegexClass: refer to Function class in base.py 20 | def sample_regexes_usage(): 21 | random.seed(123) 22 | 23 | # sample a single regex wth the types described in the paper 24 | print(UnstructuredField.generate(5).specification()) 25 | print(ConcatenationField.generate(6).specification()) 26 | print(SeperatedField.generate().specification()) 27 | 28 | # this will yield some crapy regex so we need to do rejection 29 | num_per_type = 10 30 | 31 | uns_regexes =[UnstructuredField.generate(5) for _ in range(num_per_type)] 32 | cat_regexes = [ConcatenationField.generate(6) for _ in range(num_per_type)] 33 | sep_regexes = [SeperatedField.generate() for _ in range(num_per_type)] 34 | regexes = uns_regexes + cat_regexes + sep_regexes 35 | 36 | # do filtering 37 | regexes = filter_regexes(regexes) 38 | 39 | print(len(regexes)) 40 | print(regexes[0].specification()) 41 | 42 | # save to file 43 | with open(join("usage_example_raw_regexes.txt"), "w") as f: 44 | [f.write("{}\n".format(r.to_string())) for r in uns_regexes] 45 | 46 | 47 | # x: Function class 48 | def sample_distinguish_examples(x): 49 | pos_examples = gen_hit_pos_exs(x) 50 | neg_examples = gen_hit_neg_exs(x) 51 | pos_examples = [(x,'+') for x in pos_examples] 52 | neg_examples = [(x,'-') for x in neg_examples] 53 | return pos_examples + neg_examples 54 | 55 | # Distingushing of examples as described in the paper 56 | def sample_distinguishing_exmples_usage(): 57 | random.seed(123) 58 | 59 | # read the stored file, and get the first data point as example 60 | regexes = read_tsv_file("usage_example_raw_regexes.txt") 61 | regexes = [build_func_from_str(x[0]) for x in regexes] 62 | regex = regexes[0] 63 | print(regex.to_string()) 64 | print(regex.specification()) 65 | 66 | examples = sample_distinguish_examples(regex) 67 | print(examples) 68 | 69 | def sample_random_examples(x): 70 | return gen_random_examples(x.specification()) 71 | 72 | # random examples 73 | def sample_random_examples_usage(): 74 | random.seed(123) 75 | 76 | # read the stored file, and get the first data point as example 77 | regexes = read_tsv_file("usage_example_raw_regexes.txt") 78 | regexes = [build_func_from_str(x[0]) for x in regexes] 79 | regex = regexes[0] 80 | print(regex.to_string()) 81 | print(regex.specification()) 82 | 83 | examples = sample_random_examples(regex) 84 | print(examples) 85 | 86 | 87 | # input, a specification a & and a specification b, make sure a and b aren't equivelant 88 | # sample differentiating examples that satisfies spec a, but dissatisfied spec b 89 | # return the list of such examples 90 | # may return empty list when b is a superset of a, e.g., a == , b == 91 | # try to exchange a and b if that happens 92 | # if may also return empty list in some case even if b is not a super set of a due to low probabilitis of those differentiating examples (some subtle differences) 93 | def sample_examples_to_differentiate_two_regexes(spec_a, spec_b): 94 | joint_spec = 'and({},not({}))'.format(spec_a, spec_b) 95 | examples = gen_random_examples(joint_spec, num_keep=100, num_gen=100) 96 | examples = [x for x in examples if x[1] == '+'] 97 | return examples 98 | 99 | # sample differentiating examples to differetiate two regexes, the examples will satisify one but dissatisfy another 100 | def sample_differentiating_examples_usage(): 101 | print(sample_examples_to_differentiate_two_regexes('','')) 102 | print(sample_examples_to_differentiate_two_regexes('concat(repeatatleast(,3),)','concat(repeatatleast(,4),)')) 103 | 104 | if __name__ == '__main__': 105 | sample_regexes_usage() 106 | sample_distinguishing_exmples_usage() 107 | sample_random_examples_usage() 108 | sample_differentiating_examples_usage() 109 | --------------------------------------------------------------------------------