├── .gitignore ├── equation_verification ├── __init__.py ├── sequential_model_constants.py ├── constants.py ├── sequential_model.py ├── equation_completion_experiment.py ├── dataset_loading.py └── nn_tree_experiment.py ├── data └── .DS_Store ├── visualization ├── example1.pdf └── example2.pdf ├── explore_trace.sh ├── checkpoints ├── exp_lstm_drop0.1_hidden50_12093-model-68.checkpoint ├── exp_lstm_drop0.1_hidden50_12093-optim-68.checkpoint ├── exp_lstm_drop0.2_hidden60_45678-model-99.checkpoint ├── exp_lstm_drop0.2_hidden60_45678-optim-99.checkpoint ├── exp_vanilla_drop0.1_hidden45_45678-model-66.checkpoint ├── exp_vanilla_drop0.1_hidden45_45678-optim-66.checkpoint ├── exp_vanilla_drop0.2_hidden100_45678-model-98.checkpoint ├── exp_vanilla_drop0.2_hidden100_45678-optim-98.checkpoint ├── exp_gatedpushpop_normalize_no_op_45678_60_0.2-model-74.checkpoint ├── exp_gatedpushpop_normalize_no_op_45678_60_0.2-optim-74.checkpoint ├── exp_gatedpushpop_normalize_no_op_50_0.1_12345-model-97.checkpoint └── exp_gatedpushpop_normalize_no_op_50_0.1_12345-optim-97.checkpoint ├── optimizers.py ├── run_40k_lstm.sh ├── run_40k_vanilla.sh ├── trace_40k_gatedpushpop_normalize_no_op.sh ├── run_40k_gatedpushpop_normalize_no_op.sh ├── LICENSE ├── aggregate_test_splits.py ├── using_explore_trace.txt ├── compmlete_40k.sh ├── README.md ├── parse.py └── explore_trace.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /equation_verification/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /visualization/example1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/visualization/example1.pdf -------------------------------------------------------------------------------- /visualization/example2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/visualization/example2.pdf -------------------------------------------------------------------------------- /explore_trace.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 explore_trace.py --trace-files gatedpushpop_normalize_no_op.trace \ 4 | --dump-dir visualization -------------------------------------------------------------------------------- /checkpoints/exp_lstm_drop0.1_hidden50_12093-model-68.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_lstm_drop0.1_hidden50_12093-model-68.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_lstm_drop0.1_hidden50_12093-optim-68.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_lstm_drop0.1_hidden50_12093-optim-68.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_lstm_drop0.2_hidden60_45678-model-99.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_lstm_drop0.2_hidden60_45678-model-99.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_lstm_drop0.2_hidden60_45678-optim-99.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_lstm_drop0.2_hidden60_45678-optim-99.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_vanilla_drop0.1_hidden45_45678-model-66.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_vanilla_drop0.1_hidden45_45678-model-66.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_vanilla_drop0.1_hidden45_45678-optim-66.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_vanilla_drop0.1_hidden45_45678-optim-66.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_vanilla_drop0.2_hidden100_45678-model-98.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_vanilla_drop0.2_hidden100_45678-model-98.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_vanilla_drop0.2_hidden100_45678-optim-98.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_vanilla_drop0.2_hidden100_45678-optim-98.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_gatedpushpop_normalize_no_op_45678_60_0.2-model-74.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_gatedpushpop_normalize_no_op_45678_60_0.2-model-74.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_gatedpushpop_normalize_no_op_45678_60_0.2-optim-74.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_gatedpushpop_normalize_no_op_45678_60_0.2-optim-74.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_gatedpushpop_normalize_no_op_50_0.1_12345-model-97.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_gatedpushpop_normalize_no_op_50_0.1_12345-model-97.checkpoint -------------------------------------------------------------------------------- /checkpoints/exp_gatedpushpop_normalize_no_op_50_0.1_12345-optim-97.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ForoughA/recursiveMemNet/HEAD/checkpoints/exp_gatedpushpop_normalize_no_op_50_0.1_12345-optim-97.checkpoint -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD 2 | from torch.optim import Adam 3 | 4 | def build_optimizer(params, optimizer_type, lr, momentum, weight_decay, beta1, beta2): 5 | if optimizer_type == "sgd": 6 | return SGD(params, lr, momentum=momentum, weight_decay=weight_decay) 7 | elif optimizer_type == "adam": 8 | return Adam(params, lr, weight_decay=weight_decay, betas=(beta1, beta2)) 9 | else: 10 | raise ValueError("Unhandled optimizer type:%s" % optimizer_type) -------------------------------------------------------------------------------- /run_40k_lstm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | FOLDER=40k 3 | NAME=lstm 4 | mkdir results_${FOLDER} 5 | mkdir models_${FOLDER} 6 | 7 | for SEED in 12093 12345 45678 8 | do 9 | python3 -u -m equation_verification.nn_tree_experiment \ 10 | --seed $SEED \ 11 | --train-path data/40k_train.json \ 12 | --validation-path data/40k_val_shallow.json \ 13 | --test-path data/40k_test.json \ 14 | --model-class LSTMTrees \ 15 | --checkpoint-every-n-epochs 1 \ 16 | --result-path results_${FOLDER}/exp_${NAME}_${SEED}.json \ 17 | --model-prefix models_${FOLDER}/exp_${NAME}_${SEED} > results_${FOLDER}/exp_${NAME}_${SEED}.log & 18 | done -------------------------------------------------------------------------------- /run_40k_vanilla.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | FOLDER=40k 3 | NAME=vanilla 4 | mkdir results_${FOLDER} 5 | mkdir models_${FOLDER} 6 | 7 | for SEED in 12093 12345 45678 8 | do 9 | python3 -u -m equation_verification.nn_tree_experiment \ 10 | --seed $SEED \ 11 | --train-path data/40k_train.json \ 12 | --validation-path data/40k_val_shallow.json \ 13 | --test-path data/40k_test.json \ 14 | --model-class NNTrees \ 15 | --checkpoint-every-n-epochs 1 \ 16 | --result-path results_${FOLDER}/exp_${NAME}_${SEED}.json \ 17 | --model-prefix models_${FOLDER}/exp_${NAME}_${SEED} > results_${FOLDER}/exp_${NAME}_${SEED}.log & 18 | done -------------------------------------------------------------------------------- /trace_40k_gatedpushpop_normalize_no_op.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | FOLDER=fixed_mem2out_stack_gatedpushpop_40k 3 | NAME=gatedpushpop_normalize_no_op 4 | mkdir results_${FOLDER} 5 | mkdir models_${FOLDER} 6 | 7 | SEED=12093 #or 12345 or 45678 8 | python3 -u -m equation_verification.nn_tree_experiment \ 9 | --seed $SEED \ 10 | --train-path data/40k_train.json \ 11 | --validation-path data/40k_val_shallow.json \ 12 | --test-path data/40k_test.json \ 13 | --stack-node-activation tanh \ 14 | --tree-node-activation tanh \ 15 | --model-class StackNNTreesMem2out \ 16 | --top-k 1 \ 17 | --stack-type stack \ 18 | --gate-push-pop \ 19 | --normalize-action \ 20 | --no-op \ 21 | --checkpoint-every-n-epochs 1 \ 22 | --num-epochs 40 \ 23 | --model-prefix models_${FOLDER}/exp_${NAME}_${SEED} \ 24 | --load-epoch 40 \ 25 | --evaluate-only \ 26 | --trace-path gatedpushpop_normalize_no_op.trace \ -------------------------------------------------------------------------------- /run_40k_gatedpushpop_normalize_no_op.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | FOLDER=fixed_mem2out_stack_gatedpushpop_40k 3 | NAME=gatedpushpop_normalize_no_op 4 | mkdir results_${FOLDER} 5 | mkdir models_${FOLDER} 6 | 7 | 8 | for SEED in 12093 12345 45678 9 | do 10 | python3 -u -m equation_verification.nn_tree_experiment \ 11 | --seed $SEED \ 12 | --train-path data/40k_train.json \ 13 | --validation-path data/40k_val_shallow.json \ 14 | --test-path data/40k_test.json \ 15 | --stack-node-activation tanh \ 16 | --tree-node-activation tanh \ 17 | --model-class StackNNTreesMem2out \ 18 | --top-k 1 \ 19 | --stack-type stack \ 20 | --gate-push-pop \ 21 | --normalize-action \ 22 | --no-op \ 23 | --checkpoint-every-n-epochs 1 \ 24 | --result-path results_${FOLDER}/exp_${NAME}_${SEED}.json \ 25 | --model-prefix models_${FOLDER}/exp_${NAME}_${SEED} > results_${FOLDER}/exp_${NAME}_${SEED}.log & 26 | done -------------------------------------------------------------------------------- /equation_verification/sequential_model_constants.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from equation_verification.constants import UNARY_FNS, BINARY_FNS, NUMBER_ENCODER, \ 3 | NUMBER_DECODER, SYMBOL_ENCODER, CONSTANTS 4 | 5 | SYMBOL_CLASSES = OrderedDict([ 6 | ('Symbol', ['var_%d' % d for d in range(10)]), 7 | ('Integer', [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -2, -3]), 8 | ('Rational', ['2/5', '-1/2', '0']), 9 | ('Float', [0.7]), 10 | ('Unary', UNARY_FNS), 11 | ('Binary', BINARY_FNS), 12 | ('parentheses',['(', ')', ','])]) 13 | 14 | 15 | def build_vocab(): 16 | vocab = OrderedDict() 17 | ctr = 0 18 | 19 | for symbol_type, values in SYMBOL_CLASSES.items(): 20 | for value in values: 21 | vocab["%s_%s" % (symbol_type, str(value))] = ctr 22 | ctr += 1 23 | for symbol_name in CONSTANTS: 24 | vocab[symbol_name] = ctr 25 | ctr += 1 26 | return vocab 27 | 28 | 29 | VOCAB = build_vocab() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Project Tree-SMU 2 | 3 | Copyright for project Tree-SMU is jointly held by Forough Arabshahi and Zhichu Lu, 2019. 4 | 5 | MIT License: 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a 8 | copy of this software and associated documentation files (the 9 | "Software"), to deal in the Software without restriction, including 10 | without limitation the rights to use, copy, modify, merge, publish, 11 | distribute, sublicense, and/or sell copies of the Software, and to 12 | permit persons to whom the Software is furnished to do so, subject to 13 | the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included 16 | in all copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 19 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 20 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 21 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 23 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 24 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /aggregate_test_splits.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | folder = sys.argv[1] 5 | head = sys.argv[2] 6 | stat_list = [] 7 | sizes = [277] * 14 + [281] 8 | total = 4159 9 | for file in os.listdir(folder): 10 | if not file.startswith(head) or not file.endswith(".json") or "agg" in file: 11 | continue 12 | path = os.path.join(folder, file) 13 | with open(path) as f: 14 | stats = json.loads(f.read()) 15 | stat_list.append(stats) 16 | assert len(stat_list) == 15 17 | epoch = stat_list[0]["epoch"] 18 | set_ = stat_list[0]["set"] 19 | agg = {"type":"eval", "epoch":epoch, "set": set_, "stats":{}} 20 | for depth in set(sum((list(stat["stats"].keys()) for stat in stat_list),[])): 21 | agg["stats"][depth] = {} 22 | for statistic in stat_list[0]["stats"]["all"]: 23 | numerator = 0 24 | denominator = 0 25 | for stat in stat_list: 26 | if depth in stat["stats"]: 27 | numerator += stat["stats"][depth]["count"] * stat["stats"][depth][statistic] 28 | denominator += stat["stats"][depth]["count"] 29 | agg["stats"][depth][statistic] = numerator / denominator if statistic != "count" else denominator 30 | with open(os.path.join(folder, head+"agg.json"), "w") as f: 31 | json.dump(agg, f, indent=4, sort_keys=True) 32 | -------------------------------------------------------------------------------- /equation_verification/constants.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | UNARY_FNS = ['sin', 'cos', 'csc', 'sec', 'tan', 'cot', 4 | 'asin', 'acos', 'acsc', 'asec', 'atan', 'acot', 5 | 'sinh', 'cosh', 'csch', 'sech', 'tanh', 'coth', 6 | 'asinh', 'acosh', 'acsch', 'asech', 'atanh', 'acoth', 7 | 'exp'] 8 | 9 | BINARY_FNS = ['Add', 'Mul', 'Pow', 'log'] 10 | 11 | NUMBER_ENCODER = "Number_enc" 12 | NUMBER_DECODER = "Number_dec" 13 | SYMBOL_ENCODER = "Symbol" 14 | 15 | SYMBOL_CLASSES = OrderedDict([ 16 | ('Symbol', ['var_%d' % d for d in range(10)]), 17 | ('Integer', [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -2, -3]), 18 | ('Rational', ['2/5', '-1/2', '0']), 19 | ('Float', [0.7])]) 20 | 21 | CONSTANTS = ['NegativeOne', 'NaN', 'Infinity', 'Exp1', 'Pi', 'One', 'Half'] 22 | 23 | 24 | def get_vocab_key(function_name, value): 25 | if function_name in ['Integer', 'Rational', 'Float', 'Symbol']: 26 | key = function_name + '_%s' % str(value) 27 | print('key:', key) 28 | else: 29 | key = function_name 30 | 31 | return key 32 | 33 | 34 | def build_vocab(): 35 | vocab = OrderedDict() 36 | ctr = 0 37 | 38 | for symbol_type, values in SYMBOL_CLASSES.items(): 39 | for value in values: 40 | vocab["%s_%s" % (symbol_type, str(value))] = ctr 41 | ctr += 1 42 | for symbol_name in CONSTANTS: 43 | vocab[symbol_name] = ctr 44 | ctr += 1 45 | return vocab 46 | 47 | 48 | VOCAB = build_vocab() 49 | 50 | PRETTYNAMES = dict([ 51 | ("Equality", "="), 52 | (SYMBOL_ENCODER, "EMBEDDING"), 53 | (NUMBER_ENCODER, "ENCODER"), 54 | (NUMBER_DECODER, "DECODER"), 55 | ("Add", "+"), 56 | ("Pow", "^"), 57 | ("Mul", "\u00D7"), 58 | ('NegativeOne', "-1"), 59 | ('NaN', "nan"), 60 | ('Infinity', "inf"), 61 | ('Exp1', 'e'), 62 | ('Pi', '\u03C0'), 63 | ('One','1'), 64 | ('Half', '1/2'), 65 | (SYMBOL_ENCODER+'_var_0', 'x'), 66 | (SYMBOL_ENCODER+'_var_1', 'y'), 67 | (SYMBOL_ENCODER+'_var_2', 'z'), 68 | (SYMBOL_ENCODER+'_var_3', 'w'), 69 | ]) -------------------------------------------------------------------------------- /using_explore_trace.txt: -------------------------------------------------------------------------------- 1 | An outline for interactively inspecting what the model is doing: 2 | 3 | 0. train and produce model checkpoints 4 | run_40k_gatedpushpop_normalize_no_op.sh 5 | 1. run the trained model at an epoch of choice and log activations 6 | trace_40k_gatedpushpop_normalize_no_op.sh 7 | 2. interactively select and visualize model activations for train/test examples 8 | explore_trace.sh 9 | 10 | Some additional tips for using explore_trace.sh: 11 | 12 | After explore_trace.sh is run, you will see an interactive prompt, we try 13 | to emulate a folder structure for grouping the sets of examples 14 | in our data. 15 | 16 | For example, you can type 'ls' to see the current top level folders, and 'cd' to 17 | select one. At the top level there is train, validation, and test. The next level 18 | is divided into bins based on whether an equality is predicted correctly by our model. 19 | For instance, to see correctly predicted examples, you can do 'cd (True,)', 20 | and do 'cd (False,)' for incorrect predictions. 'cd ..' will get you to the previous 21 | level. 22 | 23 | A bin is the lowest level folder, once you are in a particular bin, all examples 24 | are essentially in a flat list. The initial printed statistics include the count 25 | for the bin and a percentage N_bin/N_set, (e.g. correct_train/all_train). 26 | You can select a random batch with the sel command, 27 | e.g. 'sel mode=random n=10' to select 10 random examples from the bin. 28 | Finally, typing the 'plot' command will create visualizations of the equation tree 29 | and save it to the 'visualizations' directory. 30 | 31 | In the visualization documents, the 'mem' rows of a node represents the stack, 32 | the 'act' rows represent push/pop action (or if the particular version of the 33 | model has no-op action, then it will be in the 3rd row). 34 | 'rep' row of a node represents the output of the node. 35 | 36 | You can change the default parameters (--load-epoch and --trace-path in 37 | trace_40k_gatedpushpop_normalize_no_op.sh) and (--trace-files in 38 | trace_40k_gatedpushpop_normalize_no_op.sh) to explore the model activations for a 39 | different epoch. 40 | 41 | -------------------------------------------------------------------------------- /compmlete_40k.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | FOLDER=checkpoints 3 | RESULTS=results_completion 4 | mkdir ${RESULTS} 5 | 6 | for DATASPLIT in {0..14} 7 | do 8 | NAME=vanilla \ 9 | SEED=45678 \ 10 | EPOCH=66 \ 11 | NUM_HIDDEN=45 \ 12 | DROPOUT=0.1 \ 13 | && \ 14 | python3 -u -m equation_verification.equation_completion_experiment \ 15 | --seed $SEED \ 16 | --train-path 40k_empty.json \ 17 | --validation-path 40k_empty.json \ 18 | --test-path data/blanks/40k_test_blank_${DATASPLIT}.json \ 19 | --candidate-path data/candidate_classes.json \ 20 | --model-class NNTrees \ 21 | --dropout ${DROPOUT} \ 22 | --num-hidden ${NUM_HIDDEN} \ 23 | --checkpoint-every-n-epochs 1 \ 24 | --result-path ${RESULTS}/exp_${NAME}_completion_${NUM_HIDDEN}_${DROPOUT}_${SEED}_${EPOCH}_${DATASPLIT}.json \ 25 | --model-prefix ${FOLDER}/exp_${NAME}_drop${DROPOUT}_hidden${NUM_HIDDEN}_${SEED} \ 26 | --load-epoch ${EPOCH} \ 27 | --evaluate-only \ 28 | --cut 1000000 > ${RESULTS}/exp_${NAME}_completion_${NUM_HIDDEN}_${DROPOUT}_${SEED}_${EPOCH}_${DATASPLIT}.log & 29 | done 30 | 31 | for DATASPLIT in {0..14} 32 | do 33 | NAME=lstm \ 34 | EPOCH=68 \ 35 | NUM_HIDDEN=50 \ 36 | DROPOUT=0.1 \ 37 | SEED=12093 \ 38 | && \ 39 | python3 -u -m equation_verification.equation_completion_experiment \ 40 | --seed $SEED \ 41 | --train-path 40k_empty.json \ 42 | --validation-path 40k_empty.json \ 43 | --test-path data/blanks/40k_test_blank_${DATASPLIT}.json \ 44 | --candidate-path data/candidate_classes.json \ 45 | --model-class LSTMTrees \ 46 | --num-hidden ${NUM_HIDDEN} \ 47 | --dropout ${DROPOUT} \ 48 | --checkpoint-every-n-epochs 1 \ 49 | --result-path ${RESULTS}/exp_${NAME}_completion_${NUM_HIDDEN}_${DROPOUT}_${SEED}_${EPOCH}_${DATASPLIT}.json \ 50 | --model-prefix ${FOLDER}/exp_${NAME}_drop${DROPOUT}_hidden${NUM_HIDDEN}_${SEED} \ 51 | --load-epoch ${EPOCH} \ 52 | --evaluate-only \ 53 | --cut 1000000 > ${RESULTS}/exp_${NAME}_completion_${NUM_HIDDEN}_${DROPOUT}_${SEED}_${EPOCH}_${DATASPLIT}.log & 54 | done 55 | 56 | for DATASPLIT in {0..14} 57 | do 58 | NAME=gatedpushpop_normalize_no_op \ 59 | SEED=12345 \ 60 | EPOCH=97 \ 61 | NUM_HIDDEN=50 \ 62 | DROPOUT=0.1 \ 63 | && \ 64 | python3 -u -m equation_verification.equation_completion_experiment \ 65 | --seed $SEED \ 66 | --train-path 40k_empty.json \ 67 | --validation-path 40k_empty.json \ 68 | --test-path data/blanks/40k_test_blank_${DATASPLIT}.json \ 69 | --candidate-path data/candidate_classes.json \ 70 | --stack-node-activation tanh \ 71 | --tree-node-activation tanh \ 72 | --model-class StackNNTreesMem2out \ 73 | --num-hidden ${NUM_HIDDEN} \ 74 | --dropout ${DROPOUT} \ 75 | --top-k 1 \ 76 | --stack-type stack \ 77 | --gate-push-pop \ 78 | --normalize-action \ 79 | --no-op \ 80 | --checkpoint-every-n-epochs 1 \ 81 | --result-path ${RESULTS}/exp_${NAME}_completion_${NUM_HIDDEN}_${DROPOUT}_${SEED}_${EPOCH}_${DATASPLIT}.json \ 82 | --model-prefix ${FOLDER}/exp_${NAME}_${NUM_HIDDEN}_${DROPOUT}_${SEED} \ 83 | --load-epoch ${EPOCH} \ 84 | --evaluate-only \ 85 | --cut 1000000 > ${RESULTS}/exp_${NAME}_completion_${NUM_HIDDEN}_${DROPOUT}_${SEED}_${EPOCH}_${DATASPLIT}.log & 86 | done -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tree Stack Memory Units 2 | 3 | This is the code for the [Tree Stack Memory Units](https://arxiv.org/abs/1911.01545). 4 | 5 | ## Dependencies 6 | Python 3.6 or higher 7 | 8 | PyTorch 1.0 or higher 9 | 10 | ## Visualization tool 11 | With this tool you can explore the learned representations at the output of each node, the elements of the stack and the learned push and pop operations for each equation from train, test or validation data. Read the [using_explore_trace.txt](https://github.com/ForoughA/recursiveMemNet/blob/master/using_explore_trace.txt) file to understand how to use the tool. Two example outputs of the visualization tool are stored in folder [visualization/](https://github.com/ForoughA/recursiveMemNet/tree/master/visualization) 12 | 13 | ## Navigation 14 | 1. data/ 15 | * is a folder containing the data used in all the experiments reported in the paper. we have included the train/test/validation splits that were used in the paper. There are two sets of data for each of the experiments in the paper 16 | - Equation Verification 17 | - code/data/40k_test.json 18 | - code/data/40k_train.json 19 | - code/data/40k_val_shallow.json 20 | - Equation Completion 21 | - code/data/blanks/40k_test_blank_{i}.json for i in range 0 through 14. (test data for equation completion. This is the data from data/40k_test.json in which sub-trees of depth 1 and 2 are randomly replaced with blanks) 22 | - code/data/candidate_classes.json (contain equvalence classes for the blank candidates to compute top-k accuracy. Expressions in the same class evaluate to the same value. For example, 1+1 and 2 are in the same class. if the top ranked candidate belongs to the same class as the correct answer, then it is considered a correct blank prediction.) 23 | 2. equation_verification/ 24 | * is a folder containing the python scripts implementing the proposed model. There are 5 filed in this folder. 25 | - equation_verification/__init__.py 26 | - equation_verification/constants.py 27 | - equation_verification/dataset_loading.py 28 | - equation_verification/nn_tree_experiment.py 29 | - code/equation_verification/nn_tree_model.py (implementation of Tree-SMU, Tree-LSTM and Tree-RNN) 30 | - code/equation_verification/equation_completion_experiment.py 31 | - code/equation_verification/sequential_model.py (recurrent neural network (RNN) and LSTM implementation) 32 | - code/equation_verification/sequential_model_constants.py 33 | 3. Shell scripts can be used to replicate the experiments in the paper. These are: 34 | * Equation Verification Experiments (all the hyperparameters are default settings and not the optimal hyperparameters.) 35 | - code/run_40k_gatedpushpop_normalize_no_op.sh that replicates the Tree-SMU results. Other model abblations can be replicatd by setting the corresponding command line arg in this script by removing the normalize and the no-op flags. 36 | - code/run_40k_lstm.sh that replicates tree-LSTM results 37 | - code/run_40k_vanilla.sh that replicatd tree-RNN results 38 | * Equation Completion Experiments 39 | - code/complete_40k.sh that replicates tree-SMU, tree-LSTM and tree-RNN results using the trained model (best seed and best hyper-parameters) for Tree-SMU 40 | - code/aggregate_test_splits.py this is a python script that aggregates the results of code/complete_40k.sh into a single json. Usage is: 41 | * python3 aggregate_test_splits.py results_completion
 ${prefix} 42 | * prefix is the prefix of the output of data/complete_40k.sh. For example: exp_gatedpushpop_normalize_no_op_completion_50_0.1_12345_97_ 43 | * the output will be written to : ${prefix}agg.json in the same folder (results_completion
) 44 | 4. code/checkpoints/ 45 | * This folder contains the model checkpoints for the best trained models for equation verification and equation completion picked based on the maximum accuracy on the validation data. 46 | 5. visualization tool that can be used to visualize a learned tree model for any chosen input equation and explore the learned weights and stack operations. Relevant files are: 47 | - using_explore_trace.txt is a text file that contains instructions for how to use the visualization tool 48 | - visualization/ the visualization results will be saved in this folder. We have included two visualization examples in this folder just to give a feeling of what to expect. 49 | - explore_trace.py python script. There is no need to run this script, we have provided shell scripts for that. 50 | - explore_trace.sh shell script for running the visualization tool. Please refer to using_explore_trace.txt to see how to use this 51 | - trace_40k_gatedpushpop_normalize_no_op.sh shell script for running the visualization tool. Please refer to using_explore_trace.txt to see how to use this 52 | - parse.py parses strings into our the tree class in our python code 53 | - optimizers.py are model optimizers: adam and sgd 54 | 55 | ## Notebooks 56 | stay tuned for IPython (Jupyter) Notebooks... 57 | -------------------------------------------------------------------------------- /parse.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | 4 | import _ast 5 | from sympy import sympify 6 | 7 | from equation_verification import constants 8 | 9 | NAME = { 10 | "x": "var_0", 11 | "y": "var_1", 12 | "z": "var_2", 13 | "w": "var_3", 14 | } 15 | 16 | 17 | def build(t): 18 | if isinstance(t, _ast.Name): 19 | if t.id == "Pi" or t.id == "pi": 20 | return T.PI() 21 | return T.v(NAME[t.id]) 22 | if isinstance(t, _ast.Num): 23 | if isinstance(t.n, float): 24 | return T.f(t.n) 25 | elif isinstance(t.n, int): 26 | return T.i(t.n) 27 | elif isinstance(t, _ast.Call): 28 | return T(t.func.id, None, *[build(c) for c in t.args]) 29 | elif isinstance(t, _ast.BinOp): 30 | op = t.op.__class__.__name__ 31 | if op == "Add": 32 | pass 33 | elif op == "Mult": 34 | op = "Mul" 35 | elif op == "Pow": 36 | pass 37 | elif op == "Sub": 38 | return T("Add", None, build(t.left), 39 | T("Mul", None, T.NEGATIVE_ONE(), build(t.right))) 40 | elif op == "Div": 41 | return T("Mul", None, build(t.left), 42 | T("Pow", None, build(t.right), T.NEGATIVE_ONE())) 43 | else: 44 | raise AssertionError("Unknown BinOp %s" % op) 45 | return T(op, None, build(t.left), build(t.right)) 46 | elif isinstance(t, _ast.UnaryOp): 47 | op = t.op.__class__.__name__ 48 | if op == "USub": 49 | if not isinstance(t.operand, _ast.Num): 50 | raise AssertionError( 51 | "- only applies to number, got %s" % ast.dump(t.operand)) 52 | if not t.operand.n == 1: 53 | return T("Mul", None, T.NEGATIVE_ONE(), build(t.operand)) 54 | # raise AssertionError("- only applies to number 1, got %s" % 55 | # str(t.operand.n)) 56 | return T.NEGATIVE_ONE() 57 | else: 58 | raise AssertionError("Unknown UnaryOp %s" % op) 59 | else: 60 | raise AssertionError("Unknown Node %s" % t.__class__.__name__) 61 | 62 | def parse_expression(s): 63 | tree = ast.parse(s).body[0].value 64 | if isinstance(tree, ast.Compare): 65 | raise AssertionError("%s" % tree.__class__.__name__) 66 | return build(tree) 67 | 68 | def parse_equation(s): 69 | tree = ast.parse(s).body[0].value 70 | assert isinstance(tree, _ast.Compare) 71 | return T("Equality", None, build(tree.left), build(tree.comparators[0])) 72 | 73 | class T: 74 | 75 | def __init__(self, name, varname, l=None, r=None): 76 | self.name = name 77 | self.varname = varname 78 | self.l = l 79 | self.r = r 80 | self.children = [x for x in [l,r] if x is not None] 81 | 82 | @classmethod 83 | def v(cls, s): 84 | return T("Symbol", s, None, None) 85 | 86 | @classmethod 87 | def i(cls, i): 88 | return T("Integer", str(i), None, None) 89 | 90 | @classmethod 91 | def f(cls, f): 92 | return T("Number", str(round(f,2)), None, None) 93 | 94 | @classmethod 95 | def NEGATIVE_ONE(cls): 96 | return T("NegativeOne", "-1", None, None) 97 | 98 | @classmethod 99 | def PI(cls): 100 | return T("Pi", "pi", None, None) 101 | 102 | def __str__(self): 103 | if len(self.children) == 0: 104 | return self.name + "(" + self.varname + ")" 105 | else: 106 | return self.name + "(" + ", ".join([c.__str__() for c in self.children]) + ")" 107 | 108 | def sympy_str(self): 109 | if len(self.children) == 0: 110 | return self.varname 111 | else: 112 | return (self.name if self.name != "Equality" else "Eq")+ "(" + ", ".join( 113 | [c.sympy_str() for c in self.children]) + ")" 114 | 115 | def inord(self): 116 | ret = [self] 117 | for i in range(2): 118 | if i >= len(self.children): 119 | ret.append("#") 120 | else: 121 | ret.extend(self.children[i].inord()) 122 | return ret 123 | 124 | def depth(self): 125 | if len(self.children) == 0: 126 | return 0 127 | else: 128 | return 1 + max(c.depth() for c in self.children) 129 | 130 | def size(self): 131 | if len(self.children) == 0: 132 | return 1 133 | else: 134 | return 1 + sum(c.size() for c in self.children) 135 | 136 | def dump(self, label=True): 137 | lst = self.inord() 138 | func = [item.name if isinstance(item, T) else item for item in lst] 139 | depth = [str(item.depth()) if isinstance(item, T) else item for item in lst] 140 | nodeNum = [] 141 | n = 0 142 | numNodes = 0 143 | for node in func: 144 | if node != "#": 145 | nodeNum.append(str(n)) 146 | n += 1 147 | numNodes += 1 148 | else: 149 | nodeNum.append("#") 150 | vars = [(str(item.varname) if item.varname is not None else "") if isinstance(item, T) else item for item in lst] 151 | variables = sorted(list(set([var for var in vars if var.startswith("var")]))) 152 | variables = dict(zip(variables, list(range(len(variables))))) 153 | label = label 154 | if not label == sympify(self.sympy_str()): 155 | # print("Wrong", self.sympy_str()) 156 | pass 157 | assert numNodes == self.size() 158 | assert isinstance(label, bool) 159 | ret = {"equation":{ 160 | "depth": ",".join(depth), 161 | "func": ",".join(func), 162 | "nodeNum": ",".join(nodeNum), 163 | "numNodes": str(numNodes), 164 | "vars": ",".join(vars), 165 | "variables": variables, 166 | }, 167 | "label": str(int(label))} 168 | print(ret) 169 | return ret 170 | 171 | # with open("controlled_generation/axioms_basic.txt", "rt") as f: 172 | # for line in f: 173 | # a = parse_equation(line.strip()) 174 | # json.dumps(a.dump()) 175 | # # print(a) 176 | # # print(json.dumps(a.dump(),sort_keys=True,indent=4)) 177 | if __name__ == "__main__": 178 | print(json.dumps(parse_equation("x * (y + z) == x * y + x * z").dump(), indent=4, sort_keys=True),",") 179 | print(json.dumps(parse_equation("x * x == x ** 2").dump(), indent=4, sort_keys=True),",") 180 | print(json.dumps(parse_equation("2 ** -1 * 2 == 1").dump(), indent=4, sort_keys=True),",") 181 | print(json.dumps(parse_equation("3 ** -1 * 3 == 1").dump(), indent=4, sort_keys=True),",") 182 | print(json.dumps(parse_equation("4 ** -1 * 4 == 1").dump(), indent=4, sort_keys=True),",") 183 | print(json.dumps(parse_equation("y ** -1 * y == 1").dump(), indent=4, sort_keys=True),",") -------------------------------------------------------------------------------- /explore_trace.py: -------------------------------------------------------------------------------- 1 | import cmd 2 | import pickle 3 | import random 4 | import sys 5 | import os 6 | import traceback 7 | import numpy as np 8 | from collections import deque, OrderedDict 9 | from itertools import product 10 | 11 | from graphviz import Digraph 12 | from equation_verification.constants import NUMBER_ENCODER, NUMBER_DECODER, SYMBOL_ENCODER 13 | from argparse import ArgumentParser 14 | random.seed(0) 15 | 16 | 17 | parser = ArgumentParser() 18 | parser.add_argument("--dump-dir", type=str, default="graphviz_dump") 19 | parser.add_argument("--trace-files", type=str, nargs="+", default=None) 20 | args = parser.parse_args() 21 | 22 | if os.path.exists(args.dump_dir): 23 | txt = None 24 | while txt not in {'yes', 'no'}: 25 | txt=input("The folder %s already exists. Are you sure you want to dump output to this folder? [yes/no]" % args.dump_dir) 26 | if txt == 'yes': 27 | pass 28 | else: 29 | exit(0) 30 | 31 | PRETTYNAMES = dict([ 32 | ("Equality", "="), 33 | (SYMBOL_ENCODER, "EMBEDDING"), 34 | (NUMBER_ENCODER, "ENCODER"), 35 | (NUMBER_DECODER, "DECODER"), 36 | ("Add", "+"), 37 | ("Pow", "^"), 38 | ("Mul", "\u00D7"), 39 | ('NegativeOne', "-1"), 40 | ('NaN', "nan"), 41 | ('Infinity', "inf"), 42 | ('Exp1', 'e'), 43 | ('Pi', '\u03C0'), 44 | ('One','1'), 45 | ('Half', '1/2') 46 | ]) 47 | def visualize(tree, view="False"): 48 | view = eval(view) 49 | s = Digraph('structs', filename=os.path.join(args.dump_dir, '%s.gv' % tree.id), 50 | node_attr={'shape': 'record'}, 51 | ) 52 | s.graph_attr['rankdir'] = 'BT' 53 | def build(tree, id): 54 | myid = id[0] 55 | id[0] += 1 56 | name = tree.function_name 57 | # if name in PRETTYNAMES: 58 | # name = PRETTYNAMES[tree.function_name] 59 | # elif name.startswith("Integer") or name.startswith("Float") or name.startswith("Symbol") or name.startswith("Rational"): 60 | # name = "_".join(name.split("_")[1:]) 61 | 62 | 63 | # draw node 64 | payload = "" 65 | if hasattr(tree, 'output'): 66 | if not isinstance(tree.output, list): 67 | rep_payload = "|{rep | %s}" % str(round(tree.output, 2)) 68 | elif isinstance(tree.output[0], list): 69 | rep_payload = "|{rep | {%s}}" % "|".join( 70 | [("{%s}" % "|".join([str(round(i, 2)) 71 | for i in row])) 72 | for row in tree.output]) 73 | else: 74 | rep_payload = "|{rep | %s}" % "|".join( 75 | [str(round(i, 2)) for i in tree.output]) 76 | payload += rep_payload 77 | if hasattr(tree, 'bias'): 78 | if not isinstance(tree.bias, list): 79 | rep_payload = "|{bias | %s}" % str(round(tree.bias, 2)) 80 | elif isinstance(tree.bias[0], list): 81 | rep_payload = "|{bias | {%s}}" % "|".join( 82 | [("{%s}" % "|".join([str(round(i, 2)) 83 | for i in row])) 84 | for row in tree.bias]) 85 | else: 86 | rep_payload = "|{bias | %s}" % "|".join( 87 | [str(round(i, 2)) for i in tree.bias]) 88 | payload += rep_payload 89 | if myid == 0: 90 | payload += "|{label| %s} |{correct | %s}| {p_equals| %s}" % (str(tree.label), str(tree.correct), str(round(tree.probability, 2))) 91 | if hasattr(tree, 'action'): 92 | if isinstance(tree.action[0], list): 93 | mem_payload = "|{act | {%s}}" % "|".join( 94 | [("{%s}" % "|".join([str(round(i, 2)) 95 | for i in row])) 96 | for row in tree.action]) 97 | else: 98 | mem_payload = "|{act | %s}" % "|".join( 99 | [str(round(i, 2)) for i in tree.action]) 100 | payload += mem_payload 101 | if hasattr(tree, 'memory'): 102 | if isinstance(tree.memory[0], list): 103 | mem_payload = "|{mem | {%s}}" % "|".join( 104 | [("{%s}" % "|".join([str(round(i, 2)) 105 | for i in row])) 106 | for row in tree.memory]) 107 | else: 108 | mem_payload = "|{mem | %s}" % "|".join( 109 | [str(round(i, 2)) for i in tree.memory]) 110 | payload += mem_payload 111 | 112 | s.node(str(myid), '{{ %s} %s}' % (name, payload)) 113 | if tree.is_leaf: 114 | pass 115 | elif tree.is_unary: 116 | idl = build(tree.lchild, id) 117 | s.edges([('%d' % idl, '%d' % myid)]) 118 | elif tree.is_binary: 119 | idl = build(tree.lchild, id) 120 | idr = build(tree.rchild, id) 121 | s.edges([( '%d' % idl, '%d' % myid), ('%d' % idr, '%d' % myid)]) 122 | return myid 123 | build(tree, [0]) 124 | s.render(os.path.join(args.dump_dir, "%s %s" % (tree.pretty_str().replace("/"," div "), tree.id)), view=view) 125 | l = np.array(tree.lchild.output) 126 | r = np.array(tree.rchild.output) 127 | lnorm = np.linalg.norm(l) 128 | rnorm = np.linalg.norm(r) 129 | dot = np.dot(l, r) 130 | cosine_similarity = dot / (lnorm * rnorm) 131 | #print(tree.lchild.pretty_str(), round(lnorm,2), tree.rchild.pretty_str(), round(rnorm,2), "dot", round(dot,2), round(cosine_similarity,2)) 132 | 133 | 134 | class Shell(cmd.Cmd): 135 | 136 | def __init__(self, trace): 137 | super().__init__() 138 | self.curr_obj = trace 139 | self.stack = deque() 140 | 141 | def do_ls(self, arg): 142 | if len(self.curr_obj) > 10: 143 | show = self.curr_obj[:10] 144 | else: 145 | show = self.curr_obj 146 | for i, item in enumerate(show): 147 | print(i, item) 148 | 149 | def do_cd(self, arg): 150 | if arg == "..": 151 | if len(self.stack) == 0: 152 | print("already at top level") 153 | else: 154 | self.curr_obj = self.stack.pop() 155 | elif isinstance(self.curr_obj, dict) and arg in self.curr_obj: 156 | self.stack.append(self.curr_obj) 157 | self.curr_obj = self.curr_obj[arg] 158 | else: 159 | print("arg is not a folder") 160 | 161 | def do_sel(self, arg): 162 | kwargs = self._parse(arg) 163 | self._sel(**kwargs) 164 | 165 | def _sel(self, mode="random", n="1", depth=None, choice=None): 166 | n = int(n) 167 | depth = int(depth) if depth else None 168 | self.selection = [] 169 | domain = self.curr_obj 170 | if mode == "random": 171 | if depth is not None: 172 | domain = [tree for tree in domain if tree[0].depth == depth] 173 | self.selection = random.sample(domain, n) 174 | if mode == "select": 175 | self.selection = [domain[int(idx)] for idx in choice.split(",")] 176 | if mode == "all": 177 | self.selection = [x for x in domain] 178 | 179 | def do_plot(self, arg): 180 | kwargs = self._parse(arg) 181 | for trees in self.selection: 182 | print("."*163) 183 | for tree in trees: 184 | print("d{4:d} {1:s} {2:>10.2f} {3:s} {0:50s}".format(tree.id, tree.pretty_str(prettify=True), tree.probability, "Right" if abs(int(tree.raw["label"]) - tree.probability) < 0.5 else "Wrong", tree.depth)) 185 | # print(tree.raw["equation"]["func"], tree.raw["equation"]["vars"], tree.raw["label"]) 186 | visualize(tree, **kwargs) 187 | 188 | def onecmd(self, line): 189 | try: 190 | return super().onecmd(line) 191 | except: 192 | traceback.print_exc() 193 | 194 | def _parse(self, arg): 195 | args = arg.split() 196 | kwargs = {} 197 | for kwarg in args: 198 | key = kwarg.split("=")[0] 199 | val = kwarg.split("=")[1] 200 | kwargs[key] = val 201 | return kwargs 202 | 203 | def main(): 204 | traces = [] 205 | for trace_file in args.trace_files: 206 | with open(trace_file, "rb") as f: 207 | trace = pickle.load(f) 208 | name = trace_file.replace(".", "-").replace("/","-") 209 | traces.append((name, trace)) 210 | sets = traces[0][1].keys() 211 | assert all([trace.keys() == sets for _, trace in traces]) 212 | ultimate_trace = dict() 213 | for set in sets: 214 | ultimate_trace[set] = [] 215 | for set in traces[0][1]: 216 | for i, tree in enumerate(trace[set]): 217 | ultimate_trace[set].append([]) 218 | for name, trace in traces: 219 | for set in trace: 220 | for i, tree in enumerate(trace[set]): 221 | assert str(traces[0][1][set][i]) == str(tree) 222 | tree.id = "%s_%d_%s" % (set, i, name) 223 | ultimate_trace[set][i].append(tree) 224 | 225 | ultimate_bins = dict() 226 | for set in ultimate_trace: 227 | binnames = product(*[[True, False] for _ in range(len(traces))]) 228 | bins = OrderedDict((name, []) for name in binnames) 229 | for trees in ultimate_trace[set]: 230 | id = tuple(tree.correct for tree in trees) 231 | bins[id].append(trees) 232 | ultimate_bins[set] = bins 233 | for set in ultimate_bins: 234 | print(set) 235 | for id in list(ultimate_bins[set].keys()): 236 | ultimate_bins[set][str(id)] = ultimate_bins[set][id] 237 | idstr = "|" 238 | for i, correct in enumerate(id): 239 | if correct: 240 | idstr += args.trace_files[i] + "(right)|" 241 | else: 242 | idstr += args.trace_files[i] + "(wrong)|" 243 | print(idstr, len(ultimate_bins[set][id]), len(ultimate_bins[set][id]) / len(ultimate_trace[set])) 244 | for set in ultimate_bins: 245 | for id in list(ultimate_bins[set].keys()): 246 | if isinstance(id, tuple): 247 | del ultimate_bins[set][id] 248 | Shell(ultimate_bins).cmdloop() 249 | 250 | if __name__ == "__main__": 251 | main() 252 | -------------------------------------------------------------------------------- /equation_verification/sequential_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import sys 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | from equation_verification.sequential_model_constants import VOCAB, SYMBOL_CLASSES 12 | from equation_verification.constants import UNARY_FNS, BINARY_FNS, NUMBER_ENCODER, \ 13 | NUMBER_DECODER, SYMBOL_ENCODER, CONSTANTS 14 | from equation_verification.dataset_loading import BinaryEqnTree 15 | 16 | 17 | class LSTMchain(torch.nn.Module): 18 | def __init__(self, num_hidden, dropout): 19 | super().__init__() 20 | 21 | self.layer = LSTMnode(num_hidden, num_hidden) 22 | 23 | setattr(self, NUMBER_ENCODER, nn.Sequential(OrderedDict([ 24 | ('linear1', nn.Linear(1, num_hidden)), 25 | ('sgmd1', nn.Sigmoid()), 26 | ('linear2', nn.Linear(num_hidden, num_hidden)), 27 | ('sgmd2', nn.Sigmoid()) 28 | ]))) 29 | setattr(self, NUMBER_DECODER, nn.Sequential(OrderedDict([ 30 | ('linear1', nn.Linear(num_hidden, num_hidden)), 31 | ('sgmd1', nn.Sigmoid()), 32 | ('linear2', nn.Linear(num_hidden, 1)), 33 | ]))) 34 | setattr(self, SYMBOL_ENCODER, nn.Embedding(num_embeddings=len(VOCAB), 35 | embedding_dim=num_hidden)) 36 | self.bias = nn.Parameter(torch.FloatTensor([0])) 37 | self.num_hidden = num_hidden 38 | self.dropout = dropout 39 | 40 | def forward(self, tree, hidden, trace=None): 41 | """ 42 | hidden is a tuple of hidden state and memory 43 | hidden = (h, c) 44 | """ 45 | # print('modules:', self._modules) 46 | # if not tree.is_leaf and tree.function_name not in self._modules: 47 | # raise AssertionError("Unknown functional node: %s" % tree.function_name) 48 | 49 | nn_block = getattr(self, SYMBOL_ENCODER) 50 | if tree.is_binary: 51 | root_value = VOCAB['Binary_'+tree.function_name] 52 | root_encoded_value = Variable(torch.LongTensor([root_value])[0]) 53 | inp = nn_block(root_encoded_value) 54 | h, c = self.layer(inp, hidden, dropout=self.dropout) 55 | 56 | parentheses_value_l = VOCAB['parentheses_('] 57 | left_parentheses = Variable(torch.LongTensor([parentheses_value_l])[0]) 58 | inp = nn_block(left_parentheses) 59 | h,c = self.layer(inp, (h,c), dropout=self.dropout) 60 | 61 | hl, cl = self(tree.lchild, (h,c), trace=trace.lchild if trace else None) 62 | 63 | comma_value = VOCAB['parentheses_,'] 64 | comma = Variable(torch.LongTensor([comma_value])[0]) 65 | inp = nn_block(comma) 66 | h, c = self.layer(inp, (hl,cl), dropout=self.dropout) 67 | 68 | hr, cr = self(tree.rchild, (h,c), trace=trace.rchild if trace else None) 69 | 70 | parentheses_value_r = VOCAB['parentheses_)'] 71 | right_parentheses = Variable(torch.LongTensor([parentheses_value_r])[0]) 72 | inp = nn_block(right_parentheses) 73 | h, c = self.layer(inp, (hr, cr), dropout=self.dropout) 74 | 75 | return h, c 76 | 77 | elif tree.is_unary: 78 | 79 | if tree.function_name in {NUMBER_DECODER, NUMBER_ENCODER, 80 | SYMBOL_ENCODER}: 81 | 82 | hl, _ = self(tree.lchild, hidden, trace=trace.lchild if trace else None) 83 | nn_block = getattr(self, tree.function_name) 84 | inp = nn_block(hl) 85 | h,c = self.layer(inp, hidden, dropout=self.dropout) 86 | 87 | else: 88 | root_value = VOCAB['Unary_'+tree.function_name] 89 | root_encoded_value = Variable(torch.LongTensor([root_value])[0]) 90 | inp = nn_block(root_encoded_value) 91 | h, c = self.layer(inp, hidden, dropout=self.dropout) 92 | 93 | parentheses_value_l = VOCAB['parentheses_('] 94 | left_parentheses = Variable(torch.LongTensor([parentheses_value_l])[0]) 95 | inp = nn_block(left_parentheses) 96 | h,c = self.layer(inp, (h,c), dropout=self.dropout) 97 | 98 | hl, cl = self(tree.lchild, (h,c), trace=trace.lchild if trace else None) 99 | 100 | parentheses_value_r = VOCAB['parentheses_)'] 101 | right_parentheses = Variable(torch.LongTensor([parentheses_value_r])[0]) 102 | inp = nn_block(right_parentheses) 103 | h, c = self.layer(inp, (hl, cl), dropout=self.dropout) 104 | 105 | return h, c 106 | 107 | elif tree.is_leaf: 108 | c = Variable(torch.LongTensor([0] * self.num_hidden)) 109 | if trace: 110 | trace.output = tree.encoded_value.tolist() 111 | trace.memory = c.tolist() 112 | return tree.encoded_value, c 113 | else: 114 | raise RuntimeError("Invalid tree:\n%s" % repr(self)) 115 | 116 | def compute_batch(self, batch, trace=None): 117 | record = [] 118 | total_loss = 0 119 | for tree, label, depth in batch: 120 | if trace is not None: 121 | trace_item = eval(repr(tree)) 122 | trace.append(trace_item) 123 | else: 124 | trace_item = None 125 | 126 | hl = Variable(torch.FloatTensor([0] * self.num_hidden)) 127 | cl = Variable(torch.FloatTensor([0] * self.num_hidden)) 128 | lchild, _ = self(tree.lchild, (hl,cl), trace=trace_item.lchild if trace else None) 129 | hr = Variable(torch.FloatTensor([0] * self.num_hidden)) 130 | cr = Variable(torch.FloatTensor([0] * self.num_hidden)) 131 | rchild, _ = self(tree.rchild, (hr, cr), trace=trace_item.rchild if trace else None) 132 | 133 | if tree.is_numeric(): 134 | assert (tree.lchild.is_a_floating_point and tree.rchild.function_name == NUMBER_DECODER) \ 135 | or (tree.rchild.is_a_floating_point and tree.lchild.function_name == NUMBER_DECODER) 136 | loss = (lchild - rchild) * (lchild - rchild) 137 | correct = math.isclose(lchild.item(), rchild.item(), rel_tol=1e-3) 138 | if trace_item is not None: 139 | trace_item.probability = lchild.item() 140 | else: 141 | out = torch.cat((Variable(torch.FloatTensor([0])), torch.dot(lchild, rchild).unsqueeze(0) + self.bias), dim=0) 142 | loss = - F.log_softmax(out)[round(label.item())] 143 | correct = F.softmax(out)[round(label.item())].item() > 0.5 144 | 145 | if trace_item is not None: 146 | trace_item.probability = F.softmax(out)[1].item() 147 | trace_item.correct = correct 148 | trace_item.bias = self.bias.tolist() 149 | assert isinstance(correct, bool) 150 | record.append({ 151 | "ex": tree, 152 | "label": round(label.item()), 153 | "loss": loss.item(), 154 | "correct": correct, 155 | "depth": depth, 156 | "score": out[1].item() # WARNING: only works for symbolic data 157 | }) 158 | total_loss += loss 159 | return record, total_loss / len(batch) 160 | 161 | 162 | class LSTMnode(torch.nn.Module): 163 | def __init__(self, num_input, num_hidden): 164 | super().__init__() 165 | self.data = nn.Linear(num_input*2, num_hidden, bias=True) 166 | self.forget = nn.Linear(num_input*2, num_hidden, bias=True) 167 | self.output = nn.Linear(num_input*2, num_hidden, bias=True) 168 | self.input = nn.Linear(num_input*2, num_hidden, bias=True) 169 | 170 | def forward(self, inp, hidden, trace=None, dropout=None): 171 | """ 172 | 173 | Args: 174 | inp : (num_hidden,) 175 | hidden: ((num_hidden,), (num_hidden,)) 176 | 177 | Returns: 178 | (num_hidden,), (num_hidden) 179 | """ 180 | h, c = hidden 181 | h = torch.cat((h, inp), dim=0) 182 | i = F.sigmoid(self.data(h)) 183 | f = F.sigmoid(self.forget(h)) 184 | o = F.sigmoid(self.output(h)) 185 | u = F.tanh(self.input(h)) 186 | if dropout is None: 187 | c = i * u + f * c 188 | else: 189 | c = i * F.dropout(u,p=dropout,training=self.training) + f * c 190 | h = o * F.tanh(c) 191 | if trace: 192 | trace.output = h.tolist() 193 | trace.memory = c.tolist() 194 | trace.i = [f.tolist()] 195 | return h, c 196 | 197 | 198 | 199 | 200 | class RNNchain(torch.nn.Module): 201 | def __init__(self, num_hidden, dropout): 202 | super().__init__() 203 | 204 | self.layer = nn.Linear(2*num_hidden, num_hidden) 205 | self.act1 = nn.Sigmoid() 206 | self.layer2 = nn.Linear(num_hidden,num_hidden) 207 | self.act2 = nn.Sigmoid() 208 | 209 | setattr(self, NUMBER_ENCODER, nn.Sequential(OrderedDict([ 210 | ('linear1', nn.Linear(1, num_hidden)), 211 | ('sgmd1', nn.Sigmoid()), 212 | ('linear2', nn.Linear(num_hidden, num_hidden)), 213 | ('sgmd2', nn.Sigmoid()) 214 | ]))) 215 | setattr(self, NUMBER_DECODER, nn.Sequential(OrderedDict([ 216 | ('linear1', nn.Linear(num_hidden, num_hidden)), 217 | ('sgmd1', nn.Sigmoid()), 218 | ('linear2', nn.Linear(num_hidden, 1)), 219 | ]))) 220 | setattr(self, SYMBOL_ENCODER, nn.Embedding(num_embeddings=len(VOCAB), 221 | embedding_dim=num_hidden)) 222 | self.bias = nn.Parameter(torch.FloatTensor([0])) 223 | self.num_hidden = num_hidden 224 | self.dropout = dropout 225 | 226 | def forward(self, tree, hidden, trace=None): 227 | """ 228 | hidden is the hidden state 229 | hidden = h 230 | """ 231 | # print('modules:', self._modules) 232 | # if not tree.is_leaf and tree.function_name not in self._modules: 233 | # raise AssertionError("Unknown functional node: %s" % tree.function_name) 234 | 235 | nn_block = getattr(self, SYMBOL_ENCODER) 236 | if tree.is_binary: 237 | root_value = VOCAB['Binary_'+tree.function_name] 238 | root_encoded_value = Variable(torch.LongTensor([root_value])[0]) 239 | inp = nn_block(root_encoded_value) 240 | inp = torch.cat((inp,hidden),dim=0) 241 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 242 | 243 | parentheses_value_l = VOCAB['parentheses_('] 244 | left_parentheses = Variable(torch.LongTensor([parentheses_value_l])[0]) 245 | inp = nn_block(left_parentheses) 246 | inp = torch.cat((inp,h),dim=0) 247 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 248 | 249 | hl = self(tree.lchild, h, trace=trace.lchild if trace else None) 250 | 251 | comma_value = VOCAB['parentheses_,'] 252 | comma = Variable(torch.LongTensor([comma_value])[0]) 253 | inp = nn_block(comma) 254 | inp = torch.cat((inp,hl),dim=0) 255 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 256 | 257 | hr = self(tree.rchild, h, trace=trace.rchild if trace else None) 258 | 259 | parentheses_value_r = VOCAB['parentheses_)'] 260 | right_parentheses = Variable(torch.LongTensor([parentheses_value_r])[0]) 261 | inp = nn_block(right_parentheses) 262 | inp = torch.cat((inp,hr),dim=0) 263 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 264 | 265 | return h 266 | 267 | elif tree.is_unary: 268 | 269 | if tree.function_name in {NUMBER_DECODER, NUMBER_ENCODER, 270 | SYMBOL_ENCODER}: 271 | 272 | hl = self(tree.lchild, hidden, trace=trace.lchild if trace else None) 273 | nn_block = getattr(self, tree.function_name) 274 | inp = nn_block(hl) 275 | inp = torch.cat((inp,hidden),dim=0) 276 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 277 | 278 | else: 279 | root_value = VOCAB['Unary_'+tree.function_name] 280 | root_encoded_value = Variable(torch.LongTensor([root_value])[0]) 281 | inp = nn_block(root_encoded_value) 282 | inp = torch.cat((inp,hidden),dim=0) 283 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 284 | 285 | parentheses_value_l = VOCAB['parentheses_('] 286 | left_parentheses = Variable(torch.LongTensor([parentheses_value_l])[0]) 287 | inp = nn_block(left_parentheses) 288 | inp = torch.cat((inp,h),dim=0) 289 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 290 | 291 | hl = self(tree.lchild, h, trace=trace.lchild if trace else None) 292 | 293 | parentheses_value_r = VOCAB['parentheses_)'] 294 | right_parentheses = Variable(torch.LongTensor([parentheses_value_r])[0]) 295 | inp = nn_block(right_parentheses) 296 | inp = torch.cat((inp,hl),dim=0) 297 | h = self.act1(F.dropout(self.layer(inp),p=self.dropout,training=self.training)) 298 | 299 | return h 300 | 301 | elif tree.is_leaf: 302 | if trace: 303 | trace.output = tree.encoded_value.tolist() 304 | #trace.memory = c.tolist() 305 | return tree.encoded_value 306 | else: 307 | raise RuntimeError("Invalid tree:\n%s" % repr(self)) 308 | 309 | def compute_batch(self, batch, trace=None): 310 | record = [] 311 | total_loss = 0 312 | for tree, label, depth in batch: 313 | if trace is not None: 314 | trace_item = eval(repr(tree)) 315 | trace.append(trace_item) 316 | else: 317 | trace_item = None 318 | 319 | hl = Variable(torch.FloatTensor([0] * self.num_hidden)) 320 | lchild = self(tree.lchild, hl, trace=trace_item.lchild if trace else None) 321 | lchild = self.act2(F.dropout(self.layer2(lchild),p=self.dropout,training=self.training)) 322 | hr = Variable(torch.FloatTensor([0] * self.num_hidden)) 323 | rchild = self(tree.rchild, hr, trace=trace_item.rchild if trace else None) 324 | rchild = self.act2(F.dropout(self.layer2(rchild),p=self.dropout,training=self.training)) 325 | 326 | if tree.is_numeric(): 327 | assert (tree.lchild.is_a_floating_point and tree.rchild.function_name == NUMBER_DECODER) \ 328 | or (tree.rchild.is_a_floating_point and tree.lchild.function_name == NUMBER_DECODER) 329 | loss = (lchild - rchild) * (lchild - rchild) 330 | correct = math.isclose(lchild.item(), rchild.item(), rel_tol=1e-3) 331 | if trace_item is not None: 332 | trace_item.probability = lchild.item() 333 | else: 334 | out = torch.cat((Variable(torch.FloatTensor([0])), torch.dot(lchild, rchild).unsqueeze(0) + self.bias), dim=0) 335 | loss = - F.log_softmax(out)[round(label.item())] 336 | correct = F.softmax(out)[round(label.item())].item() > 0.5 337 | 338 | if trace_item is not None: 339 | trace_item.probability = F.softmax(out)[1].item() 340 | trace_item.correct = correct 341 | trace_item.bias = self.bias.tolist() 342 | assert isinstance(correct, bool) 343 | record.append({ 344 | "ex": tree, 345 | "label": round(label.item()), 346 | "loss": loss.item(), 347 | "correct": correct, 348 | "depth": depth, 349 | "score": out[1].item() # WARNING: only works for symbolic data 350 | }) 351 | total_loss += loss 352 | return record, total_loss / len(batch) 353 | -------------------------------------------------------------------------------- /equation_verification/equation_completion_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import pickle 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import torch 8 | import random 9 | import json 10 | 11 | from equation_verification.dataset_loading import load_equation_tree_examples, \ 12 | load_single_equation_tree_example, sequential_sampler, \ 13 | build_equation_tree_examples_list, load_equation_completion_batch, \ 14 | extract_candidates, load_equation_completion_blank_example 15 | from equation_verification.nn_tree_model import build_nnTree 16 | from optimizers import build_optimizer 17 | from parse import parse_equation 18 | from equation_verification.constants import UNARY_FNS 19 | 20 | NCHOICES = 0 21 | CUT = 100000 22 | 23 | class EquationCompletionExperiment: 24 | 25 | def __init__(self): 26 | """ 27 | Hyperparameters defined here should match nn_tree_experiment.py exactly, 28 | reasoning being that this experiment relies on a model trained using 29 | nn_tree_experiment.py, so all the configurations should be the same. 30 | """ 31 | parser = argparse.ArgumentParser( 32 | description="Train tree-LSTM on generated equalities", 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | 35 | # define and parse hyper-parameters of model/training from commandline 36 | parser.add_argument('--num-hidden', type=int, default=50, 37 | help='hidden layer size') 38 | parser.add_argument('--num-embed', type=int, default=50, 39 | help='embedding layer size') 40 | parser.add_argument('--memory-size', type=int, default=5, 41 | help='max size of the stack/queue') 42 | parser.add_argument('--lr', type=float, default=0.001, 43 | help='initial learning rate') 44 | parser.add_argument('--optimizer', type=str, default='adam', 45 | help='the optimizer type') 46 | parser.add_argument('--mom', type=float, default=0.2, 47 | help='momentum for sgd') 48 | parser.add_argument('--wd', type=float, default=0.00001, 49 | help='weight decay for sgd') 50 | parser.add_argument('--beta1', type=float, default=0.9, 51 | help='beta 1 for optimizer') 52 | parser.add_argument('--beta2', type=float, default=0.999, 53 | help='beta 2 for optimizer') 54 | parser.add_argument('--dropout', type=float, default=None, 55 | help='dropout probability (1.0 - keep probability)') 56 | 57 | # define and parse conditions of the experiment from commandline 58 | parser.add_argument('--model-class', type=str, default="NNTrees", 59 | help='the classname of the model to run') 60 | parser.add_argument('--train-path', type=str, default=None, 61 | help='path to training examples') 62 | parser.add_argument('--candidate-path', type=str, default=None, 63 | help='path to equation completion candidate answers') 64 | parser.add_argument('--unify-one-zero', type=eval, default=True, 65 | help='whether to unify ones and zeros to integer') 66 | parser.add_argument('--validation-path', type=str, default=None, nargs="+", 67 | help='path(s) to validation examples, if multiple files supplied, will evaluate each individually') 68 | parser.add_argument('--test-path', type=str, default=None, 69 | help='path to test examples') 70 | parser.add_argument('--num-epochs', type=int, default=100, 71 | help='max num of epochs') 72 | parser.add_argument('--seed', type=int, default=12093, 73 | help='max num of epochs') 74 | parser.add_argument('--batch-size', type=int, default=1, 75 | help='the batch size') 76 | parser.add_argument('--share-memory-params', default=False, 77 | action='store_true', 78 | help='whether to allow weight sharing for memory ' 79 | 'operations') 80 | parser.add_argument('--disable-sharing', default=False, 81 | action='store_true', 82 | help='whether to not allow weight sharing for memory ' 83 | 'operations') 84 | parser.add_argument('--no-op', default=False, 85 | action='store_true', 86 | help='whether to add no operation to stack push ' 87 | 'and pop') 88 | parser.add_argument('--no-pop', default=False, 89 | action='store_true', 90 | help='whether to add just push and no-op') 91 | parser.add_argument('--stack-type', type=str, default='simple', 92 | help='choose the stack type. options are: simple, nn_stack, full_stack, add_stack, simple_gated, full_stack_gated') 93 | parser.add_argument('--likeLSTM', default=False, 94 | action='store_true', 95 | help='whether to make the mem2out stack tree completely like an LSTM+stack') 96 | parser.add_argument('--gate-push-pop', default=False, 97 | action='store_true', 98 | help='whether to make the push pop action a gate rather than a number') 99 | parser.add_argument('--normalize-action', default=False, 100 | action='store_true', 101 | help='whether to normalize push pop weight to 1 before pushing and poping') 102 | parser.add_argument('--gate-top-k', default=False, 103 | action='store_true', 104 | help='whether to gate the top-k instead of weighted average') 105 | parser.add_argument('--top-k', type=int, default=5, 106 | help='the top-k stack elements will be used for computing the output' 107 | 'select k. NOTE: k is in range(0,memory_size)') 108 | parser.add_argument('--numeric', default=False, 109 | action='store_true', 110 | help='whether to train on numeric equations') 111 | parser.add_argument('--fast', default=False, 112 | action='store_true', 113 | help='whether to evaluate only on 20% training') #TODO: This is not used by any code rn 114 | parser.add_argument('--verbose', default=False, 115 | action='store_true', 116 | help='whether to print execusion trace outputs') 117 | parser.add_argument('--curriculum', default=None, 118 | help='what type of cirriculum (depth/func)') 119 | parser.add_argument('--switch-epoch', default=None, type=int, 120 | help='epoch to switch over (2stage curriculum)') 121 | parser.add_argument('--curriculum-depth', default=None, type=int, 122 | help='max depth in curriculum') 123 | parser.add_argument('--eval-depth', nargs="+", type=int, default=None, 124 | help='list of depth to evaluate on') 125 | parser.add_argument('--tree-node-activation', type=str, default='sigmoid', 126 | help='choose the activation for tree-node') 127 | parser.add_argument('--stack-node-activation', type=str, 128 | default='sigmoid', 129 | help='choose the activation for stack-node') 130 | 131 | 132 | # define and parse logging and save/loading options of the experiment 133 | parser.add_argument('--model-prefix', type=str, default=None, 134 | help='path to save/load model') 135 | parser.add_argument('--result-path', type=str, default=None, 136 | help='path to save results') 137 | parser.add_argument('--load-epoch', type=int, default=None, 138 | help='load from epoch') 139 | parser.add_argument('--evaluate-only', action="store_true", default=False, 140 | help='evaluate only') 141 | parser.add_argument('--trace-path', type=str, default=None, 142 | help='path to save traces') 143 | parser.add_argument('--disp-epoch', type=int, default=1, 144 | help='show progress for every n epochs') 145 | parser.add_argument('--checkpoint-every-n-epochs', type=int, default=5, 146 | help='save model for every n epochs') 147 | parser.add_argument('--interactive', action='store_true', default=False, 148 | help='interactive evaluation') 149 | parser.add_argument('--cut', type=int, default=1000_000, 150 | help='cut after this many examples') 151 | 152 | self.args = parser.parse_args() 153 | with open(self.args.candidate_path) as cddf: 154 | cddjson = json.loads(cddf.read()) 155 | self.candidate_trees = [tup[0] for tup in build_equation_tree_examples_list(cddjson, unify_one_zero=self.args.unify_one_zero)] 156 | self.candidate_classes = defaultdict(list) 157 | for i,tree in enumerate(self.candidate_trees): 158 | self.candidate_classes[tree.cls].append(i) 159 | 160 | def load_model_and_optim(self, epoch, model, optimizer): 161 | model_path = "%s-model-%d.checkpoint" % (self.args.model_prefix, epoch) 162 | optim_path = "%s-optim-%d.checkpoint"% (self.args.model_prefix, epoch) 163 | model.load_state_dict(torch.load(model_path)) 164 | optimizer.load_state_dict(torch.load(optim_path)) 165 | 166 | def eval(self, model, loader, trace=None): 167 | model.eval() 168 | aggregate_record = [] 169 | aggregate_loss = 0 170 | for batch in loader: 171 | record, loss = model.compute_batch(batch, trace=trace) 172 | aggregate_record.extend(record) 173 | aggregate_loss += loss.item() 174 | return aggregate_record, (aggregate_loss / len(aggregate_record)) if len(aggregate_record) != 0 else 0 175 | 176 | def aggregate_statistics(self, record, depth=None): 177 | assert len(record) % NCHOICES == 0 178 | if depth is not None: 179 | record = [item for item in record if item["depth"]==depth] 180 | if len(record) == 0: 181 | raise ValueError("empty record at depth %d" % depth) 182 | result = {} 183 | ranks = [] 184 | for i in range(0,len(record), NCHOICES): 185 | scores = [item["score"] for item in record[i:i+NCHOICES]] 186 | labels = [item["label"] for item in record[i:i+NCHOICES]] 187 | true_ex = record[i+labels.index(1)]["ex"] 188 | indices = list(range(NCHOICES)) 189 | assert sum(labels) == 1 # should only be one correct 190 | sorted_scores = sorted(list(zip(scores,labels,indices)), key=lambda x:x[0], reverse=True) 191 | print("-------- -------- -------- -------- -------- --------") 192 | print(true_ex.pretty_str()) 193 | for i,(score, label, ind) in enumerate(sorted_scores): 194 | if label == 1: 195 | ranks.append(i+1) 196 | if i < 20: 197 | print(score, 1/(1+math.exp(score)),self.candidate_trees[ind].pretty_str()) 198 | print("rank", ranks[-1]) 199 | for i in range(10): 200 | name = f"top{i}_acc" 201 | val = len([rank for rank in ranks if rank <= i+1]) / len(ranks) 202 | result[name] = val 203 | result["mrr"] = sum([1/rank for rank in ranks]) / len(ranks) 204 | return result 205 | 206 | def log(self, record, path, epoch, name): 207 | depths = sorted(list(set(item["depth"] for item in record))) 208 | stats = self.aggregate_statistics(record) 209 | with open(path, "a") as fout: 210 | log_entry = {"type": "eval", "epoch":epoch, "set":name, "stats": {"all": stats}} 211 | for depth in depths: 212 | stats_d = self.aggregate_statistics(record, depth=depth) 213 | log_entry["stats"][depth] = stats_d 214 | fout.write(json.dumps(log_entry) + "\n") 215 | 216 | def aggregate_statistics_rank(self, rank_record, depth=None): 217 | if depth is not None: 218 | rank_record = [item for item in rank_record if item["depth"]==depth] 219 | if len(rank_record) == 0: 220 | raise ValueError("empty rank_record at depth %d" % depth) 221 | result = {} 222 | for ranktype in rank_record[0]["ranks"].keys(): 223 | ranks = [item["ranks"][ranktype] for item in rank_record] 224 | for i in range(10): 225 | name = f"top{i+1}_acc_{ranktype}" 226 | val = len([rank for rank in ranks if rank <= i+1]) / len(ranks) 227 | result[name] = val 228 | result[f"mrr_{ranktype}"] = sum([1/rank for rank in ranks]) / len(ranks) 229 | result["count"] = len(rank_record) 230 | return result 231 | 232 | def log_rank(self, ranks, path, epoch, name): 233 | depths = sorted(list(set(item["depth"] for item in ranks))) 234 | stats = self.aggregate_statistics_rank(ranks) 235 | with open(path, "a") as fout: 236 | log_entry = {"type": "eval", "epoch":epoch, "set":name, "stats": {"all": stats},"raw":ranks} 237 | for depth in depths: 238 | stats_d = self.aggregate_statistics_rank(ranks, depth=depth) 239 | log_entry["stats"][depth] = stats_d 240 | fout.write(json.dumps(log_entry) + "\n") 241 | 242 | def run(self): 243 | """ 244 | Main routine for running this experiment. 245 | """ 246 | torch.manual_seed(self.args.seed) 247 | random.seed(self.args.seed) 248 | global NCHOICES 249 | with open(self.args.candidate_path) as cddf: 250 | cddjson = json.loads(cddf.read()) 251 | candidates = extract_candidates(cddjson) 252 | NCHOICES = 0 253 | for group in cddjson: 254 | NCHOICES += len(group) 255 | assert NCHOICES == len(candidates) 256 | print(f"{NCHOICES} number of candidates for equation completion") 257 | 258 | model = build_nnTree(self.args.model_class, 259 | self.args.num_hidden, 260 | self.args.num_embed, 261 | self.args.memory_size, 262 | self.args.share_memory_params, 263 | self.args.dropout, 264 | self.args.no_op, 265 | self.args.stack_type, 266 | self.args.top_k, 267 | self.args.verbose, 268 | self.args.tree_node_activation, 269 | self.args.stack_node_activation, 270 | self.args.no_pop, 271 | self.args.disable_sharing, 272 | self.args.likeLSTM, 273 | self.args.gate_push_pop, 274 | self.args.gate_top_k, 275 | self.args.normalize_action) 276 | 277 | totParams = 0 278 | for p in model.parameters(): 279 | totParams += p.numel() 280 | print('total num params:', totParams) 281 | 282 | optimizer = build_optimizer((param for param in model.parameters() if 283 | param.requires_grad), 284 | self.args.optimizer, 285 | self.args.lr, 286 | self.args.mom, 287 | self.args.wd, 288 | self.args.beta1, 289 | self.args.beta2) 290 | 291 | if self.args.load_epoch is None or self.args.evaluate_only is False: 292 | with open(self.args.result_path, "wt") as _: # TODO: if fine-tuning option is added this should be fixed 293 | pass 294 | 295 | 296 | if self.args.evaluate_only: 297 | self.load_model_and_optim(self.args.load_epoch, model, optimizer) 298 | with open(self.args.test_path,"rt") as f: 299 | groups = json.loads(f.read()) 300 | # record_agg = [] 301 | rank_agg = [] 302 | done_counter = 0 303 | for group in groups: 304 | for example in group: 305 | if done_counter >= self.args.cut: 306 | break 307 | if example["blankNodeNum"] == "0": 308 | continue 309 | if "Number" in example["equation"]["func"]: 310 | continue 311 | done_counter += 1 312 | print(done_counter) 313 | blank, lbbb, ddd = load_equation_completion_blank_example(example) 314 | test_loader = load_equation_completion_batch( 315 | [[example]], 316 | batch_size=self.args.batch_size, 317 | numeric=self.args.numeric, 318 | eval_depth=self.args.eval_depth, 319 | unify_one_zero=self.args.unify_one_zero, 320 | equation_completion=True, 321 | candidates=candidates 322 | ) 323 | record, _ = self.eval(model, test_loader, trace=None) 324 | assert len(record) == NCHOICES 325 | # record_agg.extend([{"score":item["score"], 326 | # "label":item["label"], 327 | # "depth":item["depth"], 328 | # "ex":item["ex"] if item["label"] else None} for item in record]) 329 | # computing statistics 330 | ranks = self.compute_statistic_single(blank, record) 331 | rank_agg.append(ranks) 332 | if done_counter >= self.args.cut: 333 | break 334 | # self.log(record_agg, self.args.result_path, self.args.load_epoch, "test") 335 | self.log_rank(rank_agg, self.args.result_path, self.args.load_epoch, "test") 336 | else: 337 | print("doing nothing, please set flag --evaluate-only") 338 | 339 | def compute_statistic_single(self, blank, record): 340 | scores = [item["score"] for item in record] 341 | labels = [item["label"] for item in record] 342 | indices = list(range(NCHOICES)) 343 | classes = [tree.cls for tree in self.candidate_trees] 344 | assert sum(labels) == 1 # should only be one correct 345 | sorted_scores = sorted(list(zip(scores,labels,indices,classes)), key=lambda x:x[0], reverse=True) 346 | print("-------- -------- -------- -------- -------- --------") 347 | print(blank.pretty_str()) 348 | raw_rank = None 349 | true_class = None 350 | for i,(score, label, ind, cls) in enumerate(sorted_scores): 351 | if label == 1: 352 | raw_rank = i+1 353 | true_class = cls 354 | print(score, math.exp(score)/(1+math.exp(score)), cls, f"[{self.candidate_trees[ind].pretty_str()}]") 355 | if i < 20: 356 | print(score, math.exp(score)/(1+math.exp(score)), cls,self.candidate_trees[ind].pretty_str()) 357 | print("raw_rank", raw_rank) 358 | class_rank = None 359 | for i,(score, label, ind, cls) in enumerate(sorted_scores): 360 | if class_rank is None and cls == true_class: 361 | class_rank = i+1 362 | print(score, math.exp(score)/(1+math.exp(score)), cls, f"[{self.candidate_trees[ind].pretty_str()}]") 363 | if i < 20: 364 | print(score, math.exp(score)/(1+math.exp(score)), cls, self.candidate_trees[ind].pretty_str()) 365 | print("class_rank", class_rank) 366 | collapse_scores = [] 367 | seen = set() 368 | for i,(score, label, ind, cls) in enumerate(sorted_scores): 369 | if cls not in seen: 370 | collapse_scores.append((score, label, ind, cls)) 371 | seen.add(cls) 372 | else: 373 | continue 374 | collapse_rank = None 375 | for i,(score, label, ind, cls) in enumerate(collapse_scores): 376 | if cls == true_class: 377 | collapse_rank = i+1 378 | print(score, math.exp(score)/(1+math.exp(score)), cls, f"[{self.candidate_trees[ind].pretty_str()}]") 379 | if i < 20: 380 | print(score, math.exp(score)/(1+math.exp(score)), cls,self.candidate_trees[ind].pretty_str()) 381 | print("collapse_rank", collapse_rank) 382 | random_collapse_ranks = [] 383 | for sample_s in range(10): 384 | random_collapse_inds = set(random.choice(cands) for cands in self.candidate_classes.values()) 385 | random_collapse_scores = [(s,l,ind,c) for (s,l,ind,c) in sorted_scores if ind in random_collapse_inds] 386 | random_collapse_rank = None 387 | for i,(score, label, ind, cls) in enumerate(random_collapse_scores): 388 | if cls == true_class: 389 | random_collapse_rank = i+1 390 | print(score, math.exp(score)/(1+math.exp(score)), cls, f"[{self.candidate_trees[ind].pretty_str()}]") 391 | if i < 20: 392 | print(score, math.exp(score)/(1+math.exp(score)), cls,self.candidate_trees[ind].pretty_str()) 393 | random_collapse_ranks.append(random_collapse_rank) 394 | random_collapse_rank = sum(random_collapse_ranks) / len(random_collapse_ranks) 395 | random_collapse_rank_std = (sum([(r-random_collapse_rank) ** 2 for r in random_collapse_ranks]))**0.5 396 | print("random_collapse_rank", random_collapse_rank) 397 | print() 398 | print("raw_rank", raw_rank) 399 | print("class_rank", class_rank) 400 | print("collapse_rank", collapse_rank) 401 | print("random_collapse_rank", random_collapse_rank) 402 | bins = [0] * 20 403 | unit = 1/len(bins) 404 | for (score, label, ind, cls) in sorted_scores: 405 | bin_idx = math.floor((math.exp(score)/(1+math.exp(score)) / unit)) 406 | if bin_idx >= len(bins): 407 | assert bin_idx == len(bins) 408 | bin_idx = bin_idx - 1 409 | bins[bin_idx] += 1 410 | maxcnt = max(bins) 411 | normalized_bins = [cnt/maxcnt for cnt in bins] 412 | print("1.0") 413 | for ncnt in reversed(normalized_bins): 414 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 415 | normalized_bins = [cnt/NCHOICES for cnt in bins] 416 | print("0.0") 417 | print("1.0") 418 | for ncnt in reversed(normalized_bins): 419 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 420 | print("0.0") 421 | return {"ranks":{"raw":raw_rank, "class":class_rank, "collapse":collapse_rank, "random_collapse":random_collapse_rank},"depth":blank.depth,"scores":sorted_scores,"random_collapse_std":random_collapse_rank_std} 422 | 423 | 424 | if __name__ == "__main__": 425 | # train-path as input 426 | EquationCompletionExperiment().run() -------------------------------------------------------------------------------- /equation_verification/dataset_loading.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from torch.utils.data import Dataset, RandomSampler, SequentialSampler, DataLoader 5 | from equation_verification.constants import VOCAB, SYMBOL_CLASSES, CONSTANTS, BINARY_FNS, UNARY_FNS, \ 6 | NUMBER_ENCODER, SYMBOL_ENCODER, NUMBER_DECODER, PRETTYNAMES 7 | from torch.autograd import Variable 8 | 9 | 10 | def encoded_batch(examples): 11 | # encode the values in-place 12 | def encode_value_at_node(node): 13 | if not node.is_leaf: 14 | node.value = None 15 | node.encoded_value = None 16 | elif node.is_a_floating_point: 17 | node.value = float(node.function_name) 18 | node.encoded_value = Variable(torch.FloatTensor([node.value])) 19 | else: 20 | node.value = VOCAB[node.function_name] 21 | node.encoded_value = Variable(torch.LongTensor([node.value])[0]) 22 | encoded_examples = [] 23 | for tree, label, depth in examples: 24 | tree_copy = eval(repr(tree)) 25 | tree_copy.apply(encode_value_at_node) 26 | encoded_examples.append((tree_copy, Variable(torch.FloatTensor([float(label)])), depth)) 27 | return encoded_examples 28 | 29 | def sequential_sampler(trios, batch_size): 30 | trios = encoded_batch(trios) 31 | # packaging the data using PyTorch compatible structures 32 | dataset = ExampleDataset(trios) 33 | sampler = SequentialSampler(dataset) 34 | loader = DataLoader( 35 | dataset, 36 | batch_size=batch_size, 37 | sampler=sampler, 38 | num_workers=0, 39 | collate_fn=lambda x: x 40 | # TODO: learn whether it helps to encode here 41 | ) 42 | return loader 43 | 44 | def load_equation_completion_blank_example(example_json): 45 | example = generate_blank_example(example_json) 46 | return load_single_equation_tree_example(example) 47 | 48 | def load_equation_completion_batch(example_json, batch_size=1, numeric=False, eval_depth=None, unify_one_zero=True, filter=None, equation_completion=False, candidates=None): 49 | test_trio = build_equation_tree_examples_list(example_json, numeric=numeric, depth=eval_depth, unify_one_zero=unify_one_zero, filter=filter, equation_completion=equation_completion, candidates=candidates) 50 | test_trio = encoded_batch(test_trio) 51 | print("test size: %d" % len(test_trio)) 52 | # packaging the data using PyTorch compatible structures 53 | test_dataset = ExampleDataset(test_trio) 54 | test_sampler = SequentialSampler(test_dataset) 55 | test_loader = DataLoader( 56 | test_dataset, 57 | batch_size=batch_size, 58 | sampler=test_sampler, 59 | num_workers=0, 60 | collate_fn=lambda x: x 61 | # TODO: learn whether it helps to encode here 62 | ) 63 | 64 | return test_loader 65 | 66 | def load_equation_tree_examples(train_path, validation_path, test_path, batch_size=1, numeric=False, eval_depth=None, unify_one_zero=True, filter=None): 67 | with open(train_path, "rt") as fin: 68 | train_json = json.loads(fin.read()) 69 | train_trio = build_equation_tree_examples_list(train_json, numeric=numeric, unify_one_zero=unify_one_zero, filter=filter) 70 | train_trio = encoded_batch(train_trio) 71 | print("Train size: %d" % len(train_trio)) 72 | # packaging the data using PyTorch compatible structures 73 | train_dataset = ExampleDataset(train_trio) 74 | train_sampler = RandomSampler(train_dataset) 75 | train_loader = DataLoader( 76 | train_dataset, 77 | batch_size=batch_size, 78 | sampler=train_sampler, 79 | num_workers=0, 80 | collate_fn=lambda x:x #TODO: learn whether it helps to encode here 81 | ) 82 | train_eval_sampler = SequentialSampler(train_dataset) 83 | train_eval_loader = DataLoader( 84 | train_dataset, 85 | batch_size=batch_size, 86 | sampler=train_eval_sampler, 87 | num_workers=0, 88 | collate_fn=lambda x: x # TODO: learn whether it helps to encode here 89 | ) 90 | validation_loaders = [] 91 | for v_path in validation_path: 92 | with open(v_path, "rt") as fin: 93 | validation_json = json.loads(fin.read()) 94 | validation_trio = build_equation_tree_examples_list(validation_json, depth=eval_depth, unify_one_zero=unify_one_zero, filter=filter) 95 | validation_trio = encoded_batch(validation_trio) 96 | print("Validation size: %d" % len(validation_trio)) 97 | # packaging the data using PyTorch compatible structures 98 | validation_dataset = ExampleDataset(validation_trio) 99 | validation_sampler = SequentialSampler(validation_dataset) 100 | validation_loader = DataLoader( 101 | validation_dataset, 102 | batch_size=batch_size, 103 | sampler=validation_sampler, 104 | num_workers=0, 105 | collate_fn=lambda x: x 106 | # TODO: learn whether it helps to encode here 107 | ) 108 | validation_loaders.append(validation_loader) 109 | with open(test_path, "rt") as fin: 110 | test_json = json.loads(fin.read()) 111 | test_trio = build_equation_tree_examples_list(test_json, depth=eval_depth, unify_one_zero=unify_one_zero, filter=filter) 112 | test_trio = encoded_batch(test_trio) 113 | print("test size: %d" % len(test_trio)) 114 | # packaging the data using PyTorch compatible structures 115 | test_dataset = ExampleDataset(test_trio) 116 | test_sampler = SequentialSampler(test_dataset) 117 | test_loader = DataLoader( 118 | test_dataset, 119 | batch_size=batch_size, 120 | sampler=test_sampler, 121 | num_workers=0, 122 | collate_fn=lambda x: x 123 | # TODO: learn whether it helps to encode here 124 | ) 125 | return train_loader, train_eval_loader, validation_loaders, test_loader 126 | 127 | 128 | def build_equation_tree_examples_list(json_dataset, numeric=False, depth=None, unify_one_zero=True, filter=None, equation_completion=False, candidates=None): 129 | """ 130 | 131 | Args: 132 | json_dataset: a json structure with the following schema 133 | [ 134 | [ 135 | {} --> an example at depth x 136 | ... 137 | ] --> a list containing examples at depth x 138 | ... 139 | ] --> contains lists of examples at various depths 140 | Returns: 141 | a list of trio of BinaryEqnTree and its label and depth 142 | """ 143 | result = [] 144 | for i, group in enumerate(json_dataset): 145 | for base_example in group: 146 | if equation_completion: 147 | l = generate_complete_examples_from_example_with_hole_and_candidates(base_example, candidates) 148 | else: 149 | l = [base_example] 150 | # if we do equation completion we get an expanded list otherwise we get a singleton list 151 | for example in l: 152 | eqn_tree, label, d = load_single_equation_tree_example(example, unify_one_zero=unify_one_zero) 153 | assert label in {0, 1, -1} 154 | if depth is not None and d not in depth: 155 | continue 156 | if eqn_tree.is_numeric() and numeric == False: 157 | continue # do not add numeric examples unless `numeric` flag set 158 | if eqn_tree.is_numeric() and label == 0: 159 | continue 160 | if filter is not None and not filter(eqn_tree): 161 | continue # filter is a criteria for choosing a tree (tree->bool) 162 | result.append((eqn_tree, label, d)) 163 | return result 164 | 165 | 166 | def load_single_equation_tree_example(example, unify_one_zero=True): 167 | """ 168 | 169 | Args: 170 | example: a dictionary of schema 171 | { 172 | "equation": { 173 | "vars": value for each constant node 'NegativeOne', 'Pi', 'One', 174 | 'Half', 'Integer', 'Rational', 'Float' 175 | "numNodes": number of nodes in this tree, discounting # 176 | "variables": dictionary of ?, 177 | "depth": depth of each node in this tree 178 | "nodeNum": unique ids of each node 179 | "func": the actual list of nodes in this (binary) equation tree, 180 | unary functions are still encoded as having two children, 181 | the right one being NULL (#) 182 | }, 183 | "label": "1" if the lhs of the equation equals rhs else "0" 184 | }, 185 | 186 | Returns: 187 | a BinaryEqnTree corresponding to `example`, paired with its label 188 | """ 189 | functions = example['equation']['func'].split(",") 190 | values = example['equation']['vars'].split(",") 191 | if unify_one_zero: 192 | functions = ["Integer" if function == "One" else function 193 | for function in functions] # replace One with Integer 194 | functions = ["Integer" if function == "Rational" and value == "0" else function 195 | for function, value in zip(functions, values)] # replace Rational_0 with Integer_0 196 | eqn_tree = BinaryEqnTree.build_from_preorder(functions, values) 197 | label = int(example['label']) 198 | cls = None if "class" not in example else example["class"] # this is only for candidates used in equation completion 199 | depth = max(int(d) for d in example['equation']["depth"].split(",") if d != "#") 200 | if eqn_tree.is_numeric(): 201 | if eqn_tree.lchild.maybe_extract_number_constant_node() and eqn_tree.rchild.maybe_extract_number_constant_node(): 202 | # raise AssertionError("both sides of the eqn are number constants: %s" % eqn_tree) 203 | # data for training NUMBER_ENCODER and NUMBER_DECODER, label is always the right child 204 | assert eqn_tree.rchild.function_name == NUMBER_ENCODER, 'right child should always be a number constant or label' 205 | eqn_tree.rchild = eqn_tree.rchild.maybe_extract_number_constant_node() 206 | eqn_tree.lchild = BinaryEqnTree(NUMBER_DECODER, eqn_tree.lchild, None) 207 | elif eqn_tree.lchild.maybe_extract_number_constant_node() is not None: 208 | eqn_tree.lchild = eqn_tree.lchild.maybe_extract_number_constant_node() 209 | eqn_tree.rchild = BinaryEqnTree(NUMBER_DECODER, eqn_tree.rchild, None) 210 | elif eqn_tree.rchild.maybe_extract_number_constant_node() is not None: 211 | eqn_tree.rchild = eqn_tree.rchild.maybe_extract_number_constant_node() 212 | eqn_tree.lchild = BinaryEqnTree(NUMBER_DECODER, eqn_tree.lchild, None) 213 | else: 214 | raise AssertionError("bad number equation: %s" % eqn_tree) 215 | eqn_tree.raw = example 216 | eqn_tree.label = label 217 | eqn_tree.depth = depth 218 | eqn_tree.cls = cls 219 | return eqn_tree, label, depth 220 | 221 | 222 | class BinaryEqnTree: 223 | 224 | NULL = "#" 225 | 226 | def __init__(self, function_name, lchild, rchild, 227 | is_a_floating_point=False, raw=None, label=None, depth=None): 228 | """ 229 | 230 | Args: 231 | function_name: the name of the node 232 | lchild: the left child (a BinaryEqnTree or None) 233 | rchild: the right child (a BinaryEqnTree or None) 234 | """ 235 | #TODO: make value a more general construct, i.e. a dictionary, or an object so that more than one value can be stored at a node 236 | if lchild is None and rchild is not None: 237 | raise ValueError("A tree can have the following children:" + "\n" 238 | " lchild=None, rchild=None or" + "\n" 239 | " lchild!=None, rchild=None or" + "\n" 240 | " lchild!=None, rchild!=None or" + "\n" 241 | "Got the following instead:" + "\n" 242 | " lchild=%s, rchild=%s" % (repr(lchild), repr(rchild))) 243 | self.function_name = function_name 244 | self.lchild = lchild 245 | self.rchild = rchild 246 | self.is_a_floating_point = is_a_floating_point 247 | self.value = None 248 | self.encoded_value = None 249 | self.is_binary = lchild is not None and rchild is not None 250 | self.is_unary = lchild is not None and rchild is None 251 | self.is_leaf = lchild is None and rchild is None 252 | self.raw = raw 253 | self.label = label 254 | self.depth = depth 255 | self.cls = None 256 | 257 | def apply(self, fn): 258 | if self.lchild is not None: 259 | self.lchild.apply(fn) 260 | if self.rchild is not None: 261 | self.rchild.apply(fn) 262 | fn(self) 263 | 264 | def all(self, pred): 265 | result = pred(self) 266 | if self.lchild is not None: 267 | result = result and self.lchild.all(pred) 268 | if self.rchild is not None: 269 | result = result and self.rchild.all(pred) 270 | return result 271 | 272 | def maybe_extract_number_constant_node(self): 273 | if self.function_name == NUMBER_ENCODER: 274 | return BinaryEqnTree(self.lchild.function_name, None, None, is_a_floating_point=True) 275 | if self.function_name == "Mul": 276 | if self.lchild.function_name == SYMBOL_ENCODER and \ 277 | self.lchild.lchild.function_name == "NegativeOne" and \ 278 | self.rchild.function_name == NUMBER_ENCODER: 279 | return BinaryEqnTree("-"+self.rchild.lchild.function_name, None, None, is_a_floating_point=True) 280 | if self.rchild.function_name == SYMBOL_ENCODER and \ 281 | self.rchild.lchild.function_name == "NegativeOne" and \ 282 | self.lchild.function_name == NUMBER_ENCODER: 283 | return BinaryEqnTree("-"+self.lchild.lchild.function_name, None, None, is_a_floating_point=True) 284 | return None 285 | 286 | def is_numeric(self): 287 | if self.function_name != "Equality": 288 | print("Warning: is_numeric should only be called on the root of an equation tree") 289 | return False #raise ValueError("is_numeric should only be called on the root of an equation tree") 290 | return self._is_numeric() 291 | 292 | def _is_numeric(self): 293 | if self.is_leaf: 294 | return self.is_a_floating_point 295 | if self.is_unary: 296 | return self.lchild._is_numeric() 297 | if self.is_binary: 298 | return self.lchild._is_numeric() or self.rchild._is_numeric() 299 | raise AssertionError(str(self)) 300 | 301 | def __str__(self): 302 | if self.is_binary: 303 | return "{}({}, {})".format(self.function_name, 304 | str(self.lchild), 305 | str(self.rchild)) 306 | elif self.is_unary: 307 | return "{}({})".format(self.function_name, 308 | str(self.lchild)) 309 | elif self.is_leaf: 310 | return "{}={}".format(self.function_name, self.value) 311 | else: 312 | raise RuntimeError("Invalid tree:\n%s" % repr(self)) 313 | 314 | def pretty_str(self, prettify=True): 315 | if self.function_name == "Add": 316 | lstr = self.lchild.pretty_str() 317 | rstr = self.rchild.pretty_str() 318 | return "({} + {})".format(lstr, rstr) 319 | elif self.function_name == "Mul": 320 | lstr = self.lchild.pretty_str() 321 | rstr = self.rchild.pretty_str() 322 | if self.lchild.function_name == "Add": 323 | lstr = "(%s)" % lstr 324 | if self.rchild.function_name == "Add": 325 | rstr = "(%s)" % rstr 326 | return "({} * {})".format(lstr, rstr) 327 | elif self.function_name == "Pow": 328 | lstr = self.lchild.pretty_str() 329 | rstr = self.rchild.pretty_str() 330 | return "({}^{})".format(lstr, rstr) 331 | elif self.function_name == "log": 332 | lstr = self.lchild.pretty_str() 333 | rstr = self.rchild.pretty_str() 334 | return "log({}, {})".format( lstr, rstr) 335 | elif self.function_name == "Equality": 336 | lstr = self.lchild.pretty_str() 337 | rstr = self.rchild.pretty_str() 338 | return "{} = {}".format(lstr, rstr) 339 | elif self.function_name == NUMBER_ENCODER or self.function_name == SYMBOL_ENCODER: 340 | return self.lchild.pretty_str() 341 | elif self.is_unary: 342 | return "({}({}))".format(self.function_name, 343 | self.lchild.pretty_str()) 344 | elif self.is_leaf: 345 | name = self.function_name 346 | if not prettify: 347 | return name 348 | if name in PRETTYNAMES: 349 | name = PRETTYNAMES[name] 350 | elif name.startswith("Integer") or name.startswith("Float") or \ 351 | name.startswith("Symbol") or name.startswith("Rational"): 352 | name = "_".join(name.split("_")[1:]) 353 | return "{}".format(name) 354 | else: 355 | raise RuntimeError("Can't prettify tree :\n%s" % repr(self)) 356 | 357 | 358 | def __repr__(self): 359 | return "BinaryEqnTree({},{},{},{},{},{},{})".format(repr(self.function_name), 360 | repr(self.lchild), 361 | repr(self.rchild), 362 | repr(self.is_a_floating_point), 363 | repr(self.raw), 364 | repr(self.label), 365 | repr(self.depth)) 366 | 367 | @staticmethod 368 | def build_from_preorder(functions, values): 369 | """ 370 | Recovers a BinaryEqnTree from its preorder list. 371 | WARNING: This method relies on modifying `functions` in-place. 372 | 373 | Args: 374 | functions: pre-order traversal of a BinaryEqnTree 375 | values: pre-order traversal of a BinaryEqnTree's values 376 | (leaf nodes and only leaf nodes have values, currently each leaf 377 | node has exactly one value, and that is fed as input to the 378 | function at the leaf, e.g. Leaf node is Symbol and value is "x". 379 | Here Symbol can be effectively understood as an embedding layer, 380 | and "x" is the input to that layer. Of course in actual 381 | execution "x" would be turned into an index into the embedding 382 | matrix first, so the input would be something like 383 | torch.LongTensor([12]).) 384 | Returns: 385 | the recovered BinaryEqnTree 386 | """ 387 | function_name = functions.pop(0) 388 | value = values.pop(0) 389 | if function_name == BinaryEqnTree.NULL: 390 | return None 391 | 392 | value = None if value == "" else value 393 | 394 | lchild = BinaryEqnTree.build_from_preorder(functions, values) 395 | rchild = BinaryEqnTree.build_from_preorder(functions, values) 396 | 397 | if function_name == "Number": 398 | leaf = BinaryEqnTree(value, None, None, is_a_floating_point=True) 399 | return BinaryEqnTree(NUMBER_ENCODER, leaf, None) 400 | 401 | elif function_name in SYMBOL_CLASSES: 402 | leaf = BinaryEqnTree("%s_%s" % (function_name, value), None, None) 403 | return BinaryEqnTree(SYMBOL_ENCODER, leaf, None) 404 | 405 | elif function_name in CONSTANTS: 406 | leaf = BinaryEqnTree(function_name, None, None) 407 | return BinaryEqnTree(SYMBOL_ENCODER, leaf, None) 408 | 409 | elif function_name in BINARY_FNS or function_name in UNARY_FNS or \ 410 | function_name == "Equality": 411 | return BinaryEqnTree(function_name, lchild, rchild) 412 | 413 | else: 414 | raise RuntimeError("Uncategorized function name: %s" % function_name) 415 | 416 | 417 | def extract_candidates(candidates_json): 418 | result = [] 419 | for group in candidates_json: 420 | for example in group: 421 | result.append({"functions":example["equation"]["func"].split(","), 422 | "values":example["equation"]["vars"].split(",")}) 423 | return result 424 | 425 | def generate_complete_examples_from_example_with_hole_and_candidates(example,candidates): 426 | goldfunc = example['equation']['func'] 427 | goldvar = example['equation']['vars'] 428 | functions = example['equation']['func'].split(",") 429 | values = example['equation']['vars'].split(",") 430 | nodenum = example['equation']['nodeNum'].split(",") 431 | blanknum = example['blankNodeNum'] 432 | idx = nodenum.index(blanknum) 433 | empty = lambda i: nodenum[i] == "#" # "#" 434 | leaf = lambda i: (not empty(i)) and (empty(i+1) and empty(i+2)) # "15,#,#" 435 | d11 = lambda i: (not empty(i)) and (leaf(i+1)) and (empty(i+4)) # "15,16,#,#,#" 436 | d12 = lambda i: (not empty(i)) and (leaf(i+1)) and (leaf(i+4)) #"15,16,#,#,17,#,#" 437 | assert leaf(idx) or d11(idx) or d12(idx) 438 | #TODO support more deep swaps 439 | if leaf(idx): 440 | start, end = idx, idx + 3 441 | elif d11(idx): 442 | start, end = idx, idx + 5 443 | elif d12(idx): 444 | start, end = idx, idx + 7 445 | else: 446 | assert False 447 | examples = [] 448 | correctcount = 0 449 | for candidate in candidates: 450 | candidate_functions = candidate["functions"] 451 | candidate_values = candidate["values"] 452 | newfunctions = functions[:start] + candidate_functions + functions[end:] 453 | newvalues = values[:start] + candidate_values + values[end:] 454 | newfunc = ",".join(newfunctions) 455 | newvar = ",".join(newvalues) 456 | if newfunc == goldfunc and newvar == goldvar: 457 | correctcount += 1 458 | examples.append({ 459 | "label": "1" if newfunc == goldfunc and newvar == goldvar else "0", 460 | "equation": { 461 | "vars":newvar, 462 | "func":newfunc, 463 | "depth":example["equation"]["depth"], 464 | } 465 | }) 466 | if correctcount != 1: 467 | print(correctcount) 468 | for candidate in candidates: 469 | candidate_functions = candidate["functions"] 470 | candidate_values = candidate["values"] 471 | newfunctions = functions[:start] + candidate_functions + functions[end:] 472 | newvalues = values[:start] + candidate_values + values[end:] 473 | newfunc = ",".join(newfunctions) 474 | newvar = ",".join(newvalues) 475 | print(candidate) 476 | print(functions[end:]) 477 | print(newfunc) 478 | print(newvar) 479 | print("--------") 480 | print(json.dumps(example,indent=4)) 481 | print(start, end) 482 | assert correctcount == 1 483 | return examples 484 | 485 | def generate_blank_example(example): 486 | goldfunc = example['equation']['func'] 487 | goldvar = example['equation']['vars'] 488 | functions = example['equation']['func'].split(",") 489 | values = example['equation']['vars'].split(",") 490 | nodenum = example['equation']['nodeNum'].split(",") 491 | blanknum = example['blankNodeNum'] 492 | idx = nodenum.index(blanknum) 493 | empty = lambda i: nodenum[i] == "#" # "#" 494 | leaf = lambda i: (not empty(i)) and (empty(i+1) and empty(i+2)) # "15,#,#" 495 | d11 = lambda i: (not empty(i)) and (leaf(i+1)) and (empty(i+4)) # "15,16,#,#,#" 496 | d12 = lambda i: (not empty(i)) and (leaf(i+1)) and (leaf(i+4)) #"15,16,#,#,17,#,#" 497 | assert leaf(idx) or d11(idx) or d12(idx) 498 | #TODO support more deep swaps 499 | if leaf(idx): 500 | start, end = idx, idx + 3 501 | elif d11(idx): 502 | start, end = idx, idx + 5 503 | elif d12(idx): 504 | start, end = idx, idx + 7 505 | else: 506 | assert False 507 | examples = [] 508 | correctcount = 0 509 | for candidate in [{"functions":["NaN","#","#"], "values":["NaN", "#", "#"]}]: 510 | candidate_functions = candidate["functions"] 511 | candidate_values = candidate["values"] 512 | newfunctions = functions[:start] + candidate_functions + functions[end:] 513 | newvalues = values[:start] + candidate_values + values[end:] 514 | newfunc = ",".join(newfunctions) 515 | newvar = ",".join(newvalues) 516 | if newfunc == goldfunc and newvar == goldvar: 517 | correctcount += 1 518 | examples.append({ 519 | "label": "1" if newfunc == goldfunc and newvar == goldvar else "0", 520 | "equation": { 521 | "vars":newvar, 522 | "func":newfunc, 523 | "depth":example["equation"]["depth"], 524 | } 525 | }) 526 | return examples[0] 527 | 528 | class ExampleDataset(Dataset): 529 | """ 530 | Generic Dataset. 531 | """ 532 | 533 | def __init__(self, examples): 534 | self.examples = [ex for ex in examples] 535 | 536 | def __getitem__(self, item): 537 | return self.examples[item] 538 | 539 | def __len__(self): 540 | return len(self.examples) 541 | -------------------------------------------------------------------------------- /equation_verification/nn_tree_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import pickle 4 | import traceback 5 | 6 | import torch 7 | import random 8 | import json 9 | 10 | from equation_verification.dataset_loading import load_equation_tree_examples, \ 11 | load_single_equation_tree_example, sequential_sampler, \ 12 | build_equation_tree_examples_list 13 | from equation_verification.nn_tree_model import build_nnTree 14 | from optimizers import build_optimizer 15 | from parse import parse_equation 16 | from equation_verification.constants import UNARY_FNS 17 | from equation_verification.sequential_model import LSTMchain 18 | 19 | class nnTreeEquationVerificationExperiment: 20 | 21 | def __init__(self): 22 | """ 23 | Defines and parses hyper-parameters, conditions of the experiment, 24 | logging and save/loading options from the commandline. 25 | """ 26 | parser = argparse.ArgumentParser( 27 | description="Train tree-LSTM on generated equalities", 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | 30 | # define and parse hyper-parameters of model/training from commandline 31 | parser.add_argument('--num-hidden', type=int, default=50, 32 | help='hidden layer size') 33 | parser.add_argument('--num-embed', type=int, default=50, 34 | help='embedding layer size') 35 | parser.add_argument('--memory-size', type=int, default=5, 36 | help='max size of the stack/queue') 37 | parser.add_argument('--lr', type=float, default=0.001, 38 | help='initial learning rate') 39 | parser.add_argument('--optimizer', type=str, default='adam', 40 | help='the optimizer type') 41 | parser.add_argument('--mom', type=float, default=0.2, 42 | help='momentum for sgd') 43 | parser.add_argument('--wd', type=float, default=0.00001, 44 | help='weight decay for sgd') 45 | parser.add_argument('--beta1', type=float, default=0.9, 46 | help='beta 1 for optimizer') 47 | parser.add_argument('--beta2', type=float, default=0.999, 48 | help='beta 2 for optimizer') 49 | parser.add_argument('--dropout', type=float, default=None, 50 | help='dropout probability (1.0 - keep probability)') 51 | 52 | # define and parse conditions of the experiment from commandline 53 | parser.add_argument('--model-class', type=str, default="NNTrees", 54 | help='the classname of the model to run') 55 | parser.add_argument('--train-path', type=str, default=None, 56 | help='path to training examples') 57 | parser.add_argument('--unify-one-zero', type=eval, default=True, 58 | help='whether to unify ones and zeros to integer') 59 | parser.add_argument('--validation-path', type=str, default=None, nargs="+", 60 | help='path(s) to validation examples, if multiple files supplied, will evaluate each individually') 61 | parser.add_argument('--test-path', type=str, default=None, 62 | help='path to test examples') 63 | parser.add_argument('--num-epochs', type=int, default=100, 64 | help='max num of epochs') 65 | parser.add_argument('--seed', type=int, default=12093, 66 | help='max num of epochs') 67 | parser.add_argument('--batch-size', type=int, default=1, 68 | help='the batch size') 69 | parser.add_argument('--share-memory-params', default=False, 70 | action='store_true', 71 | help='whether to allow weight sharing for memory ' 72 | 'operations') 73 | parser.add_argument('--disable-sharing', default=False, 74 | action='store_true', 75 | help='whether to not allow weight sharing for memory ' 76 | 'operations') 77 | parser.add_argument('--no-op', default=False, 78 | action='store_true', 79 | help='whether to add no operation to stack push ' 80 | 'and pop') 81 | parser.add_argument('--no-pop', default=False, 82 | action='store_true', 83 | help='whether to add just push and no-op') 84 | parser.add_argument('--stack-type', type=str, default='simple', 85 | help='choose the stack type. options are: simple, nn_stack, full_stack, add_stack, simple_gated, full_stack_gated') 86 | parser.add_argument('--likeLSTM', default=False, 87 | action='store_true', 88 | help='whether to make the mem2out stack tree completely like an LSTM+stack') 89 | parser.add_argument('--gate-push-pop', default=False, 90 | action='store_true', 91 | help='whether to make the push pop action a gate rather than a number') 92 | parser.add_argument('--normalize-action', default=False, 93 | action='store_true', 94 | help='whether to normalize push pop weight to 1 before pushing and poping') 95 | parser.add_argument('--gate-top-k', default=False, 96 | action='store_true', 97 | help='whether to gate the top-k instead of weighted average') 98 | parser.add_argument('--top-k', type=int, default=5, 99 | help='the top-k stack elements will be used for computing the output' 100 | 'select k. NOTE: k is in range(0,memory_size)') 101 | parser.add_argument('--numeric', default=False, 102 | action='store_true', 103 | help='whether to train on numeric equations') 104 | parser.add_argument('--fast', default=False, 105 | action='store_true', 106 | help='whether to evaluate only on 20% training') #TODO: This is not used by any code rn 107 | parser.add_argument('--verbose', default=False, 108 | action='store_true', 109 | help='whether to print execusion trace outputs') 110 | parser.add_argument('--curriculum', default=None, 111 | help='what type of cirriculum (depth/func)') 112 | parser.add_argument('--switch-epoch', default=None, type=int, 113 | help='epoch to switch over (2stage curriculum)') 114 | parser.add_argument('--curriculum-depth', default=None, type=int, 115 | help='max depth in curriculum') 116 | parser.add_argument('--eval-depth', nargs="+", type=int, default=None, 117 | help='list of depth to evaluate on') 118 | parser.add_argument('--tree-node-activation', type=str, default='sigmoid', 119 | help='choose the activation for tree-node') 120 | parser.add_argument('--stack-node-activation', type=str, 121 | default='sigmoid', 122 | help='choose the activation for stack-node') 123 | 124 | 125 | # define and parse logging and save/loading options of the experiment 126 | parser.add_argument('--model-prefix', type=str, default=None, 127 | help='path to save/load model') 128 | parser.add_argument('--result-path', type=str, default=None, 129 | help='path to save results') 130 | parser.add_argument('--load-epoch', type=int, default=None, 131 | help='load from epoch') 132 | parser.add_argument('--evaluate-only', action="store_true", default=False, 133 | help='evaluate only') 134 | parser.add_argument('--trace-path', type=str, default=None, 135 | help='path to save traces') 136 | parser.add_argument('--disp-epoch', type=int, default=1, 137 | help='show progress for every n epochs') 138 | parser.add_argument('--checkpoint-every-n-epochs', type=int, default=5, 139 | help='save model for every n epochs') 140 | parser.add_argument('--interactive', action='store_true', default=False, 141 | help='interactive evaluation') 142 | parser.add_argument('--force-result', action='store_true', default=False, 143 | help='force write to result in evaluate only') 144 | 145 | self.args = parser.parse_args() 146 | 147 | def aggregate_statistics(self, record, depth=None): 148 | # if len(record) == 0: 149 | # raise ValueError("empty record") 150 | if depth is not None: 151 | record = [item for item in record if item["depth"]==depth] 152 | if len(record) == 0: 153 | raise ValueError("empty record at depth %d" % depth) 154 | symbolic_true_positive = 0 155 | symbolic_true_negative = 0 156 | symbolic_false_positive = 0 157 | symbolic_false_negative = 0 158 | symbolic_loss = 0 159 | symbolic_count = 0 160 | numeric_true_positive = 0 161 | numeric_true_negative = 0 162 | numeric_false_positive = 0 163 | numeric_false_negative = 0 164 | numeric_loss = 0 165 | numeric_count = 0 166 | for item in record: 167 | if not item["ex"].is_numeric(): 168 | if item["correct"] and (item["label"] == 1): 169 | symbolic_true_positive += 1 170 | elif item["correct"] and (item["label"] == 0): 171 | symbolic_true_negative += 1 172 | elif not item["correct"] and (item["label"] == 1): 173 | symbolic_false_negative += 1 174 | elif not item["correct"] and (item["label"] == 0): 175 | symbolic_false_positive += 1 176 | else: 177 | assert False 178 | symbolic_loss += item["loss"] 179 | symbolic_count += 1 180 | 181 | else: 182 | if item["correct"] and (item["label"] == 1): 183 | numeric_true_positive += 1 184 | elif item["correct"] and (item["label"] == 0): 185 | numeric_true_negative += 1 186 | elif not item["correct"] and (item["label"] == 1): 187 | numeric_false_negative += 1 188 | elif not item["correct"] and (item["label"] == 0): 189 | numeric_false_positive += 1 190 | else: 191 | assert False 192 | 193 | numeric_loss += item["loss"] 194 | numeric_count += 1 195 | assert symbolic_count == (symbolic_true_positive + symbolic_false_positive 196 | + symbolic_true_negative + symbolic_false_negative) 197 | symbolic_accuracy = (symbolic_true_positive + symbolic_true_negative) / symbolic_count if symbolic_count != 0 else 0 198 | symbolic_precision = symbolic_true_positive / (symbolic_true_positive + symbolic_false_positive) if (symbolic_true_positive + symbolic_false_positive) != 0 else 0 199 | symbolic_recall = symbolic_true_positive / (symbolic_true_positive + symbolic_false_negative) if (symbolic_true_positive + symbolic_false_negative) != 0 else 0 200 | symbolic_f1 = 2 * (symbolic_precision * symbolic_recall) / (symbolic_precision + symbolic_recall) if (symbolic_precision + symbolic_recall) != 0 else 0 201 | symbolic_spc = symbolic_true_negative / (symbolic_true_negative + symbolic_false_positive) if (symbolic_true_negative + symbolic_false_positive) != 0 else 0 202 | assert numeric_count == ( 203 | numeric_true_positive + numeric_false_positive 204 | + numeric_true_negative + numeric_false_negative) 205 | 206 | numeric_accuracy = ( 207 | numeric_true_positive + numeric_true_negative) / \ 208 | numeric_count if numeric_count != 0 else 0 209 | numeric_precision = numeric_true_positive / ( 210 | numeric_true_positive + numeric_false_positive) if ( 211 | numeric_true_positive + numeric_false_positive) != 0 else 0 212 | numeric_recall = numeric_true_positive / ( 213 | numeric_true_positive + numeric_false_negative) if ( 214 | numeric_true_positive + numeric_false_negative) != 0 else 0 215 | numeric_f1 = 2 * (numeric_precision * numeric_recall) / ( 216 | numeric_precision + numeric_recall) if ( 217 | numeric_precision + numeric_recall) != 0 else 0 218 | print("overall") 219 | bins = [0] * 20 220 | unit = 1/len(bins) 221 | for item in record: 222 | bin_idx = math.floor((math.exp(item["score"])/(1+math.exp(item["score"])) / unit)) 223 | if bin_idx >= len(bins): 224 | assert bin_idx == len(bins) 225 | bin_idx = bin_idx - 1 226 | bins[bin_idx] += 1 227 | maxcnt = max(bins) 228 | normalized_bins = [cnt/maxcnt for cnt in bins] 229 | print("1.0") 230 | for ncnt in reversed(normalized_bins): 231 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 232 | normalized_bins = [cnt/sum(bins) for cnt in bins] 233 | print("0.0") 234 | print("1.0") 235 | for ncnt in reversed(normalized_bins): 236 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 237 | print("0.0") 238 | print("condition positive") 239 | bins = [0] * 20 240 | unit = 1/len(bins) 241 | for item in record: 242 | if item["label"] == 0: continue 243 | bin_idx = math.floor((math.exp(item["score"])/(1+math.exp(item["score"])) / unit)) 244 | if bin_idx >= len(bins): 245 | assert bin_idx == len(bins) 246 | bin_idx = bin_idx - 1 247 | bins[bin_idx] += 1 248 | maxcnt = max(bins) 249 | if maxcnt > 0: 250 | normalized_bins = [cnt/maxcnt for cnt in bins] 251 | print("1.0") 252 | for ncnt in reversed(normalized_bins): 253 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 254 | normalized_bins = [cnt/sum(bins) for cnt in bins] 255 | print("0.0") 256 | print("1.0") 257 | for ncnt in reversed(normalized_bins): 258 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 259 | print("0.0") 260 | print("condition negative") 261 | bins = [0] * 20 262 | unit = 1/len(bins) 263 | for item in record: 264 | if item["label"] == 1: continue 265 | bin_idx = math.floor((math.exp(item["score"])/(1+math.exp(item["score"])) / unit)) 266 | if bin_idx >= len(bins): 267 | assert bin_idx == len(bins) 268 | bin_idx = bin_idx - 1 269 | bins[bin_idx] += 1 270 | maxcnt = max(bins) 271 | if maxcnt > 0: 272 | normalized_bins = [cnt/maxcnt for cnt in bins] 273 | print("1.0") 274 | for ncnt in reversed(normalized_bins): 275 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 276 | normalized_bins = [cnt/sum(bins) for cnt in bins] 277 | print("0.0") 278 | print("1.0") 279 | for ncnt in reversed(normalized_bins): 280 | print(max(math.ceil(ncnt), math.ceil(ncnt * 60) )* "+") 281 | print("0.0") 282 | return { 283 | "sym_loss_avg": symbolic_loss / symbolic_count if symbolic_count != 0 else 0, 284 | "sym_acc": symbolic_accuracy, 285 | "sym_precision": symbolic_precision, 286 | "sym_recall": symbolic_recall, 287 | "sym_spc": symbolic_spc, 288 | "sym_f1": symbolic_f1, 289 | "sym_count": symbolic_count, 290 | "num_loss_avg": numeric_loss / numeric_count if numeric_count != 0 else 0, 291 | "num_acc": numeric_accuracy, 292 | "num_precision": numeric_precision, 293 | "num_recall": numeric_recall, 294 | "num_f1": numeric_f1, 295 | "num_count": numeric_count 296 | } 297 | 298 | def train(self, model, optimizer, batch): 299 | model.train() 300 | record, loss = model.compute_batch(batch) 301 | optimizer.zero_grad() 302 | loss.backward() 303 | optimizer.step() 304 | return record, loss.item() 305 | 306 | def eval(self, model, loader, trace=None): 307 | model.eval() 308 | aggregate_record = [] 309 | aggregate_loss = 0 310 | for batch in loader: 311 | record, loss = model.compute_batch(batch, trace=trace) 312 | aggregate_record.extend(record) 313 | aggregate_loss += loss.item() 314 | return aggregate_record, (aggregate_loss / len(aggregate_record)) if len(aggregate_record) != 0 else 0 315 | 316 | def eval_and_log(self, model, loader, name, epoch, path, trace=None): 317 | if trace is not None: 318 | trace[name] = [] 319 | record, loss = self.eval(model, loader, trace=trace[name] if trace is not None else None) 320 | 321 | # log 322 | depths = sorted(list(set(item["depth"] for item in record))) 323 | stats = self.aggregate_statistics(record) 324 | print("epoch={} set={} loss={} depth=all sym_loss={} " 325 | "sym_acc={} sym_prec={} sym_rec={} sym_spc={} sym_f1={} " 326 | "sym_count={}".format( 327 | epoch, name, loss, 328 | stats["sym_loss_avg"], stats["sym_acc"], 329 | stats["sym_precision"], stats["sym_recall"],stats["sym_spc"], 330 | stats["sym_f1"], stats["sym_count"] 331 | )) 332 | print("epoch={} set={} loss={} depth=all num_loss={} " 333 | "num_acc={} num_prec={} num_rec={} num_f1={} " 334 | "num_count={}".format( 335 | epoch, name, loss, 336 | stats["num_loss_avg"], stats["num_acc"], 337 | stats["num_precision"], stats["num_recall"], 338 | stats["num_f1"], stats["num_count"] 339 | )) 340 | with open(path, "a") as fout: 341 | log_entry = {"type": "eval", "epoch":epoch, "set":name, "stats": {"all": stats}} 342 | for depth in depths: 343 | stats_d = self.aggregate_statistics(record, depth=depth) 344 | print("epoch={} set={} loss={} depth={} sym_loss={} " 345 | "sym_acc={} sym_prec={} sym_rec={} sym_spc={} sym_f1={} " 346 | "sym_count={}".format( 347 | epoch, name, None, depth, 348 | stats_d["sym_loss_avg"], stats_d["sym_acc"], 349 | stats_d["sym_precision"], stats_d["sym_recall"],stats["sym_spc"], 350 | stats_d["sym_f1"], stats_d["sym_count"] 351 | )) 352 | print("epoch={} set={} loss={} depth={} num_loss={} " 353 | "num_acc={} num_prec={} num_rec={} num_f1={} " 354 | "num_count={}".format( 355 | epoch, name, None, depth, 356 | stats_d["num_loss_avg"], stats_d["num_acc"], 357 | stats_d["num_precision"], stats_d["num_recall"], 358 | stats_d["num_f1"], stats_d["num_count"] 359 | )) 360 | log_entry["stats"][depth] = stats_d 361 | fout.write(json.dumps(log_entry) + "\n") 362 | 363 | def save_model(self, epoch, model, optimizer): 364 | model_path = "%s-model-%d.checkpoint" % (self.args.model_prefix, epoch) 365 | torch.save(model.state_dict(), model_path) 366 | optim_path = "%s-optim-%d.checkpoint"% (self.args.model_prefix, epoch) 367 | torch.save(optimizer.state_dict(), optim_path) 368 | 369 | def load_model_and_optim(self, epoch, model, optimizer): 370 | model_path = "%s-model-%d.checkpoint" % (self.args.model_prefix, epoch) 371 | optim_path = "%s-optim-%d.checkpoint"% (self.args.model_prefix, epoch) 372 | model.load_state_dict(torch.load(model_path)) 373 | optimizer.load_state_dict(torch.load(optim_path)) 374 | 375 | def run(self): 376 | """ 377 | Main routine for running this experiment. 378 | """ 379 | torch.manual_seed(self.args.seed) 380 | random.seed(self.args.seed) 381 | if not self.args.interactive: 382 | train_loader, train_eval_loader, validation_loaders, test_loader = load_equation_tree_examples( 383 | self.args.train_path, 384 | self.args.validation_path, 385 | self.args.test_path, 386 | batch_size=self.args.batch_size, 387 | numeric=self.args.numeric, 388 | eval_depth=self.args.eval_depth, 389 | unify_one_zero=self.args.unify_one_zero 390 | ) 391 | 392 | if self.args.curriculum == "depth": 393 | easy_train_loader, easy_train_eval_loader, easy_validation_loader, easy_test_loader = load_equation_tree_examples( 394 | self.args.train_path, 395 | self.args.validation_path, 396 | self.args.test_path, 397 | batch_size=self.args.batch_size, 398 | numeric=self.args.numeric, 399 | eval_depth=self.args.eval_depth, 400 | unify_one_zero=self.args.unify_one_zero, 401 | filter=lambda x: x.depth <= self.args.curriculum_depth 402 | ) 403 | if self.args.curriculum == "func": 404 | prohibited = set(UNARY_FNS) 405 | easy_train_loader, easy_train_eval_loader, \ 406 | easy_validation_loader, easy_test_loader = \ 407 | load_equation_tree_examples( 408 | self.args.train_path, 409 | self.args.validation_path, 410 | self.args.test_path, 411 | batch_size=self.args.batch_size, 412 | numeric=self.args.numeric, 413 | eval_depth=self.args.eval_depth, 414 | unify_one_zero=self.args.unify_one_zero, 415 | filter=lambda x: x.all(lambda n:n.function_name not in prohibited) 416 | ) 417 | 418 | 419 | model = build_nnTree(self.args.model_class, 420 | self.args.num_hidden, 421 | self.args.num_embed, 422 | self.args.memory_size, 423 | self.args.share_memory_params, 424 | self.args.dropout, 425 | self.args.no_op, 426 | self.args.stack_type, 427 | self.args.top_k, 428 | self.args.verbose, 429 | self.args.tree_node_activation, 430 | self.args.stack_node_activation, 431 | self.args.no_pop, 432 | self.args.disable_sharing, 433 | self.args.likeLSTM, 434 | self.args.gate_push_pop, 435 | self.args.gate_top_k, 436 | self.args.normalize_action) 437 | 438 | totParams = 0 439 | for p in model.parameters(): 440 | totParams += p.numel() 441 | # print('total num params:', totParams) 442 | # 443 | 444 | optimizer = build_optimizer((param for param in model.parameters() if 445 | param.requires_grad), 446 | self.args.optimizer, 447 | self.args.lr, 448 | self.args.mom, 449 | self.args.wd, 450 | self.args.beta1, 451 | self.args.beta2) 452 | 453 | if self.args.load_epoch is None or self.args.evaluate_only is False: 454 | with open(self.args.result_path, "wt") as _: # TODO: if fine-tuning option is added this should be fixed 455 | pass 456 | 457 | 458 | batch_counter = 0 459 | if self.args.interactive: 460 | self.load_model_and_optim(self.args.load_epoch, model, optimizer) 461 | self.interactive(model) 462 | elif self.args.evaluate_only: 463 | trace = dict() 464 | self.load_model_and_optim(self.args.load_epoch, model, optimizer) 465 | self.eval_and_log(model, train_eval_loader, "train", self.args.load_epoch, "tmp" if not self.args.force_result else self.args.result_path, trace=trace) 466 | for i, vi in enumerate(validation_loaders): 467 | self.eval_and_log(model, vi, "validation:%d" % i, self.args.load_epoch, "tmp" if not self.args.force_result else self.args.result_path, trace=trace) 468 | self.eval_and_log(model, test_loader, "test", self.args.load_epoch, "tmp" if not self.args.force_result else self.args.result_path, trace=trace) 469 | if self.args.trace_path: 470 | with open(self.args.trace_path, "wb") as f: 471 | f.write(pickle.dumps(trace)) 472 | elif self.args.curriculum is not None: 473 | for epoch in range(1, self.args.num_epochs + 1): 474 | if epoch <= self.args.switch_epoch: 475 | curriculum_loader = easy_train_loader 476 | else: 477 | curriculum_loader = train_loader 478 | for batch in curriculum_loader: 479 | _, loss = self.train(model, optimizer, batch) 480 | batch_counter += 1 481 | print("iter %d loss=%f" % (batch_counter, loss)) 482 | with open(self.args.result_path, "a") as fout: 483 | fout.write(json.dumps( 484 | {"type": "train", "iter": batch_counter, 485 | "set": "train", "loss": loss}) + "\n") 486 | if epoch % self.args.disp_epoch is 0: 487 | self.eval_and_log(model, easy_train_loader, "easy_train", epoch, 488 | self.args.result_path) 489 | self.eval_and_log(model, easy_validation_loader, "easy_validation", 490 | epoch, self.args.result_path) 491 | self.eval_and_log(model, easy_test_loader, "easy_test", epoch, 492 | self.args.result_path) 493 | self.eval_and_log(model, train_eval_loader, "train", epoch, 494 | self.args.result_path) 495 | self.eval_and_log(model, validation_loader, "validation", 496 | epoch, self.args.result_path) 497 | self.eval_and_log(model, test_loader, "test", epoch, 498 | self.args.result_path) 499 | if epoch % self.args.checkpoint_every_n_epochs == 0: 500 | self.save_model(epoch, model, optimizer) 501 | else: 502 | for epoch in range(1, self.args.num_epochs + 1): 503 | for batch in train_loader: 504 | _, loss = self.train(model, optimizer, batch) 505 | batch_counter += 1 506 | print("iter %d loss=%f" % (batch_counter, loss)) 507 | with open(self.args.result_path, "a") as fout: 508 | fout.write(json.dumps( 509 | {"type": "train", "iter": batch_counter, 510 | "set": "train", "loss": loss}) + "\n") 511 | if epoch % self.args.disp_epoch is 0: 512 | # self.eval_and_log(model, train_eval_loader, "train", epoch, 513 | # self.args.result_path) 514 | for i, vi in enumerate(validation_loaders): 515 | self.eval_and_log(model, vi, "validation:%d" % i, 516 | epoch, self.args.result_path) 517 | self.eval_and_log(model, test_loader, "test", epoch, 518 | self.args.result_path) 519 | if epoch % self.args.checkpoint_every_n_epochs == 0: 520 | self.save_model(epoch, model, optimizer) 521 | 522 | def interactive(self, model): 523 | with open(self.args.train_path, "rt") as fin: 524 | train_json = json.loads(fin.read()) 525 | train_trio = build_equation_tree_examples_list(train_json, 526 | numeric=self.args.numeric, 527 | unify_one_zero=self.args.unify_one_zero) 528 | train = set(str(tree) for tree,_,_ in train_trio) 529 | gtrace = [] 530 | exit = False 531 | while True: 532 | while True: 533 | try: 534 | s = input("enter an equation:\n") 535 | if s == "exit": 536 | exit = True 537 | break 538 | equation = parse_equation(s) 539 | break 540 | except Exception: 541 | traceback.print_exc() 542 | if exit: 543 | break 544 | trio = load_single_equation_tree_example(equation.dump()) 545 | tree = trio[0] 546 | tree_r = load_single_equation_tree_example(equation.dump())[0] 547 | tree_r.lchild, tree_r.rchild = tree_r.rchild, tree_r.lchild 548 | print((str(tree) in train) or (str(tree_r) in train)) 549 | loader = sequential_sampler(trios=[trio],batch_size=1) 550 | trace = [] 551 | record, loss = self.eval(model, loader, trace=trace) 552 | print(trace[0].probability) 553 | gtrace.append(trace[0]) 554 | if self.args.trace_path: 555 | with open(self.args.trace_path, "wb") as f: 556 | f.write(pickle.dumps({"train":gtrace})) 557 | 558 | 559 | if __name__ == "__main__": 560 | # train-path as input 561 | nnTreeEquationVerificationExperiment().run() 562 | --------------------------------------------------------------------------------