├── .gitmodules ├── README.md ├── TODO.md ├── anaconda-project.yml ├── data ├── DeepCoder_data.tar.bz2 ├── T3_test_data_new.p └── T44_test_new.p ├── data_src ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── dc_program.cpython-36.pyc │ ├── makeDeepcoderData.cpython-36.pyc │ └── makeRobustFillData.cpython-36.pyc ├── dc_program.py ├── makeAlgolispData.py ├── makeDeepcoderData.py ├── makeRegexData.py ├── makeRobustFillData.py └── make_T4_test_data.py ├── eval ├── __init__.py ├── evaluate_algolisp.py ├── evaluate_deepcoder.py ├── evaluate_regex.py └── evaluate_robustfill.py ├── execute_any_gpu.sh ├── execute_gf_gpu.sh ├── execute_gpu.sh ├── execute_k80_gpu.sh ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── deepcoderModel.cpython-36.pyc └── deepcoderModel.py ├── plot ├── __init__.py ├── make_final_plots.py ├── make_frontier_plot.py ├── manipulate_results.py ├── manual_plot.py └── p_solved_hack.py ├── tests ├── __init__.py ├── demo.py ├── test_beam.py ├── test_callCompiled.py ├── test_parse.py └── test_sample_deepcoderIO.py ├── train ├── __init__.py ├── algolisp_train_dc_model.py ├── deepcoder_train_dc_model.py ├── main_supervised_algolisp.py ├── main_supervised_deepcoder.py ├── main_supervised_regex.py ├── main_supervised_robustfill.py ├── robustfill_train_dc_model.py └── sketch_project_rl_regex.py └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── deepcoder_util.cpython-36.pyc └── robustfill_util.cpython-36.pyc ├── algolisp_pypy_util.py ├── algolisp_util.py ├── deepcoder_util.py ├── naps_util.py ├── pypy_util.py ├── rb_pypy_util.py ├── regex_pypy_util.py ├── regex_util.py └── robustfill_util.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pinn"] 2 | path = pinn 3 | url = git@github.com:insperatum/pinn.git 4 | [submodule "pregex"] 5 | path = pregex 6 | url = git@github.com:insperatum/pregex.git 7 | [submodule "vhe"] 8 | path = vhe 9 | url = git@github.com:insperatum/vhe.git 10 | [submodule "ec"] 11 | path = ec 12 | url = git@github.com:ellisk42/ec.git 13 | [submodule "program_synthesis"] 14 | path = program_synthesis 15 | url = git@github.com:mtensor/program_synthesis.git 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NEURAL SKETCH PROJECT 2 | This is the code used for the ICML 2019 paper [Learning to Infer Program Sketches](https://arxiv.org/abs/1902.06349). 3 | 4 | 5 | ## Usage: 6 | A user should only have to go into the `train` folder, the `eval` folder, and the `plot` folder. 7 | `train` and `eval` folders have train and eval scripts for each domain. 8 | 9 | 10 | the `train` folder is where the training scripts are. You should run from the top level directory, with the `--pretrain` flag, the first time you run. ex: 11 | ``` 12 | anaconda-project run python main_supervised_deepcoder.py --pretrain 13 | ``` 14 | 15 | 16 | To fully train the SketchAdapt system, first train the synthesizer (referred to as the `dc_model` in the codebase): 17 | ``` 18 | python train/deepcoder_train_dc_model.py 19 | ``` 20 | and pretrain the sketch generator: 21 | ``` 22 | python train/main_supervised_deepcoder.py --pretrain 23 | ``` 24 | Then train the sketch generator: 25 | ``` 26 | python train/main_supervised_deepcoder.py 27 | ``` 28 | Evaluation can be run with: 29 | ``` 30 | python eval/evaluate_deepcoder.py 31 | ``` 32 | 33 | 34 | NB: On the MIT openmind computer cluster, the `*.sh` files are used to schedule jobs. I usually do the following: 35 | ``` 36 | sbatch execute_gpu.sh python main_supervised_deepcoder.py --pretrain 37 | ``` 38 | 39 | ## THINGS TO NOTE 40 | - for various reasons, the `ec` subdir had to be added to path, so if you are looking at an import statement and don't see the folder in the top level, it's inside `ec/` 41 | - the naming convention around "deepcoder" and "robustfill" is not great. "dc" is often used to represent the 42 | - I use the `working-mnye` branch, so it is more up to date, with all submodules, etc. If you can't find something on `master`, look it `working-mnye`. 43 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # Neural Sketch project 2 | - learning to learn sketches for programs 3 | 4 | #TODO: 5 | - [x] Get regexes working 6 | - [ ] figure out correct objective for learning holes -completed ish - no 7 | - [x] start by training holes on pretrain 8 | - [x] have lower likelihood of holes 9 | - [x] make hole probibility vary with depth 10 | - [x] fuss with correct way to generate sketches 11 | - [ ] Use more sophisticated nn architectures (perhaps only after getting RobustFill to work) 12 | - [x] Try other domains?? 13 | 14 | 15 | #New TODO: 16 | - [X] use fully supervised version of holes 17 | - [ ] more sophisiticated nn architectures 18 | 19 | 20 | 21 | #TODO for fully supervised (`at main_supervised.py`): 22 | - [X] implement first pass at make_holey_supervised 23 | - [X] test first pass at make_holey_supervised 24 | - [X] have some sort of enumeration 25 | - [ ] slightly dynamic batch bullshit so that i don't only take the top 1 sketch, and softmax them together as luke suggested 26 | - [ ] get rid of index in line where sketches are created 27 | - [ ] make scores be pytorch tensor 28 | - [ ] perhaps convert to EC domain for enumeration?? - actually not too hard ... 29 | - [ ] Kevin's continuation representation (see notes) 30 | 31 | #TODO for PROBPROG 2018 submission: 32 | - [ ] make figures: 33 | - [X] One graphic of model for model section 34 | - [X] One example of data 35 | - [X] One example of model outputs or something 36 | - [ ] One results graph, if i can swing it? 37 | - [X] Use regex figure from nampi submission? 38 | - [ ] possibly parallelize for evaluation (or not) 39 | - [ ] write paper 40 | - [X] show examples and stuff 41 | - [ ] explain technical process 42 | - [ ] write intro 43 | - [ ] remember to show 44 | 45 | 46 | 47 | #TODO for ICLR submission 48 | - [ ] refactor/generalize evaluation 49 | - [ ] beam search/Best first search 50 | - [ ] multiple holes (using EC langauge) 51 | - [X] build up Syntax-checker LSTM 52 | 53 | 54 | #TODO for DEEPCODER for ICLR: 55 | - [X] dataset/how to train 56 | - [X] primitive set for deepcoder 57 | - [X] implement the program.flatten() method for nns (super easy) 58 | - [X] implement parsing from string (not bad either, use stack machine ) 59 | - [X] implement training (pretrain and makeholey) 60 | - [X] implement evaluation 61 | - [ ] test out reversing nn implementation 62 | 63 | #ACTUAL TODO for DEEPCODER ICLR 64 | Training: 65 | - [X] DSL weights (email dude) - x 66 | - [X] generating train/test data efficiently 67 | - [X] offline dataset generation 68 | - [X] constraint based data gen 69 | - [ ] pypy for data gen 70 | - [X] modify mutation code so no left_application is a hole 71 | - [X] adding HOLE to parser when no left_application is a hole 72 | - [X] Compare no left_application is a hole to the full case where everything can be a hole 73 | - [X] write parser for full case 74 | - [X] add HOLE to nn 75 | - [X] dealing with different request types 76 | - [ ] multple holes (modify mutator) 77 | - [X] deepcoder recognition model in the loop - half completed 78 | - [X] simple deepcoder baseline 79 | - [ ] make syntaxrobustfill have same output signature as regular robustfill 80 | - [X] limit depth of programs 81 | - [X] offline dataset collection and training 82 | - [X] generation 83 | - [X] training 84 | - [X] with sketch stuff 85 | - [X] filtering out some dumb programs (ex lambda $0) 86 | - [X] use actual deepcoder data, write converter 87 | - [X] deal with issue of different types and IO effectively in a reasonable manner 88 | - [X] incorporate constraint based stuff from Marc 89 | 90 | 91 | Evaluation: 92 | - [X] parsing w/Holes - sorta 93 | - [ ] beam search 94 | - [X] figure out correct evaluation scheme 95 | - [ ] parallelization? (dealing with eval speed) 96 | - [X] good test set 97 | - [X] validation set 98 | - [ ] test multiple holes training/no_left_application vs other case 99 | - [ ] using speed 100 | 101 | Overall: 102 | - [ ] run training and evaluation (on val set) together for multple experiments to find best model 103 | - [X] refactor everything (a model Class, perhaps??) 104 | 105 | Tweaking: 106 | - [X] tweak topk - did it with a temperature param, seemed to work well 107 | - [ ] neural network tweaks for correct output format (deal with types and such) - I just fudged it 108 | - [ ] 109 | 110 | #TODO for refactoring: 111 | - [X] make one main (domain agnostic) training script/file - decided against 112 | - [X] make one main (domain agnostic) evaluation script/file - decided against 113 | - [X] figure out the correct class structure to make work easier to extend and tweak. - decided against 114 | 115 | 116 | #TODO for NAPS: 117 | - [ ] read their code 118 | - [ ] understand how to extend when needed for enumeration stuff 119 | - [ ] make validation set? 120 | - [ ] EMAIL THEM FOR ADDITIONAL QUESTIONS! TRY TO KNOW WHAT TO ASK FOR BY EOD FRIDAY. (IDEAL) 121 | - [ ] make choices such as types of holes, etc. 122 | - [ ] think about how to enumerate, etc. 123 | 124 | #TODO for TPUs: 125 | - [ ] Graph networks? 126 | - [ ] TPU stuff 127 | - [ ] read brian's stuff (https://github.com/tensorflow/minigo/blob/master/dual_net.py) 128 | 129 | 130 | #FRIDAY TODO: 131 | - [X] loading dataset 132 | - [X] evaluation code, apart from IO concerns 133 | 134 | 135 | 136 | #OCTOBER CLEAN UP 137 | - [X] switch to hierarchical file structure 138 | - [X] add EC as submodule or something 139 | - [ ] fix 'alternate' bug in evaluate code 140 | - [ ] eval script 141 | - [ ] loader scripts? 142 | - [ ] possibly find better names for saved things 143 | - [ ] remove all magic values 144 | - [ ] deal with silly sh scripts 145 | - [ ] fix readme for other users 146 | - [ ] run those other tests 147 | - [ ] perhaps redesign results stuff 148 | - [ ] make sure pypy stuff still works 149 | - [ ] make sure saved models work 150 | - [ ] figure out what needs to be abstracted, and abstract 151 | 152 | folders to comb through for hierarchical struct: 153 | - [X] train 154 | - [X] eval 155 | - [X] tests 156 | - [X] data_src 157 | - [X] models 158 | - [X] plot 159 | - [X] utils 160 | 161 | - [ ] run dc with smaller train 4 split 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /anaconda-project.yml: -------------------------------------------------------------------------------- 1 | # This is an Anaconda project file. 2 | # 3 | # Here you can describe your project and how to run it. 4 | # Use `anaconda-project run` to run the project. 5 | # The file is in YAML format, please see http://www.yaml.org/start.html for more. 6 | # 7 | 8 | # 9 | # Set the 'name' key to name your project 10 | # 11 | name: neural_sketch 12 | # 13 | # Set the 'icon' key to give your project an icon 14 | # 15 | icon: 16 | # 17 | # Set a one-sentence-or-so 'description' key with project details 18 | # 19 | description: 20 | # 21 | # In the commands section, list your runnable scripts, notebooks, and other code. 22 | # Use `anaconda-project add-command` to add commands. 23 | # 24 | commands: {} 25 | # 26 | # In the variables section, list any environment variables your code depends on. 27 | # Use `anaconda-project add-variable` to add variables. 28 | # 29 | variables: {} 30 | # 31 | # In the services section, list any services that should be 32 | # available before your code runs. 33 | # Use `anaconda-project add-service` to add services. 34 | # 35 | services: {} 36 | # 37 | # In the downloads section, list any URLs to download to local files 38 | # before your code runs. 39 | # Use `anaconda-project add-download` to add downloads. 40 | # 41 | downloads: {} 42 | # 43 | # In the packages section, list any packages that must be installed 44 | # before your code runs. 45 | # Use `anaconda-project add-packages` to add packages. 46 | # 47 | packages: 48 | - pytorch 49 | - torchvision 50 | - scipy 51 | - matplotlib 52 | 53 | # 54 | # In the channels section, list any Conda channel URLs to be searched 55 | # for packages. 56 | # 57 | # For example, 58 | # 59 | # channels: 60 | # - mychannel 61 | # 62 | channels: [] 63 | # 64 | # In the platforms section, list platforms the project should work on 65 | # Examples: "linux-64", "osx-64", "win-64" 66 | # Use `anaconda-project add-platforms` to add platforms. 67 | # 68 | platforms: 69 | - linux-64 70 | - osx-64 71 | - win-64 72 | # 73 | # You can define multiple, named environment specs. 74 | # Each inherits any global packages or channels, 75 | # but can have its own unique ones also. 76 | # Use `anaconda-project add-env-spec` to add environment specs. 77 | # 78 | env_specs: 79 | default: 80 | description: Default environment spec for running commands 81 | packages: [] 82 | channels: [] 83 | platforms: [] 84 | -------------------------------------------------------------------------------- /data/DeepCoder_data.tar.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data/DeepCoder_data.tar.bz2 -------------------------------------------------------------------------------- /data/T3_test_data_new.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data/T3_test_data_new.p -------------------------------------------------------------------------------- /data/T44_test_new.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data/T44_test_new.p -------------------------------------------------------------------------------- /data_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data_src/__init__.py -------------------------------------------------------------------------------- /data_src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data_src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data_src/__pycache__/dc_program.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data_src/__pycache__/dc_program.cpython-36.pyc -------------------------------------------------------------------------------- /data_src/__pycache__/makeDeepcoderData.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data_src/__pycache__/makeDeepcoderData.cpython-36.pyc -------------------------------------------------------------------------------- /data_src/__pycache__/makeRobustFillData.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/data_src/__pycache__/makeRobustFillData.cpython-36.pyc -------------------------------------------------------------------------------- /data_src/dc_program.py: -------------------------------------------------------------------------------- 1 | #import sys 2 | #import os 3 | #sys.path.append(os.path.abspath('./')) 4 | 5 | import numpy as np 6 | import os 7 | import sys 8 | 9 | from collections import namedtuple, defaultdict 10 | from math import ceil, sqrt 11 | 12 | 13 | Function = namedtuple('Function', ['src', 'sig', 'fun', 'bounds']) 14 | Program = namedtuple('Program', ['src', 'ins', 'out', 'fun', 'bounds']) 15 | 16 | 17 | # HELPER FUNCTIONS 18 | def type_to_string(t): 19 | if t == int: 20 | return 'int' 21 | if t == [int]: 22 | return '[int]' 23 | if t == bool: 24 | return 'bool' 25 | if t == [bool]: 26 | return '[bool]' 27 | raise ValueError('Type %s cannot be converted to string.' % t) 28 | 29 | def scanl1(f, xs): 30 | if len(xs) > 0: 31 | r = xs[0] 32 | for i in range(len(xs)): 33 | if i > 0: 34 | r = f.fun(r, xs[i]) 35 | yield r 36 | 37 | def SQR_bounds(A, B): 38 | l = max(0, A) # inclusive lower bound 39 | u = B - 1 # inclusive upper bound 40 | if l > u: 41 | return [(0, 0)] 42 | # now 0 <= l <= u 43 | # ceil(sqrt(l)) 44 | # Assume that if anything is valid then 0 is valid 45 | return [(-int(sqrt(u)), ceil(sqrt(u+1)))] 46 | 47 | def MUL_bounds(A, B): 48 | return SQR_bounds(0, min(-(A+1), B)) 49 | 50 | def scanl1_bounds(l, A, B, L): 51 | if l.src == '+' or l.src == '-': 52 | return [(int(A/L)+1, int(B/L))] 53 | elif l.src == '*': 54 | return [(int((max(0, A)+1) ** (1.0 / L)), int((max(0, B)) ** (1.0 / L)))] 55 | elif l.src == 'MIN' or l.src == 'MAX': 56 | return [(A, B)] 57 | else: 58 | raise Exception('Unsupported SCANL1 lambda, cannot compute valid input bounds.') 59 | 60 | # LINQ LANGUAGE 61 | def get_language(V): 62 | Null = V 63 | lambdas = [ 64 | Function('IDT', (int, int), lambda i: i, lambda A_B: [(A_B[0], A_B[1])]), 65 | 66 | Function('INC', (int, int), lambda i: i+1, lambda A_B10: [(A_B10[0], A_B10[1]-1)]), 67 | Function('DEC', (int, int), lambda i: i-1, lambda A_B11: [(A_B11[0]+1, A_B11[1])]), 68 | Function('SHL', (int, int), lambda i: i*2, lambda A_B12: [(int((A_B12[0]+1)/2), int(A_B12[1]/2))]), 69 | Function('SHR', (int, int), lambda i: int(float(i)/2), lambda A_B13: [(2*A_B13[0], 2*A_B13[1])]), 70 | Function('doNEG', (int, int), lambda i: -i, lambda A_B14: [(-A_B14[1]+1, -A_B14[0]+1)]), 71 | Function('MUL3', (int, int), lambda i: i*3, lambda A_B15: [(int((A_B15[0]+2)/3), int(A_B15[1]/3))]), 72 | Function('DIV3', (int, int), lambda i: int(float(i)/3), lambda A_B16: [(A_B16[0], A_B16[1])]), 73 | 74 | Function('MUL4', (int, int), lambda i: i*4, lambda A_B17: [(int((A_B17[0]+3)/4), int(A_B17[1]/4))]), 75 | Function('DIV4', (int, int), lambda i: int(float(i)/4), lambda A_B18: [(A_B18[0], A_B18[1])]), 76 | Function('SQR', (int, int), lambda i: i*i, lambda A_B19: SQR_bounds(A_B19[0], A_B19[1])), 77 | #Function('SQRT', (int, int), lambda i: int(sqrt(i)), lambda (A, B): [(max(0, A*A), B*B)]), 78 | 79 | Function('isPOS', (int, bool), lambda i: i > 0, lambda A_B20: [(A_B20[0], A_B20[1])]), 80 | Function('isNEG', (int, bool), lambda i: i < 0, lambda A_B21: [(A_B21[0], A_B21[1])]), 81 | Function('isODD', (int, bool), lambda i: i % 2 == 1, lambda A_B22: [(A_B22[0], A_B22[1])]), 82 | Function('isEVEN', (int, bool), lambda i: i % 2 == 0, lambda A_B23: [(A_B23[0], A_B23[1])]), 83 | 84 | Function('+', (int, int, int), lambda i, j: i+j, lambda A_B24: [(int(A_B24[0]/2)+1, int(A_B24[1]/2))]), 85 | Function('-', (int, int, int), lambda i, j: i-j, lambda A_B25: [(int(A_B25[0]/2)+1, int(A_B25[1]/2))]), 86 | Function('*', (int, int, int), lambda i, j: i*j, lambda A_B26: MUL_bounds(A_B26[0], A_B26[1])), 87 | Function('MIN', (int, int, int), lambda i, j: min(i, j), lambda A_B27: [(A_B27[0], A_B27[1])]), 88 | Function('MAX', (int, int, int), lambda i, j: max(i, j), lambda A_B28: [(A_B28[0], A_B28[1])]), 89 | ] 90 | 91 | LINQ = [ 92 | Function('REVERSE', ([int], [int]), lambda xs: list(reversed(xs)), lambda A_B_L: [(A_B_L[0], A_B_L[1])]), 93 | Function('SORT', ([int], [int]), lambda xs: sorted(xs), lambda A_B_L1: [(A_B_L1[0], A_B_L1[1])]), 94 | Function('TAKE', (int, [int], [int]), lambda n, xs: xs[:n], lambda A_B_L2: [(0,A_B_L2[2]), (A_B_L2[0], A_B_L2[1])]), 95 | Function('DROP', (int, [int], [int]), lambda n, xs: xs[n:], lambda A_B_L3: [(0,A_B_L3[2]), (A_B_L3[0], A_B_L3[1])]), 96 | Function('ACCESS', (int, [int], int), lambda n, xs: xs[n] if n>=0 and len(xs)>n else Null, lambda A_B_L4: [(0,A_B_L4[2]), (A_B_L4[0], A_B_L4[1])]), 97 | Function('HEAD', ([int], int), lambda xs: xs[0] if len(xs)>0 else Null, lambda A_B_L5: [(A_B_L5[0], A_B_L5[1])]), 98 | Function('LAST', ([int], int), lambda xs: xs[-1] if len(xs)>0 else Null, lambda A_B_L6: [(A_B_L6[0], A_B_L6[1])]), 99 | Function('MINIMUM', ([int], int), lambda xs: min(xs) if len(xs)>0 else Null, lambda A_B_L7: [(A_B_L7[0], A_B_L7[1])]), 100 | Function('MAXIMUM', ([int], int), lambda xs: max(xs) if len(xs)>0 else Null, lambda A_B_L8: [(A_B_L8[0], A_B_L8[1])]), 101 | Function('SUM', ([int], int), lambda xs: sum(xs), lambda A_B_L9: [(int(A_B_L9[0]/A_B_L9[2])+1, int(A_B_L9[1]/A_B_L9[2]))]), 102 | ] + \ 103 | [Function( 104 | 'MAP ' + l.src, 105 | ([int], [int]), 106 | lambda xs, l=l: list(map(l.fun, xs)), 107 | lambda A_B_L, l=l: l.bounds((A_B_L[0], A_B_L[1])) 108 | ) for l in lambdas if l.sig==(int, int)] + \ 109 | [Function( 110 | 'FILTER ' + l.src, 111 | ([int], [int]), 112 | lambda xs, l=l: list(filter(l.fun, xs)), 113 | lambda A_B_L, l=l: [(A_B_L[0], A_B_L[1])], 114 | ) for l in lambdas if l.sig==(int, bool)] + \ 115 | [Function( 116 | 'COUNT ' + l.src, 117 | ([int], int), 118 | lambda xs, l=l: len(list(filter(l.fun, xs))), 119 | lambda A_B_L, l=l: [(-V, V)], 120 | ) for l in lambdas if l.sig==(int, bool)] + \ 121 | [Function( 122 | 'ZIPWITH ' + l.src, 123 | ([int], [int], [int]), 124 | lambda xs, ys, l=l: [l.fun(x, y) for (x, y) in zip(xs, ys)], 125 | lambda A_B_L, l=l: l.bounds((A_B_L[0], A_B_L[1])) + l.bounds((A_B_L[0], A_B_L[1])), 126 | ) for l in lambdas if l.sig==(int, int, int)] + \ 127 | [Function( 128 | 'SCANL1 ' + l.src, 129 | ([int], [int]), 130 | lambda xs, l=l: list(scanl1(l, xs)), 131 | lambda A_B_L, l=l: scanl1_bounds(l, A_B_L[0], A_B_L[1], A_B_L[2]), 132 | ) for l in lambdas if l.sig==(int, int, int)] 133 | 134 | return LINQ, lambdas 135 | 136 | def compile(source_code, V, L, min_input_range_length=0, verbose=False): 137 | """ Taken in a program source code, the integer range V and the tape lengths L, 138 | and produces a Program. 139 | If L is None then input constraints are not computed. 140 | """ 141 | 142 | # Source code parsing into intermediate representation 143 | LINQ, _ = get_language(V) 144 | LINQ_names = [l.src for l in LINQ] 145 | 146 | input_types = [] 147 | types = [] 148 | functions = [] 149 | pointers = [] 150 | for line in source_code.split('\n'): 151 | instruction = line[5:] 152 | if instruction in ['int', '[int]']: 153 | input_types.append(eval(instruction)) 154 | types.append(eval(instruction)) 155 | functions.append(None) 156 | pointers.append(None) 157 | else: 158 | split = instruction.split(' ') 159 | command = split[0] 160 | args = split[1:] 161 | # Handle lambda 162 | if len(split[1]) > 1 or split[1] < 'a' or split[1] > 'z': 163 | command += ' ' + split[1] 164 | args = split[2:] 165 | f = LINQ[LINQ_names.index(command)] 166 | assert len(f.sig) - 1 == len(args) 167 | ps = [ord(arg) - ord('a') for arg in args] 168 | types.append(f.sig[-1]) 169 | functions.append(f) 170 | pointers.append(ps) 171 | assert [types[p] == t for p, t in zip(ps, f.sig)] 172 | input_length = len(input_types) 173 | program_length = len(types) 174 | 175 | # Validate program by propagating input constraints and check all registers are useful 176 | limits = [(-V, V)]*program_length 177 | if L is not None: 178 | for t in range(program_length-1, -1, -1): 179 | if t >= input_length: 180 | lim_l, lim_u = limits[t] 181 | new_lims = functions[t].bounds((lim_l, lim_u, L)) 182 | num_args = len(functions[t].sig) - 1 183 | for a in range(num_args): 184 | p = pointers[t][a] 185 | limits[pointers[t][a]] = (max(limits[p][0], new_lims[a][0]), 186 | min(limits[p][1], new_lims[a][1])) 187 | #print('t=%d: New limit for %d is %s' % (t, p, limits[pointers[t][a]])) 188 | elif min_input_range_length >= limits[t][1] - limits[t][0]: 189 | if verbose: print(('Program with no valid inputs: %s' % source_code)) 190 | return None 191 | 192 | # for t in xrange(input_length, program_length): 193 | # print('%s (%s)' % (functions[t].src, ' '.join([chr(ord('a') + p) for p in pointers[t]]))) 194 | 195 | # Construct executor 196 | my_input_types = list(input_types) 197 | my_types = list(types) 198 | my_functions = list(functions) 199 | my_pointers = list(pointers) 200 | my_program_length = program_length 201 | def program_executor(args): 202 | # print '--->' 203 | # for t in xrange(input_length, my_program_length): 204 | # print('%s <- %s (%s)' % (chr(ord('a') + t), my_functions[t].src, ' '.join([chr(ord('a') + p) for p in my_pointers[t]]))) 205 | 206 | assert len(args) == len(my_input_types) 207 | registers = [None]*my_program_length 208 | for t in range(len(args)): 209 | registers[t] = args[t] 210 | for t in range(len(args), my_program_length): 211 | registers[t] = my_functions[t].fun(*[registers[p] for p in my_pointers[t]]) 212 | return registers[-1] 213 | 214 | return Program( 215 | source_code, 216 | input_types, 217 | types[-1], 218 | program_executor, 219 | limits[:input_length] 220 | ) 221 | 222 | def generate_IO_examples(program, N, L, V): 223 | """ Given a programs, randomly generates N IO examples. 224 | using the specified length L for the input arrays. """ 225 | input_types = program.ins 226 | input_nargs = len(input_types) 227 | 228 | # Generate N input-output pairs 229 | IO = [] 230 | for _ in range(N): 231 | input_value = [None]*input_nargs 232 | for a in range(input_nargs): 233 | minv, maxv = program.bounds[a] 234 | if input_types[a] == int: 235 | input_value[a] = np.random.randint(minv, maxv) 236 | elif input_types[a] == [int]: 237 | input_value[a] = list(np.random.randint(minv, maxv, size=L)) 238 | else: 239 | raise Exception("Unsupported input type " + input_types[a] + " for random input generation") 240 | output_value = program.fun(input_value) 241 | IO.append((input_value, output_value)) 242 | assert (program.out == int and output_value <= V) or (program.out == [int] and len(output_value) == 0) or (program.out == [int] and max(output_value) <= V) 243 | return IO 244 | 245 | 246 | if __name__ == '__main__': 247 | import time 248 | source = sys.argv[1] 249 | t = time.time() 250 | source = source.replace(' | ', '\n') 251 | program = compile(source, V=512, L=10) 252 | samples = generate_IO_examples(program, N=5, L=10, V=512) 253 | print(("time:", time.time()-t)) 254 | print(program) 255 | print(samples) 256 | 257 | 258 | -------------------------------------------------------------------------------- /data_src/makeAlgolispData.py: -------------------------------------------------------------------------------- 1 | #generate deepcoder data 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | import pickle 8 | import time 9 | from collections import namedtuple 10 | #Function = namedtuple('Function', ['src', 'sig', 'fun', 'bounds']) 11 | 12 | from grammar import Grammar, NoCandidates 13 | from utilities import flatten 14 | 15 | from algolispPrimitives import tsymbol, algolispProductions, algolispPrimitives 16 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 17 | import math 18 | import random 19 | from type import Context, arrow, tint, tlist, UnificationFailure 20 | 21 | from util.algolisp_util import convert_IO, tree_to_prog, make_holey_algolisp, AlgolispHole, tree_to_seq, seq_to_tree, tree_depth, check_subtree, tokenize_for_dc 22 | 23 | from itertools import zip_longest, chain 24 | from functools import reduce 25 | import torch 26 | 27 | from program_synthesis.algolisp.dataset import dataset 28 | from program_synthesis.algolisp import arguments 29 | 30 | from util.algolisp_pypy_util import test_program_on_IO 31 | 32 | from program_synthesis.algolisp.dataset import executor 33 | #I want a list of namedTuples 34 | 35 | #Datum = namedtuple('Datum', ['tp', 'p', 'pseq', 'IO', 'sketch', 'sketchseq']) 36 | 37 | basegrammar = Grammar.fromProductions(algolispProductions()) # Fix this 38 | 39 | with open('basegrammar.p', 'rb') as file: 40 | basegrammar = pickle.load(file) 41 | 42 | #reweighted basegrammar: 43 | # class FrontierEntry(object): 44 | # def __init__( 45 | # self, 46 | # program, 47 | # _=None, 48 | # logPrior=None, 49 | # logLikelihood=None, 50 | # logPosterior=None): 51 | # self.logPosterior = logPrior + logLikelihood if logPosterior is None else logPosterior 52 | # self.program = program 53 | # self.logPrior = logPrior 54 | # self.logLikelihood = logLikelihood 55 | 56 | # def __repr__(self): 57 | # return "FrontierEntry(program={self.program}, logPrior={self.logPrior}, logLikelihood={self.logLikelihood}".format( 58 | # self=self) 59 | 60 | 61 | # class Frontier(object): 62 | # def __init__(self, frontier, task): 63 | # self.entries = frontier 64 | # self.task = task 65 | 66 | 67 | 68 | def grouper(iterable, n, fillvalue=None): 69 | "Collect data into fixed-length chunks or blocks" 70 | # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" 71 | args = [iter(iterable)] * n 72 | return zip_longest(*args, fillvalue=fillvalue) 73 | 74 | #redo this ... 75 | class Datum(): 76 | def __init__(self, tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob, spec, schema_args): 77 | self.tp = tp 78 | self.p = p 79 | self.pseq = pseq 80 | self.IO = IO 81 | self.sketch = sketch 82 | self.sketchseq = sketchseq 83 | self.reward = reward 84 | self.sketchprob = sketchprob 85 | self.spec = spec 86 | self.schema_args = schema_args 87 | 88 | def __hash__(self): 89 | return reduce(lambda a, b: hash(a + hash(b)), flatten(self.spec, abort=lambda x: type(x) is str), 0) + hash(self.p) + hash(self.sketch) 90 | 91 | Batch = namedtuple('Batch', ['tps', 'ps', 'pseqs', 'IOs', 'sketchs', 'sketchseqs', 'rewards', 'sketchprobs','specs', 'schema_args']) 92 | 93 | def convert_datum(ex, 94 | compute_sketches=False, 95 | top_k_sketches=20, 96 | inv_temp=1.0, 97 | reward_fn=None, 98 | sample_fn=None, 99 | dc_model=None, 100 | improved_dc_model=True, 101 | use_timeout=False, 102 | proper_type=False, 103 | only_passable=False, 104 | filter_depth=None, 105 | nHoles=4, 106 | exclude=None, 107 | include_only=None, 108 | #include_all=None, 109 | use_fixed_seed=False, 110 | limit_IO_size=None, 111 | use_IO=False, 112 | rng=None): 113 | if filter_depth: 114 | #filter_depth should be an iterable of depths that are allowed 115 | if not tree_depth(ex.code_tree) in filter_depth: return None 116 | 117 | if exclude: 118 | for subex in exclude: 119 | if any( check_subtree(ex.code_tree, subex) for subex in exclude): return None 120 | 121 | if include_only: 122 | for subex in include_only: 123 | if not any( check_subtree(ex.code_tree, subex) for subex in include_only): return None 124 | 125 | #if include_all: assert False, "unimplemnted" 126 | 127 | 128 | #find IO 129 | IO = convert_IO(ex.tests) #TODO 130 | 131 | if limit_IO_size: 132 | tokenized = [] 133 | for io in IO: 134 | if len(tokenize_for_dc(io, digit_enc=True)) > limit_IO_size: 135 | return None 136 | 137 | schema_args = ex.schema.args 138 | if only_passable: 139 | executor_ = executor.LispExecutor() 140 | hit_tree = test_program_on_IO(ex.code_tree, IO, schema_args, executor_) 141 | #hit_seq = test_program_on_IO(seq_to_tree(ex.code_sequence), IO, schema_args, executor_) 142 | #if not hit_tree==hit_seq: print("DATASET WARNING: tree and seq don't match, one fails tests and the other passes") 143 | if not hit_tree: return None 144 | # find tp 145 | if proper_type: 146 | assert False #arrow(some_stuff, tsymbol) 147 | out_tp = convert_tp(x.schema.return_type) 148 | else: 149 | tp = tsymbol 150 | # find program p 151 | pseq = tuple(ex.code_sequence) 152 | # find pseq 153 | if proper_type: 154 | assert False 155 | else: 156 | p = tree_to_prog(ex.code_tree) # TODO: use correct grammar, and 157 | spec = ex.text 158 | if compute_sketches: 159 | # find sketch 160 | #grammar = basegrammar if not dc_model else dc_model.infer_grammar(IO) #This line needs to change 161 | dc_input = spec if not use_IO else IO 162 | sketch, reward, sketchprob = make_holey_algolisp(p, 163 | top_k_sketches, 164 | tp, 165 | basegrammar=basegrammar, 166 | dcModel=dc_model, 167 | improved_dc_model=improved_dc_model, 168 | inv_temp=inv_temp, 169 | reward_fn=reward_fn, 170 | sample_fn=sample_fn, 171 | use_timeout=use_timeout, 172 | return_obj=AlgolispHole, 173 | dc_input=dc_input, 174 | nHoles=nHoles, 175 | use_fixed_seed=use_fixed_seed, 176 | rng=rng) #TODO 177 | # find sketchseq 178 | sketchseq = tuple(tree_to_seq(sketch.evaluate([]))) 179 | else: 180 | sketch, sketchseq, reward, sketchprob = None, None, None, None 181 | return Datum(tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob, spec, schema_args) 182 | 183 | def batchloader(data_file, 184 | batchsize=100, 185 | compute_sketches=False, 186 | dc_model=None, 187 | improved_dc_model=True, 188 | shuffle=True, 189 | top_k_sketches=20, 190 | inv_temp=1.0, 191 | reward_fn=None, 192 | sample_fn=None, 193 | use_timeout=False, 194 | only_passable=False, 195 | filter_depth=None, 196 | nHoles=1, 197 | limit_data=False, 198 | use_fixed_seed=False, 199 | use_dataset_len=False, 200 | exclude=None, 201 | include_only=None, 202 | include_all=None, 203 | limit_IO_size=None, 204 | use_IO=False, 205 | seed=42): 206 | """ 207 | Note: exclude and include_only take lists of expressions!!! don't get confused 208 | """ 209 | 210 | mode = 'train' if data_file=='train' else 'eval' 211 | parser = arguments.get_arg_parser('Training AlgoLisp', mode) 212 | args = parser.parse_args([]) #this takes only the default values, allowing us to use the top-level parser for our code in main files 213 | args.cuda = not args.no_cuda and torch.cuda.is_available() 214 | 215 | args.batch_size = batchsize #doesn't even matter now 216 | 217 | if data_file == 'train': 218 | NearDataset, _ = dataset.get_dataset(args) 219 | dataset_len = 79214 220 | elif data_file == 'dev': 221 | _, NearDataset = dataset.get_dataset(args) 222 | assert not use_dataset_len 223 | elif data_file == 'eval': 224 | print("WARNING: right now 'eval' gives correct 'eval' test set") 225 | NearDataset = dataset.get_eval_dataset(args) # TODO: 226 | assert not use_dataset_len 227 | else: 228 | assert False 229 | 230 | seeded_random = random.Random(seed) #so that state is shared 231 | 232 | if use_dataset_len: 233 | inc_list = seeded_random.sample(range(dataset_len), use_dataset_len) 234 | counter = 0 235 | 236 | def remove_datum(): 237 | if use_dataset_len: 238 | nonlocal counter 239 | rval = counter not in inc_list 240 | counter += 1 241 | return rval 242 | else: 243 | if limit_data: 244 | return not seeded_random.random() < limit_data 245 | else: 246 | return False 247 | 248 | data = (convert_datum(ex, 249 | compute_sketches=compute_sketches, 250 | top_k_sketches=top_k_sketches, 251 | inv_temp=inv_temp, 252 | reward_fn=reward_fn, 253 | sample_fn=sample_fn, 254 | dc_model=dc_model, 255 | improved_dc_model=improved_dc_model, 256 | use_timeout=use_timeout, 257 | proper_type=False, 258 | only_passable=only_passable, 259 | filter_depth=filter_depth, 260 | nHoles=nHoles, 261 | exclude=exclude, 262 | include_only=include_only, 263 | use_fixed_seed=use_fixed_seed, 264 | limit_IO_size=limit_IO_size, 265 | use_IO=use_IO, 266 | rng=seeded_random if use_fixed_seed else None) for batch in NearDataset for ex in batch if not remove_datum() ) #I assume batch has one ex 267 | data = (x for x in data if x is not None) 268 | #figure out how to deal with these 269 | if batchsize==1: 270 | yield from data 271 | 272 | else: 273 | grouped_data = grouper(data, batchsize) 274 | for group in grouped_data: 275 | tps, ps, pseqs, IOs, sketchs, sketchseqs, rewards, sketchprobs, specs, schema_args = zip(*[ 276 | (datum.tp, datum.p, datum.pseq, datum.IO, datum.sketch, datum.sketchseq, datum.reward, datum.sketchprob, datum.spec, datum.schema_args) 277 | for datum in group if datum is not None]) 278 | yield Batch(tps, ps, pseqs, IOs, sketchs, sketchseqs, torch.FloatTensor(rewards) if any(r is not None for r in rewards) else None, torch.FloatTensor(sketchprobs) if any(s is not None for s in sketchprobs) else None, specs, schema_args) # check that his work 279 | 280 | # for batch in NearDataset: 281 | # tps, ps, pseqs, IOs, sketchs, sketchseqs, rewards, sketchprobs, specs, schema_args = zip(*[ 282 | # (datum.tp, datum.p, datum.pseq, datum.IO, datum.sketch, datum.sketchseq, datum.reward, datum.sketchprob, datum.spec, datum.schema_args) 283 | # for datum in (convert_datum(ex, 284 | # compute_sketches=compute_sketches, 285 | # top_k_sketches=top_k_sketches, 286 | # inv_temp=inv_temp, 287 | # reward_fn=reward_fn, 288 | # sample_fn=sample_fn, 289 | # dc_model=dc_model, 290 | # improved_dc_model=improved_dc_model, 291 | # use_timeout=use_timeout, 292 | # proper_type=False, 293 | # only_passable=only_passable, 294 | # filter_depth=filter_depth) for ex in batch) 295 | # if datum is not None]) 296 | 297 | # yield Batch(tps, ps, pseqs, IOs, sketchs, sketchseqs, torch.FloatTensor(rewards) if any(r is not None for r in rewards) else None, torch.FloatTensor(sketchprobs) if any(s is not None for s in sketchprobs) else None, specs, schema_args) # check that his work 298 | 299 | from frontier import Frontier, FrontierEntry 300 | from task import Task 301 | 302 | def reweightbasegrammar(basegrammar, pseudoCounts, filter_depth=None, size=None): 303 | frontiers = [] 304 | for datum in islice(batchloader('train', 305 | batchsize=1, 306 | compute_sketches=False, 307 | filter_depth=filter_depth), size): #TODO 308 | #class Task(object): 309 | #def __init__(self, name, request, examples, features=None, cache=False): 310 | frontiers.append( Frontier([FrontierEntry(datum.p, logPrior=basegrammar.logLikelihood(datum.tp, datum.p), logLikelihood=0)], Task('dummyName', datum.tp, []) ) ) 311 | 312 | return basegrammar.insideOutside(frontiers, pseudoCounts, iterations=1) 313 | 314 | 315 | 316 | 317 | if __name__=='__main__': 318 | from itertools import islice 319 | # algolispProductions() 320 | # d = islice(batchloader('train', batchsize=200, compute_sketches=True, dc_model=None, improved_dc_model=True, shuffle=True, top_k_sketches=20, inv_temp=1.0, reward_fn=None, sample_fn=None, use_timeout=False),100) 321 | # for datum in d: 322 | # print("program:", datum.p) 323 | # print("sketch: ", datum.sketch) 324 | # print(len(datum.pseq)) 325 | # print() 326 | from collections import Counter 327 | c = Counter() 328 | max_len = 0 329 | for i, d in enumerate(batchloader('eval', 330 | batchsize=1, 331 | compute_sketches=False, 332 | only_passable=False)): 333 | c.update([len(d.pseq)]) 334 | max_len = max(len(d.pseq) , max_len ) 335 | 336 | 337 | print("max_len:", max_len) 338 | print(c) 339 | 340 | assert False 341 | print(basegrammar) 342 | g = reweightbasegrammar(basegrammar, 0.1, filter_depth=None, size=None) 343 | print(g) 344 | 345 | print('saving') 346 | 347 | with open('basegrammar.p','wb') as savefile: 348 | pickle.dump(g, savefile) 349 | 350 | for i in c: 351 | print(i, c[i]) 352 | -------------------------------------------------------------------------------- /data_src/makeDeepcoderData.py: -------------------------------------------------------------------------------- 1 | #generate deepcoder data 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | import pickle 8 | from util.deepcoder_util import parseprogram, make_holey_deepcoder 9 | from util.algolisp_util import make_holey_algolisp 10 | from util.deepcoder_util import basegrammar 11 | import time 12 | from collections import namedtuple 13 | #Function = namedtuple('Function', ['src', 'sig', 'fun', 'bounds']) 14 | 15 | from grammar import Grammar, NoCandidates 16 | from utilities import flatten 17 | 18 | from deepcoderPrimitives import deepcoderProductions, flatten_program 19 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 20 | import math 21 | import random 22 | from type import Context, arrow, tint, tlist, UnificationFailure 23 | from data_src.dc_program import generate_IO_examples, compile 24 | from itertools import zip_longest, chain 25 | from functools import reduce 26 | import torch 27 | 28 | #from dc_program import Program as dc_Program 29 | 30 | def make_deepcoder_data(filename, with_holes=False, size=10000000, k=20): 31 | data_list = [] 32 | save_freq = 1000 33 | 34 | t = time.time() 35 | for i in range(size): 36 | inst = getInstance(with_holes=with_holes, k=k) 37 | data_list.append(inst) 38 | 39 | if i%save_freq==0: 40 | #save data 41 | print(f"iteration {i} out of {size}") 42 | print("saving data") 43 | with open(filename, 'wb') as file: 44 | pickle.dump(data_list, file) 45 | print(f"time since last {save_freq}: {time.time()-t}") 46 | t = time.time() 47 | 48 | #I want a list of namedTuples 49 | 50 | #Datum = namedtuple('Datum', ['tp', 'p', 'pseq', 'IO', 'sketch', 'sketchseq']) 51 | 52 | class Datum(): 53 | def __init__(self, tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob): 54 | self.tp = tp 55 | self.p = p 56 | self.pseq = pseq 57 | self.IO = IO 58 | self.sketch = sketch 59 | self.sketchseq = sketchseq 60 | self.reward = reward 61 | self.sketchprob = sketchprob 62 | 63 | def __hash__(self): 64 | return reduce(lambda a, b: hash(a + hash(b)), flatten(self.IO), 0) + hash(self.p) + hash(self.sketch) 65 | 66 | Batch = namedtuple('Batch', ['tps', 'ps', 'pseqs', 'IOs', 'sketchs', 'sketchseqs', 'rewards', 'sketchprobs']) 67 | 68 | def convert_dc_program_to_ec(dc_program, tp): 69 | source = dc_program.src 70 | source = source.split('\n') 71 | source = [line.split(' ') for line in source] 72 | #print(source) 73 | num_inputs = len(tp.functionArguments()) 74 | #print(num_inputs) 75 | del source[:num_inputs] 76 | source = [[l for l in line if l != '<-'] for line in source] 77 | last_var = source[-1][0] 78 | prog = source[-1][1:] 79 | del source[-1] 80 | variables = list('abcdefghigklmnop') 81 | del variables[variables.index(last_var):] #check this line 82 | lookup = {variables[i]: ["input_" + str(i)] for i in range(num_inputs)} 83 | for line in source: 84 | lookup[line[0]] = line[1:] 85 | for variable in reversed(variables): 86 | p2 = [] 87 | for x in prog: 88 | if x==variable: 89 | p2 += lookup[variable] 90 | else: 91 | p2.append(x) 92 | prog = p2 93 | return prog 94 | 95 | 96 | def convert_source_to_datum(source, 97 | N=5, 98 | V=512, 99 | L=10, 100 | compute_sketches=False, 101 | top_k_sketches=20, 102 | inv_temp=1.0, 103 | reward_fn=None, 104 | sample_fn=None, 105 | dc_model=None, 106 | use_timeout=False, 107 | improved_dc_model=False, 108 | nHoles=1): 109 | 110 | source = source.replace(' | ', '\n') 111 | dc_program = compile(source, V=V, L=L) 112 | 113 | if dc_program is None: 114 | return None 115 | # find IO 116 | IO = tuple(generate_IO_examples(dc_program, N=N, L=L, V=V)) 117 | 118 | # find tp 119 | ins = [tint if inp == int else tlist(tint) for inp in dc_program.ins] 120 | if dc_program.out == int: 121 | out = tint 122 | else: 123 | assert dc_program.out==[int] 124 | out = tlist(tint) 125 | tp = arrow( *(ins+[out]) ) 126 | 127 | # find program p 128 | pseq = tuple(convert_dc_program_to_ec(dc_program, tp)) 129 | 130 | # find pseq 131 | p = parseprogram(pseq, tp) # TODO: use correct grammar, and 132 | 133 | if compute_sketches: 134 | # find sketch 135 | 136 | # grammar = basegrammar if not dc_model else dc_model.infer_grammar(IO) #This line needs to change 137 | # sketch, reward, sketchprob = make_holey_deepcoder(p, 138 | # top_k_sketches, 139 | # grammar, 140 | # tp, 141 | # inv_temp=inv_temp, 142 | # reward_fn=reward_fn, 143 | # sample_fn=sample_fn, 144 | # use_timeout=use_timeout, 145 | # improved_dc_model=improved_dc_model, 146 | # nHoles=nHoles) #TODO 147 | 148 | sketch, reward, sketchprob = make_holey_algolisp(p, 149 | top_k_sketches, 150 | tp, 151 | basegrammar, 152 | dcModel=dc_model, 153 | improved_dc_model=improved_dc_model, 154 | return_obj=Hole, 155 | dc_input=IO, 156 | inv_temp=inv_temp, 157 | reward_fn=reward_fn, 158 | sample_fn=sample_fn, 159 | use_timeout=use_timeout, 160 | nHoles=nHoles, 161 | domain='list') 162 | 163 | # find sketchseq 164 | sketchseq = tuple(flatten_program(sketch)) 165 | else: 166 | sketch, sketchseq, reward, sketchprob = None, None, None, None 167 | 168 | return Datum(tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob) 169 | 170 | 171 | def grouper(iterable, n, fillvalue=None): 172 | "Collect data into fixed-length chunks or blocks" 173 | # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" 174 | args = [iter(iterable)] * n 175 | return zip_longest(*args, fillvalue=fillvalue) 176 | 177 | def single_batchloader(data_file, 178 | batchsize=100, 179 | N=5, 180 | V=512, 181 | L=10, 182 | compute_sketches=False, 183 | dc_model=None, 184 | shuffle=True, 185 | top_k_sketches=20, 186 | inv_temp=1.0, 187 | reward_fn=None, 188 | sample_fn=None, 189 | use_timeout=False, 190 | improved_dc_model=False, 191 | nHoles=1): 192 | lines = (line.rstrip('\n') for i, line in enumerate(open(data_file)) if i != 0) #remove first line 193 | if shuffle: 194 | lines = list(lines) 195 | random.shuffle(lines) 196 | 197 | data = (convert_source_to_datum(line, 198 | N=N, 199 | V=V, 200 | L=L, 201 | compute_sketches=compute_sketches, 202 | dc_model=dc_model, 203 | top_k_sketches=20, 204 | inv_temp=inv_temp, 205 | reward_fn=reward_fn, 206 | sample_fn=sample_fn, 207 | use_timeout=use_timeout, 208 | improved_dc_model=improved_dc_model, 209 | nHoles=nHoles) for line in lines) 210 | data = (x for x in data if x is not None) 211 | 212 | if batchsize==1: 213 | yield from data 214 | else: 215 | grouped_data = grouper(data, batchsize) 216 | for group in grouped_data: 217 | tps, ps, pseqs, IOs, sketchs, sketchseqs, rewards, sketchprobs = zip(*[(datum.tp, datum.p, datum.pseq, datum.IO, datum.sketch, datum.sketchseq, datum.reward, datum.sketchprob) for datum in group if datum is not None]) 218 | yield Batch(tps, ps, pseqs, IOs, sketchs, sketchseqs, torch.FloatTensor(rewards) if any(r is not None for r in rewards) else None, torch.FloatTensor(sketchprobs) if any(s is not None for s in sketchprobs) else None) # check that his works 219 | 220 | def batchloader(data_file_list, 221 | batchsize=100, 222 | N=5, 223 | V=512, 224 | L=10, 225 | compute_sketches=False, 226 | dc_model=None, 227 | shuffle=True, 228 | top_k_sketches=20, 229 | inv_temp=1.0, 230 | reward_fn=None, 231 | sample_fn=None, 232 | use_timeout=False, 233 | #new 234 | improved_dc_model=False, 235 | nHoles=1): 236 | yield from chain(*[single_batchloader(data_file, 237 | batchsize=batchsize, 238 | N=N, 239 | V=V, 240 | L=L, 241 | compute_sketches=compute_sketches, 242 | dc_model=dc_model, 243 | shuffle=shuffle, 244 | top_k_sketches=top_k_sketches, 245 | inv_temp=inv_temp, 246 | reward_fn=reward_fn, 247 | sample_fn=sample_fn, 248 | use_timeout=use_timeout, 249 | improved_dc_model=improved_dc_model, 250 | nHoles=nHoles) for data_file in data_file_list]) 251 | 252 | if __name__=='__main__': 253 | from itertools import islice 254 | #convert_source_to_datum("a <- [int] | b <- [int] | c <- ZIPWITH + b a | d <- COUNT isEVEN c | e <- ZIPWITH MAX a c | f <- MAP MUL4 e | g <- TAKE d f") 255 | 256 | #filename = 'data/DeepCoder_data/T2_A2_V512_L10_train_perm.txt' 257 | train_data = 'data/DeepCoder_data/T3_A2_V512_L10_train_perm.txt' 258 | 259 | #test_data = '' 260 | 261 | #lines = (line.rstrip('\n') for i, line in enumerate(open(filename)) if i != 0) #remove first line 262 | 263 | from util.deepcoder_util import grammar 264 | 265 | import models.deepcoderModel as deepcoderModel 266 | dcModel = torch.load("./saved_models/dc_model.p") 267 | 268 | for datum in islice(batchloader([train_data], 269 | batchsize=1, 270 | N=5, 271 | V=128, 272 | L=10, 273 | compute_sketches=True, 274 | top_k_sketches=20, 275 | inv_temp=0.25, 276 | use_timeout=True), 1): 277 | 278 | print("program:", datum.p) 279 | print("sketch: ", datum.sketch) 280 | grammar = dcModel.infer_grammar(datum.IO) 281 | l = grammar.sketchLogLikelihood(datum.tp, datum.p, datum.sketch) 282 | print(l) 283 | 284 | 285 | # 286 | 287 | # p,_,_ = make_holey_deepcoder(datum.p, 50, grammar, datum.tp, inv_temp=1.0, reward_fn=None, sample_fn=None, verbose=True) 288 | # print("SKETCH:",p) 289 | # p,_,_ = make_holey_deepcoder(datum.p, 50, grammar, datum.tp, inv_temp=1.0, reward_fn=None, sample_fn=None, verbose=True, use_timeout=True) 290 | # print("SKETCH:",p) 291 | 292 | # p,_,_ = make_holey_deepcoder(datum.p, 50, grammar, datum.tp, inv_temp=0.5, reward_fn=None, sample_fn=None, verbose=True) 293 | # print("SKETCH:",p) 294 | # p,_,_ = make_holey_deepcoder(datum.p, 50, grammar, datum.tp, inv_temp=0.5, reward_fn=None, sample_fn=None, verbose=True, use_timeout=True) 295 | # print("SKETCH:",p) 296 | 297 | # p,_,_ = make_holey_deepcoder(datum.p, 50, grammar, datum.tp, inv_temp=0.25, reward_fn=None, sample_fn=None, verbose=True) 298 | # print("SKETCH:",p) 299 | # p,_,_ = make_holey_deepcoder(datum.p, 50, grammar, datum.tp, inv_temp=0.25, reward_fn=None, sample_fn=None, verbose=True, use_timeout=True) 300 | # print("SKETCH:",p) 301 | 302 | #path = 'data/pretrain_data_v1_alt.p' 303 | #make_deepcoder_data(path, with_holes=True, k=20) 304 | 305 | 306 | 307 | 308 | -------------------------------------------------------------------------------- /data_src/makeRegexData.py: -------------------------------------------------------------------------------- 1 | #generate deepcoder data 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | import pickle 8 | # TODO 9 | #from util.deepcoder_util import make_holey_deepcoder # this might be enough 10 | from util.regex_util import basegrammar # TODO 11 | from util.regex_util import sample_program, generate_IO_examples, flatten_program, make_holey_regex 12 | from util.robustfill_util import timing # TODO 13 | 14 | import time 15 | from collections import namedtuple 16 | #Function = namedtuple('Function', ['src', 'sig', 'fun', 'bounds']) 17 | 18 | from grammar import Grammar, NoCandidates 19 | from utilities import flatten 20 | 21 | # TODO 22 | from regexPrimitives import concatPrimitives 23 | 24 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 25 | import math 26 | import random 27 | from type import Context, arrow, tint, tlist, UnificationFailure, tcharacter, tpregex 28 | from itertools import zip_longest, chain, repeat, islice 29 | from functools import reduce 30 | import torch 31 | 32 | 33 | class Datum(): 34 | def __init__(self, tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob): 35 | self.tp = tp 36 | self.p = p 37 | self.pseq = pseq 38 | self.IO = IO 39 | self.sketch = sketch 40 | self.sketchseq = sketchseq 41 | self.reward = reward 42 | self.sketchprob = sketchprob 43 | 44 | def __hash__(self): 45 | return reduce(lambda a, b: hash(a + hash(b)), flatten(self.IO, abort=lambda x: type(x) is str), 0) + hash(self.p) + hash(self.sketch) 46 | 47 | Batch = namedtuple('Batch', ['tps', 'ps', 'pseqs', 'IOs', 'sketchs', 'sketchseqs', 'rewards', 'sketchprobs']) 48 | 49 | 50 | #TODO##### 51 | def sample_datum(g=basegrammar, N=5, compute_sketches=False, top_k_sketches=100, inv_temp=1.0, reward_fn=None, sample_fn=None, dc_model=None, use_timeout=False, continuation=False): 52 | 53 | # find tp 54 | if continuation: 55 | tp = arrow(tpregex, tpregex) 56 | else: 57 | tp = tpregex 58 | 59 | #sample a program: 60 | #with timing("sample program"): 61 | program = sample_program(g, tp) 62 | 63 | # find IO 64 | #with timing("sample IO:"): 65 | IO = generate_IO_examples(program, num_examples=N, continuation=continuation) 66 | if IO is None: return None 67 | IO = tuple(IO) 68 | 69 | # TODO 70 | 71 | # find pseq 72 | pseq = tuple(flatten_program(program, continuation=continuation)) #TODO 73 | 74 | if compute_sketches: 75 | # find sketch 76 | 77 | # TODO - improved dc_grammar [ ] 78 | # TODO - contextual_grammar [ ] 79 | # TODO - put grammar inference inside make_holey 80 | grammar = g if not dc_model else dc_model.infer_grammar(IO) 81 | 82 | #with timing("make_holey"): 83 | sketch, reward, sketchprob = make_holey_regex(program, top_k_sketches, grammar, tp, inv_temp=inv_temp, reward_fn=reward_fn, sample_fn=sample_fn, use_timeout=use_timeout) #TODO 84 | 85 | # find sketchseq 86 | sketchseq = tuple(flatten_program(sketch, continuation=continuation)) 87 | else: 88 | sketch, sketchseq, reward, sketchprob = None, None, None, None 89 | 90 | return Datum(tp, program, pseq, IO, sketch, sketchseq, reward, sketchprob) 91 | 92 | 93 | def grouper(iterable, n, fillvalue=None): 94 | # "Collect data into fixed-length chunks or blocks" 95 | # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" 96 | args = [iter(iterable)] * n 97 | return zip_longest(*args, fillvalue=fillvalue) 98 | 99 | #TODO######## 100 | def batchloader(size, batchsize=100, g=basegrammar, N=5, compute_sketches=False, dc_model=None, shuffle=True, top_k_sketches=20, inv_temp=1.0, reward_fn=None, sample_fn=None, use_timeout=False): 101 | if batchsize==1: 102 | data = (sample_datum(g=g, N=N, compute_sketches=compute_sketches, dc_model=dc_model, top_k_sketches=20, inv_temp=inv_temp, reward_fn=reward_fn, sample_fn=sample_fn, use_timeout=use_timeout) for _ in repeat(0)) 103 | yield from islice((x for x in data if x is not None), size) 104 | else: 105 | data = (sample_datum(g=g, N=N, compute_sketches=compute_sketches, dc_model=dc_model, top_k_sketches=20, inv_temp=inv_temp, reward_fn=reward_fn, sample_fn=sample_fn, use_timeout=use_timeout) for _ in repeat(0)) 106 | data = (x for x in data if x is not None) 107 | grouped_data = islice(grouper(data, batchsize), size) 108 | 109 | for group in grouped_data: 110 | tps, ps, pseqs, IOs, sketchs, sketchseqs, rewards, sketchprobs = zip(*[(datum.tp, datum.p, datum.pseq, datum.IO, datum.sketch, datum.sketchseq, datum.reward, datum.sketchprob) for datum in group if datum is not None]) 111 | yield Batch(tps, ps, pseqs, IOs, sketchs, sketchseqs, torch.FloatTensor(rewards) if any(r is not None for r in rewards) else None, torch.FloatTensor(sketchprobs) if any(s is not None for s in sketchprobs) else None) # check that his works 112 | 113 | #TODO 114 | def makeTestdata(synth=True, challenge=False): 115 | raise NotImplementedError 116 | tasks = [] 117 | if synth: 118 | tasks = makeTasks() 119 | if challenge: 120 | challenge_tasks, _ = loadPBETasks() 121 | tasks = tasks + challenge_tasks 122 | 123 | tasklist = [] 124 | for task in tasks: 125 | if task.stringConstants==[] and task.request == arrow(tlist(tcharacter), tlist(tcharacter)): 126 | 127 | IO = tuple( (''.join(x[0]), ''.join(y)) for x,y in task.examples) 128 | 129 | program = None 130 | pseq = None 131 | sketch, sketchseq, reward, sketchprob = None, None, None, None 132 | tp = tprogram 133 | 134 | tasklist.append( Datum(tp, program, pseq, IO, sketch, sketchseq, reward, sketchprob) ) 135 | 136 | return tasklist 137 | 138 | #TODO 139 | def loadTestTasks(path='rb_test_tasks.p'): 140 | raise NotImplementedError 141 | print("data file:", path) 142 | with open(path, 'rb') as datafile: 143 | tasks = pickle.load(datafile) 144 | return tasks 145 | 146 | if __name__=='__main__': 147 | 148 | import time 149 | import pregex as pre 150 | 151 | g = basegrammar 152 | d = sample_datum(g=g, N=4, compute_sketches=True, top_k_sketches=100, inv_temp=1.0, reward_fn=None, sample_fn=None, dc_model=None) 153 | print(d.p) 154 | print(d.p.evaluate([])) 155 | print(d.sketch) 156 | #print(d.sketch.evaluate([])(pre.String(""))) 157 | print(d.sketch.evaluate([])) 158 | print(d.sketchseq) 159 | for o in d.IO: 160 | print("example") 161 | print(o) 162 | 163 | 164 | from util.regex_util import PregHole, pre_to_prog 165 | 166 | 167 | preg = pre.create(d.sketchseq, lookup={PregHole:PregHole()}) 168 | print(preg) 169 | print(pre_to_prog(preg)) 170 | 171 | -------------------------------------------------------------------------------- /data_src/makeRobustFillData.py: -------------------------------------------------------------------------------- 1 | #generate deepcoder data 2 | 3 | import pickle 4 | # TODO 5 | from util.deepcoder_util import make_holey_deepcoder # this might be enough 6 | #from util.robustfill_util import basegrammar # TODO 7 | from util.robustfill_util import sample_program, generate_IO_examples, timing # TODO 8 | 9 | import time 10 | from collections import namedtuple 11 | #Function = namedtuple('Function', ['src', 'sig', 'fun', 'bounds']) 12 | 13 | from grammar import Grammar, NoCandidates 14 | from utilities import flatten 15 | 16 | # TODO 17 | from RobustFillPrimitives import RobustFillProductions, flatten_program, tprogram 18 | 19 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 20 | import math 21 | import random 22 | from type import Context, arrow, tint, tlist, UnificationFailure, tcharacter 23 | from itertools import zip_longest, chain, repeat, islice 24 | from functools import reduce 25 | import torch 26 | from makeTextTasks import makeTasks, loadPBETasks 27 | from util.algolisp_util import make_holey_algolisp 28 | 29 | 30 | class Datum(): 31 | def __init__(self, tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob): 32 | self.tp = tp 33 | self.p = p 34 | self.pseq = pseq 35 | self.IO = IO 36 | self.sketch = sketch 37 | self.sketchseq = sketchseq 38 | self.reward = reward 39 | self.sketchprob = sketchprob 40 | 41 | def __hash__(self): 42 | return reduce(lambda a, b: hash(a + hash(b)), flatten(self.IO, abort=lambda x: type(x) is str), 0) + hash(self.p) + hash(self.sketch) 43 | 44 | Batch = namedtuple('Batch', ['tps', 'ps', 'pseqs', 'IOs', 'sketchs', 'sketchseqs', 'rewards', 'sketchprobs']) 45 | 46 | 47 | def sample_datum(basegrammar, 48 | N=5, 49 | V=100, 50 | L=10, 51 | compute_sketches=False, 52 | top_k_sketches=100, 53 | inv_temp=1.0, 54 | reward_fn=None, 55 | sample_fn=None, 56 | dc_model=None, 57 | use_timeout=False, 58 | improved_dc_model=False, 59 | nHoles=1, 60 | input_noise=False): 61 | 62 | #sample a program: 63 | #with timing("sample program"): 64 | program = sample_program(g=basegrammar, max_len=L, max_string_size=V) # TODO 65 | # if program is bad: return None # TODO 66 | 67 | # find IO 68 | #with timing("sample IO:"): 69 | IO = generate_IO_examples(program, num_examples=N, max_string_size=V) # TODO 70 | if IO is None: return None 71 | 72 | if input_noise: 73 | import random 74 | import string 75 | replace_with = random.choice(string.printable[:-4]) 76 | ex = random.choice(range(len(IO))) 77 | i_or_o = random.choice(range(2)) 78 | 79 | old = IO[ex][i_or_o] 80 | 81 | IO[ex] = list(IO[ex]) 82 | 83 | ln = len(old) 84 | if ln >= 1: 85 | idx = random.choice( range(ln) ) 86 | 87 | 88 | mut = random.choice(range(3)) 89 | 90 | if type(IO[ex]) == tuple: 91 | IO[ex] = list(IO[ex]) 92 | 93 | if mut ==0: #removal 94 | IO[ex][i_or_o] = old[:idx] + old[idx+1:] 95 | elif mut==1: #sub 96 | IO[ex][i_or_o] = old[:idx] + replace_with + old[idx+1:] 97 | else: #insertion 98 | IO[ex][i_or_o] = old[:idx] + replace_with + old[idx:] 99 | 100 | IO[ex] = tuple(IO[ex]) 101 | 102 | IO = tuple(IO) 103 | # find tp 104 | tp = tprogram 105 | # TODO 106 | 107 | # find pseq 108 | pseq = tuple(flatten_program(program)) #TODO 109 | 110 | if compute_sketches: 111 | # find sketch 112 | #grammar = g if not dc_model else dc_model.infer_grammar(IO) 113 | #with timing("make_holey"): 114 | # sketch, reward, sketchprob = make_holey_deepcoder(program, 115 | # top_k_sketches, 116 | # grammar, 117 | # tp, 118 | # inv_temp=inv_temp, 119 | # reward_fn=reward_fn, 120 | # sample_fn=sample_fn, 121 | # use_timeout=use_timeout) #TODO 122 | 123 | sketch, reward, sketchprob = make_holey_algolisp(program, 124 | top_k_sketches, 125 | tp, 126 | basegrammar, 127 | dcModel=dc_model, 128 | improved_dc_model=improved_dc_model, 129 | return_obj=Hole, 130 | dc_input=IO, 131 | inv_temp=inv_temp, 132 | reward_fn=reward_fn, 133 | sample_fn=sample_fn, 134 | use_timeout=use_timeout, 135 | nHoles=nHoles, 136 | domain='text') 137 | 138 | # find sketchseq 139 | sketchseq = tuple(flatten_program(sketch)) 140 | else: 141 | sketch, sketchseq, reward, sketchprob = None, None, None, None 142 | 143 | return Datum(tp, program, pseq, IO, sketch, sketchseq, reward, sketchprob) 144 | 145 | 146 | def grouper(iterable, n, fillvalue=None): 147 | "Collect data into fixed-length chunks or blocks" 148 | # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" 149 | args = [iter(iterable)] * n 150 | return zip_longest(*args, fillvalue=fillvalue) 151 | 152 | 153 | # max_iteration - dcModel.iteration, 154 | # batchsize=1, 155 | # g=basegrammar, 156 | # N=args.n_examples, 157 | # V=args.max_length, 158 | # L=args.max_list_length, 159 | # compute_sketches=args.improved_dc_model, 160 | # dc_model=dcModel if use_dc_grammar and (dcModel.epochs > 1) else None, # TODO 161 | # improved_dc_model=args.improved_dc_model, 162 | # top_k_sketches=args.k, 163 | # inv_temp=args.inv_temp, 164 | # nHoles=args.nHoles, 165 | # use_timeout=args.use_timeout 166 | 167 | def batchloader(size, 168 | basegrammar, 169 | batchsize=100, 170 | N=5, 171 | V=100, 172 | L=10, 173 | compute_sketches=False, 174 | dc_model=None, 175 | shuffle=True, 176 | top_k_sketches=20, 177 | inv_temp=1.0, 178 | reward_fn=None, 179 | sample_fn=None, 180 | use_timeout=False, 181 | improved_dc_model=False, 182 | nHoles=1, 183 | input_noise=False): 184 | data = (sample_datum(basegrammar, 185 | N=N, 186 | V=V, 187 | L=L, 188 | compute_sketches=compute_sketches, 189 | dc_model=dc_model, 190 | top_k_sketches=20, 191 | inv_temp=inv_temp, 192 | reward_fn=reward_fn, 193 | sample_fn=sample_fn, 194 | use_timeout=use_timeout, 195 | improved_dc_model=improved_dc_model, 196 | nHoles=nHoles, 197 | input_noise=input_noise) for _ in repeat(0)) 198 | data = (x for x in data if x is not None) 199 | if batchsize==1: 200 | yield from islice(data, size) 201 | else: 202 | grouped_data = islice(grouper(data, batchsize), size) 203 | for group in grouped_data: 204 | tps, ps, pseqs, IOs, sketchs, sketchseqs, rewards, sketchprobs = zip(*[(datum.tp, datum.p, datum.pseq, datum.IO, datum.sketch, datum.sketchseq, datum.reward, datum.sketchprob) for datum in group if datum is not None]) 205 | yield Batch(tps, ps, pseqs, IOs, sketchs, sketchseqs, torch.FloatTensor(rewards) if any(r is not None for r in rewards) else None, torch.FloatTensor(sketchprobs) if any(s is not None for s in sketchprobs) else None) # check that his works 206 | 207 | 208 | def makeTestdata(synth=True, challenge=False): 209 | tasks = [] 210 | if synth: 211 | tasks = makeTasks() 212 | if challenge: 213 | challenge_tasks, _ = loadPBETasks() 214 | tasks = tasks + challenge_tasks 215 | 216 | tasklist = [] 217 | for task in tasks: 218 | if task.stringConstants==[] and task.request == arrow(tlist(tcharacter), tlist(tcharacter)): 219 | 220 | IO = tuple( (''.join(x[0]), ''.join(y)) for x,y in task.examples) 221 | 222 | program = None 223 | pseq = None 224 | sketch, sketchseq, reward, sketchprob = None, None, None, None 225 | tp = tprogram 226 | 227 | tasklist.append( Datum(tp, program, pseq, IO, sketch, sketchseq, reward, sketchprob) ) 228 | 229 | return tasklist 230 | 231 | 232 | # tasks = makeTestdata(synth=True, challenge=True) 233 | # with open('rb_all_tasks.p', 'wb') as savefile: 234 | # pickle.dump(tasks, savefile) 235 | # print('saved rb challenge tasks') 236 | 237 | 238 | def loadTestTasks(path='rb_test_tasks.p'): 239 | print("data file:", path) 240 | with open(path, 'rb') as datafile: 241 | tasks = pickle.load(datafile) 242 | return tasks 243 | 244 | if __name__=='__main__': 245 | import time 246 | 247 | g = Grammar.fromProductions(RobustFillProductions(max_len=50, max_index=4)) 248 | d = sample_datum(g=g, N=4, V=50, L=10, compute_sketches=True, top_k_sketches=100, inv_temp=1.0, reward_fn=None, sample_fn=None, dc_model=None) 249 | print(d.p) 250 | for i,o in d.IO: 251 | print("example") 252 | print(i) 253 | print(o) 254 | 255 | tasks = loadTestTasks('rb_all_tasks.p') 256 | for t in tasks: print(t.IO) 257 | 258 | #loader = batchloader(600, g=g, batchsize=200, N=5, V=50, L=10, compute_sketches=True, dc_model=None, shuffle=True, top_k_sketches=10) 259 | 260 | # t = time.time() 261 | # for batch in loader: 262 | # print(time.time() - t) 263 | # print(batch.IOs[0]) 264 | # print(batch.ps[0]) 265 | 266 | # print(d) 267 | # if d is not None: 268 | # print(d.p) 269 | # print(d.IO) 270 | # print(d.sketch) 271 | # from itertools import islice 272 | # convert_source_to_datum("a <- [int] | b <- [int] | c <- ZIPWITH + b a | d <- COUNT isEVEN c | e <- ZIPWITH MAX a c | f <- MAP MUL4 e | g <- TAKE d f") 273 | 274 | # filename = 'data/DeepCoder_data/T2_A2_V512_L10_train_perm.txt' 275 | # train_data = 'data/DeepCoder_data/T3_A2_V512_L10_train_perm.txt' 276 | 277 | # test_data = '' 278 | 279 | # lines = (line.rstrip('\n') for i, line in enumerate(open(filename)) if i != 0) #remove first line 280 | 281 | # for datum in islice(batchloader([train_data], batchsize=1, N=5, V=128, L=10, compute_sketches=True, top_k_sketches=20, inv_temp=0.05), 30): 282 | # print("program:", datum.p) 283 | # print("sketch: ", datum.sketch) 284 | 285 | #path = 'data/pretrain_data_v1_alt.p' 286 | #make_deepcoder_data(path, with_holes=True, k=20) -------------------------------------------------------------------------------- /data_src/make_T4_test_data.py: -------------------------------------------------------------------------------- 1 | #make_T4_test_data.py 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | from eval.evaluate_deepcoder import * 8 | from itertools import islice 9 | 10 | ##load the test dataset### 11 | # test_data = ['data/DeepCoder_test_data/T3_A2_V512_L10_P500.txt'] 12 | # test_data = ['data/DeepCoder_test_data/T5_A2_V512_L10_P100_test.txt'] #modified from original 13 | # test_data = ['data/DeepCoder_data/T3_A2_V512_L10_validation_perm.txt'] 14 | 15 | # test_data = ['data/DeepCoder_data/T4_A2_V512_L10_train_perm.txt'] 16 | 17 | test_data = ['data/DeepCoder_data/T44_test.txt'] 18 | # test_data = ['data/DeepCoder_data/T44_train.txt'] 19 | 20 | #test_data = ['data/DeepCoder_data/T2_A2_V512_L10_test_perm.txt'] 21 | dataset = batchloader(test_data, batchsize=1, N=5, V=Vrange, L=10, compute_sketches=False) 22 | dataset = list(islice(dataset, 500)) 23 | with open('data/T44_test_new.p', 'wb') as savefile: 24 | pickle.dump(dataset, savefile) 25 | print("test file saved") -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/eval/__init__.py -------------------------------------------------------------------------------- /eval/evaluate_deepcoder.py: -------------------------------------------------------------------------------- 1 | # evaluate.py 2 | # import statements 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath('./')) 6 | sys.path.append(os.path.abspath('./ec')) 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch import nn, optim 12 | from pinn import RobustFill 13 | import random 14 | import math 15 | import time 16 | import dill 17 | import pickle 18 | from util.deepcoder_util import basegrammar 19 | from util.deepcoder_util import parseprogram, tokenize_for_robustfill 20 | from data_src.makeDeepcoderData import batchloader 21 | from plot.manipulate_results import percent_solved_n_checked, percent_solved_time, plot_result 22 | 23 | from util.pypy_util import DeepcoderResult, alternate, pypy_enumerate, SketchTup 24 | 25 | # TODO 26 | 27 | from program import ParseFailure, Context 28 | from grammar import NoCandidates, Grammar 29 | from utilities import timing, callCompiled 30 | 31 | from itertools import islice, zip_longest 32 | 33 | from models import deepcoderModel 34 | sys.modules['deepcoderModel'] = deepcoderModel 35 | 36 | """ rough schematic of what needs to be done: 37 | 1) Evaluate NN on test inputs 38 | 2) Possible beam search to find candidate sketches 39 | 3) unflatten "parseprogram" the candidate sketches 40 | 4) use hole enumerator from kevin to turn sketch -> program (g.sketchEnumeration) 41 | 5) when hit something, record and move on. 42 | 43 | so I need to build parseprogram, beam search and that may be it. 44 | """ 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--pretrained', action='store_true') 47 | parser.add_argument('--n_test', type=int, default=500) 48 | parser.add_argument('--dcModel', action='store_true') 49 | parser.add_argument('--dcModel_path',type=str, default="./saved_models/list_dc_model.p") 50 | parser.add_argument('--dc_baseline', action='store_true') 51 | parser.add_argument('--n_samples', type=int, default=30) 52 | parser.add_argument('--mdl', type=int, default=14) #9 53 | parser.add_argument('--n_examples', type=int, default=5) 54 | parser.add_argument('--Vrange', type=int, default=128) 55 | parser.add_argument('--precomputed_data_file', type=str, default='data/prelim_val_data_new.p') 56 | parser.add_argument('--model_path', type=str, default="./saved_models/list_holes.p") 57 | parser.add_argument('--max_to_check', type=int, default=5000) 58 | parser.add_argument('--resultsfile', type=str, default='NA') 59 | parser.add_argument('--shuffled', action='store_true') 60 | parser.add_argument('--beam', action='store_true') 61 | parser.add_argument('--improved_dc_grammar', action='store_true') 62 | 63 | #parallel stuff 64 | #TODO 65 | args = parser.parse_args() 66 | 67 | nSamples = args.n_samples 68 | mdl = args.mdl 69 | nExamples = args.n_examples 70 | Vrange = args.Vrange 71 | max_to_check = args.max_to_check 72 | improved_dc_grammar = args.improved_dc_grammar 73 | 74 | def untorch(g): 75 | if type(g.logVariable) == float: 76 | return g 77 | else: 78 | return Grammar(g.logVariable.data.tolist()[0], 79 | [ (l.data.tolist()[0], t, p) 80 | for l, t, p in g.productions]) 81 | 82 | def evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check): 83 | t = time.time() 84 | results = [] 85 | samples = {("",)} # make more general 86 | n_checked, n_hit = 0, 0 87 | 88 | if model: 89 | # can replace with a beam search at some point 90 | # TODO: use score for allocating resources 91 | tokenized = tokenize_for_robustfill([datum.IO]) 92 | if args.beam: 93 | samples, _scores = model.beam_decode(tokenized, beam_size=nRepeats) 94 | else: 95 | samples, _scores, _ = model.sampleAndScore(tokenized, nRepeats=nRepeats) 96 | # only loop over unique samples: 97 | samples = {tuple(sample) for sample in samples} 98 | 99 | if (not improved_dc_grammar) or (not dcModel): 100 | g = basegrammar if not dcModel else dcModel.infer_grammar(datum.IO) # TODO pp 101 | g = untorch(g) 102 | 103 | sketchtups = [] 104 | for sample in samples: 105 | try: 106 | sk = parseprogram(sample, datum.tp) 107 | except (ParseFailure, NoCandidates) as e: 108 | n_checked += 1 109 | results.append(DeepcoderResult(sample, None, False, n_checked, time.time()-t)) 110 | continue 111 | 112 | if improved_dc_grammar: 113 | g = untorch(dcModel.infer_grammar((datum.IO, sample))) 114 | sketchtups.append(SketchTup(sk, g)) 115 | # only loop over unique sketches: 116 | sketchtups = {sk for sk in sketchtups} 117 | #print(len(sketches)) 118 | #print(sketches) 119 | #alternate which sketch to enumerate from each time 120 | print(len(sketchtups)) 121 | print([sk.sketch for sk in sketchtups], sep='\n') 122 | enum_results, n_checked, n_hit = pypy_enumerate(untorch(g), datum.tp, datum.IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check) 123 | 124 | print(f"task {i}:") 125 | print(f"evaluation for task {i} took {time.time()-t} seconds") 126 | print(f"For task {i}, tried {n_checked} sketches, found {n_hit} hits", flush=True) 127 | return results + enum_results 128 | 129 | ######TODO: want search time and total time to hit task ###### 130 | 131 | 132 | def evaluate_dataset(model, dataset, nRepeats, mdl, max_to_check, dcModel=None): 133 | t = time.time() 134 | if model is None: 135 | print("evaluating dcModel baseline") 136 | return {datum: list(evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check)) for i, datum in enumerate(dataset)} 137 | 138 | #TODO: refactor for strings 139 | def save_results(results, args): 140 | timestr = str(int(time.time())) 141 | r = '_test' + str(args.n_test) + '_' 142 | if args.resultsfile != 'NA': 143 | filename = 'results/' + args.resultsfile + '.p' 144 | elif args.dc_baseline: 145 | filename = "results/prelim_results_dc_baseline_" + r + timestr + '.p' 146 | elif args.pretrained: 147 | filename = "results/prelim_results_rnn_baseline_" + r + timestr + '.p' 148 | else: 149 | dc = 'wdcModel_' if args.dcModel else '' 150 | filename = "results/prelim_results_" + dc + r + timestr + '.p' 151 | with open(filename, 'wb') as savefile: 152 | dill.dump(results, savefile) 153 | print("results file saved at", filename) 154 | return savefile 155 | 156 | 157 | if __name__=='__main__': 158 | #load the model 159 | if args.dc_baseline: 160 | print("computing dc baseline, no model") 161 | assert args.dcModel 162 | model = None 163 | else: 164 | print("loading model with holes") 165 | model = torch.load(args.model_path) #TODO 166 | if args.dcModel: 167 | print("loading dc_model") 168 | dcModel = torch.load(args.dcModel_path) 169 | else: dcModel = None 170 | 171 | # ###load the test dataset### 172 | # # test_data = ['data/DeepCoder_test_data/T3_A2_V512_L10_P500.txt'] 173 | # test_data = ['data/DeepCoder_test_data/T5_A2_V512_L10_P100_test.txt'] #modified from original 174 | # # test_data = ['data/DeepCoder_data/T3_A2_V512_L10_validation_perm.txt'] 175 | # dataset = batchloader(test_data, batchsize=1, N=5, V=Vrange, L=10, compute_sketches=False) 176 | # dataset = list(dataset) 177 | # with open('data/T5_test_data_new.p', 'wb') as savefile: 178 | # pickle.dump(dataset, savefile) 179 | # print("test file saved") 180 | # assert False 181 | 182 | print("data file:", args.precomputed_data_file) 183 | with open(args.precomputed_data_file, 'rb') as datafile: 184 | #import data_src.makeDeepcoderData as makeDeepcoderData 185 | dataset = pickle.load(datafile) 186 | # optional: 187 | 188 | if args.shuffled: 189 | random.seed(42) 190 | random.shuffle(dataset) 191 | #dataset = random.shuffle(dataset) 192 | del dataset[args.n_test:] 193 | 194 | results = evaluate_dataset(model, dataset, nSamples, mdl, max_to_check, dcModel=dcModel) 195 | 196 | # count hits 197 | hits = sum(any(result.hit for result in result_list) for result_list in results.values()) 198 | print(f"hits: {hits} out of {args.n_test}, or {100*hits/args.n_test}% accuracy") 199 | 200 | # I want a plot of the form: %solved vs n_hits 201 | x_axis = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 400, 600, 800, 900, 1000, 2000, 4000] # TODO 202 | y_axis = [percent_solved_n_checked(results, x) for x in x_axis] 203 | 204 | print("percent solved vs number of evaluated programs") 205 | print("num_checked:", x_axis) 206 | print("num_solved:", y_axis) 207 | 208 | #doesn't really need a full function ... 209 | file = save_results(results, args) 210 | 211 | plot_result(results=results, plot_time=True, model_path=args.model_path) #doesn't account for changing result thingy -------------------------------------------------------------------------------- /eval/evaluate_regex.py: -------------------------------------------------------------------------------- 1 | #evaluate.py 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | #import statements 7 | import argparse 8 | import torch 9 | from torch import nn, optim 10 | from pinn import RobustFill 11 | import pregex as pre 12 | from pregex import ParseException 13 | import random 14 | import math 15 | from util.regex_util import enumerate_reg, fill_hole, date_data, all_data, PregHole, basegrammar 16 | import time 17 | import pickle 18 | 19 | from util.regex_pypy_util import RegexResult, SketchTup, alternate, pypy_enumerate 20 | 21 | 22 | #make ll per character - done 23 | #TODO: 24 | #make posterior predictive 25 | 26 | #train & use dcModel 27 | #which requires converting programs to EC domain 28 | 29 | #THREE baselines: 30 | #vanilla RNN 31 | #vanilla dcModel 32 | #neural-sketch w/out dcModel 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--n_test', type=int, default=500) 36 | parser.add_argument('--dcModel', action='store_true') 37 | parser.add_argument('--dcModel_path',type=str, default="./saved_models/dc_model.p") 38 | parser.add_argument('--holeSpecificDcModel', action='store_true') 39 | parser.add_argument('--dc_baseline', action='store_true') 40 | parser.add_argument('--n_samples', type=int, default=30) 41 | parser.add_argument('--mdl', type=int, default=14) #9 42 | parser.add_argument('--n_examples', type=int, default=5) 43 | parser.add_argument('--precomputed_data_file', type=str, default='data/prelim_val_data.p') 44 | parser.add_argument('--model_path', type=str, default="./saved_models/deepcoder_holes.p") 45 | parser.add_argument('--max_to_check', type=int, default=5000) 46 | parser.add_argument('--resultsfile', type=str, default='NA') 47 | parser.add_argument('--shuffled', action='store_true') 48 | parser.add_argument('--beam', action='store_true') 49 | args = parser.parse_args() 50 | 51 | nSamples = args.n_samples 52 | mdl = args.mdl 53 | nExamples = args.n_examples 54 | max_to_check = args.max_to_check 55 | holeSpecificDcModel = args.holeSpecificDcModel 56 | if holeSpecificDcModel: assert args.dcModel 57 | 58 | 59 | lookup_d = {PregHole:PregHole()} 60 | 61 | def untorch(g): 62 | if type(g.logVariable) == float: 63 | return g 64 | else: 65 | return Grammar(g.logVariable.data.tolist()[0], 66 | [ (l.data.tolist()[0], t, p) 67 | for l, t, p in g.productions]) 68 | 69 | def evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check): 70 | t = time.time() 71 | samples = {(PregHole,)} # make more general # TODO, i don't think 72 | n_checked, n_hit = 0, 0 73 | if model: 74 | if args.beam: 75 | samples, _scores = model.beam_decode([datum.IO[:nExamples]], beam_size=nRepeats) 76 | else: 77 | samples, _scores, _ = model.sampleAndScore([datum.IO[:nExamples]], nRepeats=nRepeats) 78 | # only loop over unique samples: 79 | samples = {tuple(sample) for sample in samples} # only 80 | if (not holeSpecificDcModel) or (not dcModel): 81 | g = basegrammar if not dcModel else dcModel.infer_grammar(datum.IO[:nExamples]) # TODO pp 82 | g = untorch(g) 83 | sketchtups = [] 84 | for sample in samples: 85 | try: 86 | sk = pre_to_prog(pre.create(sample, lookup=lookup_d)) 87 | 88 | if holeSpecificDcModel: 89 | g = untorch(dcModel.infer_grammar((datum.IO[:nExamples], sample))) #TODO: make sure this line is correct .. 90 | 91 | sketchtups.append(SketchTup(sk, g)) 92 | 93 | except ParseException: 94 | n_checked += 1 95 | yield (RegexResult(sample, None, float('-inf'), n_checked, time.time()-t), float('-inf')) 96 | continue 97 | # only loop over unique sketches: 98 | sketchtups = {sk for sk in sketchtups} #fine 99 | #alternate which sketch to enumerate from each time 100 | 101 | results, n_checked, n_hit = pypy_enumerate(datum.tp, datum.IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check) 102 | yield from (result for result in results) 103 | 104 | ######TODO: want search time and total time to hit task ###### 105 | print(f"task {i}:") 106 | print(f"evaluation for task {i} took {time.time()-t} seconds") 107 | print(f"For task {i}, tried {n_checked} sketches, found {n_hit} hits") 108 | 109 | def evaluate_dataset(model, dataset, nRepeats, mdl, max_to_check, dcModel=None): 110 | t = time.time() 111 | if model is None: 112 | print("evaluating dcModel baseline") 113 | return {datum: list(evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check) ) for i, datum in enumerate(dataset)} 114 | 115 | def save_results(results, args): 116 | timestr = str(int(time.time())) 117 | r = '_test' + str(args.n_test) + '_' 118 | if args.resultsfile != 'NA': 119 | filename = 'results/' + args.resultsfile + '.p' 120 | elif args.dc_baseline: 121 | filename = "results/prelim_results_dc_baseline_" + r + timestr + '.p' 122 | elif args.pretrained: 123 | filename = "results/prelim_results_rnn_baseline_" + r + timestr + '.p' 124 | else: 125 | dc = 'wdcModel_' if args.dcModel else '' 126 | filename = "results/prelim_results_" + dc + r + timestr + '.p' 127 | with open(filename, 'wb') as savefile: 128 | dill.dump(results, savefile) 129 | print("results file saved at", filename) 130 | return savefile 131 | 132 | if __name__=='__main__': 133 | #load the model 134 | if args.dc_baseline: 135 | print("computing dc baseline, no model") 136 | assert args.dcModel 137 | model = None 138 | else: 139 | print("loading model with holes") 140 | model = torch.load(args.model_path) #TODO 141 | if args.dcModel: 142 | print("loading dc_model") 143 | dcModel = torch.load(args.dcModel_path) 144 | else: dcModel = None 145 | 146 | ###load the test dataset### 147 | dataset = date_data(20, nExamples=nExamples) 148 | print("loaded data") 149 | 150 | if args.shuffled: 151 | random.seed(42) 152 | random.shuffle(dataset) 153 | #dataset = random.shuffle(dataset) 154 | del dataset[args.n_test:] 155 | 156 | results = evaluate_dataset(model, dataset, nSamples, mdl, max_to_check, dcModel=dcModel) 157 | 158 | #doesn't really need a full function ... 159 | save_results(results, pretrained=args.pretrained) 160 | file = save_results(results, args) 161 | 162 | # I want a plot of the form: %solved vs n_hits 163 | x_axis = [10, 20, 50, 100, 200, 400, 600] # TODO 164 | y_axis = [percent_solved_n_checked(results, x) for x in x_axis] 165 | 166 | print("percent solved vs number of evaluated programs") 167 | print("num_checked:", x_axis) 168 | print("num_solved:", y_axis) 169 | 170 | #doesn't really need a full function ... 171 | file = save_results(results, args) 172 | 173 | ####cool graphic##### 174 | for sketch, prog, ll, task in results['task_tuples']: 175 | print(task, "-->", sketch, "-->", prog, "with ll", ll) 176 | 177 | 178 | -------------------------------------------------------------------------------- /eval/evaluate_robustfill.py: -------------------------------------------------------------------------------- 1 | # evaluate.py 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | # import statements 7 | import argparse 8 | from collections import namedtuple 9 | import torch 10 | from torch import nn, optim 11 | from pinn import RobustFill 12 | import random 13 | import math 14 | import time 15 | import dill 16 | import pickle 17 | from util.rb_pypy_util import basegrammar 18 | from util.robustfill_util import parseprogram, tokenize_for_robustfill 19 | from data_src.makeRobustFillData import batchloader, Datum 20 | from plot.manipulate_results import percent_solved_n_checked, percent_solved_time, plot_result 21 | from models.deepcoderModel import load_rb_dc_model_from_path, LearnedFeatureExtractor, DeepcoderRecognitionModel, RobustFillLearnedFeatureExtractor 22 | 23 | from util.rb_pypy_util import RobustFillResult, rb_pypy_enumerate, SketchTup 24 | 25 | from util.pypy_util import alternate 26 | 27 | 28 | from program import ParseFailure, Context 29 | from grammar import NoCandidates, Grammar 30 | 31 | from itertools import islice, zip_longest 32 | from functools import reduce 33 | 34 | """ rough schematic of what needs to be done: 35 | 1) Evaluate NN on test inputs 36 | 2) Possible beam search to find candidate sketches 37 | 3) unflatten "parseprogram" the candidate sketches 38 | 4) use hole enumerator from kevin to turn sketch -> program (g.sketchEnumeration) 39 | 5) when hit something, record and move on. 40 | 41 | so I need to build parseprogram, beam search and that may be it. 42 | """ 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--pretrained', action='store_true') 45 | parser.add_argument('--pretrained_model_path', type=str, default="./saved_models/text_pretrained.p") 46 | parser.add_argument('--n_test', type=int, default=500) 47 | parser.add_argument('--dcModel', action='store_true') 48 | parser.add_argument('--dc_model_path', type=str, default="./saved_models/text_dc_model.pstate_dict") 49 | parser.add_argument('--dc_baseline', action='store_true') 50 | parser.add_argument('--n_samples', type=int, default=30) 51 | parser.add_argument('--mdl', type=int, default=17) #9 52 | parser.add_argument('--n_rec_examples', type=int, default=4) 53 | parser.add_argument('--max_length', type=int, default=25) 54 | parser.add_argument('--max_index', type=int, default=4) 55 | parser.add_argument('--precomputed_data_file', type=str, default='saved_models/rb_test_tasks.p') 56 | parser.add_argument('--model_path', type=str, default="./saved_models/text_holes.p") 57 | parser.add_argument('--max_to_check', type=int, default=5000) 58 | parser.add_argument('--resultsfile', type=str, default='NA') 59 | parser.add_argument('--test_generalization', action='store_true') 60 | parser.add_argument('--beam', action='store_true') 61 | parser.add_argument('--improved_dc_grammar', action='store_true') 62 | parser.add_argument('--noise_eval', action='store_true') 63 | #parallel stuff 64 | args = parser.parse_args() 65 | 66 | nSamples = args.n_samples 67 | mdl = args.mdl 68 | n_rec_examples = args.n_rec_examples 69 | 70 | max_to_check = args.max_to_check 71 | improved_dc_grammar = args.improved_dc_grammar 72 | 73 | 74 | def untorch(g): 75 | if type(g.logVariable) == float: 76 | return g 77 | return Grammar(g.logVariable.data.tolist()[0], 78 | [ (l.data.tolist()[0], t, p) 79 | for l, t, p in g.productions]) 80 | 81 | def evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check): 82 | t = time.time() 83 | results = [] 84 | samples = {("",)} # make more general 85 | n_checked, n_hit = 0, 0 86 | #g = basegrammar if not dcModel else dcModel.infer_grammar(datum.IO) 87 | if model: 88 | # can replace with a beam search at some point 89 | # TODO: use score for allocating resources 90 | tokenized = tokenize_for_robustfill([datum.IO[:n_rec_examples]]) 91 | if args.beam: 92 | samples, _scores = model.beam_decode(tokenized, beam_size=nRepeats) 93 | else: 94 | samples, _scores, _ = model.sampleAndScore(tokenized, nRepeats=nRepeats) 95 | 96 | # only loop over unique samples: 97 | samples = {tuple(sample) for sample in samples} 98 | 99 | if (not improved_dc_grammar) or (not dcModel): 100 | g = basegrammar if not dcModel else dcModel.infer_grammar(datum.IO) # TODO pp 101 | g = untorch(g) 102 | 103 | sketchtups = [] 104 | for sample in samples: 105 | try: 106 | sk = parseprogram(sample, datum.tp) 107 | except (ParseFailure, NoCandidates) as e: 108 | n_checked += 1 109 | results.append(RobustFillResult(sample, None, False, n_checked, time.time()-t, False)) 110 | continue 111 | 112 | if improved_dc_grammar: 113 | g = untorch(dcModel.infer_grammar((datum.IO, sample))) 114 | sketchtups.append(SketchTup(sk, g)) 115 | 116 | # only loop over unique sketches: 117 | sketchtups = {sk for sk in sketchtups} 118 | #print(len(sketches)) 119 | #print(sketches) 120 | #alternate which sketch to enumerate from each time 121 | enum_results, n_checked, n_hit = rb_pypy_enumerate(datum.tp, datum.IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check, args.test_generalization, n_rec_examples, input_noise=args.noise_eval) 122 | 123 | ######TODO: want search time and total time to hit task ###### 124 | print(f"task {i}:") 125 | print(f"evaluation for task {i} took {time.time()-t} seconds") 126 | print(f"For task {i}, tried {n_checked} sketches, found {n_hit} hits", flush=True) 127 | return results + enum_results 128 | 129 | 130 | 131 | def evaluate_dataset(model, dataset, nRepeats, mdl, max_to_check, dcModel=None): 132 | t = time.time() 133 | if model is None: 134 | print("evaluating dcModel baseline") 135 | return {datum: list(evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check)) for i, datum in enumerate(dataset)} 136 | 137 | #TODO: refactor for strings 138 | def save_results(results, args): 139 | timestr = str(int(time.time())) 140 | r = '_test' + str(args.n_test) + '_' 141 | if args.resultsfile != 'NA': 142 | filename = 'rb_results/' + args.resultsfile + '.p' 143 | elif args.dc_baseline: 144 | filename = "rb_results/prelim_results_dc_baseline_" + r + timestr + '.p' 145 | elif args.pretrained: 146 | filename = "rb_results/prelim_results_rnn_baseline_" + r + timestr + '.p' 147 | else: 148 | dc = 'wdcModel_' if args.dcModel else '' 149 | filename = "rb_results/prelim_results_" + dc + r + timestr + '.p' 150 | with open(filename, 'wb') as savefile: 151 | dill.dump(results, savefile) 152 | print("results file saved at", filename) 153 | return savefile 154 | 155 | def randomize_IO(IO): 156 | IO = list(IO) 157 | import string 158 | replace_with = random.choice(string.printable[:-4]) 159 | ex = random.choice(range(len(IO))) 160 | i_or_o = random.choice(range(2)) 161 | 162 | old = IO[ex][i_or_o] 163 | 164 | IO[ex] = list(IO[ex]) 165 | 166 | ln = len(old) 167 | if ln >= 1: 168 | idx = random.choice( range(ln) ) 169 | 170 | mut = random.choice(range(3)) 171 | 172 | if type(IO[ex]) == tuple: 173 | IO[ex] = list(IO[ex]) 174 | 175 | if mut ==0: #removal 176 | IO[ex][i_or_o] = old[:idx] + old[idx+1:] 177 | elif mut==1: #sub 178 | IO[ex][i_or_o] = old[:idx] + replace_with + old[idx+1:] 179 | else: #insertion 180 | IO[ex][i_or_o] = old[:idx] + replace_with + old[idx:] 181 | 182 | IO = tuple(IO) 183 | return IO 184 | 185 | def randomize_datum(datum): 186 | #class Datum(): 187 | #def __init__(self, tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob): 188 | IO = randomize_IO(datum.IO) 189 | return Datum(datum.tp, datum.p, datum.pseq, IO, datum.sketch, datum.sketchseq, datum.reward, datum.sketchprob) 190 | 191 | 192 | if __name__=='__main__': 193 | #load the model 194 | if args.pretrained: 195 | print("loading pretrained model") 196 | model = torch.load(args.pretrained_model_path) 197 | elif args.dc_baseline: 198 | print("computing dc baseline, no model") 199 | assert args.dcModel 200 | model = None 201 | else: 202 | print("loading model with holes") 203 | model = torch.load(args.model_path) #TODO 204 | if args.dcModel: 205 | print("loading dc_model") 206 | dcModel = load_rb_dc_model_from_path(args.dc_model_path, args.max_length, args.max_index, improved_dc_grammar, cuda=True) 207 | 208 | 209 | print("data file:", args.precomputed_data_file) 210 | with open(args.precomputed_data_file, 'rb') as datafile: 211 | dataset = pickle.load(datafile) 212 | # optional: 213 | #dataset = random.shuffle(dataset) 214 | if args.noise_eval: 215 | random.seed(42) 216 | import random 217 | dataset = [randomize_datum(datum) for datum in dataset] 218 | 219 | del dataset[args.n_test:] 220 | 221 | results = evaluate_dataset(model, dataset, nSamples, mdl, max_to_check, dcModel=dcModel if args.dcModel else None) 222 | 223 | # count hits 224 | hits = sum(any(result.hit for result in result_list) for result_list in results.values()) 225 | print(f"hits: {hits} out of {len(dataset)}, or {100*hits/len(dataset)}% accuracy") 226 | 227 | # I want a plot of the form: %solved vs n_hits 228 | x_axis = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 400, 600, 800, 900, 1000, 2000, 4000] # TODO 229 | y_axis = [percent_solved_n_checked(results, x) for x in x_axis] 230 | 231 | print("percent solved vs number of evaluated programs") 232 | print("num_checked:", x_axis) 233 | print("num_solved:", y_axis) 234 | 235 | #doesn't really need a full function ... 236 | file = save_results(results, args) 237 | 238 | plot_result(results=results, plot_time=True, model_path=args.model_path, robustfill=True) #doesn't account for changing result thingy -------------------------------------------------------------------------------- /execute_any_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --qos=tenenbaum 4 | #SBATCH --time=2000 5 | #SBATCH --mem=30G 6 | #SBATCH --job-name=neural_sketch 7 | #SBATCH --cpus-per-task=1 8 | #SBATCH --gres=gpu:1 9 | 10 | 11 | #export PATH=/om/user/mnye/miniconda3/bin/:$PATH 12 | #source activate /om/user/mnye/vhe/envs/default/ 13 | #cd /om/user/mnye/vhe 14 | anaconda-project run $@ 15 | -------------------------------------------------------------------------------- /execute_gf_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --qos=tenenbaum 4 | #SBATCH --time=2000 5 | #SBATCH --mem=30G 6 | #SBATCH --job-name=neural_sketch 7 | #SBATCH --cpus-per-task=1 8 | #SBATCH --gres=gpu:GEFORCEGTX1080TI:1 9 | 10 | 11 | #export PATH=/om/user/mnye/miniconda3/bin/:$PATH 12 | #source activate /om/user/mnye/vhe/envs/default/ 13 | #cd /om/user/mnye/vhe 14 | anaconda-project run $@ 15 | -------------------------------------------------------------------------------- /execute_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --qos=tenenbaum 4 | #SBATCH --time=3000 5 | #SBATCH --mem=50G 6 | #SBATCH --job-name=neural_sketch 7 | #SBATCH --cpus-per-task=1 8 | #SBATCH --gres=gpu:titan-x:1 9 | 10 | 11 | #export PATH=/om/user/mnye/miniconda3/bin/:$PATH 12 | #source activate /om/user/mnye/vhe/envs/default/ 13 | #cd /om/user/mnye/vhe 14 | which python 15 | anaconda-project run $@ 16 | -------------------------------------------------------------------------------- /execute_k80_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --qos=tenenbaum 4 | #SBATCH --time=2000 5 | #SBATCH --mem=30G 6 | #SBATCH --job-name=neural_sketch 7 | #SBATCH --cpus-per-task=1 8 | #SBATCH --gres=gpu:tesla-k80:1 9 | 10 | 11 | #export PATH=/om/user/mnye/miniconda3/bin/:$PATH 12 | #source activate /om/user/mnye/vhe/envs/default/ 13 | #cd /om/user/mnye/vhe 14 | anaconda-project run $@ 15 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/deepcoderModel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/models/__pycache__/deepcoderModel.cpython-36.pyc -------------------------------------------------------------------------------- /plot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/plot/__init__.py -------------------------------------------------------------------------------- /plot/make_final_plots.py: -------------------------------------------------------------------------------- 1 | #make_final_plots 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | from plot.manipulate_results import plot_result_list 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--legend_list', nargs='*') 12 | parser.add_argument('--file_list', nargs='*') 13 | parser.add_argument('--filename', type=str) 14 | parser.add_argument('--robustfill', action='store_true') 15 | parser.add_argument('--generalization', action='store_true') 16 | parser.add_argument('--notime', action='store_true') 17 | parser.add_argument('--title', type=str, default='NA') 18 | parser.add_argument('--max_budget', type=int, default=500000) 19 | parser.add_argument('--double', action='store_true') 20 | args = parser.parse_args() 21 | 22 | if __name__=='__main__': 23 | args.legend_list = [l.replace('_', ' ') for l in args.legend_list] 24 | print(args.legend_list) 25 | plot_result_list(args.file_list, args.legend_list, args.filename, robustfill=args.robustfill, plot_time=(not args.notime), generalization=args.generalization, title=args.title, max_budget=args.max_budget, double=args.double) -------------------------------------------------------------------------------- /plot/make_frontier_plot.py: -------------------------------------------------------------------------------- 1 | #make frontier plot 2 | import sys 3 | import os 4 | 5 | 6 | from plot.manipulate_results import * 7 | 8 | resultsfile_dict = {5:'results/dc_timeout_0.25_40k_beam5.p', 20:'results/dc_timeout_0.25_40k_beam20.p', 50:'results/dc_timeout_0.25_40k.p', 100:'results/dc_timeout_0.25_40k_beam100.p' } #fill this in 9 | 10 | result_dict = {} 11 | for num, resultsfile in resultsfile_dict.items(): 12 | with open(resultsfile, 'rb') as savefile: 13 | results = dill.load(savefile) 14 | result_dict[num] = results 15 | 16 | 17 | def find_x_percent(result, x, n_checked_min=0, n_checked_max=40000): 18 | 19 | fudge = 0.005 20 | n_fudge = 3 21 | 22 | #n_checked_min = 0 23 | #n_checked_max = 40000 24 | 25 | #percent_min = percent_solved_n_checked(results, n_checked_min) 26 | #percent_max = percent_solved_n_checked(results, n_checked_max) 27 | 28 | n_checked_mid = int(n_checked_min + n_checked_max)/2 29 | 30 | if n_checked_mid > 39999: return 40000 31 | elif n_checked_mid < 2: return 1 32 | 33 | percent_mid = percent_solved_n_checked(result, n_checked_mid) 34 | print("testing", n_checked_mid) 35 | print(n_checked_mid, "had percent of", percent_mid) 36 | 37 | if abs(x - percent_mid) < fudge: return n_checked_mid 38 | elif abs(n_checked_mid - n_checked_min) < n_fudge or abs(n_checked_mid - n_checked_max) < n_fudge: return n_checked_mid 39 | elif x > percent_mid: return find_x_percent(result, x, n_checked_min=n_checked_mid, n_checked_max=n_checked_max) 40 | elif x < percent_mid: return find_x_percent(result, x, n_checked_min=n_checked_min, n_checked_max=n_checked_mid) 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | line_60 = list(zip(*[(find_x_percent(result, .850), beam_size) for beam_size, result in result_dict.items()])) 49 | 50 | line_40 = list(zip(*[(find_x_percent(result, .80), beam_size) for beam_size, result in result_dict.items()])) 51 | 52 | line_20 = list(zip(*[(find_x_percent(result, .70), beam_size) for beam_size, result in result_dict.items()])) 53 | 54 | 55 | result_list = [line_20, line_40, line_60] 56 | legend_list = ["70% solved", "80% solved", "85% solved"] 57 | 58 | 59 | 60 | #plot the 3 lines: 61 | 62 | fig, ax = plt.subplots() 63 | 64 | for result, legend in zip(result_list, legend_list): 65 | #y_axis = [percent_solved_n_checked(result, x) for x in x_axis] 66 | 67 | plt.plot(result[0], result[1], label=legend, linewidth=4.0, marker="o") 68 | 69 | ax.set(title='Frontiers', ylabel='Beam size', xlabel='Number of candidates evaluated per problem') 70 | ax.legend(loc='best') 71 | savefile='plots/' + 'frontier' + '.eps' 72 | #savefile = 'plots/time_prelim.png' 73 | plt.savefig(savefile) -------------------------------------------------------------------------------- /plot/manipulate_results.py: -------------------------------------------------------------------------------- 1 | #play with results 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | #added for camera ready so there are no type 3 fonts 10 | matplotlib.rcParams['pdf.fonttype'] = 42 11 | matplotlib.rcParams['ps.fonttype'] = 42 12 | 13 | import matplotlib.pyplot as plt 14 | import dill 15 | import time 16 | import argparse 17 | from plot.p_solved_hack import hack_percent_solved_n_checked 18 | 19 | from data_src import makeRobustFillData 20 | sys.modules['makeRobustFillData'] = makeRobustFillData 21 | 22 | SMALL_SIZE = 14 23 | MEDIUM_SIZE = 16 24 | BIGGER_SIZE = 18 25 | 26 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 27 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 28 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 29 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 30 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 31 | plt.rc('legend', fontsize=12) # legend fontsize 32 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 33 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0)) 34 | 35 | 36 | mem_problem_list = ['results/dc_T45_0.25_800k_pypy.p', 'results/dc_T34_0.25_2M_pypy.p', 'results/dc_T45_0.25_2M_pypy.p'] 37 | 38 | def solve_time_percentile(results, percentile, use_misses=False): 39 | import numpy as np 40 | #gather the list of times: 41 | time_list = [] 42 | for result_list in results.values(): 43 | candidates = [ result.time for result in result_list if result.hit] 44 | if len(candidates) > 0: 45 | min_val = min( candidates) 46 | time_list.append(min_val) 47 | elif use_misses: 48 | min_val = float('inf') 49 | time_list.append(min_val) 50 | else: #do nothing, because don't add to list 51 | pass 52 | #sum(any(result.hit and result.time <= time for result in result_list) for result_list in results.values())/len(results) 53 | return np.percentile(time_list, percentile) 54 | 55 | def percent_solved_n_checked(results, n_checked): 56 | return 100*sum(any(result.hit and result.n_checked <= n_checked for result in result_list) for result_list in results.values())/len(results) 57 | #speed this up!!! 58 | 59 | 60 | def percent_solved_time(results, time): 61 | return sum(any(result.hit and result.time <= time for result in result_list) for result_list in results.values())/len(results) 62 | 63 | def generalization_ratio(results, n_checked): 64 | return 0 if sum(any(result.hit and result.n_checked <= n_checked for result in result_list) for result_list in results.values()) <= 0 else sum(any(result.g_hit and result.n_checked <= n_checked for result in result_list) for result_list in results.values())/sum(any(result.hit and result.n_checked <= n_checked for result in result_list) for result_list in results.values()) 65 | 66 | 67 | def rload(resultsfile): 68 | with open(resultsfile, 'rb') as savefile: 69 | results = dill.load(savefile) 70 | return results 71 | 72 | 73 | def plot_result(results=None, baseresults=None, resultsfile=None, basefile='results/prelim_results_dc_baseline__test50_1536770416.p', plot_time=True, model_path=None, robustfill=False, rnn=False): 74 | assert bool(results) != bool(resultsfile) # xor of them 75 | if not results: 76 | with open(resultsfile, 'rb') as savefile: 77 | results = dill.load(savefile) 78 | if not baseresults: 79 | try: 80 | with open(basefile, 'rb') as savefile: 81 | baseresults = dill.load(savefile) 82 | except FileNotFoundError: 83 | with open('../../'+basefile, 'rb') as savefile: # in case it is not in this directory 84 | baseresults = dill.load(savefile) 85 | 86 | filename = model_path + str(time.time()) if model_path else str(time.time()) 87 | baseline=basefile 88 | x_axis = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 400, 600, 800, 900, 1000, 2000, 4000] # TODO 89 | y_axis = [percent_solved_n_checked(results, x) for x in x_axis] 90 | base_y = [percent_solved_n_checked(baseresults, x) for x in x_axis] 91 | x_time = [0.1, 1, 1.5, 2, 2.5, 3, 4, 5, 10, 15, 20, 25, 30, 45, 60, 90, 120, 180, 200, 240] 92 | our_time = [percent_solved_time(results, x) for x in x_time] 93 | base_time = [percent_solved_time(baseresults, x) for x in x_time] 94 | print(x_axis) 95 | print(y_axis) 96 | fig, ax = plt.subplots() 97 | plt.plot(x_axis, y_axis, label='flexible neural sketch (ours)', linewidth=4.0, marker="o") 98 | if rnn: plt.plot(x_axis, base_y, label='rnn baseline', linewidth=4.0, marker="o") 99 | else: plt.plot(x_axis, base_y, label='deepcoder baseline', linewidth=4.0, marker="o") 100 | 101 | if robustfill: 102 | ax.set(title='Robustfill problems solved by neural sketch system', ylabel='% of problems solved', xlabel='Number of candidates evaluated per problem') 103 | else: 104 | ax.set(title='Deepcoder problems solved by neural sketch system', ylabel='% of problems solved', xlabel='Number of candidates evaluated per problem') 105 | ax.legend(loc='best') 106 | savefile='plots/' + baseline.split('/')[-1] + filename.split('/')[-1] + '.eps' 107 | plt.savefig(savefile) 108 | if plot_time: 109 | fig, ax = plt.subplots() 110 | plt.plot(x_time, our_time, label='flexible neural sketch (ours)', linewidth=4.0, marker="o") 111 | if rnn: plt.plot(x_time, base_time, label='rnn baseline', linewidth=4.0, marker="o") 112 | else: plt.plot(x_time, base_time, label='deepcoder baseline', linewidth=4.0, marker="o") 113 | 114 | if robustfill: 115 | ax.set(title='Robustfill problems solved by neural sketch system', ylabel='% of problems solved', xlabel='wall clock time (seconds)') 116 | else: 117 | ax.set(title='Deepcoder problems solved by neural sketch system', ylabel='% of problems solved', xlabel='wall clock time (seconds)') 118 | ax.legend(loc='best') 119 | savefile='plots/time_' + baseline.split('/')[-1] + filename.split('/')[-1] + '.eps' 120 | #savefile = 'plots/time_prelim.png' 121 | plt.savefig(savefile) 122 | 123 | 124 | def plot_result_list(file_list, legend_list, filename, robustfill=False, plot_time=True, generalization=False, double=False, title='NA', max_budget=500000): 125 | result_list = [] 126 | if double: 127 | print("WARNING: this is a major hack, don't try at home") 128 | l = int(len(file_list)/2) 129 | print("num:", l) 130 | fl = zip(file_list[:l], file_list[l:]) 131 | for f1, f2 in fl: 132 | with open(f1, 'rb') as savefile: 133 | r1 = dill.load(savefile) 134 | print("type:") 135 | print(type(r1)) 136 | with open(f2, 'rb') as savefile: 137 | r2 = dill.load(savefile) 138 | print("combining", f1, "and", f2, "...") 139 | result_list.append({**r1, **r2}) 140 | else: 141 | for file in file_list: 142 | if file in mem_problem_list: 143 | result_list.append(file) 144 | else: 145 | with open(file, 'rb') as savefile: 146 | result_list.append(dill.load(savefile)) 147 | if max_budget > 40000: 148 | x_axis= [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 400, 600, 800, 900, 1000, 2000, 4000, 5000, 8000, 10000, 12000, 15000, 20000, 40000, 50000, 60000, 80000] + list(range(100000, max_budget, 100000)) # TODO 149 | else: 150 | x_axis= [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 400, 600, 800, 900, 1000, 2000, 4000, 5000, 8000, 10000, 12000, 15000, 20000, 40000] # TODO 151 | x_time = [0.1, 1, 1.5, 2, 2.5, 3, 4, 5, 10, 15, 20, 25, 30, 45, 60, 90, 120, 140, 150, 160, 180, 200, 240, 300] #, 400]+ list(range(500,2600,100)) 152 | 153 | fig, ax = plt.subplots() 154 | #fig.set_size_inches(10,6) 155 | 156 | for result, legend in zip(result_list, legend_list): 157 | 158 | if 'RNN' in legend or 'RobustFill' in legend: 159 | y = percent_solved_n_checked(result, x_axis[-1]) 160 | y_axis = [y for _ in x_axis] 161 | if '50' in legend: 162 | plt.semilogx(x_axis, y_axis, label=legend, linewidth=2.0, linestyle='--', c='C6') 163 | else: 164 | plt.semilogx(x_axis, y_axis, label=legend, linewidth=2.0, linestyle='--', c='C4') 165 | else: 166 | if generalization: 167 | y_axis = [generalization_ratio(result, x) for x in x_axis] 168 | else: 169 | if result in mem_problem_list: 170 | y_axis = [hack_percent_solved_n_checked(result, x) for x in x_axis] #TODO!!!! 171 | else: 172 | y_axis = [percent_solved_n_checked(result, x) for x in x_axis] 173 | if "Deepcoder" in legend: plt.plot(x_axis, y_axis, label=legend, linewidth=2.0, linestyle='-.', c='C3') 174 | else: plt.semilogx(x_axis, y_axis, label=legend, linewidth=2.0) 175 | 176 | 177 | if robustfill: 178 | if generalization: ax.set(title='String transformation generalization_ratio', ylabel='num correct hits/num total hits', xlabel='Number of candidates evaluated per problem') 179 | else: 180 | ax.set(title='String editing programs' if title=='NA' else title, ylabel='% of problems solved', xlabel='Number of candidates evaluated per problem') 181 | #ax.set_aspect(0.5) 182 | else: 183 | ax.set(title='length 3 test programs' if title=='NA' else title, ylabel='% of problems solved', xlabel='Number of candidates evaluated per problem') 184 | #ax.set(title='Trained on length 4 programs, tested on length 5 programs', ylabel='% of problems solved', xlabel='Number of candidates evaluated per problem') 185 | 186 | ax.legend(loc='best') 187 | savefile='plots/' +filename+ '.eps' 188 | plt.savefig(savefile) 189 | 190 | 191 | if plot_time: 192 | 193 | fig, ax = plt.subplots() 194 | 195 | for result, legend in zip(result_list, legend_list): 196 | our_time = [percent_solved_time(result, x) for x in x_time] 197 | if 'RNN' in legend or 'RobustFill' in legend: 198 | if '50' in legend: 199 | plt.plot(x_time, our_time, label=legend, linewidth=2.0, linestyle='--', c='C6') 200 | else: 201 | plt.plot(x_time, our_time, label=legend, linewidth=2.0, linestyle='--', c='C4') 202 | elif "Deepcoder" in legend: 203 | plt.plot(x_time, our_time, label=legend, linewidth=2.0, linestyle='-.', c='C3') 204 | else: 205 | plt.plot(x_time, our_time, label=legend, linewidth=2.0) 206 | 207 | if robustfill: 208 | ax.set(title='String editing problems - evaluation time', ylabel='% of problems solved', xlabel='wall clock time (seconds)') 209 | else: 210 | ax.set(title='List processing problems solved by neural sketch system', ylabel='% of problems solved', xlabel='wall clock time (seconds)') 211 | ax.legend(loc='best') 212 | 213 | savefile='plots/time_' + filename + '.eps' 214 | #savefile = 'plots/time_prelim.png' 215 | plt.savefig(savefile) 216 | 217 | 218 | if __name__=='__main__': 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument('--basefile', type=str, default='results/prelim_results_dc_baseline__test50_1536770416.p') 221 | parser.add_argument('--resultsfile', type=str, default='results/prelim_results_wdcModel__test50_1536803560.p') 222 | parser.add_argument('--rb', action='store_true') 223 | parser.add_argument('--rnnbase', action='store_true') 224 | args = parser.parse_args() 225 | # filename = 'results/prelim_results_dc_baseline__test20_1536185296.p' 226 | # #filename = 'results/prelim_results_wdcModel__test20_1536184791.p' no input, i borked it 227 | # filename = 'results/prelim_results_wdcModel__test20_1536102640.p' 228 | # filename = 'results/prelim_results_dc_baseline__test200_1536244251.p' 229 | # filename = 'results/prelim_results_wdcModel__test50_1536249076.p' 230 | # filename = 'results/prelim_results_dc_baseline__test50_1536246985.p' 231 | #baseline = 'results/prelim_results_dc_baseline__test50_1536260203.p' 232 | #filename = 'results/prelim_results_wdcModel__test50_1536249614.p' 233 | #baseline = 'results/prelim_results_dc_baseline__test50_1536770416.p' 234 | #filename = 'results/prelim_results_wdcModel__test50_1536770193.p' # supervised trained 235 | #filename = 'results/prelim_results_wdcModel__test50_1536770602.p' # rl trained old version 236 | #filename = 'results/prelim_results_wdcModel__test50_1536818453.p' # exp86 without var red 237 | #filename = 'results/prelim_results_wdcModel__test50_1536819663.p' # linear 8 with variance reduction 238 | #filename = 'results/prelim_results_wdcModel__test50_1536854115.p' #flat without var red 239 | #filename = 'results/prelim_results_wdcModel__test50_1536854557.p' # exp86 with variance red 240 | #filename = 'results/prelim_results_wdcModel__test50_1536803560.p' # exp86 w no var red take 2 241 | 242 | plot_result(resultsfile=args.resultsfile, basefile=args.basefile, plot_time=True, robustfill=args.rb, rnn=args.rnnbase) 243 | 244 | -------------------------------------------------------------------------------- /plot/manual_plot.py: -------------------------------------------------------------------------------- 1 | #manual_plot.py 2 | 3 | from plot.manipulate_results import * 4 | 5 | filename = 6 | file_list = 7 | dc = 8 | rnn = 9 | 10 | 11 | def load_result(filename): 12 | with open(file, 'rb') as savefile: 13 | result = dill.load(savefile) 14 | return result 15 | 16 | x_axis= [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 400, 600, 800, 900, 1000, 2000, 4000, 5000, 8000, 10000, 12000, 15000, 20000, 40000] #, 50000, 60000, 80000, 100000] # TODO 17 | x_time = [0.1, 1, 1.5, 2, 2.5, 3, 4, 5, 10, 15, 20, 25, 30, 45, 60, 90, 120, 140, 150, 160, 180, 200, 240, 300, 400, 500, 600, 800, 1200] 18 | fig, ax = plt.subplots() 19 | 20 | 21 | result = load_result() 22 | y_axis = [percent_solved_n_checked(result, x) for x in x_axis] 23 | plt.plot(x_axis, y_axis, label=legend, linewidth=4.0) 24 | 25 | 26 | result = load_result() 27 | y_axis = [percent_solved_n_checked(result, x) for x in x_axis] 28 | plt.plot(x_axis, y_axis, label=legend, linewidth=4.0) 29 | 30 | 31 | result = load_result() 32 | y_axis = [percent_solved_n_checked(result, x) for x in x_axis] 33 | plt.plot(x_axis, y_axis, label=legend, linewidth=4.0) 34 | 35 | 36 | 37 | 38 | ax.set(title='String transformation generalization_ratio', ylabel='num correct hits/num total hits', xlabel='Number of candidates evaluated per problem') 39 | #ax.set(title='List processing problems solved by neural sketch system', ylabel='% of problems solved', xlabel='Number of candidates evaluated per problem') 40 | ax.legend(loc='best') 41 | savefile='plots/' +filename+ '.eps' 42 | plt.savefig(savefile) 43 | 44 | 45 | # if plot_time: 46 | 47 | # fig, ax = plt.subplots() 48 | 49 | # for result, legend in zip(result_list, legend_list): 50 | # our_time = [percent_solved_time(result, x) for x in x_time] 51 | # plt.plot(x_time, our_time, label=legend, linewidth=4.0) 52 | 53 | # if robustfill: 54 | # ax.set(title='String transformation problems solved by neural sketch system', ylabel='% of problems solved', xlabel='wall clock time (seconds)') 55 | # else: 56 | # ax.set(title='List processing problems solved by neural sketch system', ylabel='% of problems solved', xlabel='wall clock time (seconds)') 57 | # ax.legend(loc='best') 58 | 59 | # savefile='plots/time_' + filename + '.eps' 60 | # #savefile = 'plots/time_prelim.png' 61 | # plt.savefig(savefile) 62 | 63 | 64 | -------------------------------------------------------------------------------- /plot/p_solved_hack.py: -------------------------------------------------------------------------------- 1 | 2 | # for dc_T45_0.25_2M_pypy: 3 | d = {} 4 | d['results/dc_T45_0.25_800k_pypy.p'] = (96, [158, 5 | 163396, 6 | 311247, 7 | 657822, 8 | 8, 9 | 119964, 10 | 227, 11 | 2, 12 | 908, 13 | 57035, 14 | 3466, 15 | 165488, 16 | 3815, 17 | 14, 18 | 592314, 19 | 2, 20 | 28241, 21 | 192861, 22 | 473637, 23 | 384, 24 | 174487, 25 | 1, 26 | 109529, 27 | 41359, 28 | 37, 29 | 9485, 30 | 349052, 31 | 4, 32 | 11289, 33 | 411421, 34 | 31855, 35 | 26144, 36 | 337177]) 37 | 38 | d['results/dc_T34_0.25_2M_pypy.p'] = (100, [1, 39 | 9, 40 | 3, 41 | 166765, 42 | 194661, 43 | 1453568, 44 | 24103, 45 | 2452, 46 | 3, 47 | 860852, 48 | 9188, 49 | 9, 50 | 10534, 51 | 177000, 52 | 88, 53 | 12910, 54 | 1234602, 55 | 2598, 56 | 17689, 57 | 167744, 58 | 86728, 59 | 10419, 60 | 1905392, 61 | 70292, 62 | 1886154, 63 | 1834284, 64 | 15165, 65 | 3, 66 | 86520, 67 | 51462, 68 | 445236, 69 | 445236, 70 | 262896, 71 | 30792, 72 | 637748, 73 | 459280, 74 | 1, 75 | 531075, 76 | 31822, 77 | 1419164, 78 | 5423, 79 | 42411, 80 | 733196, 81 | 66150, 82 | 697590, 83 | 2386, 84 | 82505, 85 | 8024, 86 | 694229, 87 | 559142, 88 | 537222, 89 | 9513, 90 | 37, 91 | 193725, 92 | 1839235, 93 | 2339, 94 | 27338, 95 | 1, 96 | 196727, 97 | 868, 98 | 66016, 99 | 1706659, 100 | 370707, 101 | 75508, 102 | 28217, 103 | 56909, 104 | 10046, 105 | 1, 106 | 383314, 107 | 86560, 108 | 1088, 109 | 294219, 110 | 1, 111 | 24859]) 112 | 113 | 114 | d['results/dc_T45_0.25_2M_pypy.p'] = (96, [133, 115 | 163417, 116 | 311212, 117 | 657815, 118 | 17, 119 | 1468774, 120 | 119987, 121 | 220, 122 | 1, 123 | 871, 124 | 57035, 125 | 3466, 126 | 165497, 127 | 1234700, 128 | 3787, 129 | 30, 130 | 592317, 131 | 1, 132 | 28217, 133 | 192901, 134 | 473666, 135 | 1699662, 136 | 1362055, 137 | 389, 138 | 174519, 139 | 1, 140 | 109525, 141 | 41394, 142 | 84, 143 | 9460, 144 | 349045, 145 | 1639884, 146 | 9, 147 | 11291, 148 | 1793633, 149 | 411375, 150 | 31879, 151 | 1436020, 152 | 930180, 153 | 26136, 154 | 337179]) 155 | 156 | #def percent_solved_n_checked(results, n_checked): 157 | # return sum(any(result.hit and result.n_checked <= n_checked for result in result_list) for result_list in results.values())/len(results) 158 | # #speed this up!!! 159 | 160 | def hack_percent_solved_n_checked(results, n_checked): 161 | n, x = d[results] 162 | return sum(num_checked <= n_checked for num_checked in x)/n 163 | 164 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/tests/__init__.py -------------------------------------------------------------------------------- /tests/demo.py: -------------------------------------------------------------------------------- 1 | #demo.py 2 | 3 | 4 | import argparse 5 | import torch 6 | from torch import nn, optim 7 | 8 | from pinn import RobustFill 9 | import pregex as pre 10 | #from vhe import VHE, DataLoader, Factors, Result, RegexPrior 11 | import random 12 | 13 | from sketch_project import Hole 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--pretrained', action='store_true') 18 | parser.add_argument('--pretrain_holes', action='store_true') 19 | args = parser.parse_args() 20 | 21 | 22 | print("loading model") 23 | 24 | if args.pretrained: 25 | print("loading pretrained_model") 26 | model=torch.load("./sketch_model_pretrained.p") 27 | elif args.pretrain_holes: 28 | print("loading pretrain_holes") 29 | model=torch.load("./sketch_model_pretrain_holes.p") 30 | else: 31 | print("loading model with holes") 32 | model=torch.load("./sketch_model_holes.p") 33 | 34 | 35 | for i in range(999): 36 | print("-"*20, "\n") 37 | if i==0: 38 | examples = ["bar", "car", "dar"] 39 | print("Using examples:") 40 | for e in examples: print(e) 41 | print() 42 | else: 43 | print("Please enter examples (one per line):") 44 | examples = [] 45 | nextInput = True 46 | while nextInput: 47 | s = input() 48 | if s=="": 49 | nextInput=False 50 | else: 51 | examples.append(s) 52 | 53 | print("calculating... ") 54 | samples, scores = model.sampleAndScore([examples], nRepeats=2) 55 | print(samples) 56 | print(scores) 57 | print(len(scores), len(samples)) 58 | index = scores.index(max(scores)) 59 | #print(samples[index]) 60 | try: sample = pre.create(list(samples[index])) 61 | except: sample = samples[index] 62 | #sample = samples[index] 63 | print("best example by nn score:", sample, ", nn score:", max(scores)) 64 | 65 | 66 | pregexes = [] 67 | pscores = [] 68 | for samp in samples: 69 | try: 70 | reg = pre.create(list(samp)) 71 | pregexes.append(reg) 72 | pscores.append(sum(reg.match(ex) for ex in examples )) 73 | except: 74 | pregexes.append(samp) 75 | pscores.append(float('-inf')) 76 | 77 | index = pscores.index(max(pscores)) 78 | preg = pregexes[index] 79 | 80 | print("best example by pregex score:", preg, ", preg score:", max(pscores)) 81 | -------------------------------------------------------------------------------- /tests/test_beam.py: -------------------------------------------------------------------------------- 1 | #test_beam.py 2 | from eval.evaluate_deepcoder import * 3 | 4 | path = 'experiments/deepcoder_timeout_0.5_1537326548865/deepcoder_holes.p' 5 | 6 | model = torch.load(path) 7 | 8 | with open(args.precomputed_data_file, 'rb') as datafile: 9 | dataset = pickle.load(datafile) 10 | 11 | print("loaded model and dataset") 12 | 13 | 14 | datum = dataset[0] 15 | 16 | print("IO:") 17 | print(datum.IO) 18 | 19 | tokenized = tokenize_for_robustfill([datum.IO]) 20 | #samples, _scores, _ = model.beam_decode(tokenized, nRepeats=nRepeats) 21 | 22 | # samples, _scores, _ = model.sampleAndScore(tokenized, nRepeats=10) 23 | # print(samples, _scores) 24 | # assert False 25 | targets, scores = model.beam_decode(tokenized, beam_size=10, vocab_filter=None) 26 | 27 | 28 | 29 | print(targets, scores) -------------------------------------------------------------------------------- /tests/test_callCompiled.py: -------------------------------------------------------------------------------- 1 | #test callCompiled 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath('./')) 5 | sys.path.append(os.path.abspath('./ec')) 6 | 7 | from utilities import callCompiled, eprint 8 | 9 | from fun import f 10 | 11 | x = 6 12 | ans = callCompiled(f, x) 13 | 14 | eprint(ans) -------------------------------------------------------------------------------- /tests/test_parse.py: -------------------------------------------------------------------------------- 1 | #Sketch project 2 | 3 | 4 | #from builtins import super 5 | #import pickle 6 | #import string 7 | #import argparse 8 | #import random 9 | 10 | #import torch 11 | #from torch import nn, optim 12 | 13 | #from pinn import RobustFill 14 | #from pinn import SyntaxCheckingRobustFill #TODO 15 | #import random 16 | #import math 17 | 18 | #from collections import OrderedDict 19 | #from util import enumerate_reg, Hole 20 | 21 | import sys 22 | import os 23 | sys.path.append(os.path.abspath('./')) 24 | sys.path.append(os.path.abspath('./ec')) 25 | 26 | from grammar import Grammar 27 | from deepcoderPrimitives import deepcoderProductions, flatten_program 28 | 29 | #from program import Application, Hole 30 | 31 | #import math 32 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 33 | from program import prettyProgram 34 | 35 | from train.main_supervised_deepcoder import parseprogram, make_holey_deepcoder 36 | 37 | #g = Grammar.uniform(deepcoderPrimitives()) 38 | g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9) #TODO - find correct grammar weights 39 | request = arrow(tlist(tint), tint, tint) 40 | 41 | p = g.sample(request) 42 | 43 | sketch = make_holey_deepcoder(p, 10, g, request) 44 | 45 | print("request:", request) 46 | print("program:") 47 | print(prettyProgram(p)) 48 | print("flattened_program:") 49 | flat = flatten_program(p) 50 | print(flat) 51 | 52 | prog = parseprogram(flat, request) 53 | print("recovered program:") 54 | print(prettyProgram(prog)) 55 | print("-----") 56 | print("sketch:") 57 | print(sketch) 58 | print("flattend sketch:") 59 | flatsketch = flatten_program(sketch) 60 | print(flatsketch) 61 | print("recovered sketch") 62 | recovered_sketch = parseprogram(flatsketch, request) 63 | print(recovered_sketch) 64 | 65 | 66 | for i in range(1000): 67 | p = g.sample(request) 68 | 69 | sketch = make_holey_deepcoder(p, 10, g, request) 70 | 71 | 72 | flat = flatten_program(p) 73 | 74 | 75 | prog = parseprogram(flat, request) 76 | 77 | flatsketch = flatten_program(sketch) 78 | 79 | recovered_sketch = parseprogram(flatsketch, request) 80 | if not flatsketch == ['']: 81 | assert recovered_sketch == sketch 82 | -------------------------------------------------------------------------------- /tests/test_sample_deepcoderIO.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath('./')) 4 | sys.path.append(os.path.abspath('./ec')) 5 | 6 | from grammar import Grammar 7 | from deepcoderPrimitives import deepcoderProductions, flatten_program 8 | 9 | #from program import Application, Hole 10 | 11 | #import math 12 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 13 | from program import prettyProgram 14 | 15 | from train.main_supervised_deepcoder import parseprogram, make_holey_deepcoder, sampleIO, getInstance, grammar 16 | import time 17 | 18 | max_length=30 19 | 20 | inst = getInstance(5, verbose=True) 21 | print("program:") 22 | print(inst['p']) 23 | print("IO:") 24 | print(inst['IO']) 25 | 26 | t = time.time() 27 | p = make_holey_deepcoder(inst['p'], 5, grammar, inst['tp']) 28 | print(time.time() - t) 29 | 30 | t = time.time() 31 | sketch = flatten_program(p) 32 | print(time.time() - t) -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/train/__init__.py -------------------------------------------------------------------------------- /train/algolisp_train_dc_model.py: -------------------------------------------------------------------------------- 1 | #Training deepcoderModel 2 | from builtins import super 3 | import pickle 4 | import string 5 | import argparse 6 | import random 7 | 8 | import sys 9 | import os 10 | sys.path.append(os.path.abspath('./')) 11 | sys.path.append(os.path.abspath('./ec')) 12 | import torch 13 | from torch import nn, optim 14 | 15 | from pinn import RobustFill 16 | from pinn import SyntaxCheckingRobustFill #TODO 17 | import random 18 | import math 19 | import time 20 | 21 | from collections import OrderedDict 22 | from itertools import chain 23 | 24 | 25 | from grammar import Grammar, NoCandidates 26 | from algolispPrimitives import algolispProductions, primitive_lookup, algolisp_input_vocab, algolisp_IO_vocab, digit_enc_vocab 27 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 28 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 29 | #from util.deepcoder_util import parseprogram, grammar 30 | from data_src.makeAlgolispData import batchloader, basegrammar 31 | 32 | from models.deepcoderModel import SketchFeatureExtractor, HoleSpecificFeatureExtractor, ImprovedRecognitionModel, AlgolispIOFeatureExtractor 33 | 34 | # from deepcoderModel import 35 | 36 | 37 | def newDcModel(cuda=True, IO2seq=False, digit_enc=False): 38 | if IO2seq: 39 | input_vocab = algolisp_IO_vocab() if not digit_enc else digit_enc_vocab()# TODO 40 | algolisp_vocab = list(primitive_lookup.keys()) + ['(',')', ''] 41 | specExtractor = AlgolispIOFeatureExtractor(input_vocab, hidden=128, use_cuda=cuda, digit_enc=digit_enc) # Is this okay? max length 42 | sketchExtractor = SketchFeatureExtractor(algolisp_vocab, hidden=128, use_cuda=cuda) 43 | extractor = HoleSpecificFeatureExtractor(specExtractor, sketchExtractor, hidden=128, use_cuda=cuda) 44 | dcModel = ImprovedRecognitionModel(extractor, basegrammar, hidden=[128], cuda=cuda, contextual=False) 45 | else: 46 | input_vocab = algolisp_input_vocab# TODO 47 | algolisp_vocab = list(primitive_lookup.keys()) + ['(',')', ''] 48 | specExtractor = SketchFeatureExtractor(input_vocab, hidden=128, use_cuda=cuda) # Is this okay? max length 49 | sketchExtractor = SketchFeatureExtractor(algolisp_vocab, hidden=128, use_cuda=cuda) 50 | extractor = HoleSpecificFeatureExtractor(specExtractor, sketchExtractor, hidden=128, use_cuda=cuda) 51 | dcModel = ImprovedRecognitionModel(extractor, basegrammar, hidden=[128], cuda=cuda, contextual=False) 52 | 53 | 54 | return(dcModel) 55 | 56 | if __name__ == "__main__": 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--debug', action='store_true') 60 | parser.add_argument('--nosave', action='store_true') 61 | parser.add_argument('-k', type=int, default=40) #TODO 62 | parser.add_argument('--max_epochs', type=int, default=10) 63 | parser.add_argument('--save_model_path', type=str, default='./saved_models/algolisp_dc_model.p') 64 | parser.add_argument('--load_model_path', type=str, default='./saved_models/algolisp_dc_model.p') 65 | parser.add_argument('--new', action='store_true') 66 | parser.add_argument('--train_data', type=str, default='train') 67 | parser.add_argument('--use_dc_grammar', action='store_true') 68 | parser.add_argument('--improved_dc_model', action='store_true', default=True) 69 | parser.add_argument('--inv_temp', type=float, default=0.01) #idk what the deal with this is ... 70 | parser.add_argument('--use_timeout', action='store_true', default=True) 71 | parser.add_argument('--filter_depth', nargs='+', type=int, default=None) 72 | parser.add_argument('--nHoles', type=int, default=1) 73 | parser.add_argument('--limit_data', type=float, default=False) 74 | parser.add_argument('--IO2seq', action='store_true') 75 | parser.add_argument('--seed', type=int, default=42) 76 | parser.add_argument('--use_dataset_len', type=int, default=False) 77 | 78 | parser.add_argument('--exclude_odd', action='store_true') 79 | parser.add_argument('--exclude_even', action='store_true') 80 | parser.add_argument('--exclude_geq', action='store_true') 81 | parser.add_argument('--exclude_gt', action='store_true') 82 | 83 | parser.add_argument('--digit_enc', action='store_true') 84 | parser.add_argument('--limit_IO_size', type=int, default=None) 85 | args = parser.parse_args() 86 | 87 | assert not (args.exclude_even and args.exclude_odd) 88 | 89 | #xor all the options classes: 90 | if any([args.exclude_even, args.exclude_odd, args.exclude_geq]): 91 | assert (args.exclude_even or args.exclude_odd) != args.exclude_geq 92 | 93 | if args.exclude_odd: 94 | exclude = [ ["lambda1", ["==", ["%", "arg1", "2"], "1"]] ] 95 | elif args.exclude_even: 96 | exclude = [ ["lambda1", ["==", ["%", "arg1", "2"], "0"]] ] 97 | elif args.exclude_geq: 98 | exclude = [">="] 99 | elif args.exclude_gt: 100 | exclude = [">"] 101 | else: 102 | exclude = None 103 | 104 | batchsize = 1 105 | max_epochs = args.max_epochs 106 | use_dc_grammar = args.use_dc_grammar 107 | improved_dc_model = args.improved_dc_model 108 | top_k_sketches = args.k 109 | inv_temp = args.inv_temp 110 | use_timeout = args.use_timeout 111 | 112 | if not improved_dc_model: assert False, "unimplemented" 113 | 114 | 115 | train_datas = args.train_data 116 | 117 | print("Loading model", flush=True) 118 | try: 119 | if args.new: 120 | raise FileNotFoundError 121 | dcModel=newDcModel(IO2seq=args.IO2seq, digit_enc=args.digit_enc) 122 | dcModel.load_state_dict(torch.load(args.load_model_path)) 123 | print('found saved dcModel, loading ...') 124 | except FileNotFoundError: 125 | print("no saved dcModel, creating new one") 126 | 127 | 128 | #extractor = LearnedFeatureExtractor(deepcoder_io_vocab, hidden=128) 129 | #dcModel = DeepcoderRecognitionModel(extractor, grammar, hidden=[128], cuda=True) 130 | 131 | dcModel = newDcModel(IO2seq=args.IO2seq, digit_enc=args.digit_enc) 132 | 133 | print("number of parameters is", sum(p.numel() for p in dcModel.parameters() if p.requires_grad)) 134 | 135 | ######## TRAINING ######## 136 | #make this a function... 137 | t2 = time.time() 138 | print("training", flush=True) 139 | if not hasattr(dcModel, 'iteration'): 140 | dcModel.iteration = 0 141 | dcModel.scores = [] 142 | if not hasattr(dcModel, 'epochs'): 143 | dcModel.epochs = 0 144 | 145 | for j in range(dcModel.epochs, max_epochs): #TODO 146 | print(f"\tepoch {j}:") 147 | 148 | for i, datum in enumerate(batchloader(train_datas, 149 | batchsize=1, 150 | compute_sketches=True, 151 | dc_model=dcModel if use_dc_grammar and (dcModel.epochs > 1) else None, # This means dcModel updated every epoch, but not first two 152 | improved_dc_model=improved_dc_model, 153 | top_k_sketches=args.k, 154 | inv_temp=args.inv_temp, 155 | reward_fn=None, 156 | sample_fn=None, 157 | nHoles=args.nHoles, 158 | use_timeout=args.use_timeout, 159 | filter_depth=args.filter_depth, 160 | limit_data=args.limit_data, 161 | seed=args.seed, 162 | use_dataset_len=args.use_dataset_len, 163 | limit_IO_size=args.limit_IO_size, 164 | exclude=exclude)): #TODO 165 | 166 | spec = datum.spec if not args.IO2seq else datum.IO 167 | t = time.time() 168 | t3 = t-t2 169 | #score = dcModel.optimizer_step(datum.IO, datum.p, datum.tp) #TODO make sure inputs are correctly formatted 170 | score = dcModel.optimizer_step((spec, datum.sketchseq), datum.p, datum.sketch, datum.tp) 171 | t2 = time.time() 172 | 173 | dcModel.scores.append(score) 174 | dcModel.iteration += 1 175 | if i%500==0 and not i==0: 176 | print("iteration", i, "average score:", sum(dcModel.scores[-500:])/500, flush=True) 177 | print(f"network time: {t2-t}, other time: {t3}") 178 | #if i%5000==0: 179 | if not args.nosave: 180 | torch.save(dcModel.state_dict(), args.save_model_path+f'_{str(j)}_iter_{str(i)}.p') 181 | torch.save(dcModel.state_dict(), args.save_model_path) 182 | #to prevent overwriting model: 183 | dcModel.epochs += 1 184 | if not args.nosave: 185 | torch.save(dcModel.state_dict(), args.save_model_path+'_{}.p'.format(str(j))) 186 | torch.save(dcModel.state_dict(), args.save_model_path) 187 | 188 | 189 | ######## End training ######## 190 | -------------------------------------------------------------------------------- /train/deepcoder_train_dc_model.py: -------------------------------------------------------------------------------- 1 | #Training deepcoderModel 2 | from builtins import super 3 | import pickle 4 | import string 5 | import argparse 6 | import random 7 | 8 | import sys 9 | import os 10 | sys.path.append(os.path.abspath('./')) 11 | sys.path.append(os.path.abspath('./ec')) 12 | import torch 13 | from torch import nn, optim 14 | 15 | from pinn import RobustFill 16 | from pinn import SyntaxCheckingRobustFill #TODO 17 | import random 18 | import math 19 | import time 20 | 21 | from collections import OrderedDict 22 | from itertools import chain 23 | 24 | from grammar import Grammar, NoCandidates 25 | from deepcoderPrimitives import deepcoderProductions, flatten_program 26 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 27 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 28 | from util.deepcoder_util import parseprogram, basegrammar, deepcoder_vocab 29 | from data_src.makeDeepcoderData import batchloader 30 | from models.deepcoderModel import LearnedFeatureExtractor, DeepcoderRecognitionModel, SketchFeatureExtractor, HoleSpecificFeatureExtractor, ImprovedRecognitionModel 31 | 32 | 33 | # from deepcoderModel import 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--debug', action='store_true') 36 | parser.add_argument('--nosave', action='store_true') 37 | parser.add_argument('-k', type=int, default=50) #TODO 38 | parser.add_argument('--Vrange', type=int, default=128) 39 | parser.add_argument('--max_epochs', type=int, default=50) 40 | parser.add_argument('--max_list_length', type=int, default=10) 41 | parser.add_argument('--save_model_path', type=str, default='./saved_models/list_dc_model.p') 42 | parser.add_argument('--load_model_path', type=str, default='./saved_models/list_dc_model.p') 43 | parser.add_argument('--new', action='store_true') 44 | parser.add_argument('--max_n_inputs', type=int, default=2) 45 | parser.add_argument('--improved_dc_model', action='store_true') 46 | parser.add_argument('--train_data', nargs='*', 47 | default=['data/DeepCoder_data/T2_A2_V512_L10_train.txt', 'data/DeepCoder_data/T3_A2_V512_L10_train_perm.txt']) 48 | parser.add_argument('--cuda', action='store_true', default=True) 49 | parser.add_argument('--inv_temp', type=float, default=0.1) #idk what the deal with this is ... 50 | parser.add_argument('--use_timeout', action='store_true', default=True) 51 | parser.add_argument('--nHoles', type=int, default=1) 52 | parser.add_argument('--use_dc_grammar', action='store_true') 53 | parser.add_argument('--max_iterations', type=int, default=100000000) 54 | args = parser.parse_args() 55 | 56 | max_length = 30 57 | batchsize = 1 58 | Vrange = args.Vrange 59 | max_epochs = args.max_epochs 60 | max_list_length = args.max_list_length 61 | cuda=args.cuda 62 | use_dc_grammar=args.use_dc_grammar 63 | 64 | deepcoder_io_vocab = list(range(-Vrange, Vrange+1)) + ["LIST_START", "LIST_END"] 65 | 66 | if __name__ == "__main__": 67 | 68 | train_datas = args.train_data 69 | 70 | vocab = deepcoder_vocab(basegrammar, n_inputs=args.max_n_inputs) 71 | 72 | print("Loading model", flush=True) 73 | try: 74 | if args.new: 75 | raise FileNotFoundError 76 | dcModel=torch.load(args.load_model_path) 77 | print('found saved dcModel, loading ...') 78 | except FileNotFoundError: 79 | print("no saved dcModel, creating new one") 80 | if args.improved_dc_model: 81 | print("creating new improved dc model") 82 | specExtractor = LearnedFeatureExtractor(deepcoder_io_vocab, hidden=128, use_cuda=cuda) 83 | sketchExtractor = SketchFeatureExtractor(vocab, hidden=128, use_cuda=cuda) 84 | extractor = HoleSpecificFeatureExtractor(specExtractor, sketchExtractor, hidden=128, use_cuda=cuda) 85 | dcModel = ImprovedRecognitionModel(extractor, basegrammar, hidden=[128], cuda=cuda, contextual=False) 86 | else: 87 | extractor = LearnedFeatureExtractor(deepcoder_io_vocab, hidden=128) 88 | dcModel = DeepcoderRecognitionModel(extractor, basegrammar, hidden=[128], cuda=True) 89 | 90 | print("number of parameters is", 91 | sum(p.numel() for p in dcModel.parameters() if p.requires_grad)) 92 | 93 | ######## TRAINING ######## 94 | #make this a function... 95 | t2 = time.time() 96 | print("training", flush=True) 97 | if not hasattr(dcModel, 'iteration'): 98 | dcModel.iteration = 0 99 | dcModel.scores = [] 100 | if not hasattr(dcModel, 'epochs'): 101 | dcModel.epochs = 0 102 | 103 | for j in range(dcModel.epochs, max_epochs): #TODO 104 | print(f"\tepoch {j}:") 105 | 106 | for i, datum in enumerate( 107 | batchloader(train_datas, 108 | batchsize=batchsize, 109 | N=5, 110 | V=Vrange, 111 | L=max_list_length, 112 | compute_sketches=args.improved_dc_model, 113 | dc_model=dcModel if use_dc_grammar and (dcModel.epochs > 1) else None, # TODO 114 | improved_dc_model=args.improved_dc_model, 115 | top_k_sketches=args.k, 116 | inv_temp=args.inv_temp, 117 | reward_fn=None, 118 | sample_fn=None, 119 | nHoles=args.nHoles, 120 | use_timeout=args.use_timeout, 121 | shuffle=True)): #TODO 122 | 123 | t = time.time() 124 | t3 = t-t2 125 | if args.improved_dc_model: 126 | score = dcModel.optimizer_step((datum.IO, datum.sketchseq), datum.p, datum.sketch, datum.tp) 127 | else: 128 | score = dcModel.optimizer_step(datum.IO, datum.p, datum.tp) #TODO make sure inputs are correctly formatted 129 | 130 | t2 = time.time() 131 | #print(datum.sketch) 132 | dcModel.scores.append(score) 133 | dcModel.iteration += 1 134 | if dcModel.iteration > args.max_iterations: 135 | print('done training') 136 | break 137 | if i%500==0 and not i==0: 138 | print("pretrain iteration", i, "average score:", sum(dcModel.scores[-500:])/500, flush=True) 139 | print(f"network time: {t2-t}, other time: {t3}") 140 | if i%50000==0: 141 | #do some inference 142 | #g = dcModel.infer_grammar(IO) #TODO 143 | if not args.nosave: 144 | torch.save(dcModel, args.save_model_path+f'_{str(j)}_iter_{str(i)}.p') 145 | torch.save(dcModel, args.save_model_path) 146 | 147 | #to prevent overwriting model: 148 | if not args.nosave: 149 | torch.save(dcModel, args.save_model_path+'_{}.p'.format(str(j))) 150 | torch.save(dcModel, args.save_model_path) 151 | 152 | 153 | ######## End training ######## 154 | -------------------------------------------------------------------------------- /train/main_supervised_deepcoder.py: -------------------------------------------------------------------------------- 1 | #Sketch project 2 | from builtins import super 3 | import pickle 4 | import string 5 | import argparse 6 | import random 7 | import torch 8 | from torch import nn, optim 9 | import random 10 | import math 11 | import time 12 | from collections import OrderedDict 13 | from itertools import chain 14 | import math 15 | 16 | import sys 17 | import os 18 | #sys.path.append(os.path.abspath('./')) 19 | sys.path.append(os.path.abspath('./')) 20 | sys.path.append(os.path.abspath('./ec')) 21 | 22 | from pinn import RobustFill 23 | from pinn import SyntaxCheckingRobustFill #TODO 24 | 25 | from grammar import Grammar, NoCandidates 26 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 27 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 28 | from deepcoderPrimitives import deepcoderProductions, flatten_program 29 | from utilities import timing 30 | 31 | from data_src.makeDeepcoderData import batchloader 32 | from util.deepcoder_util import parseprogram, basegrammar, tokenize_for_robustfill, deepcoder_vocab 33 | 34 | 35 | 36 | # import sys 37 | # sys.path.append("/om/user/mnye/ec") 38 | 39 | # from grammar import Grammar, NoCandidates 40 | # from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 41 | # from type import Context, arrow, tint, tlist, tbool, UnificationFailure 42 | # from deepcoderPrimitives import deepcoderProductions, flatten_program 43 | # from utilities import timing 44 | 45 | # from makeDeepcoderData import batchloader 46 | # import math 47 | # from deepcoder_util import parseprogram, grammar, tokenize_for_robustfill 48 | # from itertools import chain 49 | 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--pretrain', action='store_true') 52 | parser.add_argument('--debug', action='store_true') 53 | parser.add_argument('--nosave', action='store_true') 54 | #parser.add_argument('--start_with_holes', action='store_true') 55 | parser.add_argument('--variance_reduction', action='store_true') 56 | parser.add_argument('-k', type=int, default=50) #TODO 57 | parser.add_argument('--new', action='store_true') 58 | parser.add_argument('--rnn_max_length', type=int, default=30) 59 | parser.add_argument('--batchsize', type=int, default=50) 60 | parser.add_argument('--Vrange', type=int, default=128) 61 | parser.add_argument('--n_examples', type=int, default=5) 62 | parser.add_argument('--max_list_length', type=int, default=10) 63 | parser.add_argument('--max_n_inputs', type=int, default=2) 64 | parser.add_argument('--max_pretrain_epochs', type=int, default=10) 65 | parser.add_argument('--max_pretrain_iterations', type=int, default=100000) 66 | parser.add_argument('--max_iterations', type=int, default=100000) 67 | parser.add_argument('--max_epochs', type=int, default=10) 68 | parser.add_argument('--train_data', nargs='*', 69 | default=['data/DeepCoder_data/T2_A2_V512_L10_train.txt', 'data/DeepCoder_data/T3_A2_V512_L10_train_perm.txt']) 70 | # save and load files 71 | parser.add_argument('--load_pretrained_model_path', type=str, default="./saved_models/list_pretrained.p") 72 | parser.add_argument('--save_pretrained_model_path', type=str, default="./saved_models/list_pretrained.p") 73 | parser.add_argument('--save_model_path', type=str, default="./saved_models/list_holes.p") 74 | parser.add_argument('--save_freq', type=int, default=400) 75 | parser.add_argument('--print_freq', type=int, default=1) 76 | parser.add_argument('--top_k_sketches', type=int, default=100) 77 | parser.add_argument('--inv_temp', type=float, default=1.0) 78 | parser.add_argument('--use_rl', action='store_true') 79 | parser.add_argument('--imp_weight_trunc', action='store_true') 80 | parser.add_argument('--rl_no_syntax', action='store_true') 81 | parser.add_argument('--use_dc_grammar', type=str, default='NA') # 82 | parser.add_argument('--rl_lr', type=float, default=0.001) 83 | parser.add_argument('--reward_fn', type=str, default='original', choices=['original','linear','exp', 'flat']) 84 | parser.add_argument('--sample_fn', type=str, default='original', choices=['original','linear','exp', 'flat']) 85 | parser.add_argument('--r_max', type=int, default=8) 86 | parser.add_argument('--timing', action='store_true') 87 | parser.add_argument('--num_half_lifes', type=float, default=4) 88 | parser.add_argument('--use_timeout', action='store_true', default=True) 89 | parser.add_argument('--improved_dc_model', action='store_true') 90 | parser.add_argument('--nHoles', type=int, default=3) 91 | parser.add_argument('--dcModelcpu', action='store_true') 92 | args = parser.parse_args() 93 | 94 | #assume we want num_half_life half lives to occur by the r_max value ... 95 | 96 | alpha = math.log(2)*args.num_half_lifes/math.exp(args.r_max) 97 | 98 | reward_fn = { 99 | 'original': None, 100 | 'linear': lambda x: max(math.exp(args.r_max) - math.exp(-x), 0)/math.exp(args.r_max), 101 | 'exp': lambda x: math.exp(-alpha*math.exp(-x)), 102 | 'flat': lambda x: 1 if x > -args.r_max else 0 103 | }[args.reward_fn] 104 | sample_fn = { 105 | 'original': None, 106 | 'linear': lambda x: max(math.exp(args.r_max) - math.exp(-x), 0), 107 | 'exp': lambda x: math.exp(-alpha*math.exp(-x)), 108 | 'flat': lambda x: 1 if x > -args.r_max else 0 109 | }[args.sample_fn] 110 | 111 | 112 | batchsize = args.batchsize 113 | Vrange = args.Vrange 114 | train_datas = args.train_data 115 | 116 | if args.improved_dc_model: assert not args.use_dc_grammar == 'NA' 117 | 118 | if args.use_dc_grammar == 'NA': 119 | use_dc_grammar = False 120 | dc_grammar_path = None 121 | else: 122 | use_dc_grammar = True 123 | dc_model_path = args.use_dc_grammar 124 | 125 | vocab = deepcoder_vocab(basegrammar, n_inputs=args.max_n_inputs) 126 | 127 | if __name__ == "__main__": 128 | print("Loading model", flush=True) 129 | try: 130 | if args.new: raise FileNotFoundError 131 | else: 132 | model=torch.load(args.load_pretrained_model_path) 133 | print('found saved model, loaded pretrained model (without holes)') 134 | except FileNotFoundError: 135 | print("no saved model, creating new one") 136 | model = SyntaxCheckingRobustFill( 137 | input_vocabularies=[list(range(-Vrange, Vrange+1)) + ["LIST_START", "LIST_END"], 138 | list(range(-Vrange, Vrange+1)) + ["LIST_START", "LIST_END"]], 139 | target_vocabulary=vocab, max_length=args.rnn_max_length, hidden_size=512) # TODO 140 | model.pretrain_iteration = 0 141 | model.pretrain_scores = [] 142 | model.pretrain_epochs = 0 143 | model.iteration = 0 144 | model.hole_scores = [] 145 | model.epochs = 0 146 | 147 | if use_dc_grammar: 148 | print("loading dc model") 149 | dc_model=torch.load(dc_model_path) 150 | if args.dcModelcpu: 151 | dc_model.cpu() 152 | 153 | model.cuda() 154 | print("number of parameters is", sum(p.numel() for p in model.parameters() if p.requires_grad)) 155 | 156 | if args.use_rl: model._get_optimiser(lr=args.rl_lr) 157 | if args.variance_reduction: 158 | #if not hasattr(model, 'variance_red'): 159 | # print("creating variance_red param") 160 | #variance_red = nn.Parameter(torch.Tensor([0], requires_grad=True, device="cuda")) 161 | #variance_red = torch.zeros(1, requires_grad=True, device="cuda") 162 | variance_red = torch.Tensor([.95]).cuda().requires_grad_() 163 | model.opt.add_param_group({"params": variance_red}) 164 | #model._clear_optimiser() 165 | 166 | ####### train with holes ######## 167 | pretraining = args.pretrain and model.pretrain_epochs < args.max_pretrain_epochs 168 | training = model.epochs < args.max_epochs 169 | 170 | t2 = time.time() 171 | while pretraining or training: 172 | j = model.pretrain_epochs if pretraining else model.epochs 173 | if pretraining: print(f"\tpretraining epoch {j}:") 174 | else: print(f"\ttraining epoch {j}:") 175 | path = args.save_pretrained_model_path if pretraining else args.save_model_path 176 | 177 | #TODO: fix the batch loader: 178 | for i, batch in enumerate(batchloader(train_datas, 179 | batchsize=batchsize, 180 | N=args.n_examples, 181 | V=Vrange, 182 | L=args.max_list_length, 183 | compute_sketches=not pretraining, 184 | dc_model=dc_model if use_dc_grammar else None, 185 | top_k_sketches=args.top_k_sketches, 186 | inv_temp=args.inv_temp, 187 | reward_fn=reward_fn, 188 | sample_fn=sample_fn, 189 | use_timeout=args.use_timeout, 190 | improved_dc_model=args.improved_dc_model, 191 | nHoles=args.nHoles)): 192 | 193 | IOs = tokenize_for_robustfill(batch.IOs) 194 | if args.timing: t = time.time() 195 | objective, syntax_score = model.optimiser_step(IOs, batch.pseqs if pretraining else batch.sketchseqs) 196 | if args.timing: 197 | print(f"network time: {time.time()-t}, other time: {t-t2}") 198 | t2 = time.time() 199 | if pretraining: 200 | model.pretrain_scores.append(objective) 201 | model.pretrain_iteration += 1 202 | if model.pretrain_iteration >= args.max_pretrain_iterations: break 203 | else: 204 | model.iteration += 1 205 | if model.iteration >= args.max_iterations: break 206 | model.hole_scores.append(objective) 207 | if i%args.print_freq==0: 208 | if args.use_rl: print("reweighted_reward:", reweighted_reward.mean().data.item()) 209 | print("iteration", i, "score:", objective if not args.use_rl else score.mean().data.item() , "syntax_score:", syntax_score if not args.use_rl else syntax_score.data.item(), flush=True) 210 | if i%args.save_freq==0: 211 | if not args.nosave: 212 | torch.save(model, path+f'_{str(j)}_iter_{str(i)}.p') 213 | torch.save(model, path) 214 | if not args.nosave: 215 | torch.save(model, path+'_{}.p'.format(str(j))) 216 | torch.save(model, path) 217 | if pretraining: model.pretrain_epochs += 1 218 | else: model.epochs += 1 219 | if model.pretrain_epochs >= args.max_pretrain_epochs: pretraining = False 220 | if model.epochs >= args.max_epochs: training = False 221 | 222 | ####### End train with holes ######## 223 | 224 | # RL formalism w luke 225 | # add temperature parameter - x 226 | # think about RL objective 227 | 228 | # merge pretrain and regular train - x 229 | # use with timing(nn training) - idk 230 | # deal with model attributes - x 231 | -------------------------------------------------------------------------------- /train/main_supervised_robustfill.py: -------------------------------------------------------------------------------- 1 | #Sketch project 2 | from builtins import super 3 | import pickle 4 | import string 5 | import argparse 6 | import random 7 | import torch 8 | from torch import nn, optim 9 | 10 | import sys 11 | import os 12 | sys.path.append(os.path.abspath('./')) 13 | sys.path.append(os.path.abspath('./ec')) 14 | 15 | from pinn import RobustFill 16 | from pinn import SyntaxCheckingRobustFill #TODO 17 | import random 18 | import math 19 | import time 20 | from collections import OrderedDict 21 | #from util import enumerate_reg, Hole 22 | 23 | 24 | from grammar import Grammar, NoCandidates 25 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 26 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 27 | from RobustFillPrimitives import RobustFillProductions, flatten_program 28 | from utilities import timing 29 | 30 | 31 | from models.deepcoderModel import LearnedFeatureExtractor, DeepcoderRecognitionModel, RobustFillLearnedFeatureExtractor, load_rb_dc_model_from_path 32 | from data_src.makeRobustFillData import batchloader 33 | import math 34 | from util.robustfill_util import tokenize_for_robustfill, robustfill_vocab 35 | from itertools import chain 36 | from string import printable 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--pretrain', action='store_true') 40 | parser.add_argument('--debug', action='store_true') 41 | parser.add_argument('--nosave', action='store_true') 42 | #parser.add_argument('--start_with_holes', action='store_true') 43 | parser.add_argument('--variance_reduction', action='store_true') 44 | parser.add_argument('-k', type=int, default=50) #TODO 45 | parser.add_argument('--new', action='store_true') 46 | parser.add_argument('--rnn_max_length', type=int, default=20) # TODO 47 | parser.add_argument('--batchsize', type=int, default=50) 48 | parser.add_argument('--max_length', type=int, default=25) 49 | parser.add_argument('--n_examples', type=int, default=4) 50 | parser.add_argument('--max_list_length', type=int, default=10) 51 | parser.add_argument('--max_index',type=int, default=4) 52 | # parser.add_argument('--max_pretrain_epochs', type=int, default=4) 53 | # parser.add_argument('--max_epochs', type=int, default=5) 54 | parser.add_argument('--max_pretrain_iteration', type=int, default=400*5*2*4) 55 | parser.add_argument('--max_iteration', type=int, default=400*5*2*4) 56 | parser.add_argument('--train_data',type=str,default='NA') 57 | # save and load files 58 | parser.add_argument('--load_pretrained_model_path', type=str, default="./saved_models/text_pretrained.p") 59 | parser.add_argument('--save_pretrained_model_path', type=str, default="./saved_models/text_pretrained.p") 60 | parser.add_argument('--save_model_path', type=str, default="./saved_models/text_holes.p") 61 | parser.add_argument('--save_freq', type=int, default=200) 62 | parser.add_argument('--print_freq', type=int, default=1) 63 | parser.add_argument('--top_k_sketches', type=int, default=100) 64 | parser.add_argument('--inv_temp', type=float, default=1.0) 65 | parser.add_argument('--use_rl', action='store_true') 66 | #parser.add_argument('--imp_weight_trunc', action='store_true') 67 | #parser.add_argument('--rl_no_syntax', action='store_true') 68 | parser.add_argument('--use_dc_grammar', type=str, default='NA') 69 | parser.add_argument('--rl_lr', type=float, default=0.001) 70 | parser.add_argument('--reward_fn', type=str, default='original', choices=['original','linear','exp', 'flat']) 71 | parser.add_argument('--sample_fn', type=str, default='original', choices=['original','linear','exp', 'flat']) 72 | parser.add_argument('--r_max', type=int, default=8) 73 | parser.add_argument('--timing', action='store_true') 74 | parser.add_argument('--num_half_lifes', type=float, default=4) 75 | parser.add_argument('--use_timeout', action='store_true') 76 | 77 | parser.add_argument('--improved_dc_model', action='store_true') 78 | parser.add_argument('--nHoles', type=int, default=3) 79 | parser.add_argument('--input_noise', action='store_true') 80 | 81 | parser.add_argument('--load_trained_model', action='store_true') 82 | parser.add_argument('--load_trained_model_path', type=str, default="./saved_models/algolisp_holes.p") 83 | 84 | args = parser.parse_args() 85 | 86 | #assume we want num_half_life half lives to occur by the r_max value ... 87 | 88 | alpha = math.log(2)*args.num_half_lifes/math.exp(args.r_max) 89 | 90 | reward_fn = { 91 | 'original': None, 92 | 'linear': lambda x: max(math.exp(args.r_max) - math.exp(-x), 0)/math.exp(args.r_max), 93 | 'exp': lambda x: math.exp(-alpha*math.exp(-x)), 94 | 'flat': lambda x: 1 if x > -args.r_max else 0 95 | }[args.reward_fn] 96 | sample_fn = { 97 | 'original': None, 98 | 'linear': lambda x: max(math.exp(args.r_max) - math.exp(-x), 0), 99 | 'exp': lambda x: math.exp(-alpha*math.exp(-x)), 100 | 'flat': lambda x: 1 if x > -args.r_max else 0 101 | }[args.sample_fn] 102 | 103 | 104 | batchsize = args.batchsize 105 | max_length = args.max_length 106 | if args.train_data != 'NA': 107 | train_datas = args.train_data 108 | assert False 109 | 110 | if args.use_dc_grammar == 'NA': 111 | use_dc_grammar = False 112 | dc_grammar_path = None 113 | else: 114 | use_dc_grammar = True 115 | dc_model_path = args.use_dc_grammar 116 | 117 | basegrammar = Grammar.fromProductions(RobustFillProductions(args.max_length, args.max_index)) 118 | 119 | vocab = robustfill_vocab(basegrammar) 120 | 121 | if __name__ == "__main__": 122 | print("Loading model", flush=True) 123 | try: 124 | if args.new: raise FileNotFoundError 125 | elif args.load_trained_model: 126 | model=torch.load(args.load_trained_model_path) 127 | print("loading saved trained model, continuing training") 128 | else: 129 | model=torch.load(args.load_pretrained_model_path) 130 | print('found saved model, loaded pretrained model (without holes)') 131 | except FileNotFoundError: 132 | print("no saved model, creating new one") 133 | model = SyntaxCheckingRobustFill( 134 | input_vocabularies=[printable[:-4], printable[:-4]], 135 | target_vocabulary=vocab, max_length=args.rnn_max_length, hidden_size=512) # TODO 136 | model.pretrain_iteration = 0 137 | model.pretrain_scores = [] 138 | model.iteration = 0 139 | model.hole_scores = [] 140 | 141 | if use_dc_grammar: 142 | print("loading dc model") 143 | dc_model = load_rb_dc_model_from_path(dc_model_path, args.max_length, args.max_index, args.improved_dc_model, cuda=True) 144 | 145 | model.cuda() 146 | print("number of parameters is", sum(p.numel() for p in model.parameters() if p.requires_grad)) 147 | 148 | if args.use_rl: model._get_optimiser(lr=args.rl_lr) 149 | if args.variance_reduction: 150 | #if not hasattr(model, 'variance_red'): 151 | # print("creating variance_red param") 152 | #variance_red = nn.Parameter(torch.Tensor([0], requires_grad=True, device="cuda")) 153 | #variance_red = torch.zeros(1, requires_grad=True, device="cuda") 154 | variance_red = torch.Tensor([.95]).cuda().requires_grad_() 155 | model.opt.add_param_group({"params": variance_red}) 156 | #model._clear_optimiser() 157 | 158 | ####### train with holes ######## 159 | pretraining = args.pretrain and model.pretrain_iteration < args.max_pretrain_iteration 160 | training = model.iteration < args.max_iteration and not pretraining 161 | 162 | t2 = time.time() 163 | while pretraining or training: 164 | path = args.save_pretrained_model_path if pretraining else args.save_model_path 165 | iter_remaining = args.max_pretrain_iteration - model.pretrain_iteration if pretraining else args.max_iteration - model.iteration 166 | print("pretraining:", pretraining) 167 | print("iter to train:", iter_remaining) 168 | for i, batch in zip(range(iter_remaining), batchloader(iter_remaining, 169 | basegrammar, 170 | batchsize=batchsize, 171 | N=args.n_examples, 172 | V=max_length, 173 | L=args.max_list_length, 174 | compute_sketches=not pretraining, 175 | dc_model=dc_model if use_dc_grammar else None, 176 | top_k_sketches=args.top_k_sketches, 177 | inv_temp=args.inv_temp, 178 | reward_fn=reward_fn, 179 | sample_fn=sample_fn, 180 | use_timeout=args.use_timeout, 181 | improved_dc_model=args.improved_dc_model, 182 | nHoles=args.nHoles, 183 | input_noise=args.input_noise)): 184 | IOs = tokenize_for_robustfill(batch.IOs) 185 | if args.timing: t = time.time() 186 | if not pretraining and args.use_rl: 187 | #if not hasattr(model, 'opt'): 188 | # model._get_optimiser(lr=args.rl_lr) #todo 189 | # model.opt.add_param_group({"params": variance_red}) 190 | model.opt.zero_grad() 191 | if args.imp_weight_trunc: 192 | print("not finished implementing") 193 | assert False 194 | else: 195 | score, syntax_score = model.score(IOs, batch.sketchseqs, autograd=True) 196 | print("rewards:", batch.rewards.mean()) 197 | print("sketchprobs:", batch.sketchprobs.mean()) 198 | if args.variance_reduction: 199 | if not args.rl_no_syntax: 200 | objective = torch.exp(score.data)/batch.sketchprobs.cuda() * (batch.rewards.cuda() - variance_red.data) * (score + syntax_score) - torch.pow((batch.rewards.cuda() - variance_red),2) 201 | else: 202 | objective = torch.exp(score.data)/batch.sketchprobs.cuda() * (batch.rewards.cuda() - variance_red.data) * score - torch.pow((batch.rewards.cuda() - variance_red),2) 203 | reweighted_reward = torch.exp(score.data)/batch.sketchprobs.cuda() * batch.rewards.cuda() 204 | else: 205 | if not args.rl_no_syntax: 206 | reweighted_reward = torch.exp(score.data)/batch.sketchprobs.cuda() * batch.rewards.cuda() 207 | objective = reweighted_reward * (score + syntax_score) 208 | else: 209 | reweighted_reward = torch.exp(score.data)/batch.sketchprobs.cuda() * batch.rewards.cuda() 210 | objective = reweighted_reward * score 211 | objective = objective.mean() 212 | (-objective).backward() 213 | model.opt.step() 214 | #for the purpose of printing: 215 | syntax_score = syntax_score.mean() 216 | objective 217 | if args.variance_reduction: 218 | print("variance_red_baseline:", variance_red.data.item()) 219 | else: 220 | objective, syntax_score = model.optimiser_step(IOs, batch.pseqs if pretraining else batch.sketchseqs) 221 | if args.timing: 222 | print(f"network time: {time.time()-t}, other time: {t-t2}", flush=True) 223 | t2 = time.time() 224 | if pretraining: 225 | model.pretrain_scores.append(objective) 226 | model.pretrain_iteration += 1 227 | else: 228 | model.iteration += 1 229 | model.hole_scores.append(objective) 230 | if i%args.print_freq==0: 231 | if args.use_rl: print("reweighted_reward:", reweighted_reward.mean().data.item()) 232 | print("iteration", i, "score:", objective if not args.use_rl else score.mean().data.item() , "syntax_score:", syntax_score if not args.use_rl else syntax_score.data.item(), flush=True) 233 | if i%args.save_freq==0: 234 | if not args.nosave: 235 | torch.save(model, path+f'_iter_{str(i)}.p') 236 | torch.save(model, path) 237 | if model.pretrain_iteration >= args.max_pretrain_iteration: pretraining = False 238 | if not pretraining and model.iteration < args.max_iteration: training = True 239 | if training and model.iteration >= args.max_iteration: training = False 240 | 241 | ####### End train with holes ######## 242 | 243 | # RL formalism w luke 244 | # add temperature parameter - x 245 | # think about RL objective 246 | 247 | # merge pretrain and regular train - x 248 | # use with timing(nn training) - idk 249 | # deal with model attributes - x 250 | -------------------------------------------------------------------------------- /train/robustfill_train_dc_model.py: -------------------------------------------------------------------------------- 1 | #Training robustfill deepcoderModel 2 | from builtins import super 3 | import pickle 4 | import string 5 | import argparse 6 | import random 7 | 8 | import torch 9 | from torch import nn, optim 10 | 11 | import sys 12 | import os 13 | sys.path.append(os.path.abspath('./')) 14 | sys.path.append(os.path.abspath('./ec')) 15 | 16 | from pinn import RobustFill 17 | from pinn import SyntaxCheckingRobustFill #TODO 18 | import random 19 | import math 20 | import time 21 | 22 | from collections import OrderedDict 23 | #from util import enumerate_reg, Hole 24 | 25 | 26 | from grammar import Grammar, NoCandidates 27 | from RobustFillPrimitives import RobustFillProductions, flatten_program 28 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 29 | import math 30 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 31 | from util.robustfill_util import parseprogram, robustfill_vocab 32 | from data_src.makeRobustFillData import batchloader 33 | from itertools import chain 34 | from models.deepcoderModel import LearnedFeatureExtractor, DeepcoderRecognitionModel 35 | from models.deepcoderModel import RobustFillLearnedFeatureExtractor, load_rb_dc_model_from_path 36 | from models.deepcoderModel import SketchFeatureExtractor, HoleSpecificFeatureExtractor, ImprovedRecognitionModel 37 | 38 | from string import printable 39 | 40 | # from deepcoderModel import 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--debug', action='store_true') 43 | parser.add_argument('--nosave', action='store_true') 44 | parser.add_argument('-k', type=int, default=3) #TODO 45 | parser.add_argument('--Vrange', type=int, default=128) 46 | parser.add_argument('--max_list_length', type=int, default=10) 47 | parser.add_argument('--save_model_path', type=str, default='./saved_models/text_dc_model.p') 48 | parser.add_argument('--load_model_path', type=str, default='./saved_models/text_dc_model.p') 49 | parser.add_argument('--new', action='store_true') 50 | parser.add_argument('--n_examples', type=int, default=4) 51 | parser.add_argument('--max_length', type=int, default=25) 52 | parser.add_argument('--max_index', type=int, default=4) 53 | parser.add_argument('--max_iteration', type=int, default=50*400*100) #approximate other model 54 | parser.add_argument('--improved_dc_model', action='store_true') 55 | parser.add_argument('--cuda', action='store_true', default=True) 56 | parser.add_argument('--inv_temp', type=float, default=0.1) #idk what the deal with this is ... 57 | parser.add_argument('--use_timeout', action='store_true', default=True) 58 | parser.add_argument('--nHoles', type=int, default=1) 59 | parser.add_argument('--use_dc_grammar', action='store_true') 60 | parser.add_argument('--input_noise', action='store_true') 61 | 62 | args = parser.parse_args() 63 | 64 | batchsize = 1 65 | max_iteration = args.max_iteration 66 | max_list_length = args.max_list_length 67 | use_dc_grammar = args.use_dc_grammar 68 | cuda=args.cuda 69 | 70 | robustfill_io_vocab = printable[:-4] 71 | 72 | basegrammar = Grammar.fromProductions(RobustFillProductions(args.max_length, args.max_index)) 73 | vocab = robustfill_vocab(basegrammar) 74 | 75 | if __name__ == "__main__": 76 | 77 | 78 | print("Loading model", flush=True) 79 | try: 80 | if args.new: 81 | raise FileNotFoundError 82 | dcModel=torch.load(args.load_model_path) 83 | print('found saved dcModel, loading ...') 84 | except FileNotFoundError: 85 | print("no saved dcModel, creating new one") 86 | if args.improved_dc_model: 87 | print("creating new improved dc model") 88 | ### 89 | specExtractor = RobustFillLearnedFeatureExtractor(robustfill_io_vocab, hidden=128, use_cuda=cuda) 90 | sketchExtractor = SketchFeatureExtractor(vocab, hidden=128, use_cuda=cuda) 91 | extractor = HoleSpecificFeatureExtractor(specExtractor, sketchExtractor, hidden=128, use_cuda=cuda) 92 | dcModel = ImprovedRecognitionModel(extractor, basegrammar, hidden=[128], cuda=cuda, contextual=False) 93 | ### 94 | else: 95 | extractor = RobustFillLearnedFeatureExtractor(robustfill_io_vocab, hidden=128) # probably want to make it much deeper .... 96 | dcModel = DeepcoderRecognitionModel(extractor, basegrammar, hidden=[128], cuda=True) # probably want to make it much deeper .... 97 | 98 | print("number of parameters is", sum(p.numel() for p in dcModel.parameters() if p.requires_grad)) 99 | 100 | ######## TRAINING ######## 101 | #make this a function... 102 | t2 = time.time() 103 | print("training", flush=True) 104 | if not hasattr(dcModel, 'iteration'): 105 | dcModel.iteration = 0 106 | dcModel.scores = [] 107 | 108 | if dcModel.iteration <= max_iteration: 109 | for i, datum in zip(range(max_iteration - dcModel.iteration), 110 | batchloader(max_iteration - dcModel.iteration, 111 | basegrammar, 112 | batchsize=1, 113 | N=args.n_examples, 114 | V=args.max_length, 115 | L=args.max_list_length, 116 | compute_sketches=args.improved_dc_model, 117 | dc_model=dcModel if use_dc_grammar and (dcModel.epochs > 1) else None, # TODO 118 | improved_dc_model=args.improved_dc_model, 119 | top_k_sketches=args.k, 120 | inv_temp=args.inv_temp, 121 | nHoles=args.nHoles, 122 | use_timeout=args.use_timeout, 123 | input_noise=args.input_noise)): #TODO 124 | 125 | 126 | t = time.time() 127 | t3 = t-t2 128 | if args.improved_dc_model: 129 | score = dcModel.optimizer_step((datum.IO, datum.sketchseq), datum.p, datum.sketch, datum.tp) 130 | else: 131 | score = dcModel.optimizer_step(datum.IO, datum.p, datum.tp) #TODO make sure inputs are correctly formatted 132 | t2 = time.time() 133 | 134 | dcModel.scores.append(score) 135 | dcModel.iteration += 1 136 | if i%500==0 and not i==0: 137 | print("pretrain iteration", i, "average score:", sum(dcModel.scores[-500:])/500, flush=True) 138 | print(f"network time: {t2-t}, other time: {t3}") 139 | if i%50000==0: 140 | #do some inference 141 | #g = dcModel.infer_grammar(IO) #TODO 142 | if not args.nosave: 143 | torch.save(dcModel.state_dict(), args.save_model_path+f'_iter_{str(i)}.p'+'state_dict') 144 | torch.save(dcModel.state_dict(), args.save_model_path+'state_dict') 145 | #dcModel.load_state_dict(torch.load(args.save_model_path+'state_dict')) 146 | #to prevent overwriting model: 147 | 148 | 149 | 150 | ######## End training ######## 151 | -------------------------------------------------------------------------------- /train/sketch_project_rl_regex.py: -------------------------------------------------------------------------------- 1 | #Sketch project 2 | 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath('./')) 6 | sys.path.append(os.path.abspath('./ec')) 7 | 8 | from builtins import super 9 | import pickle 10 | import string 11 | import argparse 12 | import random 13 | 14 | import torch 15 | from torch import nn, optim 16 | 17 | from pinn import RobustFill 18 | import pregex as pre 19 | from vhe import VHE, DataLoader, Factors, Result, RegexPrior 20 | import random 21 | import math 22 | from util.util import Hole 23 | 24 | regex_prior = RegexPrior() 25 | #k_shot = 4 26 | 27 | 28 | 29 | regex_vocab = list(string.printable[:-4]) + \ 30 | [pre.OPEN, pre.CLOSE, pre.String, pre.Concat, pre.Alt, pre.KleeneStar, pre.Plus, pre.Maybe, Hole] + \ 31 | regex_prior.character_classes 32 | 33 | 34 | def make_holey(r: pre.Pregex, p=0.05) -> (pre.Pregex, torch.Tensor): 35 | """ 36 | makes a regex holey 37 | """ 38 | scores = 0 39 | def make_holey_inner(r: pre.Pregex) -> pre.Pregex: 40 | if random.random() < p: 41 | nonlocal scores 42 | scores += regex_prior.scoreregex(r) 43 | return Hole() 44 | else: 45 | return r.map(make_holey_inner) 46 | 47 | holey = make_holey_inner(r) 48 | return holey, torch.Tensor([scores]) 49 | 50 | def sketch_logprior(preg: pre.Pregex, p=0.05) -> torch.Tensor: 51 | logprior=0 52 | for r, d in preg.walk(): 53 | if type(r) is pre.String or type(r) is pre.CharacterClass or type(r) is Hole: #TODO, is a leaf 54 | if type(r) is Hole: #TODO 55 | logprior += math.log(p) + d*math.log(1-p) 56 | else: 57 | logprior += (d+1)*math.log(1-p) 58 | 59 | return torch.tensor([logprior]) 60 | 61 | 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--pretrain', action='store_true') 67 | parser.add_argument('--pretrain_holes', action='store_true') 68 | parser.add_argument('--debug', action='store_true') 69 | parser.add_argument('--nosave', action='store_true') 70 | parser.add_argument('--start_with_holes', action='store_true') 71 | parser.add_argument('--variance_reduction', action='store_true') 72 | args = parser.parse_args() 73 | 74 | max_length = 30 75 | batch_size = 200 76 | 77 | print("Loading model", flush=True) 78 | try: 79 | if args.start_with_holes: 80 | model=torch.load("./saved_models/sketch_model_pretrain_holes.p") 81 | print('found saved model, loading pretrained model with holes') 82 | else: 83 | model=torch.load("./saved_models/sketch_model_pretrained.p") 84 | print('found saved model, loading pretrained model') 85 | 86 | except FileNotFoundError: 87 | print("no saved model, creating new one") 88 | model = RobustFill(input_vocabularies=[string.printable[:-4]], target_vocabulary=regex_vocab, max_length=max_length) 89 | 90 | model.cuda() 91 | print("number of parameters is", sum(p.numel() for p in model.parameters() if p.requires_grad)) 92 | 93 | 94 | ######## Pretraining without holes ######## 95 | 96 | def getInstance(k_shot=4): 97 | """ 98 | Returns a single problem instance, as input/target strings 99 | """ 100 | #k_shot = 4 #random.choice(range(3,6)) #this means from 3 to 5 examples 101 | 102 | while True: 103 | r = regex_prior.sampleregex() 104 | c = r.flatten() 105 | x = r.sample() 106 | Dc = [r.sample() for i in range(k_shot)] 107 | c_input = [c] 108 | if all(len(x)= E_{S~Q) log P(y|S) 201 | = E{S~R} Q(S)/R(S) log P(y|S)""" 202 | 203 | #variance reduction term to help with learning 204 | #Dc is a list of lists of strings (which are iterables) 205 | 206 | #can do something silly like average embedding 207 | 208 | 209 | objective = torch.exp(model.score(Dc, sketch, autograd=True)) / torch.exp(sketch_prior) * (holescore - model.variance_red.data) - torch.pow((holescore - model.variance_red),2) 210 | #objective = model.score(Dc, sketch, autograd=True) / torch.exp(sketch_prior) * holescore 211 | 212 | #objective = model.score(Dc, sketch, autograd=True)*(holescore - sketch_prior) 213 | #control: 214 | #objective = model.score(Dc, c, autograd=True) 215 | #control 2: 216 | #objective = model.score(Dc, sketch, autograd=True) 217 | 218 | #objective = model.score(Dc, sketch, autograd=True)*(holescore - full_program_score) 219 | #print(objective.size()) 220 | objective = objective.mean() 221 | #print(objective) 222 | (-objective).backward() 223 | optimizer.step() 224 | 225 | else: #if args.pretrain_holes: 226 | objective = model.optimiser_step(Dc, sketch) 227 | 228 | model.iteration += 1 229 | model.hole_scores.append(objective) 230 | if i%1==0: 231 | print("iteration", i, "score:", objective, flush=True) 232 | if i%100==0: 233 | inst = getInstance() 234 | samples, scores = model.sampleAndScore([inst['Dc']], nRepeats=100) 235 | index = scores.index(max(scores)) 236 | #print(samples[index]) 237 | try: sample = pre.create(list(samples[index])) 238 | except: sample = samples[index] 239 | sample = samples[index] 240 | print("actual program:", pre.create(inst['c'])) 241 | print("generated examples:") 242 | print(*inst['Dc']) 243 | print("inferred:", sample) 244 | 245 | if i%100==0: # and not i==0: 246 | if not args.nosave: 247 | torch.save(model, './saved_models/sketch_model_holes_ep_{}.p'.format(str(i))) 248 | torch.save(model, './saved_models/sketch_model_holes.p') 249 | 250 | ####### End train with holes ######## 251 | 252 | 253 | 254 | ######testing###### 255 | 256 | 257 | # ###### full RL training with enumeration ######## 258 | # optimizer = optim.Adam(model.parameters(), lr=1e-3) 259 | 260 | # for batch in actual_data: #or synthetic data, whatever: 261 | # optimizer.zero_grad() 262 | # samples, scores = model.sampleAndScore(batch, autograd=True) #TODO: change so you can get grad through here, and so that scores are seperate??? 263 | # objective = [] 264 | # for sample, score, examples in zip(samples, scores, batch): 265 | # #DO RL 266 | # objective.append = -score*find_ll_reward_with_enumeration(sample, examples, time=10) #TODO, ideally should be per hole 267 | 268 | 269 | # objective = torch.sum(objective) 270 | # #DO RL 271 | # #TODO???? oh god. Make reward in positive range??) 272 | # #Q: is it even 273 | 274 | # objective.backward() 275 | # optimizer.step() 276 | 277 | #from holes, enumerate: 278 | #option 1: use ec enumeration 279 | #option 2: use something else 280 | """ 281 | RL questions: 282 | - should I do param updates for each batch??? 283 | - can we even get gradients through the whole sample? no, but not too hard I think 284 | - 285 | """ 286 | #trees will be nice because you can enumerate within the tree by just sampling more --- then is it even worth it?? 287 | ###### End full training ######## 288 | 289 | 290 | 291 | 292 | #informal testing: 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/util/__init__.py -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/deepcoder_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/util/__pycache__/deepcoder_util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/robustfill_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtensor/neural_sketch/687d6b6c68ad534de32285398f54317b7f2edceb/util/__pycache__/robustfill_util.cpython-36.pyc -------------------------------------------------------------------------------- /util/algolisp_pypy_util.py: -------------------------------------------------------------------------------- 1 | #pypy_util.py 2 | 3 | 4 | import time 5 | from program import ParseFailure, Context 6 | from grammar import NoCandidates, Grammar, SketchEnumerationFailure 7 | from utilities import timing, callCompiled 8 | from collections import namedtuple 9 | from itertools import islice, zip_longest 10 | from functools import reduce 11 | 12 | from program_synthesis.algolisp.dataset import executor 13 | #from memory_profiler import profile 14 | 15 | SketchTup = namedtuple("SketchTup", ['sketch', 'g']) 16 | AlgolispResult = namedtuple("AlgolispResult", ["sketch", "prog", "hit", "n_checked", "time"]) 17 | 18 | 19 | #from algolisp code 20 | #executor_ = executor.LispExecutor() 21 | 22 | # #for reference: 23 | # def get_stats_from_code(args): 24 | # res, example, executor_ = args 25 | # if len(example.tests) == 0: 26 | # return None 27 | # if executor_ is not None: 28 | # stats = executor.evaluate_code( 29 | # res.code_tree if res.code_tree else res.code_sequence, example.schema.args, example.tests, 30 | # executor_) 31 | # stats['exact-code-match'] = is_same_code(example, res) 32 | # stats['correct-program'] = int(stats['tests-executed'] == stats['tests-passed']) 33 | # else: assert False 34 | # #what is a res? 35 | 36 | 37 | def test_program_on_IO(e, IO, schema_args, executor_): 38 | """ 39 | run executor 40 | """ 41 | stats = executor.evaluate_code( 42 | e, schema_args, IO, 43 | executor_) 44 | #print(stats['tests-executed'], stats['tests-passed']) 45 | return stats['tests-executed'] == stats['tests-passed'] 46 | 47 | def alternate(*args): 48 | # note: python 2 - use izip_longest 49 | for iterable in zip_longest(*args): 50 | for item in iterable: 51 | if item is not None: 52 | yield item 53 | 54 | 55 | def algolisp_enumerate(tp, IO, schema_args, mdl, sketchtups, n_checked, n_hit, t, max_to_check, i): 56 | results = [] 57 | executor_ = executor.LispExecutor() 58 | 59 | # empty = all(False for _ in alternate(*(((sk.sketch, x) for x in sk.g.sketchEnumeration(Context.EMPTY, [], tp, sk.sketch, mdl)) for sk in sketchtups))) 60 | # if empty: 61 | # print("ALTERNATE EMPTY") #TODO 62 | #print("lensketchups", len(sketchtups)) 63 | #print(p for p in sk.g.sketchEnumeration(Context.EMPTY, [], tp, sk.sketch, mdl) for sk in sketchtups) 64 | 65 | #for sketch, xp in alternate(*(((sk.sketch, x) for x in sk.g.sketchEnumeration(Context.EMPTY, [], tp, sk.sketch, mdl)) for sk in sketchtups)): 66 | 67 | 68 | f = lambda tup: map( lambda x: (tup.sketch, x), tup.g.sketchEnumeration(Context.EMPTY, [], tp, tup.sketch, mdl, maximumDepth=100)) 69 | sIterable = list(map(f, sketchtups)) 70 | 71 | hit = False 72 | #print("task", i, "starting enum loop inner, took", time.time()-t, "seconds", flush=True) 73 | for sketch, xp in alternate(* sIterable ): 74 | #prevent overflow maybe? 75 | if n_checked % 1000 == 0: 76 | executor.code_lisp._EXECUTION_CACHE = {} 77 | executor_ = executor.LispExecutor() 78 | _, _, p = xp 79 | e = p.evaluate([]) 80 | #print(e) 81 | hit = test_program_on_IO(e, IO, schema_args, executor_) 82 | prog = p if hit else None 83 | n_checked += 1 84 | n_hit += 1 if hit else 0 85 | if hit: 86 | results.append( AlgolispResult(sketch, prog, hit, n_checked, time.time()-t) ) 87 | break 88 | if n_checked >= max_to_check: 89 | del sketch 90 | del xp 91 | break 92 | if n_checked < len(sketchtups) and not hit: print("WARNING: not all candidate sketches checked") 93 | #print("task", i, "done enum loop inner, took", time.time()-t, "seconds", flush=True) 94 | del executor_ 95 | del sIterable 96 | del f 97 | #print("ex cache len:") 98 | #print(len(executor.code_lisp._EXECUTION_CACHE)) 99 | executor.code_lisp._EXECUTION_CACHE = {} 100 | return results, n_checked, n_hit 101 | 102 | 103 | #pypy_enumerate(datum.tp, datum.IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check) 104 | def pypy_enumerate(tp, IO, schema_args, mdl, sketchtups, n_checked, n_hit, t, max_to_check, i): 105 | return callCompiled(algolisp_enumerate, tp, IO, schema_args, mdl, sketchtups, n_checked, n_hit, t, max_to_check, i) 106 | #if pypy doesn't work we can just call algolisp_enumerate normally 107 | -------------------------------------------------------------------------------- /util/deepcoder_util.py: -------------------------------------------------------------------------------- 1 | #deepcoder_util.py 2 | import sys 3 | import os 4 | #sys.path.append(os.path.abspath('./')) 5 | 6 | from builtins import super 7 | import pickle 8 | import string 9 | import argparse 10 | import random 11 | 12 | import torch 13 | from torch import nn, optim 14 | 15 | from pinn import RobustFill 16 | from pinn import SyntaxCheckingRobustFill 17 | import random 18 | import math 19 | import time 20 | 21 | from collections import OrderedDict 22 | #from util import enumerate_reg, Hole 23 | 24 | from grammar import Grammar, NoCandidates 25 | from deepcoderPrimitives import deepcoderProductions, flatten_program 26 | 27 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure 28 | 29 | import math 30 | from type import Context, arrow, tint, tlist, tbool, UnificationFailure 31 | 32 | productions = deepcoderProductions() # TODO - figure out good production probs ... 33 | basegrammar = Grammar.fromProductions(productions, logVariable=0.0) # TODO 34 | 35 | def deepcoder_vocab(grammar, n_inputs=2): 36 | return [prim.name for prim in grammar.primitives] + ['input_' + str(i) for i in range(n_inputs)] + [''] # TODO 37 | 38 | def tokenize_for_robustfill(IOs): 39 | """ 40 | tokenizes a batch of IOs 41 | """ 42 | newIOs = [] 43 | for examples in IOs: 44 | tokenized = [] 45 | for xs, y in examples: 46 | if isinstance(y, list): 47 | y = ["LIST_START"] + y + ["LIST_END"] 48 | else: 49 | y = [y] 50 | serializedInputs = [] 51 | for x in xs: 52 | if isinstance(x, list): 53 | x = ["LIST_START"] + x + ["LIST_END"] 54 | else: 55 | x = [x] 56 | serializedInputs.extend(x) 57 | tokenized.append((serializedInputs, y)) 58 | newIOs.append(tokenized) 59 | return newIOs 60 | 61 | def buildCandidate(request, context, environment, parsecontext, index_dict={}): 62 | """Primitives that are candidates for being used given a requested type 63 | If returnTable is false (default): 64 | returns [((log)likelihood, tp, primitive, context)] 65 | if returntable is true: returns {primitive: ((log)likelihood, tp, context)}""" 66 | variable_list = ['input_' + str(i) for i in range(4)] 67 | 68 | if len(parsecontext) == 0: raise NoCandidates() 69 | chosen_str = parsecontext[0] 70 | parsecontext = parsecontext[1:] #is this right? 71 | 72 | candidate = None 73 | 74 | #for l, t, p in self.productions: 75 | #print(sys.path) 76 | #print("PRIMITIVE GLOBALS:", Primitive.GLOBALS) 77 | if chosen_str in Primitive.GLOBALS: #if it is a primtive 78 | p = Primitive.GLOBALS[chosen_str] 79 | t = p.tp 80 | try: 81 | newContext, t = t.instantiate(context) 82 | newContext = newContext.unify(t.returns(), request) 83 | t = t.apply(newContext) 84 | #candidates.append((l, t, p, newContext)) 85 | candidate = (t, p, newContext) 86 | 87 | except UnificationFailure: 88 | raise ParseFailure() 89 | 90 | elif chosen_str in variable_list: 91 | try: 92 | j = index_dict[chosen_str] 93 | except KeyError: 94 | raise ParseFailure() 95 | t = environment[j] 96 | #for j, t in enumerate(environment): 97 | try: 98 | newContext = context.unify(t.returns(), request) 99 | t = t.apply(newContext) 100 | candidate = (t, Index(j), newContext) 101 | except UnificationFailure: 102 | raise ParseFailure() 103 | else: #if it is a hole: 104 | try: assert chosen_str == '' #TODO, choose correct representation of program 105 | except AssertionError as e: 106 | print("bad string:", chosen_str) 107 | assert False 108 | p = Hole() 109 | t = request #[try all possibilities and backtrack] #p.inferType(context, environment, freeVariables) #TODO 110 | # or hole is request. 111 | try: 112 | newContext, t = t.instantiate(context) 113 | newContext = newContext.unify(t.returns(), request) 114 | t = t.apply(newContext) 115 | #candidates.append((l, t, p, newContext)) 116 | candidate = (t, p, newContext) 117 | 118 | except UnificationFailure: 119 | raise ParseFailure() 120 | 121 | if candidate == None: 122 | raise NoCandidates() 123 | 124 | 125 | return parsecontext, candidate 126 | 127 | 128 | def parseprogram(pseq, request): #TODO 129 | num_inputs = len(request.functionArguments()) 130 | 131 | index_dict = {'input_' + str(i): num_inputs-i-1 for i in range(num_inputs)} 132 | 133 | #request = something #TODO 134 | 135 | def _parse(request, parsecontext, context, environment): 136 | if request.isArrow(): 137 | parsecontext, context, expression = _parse( 138 | request.arguments[1], parsecontext, context, [ 139 | request.arguments[0]] + environment) 140 | return parsecontext, context, Abstraction(expression) #TODO 141 | 142 | parsecontext, candidate = buildCandidate(request, context, environment, parsecontext, index_dict=index_dict) 143 | 144 | newType, chosenPrimitive, context = candidate 145 | 146 | # Sample the arguments 147 | xs = newType.functionArguments() 148 | returnValue = chosenPrimitive 149 | 150 | for x in xs: 151 | x = x.apply(context) 152 | parsecontext, context, x = _parse( 153 | x, parsecontext, context, environment) 154 | returnValue = Application(returnValue, x) 155 | 156 | return parsecontext, context, returnValue 157 | 158 | _, _, e = _parse( 159 | request, pseq, Context.EMPTY, []) 160 | return e 161 | 162 | def make_holey_deepcoder(prog, 163 | k, 164 | g, 165 | request, 166 | inv_temp=1.0, 167 | reward_fn=None, 168 | sample_fn=None, 169 | verbose=False, 170 | use_timeout=False): 171 | #need to add improved_dc_model=False, nHoles=1 172 | """ 173 | inv_temp==1 => use true mdls 174 | inv_temp==0 => sample uniformly 175 | 0 < inv_temp < 1 ==> something in between 176 | """ 177 | choices = g.enumerateHoles(request, prog, k=k) 178 | 179 | if len(list(choices)) == 0: 180 | #if there are none, then use the original program 181 | choices = [(prog, 0)] 182 | #print("prog:", prog, "choices", list(choices)) 183 | progs, weights = zip(*choices) 184 | 185 | # if verbose: 186 | # for c in choices: print(c) 187 | 188 | if sample_fn is None: 189 | sample_fn = lambda x: inv_temp*math.exp(inv_temp*x) 190 | 191 | if use_timeout: 192 | # sample timeout 193 | r = random.random() 194 | t = -math.log(r)/inv_temp 195 | 196 | cs = list(zip(progs, [-w for w in weights])) 197 | if t < list(cs)[0][1]: return prog, None, None 198 | 199 | below_cutoff_choices = [(p, w) for p,w in cs if t > w] 200 | 201 | 202 | _, max_w = max(below_cutoff_choices, key=lambda item: item[1]) 203 | 204 | options = [(p, None, None) for p, w in below_cutoff_choices if w==max_w] 205 | x = random.choices(options, k=1) 206 | return x[0] 207 | 208 | # cdf = lambda x: 1 - math.exp(-inv_temp*(x)) 209 | # weights = [-w for w in weights] 210 | # probs = list(weights) 211 | # #probs[0] = cdf(weights[0]) 212 | # for i in range(0, len(weights)-1): 213 | # probs[i] = cdf(weights[i+1]) - cdf(weights[i]) 214 | 215 | # probs[-1] = 1 - cdf(weights[-1]) 216 | # weights = tuple(probs) 217 | else: 218 | #normalize weights, and then rezip 219 | 220 | weights = [sample_fn(w) for w in weights] 221 | #normalize_weights 222 | w_sum = sum(w for w in weights) 223 | weights = [w/w_sum for w in weights] 224 | 225 | 226 | if reward_fn is None: 227 | reward_fn = math.exp 228 | rewards = [reward_fn(w) for w in weights] 229 | 230 | prog_reward_probs = list(zip(progs, rewards, weights)) 231 | 232 | if verbose: 233 | for p, r, prob in prog_reward_probs: 234 | print(p, prob) 235 | 236 | if k > 1: 237 | x = random.choices(prog_reward_probs, weights=weights, k=1) 238 | return x[0] #outputs prog, prob 239 | else: 240 | return prog_reward_probs[0] #outputs prog, prob 241 | 242 | 243 | 244 | 245 | # ####unused#### 246 | # def sample_request(): #TODO 247 | # requests = [ 248 | # arrow(tlist(tint), tlist(tint)), 249 | # arrow(tlist(tint), tint), 250 | # #arrow(tint, tlist(tint)), 251 | # arrow(tint, tint) 252 | # ] 253 | # return random.choices(requests, weights=[4,3,1])[0] #TODO 254 | 255 | # def isListFunction(tp): 256 | # try: 257 | # Context().unify(tp, arrow(tlist(tint), tint)) #TODO, idk if this will work 258 | # return True 259 | # except UnificationFailure: 260 | # try: 261 | # Context().unify(tp, arrow(tlist(tint), tlist(tint))) #TODO, idk if this will work 262 | # return True 263 | # except UnificationFailure: 264 | # return False 265 | 266 | 267 | # def isIntFunction(tp): 268 | # try: 269 | # Context().unify(tp, arrow(tint, tint)) #TODO, idk if this will work 270 | # return True 271 | # except UnificationFailure: 272 | # try: 273 | # Context().unify(tp, arrow(tint, tlist(tint))) #TODO, idk if this will work 274 | # return True 275 | # except UnificationFailure: 276 | # return False 277 | 278 | # def sampleIO(program, tp, k_shot=4, verbose=False): #TODO 279 | # #needs to have constraint stuff 280 | # N_EXAMPLES = 5 281 | # RANGE = 30 #TODO 282 | # LIST_LEN_RANGE = 8 283 | # OUTPUT_RANGE = 128 284 | 285 | # #stolen from Luke. Should be abstracted in a class of some sort. 286 | # def _featuresOfProgram(program, tp, k_shot=4): 287 | # e = program.evaluate([]) 288 | # examples = [] 289 | # if isListFunction(tp): 290 | # sample = lambda: random.sample(range(-RANGE, RANGE), random.randint(0, LIST_LEN_RANGE)) 291 | # elif isIntFunction(tp): 292 | # sample = lambda: random.randint(-RANGE, RANGE-1) 293 | # else: 294 | # return None 295 | # for _ in range(N_EXAMPLES*3): 296 | # x = sample() 297 | # #try: 298 | # #print("program", program, "e", e, "x", x) 299 | # y = e(x) 300 | # #eprint(tp, program, x, y) 301 | 302 | # if x == [] or y == []: 303 | # if verbose: print("tripped empty list continue ") 304 | # continue 305 | 306 | # if type(y) == int: 307 | # y = [y] #TODO fix this dumb hack ... 308 | # if type(x) == int: 309 | # x = [x] #TODO fix this dumb hack ... 310 | 311 | # if any((num >= OUTPUT_RANGE) or (num < -OUTPUT_RANGE) for num in y): #this is a total hack 312 | # if verbose: print("tripped range continue", flush=True) 313 | # continue 314 | 315 | # examples.append( (x, y) ) 316 | 317 | 318 | # # except: 319 | # # print("tripped continue 2", flush=True) 320 | # # continue 321 | 322 | # if len(examples) >= k_shot: break 323 | # else: 324 | # return None #What do I do if I get a None?? Try another program ... 325 | # return examples 326 | 327 | # return _featuresOfProgram(program, tp, k_shot=k_shot) 328 | 329 | # def getInstance(k_shot=4, max_length=30, verbose=False, with_holes=False, k=None): 330 | # """ 331 | # Returns a single problem instance, as input/target strings 332 | # """ 333 | # #TODO 334 | # assert False, "this function has been depricated" 335 | # while True: 336 | # #request = arrow(tlist(tint), tint, tint) 337 | # #print("starting getIntance loop") 338 | # request = sample_request() 339 | # #print("request", request) 340 | # p = grammar.sample(request, maximumDepth=4) #grammar not abstracted well in this script 341 | # #print("program:", p) 342 | 343 | # IO = sampleIO(p, request, k_shot, verbose=verbose) 344 | # if IO == None: #TODO, this is a hack!!! 345 | # if verbose: print("tripped IO==None continue") 346 | # continue 347 | # if any(y==None for x,y in IO): 348 | # if verbose: print("tripped y==None continue") 349 | # assert False 350 | # continue 351 | 352 | # pseq = flatten_program(p) 353 | 354 | 355 | # if all(len(x)': "map", 34 | "record_name#":"record_name" 35 | } 36 | 37 | def convert_to_tp(x, target_tp): 38 | original_tp = x.infer() 39 | request = arrow(original_tp, target_tp) 40 | converter = [x for x in Primitive.GLOBALS.values() if x.tp==request][0] # TODO - hack 41 | return Application(converter, x) 42 | 43 | def uast_to_ec(uast) -> Program: # TODO 44 | # get list l 45 | 46 | def recurse_list(l, target_tp): 47 | #base case 48 | x = recurse(l[0]) 49 | x = convert_to_tp(x, target_tp) 50 | e = convert_to_tp(x, tlist(target_tp)) 51 | for exp in l[1:]: 52 | x = recurse(exp) 53 | if x.infer() != target_tp: 54 | x = convert_to_tp(x, target_tp) # maybe always convert? 55 | request = arrow(target_tp, tlist(target_tp), tlist(target_tp)) 56 | list_converter = [x for x in Primitive.GLOBALS.values() if x.tp==request][0] # TODO 57 | e = Application(Application(list_converter, x), e) 58 | return e 59 | 60 | def recurse(l): # a list 61 | if type(l) == list: 62 | e = recurse(l[0]) 63 | tp_args = e.infer().functionArguments() 64 | for tp_arg, exp in zip(tp_args, l[1:]): 65 | if tp_arg.name=='list': 66 | # for now, assume the correct num of brackets 67 | x = recurse_list(exp, tp_arg.arguments[0]) 68 | else: 69 | x = recurse(exp) 70 | if tp_arg != x.infer(): 71 | x = convert_to_tp(x, tp_arg) 72 | e = Application(e, x) 73 | return e 74 | elif l in primitive_lookup: 75 | e = Primitive.GLOBALS[primitive_lookup[l]] 76 | elif l in type_lookup: # TODO fix this part 77 | raise unimplemented() 78 | # elif 79 | # e = other_prims[l] 80 | # elif l in names_stuff: 81 | # raise unimplemented() 82 | elif type(l) == int: 83 | l = str(l) 84 | e = recurse(l) 85 | else: 86 | print("l:", l) 87 | assert False 88 | return e 89 | 90 | return recurse(uast) # ec_program 91 | 92 | def tokenize_lisp_expr(sexp): 93 | #split sexp by spaces 94 | #split at brackets and parens 95 | return slist 96 | 97 | if __name__=='__main__': 98 | #["assign", "bool", ["var", "bool", "var5"], ["val", "bool", False]] 99 | u = ["invoke", "bool", "&&", [["invoke", "bool", "!", [["invoke", "bool", "!=", [["val", "int", -1], ["invoke", "int", "string_find", [["var", "char*", "var0"], ["val", "char*", "1"]]]]]]], ["invoke", "bool", "!=", [["val", "int", -1], ["invoke", "int", "string_find", [["var", "char*", "var1"], ["val", "char*", "1"]]]]]]] 100 | p = uast_to_ec(u) 101 | print(p) 102 | e = p.evaluate([]) 103 | print(e) 104 | print(u==e) 105 | 106 | -------------------------------------------------------------------------------- /util/pypy_util.py: -------------------------------------------------------------------------------- 1 | #pypy_util.py 2 | 3 | 4 | import time 5 | from program import ParseFailure, Context 6 | from grammar import NoCandidates, Grammar 7 | from utilities import timing, callCompiled 8 | from collections import namedtuple 9 | from itertools import islice, zip_longest 10 | from functools import reduce 11 | from deepcoderPrimitives import deepcoderPrimitives 12 | 13 | SketchTup = namedtuple("SketchTup", ['sketch', 'g']) 14 | DeepcoderResult = namedtuple("DeepcoderResult", ["sketch", "prog", "hit", "n_checked", "time"]) 15 | deepcoderPrimitives() 16 | 17 | def test_program_on_IO(e, IO): 18 | return all(reduce(lambda a, b: a(b), xs, e)==y for xs, y in IO) 19 | 20 | def alternate(*args): 21 | # note: python 2 - use izip_longest 22 | for iterable in zip_longest(*args): 23 | for item in iterable: 24 | if item is not None: 25 | yield item 26 | 27 | 28 | def dc_enumerate(tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check): 29 | results = [] 30 | 31 | f = lambda tup: map( lambda x: (tup.sketch, x), tup.g.sketchEnumeration(Context.EMPTY, [], tp, tup.sketch, mdl, maximumDepth=20)) 32 | sIterable = list(map(f, sketchtups)) 33 | 34 | 35 | for sk, x in alternate(*sIterable): 36 | _, _, p = x 37 | e = p.evaluate([]) 38 | hit = test_program_on_IO(e, IO) 39 | prog = p if hit else None 40 | n_checked += 1 41 | n_hit += 1 if hit else 0 42 | results.append( DeepcoderResult(sk, prog, hit, n_checked, time.time()-t) ) 43 | if hit: break 44 | if n_checked >= max_to_check: break 45 | 46 | return results, n_checked, n_hit 47 | 48 | 49 | def pypy_enumerate(g, tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check): 50 | return callCompiled(dc_enumerate, tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check) -------------------------------------------------------------------------------- /util/rb_pypy_util.py: -------------------------------------------------------------------------------- 1 | #rb_pypy_util.py 2 | 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath('./')) 6 | sys.path.append(os.path.abspath('./ec')) 7 | 8 | import time 9 | from program import ParseFailure, Context 10 | from grammar import NoCandidates, Grammar 11 | from utilities import timing, callCompiled 12 | from collections import namedtuple 13 | from itertools import islice, zip_longest 14 | from functools import reduce 15 | from RobustFillPrimitives import robustFillPrimitives, RobustFillProductions 16 | from util.pypy_util import alternate 17 | 18 | #A nasty hack!!! 19 | max_len = 25 20 | max_index = 4 21 | 22 | #_ = robustFillPrimitives(max_len=max_len, max_index=max_index) 23 | basegrammar = Grammar.fromProductions(RobustFillProductions(max_len, max_index)) 24 | 25 | 26 | SketchTup = namedtuple("SketchTup", ['sketch', 'g']) 27 | RobustFillResult = namedtuple("RobustFillResult", ["sketch", "prog", "hit", "n_checked", "time", "g_hit"]) 28 | 29 | def test_program_on_IO(e, IO, n_rec_examples, generalization=False, noise=False): 30 | examples = IO if generalization else IO[:n_rec_examples] 31 | 32 | if noise: 33 | l = len(examples) 34 | try: 35 | return sum(e(x)==y for x, y in examples) >= l - 1 #TODO: check that this makes sense 36 | except IndexError: 37 | return False 38 | try: 39 | return all(e(x)==y for x, y in examples) #TODO: check that this makes sense 40 | except IndexError: 41 | return False 42 | 43 | def rb_enumerate(tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check, test_generalization, n_rec_examples, input_noise=False): 44 | results = [] 45 | 46 | results = [] 47 | 48 | f = lambda tup: map( lambda x: (tup.sketch, x), tup.g.sketchEnumeration(Context.EMPTY, [], tp, tup.sketch, mdl, maximumDepth=20)) 49 | sIterable = list(map(f, sketchtups)) 50 | 51 | for sk, x in alternate(*sIterable): 52 | _, _, p = x 53 | e = p.evaluate([]) 54 | try: 55 | hit = test_program_on_IO(e, IO, n_rec_examples, noise=input_noise) 56 | except: hit = False 57 | prog = p if hit else None 58 | n_checked += 1 59 | n_hit += 1 if hit else 0 60 | #testing generalization 61 | gen_hit = test_program_on_IO(e, IO, n_rec_examples, generalization=True) if test_generalization else False 62 | results.append( RobustFillResult(sk, prog, hit, n_checked, time.time()-t, gen_hit)) 63 | if hit: break 64 | if n_checked >= max_to_check: break 65 | 66 | return results, n_checked, n_hit 67 | 68 | 69 | def rb_pypy_enumerate(tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check, test_generalization, n_rec_examples, input_noise=False): 70 | return callCompiled(rb_enumerate, tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check, test_generalization, n_rec_examples, input_noise=input_noise) -------------------------------------------------------------------------------- /util/regex_pypy_util.py: -------------------------------------------------------------------------------- 1 | #regex_pypy_util.py 2 | 3 | 4 | import time 5 | from program import ParseFailure, Context 6 | from grammar import NoCandidates, Grammar 7 | from utilities import timing, callCompiled 8 | from collections import namedtuple 9 | from itertools import islice, zip_longest 10 | from functools import reduce 11 | import pregex as pre #oh god 12 | 13 | RegexResult = namedtuple("RegexResult", ["sketch", "prog", "ll", "n_checked", "time"]) 14 | SketchTup = namedtuple("SketchTup", ['sketch', 'g']) 15 | 16 | 17 | def alternate(*args): 18 | # note: python 2 - use izip_longest 19 | for iterable in zip_longest(*args): 20 | for item in iterable: 21 | if item is not None: 22 | yield item 23 | 24 | def regex_enumerate(tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check): 25 | results = [] 26 | best_ll = float('-inf') 27 | for sketch, x in alternate(*(((sk.sketch, x) for x in sk.g.sketchEnumeration(Context.EMPTY, [], tp, sk.sketch, mdl)) for sk in sketchtups)): #TODO!! sketches n grammars n stuff 28 | _, _, p = x 29 | e = p.evaluate([]) 30 | ll = sum(e.match(example)/len(example) for example in IO[nExamples:])/len(IO[nExamples:]) # TODO -- check this 31 | if ll > best_ll: 32 | best_ll = ll 33 | hit = ( ll is not float('-inf') ) 34 | prog = p if hit else None 35 | n_checked += 1 36 | n_hit += 1 if hit else 0 37 | results.append((RegexResult(sketch, prog, ll, n_checked, time.time()-t), best_ll)) 38 | if n_checked >= max_to_check: break 39 | 40 | return results, n_checked, n_hit 41 | 42 | def pypy_enumerate(tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check): 43 | return callCompiled(regex_enumerate, tp, IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check) -------------------------------------------------------------------------------- /util/robustfill_util.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from builtins import super 4 | import pickle 5 | import string 6 | import argparse 7 | import random 8 | 9 | import torch 10 | from torch import nn, optim 11 | 12 | from pinn import RobustFill 13 | from pinn import SyntaxCheckingRobustFill 14 | import random 15 | import math 16 | import time 17 | from string import printable 18 | 19 | from pregex import pregex as pre 20 | 21 | from collections import OrderedDict 22 | #from util import enumerate_reg, Hole 23 | 24 | import re 25 | 26 | 27 | from grammar import Grammar, NoCandidates 28 | #from deepcoderPrimitives import deepcoderProductions, flatten_program 29 | from RobustFillPrimitives import RobustFillProductions, flatten_program, tprogram, Constraint_prop, delimiters 30 | from program import Application, Hole, Primitive, Index, Abstraction, ParseFailure, prettyProgram 31 | import math 32 | from type import Context, arrow, UnificationFailure 33 | 34 | #productions = RobustFillProductions() # TODO - figure out good production probs ... 35 | #basegrammar = Grammar.fromProductions(productions, logVariable=0.0) # TODO 36 | 37 | def robustfill_vocab(grammar): 38 | return [prim.name for prim in grammar.primitives] + [''] # TODO 39 | 40 | class timing(object): 41 | def __init__(self, message): 42 | self.message = message 43 | 44 | def __enter__(self): 45 | self.start = time.time() 46 | return self 47 | 48 | def __exit__(self, type, value, traceback): 49 | print(self.message, f"in {time.time() - self.start} seconds") 50 | 51 | d = {re.escape(i):i for i in delimiters} 52 | d['\\.'] = '\\.' 53 | d['\\)'] = '\\)' 54 | d['\\('] = '\\(' 55 | d['('] = '\\(' 56 | d[')'] = '\\)' 57 | preg_dict = {r'[A-Z][a-z]+': '\\u\\l+', r'[A-Z]': '\\u', r'[a-z]': '\\l', **d} #note: uses escaped versions of delimiters for constraints 58 | 59 | def extract_constraints(program): # TODO 60 | #throw an issue if min bigger than max 61 | return Constraint_prop().execute(program) 62 | 63 | def sample_program(g=None, max_len=10, max_string_size=100): 64 | assert g is not None 65 | request = tprogram 66 | #with timing("sample from grammar"): 67 | p = g.sample(request, maximumDepth=5, maxAttempts=None) #todo args?? 68 | 69 | if flatten_program(p).count('concat_list') >= max_len: 70 | #resample 71 | return sample_program(g=g, max_len=max_len, max_string_size=max_string_size) # TODO 72 | else: return p 73 | 74 | 75 | def generate_inputs_from_constraints(constraint_dict, min_size, max_string_size=100): 76 | #sample a size from min to max 77 | size = random.randint(min_size, max_string_size) 78 | indices = set(range(size)) 79 | slist = random.choices(printable[:-4] , k=size) 80 | # schematically: 81 | #print("min_size", min_size) 82 | #print("size", size) 83 | for item in constraint_dict: 84 | #print("ITEM", item) 85 | #print("sliststr:", ''.join(slist)) 86 | num_to_insert = max(0, constraint_dict[item] - len(re.findall(re.escape(item), ''.join(slist)))) 87 | #print("num_to_insert", num_to_insert) 88 | if len(indices) < num_to_insert: return None 89 | indices_to_insert = set(random.sample(indices, k=num_to_insert)) 90 | # do something here 91 | #print("PREG INPUT",item if item not in preg_dict else preg_dict[item]) 92 | for i in indices_to_insert: 93 | slist[i] = pre.create(item).sample() if item not in preg_dict else pre.create(preg_dict[item]).sample() 94 | indices = indices - indices_to_insert 95 | #may be too big but whatever 96 | string = ''.join(slist) 97 | if len(string) > max_string_size: return string[:max_string_size] # may break but whatever 98 | return string 99 | 100 | def generate_IO_examples(program, num_examples=5, max_string_size=100): 101 | constraint_dict, min_size = extract_constraints(program) # check here that 102 | if min_size > max_string_size: return None 103 | """ 104 | need to generate num_examples examples. 105 | can go wrong if: 106 | min_size > max_string_size - sample new prog - do that in extract_constraints and sample new prog 107 | program(input) > max_string_size - can do a few tries 108 | program(input) throws an error bc constraint prop wasn't perfect - can do a few tries 109 | """ 110 | examples = [] 111 | for _ in range(2*num_examples): 112 | #with timing("generate_from constraints"): 113 | instring = generate_inputs_from_constraints(constraint_dict, min_size, max_string_size=max_string_size) 114 | if instring is None: continue 115 | try: outstring = program.evaluate([])(instring) 116 | except IndexError: continue 117 | if len(outstring) > max_string_size: continue # might cheange to return None for speed 118 | examples.append((instring, outstring)) 119 | if len(examples) >= num_examples: break 120 | else: return None 121 | return examples 122 | 123 | 124 | 125 | def tokenize_for_robustfill(IOs): 126 | """ 127 | tokenizes a batch of IOs ... I think none is necessary ... 128 | """ 129 | return IOs 130 | 131 | def buildCandidate(request, context, environment, parsecontext): #TODO 132 | """Primitives that are candidates for being used given a requested type 133 | If returnTable is false (default): 134 | returns [((log)likelihood, tp, primitive, context)] 135 | if returntable is true: returns {primitive: ((log)likelihood, tp, context)}""" 136 | variable_list = ['input_' + str(i) for i in range(4)] 137 | 138 | if len(parsecontext) == 0: raise NoCandidates() 139 | chosen_str = parsecontext[0] 140 | parsecontext = parsecontext[1:] #is this right? 141 | 142 | candidate = None 143 | #for l, t, p in self.productions: 144 | if chosen_str in Primitive.GLOBALS: #if it is a primtive 145 | p = Primitive.GLOBALS[chosen_str] 146 | t = p.tp 147 | try: 148 | newContext, t = t.instantiate(context) 149 | newContext = newContext.unify(t.returns(), request) 150 | t = t.apply(newContext) 151 | #candidates.append((l, t, p, newContext)) 152 | candidate = (t, p, newContext) 153 | 154 | except UnificationFailure: 155 | raise ParseFailure() 156 | 157 | else: #if it is a hole: 158 | try: assert chosen_str == '' #TODO, choose correct representation of program 159 | except AssertionError as e: 160 | print("bad string:", chosen_str) 161 | assert False 162 | p = Hole() 163 | t = request #[try all possibilities and backtrack] #p.inferType(context, environment, freeVariables) #TODO 164 | # or hole is request. 165 | try: 166 | newContext, t = t.instantiate(context) 167 | newContext = newContext.unify(t.returns(), request) 168 | t = t.apply(newContext) 169 | #candidates.append((l, t, p, newContext)) 170 | candidate = (t, p, newContext) 171 | 172 | except UnificationFailure: 173 | raise ParseFailure() 174 | if candidate == None: 175 | raise NoCandidates() 176 | return parsecontext, candidate 177 | 178 | 179 | def parseprogram(pseq, request): #TODO 180 | num_inputs = len(request.functionArguments()) 181 | #request = something #TODO 182 | def _parse(request, parsecontext, context, environment): 183 | if request.isArrow(): 184 | parsecontext, context, expression = _parse( 185 | request.arguments[1], parsecontext, context, [ 186 | request.arguments[0]] + environment) 187 | return parsecontext, context, Abstraction(expression) #TODO 188 | 189 | parsecontext, candidate = buildCandidate(request, context, environment, parsecontext) 190 | newType, chosenPrimitive, context = candidate 191 | 192 | # Sample the arguments 193 | xs = newType.functionArguments() 194 | returnValue = chosenPrimitive 195 | for x in xs: 196 | x = x.apply(context) 197 | parsecontext, context, x = _parse( 198 | x, parsecontext, context, environment) 199 | returnValue = Application(returnValue, x) 200 | return parsecontext, context, returnValue 201 | _, _, e = _parse( 202 | request, pseq, Context.EMPTY, []) 203 | return e 204 | 205 | def make_holey_deepcoder(prog, k, g, request, inv_temp=1.0, reward_fn=None, sample_fn=None): # TODO 206 | """ 207 | inv_temp==1 => use true mdls 208 | inv_temp==0 => sample uniformly 209 | 0 < inv_temp < 1 ==> something in between 210 | """ 211 | choices = g.enumerateHoles(request, prog, k=k) 212 | if len(list(choices)) == 0: 213 | #if there are none, then use the original program 214 | choices = [(prog, 0)] 215 | #print("prog:", prog, "choices", list(choices)) 216 | progs, weights = zip(*choices) 217 | #normalize weights, and then rezip 218 | if reward_fn is None: 219 | reward_fn = math.exp 220 | if sample_fn is None: 221 | sample_fn = lambda x: math.exp(inv_temp*x) 222 | rewards = [reward_fn(w) for w in weights] 223 | weights = [sample_fn(w) for w in weights] 224 | #normalize_weights 225 | w_sum = sum(w for w in weights) 226 | weights = [w/w_sum for w in weights] 227 | 228 | prog_reward_probs = list(zip(progs, rewards, weights)) 229 | 230 | if k > 1: 231 | x = random.choices(prog_reward_probs, weights=weights, k=1) 232 | return x[0] #outputs prog, prob 233 | else: 234 | return prog_reward_probs[0] #outputs prog, prob 235 | 236 | if __name__=='__main__': 237 | g = Grammar.fromProductions(RobustFillProductions()) 238 | print(len(g)) 239 | request = tprogram 240 | p = g.sample(request) 241 | print("request:", request) 242 | print("program:") 243 | print(prettyProgram(p)) 244 | s = 'abcdefg' 245 | e = p.evaluate([]) 246 | print("prog applied to", s) 247 | print(e(s)) 248 | print("flattened_program:") 249 | flat = flatten_program(p) 250 | print(flat) 251 | pr = parseprogram(flat, request) 252 | print(prettyProgram(pr)) 253 | 254 | --------------------------------------------------------------------------------