├── .DS_Store ├── .gitignore ├── LICENSE.txt ├── README.rst ├── perf.prof ├── requirements.txt ├── scripts ├── ast │ ├── data_download.sh │ ├── data_preporcess.py │ ├── data_process.py │ ├── data_process.sh │ ├── eval.sh │ ├── glove_tokens.py │ ├── grid_search.sh │ ├── run.sh │ └── train.sh ├── pyast │ ├── data_download.sh │ ├── data_process.sh │ ├── run.sh │ └── train.sh ├── tensorboard.sh └── token │ ├── accuracy.sh │ ├── data_process.py │ ├── run.sh │ ├── tokenizer.js │ └── train.sh └── zerogercrnn ├── __init__.py ├── experiments ├── __init__.py ├── ast_level │ ├── __init__.py │ ├── ast_core.py │ ├── common.py │ ├── data.py │ ├── main.py │ ├── metrics.py │ ├── nt2n_base │ │ ├── __init__.py │ │ ├── main.py │ │ └── model.py │ ├── nt2n_base_attention │ │ ├── __init__.py │ │ ├── main.py │ │ └── model.py │ ├── nt2n_base_attention_plus_layered │ │ ├── __init__.py │ │ ├── main.py │ │ └── model.py │ ├── ntn2t_base │ │ ├── __init__.py │ │ ├── main.py │ │ └── model.py │ ├── raw_data.py │ ├── results.py │ ├── utils.py │ └── vis │ │ ├── __init__.py │ │ ├── accuracies.py │ │ ├── ast_info.py │ │ ├── compare.py │ │ ├── model.py │ │ ├── plots.py │ │ ├── post_accuracy.py │ │ ├── pre_accuracy.py │ │ └── utils.py ├── common.py ├── pyast │ └── metrics.py ├── temp │ ├── __init__.py │ └── mnist_norm_test.py └── token_level │ ├── __init__.py │ ├── base │ ├── __init__.py │ ├── main.py │ └── model.py │ ├── common.py │ ├── core.py │ ├── data.py │ ├── main.py │ ├── metrics.py │ └── results.py ├── global_constants.py ├── lib ├── __init__.py ├── accuracies.py ├── argutils.py ├── attn.py ├── calculation.py ├── constants.py ├── core.py ├── data.py ├── embedding.py ├── file.py ├── health.py ├── log.py ├── metrics.py ├── preprocess.py ├── run.py ├── utils.py └── visualization │ ├── __init__.py │ ├── embeddings.py │ ├── html_helper.py │ ├── plotter.py │ └── text.py ├── test └── lib │ ├── __init__.py │ ├── calculation_test.py │ ├── data_test.py │ └── metrics_test.py └── testutils └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Idea 2 | .idea/ 3 | 4 | # Helper directories 5 | data/ 6 | venv/ 7 | saved/ 8 | tensorboard/ 9 | eval/ 10 | eval_local/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # dotenv 93 | .env 94 | 95 | # virtualenv 96 | .venv 97 | venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ################## 2 | Bachelor's grad work in neural code completion 3 | ################## 4 | 5 | Initial set up 6 | ================= 7 | Create virtual environment: ``./venv.sh`` 8 | 9 | Activate virtual environment: ``source env/bin/activate`` 10 | 11 | Proposed models are working with AST so there is a possibility to complete any language. For now there is possibility to test model on two datasets: 12 | 13 | 1. Javascript (`js150 dataset link `_) 14 | 2. Python (`py150 dataset link `_) 15 | 16 | 17 | Javascript 18 | ============== 19 | To train model on Javascript dataset: 20 | 21 | 1. Download data: ``./scripts/ast/data_download.sh`` 22 | 2. Process data: ``./scripts/ast/data_process.sh`` 23 | 3. Train model: ``./scripts/ast/run.sh`` 24 | 25 | To change model parameters edit file: ``scripts/ast/train.sh`` 26 | 27 | Python 28 | ============== 29 | To train model on Python dataset: 30 | 31 | 1. Download data: ``./scripts/pyast/data_download.sh`` 32 | 2. Process data: ``./scripts/pyast/data_process.sh`` 33 | 3. Train model: ``./scripts/pyast/run.sh`` 34 | 35 | To change model parameters edit file: ``scripts/pyast/train.sh`` 36 | 37 | Results 38 | ============= 39 | For accuracy visualization tensorboard is used. To run it use: ``./scripts/tensorboard.sh`` 40 | -------------------------------------------------------------------------------- /perf.prof: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/perf.prof -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch==0.4.0 3 | torchvision 4 | tqdm 5 | jupyter 6 | tensorflow 7 | tensorboard 8 | tensorboardX 9 | matplotlib 10 | visdom 11 | pylint 12 | -------------------------------------------------------------------------------- /scripts/ast/data_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget 'http://files.srl.inf.ethz.ch/data/js_dataset.tar.gz' 4 | gunzip -c js_dataset.tar.gz | tar xopf - 5 | rm js_dataset.tar.gz 6 | rm data.tar.gz 7 | rm programs_training.txt 8 | rm programs_eval.txt 9 | rm README.txt 10 | mkdir data 11 | mkdir data/ast 12 | mv programs_training.json data/ast/programs_training.json 13 | mv programs_eval.json data/ast/programs_eval.json 14 | -------------------------------------------------------------------------------- /scripts/ast/data_preporcess.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.lib.constants import ENCODING 2 | from zerogercrnn.lib.log import tqdm_lim 3 | 4 | 5 | def create_smaller_file(file, new_file, lim): 6 | in_file = open(file, mode='r', encoding=ENCODING) 7 | out_file = open(new_file, mode='w', encoding=ENCODING) 8 | 9 | for line in tqdm_lim(in_file, lim=lim): 10 | out_file.write(line) 11 | 12 | in_file.close() 13 | out_file.close() 14 | 15 | 16 | if __name__ == '__main__': 17 | create_smaller_file('data/ast/programs_eval.json', 'data/ast/programs_eval_5k.json', lim=5000) 18 | -------------------------------------------------------------------------------- /scripts/ast/data_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from itertools import chain 4 | 5 | from zerogercrnn.experiments.ast_level.raw_data import TokensRetriever, JsonConverter, OneHotConverter 6 | from zerogercrnn.lib.constants import ENCODING 7 | from zerogercrnn.lib.preprocess import extract_jsons_info, JsonListKeyExtractor 8 | 9 | parser = argparse.ArgumentParser(description='Data processing for token level neural network') 10 | parser.add_argument('--file_train_raw', type=str, help='Raw train file') 11 | parser.add_argument('--file_eval_raw', type=str, help='Raw eval file') 12 | parser.add_argument('--file_non_terminals', type=str, help='File to store terminals') 13 | parser.add_argument('--file_terminals', type=str, help='File to store non-terminals') 14 | parser.add_argument('--file_train_converted', type=str, help='Sequence train file') 15 | parser.add_argument('--file_eval_converted', type=str, help='Sequence eval file') 16 | parser.add_argument('--file_train', type=str, help='One-hot train file') 17 | parser.add_argument('--file_eval', type=str, help='One-hot eval file') 18 | parser.add_argument('--file_glove_map', type=str, help='File from glove_tokens.py storing map from token to number') 19 | parser.add_argument('--file_glove_vocab', type=str, help='Vocabulary of trained Glove vectors') 20 | parser.add_argument('--file_glove_terminals', type=str, help='Where to put terminals corpus of GloVe') 21 | parser.add_argument('--file_glove_non_terminals', type=str, help='Where to put non-terminals corpus of GloVe') 22 | parser.add_argument('--last_is_zero', action='store_true', help='Is programs jsons ends with 0?') 23 | 24 | LIM = 100000 25 | 26 | """ 27 | Script that forms one-hot sequences of (N, T) from JS dataset. 28 | """ 29 | 30 | 31 | def create_glove_terminals_file(args): 32 | json_data = open(args.file_glove_map).read() 33 | term2id = json.loads(json_data) 34 | id2term = {} 35 | for (k, v) in term2id.items(): 36 | id2term[v] = k 37 | 38 | terminals = [] 39 | with open(args.file_glove_vocab, mode='r', encoding=ENCODING) as f: 40 | for line in f: 41 | t = line.split(' ') 42 | terminals.append(id2term[int(t[0])]) 43 | 44 | glove_terminals = open(args.file_glove_terminals, mode='w', encoding=ENCODING) 45 | glove_terminals.write(json.dumps(terminals)) 46 | 47 | 48 | def get_tokens(args): 49 | TokensRetriever().get_and_write_tokens( 50 | dataset=args.file_train_raw, 51 | non_terminal_dest=args.file_non_terminals, 52 | terminal_dest=args.file_terminals, 53 | encoding=ENCODING, 54 | append_eof=True, 55 | lim=LIM 56 | ) 57 | 58 | 59 | def convert_files(args): 60 | print('Train') 61 | JsonConverter.convert_file( 62 | raw_file=args.file_train_raw, 63 | dest_file=args.file_train_converted, 64 | terminals_file=args.file_terminals, 65 | encoding=ENCODING, 66 | append_eof=True, 67 | lim=LIM, 68 | last_is_zero=args.last_is_zero 69 | ) 70 | 71 | print('Eval') 72 | JsonConverter.convert_file( 73 | raw_file=args.file_eval_raw, 74 | dest_file=args.file_eval_converted, 75 | terminals_file=args.file_terminals, 76 | encoding=ENCODING, 77 | append_eof=True, 78 | lim=LIM, 79 | last_is_zero=args.last_is_zero 80 | ) 81 | 82 | 83 | def form_one_hot(args): 84 | converter = OneHotConverter( 85 | file_non_terminals=args.file_non_terminals, 86 | file_terminals=args.file_terminals 87 | ) 88 | 89 | print('Train') 90 | converter.convert_file( 91 | src_file=args.file_train_converted, 92 | dst_file=args.file_train, 93 | lim=LIM 94 | ) 95 | 96 | print('Eval') 97 | converter.convert_file( 98 | src_file=args.file_eval_converted, 99 | dst_file=args.file_eval, 100 | lim=LIM 101 | ) 102 | 103 | 104 | def create_glove_non_terminals_file(args): 105 | """Create GloVe non-terminals file from one-hot file. """ 106 | 107 | with open(file=args.file_glove_non_terminals, mode='w', encoding=ENCODING) as f: 108 | nt_extractor = JsonListKeyExtractor(key='N') # extract non terminals for one-hot files 109 | for nt_generator in extract_jsons_info(nt_extractor, args.file_train, args.file_eval): 110 | f.write(' '.join(map(str, chain(nt_generator, [-1] * 10)))) 111 | f.write(' ') 112 | 113 | 114 | def main(): 115 | args = parser.parse_args() 116 | 117 | # print('Retrieving Glove terminals') 118 | # create_glove_terminals_file(args) 119 | 120 | print('Retrieving tokens ...') 121 | get_tokens(args) 122 | 123 | print('Converting to sequences ...') 124 | convert_files(args) 125 | 126 | print('Forming one-hot ...') 127 | form_one_hot(args) 128 | 129 | # print('Creating GloVe non-terminals corpus') 130 | # create_glove_non_terminals_file(args) 131 | 132 | print('Train file: {}'.format(args.file_train)) 133 | print('Eval file: {}'.format(args.file_eval)) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /scripts/ast/data_process.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PYTHONPATH=. python3 -m cProfile -o perf.prof scripts/ast/data_process.py \ 4 | --file_train_raw "data/ast/programs_training.json" \ 5 | --file_eval_raw "data/ast/programs_eval.json" \ 6 | --file_non_terminals "data/ast/non_terminals.json" \ 7 | --file_terminals "data/ast/terminals.json" \ 8 | --file_train_converted "data/ast/programs_training_seq.json" \ 9 | --file_eval_converted "data/ast/programs_eval_seq.json" \ 10 | --file_train "data/ast/file_train.json" \ 11 | --file_eval "data/ast/file_eval.json" \ 12 | --file_glove_map "data/ast/terminals_map.json" \ 13 | --file_glove_vocab "data/ast/vocab.txt" \ 14 | --file_glove_terminals "data/ast/glove_terminals.json" \ 15 | --file_glove_non_terminals "data/ast/glove_non_terminals_corpus.txt" \ 16 | --last_is_zero -------------------------------------------------------------------------------- /scripts/ast/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #!/bin/bash 3 | 4 | mkdir saved 5 | mkdir saved/$2 6 | mkdir eval 7 | mkdir eval/$2 8 | 9 | PYTHONPATH=. python3 zerogercrnn/experiments/ast_level/main.py \ 10 | --title $2 \ 11 | --prediction $1 \ 12 | --eval \ 13 | --eval_results_directory eval/$2 \ 14 | --eval_file "data/ast/file_eval.json" \ 15 | --data_limit 100000 \ 16 | --model_save_dir saved/$2 \ 17 | --saved_model "saved/" \ 18 | --seq_len 50 \ 19 | --batch_size 128 \ 20 | --learning_rate 0.001 \ 21 | --epochs 8 \ 22 | --decay_after_epoch 0 \ 23 | --decay_multiplier 0.6 \ 24 | --weight_decay=0. \ 25 | --hidden_size 1500 \ 26 | --num_layers 1 \ 27 | --dropout 0.01 \ 28 | --layered_hidden_size 500 \ 29 | --non_terminals_num 97 \ 30 | --non_terminal_embedding_dim 300 \ 31 | --num_tree_layers 50 \ 32 | --non_terminals_file "data/ast/non_terminals.json" \ 33 | --non_terminal_embeddings_file "data/ast/non_terminal_embeddings.txt" \ 34 | --terminals_num 50001 \ 35 | --terminal_embedding_dim 1200 \ 36 | --terminals_file "data/ast/terminals.json" \ 37 | --terminal_embeddings_file "data/ast/terminal_embeddings.txt" \ 38 | --node_depths_embedding_dim 20 \ 39 | --nodes_depths_stat_file "eval/ast/stat/node_depths.json" -------------------------------------------------------------------------------- /scripts/ast/glove_tokens.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Data processing for token level neural network') 6 | parser.add_argument('--task', type=str, help='One of: terminals, non-terminals') 7 | parser.add_argument('--input_file', type=str, help='Input file for task') 8 | parser.add_argument('--output_file', type=str, help='Output file for task') 9 | parser.add_argument('--token_map_file', type=str, help='Map from token name to int file') 10 | 11 | LIM = 100000 12 | ENCODING = 'ISO-8859-1' 13 | EMP_TOKEN = '' 14 | 15 | 16 | def write_map(file, raw_map): 17 | f_write = open(file, mode='w', encoding=ENCODING) 18 | f_write.write(json.dumps(raw_map)) 19 | 20 | 21 | def create_terminals_file(args, lim=LIM): 22 | """Create file for terminals consisiting of sequence of token numbers. Each token is mapped into it's number. 23 | i.e. data.x.y -> 0 1 2 1 3 24 | 25 | NB: token numbers can have a big values. (Number of different tokens in data) 26 | No unk tokens here. 27 | """ 28 | 29 | f_write = open(args.output_file, mode='w', encoding=ENCODING) 30 | 31 | terminals = {EMP_TOKEN: 0} 32 | current_id = 1 33 | 34 | it = 0 35 | with open(args.input_file, mode='r', encoding=ENCODING) as f: 36 | for l in tqdm(f, total=min(lim, 100000)): 37 | it += 1 38 | 39 | raw_json = json.loads(l) 40 | converted = [] 41 | for node in raw_json: 42 | if node == 0: 43 | break 44 | 45 | # add terminal 46 | if 'value' in node: 47 | node_value = str(node['value']) 48 | if node_value not in terminals.keys(): 49 | terminals[node_value] = current_id 50 | current_id += 1 51 | else: 52 | node_value = EMP_TOKEN 53 | 54 | converted.append(terminals[node_value]) 55 | 56 | f_write.write(' '.join([str(x) for x in converted])) 57 | f_write.write(' ') 58 | 59 | if (lim is not None) and (it == lim): 60 | break 61 | 62 | write_map(args.token_map_file, terminals) 63 | 64 | if __name__ == '__main__': 65 | args = parser.parse_args() 66 | 67 | if args.task == 'terminals': 68 | create_terminals_file(args) 69 | else: 70 | raise Exception('Not supported task') 71 | -------------------------------------------------------------------------------- /scripts/ast/grid_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir saved 4 | mkdir saved/$2 5 | mkdir eval 6 | mkdir eval/$2 7 | 8 | PYTHONPATH=. python3 -m cProfile -o program.prof zerogercrnn/experiments/ast_level/main.py \ 9 | --title $2 \ 10 | --prediction $1 \ 11 | --eval_results_directory eval/$2 \ 12 | --train_file "data/ast/file_train.json" \ 13 | --data_limit 100000 \ 14 | --model_save_dir saved/$2 \ 15 | --seq_len 50 \ 16 | --batch_size 80 \ 17 | --learning_rate 0.001 \ 18 | --epochs 20 \ 19 | --decay_after_epoch 0 \ 20 | --decay_multiplier 0.9 \ 21 | --weight_decay=0. \ 22 | --hidden_size 500 \ 23 | --num_layers 1 \ 24 | --dropout 0.01 \ 25 | --layered_hidden_size 500 \ 26 | --non_terminals_num 97 \ 27 | --non_terminal_embedding_dim 50 \ 28 | --non_terminals_file "data/ast/non_terminals.json" \ 29 | --non_terminal_embeddings_file "data/ast/non_terminal_embeddings.txt" \ 30 | --terminals_num 50001 \ 31 | --terminal_embedding_dim 50 \ 32 | --terminals_file "data/ast/terminals.json" \ 33 | --terminal_embeddings_file "data/ast/terminal_embeddings.txt" \ 34 | --node_depths_embedding_dim 20 \ 35 | --nodes_depths_stat_file "eval/ast/stat/node_depths.json" \ 36 | --grid_name hidden_size \ 37 | --grid_values 100 500 1000 1500 2000 38 | 39 | -------------------------------------------------------------------------------- /scripts/ast/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./scripts/ast/train.sh nt2n_base_attention_plus_layered eval_nt2n_base_attention_plus_layered_large_embeddings 4 | -------------------------------------------------------------------------------- /scripts/ast/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir saved 4 | mkdir saved/$2 5 | mkdir eval 6 | mkdir eval/$2 7 | 8 | PYTHONPATH=. python3 zerogercrnn/experiments/ast_level/main.py \ 9 | --title $2 \ 10 | --prediction $1 \ 11 | --eval_results_directory eval/$2 \ 12 | --train_file "data/ast/file_train.json" \ 13 | --data_limit 100000 \ 14 | --model_save_dir saved/$2 \ 15 | --seq_len 50 \ 16 | --batch_size 128 \ 17 | --learning_rate 0.0001 \ 18 | --epochs 8 \ 19 | --decay_after_epoch 0 \ 20 | --decay_multiplier 0.6 \ 21 | --weight_decay=0. \ 22 | --hidden_size 1500 \ 23 | --num_layers 1 \ 24 | --dropout 0.01 \ 25 | --layered_hidden_size 500 \ 26 | --num_tree_layers 50 \ 27 | --non_terminals_num 97 \ 28 | --non_terminal_embedding_dim 300 \ 29 | --non_terminals_file "data/ast/non_terminals.json" \ 30 | --non_terminal_embeddings_file "data/ast/non_terminal_embeddings.txt" \ 31 | --terminals_num 50001 \ 32 | --terminal_embedding_dim 1200 \ 33 | --terminals_file "data/ast/terminals.json" \ 34 | --terminal_embeddings_file "data/ast/terminal_embeddings.txt" \ 35 | --node_depths_embedding_dim 20 \ 36 | --nodes_depths_stat_file "eval/ast/stat/node_depths.json" 37 | 38 | -------------------------------------------------------------------------------- /scripts/pyast/data_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget 'http://files.srl.inf.ethz.ch/data/py150.tar.gz' 4 | gunzip -c py150.tar.gz | tar xopf - 5 | rm py150.tar.gz 6 | mkdir data/pyast 7 | mv python50k_eval.json data/pyast/python50k_eval.json 8 | mv python100k_train.json data/pyast/python100k_train.json 9 | cd .. -------------------------------------------------------------------------------- /scripts/pyast/data_process.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PYTHONPATH=. python3 -m cProfile -o perf.prof scripts/ast/data_process.py \ 4 | --file_train_raw "data/pyast/python100k_train.json" \ 5 | --file_eval_raw "data/pyast/python50k_eval.json" \ 6 | --file_non_terminals "data/pyast/non_terminals.json" \ 7 | --file_terminals "data/pyast/terminals.json" \ 8 | --file_train_converted "data/pyast/programs_training_seq.json" \ 9 | --file_eval_converted "data/pyast/programs_eval_seq.json" \ 10 | --file_train "data/pyast/file_train.json" \ 11 | --file_eval "data/pyast/file_eval.json" -------------------------------------------------------------------------------- /scripts/pyast/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./scripts/pyast/train.sh nt2n_base nt2n_base_try -------------------------------------------------------------------------------- /scripts/pyast/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir saved 4 | mkdir saved/$2 5 | mkdir eval 6 | mkdir eval/$2 7 | 8 | PYTHONPATH=. python3 -m cProfile -o program.prof zerogercrnn/experiments/ast_level/main.py \ 9 | --title $2 \ 10 | --prediction $1 \ 11 | --eval_results_directory eval/$2 \ 12 | --train_file "data/pyast/file_train.json" \ 13 | --data_limit 100 \ 14 | --model_save_dir saved/$2 \ 15 | --seq_len 5 \ 16 | --batch_size 5 \ 17 | --learning_rate 0.001 \ 18 | --epochs 30 \ 19 | --decay_after_epoch 0 \ 20 | --decay_multiplier 0.8 \ 21 | --weight_decay=0. \ 22 | --hidden_size 500 \ 23 | --num_layers 1 \ 24 | --dropout 0.01 \ 25 | --layered_hidden_size 100 \ 26 | --num_tree_layers 30 \ 27 | --non_terminals_num 322 \ 28 | --non_terminal_embedding_dim 50 \ 29 | --non_terminals_file "data/pyast/non_terminals.json" \ 30 | --terminals_num 50001 \ 31 | --terminal_embedding_dim 50 \ 32 | --terminals_file "data/pyast/terminals.json" \ 33 | --node_depths_embedding_dim 20 \ 34 | --nodes_depths_stat_file "eval/ast/stat/node_depths.json" 35 | 36 | -------------------------------------------------------------------------------- /scripts/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir tensorboard 4 | mkdir tensorboard/runs 5 | tensorboard --logdir tensorboard/runs -------------------------------------------------------------------------------- /scripts/token/accuracy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir saved 4 | mkdir saved/$1 5 | PYTHONPATH=. python3 zerogercrnn/experiments/token_level/main.py \ 6 | --title $1 \ 7 | --task accuracy \ 8 | --eval_file "data/tokens/file_eval.json" \ 9 | --embeddings_file "data/tokens/vectors.txt" \ 10 | --data_limit 10000 \ 11 | --model_save_dir saved/$1 \ 12 | --tokens_count 51000 \ 13 | --seq_len 50 \ 14 | --batch_size 100 \ 15 | --learning_rate 0.005 \ 16 | --epochs 20 \ 17 | --decay_after_epoch 0 \ 18 | --decay_multiplier 0.9 \ 19 | --embedding_size 50 \ 20 | --hidden_size 1500 \ 21 | --num_layers 1 \ 22 | --dropout 0.01 \ 23 | --weight_decay=0. 24 | -------------------------------------------------------------------------------- /scripts/token/data_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | 5 | from tqdm import tqdm 6 | 7 | from zerogercrnn.lib.file import read_lines 8 | 9 | parser = argparse.ArgumentParser(description='Data processing for token level neural network') 10 | parser.add_argument('--task', type=str, help='One of: token, one_hot_json, one_hot_text') 11 | parser.add_argument('--tokens_file', type=str, help='File with tokens') 12 | parser.add_argument('--input_file', type=str, help='Input file for task') 13 | parser.add_argument('--output_file', type=str, help='Output file for task') 14 | 15 | TRAIN_FILE = '/Users/zerogerc/Documents/datasets/js_dataset.tar/programs_training_tokenized.json' 16 | TRAIN_FILE_ONE_HOT = '/Users/zerogerc/Yandex.Disk.localized/shared_files/University/diploma/rnn-autocomplete/data/tokens/file_train.json' 17 | TRAIN_FILE_PLAIN = '/Users/zerogerc/Yandex.Disk.localized/shared_files/University/diploma/rnn-autocomplete/data/tokens/file_train_plain.txt' 18 | 19 | EVAL_FILE = 'data/programs_eval_tokenized.json' 20 | EVAL_FILE_ONE_HOT = 'data/tokens/file_eval.json' 21 | EVAL_FILE_PLAIN = 'data/tokens/file_eval_plain.txt' 22 | 23 | TOKENS_FILE = '/Users/zerogerc/Yandex.Disk.localized/shared_files/University/diploma/rnn-autocomplete/data/tokens/tokens.txt' 24 | 25 | ENCODING = 'ISO-8859-1' 26 | 27 | EMPTY_STRING_SPACE = re.compile('[ \t]+') 28 | EMPTY_STRING_NEWLINE = re.compile('[ \n\t]+') 29 | 30 | 31 | # UNK TOKEN is zero 32 | 33 | def normalize_token(token): 34 | if EMPTY_STRING_SPACE.fullmatch(token): 35 | return ' ' 36 | if EMPTY_STRING_NEWLINE.fullmatch(token): 37 | return '\n' 38 | return token.strip() 39 | 40 | 41 | def read_tokens(tokens_path): 42 | id2token = read_lines(tokens_path, encoding=ENCODING) 43 | token2id = {} 44 | for id, token in enumerate(id2token): 45 | token2id[token] = id 46 | 47 | return token2id, id2token 48 | 49 | 50 | def get_tokens(file_path, output_path, lim=100): 51 | tokens = {} 52 | for l in open(file=file_path, mode='r', encoding=ENCODING): 53 | for t in json.loads(l): 54 | t = normalize_token(t) 55 | if t not in tokens.keys(): 56 | tokens[t] = 0 57 | tokens[t] += 1 58 | 59 | with open(output_path, mode='w', encoding=ENCODING) as f: 60 | sorted_terminals = sorted(tokens.keys(), key=lambda key: tokens[key], reverse=True) 61 | for t in sorted_terminals[:lim]: 62 | f.write('{}\n'.format(t)) 63 | 64 | 65 | def convert_to_one_hot(file_path, tokens_path, output_file, total): 66 | token2id, id2token = read_tokens(tokens_path) 67 | 68 | all_tokens = 0 69 | unk_tokens = 0 70 | 71 | out_file = open(file=output_file, mode='w', encoding=ENCODING) 72 | for l in tqdm(open(file=file_path, mode='r', encoding=ENCODING), total=total): 73 | one_hot = [] 74 | for t in json.loads(l): 75 | t = normalize_token(t) 76 | if t == ' ': # skip spaces 77 | continue 78 | 79 | all_tokens += 1 80 | if t in token2id.keys(): 81 | one_hot.append(1 + token2id[t]) 82 | else: 83 | one_hot.append(0) 84 | unk_tokens += 1 85 | 86 | out_file.write(json.dumps(one_hot)) 87 | out_file.write('\n') 88 | 89 | print(' tokens percentage: {}'.format(float(unk_tokens) / all_tokens)) 90 | 91 | 92 | def convert_to_plain_text(file_path, tokens_path, output_file, total): 93 | token2id, id2token = read_tokens(tokens_path) 94 | 95 | out_file = open(file=output_file, mode='w', encoding=ENCODING) 96 | for l in tqdm(open(file=file_path, mode='r', encoding=ENCODING), total=total): 97 | one_hot = [] 98 | for t in json.loads(l): 99 | one_hot.append(str(t)) 100 | out_file.write(' '.join(one_hot)) 101 | out_file.write(' ') 102 | 103 | 104 | if __name__ == '__main__': 105 | args = parser.parse_args() 106 | 107 | # input_file = args.input_file 108 | # output_file = args.output_file 109 | # tokens_file = args.tokens_file 110 | 111 | input_file = EVAL_FILE 112 | output_file = EVAL_FILE_ONE_HOT 113 | tokens_file = TOKENS_FILE 114 | 115 | if args.task == 'token': 116 | get_tokens(input_file, output_file, lim=50000) 117 | elif args.task == 'one_hot_json': 118 | convert_to_one_hot(input_file, tokens_file, output_file, total=100000) 119 | elif args.task == 'one_hot_plain': 120 | convert_to_plain_text(input_file, tokens_file, output_file, total=100000) 121 | else: 122 | raise Exception('Unknown task type') 123 | # elif args.task == 'one_hot_text': 124 | # convert_to_plain_text(TRAIN_FILE, TOKENS_FILE, ) 125 | -------------------------------------------------------------------------------- /scripts/token/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./scripts/token/train.sh token_base 28May_token_base_hs500 -------------------------------------------------------------------------------- /scripts/token/tokenizer.js: -------------------------------------------------------------------------------- 1 | const fs = require('fs'); 2 | const readline = require('readline'); 3 | const jsTokens = require('js-tokens').default; 4 | 5 | 6 | const pathRoot = '/Users/zerogerc/Documents/datasets/js_dataset.tar/'; 7 | 8 | const evalInputPath = '/Users/zerogerc/Documents/datasets/js_dataset.tar/programs_training.txt'; 9 | const evalOutputPath = '/Users/zerogerc/Documents/datasets/js_dataset.tar/programs_training_tokenized.txt'; 10 | 11 | function appendLineToFile(filePath, text) { 12 | fs.appendFileSync(filePath, text + '\n') 13 | } 14 | 15 | function readFile(filePath) { 16 | return fs.readFileSync(filePath, {encoding: 'utf-8'}) 17 | } 18 | 19 | function processFile(outputPath, filePath) { 20 | const content = readFile(filePath); 21 | const tokens = content.match(jsTokens); 22 | appendLineToFile(outputPath, JSON.stringify(tokens)) 23 | } 24 | 25 | function main() { 26 | readline.createInterface({ 27 | input: fs.createReadStream(evalInputPath) 28 | }).on('line', function (line) { 29 | if (fs.existsSync(pathRoot + line)) { 30 | processFile(evalOutputPath, pathRoot + line) 31 | } 32 | }) 33 | } 34 | 35 | main(); 36 | -------------------------------------------------------------------------------- /scripts/token/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir saved 4 | mkdir saved/$2 5 | mkdir eval 6 | mkdir eval/$2 7 | 8 | PYTHONPATH=. python3 -m cProfile -o program.prof zerogercrnn/experiments/token_level/main.py \ 9 | --title $2 \ 10 | --prediction $1 \ 11 | --eval_results_directory eval/$2 \ 12 | --train_file "data/tokens/file_train.json" \ 13 | --data_limit 100000 \ 14 | --model_save_dir saved/$2 \ 15 | --seq_len 50 \ 16 | --batch_size 80 \ 17 | --learning_rate 0.001 \ 18 | --epochs 30 \ 19 | --decay_after_epoch 0 \ 20 | --decay_multiplier 0.9 \ 21 | --weight_decay=0. \ 22 | --hidden_size 500 \ 23 | --num_layers 1 \ 24 | --dropout 0.01 \ 25 | --tokens_num 51000 \ 26 | --token_embedding_dim 50 27 | -------------------------------------------------------------------------------- /zerogercrnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/ast_level/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/ast_core.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | from zerogercrnn.lib.utils import setup_tensor, repackage_hidden 4 | from zerogercrnn.lib.core import CombinedModule, AlphaBetaSumLayer, EmbeddingsModule, LinearLayer, NormalizationLayer 5 | from zerogercrnn.lib.attn import Attn 6 | from zerogercrnn.lib.calculation import calc_attention_combination 7 | import torch.nn.functional as F 8 | 9 | from zerogercrnn.experiments.ast_level.data import ASTInput 10 | 11 | 12 | class ASTNT2NModule(CombinedModule): 13 | 14 | def __init__( 15 | self, 16 | non_terminals_num, 17 | non_terminal_embedding_dim, 18 | terminals_num, 19 | terminal_embedding_dim, 20 | recurrent_output_size 21 | ): 22 | super().__init__() 23 | 24 | self.non_terminals_num = non_terminals_num 25 | self.non_terminal_embedding_dim = non_terminal_embedding_dim 26 | self.terminals_num = terminals_num 27 | self.terminal_embedding_dim = terminal_embedding_dim 28 | self.recurrent_output_size = recurrent_output_size 29 | 30 | self.nt_embedding = self.module(EmbeddingsModule( 31 | num_embeddings=self.non_terminals_num, 32 | embedding_dim=self.non_terminal_embedding_dim, 33 | sparse=True 34 | )) 35 | 36 | self.t_embedding = self.module(EmbeddingsModule( 37 | num_embeddings=self.terminals_num, 38 | embedding_dim=self.terminal_embedding_dim, 39 | sparse=True 40 | )) 41 | 42 | self.h2o = self.module(LinearLayer( 43 | input_size=self.recurrent_output_size, 44 | output_size=self.non_terminals_num 45 | )) 46 | 47 | @abstractmethod 48 | def get_recurrent_output(self, combined_input, ast_input: ASTInput, m_hidden, forget_vector): 49 | """Method should return tensor that will be passed to h2o and updated hidden that will be passed 50 | to next invocation of get_recurrent_output as m_hidden.""" 51 | 52 | return None, None 53 | 54 | def forward(self, ast_input: ASTInput, m_hidden, forget_vector): 55 | non_terminal_input = ast_input.non_terminals 56 | terminal_input = ast_input.terminals 57 | 58 | nt_embedded = self.nt_embedding(non_terminal_input) 59 | t_embedded = self.t_embedding(terminal_input) 60 | combined_input = torch.cat([nt_embedded, t_embedded], dim=-1) 61 | 62 | recurrent_output, new_m_hidden = self.get_recurrent_output( 63 | combined_input=combined_input, 64 | ast_input=ast_input, 65 | m_hidden=m_hidden, 66 | forget_vector=forget_vector 67 | ) 68 | 69 | m_output = self.h2o(recurrent_output) 70 | return m_output, new_m_hidden 71 | 72 | 73 | class LastKBuffer: 74 | def __init__(self, window_len, hidden_size): 75 | self.buffer = None 76 | self.window_len = window_len 77 | self.hidden_size = hidden_size 78 | self.it = 0 79 | 80 | def add_vector(self, vector): 81 | self.buffer[self.it] = vector 82 | self.it += 1 83 | if self.it >= self.window_len: 84 | self.it = 0 85 | 86 | def get(self): 87 | return torch.stack(self.buffer, dim=1) 88 | 89 | def init_buffer(self, batch_size): 90 | self.buffer = [setup_tensor(torch.zeros((batch_size, self.hidden_size))) for _ in range(self.window_len)] 91 | 92 | def repackage_and_forget_buffer_partly(self, forget_vector): 93 | self.buffer = [repackage_hidden(b.mul(forget_vector)) for b in self.buffer] 94 | 95 | 96 | class LastKAttention(CombinedModule): 97 | """TODO: make it use base K attention""" 98 | def __init__(self, hidden_size, k=50, ab_transform=False): 99 | super().__init__() 100 | self.hidden_size = hidden_size 101 | self.k = k 102 | self.ab_transform = ab_transform 103 | 104 | self.context_buffer = None 105 | self.attn = self.module(Attn(method='general', hidden_size=self.hidden_size)) 106 | if self.ab_transform: 107 | self.alpha_beta_sum = self.module(AlphaBetaSumLayer(min_value=-1, max_value=2)) 108 | 109 | def repackage_and_forget_buffer_partly(self, forget_vector): 110 | self.context_buffer.repackage_and_forget_buffer_partly(forget_vector) 111 | 112 | def init_hidden(self, batch_size): 113 | self.context_buffer = LastKBuffer(window_len=self.k, hidden_size=self.hidden_size) 114 | self.context_buffer.init_buffer(batch_size) 115 | 116 | def forward(self, current_hidden): 117 | if self.context_buffer is None: 118 | raise Exception('You should init buffer first') 119 | 120 | current_buffer = self.context_buffer.get() 121 | attn_output_coefficients = self.attn(current_hidden, current_buffer) 122 | attn_output = calc_attention_combination(attn_output_coefficients, current_buffer) 123 | 124 | buffer_vector = current_hidden 125 | if self.ab_transform: 126 | buffer_vector = self.alpha_beta_sum(current_hidden, attn_output) 127 | 128 | self.context_buffer.add_vector(buffer_vector) 129 | return attn_output 130 | 131 | 132 | class LastKAttentionBase(CombinedModule): 133 | """Layer that stores buffer of last k vectors passed to add_vector. 134 | forward will compute attention of input to all vectors in buffer. 135 | """ 136 | 137 | def __init__(self, hidden_size, k=50): 138 | super().__init__() 139 | self.hidden_size = hidden_size 140 | self.k = k 141 | 142 | self.context_buffer = None 143 | self.attn = self.module(Attn(method='general', hidden_size=self.hidden_size)) 144 | 145 | def add_vector(self, vector): 146 | self.context_buffer.add_vector(vector) 147 | 148 | def forward(self, current_hidden): 149 | if self.context_buffer is None: 150 | raise Exception('You should init buffer first') 151 | 152 | current_buffer = self.context_buffer.get() 153 | attn_output_coefficients = self.attn(current_hidden, current_buffer) 154 | attn_output = calc_attention_combination(attn_output_coefficients, current_buffer) 155 | 156 | return attn_output 157 | 158 | def repackage_and_forget_buffer_partly(self, forget_vector): 159 | self.context_buffer.repackage_and_forget_buffer_partly(forget_vector) 160 | 161 | def init_hidden(self, batch_size): 162 | self.context_buffer = LastKBuffer(window_len=self.k, hidden_size=self.hidden_size) 163 | self.context_buffer.init_buffer(batch_size) 164 | 165 | 166 | class GatedLastKAttention(CombinedModule): 167 | 168 | def __init__(self, input_size, hidden_size, k): 169 | super().__init__() 170 | self.input_size = input_size 171 | self.hidden_size = hidden_size 172 | 173 | self.base_attn = self.module( 174 | LastKAttentionBase( 175 | hidden_size=self.hidden_size, 176 | k=k 177 | ) 178 | ) 179 | 180 | self.x_norm = self.module(NormalizationLayer(features_num=self.input_size)) 181 | self.h_norm = self.module(NormalizationLayer(features_num=self.hidden_size)) 182 | 183 | self.w_cntx = self.module(LinearLayer( 184 | input_size=self.hidden_size + self.input_size, 185 | output_size=self.hidden_size, 186 | bias=False 187 | )) 188 | 189 | self.w_h = self.module(LinearLayer( 190 | input_size=self.hidden_size + self.input_size, 191 | output_size=self.hidden_size, 192 | bias=False 193 | )) 194 | 195 | def forward(self, current_input, current_hidden): 196 | x = self.x_norm(current_input) 197 | h = current_hidden 198 | cntx = self.base_attn(current_hidden) 199 | 200 | # combine cntx and h with current_input to allow model to make different decisions based on current input. 201 | cntx_x = torch.cat((cntx, x), dim=-1) 202 | h_x = torch.cat((h, x), dim=-1) 203 | 204 | # calculate gated functions 205 | g_cntx = F.tanh(self.w_cntx(cntx_x)) 206 | g_h = F.tanh(self.w_h(h_x)) 207 | 208 | # calculate output as sum of cntx and h multiplied by corresponding activations. 209 | m_output = (g_cntx * cntx) + (g_h * h) 210 | self.base_attn.add_vector(m_output) 211 | 212 | return m_output 213 | 214 | def repackage_and_forget_buffer_partly(self, forget_vector): 215 | self.base_attn.repackage_and_forget_buffer_partly(forget_vector) 216 | 217 | def init_hidden(self, batch_size): 218 | self.base_attn.init_hidden(batch_size) 219 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/common.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from zerogercrnn.experiments.ast_level.data import ASTInput, ASTTarget, ASTDataReader, ASTDataGenerator 7 | from zerogercrnn.experiments.common import Main 8 | from zerogercrnn.lib.embedding import Embeddings 9 | from zerogercrnn.lib.metrics import Metrics 10 | from zerogercrnn.lib.run import NetworkRoutine 11 | from zerogercrnn.lib.utils import filter_requires_grad 12 | 13 | 14 | # region Utils 15 | 16 | def create_terminal_embeddings(args): 17 | return Embeddings( 18 | embeddings_size=args.terminal_embedding_dim, 19 | vector_file=args.terminal_embeddings_file, 20 | squeeze=True 21 | ) 22 | 23 | 24 | def create_non_terminal_embeddings(args): 25 | return Embeddings( 26 | embeddings_size=args.non_terminal_embedding_dim, 27 | vector_file=args.non_terminal_embeddings_file, 28 | squeeze=False 29 | ) 30 | 31 | 32 | def create_data_generator(args): 33 | data_reader = ASTDataReader( 34 | file_train=args.train_file, 35 | file_eval=args.eval_file, 36 | seq_len=args.seq_len, 37 | number_of_seq=20, 38 | limit=args.data_limit 39 | ) 40 | 41 | data_generator = ASTDataGenerator( 42 | data_reader=data_reader, 43 | seq_len=args.seq_len, 44 | batch_size=args.batch_size 45 | ) 46 | 47 | return data_generator 48 | 49 | 50 | # endregion 51 | 52 | # region Loss 53 | 54 | class ASTLoss(nn.Module): 55 | 56 | @abstractmethod 57 | def forward(self, prediction: torch.Tensor, target: ASTTarget): 58 | pass 59 | 60 | 61 | class NonTerminalsCrossEntropyLoss(ASTLoss): 62 | 63 | def __init__(self): 64 | super().__init__() 65 | self.criterion = nn.CrossEntropyLoss() 66 | 67 | def forward(self, prediction: torch.Tensor, target: ASTTarget): 68 | return self.criterion(prediction.view(-1, prediction.size()[-1]), target.non_terminals.view(-1)) 69 | 70 | 71 | class TerminalsCrossEntropyLoss(ASTLoss): 72 | 73 | def __init__(self): 74 | super().__init__() 75 | self.criterion = nn.CrossEntropyLoss() 76 | 77 | def forward(self, prediction: torch.Tensor, target: ASTTarget): 78 | return self.criterion(prediction.view(-1, prediction.size()[-1]), target.terminals.view(-1)) 79 | 80 | 81 | # endregion 82 | 83 | # region Metrics 84 | 85 | class NonTerminalMetrics(Metrics): 86 | 87 | def __init__(self, base: Metrics): 88 | super().__init__() 89 | self.base = base 90 | 91 | def drop_state(self): 92 | self.base.drop_state() 93 | 94 | def report(self, prediction_target): 95 | prediction, target = prediction_target 96 | self.base.report((prediction, target.non_terminals)) 97 | 98 | def get_current_value(self, should_print=False): 99 | return self.base.get_current_value(should_print=should_print) 100 | 101 | 102 | class TerminalMetrics(Metrics): 103 | 104 | def __init__(self, base: Metrics): 105 | super().__init__() 106 | self.base = base 107 | 108 | def drop_state(self): 109 | self.base.drop_state() 110 | 111 | def report(self, prediction_target): 112 | prediction, target = prediction_target 113 | self.base.report((prediction, target.terminals)) 114 | 115 | def get_current_value(self, should_print=False): 116 | return self.base.get_current_value(should_print=should_print) 117 | 118 | 119 | # endregion 120 | 121 | # region Routine 122 | 123 | def run_model(model, iter_data, hidden, batch_size): 124 | (m_input, m_target), forget_vector = iter_data 125 | assert forget_vector.size()[0] == batch_size 126 | 127 | m_input = ASTInput.setup(m_input) 128 | m_target = ASTTarget.setup(m_target) 129 | 130 | m_input.current_non_terminals = m_target.non_terminals 131 | 132 | if hidden is None: 133 | hidden = model.init_hidden(batch_size=batch_size) 134 | 135 | prediction, hidden = model(m_input, hidden, forget_vector=forget_vector) 136 | 137 | return prediction, m_target, hidden 138 | 139 | 140 | class ASTRoutine(NetworkRoutine): 141 | 142 | def __init__(self, model, batch_size, seq_len, criterion: ASTLoss, optimizers): 143 | super().__init__(model) 144 | self.model = self.network 145 | self.batch_size = batch_size 146 | self.seq_len = seq_len 147 | self.criterion = criterion 148 | self.optimizers = optimizers 149 | 150 | self.hidden = None 151 | 152 | def optimize(self, loss): 153 | # Backward pass 154 | loss.backward() 155 | torch.nn.utils.clip_grad_norm_(filter_requires_grad(self.model.parameters()), 5) 156 | # torch.nn.utils.clip_grad_norm_(filter_requires_grad(self.model.sparse_parameters()), 5) 157 | 158 | # Optimizer step 159 | for optimizer in self.optimizers: 160 | optimizer.step() 161 | 162 | def run(self, iter_num, iter_data): 163 | if self.optimizers is not None: 164 | for optimizer in self.optimizers: 165 | optimizer.zero_grad() 166 | 167 | prediction, target, hidden = run_model( 168 | model=self.model, 169 | iter_data=iter_data, 170 | hidden=self.hidden, 171 | batch_size=self.batch_size 172 | ) 173 | self.hidden = hidden 174 | 175 | if self.optimizers is not None: 176 | loss = self.criterion(prediction, target) 177 | self.optimize(loss) 178 | 179 | return prediction, target 180 | 181 | 182 | # endregion 183 | 184 | class ASTMain(Main): 185 | def __init__(self, args): 186 | self.non_terminal_embeddings = self.create_non_terminal_embeddings(args) 187 | self.terminal_embeddings = self.create_terminal_embeddings(args) 188 | super().__init__(args) 189 | 190 | @abstractmethod 191 | def create_model(self, args): 192 | pass 193 | 194 | @abstractmethod 195 | def create_criterion(self, args): 196 | pass 197 | 198 | @abstractmethod 199 | def create_train_metrics(self, args): 200 | pass 201 | 202 | def create_data_generator(self, args): 203 | return create_data_generator(args) 204 | 205 | def create_terminal_embeddings(self, args): 206 | return None 207 | 208 | def create_non_terminal_embeddings(self, args): 209 | return None 210 | 211 | def create_train_routine(self, args): 212 | return ASTRoutine( 213 | model=self.model, 214 | batch_size=args.batch_size, 215 | seq_len=args.seq_len, 216 | criterion=self.criterion, 217 | optimizers=self.optimizers 218 | ) 219 | 220 | def create_validation_routine(self, args): 221 | return ASTRoutine( 222 | model=self.model, 223 | batch_size=args.batch_size, 224 | seq_len=args.seq_len, 225 | criterion=self.criterion, 226 | optimizers=None 227 | ) 228 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from zerogercrnn.experiments.ast_level.nt2n_base_attention.main import NT2NBaseAttentionMain 7 | from zerogercrnn.experiments.ast_level.nt2n_base.main import NT2NBaseMain 8 | from zerogercrnn.experiments.ast_level.ntn2t_base.main import NTN2TBaseMain 9 | from zerogercrnn.experiments.ast_level.nt2n_base_attention_plus_layered.main import NT2NBaseAttentionPlusLayeredMain 10 | from zerogercrnn.experiments.ast_level.vis.utils import draw_line_plot, visualize_tensor 11 | from zerogercrnn.lib.argutils import add_general_arguments, add_batching_data_args, add_optimization_args, \ 12 | add_recurrent_core_args, add_non_terminal_args, add_terminal_args 13 | from zerogercrnn.lib.log import logger 14 | 15 | parser = argparse.ArgumentParser(description='AST level neural network') 16 | add_general_arguments(parser) 17 | add_batching_data_args(parser) 18 | add_optimization_args(parser) 19 | add_recurrent_core_args(parser) 20 | add_non_terminal_args(parser) 21 | add_terminal_args(parser) 22 | 23 | parser.add_argument('--prediction', type=str, help='One of: nt2n, nt2n_pre, nt2n_tail, nt2n_sum, nt2nt, ntn2t') 24 | parser.add_argument('--save_model_every', type=int, help='How often to save model', default=1) 25 | 26 | # This is for evaluation purposes 27 | parser.add_argument('--eval', action='store_true', help='Evaluate or train') 28 | parser.add_argument('--eval_results_directory', type=str, help='Where to save results of evaluation') 29 | 30 | # Grid search parameters 31 | parser.add_argument('--grid_name', type=str, help='Parameter to grid search') 32 | parser.add_argument( 33 | '--grid_values', nargs='+', type=int, 34 | help='Values for grid searching' 35 | ) # how to make it int or float? 36 | 37 | # Additional parameters for specific models 38 | parser.add_argument( 39 | '--node_depths_embedding_dim', type=int, 40 | help='Dimension of continuous representation of node depth' 41 | ) 42 | parser.add_argument( 43 | '--nodes_depths_stat_file', type=str, 44 | help='File with number of times particular depth is occurred in train file' 45 | ) 46 | 47 | 48 | def get_main(args): 49 | if args.prediction == 'nt2n_base': 50 | main = NT2NBaseMain(args) 51 | elif args.prediction == 'ntn2t_base': 52 | main = NTN2TBaseMain(args) 53 | elif args.prediction == 'nt2n_base_attention': 54 | main = NT2NBaseAttentionMain(args) 55 | elif args.prediction == 'nt2n_base_attention_plus_layered': 56 | main = NT2NBaseAttentionPlusLayeredMain(args) 57 | else: 58 | raise Exception('Not supported prediction type: {}'.format(args.prediction)) 59 | 60 | return main 61 | 62 | 63 | def train(args): 64 | get_main(args).train(args) 65 | 66 | 67 | def evaluate(args): 68 | if args.saved_model is None: 69 | print('WARNING: Running eval without saved_model. Not a good idea') 70 | get_main(args).eval(args) 71 | 72 | 73 | def grid_search(args): 74 | parameter_name = args.grid_name 75 | parameter_values = args.grid_values 76 | 77 | initial_title = args.title 78 | initial_save_dir = args.model_save_dir 79 | 80 | for p in parameter_values: 81 | suffix = '_grid_' + parameter_name + '_' + str(p) 82 | args.title = initial_title + suffix 83 | args.model_save_dir = initial_save_dir + suffix 84 | if not os.path.exists(args.model_save_dir): 85 | os.makedirs(args.model_save_dir) 86 | 87 | setattr(args, parameter_name, p) 88 | 89 | main = get_main(args) 90 | main.train(args) 91 | 92 | 93 | def visualize(args): 94 | main = get_main(args) 95 | model = main.model 96 | 97 | h2o = model.h2o.affine.weight 98 | h2o_line = torch.sum(h2o, dim=0).detach().numpy() 99 | 100 | draw_line_plot(h2o_line) 101 | visualize_tensor(h2o) 102 | 103 | 104 | if __name__ == '__main__': 105 | print(torch.__version__) 106 | _args = parser.parse_args() 107 | assert _args.title is not None 108 | logger.should_log = _args.log 109 | 110 | if _args.grid_name is not None: 111 | grid_search(_args) 112 | elif _args.eval: 113 | evaluate(_args) 114 | else: 115 | train(_args) 116 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/ast_level/nt2n_base/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base/main.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.experiments.ast_level.common import ASTMain, NonTerminalMetrics, NonTerminalsCrossEntropyLoss 2 | from zerogercrnn.experiments.ast_level.metrics import NonTerminalsMetricsWrapper, SingleNonTerminalAccuracyMetrics 3 | from zerogercrnn.experiments.ast_level.nt2n_base.model import NT2NBaseModel 4 | from zerogercrnn.lib.metrics import SequentialMetrics, MaxPredictionAccuracyMetrics, ResultsSaver, MaxPredictionWrapper, TopKWrapper 5 | 6 | 7 | class NT2NBaseMain(ASTMain): 8 | def create_model(self, args): 9 | return NT2NBaseModel( 10 | non_terminals_num=args.non_terminals_num, 11 | non_terminal_embedding_dim=args.non_terminal_embedding_dim, 12 | terminals_num=args.terminals_num, 13 | terminal_embedding_dim=args.terminal_embedding_dim, 14 | hidden_dim=args.hidden_size, 15 | num_layers=args.num_layers, 16 | dropout=args.dropout 17 | ) 18 | 19 | def create_criterion(self, args): 20 | return NonTerminalsCrossEntropyLoss() 21 | 22 | def create_train_metrics(self, args): 23 | return NonTerminalMetrics(base=MaxPredictionAccuracyMetrics()) 24 | 25 | def create_eval_metrics(self, args): 26 | return SequentialMetrics([ 27 | NonTerminalMetrics(base=MaxPredictionAccuracyMetrics()), 28 | SingleNonTerminalAccuracyMetrics( 29 | non_terminals_file=args.non_terminals_file, 30 | results_dir=args.eval_results_directory 31 | ), 32 | NonTerminalsMetricsWrapper(TopKWrapper(base=ResultsSaver(dir_to_save=args.eval_results_directory))) 33 | ]) -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from zerogercrnn.experiments.ast_level.data import ASTInput 4 | from zerogercrnn.lib.core import CombinedModule, EmbeddingsModule, RecurrentCore, LinearLayer 5 | from zerogercrnn.lib.utils import repackage_hidden, forget_hidden_partly 6 | from zerogercrnn.experiments.ast_level.ast_core import ASTNT2NModule 7 | 8 | 9 | class NT2NBaseModel(ASTNT2NModule): 10 | def __init__( 11 | self, 12 | non_terminals_num, 13 | non_terminal_embedding_dim, 14 | terminals_num, 15 | terminal_embedding_dim, 16 | hidden_dim, 17 | num_layers, 18 | dropout 19 | ): 20 | super().__init__( 21 | non_terminals_num=non_terminals_num, 22 | non_terminal_embedding_dim=non_terminal_embedding_dim, 23 | terminals_num=terminals_num, 24 | terminal_embedding_dim=terminal_embedding_dim, 25 | recurrent_output_size=hidden_dim 26 | ) 27 | self.hidden_dim = hidden_dim 28 | self.num_layers = num_layers 29 | self.dropout = dropout 30 | 31 | self.recurrent_core = self.module(RecurrentCore( 32 | input_size=self.non_terminal_embedding_dim + self.terminal_embedding_dim, 33 | hidden_size=self.hidden_dim, 34 | num_layers=self.num_layers, 35 | dropout=self.dropout, 36 | model_type='lstm' 37 | )) 38 | 39 | def get_recurrent_output(self, combined_input, ast_input: ASTInput, m_hidden, forget_vector): 40 | hidden = m_hidden 41 | 42 | hidden = forget_hidden_partly(hidden, forget_vector=forget_vector) 43 | hidden = repackage_hidden(hidden) 44 | 45 | recurrent_output, new_hidden = self.recurrent_core(combined_input, hidden) 46 | 47 | return recurrent_output, new_hidden 48 | 49 | def init_hidden(self, batch_size): 50 | return self.recurrent_core.init_hidden(batch_size) 51 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/ast_level/nt2n_base_attention/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base_attention/main.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.experiments.ast_level.common import ASTMain, NonTerminalMetrics, NonTerminalsCrossEntropyLoss 2 | from zerogercrnn.experiments.ast_level.metrics import NonTerminalsMetricsWrapper, SingleNonTerminalAccuracyMetrics 3 | from zerogercrnn.experiments.ast_level.nt2n_base_attention.model import NT2NBaseAttentionModel 4 | from zerogercrnn.lib.metrics import SequentialMetrics, MaxPredictionAccuracyMetrics, ResultsSaver, MaxPredictionWrapper, TopKWrapper 5 | 6 | 7 | class NT2NBaseAttentionMain(ASTMain): 8 | def create_model(self, args): 9 | return NT2NBaseAttentionModel( 10 | non_terminals_num=args.non_terminals_num, 11 | non_terminal_embedding_dim=args.non_terminal_embedding_dim, 12 | terminals_num=args.terminals_num, 13 | terminal_embedding_dim=args.terminal_embedding_dim, 14 | hidden_dim=args.hidden_size, 15 | num_layers=args.num_layers, 16 | dropout=args.dropout, 17 | is_eval=args.eval 18 | ) 19 | 20 | def create_criterion(self, args): 21 | return NonTerminalsCrossEntropyLoss() 22 | 23 | def create_train_metrics(self, args): 24 | return NonTerminalMetrics(base=MaxPredictionAccuracyMetrics()) 25 | 26 | def create_eval_metrics(self, args): 27 | return SequentialMetrics([ 28 | NonTerminalMetrics(base=MaxPredictionAccuracyMetrics()), 29 | SingleNonTerminalAccuracyMetrics( 30 | non_terminals_file=args.non_terminals_file, 31 | results_dir=args.eval_results_directory 32 | ), 33 | NonTerminalsMetricsWrapper(TopKWrapper(base=ResultsSaver(dir_to_save=args.eval_results_directory))) 34 | ]) 35 | 36 | def register_eval_hooks(self): 37 | return add_eval_hooks(self.model) 38 | 39 | 40 | def add_eval_hooks(model: NT2NBaseAttentionModel): 41 | return [model.last_k_attention.attn_metrics] 42 | 43 | 44 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base_attention/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from zerogercrnn.experiments.ast_level.data import ASTInput 5 | from zerogercrnn.lib.core import CombinedModule, EmbeddingsModule, RecurrentCore, LinearLayer, LSTMCellDropout 6 | from zerogercrnn.lib.attn import Attn 7 | from zerogercrnn.lib.utils import repackage_hidden, forget_hidden_partly, get_best_device, setup_tensor, forget_hidden_partly_lstm_cell 8 | from zerogercrnn.lib.calculation import calc_attention_combination, set_layered_hidden 9 | from zerogercrnn.experiments.ast_level.metrics import PerNtAttentionMetrics 10 | 11 | 12 | class LastKBuffer: 13 | def __init__(self, window_len, hidden_size, is_eval=False): 14 | self.buffer = None 15 | self.window_len = window_len 16 | self.hidden_size = hidden_size 17 | self.it = 0 18 | self.is_eval = is_eval 19 | 20 | def add_vector(self, vector): 21 | self.buffer[self.it] = vector 22 | self.it += 1 23 | if self.it >= self.window_len: 24 | self.it = 0 25 | 26 | def get(self): 27 | return torch.stack(self.buffer, dim=1) 28 | 29 | def init_buffer(self, batch_size): 30 | c = 1 31 | if self.is_eval: 32 | c = 2 33 | 34 | self.buffer = [setup_tensor(torch.zeros((batch_size, self.hidden_size))) for _ in range(c * self.window_len)] 35 | 36 | def repackage_and_forget_buffer_partly(self, forget_vector): 37 | # self.buffer = self.buffer.mul(forget_vector.unsqueeze(1)) TODO: implement forgetting 38 | self.buffer = [repackage_hidden(b) for b in self.buffer] 39 | 40 | 41 | class LastKAttention(CombinedModule): 42 | def __init__(self, hidden_size, k=50, is_eval=False): 43 | super().__init__() 44 | self.hidden_size = hidden_size 45 | self.k = k 46 | self.context_buffer = None 47 | self.attn = self.module(Attn(method='general', hidden_size=self.hidden_size)) 48 | 49 | self.is_eval = is_eval 50 | if self.is_eval: 51 | self.attn_metrics = PerNtAttentionMetrics() 52 | 53 | def repackage_and_forget_buffer_partly(self, forget_vector): 54 | self.context_buffer.repackage_and_forget_buffer_partly(forget_vector) 55 | 56 | def init_hidden(self, batch_size): 57 | self.context_buffer = LastKBuffer(window_len=self.k, hidden_size=self.hidden_size) 58 | self.context_buffer.init_buffer(batch_size) 59 | 60 | def forward(self, current_input, current_hidden): 61 | if self.context_buffer is None: 62 | raise Exception('You should init buffer first') 63 | 64 | current_buffer = self.context_buffer.get() 65 | attn_output_coefficients = self.attn(current_hidden, current_buffer) 66 | attn_output = calc_attention_combination(attn_output_coefficients, current_buffer) 67 | 68 | if self.is_eval: 69 | self.attn_metrics.report(current_input, attn_output_coefficients) 70 | 71 | self.context_buffer.add_vector(current_hidden) 72 | return attn_output 73 | 74 | 75 | class NT2NBaseAttentionModel(CombinedModule): 76 | """Base Model with attention on last n hidden states of LSTM.""" 77 | 78 | def __init__( 79 | self, 80 | non_terminals_num, 81 | non_terminal_embedding_dim, 82 | terminals_num, 83 | terminal_embedding_dim, 84 | hidden_dim, 85 | num_layers, 86 | dropout, 87 | is_eval 88 | ): 89 | super().__init__() 90 | 91 | self.non_terminals_num = non_terminals_num 92 | self.non_terminal_embedding_dim = non_terminal_embedding_dim 93 | self.terminals_num = terminals_num 94 | self.terminal_embedding_dim = terminal_embedding_dim 95 | self.hidden_dim = hidden_dim 96 | self.num_layers = num_layers 97 | self.dropout = dropout 98 | 99 | self.nt_embedding = self.module(EmbeddingsModule( 100 | num_embeddings=self.non_terminals_num, 101 | embedding_dim=self.non_terminal_embedding_dim, 102 | sparse=True 103 | )) 104 | 105 | self.t_embedding = self.module(EmbeddingsModule( 106 | num_embeddings=self.terminals_num, 107 | embedding_dim=self.terminal_embedding_dim, 108 | sparse=True 109 | )) 110 | 111 | # self.recurrent_core = self.module(RecurrentCore( 112 | # input_size=self.non_terminal_embedding_dim + self.terminal_embedding_dim, 113 | # hidden_size=self.hidden_dim, 114 | # num_layers=self.num_layers, 115 | # dropout=self.dropout, 116 | # model_type='lstm' 117 | # )) 118 | 119 | self.recurrent_cell = self.module(LSTMCellDropout( 120 | input_size=self.non_terminal_embedding_dim + self.terminal_embedding_dim, 121 | hidden_size=self.hidden_dim, 122 | dropout=self.dropout 123 | )) 124 | 125 | self.last_k_attention = self.module(LastKAttention( 126 | hidden_size=self.hidden_dim, 127 | k=50, 128 | is_eval=is_eval 129 | )) 130 | 131 | self.h2o = self.module(LinearLayer( 132 | input_size=2 * self.hidden_dim, 133 | output_size=self.non_terminals_num 134 | )) 135 | 136 | def forward(self, m_input: ASTInput, hidden, forget_vector): 137 | non_terminal_input = m_input.non_terminals 138 | terminal_input = m_input.terminals 139 | 140 | nt_embedded = self.nt_embedding(non_terminal_input) 141 | t_embedded = self.t_embedding(terminal_input) 142 | 143 | combined_input = torch.cat([nt_embedded, t_embedded], dim=2) 144 | 145 | recurrent_output, new_hidden, attn_output = self.get_recurrent_layers_outputs( 146 | ast_input=m_input, 147 | combined_input=combined_input, 148 | hidden=hidden, 149 | forget_vector=forget_vector 150 | ) 151 | 152 | concatenated_output = torch.cat((recurrent_output, attn_output), dim=-1) 153 | prediction = self.h2o(concatenated_output) 154 | 155 | return prediction, new_hidden 156 | 157 | def get_recurrent_layers_outputs( 158 | self, ast_input: ASTInput, combined_input, hidden, forget_vector): 159 | hidden = repackage_hidden(forget_hidden_partly_lstm_cell(hidden, forget_vector=forget_vector)) 160 | self.last_k_attention.repackage_and_forget_buffer_partly(forget_vector) 161 | 162 | recurrent_output = [] 163 | layered_attn_output = [] 164 | for i in range(combined_input.size()[0]): 165 | reinit_dropout = i == 0 166 | 167 | # core recurrent part 168 | cur_h, cur_c = self.recurrent_cell(combined_input[i], hidden, reinit_dropout=reinit_dropout) 169 | hidden = (cur_h, cur_c) 170 | recurrent_output.append(cur_h) 171 | 172 | # layered part 173 | attn_output = self.last_k_attention(ast_input.non_terminals[i], cur_h) 174 | layered_attn_output.append(attn_output) 175 | 176 | 177 | # combine outputs from different layers 178 | recurrent_output = torch.stack(recurrent_output, dim=0) 179 | layered_attn_output = torch.stack(layered_attn_output, dim=0) 180 | 181 | return recurrent_output, hidden, layered_attn_output 182 | 183 | def init_hidden(self, batch_size): 184 | self.last_k_attention.init_hidden(batch_size) 185 | return self.recurrent_cell.init_hidden(batch_size) 186 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base_attention_plus_layered/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/ast_level/nt2n_base_attention_plus_layered/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base_attention_plus_layered/main.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.experiments.ast_level.common import ASTMain, NonTerminalMetrics, NonTerminalsCrossEntropyLoss 2 | from zerogercrnn.experiments.ast_level.metrics import NonTerminalsMetricsWrapper, SingleNonTerminalAccuracyMetrics 3 | from zerogercrnn.experiments.ast_level.nt2n_base_attention_plus_layered.model import NT2NBaseAttentionPlusLayeredModel 4 | from zerogercrnn.lib.metrics import SequentialMetrics, MaxPredictionAccuracyMetrics, ResultsSaver, MaxPredictionWrapper, TopKWrapper, FeaturesMeanVarianceMetrics 5 | from zerogercrnn.lib.utils import register_input_hook 6 | 7 | 8 | class NT2NBaseAttentionPlusLayeredMain(ASTMain): 9 | def create_model(self, args): 10 | return NT2NBaseAttentionPlusLayeredModel( 11 | non_terminals_num=args.non_terminals_num, 12 | non_terminal_embedding_dim=args.non_terminal_embedding_dim, 13 | terminals_num=args.terminals_num, 14 | terminal_embedding_dim=args.terminal_embedding_dim, 15 | hidden_dim=args.hidden_size, 16 | layered_hidden_size=args.layered_hidden_size, 17 | num_tree_layers=args.num_tree_layers, 18 | dropout=args.dropout 19 | ) 20 | 21 | def create_criterion(self, args): 22 | return NonTerminalsCrossEntropyLoss() 23 | 24 | def create_train_metrics(self, args): 25 | return NonTerminalMetrics(base=MaxPredictionAccuracyMetrics()) 26 | 27 | def create_eval_metrics(self, args): 28 | return SequentialMetrics([ 29 | NonTerminalMetrics(base=MaxPredictionAccuracyMetrics()), 30 | SingleNonTerminalAccuracyMetrics( 31 | non_terminals_file=args.non_terminals_file, 32 | results_dir=args.eval_results_directory 33 | ), 34 | NonTerminalsMetricsWrapper(TopKWrapper(base=ResultsSaver(dir_to_save=args.eval_results_directory))) 35 | ]) 36 | 37 | def register_eval_hooks(self): 38 | return [] 39 | 40 | 41 | def add_eval_hooks(model: NT2NBaseAttentionPlusLayeredModel): 42 | metrics = FeaturesMeanVarianceMetrics() 43 | register_input_hook(model.h2o, metrics) 44 | 45 | return [metrics] 46 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/nt2n_base_attention_plus_layered/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from zerogercrnn.experiments.ast_level.data import ASTInput 5 | from zerogercrnn.lib.attn import calc_attention_combination 6 | from zerogercrnn.lib.core import CombinedModule, EmbeddingsModule, RecurrentCore, LinearLayer, LSTMCellDropout, \ 7 | LayeredRecurrentUpdateAfter 8 | from zerogercrnn.lib.utils import repackage_hidden, forget_hidden_partly, get_best_device, setup_tensor, \ 9 | forget_hidden_partly_lstm_cell 10 | from zerogercrnn.experiments.ast_level.ast_core import LastKAttention 11 | from zerogercrnn.lib.attn import Attn 12 | from zerogercrnn.experiments.ast_level.ast_core import ASTNT2NModule 13 | 14 | 15 | class LayeredAttentionRecurrent(LayeredRecurrentUpdateAfter): 16 | 17 | def pick_current_output(self, layered_hidden, nodes_depth): 18 | return None 19 | 20 | 21 | class NT2NBaseAttentionPlusLayeredModel(ASTNT2NModule): 22 | """Base Model with attention on last n hidden states of LSTM.""" 23 | 24 | def __init__( 25 | self, 26 | non_terminals_num, 27 | non_terminal_embedding_dim, 28 | terminals_num, 29 | terminal_embedding_dim, 30 | hidden_dim, 31 | layered_hidden_size, 32 | num_tree_layers, 33 | dropout 34 | ): 35 | super().__init__( 36 | non_terminals_num=non_terminals_num, 37 | non_terminal_embedding_dim=non_terminal_embedding_dim, 38 | terminals_num=terminals_num, 39 | terminal_embedding_dim=terminal_embedding_dim, 40 | recurrent_output_size=2 * hidden_dim + layered_hidden_size 41 | ) 42 | 43 | self.hidden_dim = hidden_dim 44 | self.layered_hidden_size = layered_hidden_size 45 | self.dropout = dropout 46 | self.num_tree_layers = num_tree_layers 47 | 48 | self.recurrent_cell = self.module(LSTMCellDropout( 49 | input_size=self.non_terminal_embedding_dim + self.terminal_embedding_dim, 50 | hidden_size=self.hidden_dim, 51 | dropout=self.dropout 52 | )) 53 | self.layered_attention = self.module(Attn(method='general', hidden_size=self.layered_hidden_size)) 54 | 55 | self.last_k_attention = self.module(LastKAttention( 56 | hidden_size=self.hidden_dim, 57 | k=50 58 | )) 59 | 60 | self.layered_recurrent = self.module(LayeredAttentionRecurrent( 61 | input_size=self.non_terminal_embedding_dim + self.terminal_embedding_dim, 62 | num_tree_layers=self.num_tree_layers, 63 | single_hidden_size=self.layered_hidden_size 64 | )) 65 | 66 | def get_recurrent_output(self, combined_input, ast_input: ASTInput, m_hidden, forget_vector): 67 | hidden, layered_hidden = m_hidden 68 | nodes_depth = ast_input.nodes_depth 69 | 70 | # repackage hidden and forgot hidden if program file changed 71 | hidden = repackage_hidden(forget_hidden_partly_lstm_cell(hidden, forget_vector=forget_vector)) 72 | layered_hidden = LayeredRecurrentUpdateAfter.repackage_and_partly_forget_hidden( 73 | layered_hidden=layered_hidden, 74 | forget_vector=forget_vector 75 | ) 76 | self.last_k_attention.repackage_and_forget_buffer_partly(forget_vector) 77 | 78 | # prepare node depths (store only self.num_tree_layers) 79 | nodes_depth = torch.clamp(nodes_depth, min=0, max=self.num_tree_layers - 1) 80 | 81 | recurrent_output = [] 82 | attn_output = [] 83 | layered_output = [] 84 | b_h = None 85 | for i in range(combined_input.size()[0]): 86 | reinit_dropout = i == 0 87 | 88 | # core recurrent part 89 | cur_h, cur_c = self.recurrent_cell(combined_input[i], hidden, reinit_dropout=reinit_dropout) 90 | hidden = (cur_h, cur_c) 91 | b_h = hidden 92 | recurrent_output.append(cur_h) 93 | 94 | # attn part 95 | cur_attn_output = self.last_k_attention(cur_h) 96 | attn_output.append(cur_attn_output) 97 | 98 | # layered part 99 | l_h, l_c = self.layered_recurrent( 100 | combined_input[i], 101 | nodes_depth[i], 102 | layered_hidden=layered_hidden, 103 | reinit_dropout=reinit_dropout 104 | ) 105 | 106 | layered_hidden = LayeredRecurrentUpdateAfter.update_layered_lstm_hidden( 107 | layered_hidden=layered_hidden, 108 | node_depths=nodes_depth[i], 109 | new_value=(l_h, l_c) 110 | ) 111 | 112 | layered_output_coefficients = self.layered_attention(l_h, layered_hidden[0]) 113 | cur_layered_output = calc_attention_combination(layered_output_coefficients, layered_hidden[0]) 114 | layered_output.append(cur_layered_output) # maybe cat? 115 | 116 | # combine outputs from different layers 117 | recurrent_output = torch.stack(recurrent_output, dim=0) 118 | attn_output = torch.stack(attn_output, dim=0) 119 | layered_output = torch.stack(layered_output, dim=0) 120 | 121 | assert b_h == hidden 122 | concatenated_output = torch.cat((recurrent_output, attn_output, layered_output), dim=-1) 123 | 124 | return concatenated_output, (hidden, layered_hidden) 125 | 126 | def init_hidden(self, batch_size): 127 | self.last_k_attention.init_hidden(batch_size) 128 | return self.recurrent_cell.init_hidden(batch_size), \ 129 | self.layered_recurrent.init_hidden(batch_size) 130 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/ntn2t_base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/ast_level/ntn2t_base/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/ntn2t_base/main.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.experiments.ast_level.common import ASTMain, TerminalMetrics, TerminalsCrossEntropyLoss 2 | from zerogercrnn.experiments.ast_level.ntn2t_base.model import NTN2TBaseModel 3 | from zerogercrnn.lib.metrics import MaxPredictionAccuracyMetrics, TopKWrapper, ResultsSaver 4 | 5 | 6 | class NTN2TBaseMain(ASTMain): 7 | 8 | def create_model(self, args): 9 | return NTN2TBaseModel( 10 | non_terminals_num=args.non_terminals_num, 11 | non_terminal_embedding_dim=args.non_terminal_embedding_dim, 12 | terminals_num=args.terminals_num, 13 | terminal_embedding_dim=args.terminal_embedding_dim, 14 | hidden_dim=args.hidden_size, 15 | num_layers=args.num_layers, 16 | dropout=args.dropout 17 | ) 18 | 19 | def create_criterion(self, args): 20 | return TerminalsCrossEntropyLoss() 21 | 22 | def create_train_metrics(self, args): 23 | return TerminalMetrics(base=MaxPredictionAccuracyMetrics()) 24 | 25 | def create_eval_metrics(self, args): 26 | return TerminalMetrics(TopKWrapper(base=ResultsSaver(dir_to_save=args.eval_results_directory))) 27 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/ntn2t_base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from zerogercrnn.experiments.ast_level.data import ASTInput 4 | from zerogercrnn.lib.core import EmbeddingsModule, RecurrentCore, \ 5 | LinearLayer, CombinedModule 6 | from zerogercrnn.lib.utils import forget_hidden_partly, repackage_hidden 7 | 8 | 9 | class NTN2TBaseModel(CombinedModule): 10 | """Pure NT2N Model (no pretrained embeddings, no attention, no layered hidden)""" 11 | 12 | def __init__( 13 | self, 14 | non_terminals_num, 15 | non_terminal_embedding_dim, 16 | terminals_num, 17 | terminal_embedding_dim, 18 | hidden_dim, 19 | num_layers, 20 | dropout 21 | ): 22 | super().__init__() 23 | 24 | self.non_terminals_num = non_terminals_num 25 | self.non_terminal_embedding_dim = non_terminal_embedding_dim 26 | self.terminals_num = terminals_num 27 | self.terminal_embedding_dim = terminal_embedding_dim 28 | self.hidden_dim = hidden_dim 29 | self.num_layers = num_layers 30 | self.dropout = dropout 31 | 32 | self.nt_embedding = self.module(EmbeddingsModule( 33 | num_embeddings=self.non_terminals_num, 34 | embedding_dim=self.non_terminal_embedding_dim, 35 | sparse=True 36 | )) 37 | 38 | self.t_embedding = self.module(EmbeddingsModule( 39 | num_embeddings=self.terminals_num, 40 | embedding_dim=self.terminal_embedding_dim, 41 | sparse=True 42 | )) 43 | 44 | self.recurrent_core = self.module(RecurrentCore( 45 | input_size=2 * self.non_terminal_embedding_dim + self.terminal_embedding_dim, 46 | hidden_size=self.hidden_dim, 47 | num_layers=self.num_layers, 48 | dropout=self.dropout, 49 | model_type='lstm' 50 | )) 51 | 52 | self.h2t = self.module(LinearLayer( 53 | input_size=self.hidden_dim, 54 | output_size=self.terminals_num 55 | )) 56 | 57 | def forward(self, m_input: ASTInput, hidden, forget_vector): 58 | non_terminal_input = m_input.non_terminals 59 | terminal_input = m_input.terminals 60 | current_non_terminal_input = m_input.current_non_terminals 61 | 62 | nt_embedded = self.nt_embedding(non_terminal_input) 63 | t_embedded = self.t_embedding(terminal_input) 64 | cur_nt_embedded = self.nt_embedding(current_non_terminal_input) 65 | 66 | combined_input = torch.cat([nt_embedded, cur_nt_embedded, t_embedded], dim=2) 67 | 68 | hidden = repackage_hidden(hidden) 69 | hidden = forget_hidden_partly(hidden, forget_vector=forget_vector) 70 | recurrent_output, new_hidden = self.recurrent_core(combined_input, hidden) 71 | 72 | prediction = self.h2t(recurrent_output) 73 | 74 | return prediction, new_hidden 75 | 76 | def init_hidden(self, batch_size): 77 | return self.recurrent_core.init_hidden(batch_size) 78 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from zerogercrnn.experiments.ast_level.main import get_main 6 | from zerogercrnn.lib.argutils import add_general_arguments, add_batching_data_args, add_optimization_args, \ 7 | add_recurrent_core_args, add_non_terminal_args, add_terminal_args 8 | from zerogercrnn.lib.log import logger 9 | from zerogercrnn.lib.log import tqdm_lim 10 | from zerogercrnn.lib.metrics import MaxPredictionAccuracyMetrics 11 | 12 | parser = argparse.ArgumentParser(description='AST level neural network') 13 | add_general_arguments(parser) 14 | add_batching_data_args(parser) 15 | add_optimization_args(parser) 16 | add_recurrent_core_args(parser) 17 | add_non_terminal_args(parser) 18 | add_terminal_args(parser) 19 | parser.add_argument('--terminal_embeddings_file', type=str, help='File with pretrained terminal embeddings') 20 | parser.add_argument('--prediction', type=str, help='One of: nt2n, nt2nt, ntn2t') 21 | 22 | 23 | def print_results(args): 24 | # assert args.prediction == 'nt2n' 25 | 26 | # seed = 1000 27 | # random.seed(seed) 28 | # numpy.random.seed(seed) 29 | 30 | main = get_main(args) 31 | 32 | routine = main.validation_routine 33 | 34 | metrics = MaxPredictionAccuracyMetrics() 35 | metrics.drop_state() 36 | main.model.eval() 37 | 38 | for iter_num, iter_data in enumerate(tqdm_lim(main.data_generator.get_eval_generator(), lim=1000)): 39 | metrics_data = routine.run(iter_num, iter_data) 40 | metrics.report(metrics_data) 41 | metrics.get_current_value(should_print=True) 42 | 43 | 44 | if __name__ == '__main__': 45 | _args = parser.parse_args() 46 | assert _args.title is not None 47 | logger.should_log = _args.log 48 | 49 | print_results(_args) 50 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/utils.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.lib.constants import EOF_TOKEN, EMPTY_TOKEN, UNKNOWN_TOKEN, EOF_TOKEN_ID, EMPTY_TOKEN_ID, \ 2 | UNKNOWN_TOKEN_ID 3 | from zerogercrnn.lib.preprocess import read_json 4 | 5 | DEFAULT_TERMINALS_FILE = 'data/ast/terminals.json' 6 | DEFAULT_NON_TERMINALS_FILE = 'data/ast/non_terminals.json' 7 | 8 | 9 | def read_terminals(terminals_file=DEFAULT_TERMINALS_FILE): 10 | """Returns all terminals in order that they are coded in file_train. """ 11 | terminals = read_json(terminals_file)[:50000 - 1] 12 | return [EMPTY_TOKEN] + terminals + [UNKNOWN_TOKEN] 13 | 14 | 15 | def read_non_terminals(non_terminals_file=DEFAULT_NON_TERMINALS_FILE): 16 | """Returns all non-terminals in order that they are coded in file_train. """ 17 | non_terminals = read_json(non_terminals_file) 18 | return non_terminals + [EOF_TOKEN] 19 | 20 | 21 | def get_str2id(strings_array): 22 | """Returns map from string to index in array. """ 23 | str2id = {} 24 | for i in range(len(strings_array)): 25 | str2id[strings_array[i]] = i 26 | 27 | return str2id 28 | 29 | 30 | if __name__ == '__main__': 31 | nt = read_non_terminals() 32 | t = read_terminals() 33 | 34 | nt2id = get_str2id(nt) 35 | t2id = get_str2id(t) 36 | 37 | assert EOF_TOKEN_ID == nt2id[EOF_TOKEN] 38 | assert EMPTY_TOKEN_ID == t2id[EMPTY_TOKEN] 39 | assert UNKNOWN_TOKEN_ID == t2id[UNKNOWN_TOKEN] 40 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/ast_level/vis/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/accuracies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import json 4 | import os 5 | 6 | from zerogercrnn.lib.preprocess import read_json 7 | from zerogercrnn.experiments.ast_level.data import ASTTarget 8 | from zerogercrnn.experiments.ast_level.metrics import NonTerminalsMetricsWrapper 9 | from zerogercrnn.experiments.ast_level.metrics import SingleNonTerminalAccuracyMetrics, EmptyNonEmptyWrapper, EmptyNonEmptyTerminalTopKAccuracyWrapper 10 | from zerogercrnn.experiments.token_level.metrics import AggregatedTokenMetrics 11 | from zerogercrnn.lib.metrics import SequentialMetrics, BaseAccuracyMetrics, TopKAccuracy 12 | 13 | from zerogercrnn.experiments.pyast.metrics import PythonPerNonTerminalAccuracyMetrics 14 | 15 | # region Utils 16 | 17 | class ResultsReader: 18 | """Class that could read results from lib.metrics.ResultsSaver and then produce matrices for visualization.""" 19 | 20 | def __init__(self, results_dir): 21 | self.results_dir = results_dir 22 | self.predicted = np.load(self.results_dir + '/predicted') 23 | self.target = np.load(self.results_dir + '/target') 24 | 25 | 26 | def run_nt_metrics(reader, metrics): 27 | metrics.drop_state() 28 | metrics.report(( 29 | torch.from_numpy(reader.predicted), 30 | ASTTarget(torch.from_numpy(reader.target), None) 31 | )) 32 | metrics.get_current_value(should_print=True) 33 | 34 | 35 | def run_metrics(reader, metrics): 36 | metrics.drop_state() 37 | metrics.report(( 38 | torch.from_numpy(reader.predicted), 39 | torch.from_numpy(reader.target), 40 | )) 41 | metrics.get_current_value(should_print=True) 42 | 43 | 44 | def get_accuracy_result(results_dir): 45 | reader = ResultsReader(results_dir=results_dir) 46 | metrics = NonTerminalsMetricsWrapper(BaseAccuracyMetrics()) 47 | run_nt_metrics(reader, metrics) 48 | 49 | 50 | def get_per_nt_result(results_dir, save_dir, group=False): 51 | reader = ResultsReader(results_dir=results_dir) 52 | metrics = SingleNonTerminalAccuracyMetrics( 53 | non_terminals_file='data/ast/non_terminals.json', 54 | results_dir=save_dir, 55 | group=group, 56 | dim=None 57 | ) 58 | 59 | run_nt_metrics(reader, metrics) 60 | 61 | 62 | # endregion 63 | 64 | 65 | def eval_nt(results_dir, save_dir, group=False): 66 | reader = ResultsReader(results_dir=results_dir) 67 | 68 | metrics = SequentialMetrics([ 69 | NonTerminalsMetricsWrapper(BaseAccuracyMetrics()), 70 | SingleNonTerminalAccuracyMetrics( 71 | non_terminals_file='data/ast/non_terminals.json', 72 | results_dir=save_dir, 73 | group=group, 74 | dim=None 75 | ) 76 | ]) 77 | 78 | # run_nt_metrics(reader, metrics) 79 | 80 | metrics.drop_state() 81 | metrics.report(( 82 | torch.from_numpy(reader.predicted[:, :, 0]), 83 | ASTTarget(torch.from_numpy(reader.target), None) 84 | )) 85 | metrics.get_current_value(should_print=True) 86 | 87 | def eval_t(res_dir, save_dir): 88 | reader = ResultsReader(results_dir=res_dir) 89 | # metrics = EmptyNonEmptyWrapper(AggregatedTerminalMetrics(), AggregatedTerminalMetrics()) 90 | metrics = EmptyNonEmptyTerminalTopKAccuracyWrapper() 91 | run_metrics(reader, metrics) 92 | 93 | 94 | def eval_token(res_dir, save_dir): 95 | reader = ResultsReader(results_dir=res_dir) 96 | metrics = AggregatedTokenMetrics() 97 | run_metrics(reader, metrics) 98 | 99 | 100 | def calc_top_accuracies(results_dir): 101 | reader = ResultsReader(results_dir=results_dir) 102 | metrics = TopKAccuracy(k=5) 103 | 104 | metrics.drop_state() 105 | metrics.report(( 106 | torch.from_numpy(reader.predicted), 107 | torch.from_numpy(reader.target) 108 | )) 109 | metrics.get_current_value(should_print=True) 110 | 111 | 112 | def convert_to_top1(dir_from, dir_to): 113 | reader = ResultsReader(results_dir=dir_from) 114 | np.save(os.path.join(dir_to, 'predicted'), reader.predicted[:, :, 0]) 115 | np.save(os.path.join(dir_to, 'target'), reader.target) 116 | 117 | 118 | def main(task, model): 119 | common_dirs = { 120 | 'base': 'eval_verified/nt2n_base_30k/top5_new', 121 | 'base_large_embeddings': 'eval_verified/nt2n_base_large_embeddings_30k', 122 | 'attention': 'eval_verified/nt2n_base_attention_30k', 123 | 'layered': 'eval_verified/nt2n_base_attention_plus_layered_30k/top5_new', 124 | 'layered_old': 'eval_verified/nt2n_layered_attention', 125 | 'token': 'eval_verified/token_base', 126 | 'terminal': 'eval_verified/ntn2t_base' 127 | } 128 | topk_dirs = { 129 | 'base': 'eval_verified/nt2n_base_30k/top5_new', 130 | 'base_large_embeddings': 'eval_verified/nt2n_base_large_embeddings_30k', 131 | 'attention': 'eval_verified/nt2n_base_attention_30k/top5', 132 | 'layered': 'eval_verified/nt2n_base_attention_plus_layered_30k/top5', 133 | 'layered_old': 'eval_verified/nt2n_layered_attention/top5', 134 | 'token': 'eval_verified/token_base/top5', 135 | 'terminal': 'eval_verified/ntn2t_base/top5' 136 | } 137 | save_dir = 'eval_local' 138 | 139 | if task == 'nt_eval': 140 | res_dir = topk_dirs[model] 141 | eval_nt( 142 | results_dir=res_dir, 143 | save_dir=save_dir, 144 | group=True 145 | ) 146 | elif task == 'token_eval': 147 | res_dir = common_dirs[model] 148 | eval_token(res_dir, save_dir) 149 | elif task == 't_eval': 150 | res_dir = topk_dirs[model] 151 | eval_t(res_dir, save_dir) 152 | elif task == 'topk': 153 | res_dir = topk_dirs[model] 154 | calc_top_accuracies(results_dir=res_dir) 155 | elif task == 'to_top1': 156 | convert_to_top1(topk_dirs[model], common_dirs[model]) 157 | else: 158 | raise Exception('Unknown task type') 159 | 160 | 161 | def calculate_python_per_nt_acc(non_terminals_file, directory): 162 | reader = ResultsReader(results_dir=directory) 163 | metrics = PythonPerNonTerminalAccuracyMetrics( 164 | non_terminals_file=non_terminals_file, 165 | results_dir=directory, 166 | add_unk=True, 167 | dim=None 168 | ) 169 | 170 | metrics.report(( 171 | torch.from_numpy(reader.predicted[:, :, 0]), 172 | ASTTarget(torch.from_numpy(reader.target), None) 173 | )) 174 | 175 | metrics.get_current_value(should_print=True) 176 | 177 | 178 | def show_python_per_nt_accuracies(file, group=False, to_save_file=None): 179 | result = read_json(file) 180 | 181 | hits = {} 182 | misses = {} 183 | 184 | for i in range(len(result)): 185 | nt_type = result[i]['type'] 186 | cur_hits = result[i]['hits'] 187 | cur_misses = result[i]['misses'] 188 | 189 | if group: 190 | if nt_type != 'EOF' and nt_type !='': 191 | nt_type = nt_type[:-2] 192 | if nt_type.startswith('Compare'): 193 | nt_type = 'Compare' 194 | 195 | if nt_type not in hits: 196 | hits[nt_type] = 0 197 | misses[nt_type] = 0 198 | 199 | hits[nt_type] += cur_hits 200 | misses[nt_type] += cur_misses 201 | 202 | to_save = [] 203 | for nt in sorted(hits.keys()): 204 | accuracy = 0 205 | if hits[nt] + misses[nt] != 0: 206 | accuracy = hits[nt] / (hits[nt] + misses[nt]) 207 | 208 | to_save.append({'type': nt, 'accuracy': accuracy}) 209 | print('Accuracy on {} is {}'.format(nt, accuracy)) 210 | 211 | if to_save_file is not None: 212 | f = open(to_save_file, mode='w') 213 | f.write(json.dumps(to_save)) 214 | 215 | if __name__ == '__main__': 216 | # calculate_python_per_nt_acc( 217 | # non_terminals_file='data/pyast_server/non_terminals.json', 218 | # directory='eval_verified/py_nt2n_base' 219 | # ) 220 | 221 | # show_python_per_nt_accuracies( 222 | # file='eval_verified/py_nt2n_base/py_nt_acc.txt', 223 | # group=True, 224 | # to_save_file='eval_verified/py_nt2n_base/nt_acc_grouped.txt' 225 | # ) 226 | 227 | _tasks = ['topk', 'to_top1', 'nt_eval', 't_eval', 'token_eval'] 228 | _models = ['base_large_embeddings'] 229 | main(task='nt_eval', model=_models[0]) 230 | # main(task='topk', model='base') 231 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/ast_info.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from zerogercrnn.lib.preprocess import write_json, read_jsons, read_json, extract_jsons_info, JsonExtractor 7 | 8 | FILE_TRAINING = 'data/programs_training.json' 9 | FILE_TRAINING_PROCESSED = 'data/ast/file_train.json' 10 | FILE_STAT_TREE_HEIGHTS = 'data/ast/stat_tree_heights.json' 11 | FILE_STAT_PROGRAM_LENGTHS = 'data/ast/stat_program_lengths.json' 12 | 13 | PY_FILE_TRAINING_PROCESSED = 'data/pyast/file_train.json' 14 | 15 | # region Utils 16 | 17 | def draw_plot(x, y, x_label=None, y_label=None): 18 | plt.plot(x, y) 19 | if x_label is not None: 20 | plt.xlabel(x_label) 21 | if y_label is not None: 22 | plt.ylabel(y_label) 23 | plt.show() 24 | 25 | 26 | def get_percentile_plot(stat): 27 | """Get x and y for plot describing percentile of stat < x for each x. 28 | 29 | :param stat: array of int values 30 | """ 31 | stat = sorted(stat) 32 | x = [] 33 | y = [] 34 | for i in range(len(stat)): 35 | if (i == len(stat) - 1) or (stat[i] != stat[i + 1]): 36 | x.append(stat[i]) 37 | y.append(float(i + 1) / len(stat)) 38 | 39 | return x, y 40 | 41 | 42 | def plot_percentile_from_file(file, x_label, y_label): 43 | stat = list(read_jsons(file))[0] 44 | x, y = get_percentile_plot(stat) 45 | draw_plot(x, y, x_label=x_label, y_label=y_label) 46 | 47 | 48 | # endregion 49 | 50 | # region TreeHeight 51 | def print_tree_heights_stats(tree_heights): 52 | print('Min height of the tree: {}'.format(min(tree_heights))) 53 | print('Max height of the tree: {}'.format(max(tree_heights))) 54 | print('Average height of the trees: {}'.format(float(sum(tree_heights)) / len(tree_heights))) 55 | 56 | 57 | def print_tree_heights_stats_from_file(tree_heights_file): 58 | print_tree_heights_stats(list(read_jsons(tree_heights_file))[0]) 59 | 60 | 61 | def tree_heights_distribution(tree_heights_file): 62 | plot_percentile_from_file(tree_heights_file, x_label='Tree height', y_label='Percent of data') 63 | 64 | 65 | class JsonTreeHeightExtractor(JsonExtractor): 66 | 67 | def __init__(self): 68 | self.buffer = {} 69 | 70 | def extract(self, raw_json): 71 | return self._calc_height(raw_json) 72 | 73 | def _calc_height(self, raw_json): 74 | to_calc = [] 75 | for node in raw_json: 76 | if node == 0: 77 | break 78 | 79 | if 'children' in node: 80 | to_calc.append(node) 81 | else: 82 | self.buffer[int(node['id'])] = 1 83 | 84 | for node in reversed(to_calc): 85 | id = int(node['id']) 86 | self.buffer[id] = 0 87 | for children in node['children']: 88 | self.buffer[id] = max(self.buffer[id], self.buffer[int(children)] + 1) 89 | 90 | return self.buffer[0] 91 | 92 | 93 | def calc_tree_heights(heights_file): 94 | tree_heights = [] 95 | 96 | extractor = JsonTreeHeightExtractor() 97 | for current_height in extract_jsons_info(extractor, FILE_TRAINING): 98 | tree_heights.append(current_height) 99 | 100 | if heights_file is not None: 101 | write_json(heights_file, tree_heights) 102 | print_tree_heights_stats(tree_heights=tree_heights) 103 | 104 | 105 | # endregion 106 | 107 | # region ProgramLen 108 | 109 | 110 | def plot_program_len_percentiles(lengths_file): 111 | plot_percentile_from_file(lengths_file, x_label='Program lengths', y_label='Percentile') 112 | 113 | 114 | class JsonProgramLenExtractor(JsonExtractor): 115 | 116 | def extract(self, raw_json): 117 | return len(raw_json) - 1 118 | 119 | 120 | def calc_programs_len(lengths_file): 121 | extractor = JsonProgramLenExtractor() 122 | 123 | program_lengths = list(extract_jsons_info(extractor, FILE_TRAINING)) 124 | if lengths_file is not None: 125 | write_json(lengths_file, program_lengths) 126 | 127 | 128 | # endregion 129 | 130 | 131 | class JsonProgramDepthStatExtractor(JsonExtractor): 132 | 133 | def extract(self, raw_json): 134 | depths_prob = np.zeros(50) 135 | for node in raw_json: 136 | depths_prob[min(node['d'], 49)] += 1 137 | 138 | return depths_prob 139 | 140 | 141 | def extract_depths_histogram(file_train): 142 | extractor = JsonProgramDepthStatExtractor() 143 | 144 | depths_prob = np.zeros(50) 145 | for info in extract_jsons_info(extractor, file_train): 146 | depths_prob = depths_prob + info 147 | 148 | res = [x for x in depths_prob] 149 | with open('eval_local/node_depths.json', 'w') as f: 150 | f.write(json.dumps(res)) 151 | 152 | 153 | def draw_histogram(file): 154 | values = read_json(file) 155 | all = np.sum(values) 156 | values /= all 157 | 158 | plt.plot(values) 159 | # n, bins, patches = plt.hist(values, 100, density=True, facecolor='g', alpha=0.75) 160 | 161 | plt.xlabel('Smarts') 162 | plt.ylabel('Probability') 163 | plt.show() 164 | 165 | 166 | class EasyNonTerminalsExtractor(JsonExtractor): 167 | 168 | def __init__(self): 169 | super().__init__() 170 | self.parents_table = {} 171 | 172 | def extract(self, raw_json): 173 | for i in range(len(raw_json) - 1): 174 | node = raw_json[i] 175 | if 'children' in node: 176 | parent_type = node['type'] 177 | children = node['children'] 178 | for position in range(len(children)): 179 | child_type = raw_json[int(children[position])]['type'] 180 | 181 | if child_type not in self.parents_table: 182 | self.parents_table[child_type] = [] 183 | 184 | self.parents_table[child_type].append(parent_type + '_' + str(position)) 185 | 186 | 187 | def get_easy_non_terminals(file, lim=None): 188 | extractor = EasyNonTerminalsExtractor() 189 | for info in extract_jsons_info(extractor, file, lim=lim): 190 | print(info) 191 | 192 | 193 | class NonTerminalsStatExtractor(JsonExtractor): 194 | def __init__(self): 195 | super().__init__() 196 | self.stat = {} 197 | 198 | def extract(self, raw_json): 199 | for i in range(len(raw_json) - 1): 200 | t = raw_json[i]['type'] 201 | if t not in self.stat: 202 | self.stat[t] = 0 203 | 204 | self.stat[t] += 1 205 | 206 | return True 207 | 208 | 209 | def visualize_nt_stat(file): 210 | stat = read_json(file) 211 | labels = [] 212 | values = [] 213 | sum = 0 214 | for k in sorted(stat.keys()): 215 | labels.append(k) 216 | values.append(stat[k]) 217 | sum += stat[k] 218 | 219 | x = np.arange(len(values)) 220 | y = np.array(values) / sum * 100 221 | 222 | plt.xticks(x, labels, rotation=30, horizontalalignment='right', fontsize=5) 223 | plt.grid(True) 224 | 225 | plt.plot(x, y) 226 | plt.show() 227 | 228 | 229 | class UNKNTExtractor(JsonExtractor): 230 | 231 | def extract(self, raw_json): 232 | unk_count = 0 233 | for node in raw_json: 234 | if node['N'] == 321: 235 | unk_count += 1 236 | return len(raw_json), unk_count 237 | 238 | 239 | def get_unk_nt_percentage(file_eval): # unk percentage: 8.86e-7 240 | extractor = UNKNTExtractor() 241 | total_count = 0 242 | unk_count = 0 243 | for (c_t, c_u) in extract_jsons_info(extractor, file_eval): 244 | total_count += c_t 245 | unk_count += c_u 246 | 247 | print(float(unk_count) / total_count) 248 | 249 | 250 | def get_non_terminals_statistic(file, lim=None): 251 | extractor = NonTerminalsStatExtractor() 252 | list(extract_jsons_info(extractor, file, lim=lim)) 253 | 254 | with open('data/ast/stat_nt_occurrences.json', mode='w') as f: 255 | f.write(json.dumps(extractor.stat)) 256 | 257 | 258 | def run_main(): 259 | # get_unk_nt_percentage('data/pyast/file_eval.json') 260 | # extract_depths_histogram(PY_FILE_TRAINING_PROCESSED) 261 | # extract_depths_histogram(FILE_TRAINING_PROCESSED) 262 | draw_histogram('eval_local/node_depths.json') 263 | 264 | # get_easy_non_terminals(file='data/programs_eval_10000.json', lim=100) 265 | # get_non_terminals_statistic(file='data/programs_eval_10000.json', lim=10000) 266 | # visualize_nt_stat(file='data/ast/stat_nt_occurrences.json') 267 | 268 | # calc_programs_len(FILE_STAT_PROGRAM_LENGTHS) 269 | # plot_program_len_percentiles(FILE_STAT_PROGRAM_LENGTHS) 270 | 271 | # calc_tree_heights(heights_file=FILE_STAT_TREE_HEIGHTS) 272 | # print_tree_heights_stats_from_file(tree_heights_file=FILE_STAT_TREE_HEIGHTS) 273 | # tree_heights_distribution(tree_heights_file=FILE_STAT_TREE_HEIGHTS) 274 | 275 | 276 | if __name__ == '__main__': 277 | run_main() 278 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/compare.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from zerogercrnn.lib.constants import EOF_TOKEN 5 | from zerogercrnn.lib.preprocess import read_json, read_jsons 6 | 7 | 8 | def compare_per_nt(file1, file2, y_label): 9 | nt1, res1 = list(read_jsons(file1)) 10 | nt2, res2 = list(read_jsons(file2)) 11 | assert nt1 == nt2 12 | 13 | x = np.arange(len(nt1)) 14 | y1 = np.array(res1) 15 | y2 = np.array(res2) 16 | 17 | my_xticks = nt1 18 | plt.xticks(x, my_xticks, rotation=30, horizontalalignment='right', fontsize=5) 19 | plt.ylabel(y_label) 20 | plt.grid(True) 21 | 22 | plt.plot(x, (y2 - y1) * 100) 23 | plt.show() 24 | 25 | # print('Diff as second - first:') 26 | # for i in range(len(nt1)): 27 | # print('{} : {}'.format(nt1[i], res2[i] - res1[i])) 28 | 29 | 30 | def compare_per_two_plots(file1, file2, y_label): 31 | nt1, res1 = list(read_jsons(file1)) 32 | nt2, res2 = list(read_jsons(file2)) 33 | assert nt1 == nt2 34 | 35 | x = np.arange(len(nt1)) 36 | y1 = np.array(res1) 37 | y2 = np.array(res2) 38 | 39 | my_xticks = nt1 40 | plt.xticks(x, my_xticks, rotation=30, horizontalalignment='right', fontsize=5) 41 | plt.ylabel(y_label) 42 | plt.grid(True) 43 | 44 | plt.plot(x, y1, 'r', x, y2, 'g') 45 | plt.show() 46 | 47 | # print('Diff as second - first:') 48 | # for i in range(len(nt1)): 49 | # print('{} : {}'.format(nt1[i], res2[i] - res1[i])) 50 | 51 | 52 | def run_main(): 53 | values_base = np.array(read_json('eval/nt2n_base/nt_acc.json')) 54 | values_layered = np.array(read_json('eval/nt2n_layered_attention/nt_acc.json')) 55 | non_terminals = read_json('data/ast/non_terminals.json') 56 | non_terminals.append(EOF_TOKEN) 57 | 58 | diff = values_layered - values_base 59 | for i in range(len(non_terminals)): 60 | print('{}: {}'.format(non_terminals[i], diff[i])) 61 | 62 | 63 | if __name__ == '__main__': 64 | # run_main() 65 | compare_per_nt( 66 | file1='eval_local/nt2n_base/nt_acc_no_group.txt', 67 | file2='eval_local/nt2n_layered_attention/nt_acc_no_group.txt', 68 | y_label='Gain of layered attention comparing to base model' 69 | ) 70 | # compare_per_two_plots( 71 | # file1='eval_local/nt2n_base/nt_acc_no_group.txt', 72 | # file2='eval_local/nt2n_layered_attention/nt_acc_no_group.txt', 73 | # y_label='Accuracies: Red - base model, Green - layered model' 74 | # ) 75 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | from zerogercrnn.experiments.ast_level.vis.utils import visualize_attention_matrix 7 | 8 | """ 9 | Tools for model visualization. 10 | """ 11 | 12 | 13 | def visualize_attention(file): 14 | """Visualize attention matrix stored as numpy array""" 15 | tensor = np.load(file) 16 | visualize_attention_matrix(tensor) 17 | 18 | 19 | def visualize_output_combination(file_before, file_after): 20 | tensor_before = np.load(file_before) 21 | print(np.sum(tensor_before[:1500]) / 1500) 22 | print(np.sum(tensor_before[-500:]) / 500) 23 | plt.plot(tensor_before, 'r') 24 | 25 | tensor_after = np.load(file_after) 26 | print(np.sum(tensor_after[:1500]) / 1500) 27 | print(np.sum(tensor_after[-500:]) / 500) 28 | plt.plot(tensor_after, 'g') 29 | 30 | plt.show() 31 | 32 | 33 | def visualize_line(file): 34 | line = np.load(file) 35 | plt.plot(line) 36 | plt.show() 37 | 38 | 39 | def visualize_running_mean_and_variance(mean_file, variance_file): 40 | mean = np.load(mean_file) 41 | variance = np.load(variance_file) 42 | 43 | plt.plot(variance) 44 | plt.show() 45 | 46 | 47 | def draw_1d_plot_from_file(*files): 48 | legend = [] 49 | for f in files: 50 | cur, = plt.plot(np.load(f), label=f) 51 | legend.append(cur) 52 | 53 | plt.legend(handles=legend) 54 | plt.show() 55 | 56 | 57 | def draw_mean_deviation_variance(directory='eval/temp'): 58 | draw_1d_plot_from_file( 59 | os.path.join(directory, 'mean.npy'), 60 | os.path.join(directory, 'deviation.npy'), 61 | os.path.join(directory, 'variance.npy') 62 | ) 63 | 64 | 65 | def draw_mean_variance(directory='eval/temp'): 66 | c1, = plt.plot(np.load(os.path.join(directory, 'mean.npy')), label=os.path.join(directory, 'mean.npy')) 67 | c2, = plt.plot(np.sqrt(np.load(os.path.join(directory, 'variance.npy'))), label=os.path.join(directory, 'std.npy')) 68 | plt.legend(handles=[c1,c2]) 69 | plt.show() 70 | 71 | 72 | if __name__ == '__main__': 73 | draw_mean_variance(directory='eval/temp') 74 | # draw_mean_variance(directory='eval/temp/nt2n_base_attention_norm_after') 75 | # visualize_attention(file='eval/temp/attention/per_depth_matrix.npy') 76 | # draw_mean_variance(directory='eval/temp/before_input') 77 | # draw_mean_variance(directory='eval/temp/after_input') 78 | # 79 | # draw_mean_variance(directory='eval/temp/before_output') 80 | # draw_mean_variance(directory='eval/temp/after_output') 81 | # draw_1d_plot_from_file('eval/temp/deviation.npy') 82 | # draw_1d_plot_from_file('eval/temp/variance.npy') 83 | 84 | # visualize_line('eval/temp/layered_input_matrix.npy') 85 | # visualize_attention('eval_local/attention/per_depth_matrix.npy') 86 | # visualize_output_combination( 87 | # file_before='eval/temp/new_output_sum_before_matrix.npy', 88 | # file_after='eval/temp/new_output_sum_after_matrix.npy' 89 | # ) 90 | # visualize_output_combination( 91 | # file_before='eval/temp/test_before.npy', 92 | # file_after='eval/temp/test_after.npy' 93 | # ) 94 | # visualize_running_mean_and_variance( 95 | # mean_file='eval/temp/running_mean.npy', 96 | # variance_file='eval/temp/running_var.npy' 97 | # ) 98 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import json 4 | 5 | from zerogercrnn.lib.constants import EOF_TOKEN, EMPTY_TOKEN 6 | from zerogercrnn.lib.preprocess import read_json, read_jsons 7 | from matplotlib.lines import Line2D 8 | 9 | COLOR_BASE = '#607D8B' 10 | COLOR_RED = '#f44336' 11 | COLOR_GREEN = '#4CAF50' 12 | 13 | 14 | class Plot: 15 | def __init__(self, data, label=None): 16 | self.data = data 17 | self.label = label 18 | 19 | 20 | def add_nt_x_ticks(nt): 21 | x = np.arange(len(nt)) 22 | plt.gcf().subplots_adjust(bottom=0.2) 23 | res = plt.xticks(x, nt, rotation=30, horizontalalignment='right', fontsize=12) 24 | for id, cur in enumerate(nt): 25 | if cur == 'FunctionDeclaration' or cur == 'DoWhileStatement' or cur == 'CatchClause': 26 | res[1][id].set_weight('bold') 27 | 28 | 29 | def draw_per_nt_plot_inner(nt, *plots, y_label=None): 30 | add_nt_x_ticks(nt) 31 | plt.ylabel(y_label) 32 | plt.grid(True) 33 | 34 | legend = [] 35 | for p in plots: 36 | cur_legend, = plt.plot(p.data, label=p.label) 37 | legend.append(cur_legend) 38 | 39 | plt.grid(True) 40 | plt.legend(handles=legend) 41 | plt.show() 42 | 43 | 44 | def draw_per_nt_plot(file, y_label='Per NT accuracy'): 45 | nt, data = list(read_jsons(file)) 46 | draw_per_nt_plot_inner(nt, Plot(data=data), y_label=y_label) 47 | 48 | 49 | def draw_per_nt_bar_chart(nt, *plots, y_label='Per NT accuracy'): 50 | ind = np.arange(len(nt)) 51 | legend_rects = [] 52 | legend_labels = [] 53 | width = 0.4 / len(plots) 54 | current_shift = -0.2 55 | for p in plots: 56 | cur = plt.bar(ind + current_shift + width / 2, p.data, width=width) 57 | legend_rects.append(cur[0]) 58 | legend_labels.append(p.label) 59 | current_shift += width 60 | 61 | plt.legend(legend_rects, legend_labels) 62 | add_nt_x_ticks(nt) 63 | plt.show() 64 | 65 | 66 | def bar_chart(): 67 | import numpy as np 68 | import matplotlib.pyplot as plt 69 | 70 | N = 5 71 | menMeans = (20, 35, 30, 35, 27) 72 | womenMeans = (25, 32, 34, 20, 25) 73 | menStd = (2, 3, 4, 1, 2) 74 | womenStd = (3, 5, 2, 3, 3) 75 | ind = np.arange(N) # the x locations for the groups 76 | width = 0.35 # the width of the bars: can also be len(x) sequence 77 | 78 | p1 = plt.bar(ind, menMeans, width, yerr=menStd) 79 | p2 = plt.bar(ind, womenMeans, width, 80 | bottom=menMeans, yerr=womenStd) 81 | 82 | plt.ylabel('Scores') 83 | plt.title('Scores by group and gender') 84 | plt.xticks(ind, ('G1', 'G2', 'G3', 'G4', 'G5')) 85 | plt.yticks(np.arange(0, 81, 10)) 86 | plt.legend((p1[0], p2[0]), ('Men', 'Women')) 87 | 88 | plt.show() 89 | 90 | 91 | def compare_per_nt(file1, file2, y_label='New'): 92 | nt1, data1 = list(read_jsons(file1)) 93 | nt2, data2 = list(read_jsons(file2)) 94 | assert nt1 == nt2 95 | nt = nt1 96 | data1 = np.array(data1) 97 | data2 = np.array(data2) 98 | diff = data2 - data1 99 | 100 | ind = np.arange(len(nt)) 101 | p1 = plt.bar(ind, data1) 102 | for bar in p1: 103 | bar.set_facecolor(COLOR_BASE) 104 | 105 | p2 = plt.bar(ind, diff, bottom=data1) 106 | 107 | for id, bar in enumerate(p2): 108 | if diff[id] >= 0: 109 | bar.set_facecolor(COLOR_GREEN) 110 | else: 111 | bar.set_facecolor(COLOR_RED) 112 | 113 | custom_lines = [Line2D([0], [0], color=COLOR_BASE, lw=4), 114 | Line2D([0], [0], color=COLOR_GREEN, lw=4), 115 | Line2D([0], [0], color=COLOR_RED, lw=4)] 116 | plt.legend(custom_lines, ['База', 'Улучшение', 'Ухудшение']) 117 | 118 | # plt.legend((p1[0], p2[0]), (file1, file2)) 119 | add_nt_x_ticks(nt) 120 | 121 | plt.ylabel(y_label, fontsize=12) 122 | plt.show() 123 | 124 | 125 | def compare_per_nt_diff_only(file1, file2, y_label='New'): 126 | nt1, data1 = list(read_jsons(file1)) 127 | nt2, data2 = list(read_jsons(file2)) 128 | assert nt1 == nt2 129 | 130 | nt = read_json('data/ast/non_terminals_plot_modified_attention.json') 131 | assert len(nt1) == len(nt) 132 | 133 | data1 = np.array(data1) 134 | data2 = np.array(data2) 135 | diff = data2 - data1 136 | 137 | ind = np.arange(len(nt)) 138 | p1 = plt.bar(ind, (diff) * 100, width=1) 139 | 140 | for id, bar in enumerate(p1): 141 | if diff[id] >= 0: 142 | bar.set_facecolor(COLOR_GREEN) 143 | else: 144 | bar.set_facecolor(COLOR_RED) 145 | 146 | custom_lines = [ 147 | Line2D([0], [0], color=COLOR_GREEN, lw=4), 148 | Line2D([0], [0], color=COLOR_RED, lw=4) 149 | ] 150 | plt.legend(custom_lines, ['Улучшение', 'Ухудшение'], prop={'size': 16}) 151 | plt.grid(True) 152 | 153 | add_nt_x_ticks(nt) 154 | 155 | plt.ylabel(y_label, fontsize=14) 156 | plt.show() 157 | 158 | 159 | def main(): 160 | res_file_base = 'eval_verified/nt2n_base_30k/nt_acc_grouped.txt' 161 | res_file_base_attention = 'eval_verified/nt2n_base_attention_30k/nt_acc_grouped.txt' 162 | res_file_layered = 'eval_verified/nt2n_base_attention_plus_layered_30k/nt_acc_grouped.txt' 163 | 164 | res_file_base_old = 'eval_verified/nt2n_base/nt_acc.txt' 165 | res_file_layered_attention_old = 'eval_verified/nt2n_layered_attention/nt_acc.txt' 166 | # draw_per_nt_plot(res_file_layered_attention_old) 167 | compare_per_nt_diff_only(res_file_base, res_file_layered, y_label='Разница в процентных пунктах') 168 | 169 | 170 | def create_non_terminals_plot(): 171 | nt = read_json('data/ast/non_terminals.json') 172 | with open('data/ast/non_terminals_plot.json', mode='w') as f: 173 | nt_grouped = list(set([c[:-2] for c in nt])) 174 | f.write(json.dumps(sorted(nt_grouped + [EOF_TOKEN]))) 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | # bar_chart() 180 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/post_accuracy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from zerogercrnn.lib.constants import EOF_TOKEN 5 | from zerogercrnn.lib.preprocess import read_json 6 | 7 | JS_NON_TERMINALS = 'data/ast/non_terminals.json' 8 | 9 | 10 | # region Utils 11 | 12 | def save_pretty_json(json_file): 13 | jsn = read_json(json_file) 14 | with open(os.path.splitext(json_file)[0] + '_pretty.json', mode='w') as f: 15 | f.write(json.dumps(jsn, indent=4, sort_keys=True)) 16 | 17 | 18 | def accuracy(hits, misses): 19 | if hits + misses == 0: 20 | return 0 21 | return float(hits) / (hits + misses) 22 | 23 | 24 | # endregion 25 | 26 | 27 | class NtMapUtils: 28 | @staticmethod 29 | def per_nt_accuracies(mp): 30 | res = {} 31 | for key, value in mp.items(): 32 | res[key] = accuracy(hits=value['hits'], misses=value['misses']) 33 | return res 34 | 35 | @staticmethod 36 | def total_accuracy(mp): 37 | hits = 0 38 | misses = 0 39 | for key, value in mp.items(): 40 | hits += value['hits'] 41 | misses += value['misses'] 42 | return accuracy(hits=hits, misses=misses) 43 | 44 | @staticmethod 45 | def grouped_per_nt_accuracies(mp): 46 | hits = {} 47 | misses = {} 48 | for key, value in mp.items(): 49 | if key != EOF_TOKEN: 50 | key = key[:-2] 51 | 52 | hits[key] = hits.get(key, 0) + value['hits'] 53 | misses[key] = misses.get(key, 0) + value['misses'] 54 | 55 | res = {} 56 | for k in hits.keys(): 57 | res[k] = accuracy(hits=hits[k], misses=misses[k]) 58 | return res 59 | 60 | 61 | def draw_top_1(file): 62 | res = read_json(os.path.join(file, 'topk.json')) 63 | print(NtMapUtils.total_accuracy(res['top0'])) 64 | print(json.dumps( 65 | NtMapUtils.grouped_per_nt_accuracies(res['top0']), 66 | indent=4, 67 | sort_keys=True 68 | )) 69 | 70 | 71 | def main(): 72 | res_dir_base_large_embeddings = 'eval_verified/nt2n_base_large_embeddings_30k' 73 | draw_top_1(res_dir_base_large_embeddings) 74 | 75 | 76 | if __name__ == '__main__': 77 | save_pretty_json('eval_verified/nt2n_base_large_embeddings_30k/topk.json') 78 | # main() 79 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/pre_accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import os 5 | from zerogercrnn.lib.accuracies import indexed_topk_hits, topk_hits 6 | from zerogercrnn.experiments.ast_level.data import ASTTarget 7 | from zerogercrnn.lib.metrics import Metrics 8 | from zerogercrnn.lib.preprocess import read_json 9 | from zerogercrnn.lib.constants import EOF_TOKEN 10 | 11 | JS_NON_TERMINALS = 'data/ast/non_terminals.json' 12 | 13 | # region Utils 14 | 15 | class ResultsReader: 16 | """Class that could read results from lib.metrics.ResultsSaver and then produce matrices for visualization.""" 17 | 18 | def __init__(self, results_dir): 19 | self.results_dir = results_dir 20 | self.predicted = torch.from_numpy(np.load(self.results_dir + '/predicted')) 21 | self.target = torch.from_numpy(np.load(self.results_dir + '/target')) 22 | 23 | def get_nt_predicted_target(self): 24 | return self.predicted, ASTTarget(self.target, None) 25 | 26 | 27 | def run_nt_metrics(reader: ResultsReader, metrics: Metrics): 28 | metrics.drop_state() 29 | metrics.report(reader.get_nt_predicted_target()) 30 | metrics.get_current_value(should_print=True) 31 | 32 | 33 | # endregion Utils 34 | 35 | def get_per_nt_hits_and_misses(nt_id, predicted, target): 36 | index = (target == nt_id).nonzero().squeeze() 37 | return indexed_topk_hits(predicted, target, index) 38 | 39 | 40 | def transform_to_per_nt_topk(reader, non_terminals): 41 | assert torch.max(reader.target) <= len(non_terminals) - 1 42 | 43 | res = {} 44 | 45 | predicted = reader.predicted.view(-1, reader.predicted.size()[-1]) 46 | target = reader.target.view(-1) 47 | 48 | for id, nt in enumerate(non_terminals): 49 | topk_hits, total = get_per_nt_hits_and_misses(id, predicted, target) 50 | 51 | for i in range(topk_hits.size(0)): 52 | key = 'top' + str(i) 53 | if key not in res: 54 | res[key] = {} 55 | 56 | if nt not in res[key]: 57 | res[key][nt] = {} 58 | res[key][nt]['hits'] = 0 59 | res[key][nt]['misses'] = 0 60 | 61 | res[key][nt]['hits'] += topk_hits[i].item() 62 | res[key][nt]['misses'] += total - topk_hits[i].item() 63 | 64 | return res 65 | 66 | 67 | def task_transform_to_per_nt_topk(results_dir, non_terminals_file, res_dir): 68 | reader = ResultsReader(results_dir=results_dir) 69 | non_terminals = read_json(non_terminals_file) 70 | non_terminals.append(EOF_TOKEN) 71 | 72 | res = transform_to_per_nt_topk(reader, non_terminals) 73 | with open(os.path.join(res_dir, 'topk.json'), mode='w') as f: 74 | f.write(json.dumps(res)) 75 | 76 | 77 | def main(): 78 | task_transform_to_per_nt_topk( 79 | results_dir='eval_verified/nt2n_base_large_embeddings_30k', 80 | non_terminals_file=JS_NON_TERMINALS, 81 | res_dir='eval_local' 82 | ) 83 | # tasks = ['per_nt_top_k'] 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/ast_level/vis/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | def visualize_attention_matrix(matrix): 7 | plt.matshow(matrix) 8 | plt.colorbar() 9 | plt.show() 10 | 11 | 12 | def draw_line_plot(line): 13 | plt.plot(line) 14 | plt.show() 15 | 16 | 17 | def visualize_tensor(tensor_to_visualize): 18 | """Draws a heatmap of tensor.""" 19 | tensor_to_visualize = tensor_to_visualize.detach().numpy() 20 | X = np.arange(0, tensor_to_visualize.shape[0]) 21 | Y = np.arange(0, tensor_to_visualize.shape[1]) 22 | X, Y = np.meshgrid(X, Y, indexing='ij') 23 | 24 | plt.figure() 25 | plt.pcolor(X, Y, tensor_to_visualize) 26 | plt.colorbar() 27 | plt.show() -------------------------------------------------------------------------------- /zerogercrnn/experiments/common.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import MultiStepLR 8 | 9 | from zerogercrnn.lib.core import BaseModule 10 | from zerogercrnn.lib.data import BatchedDataGenerator 11 | from zerogercrnn.lib.file import load_if_saved, load_cuda_on_cpu 12 | from zerogercrnn.lib.metrics import Metrics 13 | from zerogercrnn.lib.run import TrainEpochRunner, NetworkRoutine 14 | from zerogercrnn.lib.utils import filter_requires_grad, get_best_device 15 | 16 | 17 | # region CreateUtils 18 | 19 | def get_optimizer(args, model): 20 | return optim.Adam( 21 | params=filter_requires_grad(model.parameters()), 22 | lr=args.learning_rate, 23 | weight_decay=args.weight_decay 24 | ) 25 | 26 | 27 | def get_sparse_optimizer(args, model): 28 | return optim.SparseAdam( 29 | params=filter_requires_grad(model.sparse_parameters()), 30 | lr=args.learning_rate 31 | ) 32 | 33 | 34 | def get_optimizers(args, model): 35 | optimizers = [] 36 | if len(list(filter_requires_grad(model.parameters()))) != 0: 37 | optimizers.append(get_optimizer(args, model)) 38 | if len(list(filter_requires_grad(model.sparse_parameters()))) != 0: 39 | optimizers.append(get_sparse_optimizer(args, model)) 40 | 41 | if len(optimizers) == 0: 42 | raise Exception('Model has no parameters!') 43 | 44 | return optimizers 45 | 46 | 47 | def get_scheduler(args, optimizer): 48 | return MultiStepLR( 49 | optimizer=optimizer, 50 | milestones=list(range(args.decay_after_epoch, args.epochs + 20)), 51 | gamma=args.decay_multiplier 52 | ) 53 | 54 | 55 | # endregion 56 | 57 | 58 | class Main: 59 | def __init__(self, args): 60 | self.model = self.create_model(args).to(get_best_device()) 61 | self.load_model(args) 62 | 63 | self.optimizers = self.create_optimizers(args) 64 | self.schedulers = self.create_schedulers(args) 65 | self.criterion = self.create_criterion(args) 66 | 67 | self.data_generator = self.create_data_generator(args) 68 | 69 | self.train_routine = self.create_train_routine(args) 70 | self.validation_routine = self.create_validation_routine(args) 71 | self.train_metrics = self.create_train_metrics(args) 72 | self.eval_metrics = self.create_eval_metrics(args) 73 | self.plotter = 'tensorboard' 74 | 75 | @abstractmethod 76 | def create_data_generator(self, args) -> BatchedDataGenerator: 77 | pass 78 | 79 | @abstractmethod 80 | def create_model(self, args) -> BaseModule: 81 | pass 82 | 83 | @abstractmethod 84 | def create_criterion(self, args) -> nn.Module: 85 | pass 86 | 87 | @abstractmethod 88 | def create_train_routine(self, args) -> NetworkRoutine: 89 | pass 90 | 91 | @abstractmethod 92 | def create_validation_routine(self, args) -> NetworkRoutine: 93 | pass 94 | 95 | @abstractmethod 96 | def create_train_metrics(self, args) -> Metrics: 97 | pass 98 | 99 | @abstractmethod 100 | def create_eval_metrics(self, args) -> Metrics: 101 | return self.create_train_metrics(args) # Good enough if you don't want to eval now 102 | 103 | def train(self, args): 104 | runner = TrainEpochRunner( 105 | network=self.model, 106 | train_routine=self.train_routine, 107 | validation_routine=self.validation_routine, 108 | metrics=self.train_metrics, 109 | data_generator=self.data_generator, 110 | schedulers=self.schedulers, 111 | plotter=self.plotter, 112 | save_dir=args.model_save_dir, 113 | title=args.title, 114 | report_train_every=10, 115 | plot_train_every=50, 116 | save_model_every=args.save_model_every 117 | ) 118 | 119 | runner.run(number_of_epochs=args.epochs) 120 | 121 | def eval(self, args): 122 | print('Evaluation started!') 123 | if not os.path.exists(args.eval_results_directory): 124 | os.makedirs(args.eval_results_directory) 125 | 126 | self.model.eval() 127 | self.eval_metrics.eval() 128 | self.eval_metrics.drop_state() 129 | it = 0 130 | hook_metrics = self.register_eval_hooks() 131 | 132 | with torch.no_grad(): 133 | for iter_data in self.data_generator.get_eval_generator(): 134 | metrics_values = self.validation_routine.run( 135 | iter_num=it, 136 | iter_data=iter_data 137 | ) 138 | self.eval_metrics.report(metrics_values) 139 | it += 1 140 | 141 | if it % 1000 == 0: 142 | self.eval_metrics.get_current_value(should_print=True) 143 | 144 | self.eval_metrics.decrease_hits(self.data_generator.data_reader.eval_tails) 145 | self.eval_metrics.get_current_value(should_print=True) 146 | 147 | for m in hook_metrics: 148 | m.get_current_value(should_print=True) 149 | 150 | def register_eval_hooks(self): 151 | return [] 152 | 153 | def create_optimizers(self, args): 154 | return get_optimizers(args, self.model) 155 | 156 | def create_schedulers(self, args): 157 | return [get_scheduler(args, opt) for opt in self.optimizers] 158 | 159 | def load_model(self, args): 160 | if args.saved_model is not None: 161 | if torch.cuda.is_available(): 162 | load_if_saved(self.model, args.saved_model) 163 | else: 164 | load_cuda_on_cpu(self.model, args.saved_model) 165 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/pyast/metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | 6 | from zerogercrnn.experiments.ast_level.utils import read_non_terminals 7 | from zerogercrnn.lib.metrics import Metrics, IndexedAccuracyMetrics 8 | 9 | 10 | class PythonPerNonTerminalAccuracyMetrics(Metrics): 11 | """Metrics that show accuracies per non-terminal. It should not be used for plotting, but to 12 | print results on console during model evaluation.""" 13 | 14 | def __init__(self, non_terminals_file, results_dir=None, add_unk=False, dim=2): 15 | """ 16 | 17 | :param non_terminals_file: file with json of non-terminals 18 | :param results_dir: where to save json with accuracies per non-terminal 19 | :param dim: dimension to run max function on for predicted values 20 | """ 21 | super().__init__() 22 | print('Python SingleNonTerminalAccuracyMetrics created!') 23 | 24 | self.non_terminals = read_non_terminals(non_terminals_file) 25 | if add_unk: 26 | self.non_terminals.append('') 27 | 28 | self.non_terminals_number = len(self.non_terminals) 29 | self.results_dir = results_dir 30 | self.dim = dim 31 | 32 | self.accuracies = [IndexedAccuracyMetrics(label='ERROR') for _ in self.non_terminals] 33 | 34 | def drop_state(self): 35 | for accuracy in self.accuracies: 36 | accuracy.drop_state() 37 | 38 | def report(self, data): 39 | prediction, target = data 40 | if self.dim is None: 41 | predicted = prediction 42 | else: 43 | _, predicted = torch.max(prediction, dim=self.dim) 44 | predicted = predicted.view(-1) 45 | target = target.non_terminals.view(-1) 46 | 47 | for cur in range(len(self.non_terminals)): 48 | indices = (target == cur).nonzero().squeeze() 49 | self.accuracies[cur].report(predicted, target, indices) 50 | 51 | def get_current_value(self, should_print=False): 52 | result = [] 53 | for cur in range(len(self.non_terminals)): 54 | cur_hits = self.accuracies[cur].metrics.hits 55 | cur_misses = self.accuracies[cur].metrics.misses 56 | result.append({ 57 | 'type': self.non_terminals[cur], 58 | 'hits': cur_hits, 59 | 'misses': cur_misses 60 | }) 61 | 62 | if should_print: 63 | accuracy = 0 64 | if cur_hits + cur_misses != 0: 65 | accuracy = cur_hits / (cur_hits + cur_misses) 66 | 67 | print('Accuracy on {} is {}'.format(self.non_terminals[cur], accuracy)) 68 | 69 | self.save_to_file(result) 70 | 71 | return 0 # this metrics if only for printing 72 | 73 | def save_to_file(self, result): 74 | if self.results_dir is not None: 75 | with open(os.path.join(self.results_dir, 'py_nt_acc.txt'), mode='w') as f: 76 | f.write(json.dumps(result)) 77 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/temp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/temp/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/temp/mnist_norm_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | from torch import nn as nn 7 | 8 | from zerogercrnn.lib.core import CombinedModule, LinearLayer, NormalizationLayer 9 | from zerogercrnn.lib.log import tqdm_lim 10 | from zerogercrnn.lib.metrics import TensorVisualizer2DMetrics 11 | 12 | 13 | class MNISTClassifier(CombinedModule): 14 | 15 | def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128): 16 | super(MNISTClassifier, self).__init__() 17 | self.action_space = action_space 18 | num_outputs = action_space 19 | 20 | self.linear1 = self.module(LinearLayer(num_inputs, hidden_size1)) 21 | self.linear2 = self.module(LinearLayer(hidden_size1, hidden_size2)) 22 | self.linear3 = self.module(LinearLayer(hidden_size2, num_outputs)) 23 | self.bn1 = self.module(NormalizationLayer(hidden_size1)) 24 | self.bn2 = self.module(NormalizationLayer(hidden_size2)) 25 | 26 | def forward(self, inputs): 27 | x = inputs 28 | x = self.bn1(F.relu(self.linear1(x))) 29 | x = self.bn2(F.relu(self.linear2(x))) 30 | out = self.linear3(x) 31 | return out 32 | 33 | 34 | def get_data_loader(batch_size): 35 | return torch.utils.data.DataLoader( 36 | datasets.MNIST( 37 | 'data/temp', 38 | train=True, 39 | download=True, 40 | transform=transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,)) 43 | ])), 44 | batch_size=batch_size, 45 | shuffle=True 46 | ) 47 | 48 | 49 | if __name__ == '__main__': 50 | batch_size = 10 51 | 52 | train_loader = get_data_loader(batch_size) 53 | model = MNISTClassifier(num_inputs=28 * 28, action_space=10) 54 | 55 | loss_func = nn.CrossEntropyLoss() 56 | optimizer = optim.Adam(model.parameters(), lr=0.01) 57 | 58 | metrics_before = TensorVisualizer2DMetrics(file='eval/temp/test_before') 59 | metrics_after = TensorVisualizer2DMetrics(file='eval/temp/test_after') 60 | 61 | 62 | def norm_hook(module, m_input, m_output): 63 | metrics_before.report(m_input[0]) 64 | metrics_after.report(m_output) 65 | 66 | 67 | model.bn2.register_forward_hook(norm_hook) 68 | 69 | model.train() 70 | it = 0 71 | for i in range(1): 72 | for data, target in tqdm_lim(train_loader, lim=2000): 73 | optimizer.zero_grad() 74 | m_output = model(data.view(batch_size, -1)) 75 | loss = loss_func(m_output, target) 76 | loss.backward() 77 | optimizer.step() 78 | 79 | if it % 1000 == 0: 80 | print(loss.item()) 81 | it += 1 82 | 83 | metrics_before.get_current_value(should_print=True) 84 | metrics_after.get_current_value(should_print=False) 85 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/token_level/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/experiments/token_level/base/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/base/main.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | from zerogercrnn.experiments.token_level.common import TokenMain, TokensLoss 4 | from zerogercrnn.experiments.token_level.base.model import TokenBaseModel 5 | from zerogercrnn.lib.core import BaseModule 6 | from zerogercrnn.lib.metrics import Metrics, MaxPredictionAccuracyMetrics, SequentialMetrics, TopKWrapper, ResultsSaver 7 | 8 | 9 | class TokenBaseMain(TokenMain): 10 | 11 | def __init__(self, args): 12 | super().__init__(args) 13 | 14 | def create_model(self, args) -> BaseModule: 15 | return TokenBaseModel( 16 | num_tokens=args.tokens_num, 17 | embedding_dim=args.token_embedding_dim, 18 | hidden_size=args.hidden_size 19 | ) 20 | 21 | def create_criterion(self, args) -> nn.Module: 22 | return TokensLoss() 23 | 24 | def create_train_metrics(self, args) -> Metrics: 25 | return MaxPredictionAccuracyMetrics() 26 | 27 | def create_eval_metrics(self, args) -> Metrics: 28 | return SequentialMetrics([ 29 | MaxPredictionAccuracyMetrics(), 30 | TopKWrapper(base=ResultsSaver(dir_to_save=args.eval_results_directory)) 31 | ]) 32 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/base/model.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.experiments.token_level.core import TokenModel 2 | from zerogercrnn.lib.core import RecurrentCore 3 | from zerogercrnn.lib.utils import forget_hidden_partly, repackage_hidden 4 | 5 | 6 | class TokenBaseModel(TokenModel): 7 | def __init__(self, num_tokens, embedding_dim, hidden_size): 8 | super().__init__(num_tokens=num_tokens, embedding_dim=embedding_dim, recurrent_output_size=hidden_size) 9 | self.hidden_size = hidden_size 10 | 11 | self.lstm = self.module(RecurrentCore( 12 | input_size=self.embedding_dim, 13 | hidden_size=self.hidden_size, 14 | num_layers=1, 15 | dropout=0., 16 | model_type='lstm' 17 | )) 18 | 19 | def get_recurrent_output(self, input_embedded, hidden, forget_vector): 20 | hidden = repackage_hidden(forget_hidden_partly(hidden, forget_vector)) 21 | return self.lstm(input_embedded, hidden) 22 | 23 | def init_hidden(self, batch_size): 24 | return self.lstm.init_hidden(batch_size) 25 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/common.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from zerogercrnn.experiments.common import Main 7 | from zerogercrnn.experiments.token_level.data import TokensDataReader, TokensDataGenerator 8 | from zerogercrnn.lib.core import BaseModule 9 | from zerogercrnn.lib.data import BatchedDataGenerator 10 | from zerogercrnn.lib.metrics import Metrics 11 | from zerogercrnn.lib.run import NetworkRoutine 12 | from zerogercrnn.lib.utils import filter_requires_grad, get_best_device 13 | 14 | 15 | def create_data_generator(args) -> BatchedDataGenerator: 16 | reader = TokensDataReader( 17 | train_file=args.train_file, 18 | eval_file=args.eval_file, 19 | seq_len=args.seq_len, 20 | limit=args.data_limit 21 | ) 22 | 23 | data_generator = TokensDataGenerator( 24 | data_reader=reader, 25 | seq_len=args.seq_len, 26 | batch_size=args.batch_size 27 | ) 28 | 29 | return data_generator 30 | 31 | 32 | class TokensLoss(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | self.criterion = nn.CrossEntropyLoss() 37 | 38 | def forward(self, prediction: torch.Tensor, target: torch.Tensor): 39 | return self.criterion(prediction.view(-1, prediction.size()[-1]), target.view(-1)) 40 | 41 | 42 | def run_model(model: nn.Module, iter_data, hidden, batch_size): 43 | (n_input, n_target), forget_vector = iter_data 44 | assert forget_vector.size()[0] == batch_size 45 | 46 | n_input = n_input.to(get_best_device()) 47 | n_target = n_target.to(get_best_device()) 48 | 49 | if hidden is None: 50 | hidden = model.init_hidden(batch_size=batch_size) 51 | 52 | prediction, hidden = model(n_input, hidden, forget_vector=forget_vector) 53 | 54 | return prediction, n_target, hidden 55 | 56 | 57 | class TokenLevelRoutine(NetworkRoutine): 58 | 59 | def __init__(self, model: nn.Module, batch_size, seq_len, criterion: nn.Module, optimizers): 60 | super().__init__(model) 61 | self.model = self.network 62 | self.batch_size = batch_size 63 | self.seq_len = seq_len 64 | self.criterion = criterion 65 | self.optimizers = optimizers 66 | 67 | self.hidden = None 68 | 69 | def optimize(self, loss): 70 | # Backward pass 71 | loss.backward() 72 | torch.nn.utils.clip_grad_norm_(filter_requires_grad(self.model.parameters()), 5) 73 | # torch.nn.utils.clip_grad_norm_(filter_requires_grad(self.model.sparse_parameters()), 5) 74 | 75 | # Optimizer step 76 | for optimizer in self.optimizers: 77 | optimizer.step() 78 | 79 | def run(self, iter_num, iter_data): 80 | if self.optimizers is not None: 81 | for optimizer in self.optimizers: 82 | optimizer.zero_grad() 83 | 84 | prediction, m_target, hidden = run_model( 85 | model=self.model, 86 | iter_data=iter_data, 87 | hidden=self.hidden, 88 | batch_size=self.batch_size 89 | ) 90 | self.hidden = hidden 91 | 92 | loss = self.criterion(prediction, m_target) 93 | if self.optimizers is not None: 94 | self.optimize(loss) 95 | 96 | return prediction, m_target 97 | 98 | 99 | class TokenMain(Main): 100 | 101 | def __init__(self, args): 102 | super().__init__(args) 103 | 104 | @abstractmethod 105 | def create_model(self, args) -> BaseModule: 106 | pass 107 | 108 | @abstractmethod 109 | def create_criterion(self, args) -> nn.Module: 110 | pass 111 | 112 | @abstractmethod 113 | def create_train_metrics(self, args) -> Metrics: 114 | pass 115 | 116 | @abstractmethod 117 | def create_eval_metrics(self, args) -> Metrics: 118 | pass 119 | 120 | def create_data_generator(self, args) -> BatchedDataGenerator: 121 | return create_data_generator(args) 122 | 123 | def create_train_routine(self, args) -> NetworkRoutine: 124 | return TokenLevelRoutine( 125 | model=self.model, 126 | batch_size=args.batch_size, 127 | seq_len=args.seq_len, 128 | criterion=self.criterion, 129 | optimizers=self.optimizers 130 | ) 131 | 132 | def create_validation_routine(self, args) -> NetworkRoutine: 133 | return TokenLevelRoutine( 134 | model=self.model, 135 | batch_size=args.batch_size, 136 | seq_len=args.seq_len, 137 | criterion=self.criterion, 138 | optimizers=None 139 | ) 140 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/core.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | 5 | from zerogercrnn.lib.core import CombinedModule, EmbeddingsModule, LinearLayer 6 | 7 | 8 | class TokenModel(CombinedModule): 9 | def __init__(self, num_tokens, embedding_dim, recurrent_output_size): 10 | super().__init__() 11 | self.num_tokens = num_tokens 12 | self.embedding_dim = embedding_dim 13 | self.recurrent_output_size = recurrent_output_size 14 | 15 | self.token_embeddings = self.module(EmbeddingsModule( 16 | num_embeddings=self.num_tokens, 17 | embedding_dim=self.embedding_dim, 18 | sparse=True 19 | )) 20 | 21 | self.h2o = self.module(LinearLayer( 22 | input_size=self.recurrent_output_size, 23 | output_size=self.num_tokens, 24 | bias=True 25 | )) 26 | 27 | def forward(self, m_input: torch.Tensor, hidden: torch.Tensor, forget_vector: torch.Tensor): 28 | input_embedded = self.token_embeddings(m_input) 29 | recurrent_output, hidden = self.get_recurrent_output(input_embedded, hidden, forget_vector) 30 | m_output = self.h2o(recurrent_output) 31 | return m_output, hidden 32 | 33 | @abstractmethod 34 | def get_recurrent_output(self, input_embedded, hidden, forget_vector): 35 | return None, None 36 | 37 | @abstractmethod 38 | def init_hidden(self, batch_size): 39 | pass 40 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from zerogercrnn.lib.data import DataChunk, BatchedDataGenerator, split_train_validation, DataReader 7 | from zerogercrnn.lib.utils import get_best_device 8 | 9 | # hack for tqdm 10 | tqdm.monitor_interval = 0 11 | 12 | from zerogercrnn.lib.calculation import pad_tensor 13 | 14 | VECTOR_FILE = 'data/tokens/vectors.txt' 15 | TRAIN_FILE = 'data/tokens/file_train.json' 16 | EVAL_FILE = 'data/tokens/file_eval.json' 17 | ENCODING = 'ISO-8859-1' 18 | 19 | 20 | class TokensDataChunk(DataChunk): 21 | def __init__(self, one_hot_tensor): 22 | super().__init__() 23 | 24 | self.one_hot_tensor = one_hot_tensor.to(get_best_device()) 25 | self.seq_len = None 26 | 27 | def prepare_data(self, seq_len): 28 | self.seq_len = seq_len 29 | self.one_hot_tensor = pad_tensor(tensor=self.one_hot_tensor, seq_len=seq_len) 30 | 31 | def get_by_index(self, index): 32 | if self.seq_len is None: 33 | raise Exception('You should call prepare_data with specified seq_len first') 34 | if index + self.seq_len > self.size(): 35 | raise Exception('Not enough data in chunk') 36 | 37 | input_tensor_emb = self.one_hot_tensor.narrow(dim=0, start=index, length=self.seq_len - 1) 38 | target_tensor = self.one_hot_tensor.narrow(dim=0, start=index + 1, length=self.seq_len - 1) 39 | 40 | return input_tensor_emb, target_tensor 41 | 42 | def size(self): 43 | return self.one_hot_tensor.size()[0] 44 | 45 | 46 | class TokensDataReader(DataReader): 47 | """Reads the data from file and transform it to torch Tensors.""" 48 | 49 | def __init__(self, train_file, eval_file, seq_len, limit=100000): 50 | super().__init__() 51 | self.train_file = train_file 52 | self.eval_file = eval_file 53 | self.seq_len = seq_len 54 | 55 | print('Start data reading') 56 | if self.train_file is not None: 57 | self.train_data, self.validation_data = split_train_validation( 58 | data=self._read_file(train_file, limit=limit, label='Train'), 59 | split_coefficient=0.8 60 | ) 61 | 62 | if self.eval_file is not None: 63 | self.eval_data = self._read_file(eval_file, limit=limit, label='Eval') 64 | 65 | print('Data reading finished') 66 | print('Train size: {}, Validation size: {}, Eval size: {}'.format( 67 | len(self.train_data), 68 | len(self.validation_data), 69 | len(self.eval_data) 70 | )) 71 | 72 | def _read_file(self, file_path, limit=100000, label='Data'): 73 | print('Reading {} ... '.format(label)) 74 | data = [] 75 | it = 0 76 | for l in tqdm(open(file=file_path, mode='r', encoding=ENCODING), total=limit): 77 | it += 1 78 | 79 | tokens = json.loads(l) 80 | one_hot = torch.LongTensor(tokens).to(get_best_device()) 81 | 82 | data.append(TokensDataChunk(one_hot_tensor=one_hot)) 83 | 84 | if (limit is not None) and (it == limit): 85 | break 86 | 87 | return list(filter(lambda d: d.size() >= self.seq_len, data)) 88 | 89 | 90 | class TokensDataGenerator(BatchedDataGenerator): 91 | 92 | def __init__(self, data_reader: DataReader, seq_len, batch_size): 93 | super().__init__(data_reader, seq_len=seq_len, batch_size=batch_size) 94 | 95 | def _retrieve_batch(self, key, buckets): 96 | inputs = [] 97 | targets = [] 98 | 99 | for b in buckets: 100 | id, chunk = b.get_next_index_with_chunk() 101 | 102 | i, t = chunk.get_by_index(id) 103 | 104 | inputs.append(i) 105 | targets.append(t) 106 | 107 | return torch.stack(inputs, dim=1), torch.stack(targets, dim=1) 108 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from zerogercrnn.experiments.common import Main 6 | from zerogercrnn.experiments.token_level.base.main import TokenBaseMain 7 | from zerogercrnn.lib.argutils import add_general_arguments, add_batching_data_args, add_optimization_args, \ 8 | add_recurrent_core_args, add_tokens_args 9 | from zerogercrnn.lib.log import logger 10 | 11 | parser = argparse.ArgumentParser(description='Token level neural network') 12 | add_general_arguments(parser) 13 | add_batching_data_args(parser) 14 | add_optimization_args(parser) 15 | add_recurrent_core_args(parser) 16 | add_tokens_args(parser) 17 | 18 | parser.add_argument('--prediction', type=str, help='One of: nt2n, nt2n_pre, nt2n_tail, nt2n_sum, nt2nt, ntn2t') 19 | parser.add_argument('--save_model_every', type=int, help='How often to save model', default=1) 20 | 21 | # This is for evaluation purposes 22 | parser.add_argument('--eval', action='store_true', help='Evaluate or train') 23 | parser.add_argument('--eval_results_directory', type=str, help='Where to save results of evaluation') 24 | 25 | 26 | def get_main(args) -> Main: 27 | if args.prediction == 'token_base': 28 | main = TokenBaseMain(args) 29 | else: 30 | raise Exception('Unknown type of prediction: {}'.format(args.prediciton)) 31 | 32 | return main 33 | 34 | 35 | def train(args): 36 | get_main(args).train(args) 37 | 38 | 39 | def evaluate(args): 40 | get_main(args).eval(args) 41 | 42 | 43 | if __name__ == '__main__': 44 | print(torch.__version__) 45 | _args = parser.parse_args() 46 | assert _args.title is not None 47 | logger.should_log = _args.log 48 | 49 | if _args.eval: 50 | evaluate(_args) 51 | else: 52 | train(_args) 53 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/metrics.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.lib.metrics import Metrics, BaseAccuracyMetrics, IndexedAccuracyMetrics 2 | 3 | 4 | class AggregatedTokenMetrics(Metrics): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | self.common = BaseAccuracyMetrics() 9 | self.target_non_unk = IndexedAccuracyMetrics('Target not unk') 10 | self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk') 11 | 12 | def drop_state(self): 13 | self.common.drop_state() 14 | self.target_non_unk.drop_state() 15 | self.prediction_non_unk.drop_state() 16 | 17 | def report(self, prediction_target): 18 | prediction, target = prediction_target 19 | prediction = prediction.view(-1) 20 | target = target.view(-1) 21 | 22 | self.common.report((prediction, target)) 23 | 24 | pred_non_unk_indices = (prediction != 0).nonzero().squeeze() 25 | target_non_unk_indices = (target != 0).nonzero().squeeze() 26 | 27 | self.prediction_non_unk.report(prediction, target, pred_non_unk_indices) 28 | self.target_non_unk.report(prediction, target, target_non_unk_indices) 29 | 30 | def get_current_value(self, should_print=False): 31 | print('P1 = {}'.format(self.common.get_current_value(False))) 32 | print('P2 = {}'.format(self.prediction_non_unk.metrics.hits / (self.common.hits + self.common.misses))) 33 | print('P3 = {}'.format(self.target_non_unk.get_current_value(False))) 34 | print('P4 = {}'.format(self.prediction_non_unk.get_current_value(False))) 35 | -------------------------------------------------------------------------------- /zerogercrnn/experiments/token_level/results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import numpy 5 | import torch 6 | 7 | from zerogercrnn.experiments.token_level.main import create_data_generator, create_model 8 | from zerogercrnn.experiments.token_level.main import run_model 9 | from zerogercrnn.lib.file import read_lines, load_if_saved, load_cuda_on_cpu 10 | from zerogercrnn.lib.visualization.text import show_diff 11 | 12 | parser = argparse.ArgumentParser(description='AST level neural network') 13 | parser.add_argument('--title', type=str, help='Title for this run. Used in tensorboard and in saving of models.') 14 | parser.add_argument('--train_file', type=str, help='File with training data') 15 | parser.add_argument('--eval_file', type=str, help='File with eval data') 16 | parser.add_argument('--embeddings_file', type=str, help='File with embedding vectors') 17 | parser.add_argument('--data_limit', type=int, help='How much lines of data to process (only for fast checking)') 18 | parser.add_argument('--model_save_dir', type=str, help='Where to save trained models') 19 | parser.add_argument('--saved_model', type=str, help='File with trained model if not fresh train') 20 | parser.add_argument('--log', action='store_true', help='Log performance?') 21 | parser.add_argument('--vocab', type=str, help='Vocabulary of used tokens') 22 | 23 | parser.add_argument('--tokens_count', type=int, help='All possible tokens count') 24 | parser.add_argument('--seq_len', type=int, help='Recurrent layer time unrolling') 25 | parser.add_argument('--batch_size', type=int, help='Size of batch') 26 | parser.add_argument('--learning_rate', type=float, help='Learning rate') 27 | parser.add_argument('--epochs', type=int, help='Number of epochs to run model') 28 | parser.add_argument('--decay_after_epoch', type=int, help='Multiply lr by decay_multiplier each epoch') 29 | parser.add_argument('--decay_multiplier', type=float, help='Multiply lr by this number after decay_after_epoch') 30 | parser.add_argument('--embedding_size', type=int, help='Size of embedding to use') 31 | parser.add_argument('--hidden_size', type=int, help='Hidden size of recurrent part of model') 32 | parser.add_argument('--num_layers', type=int, help='Number of recurrent layers') 33 | parser.add_argument('--dropout', type=float, help='Dropout to apply to recurrent layer') 34 | parser.add_argument('--weight_decay', type=float, help='Weight decay for l2 regularization') 35 | 36 | ENCODING = 'ISO-8859-1' 37 | 38 | 39 | def load_model(args, model): 40 | if args.saved_model is not None: 41 | if torch.cuda.is_available(): 42 | load_if_saved(model, args.saved_model) 43 | else: 44 | load_cuda_on_cpu(model, args.saved_model) 45 | 46 | 47 | def load_dictionary(tokens_path): 48 | id2token = read_lines(tokens_path, encoding=ENCODING) 49 | token2id = {} 50 | for id, token in enumerate(id2token): 51 | token2id[token] = id 52 | 53 | return token2id, id2token 54 | 55 | 56 | def single_data_prediction(args, model, iter_data, hidden): 57 | prediction, target, hidden = run_model(model=model, iter_data=iter_data, hidden=hidden, batch_size=args.batch_size) 58 | return prediction, target, hidden 59 | 60 | 61 | def get_token_for_print(id2token, id): 62 | if id == 0: 63 | return 'UNK' 64 | else: 65 | return id2token[id - 1] 66 | 67 | 68 | def print_results_for_current_prediction(id2token, prediction, target): 69 | prediction_values, prediction = torch.max(prediction, dim=2) 70 | prediction = prediction.view(-1) 71 | target = target.view(-1) 72 | 73 | text_actual = [] 74 | text_predicted = [] 75 | 76 | for i in range(len(prediction)): 77 | is_true = prediction.data[i] == target.data[i] 78 | 79 | text_actual.append(get_token_for_print(id2token, target.data[i])) 80 | text_predicted.append(get_token_for_print(id2token, prediction.data[i])) 81 | 82 | return text_actual, text_predicted 83 | 84 | 85 | def format_text(text): 86 | formatted = [] 87 | it = 0 88 | for t in text: 89 | if it % 20 == 0: 90 | formatted.append('\n') 91 | it += 1 92 | formatted.append(t) 93 | formatted.append(' ') 94 | 95 | return formatted 96 | 97 | 98 | def print_prediction(args): 99 | model = create_model(args) 100 | 101 | if args.batch_size != 1: 102 | raise Exception('batch_size should be 1 for visualization') 103 | 104 | if args.saved_model is not None: 105 | if torch.cuda.is_available(): 106 | load_if_saved(model, args.saved_model) 107 | else: 108 | load_cuda_on_cpu(model, args.saved_model) 109 | 110 | generator = create_data_generator(args) 111 | 112 | model.eval() 113 | hidden = None 114 | 115 | lim = 1 116 | it = 0 117 | 118 | token2id, id2token = load_dictionary(args.vocab) 119 | text_actual = [] 120 | text_predicted = [] 121 | for iter_data in generator.get_eval_generator(): 122 | prediction, target, n_hidden = single_data_prediction(args, model, iter_data, hidden) 123 | c_a, c_p = print_results_for_current_prediction(id2token, prediction, target) 124 | text_actual += c_a 125 | text_predicted += c_p 126 | 127 | hidden = n_hidden 128 | it += 1 129 | if it == lim: 130 | break 131 | 132 | show_diff(format_text(text_predicted), format_text(text_actual)) 133 | 134 | 135 | if __name__ == '__main__': 136 | # good seeds: 10 5 11 137 | # random.seed(seed) 138 | # numpy.random.seed(seed) 139 | 140 | _args = parser.parse_args() 141 | 142 | for seed in [13, 14, 15, 16, 17, 18, 19, 20]: 143 | random.seed(seed) 144 | numpy.random.seed(seed) 145 | print_prediction(_args) 146 | -------------------------------------------------------------------------------- /zerogercrnn/global_constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | DEFAULT_ENCODING = 'ISO-8859-1' 5 | -------------------------------------------------------------------------------- /zerogercrnn/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/lib/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/lib/accuracies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def indexed_topk_hits(prediction, target, index): 5 | """ 6 | :param prediction: tensor of size [N, K], where K is the number of top predictions 7 | :param target: tensor of size [N] 8 | :param index: tensor of size [I] with indexes to count accuracy on 9 | :return: array with T topk hits, total_entries 10 | """ 11 | selected_prediction = torch.index_select(prediction, 0, index) 12 | selected_target = torch.index_select(target, 0, index) 13 | 14 | if selected_prediction.size()[0] == 0: 15 | return torch.zeros((prediction.size()[-1]), dtype=torch.int64), 0 16 | return topk_hits(selected_prediction, selected_target) 17 | 18 | 19 | def topk_hits(prediction, target): 20 | """ 21 | :param prediction: tensor of size [N, K], where K is the number of top predictions 22 | :param target: tensor of size [N] 23 | :return: array with T topk hits, total_entries 24 | """ 25 | n = prediction.size()[0] 26 | k = prediction.size()[1] 27 | 28 | hits = torch.zeros(k, dtype=torch.int64) 29 | correct = prediction.eq(target.unsqueeze(1).expand_as(prediction)) 30 | for tk in range(k): 31 | cur_hits = correct[:, :tk + 1] 32 | hits[tk] += cur_hits.sum() 33 | 34 | return hits, n 35 | -------------------------------------------------------------------------------- /zerogercrnn/lib/argutils.py: -------------------------------------------------------------------------------- 1 | def add_general_arguments(parser): 2 | parser.add_argument('--title', type=str, help='Title for this run. Used in tensorboard and in saving of models.') 3 | parser.add_argument('--train_file', type=str, help='File with training data') 4 | parser.add_argument('--eval_file', type=str, help='File with eval data') 5 | parser.add_argument('--data_limit', type=int, help='How much lines of data to process (only for fast checking)') 6 | parser.add_argument('--model_save_dir', type=str, help='Where to save trained models') 7 | parser.add_argument('--saved_model', type=str, help='File with trained model if not fresh train') 8 | parser.add_argument('--log', action='store_true', help='Log performance?') 9 | 10 | 11 | def add_batching_data_args(parser): 12 | parser.add_argument('--seq_len', type=int, help='Recurrent layer time unrolling') 13 | parser.add_argument('--batch_size', type=int, help='Size of batch') 14 | 15 | 16 | def add_optimization_args(parser): 17 | parser.add_argument('--learning_rate', type=float, help='Learning rate') 18 | parser.add_argument('--epochs', type=int, help='Number of epochs to run model') 19 | parser.add_argument('--decay_after_epoch', type=int, help='Multiply lr by decay_multiplier each epoch') 20 | parser.add_argument('--decay_multiplier', type=float, help='Multiply lr by this number after decay_after_epoch') 21 | parser.add_argument('--weight_decay', type=float, help='Weight decay for l2 regularization') 22 | 23 | 24 | def add_recurrent_core_args(parser): 25 | parser.add_argument('--hidden_size', type=int, help='Hidden size of recurrent part of model') 26 | parser.add_argument('--num_layers', type=int, help='Number of recurrent layers') 27 | parser.add_argument('--dropout', type=float, help='Dropout to apply to recurrent layer') 28 | # Layered LSTM args, ignored if not layered 29 | parser.add_argument('--layered_hidden_size', type=int, help='Size of hidden state in layered lstm') 30 | parser.add_argument('--num_tree_layers', type=int, help='Number of layers to distribute hidden size') 31 | 32 | 33 | def add_non_terminal_args(parser): 34 | parser.add_argument('--non_terminals_num', type=int, help='Number of different non-terminals') 35 | parser.add_argument('--non_terminal_embedding_dim', type=int, help='Dimension of non-terminal embeddings') 36 | parser.add_argument('--non_terminals_file', type=str, help='Json file with all non-terminals') 37 | parser.add_argument('--non_terminal_embeddings_file', type=str, help='File with pretrained non-terminal embeddings') 38 | 39 | 40 | def add_terminal_args(parser): 41 | parser.add_argument('--terminals_num', type=int, help='Number of different terminals') 42 | parser.add_argument('--terminal_embedding_dim', type=int, help='Dimension of terminal embeddings') 43 | parser.add_argument('--terminals_file', type=str, help='Json file with all terminals') 44 | parser.add_argument('--terminal_embeddings_file', type=str, help='File with pretrained terminal embeddings') 45 | 46 | 47 | def add_tokens_args(parser): 48 | parser.add_argument('--tokens_num', type=int, help='Number of different tokens in train file') 49 | parser.add_argument('--token_embedding_dim', type=int, help='Size of continuous token representation') -------------------------------------------------------------------------------- /zerogercrnn/lib/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from zerogercrnn.lib.calculation import drop_matrix_rows_3d, calc_attention_combination 6 | from zerogercrnn.lib.core import BaseModule 7 | from zerogercrnn.lib.utils import init_layers_uniform, get_best_device 8 | 9 | 10 | class CyclicBuffer: 11 | def __init__(self, buffer): 12 | self.buffer = buffer 13 | self.it = 0 14 | 15 | def add_vector(self, vector): 16 | self.buffer[:, self.it, :].copy_(vector) # TODO: general way 17 | self.it += 1 18 | if self.it >= self.buffer.size()[1]: 19 | self.it = 0 20 | 21 | def get(self): 22 | return self.buffer 23 | 24 | 25 | class LastKBuffer: 26 | def __init__(self, window_len, buffer): 27 | assert window_len <= buffer.size()[1] 28 | self.buffer_size = buffer.size()[1] 29 | self.window_len = window_len 30 | self.buffer = buffer 31 | 32 | self.it = window_len 33 | 34 | def add_vector(self, vector): 35 | self.buffer[:, self.it, :].copy_(vector.detach()) # TODO: general way 36 | self.it += 1 37 | if self.it >= self.buffer_size: 38 | self.buffer.narrow(dim=1, start=0, length=self.window_len).copy_( 39 | self.buffer.narrow(dim=1, start=self.buffer_size - self.window_len, length=self.window_len) 40 | ) 41 | self.it = self.window_len 42 | 43 | def get(self): 44 | return self.buffer.narrow(dim=1, start=self.it - self.window_len, length=self.window_len) 45 | 46 | 47 | class Attn(BaseModule): 48 | def __init__(self, method, hidden_size): 49 | super(Attn, self).__init__() 50 | 51 | self.method = method 52 | self.hidden_size = hidden_size 53 | 54 | if self.method == 'general': 55 | self.attn = nn.Linear(self.hidden_size, self.hidden_size) 56 | init_layers_uniform(-0.05, 0.05, [self.attn]) 57 | 58 | # elif self.method == 'concat': 59 | # self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 60 | # self.other = nn.Parameter(torch.FloatTensor(1, hidden_size)) 61 | # nn.init.uniform(self.attn.parameters(), -0.05, 0.05) 62 | # nn.init.uniform(self.other, -0.05, 0.05) 63 | 64 | def forward(self, main_vector, attn_vectors): 65 | """ 66 | :param main_vector: matrix of size [batch_size, N] 67 | :param attn_vectors: matrix of size [batch_size, seq_len, N] 68 | :return: 69 | """ 70 | seq_len = attn_vectors.size()[1] 71 | 72 | # Calculate energies for each encoder output 73 | attn_energies = self.score(main_vector, attn_vectors) 74 | 75 | return F.softmax(attn_energies, dim=1) 76 | 77 | def score(self, main_vector, attn_vectors): 78 | """ 79 | :param main_vector: matrix of size [batch_size, N] 80 | :param attn_vectors: matrix of size [batch_size, seq_len, N] 81 | :return: matrix with attention coefficients of size [batch_size, seq_len, 1] 82 | """ 83 | if self.method == 'dot': 84 | pass # all is ready 85 | elif self.method == 'general': 86 | attn_vectors = self.attn(attn_vectors) 87 | else: 88 | raise Exception('Unknown attention method: {}'.format(self.method)) 89 | 90 | # main_vector [batch_size, N] -> [batch_size, 1, 1, N] 91 | main_vector = main_vector.unsqueeze(1).unsqueeze(1) 92 | # att_vectors [batch_size, seq_len, N, 1] 93 | attn_vectors = attn_vectors.unsqueeze(3) 94 | # after multiplication [batch_size, seq_len, 1, 1] -> [batch_size, seq_len, 1, 1] 95 | energy = main_vector.matmul(attn_vectors).squeeze(-1) 96 | return energy 97 | 98 | # TODO: implement concat 99 | # elif self.method == 'concat': 100 | # energy = self.attn(torch.cat((hidden, encoder_output), 1)) 101 | # energy = self.other.dot(energy) 102 | # return energy 103 | 104 | 105 | class ContextAttention(BaseModule): 106 | """Attention layer that calculate attention of past seq_len reported inputs to the currently reported input.""" 107 | 108 | def __init__(self, context_len, hidden_size): 109 | super().__init__() 110 | self.seq_len = context_len 111 | self.hidden_size = hidden_size 112 | self.it = 0 113 | 114 | # Layer that applies attention to past self.cntx hidden states of contexts 115 | self.attn = Attn(method='general', hidden_size=self.hidden_size) 116 | 117 | # Matrix that will hold past seq_len contexts. No backprop will be computed 118 | # size: [batch_size, seq_len, hidden_size] 119 | self.context_buffer = None 120 | 121 | def init_hidden(self, batch_size): 122 | b_matrix = torch.FloatTensor(batch_size, 2 * self.seq_len, self.hidden_size).to(get_best_device()) 123 | self.context_buffer = LastKBuffer(window_len=self.seq_len, buffer=b_matrix) 124 | 125 | def forget_context_partly(self, forget_vector): 126 | """Method to drop context for programs that ended. 127 | :param forget_vector vector of size [batch_size, 1] with either 0 or 1 128 | """ 129 | drop_matrix_rows_3d(self.context_buffer.get(), forget_vector) 130 | 131 | def forward(self, h_t): 132 | """ 133 | :param h_t: current hidden state of size [batch_size, hidden_size] 134 | :return: hidden state with applied sum attention of size [batch_size, hidden_size] 135 | """ 136 | assert self.context_buffer is not None 137 | 138 | current_context = self.context_buffer.get() 139 | attn_weights = self.attn(h_t, current_context) 140 | 141 | # self.it += 1 142 | # if self.it % 10000 == 0: 143 | # print(attn_weights.data[0]) 144 | 145 | # Calc current context vector as sum of previous contexts multiplied by attention coefficients 146 | cntx = calc_attention_combination(attn_weights, current_context) 147 | 148 | self.context_buffer.add_vector(h_t) 149 | return cntx 150 | -------------------------------------------------------------------------------- /zerogercrnn/lib/calculation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def shift_left(matrix, dimension): 5 | """Shift tensor left by one along specified dimension. This operation performed in-place""" 6 | m_len = matrix.size()[dimension] 7 | matrix.narrow(dim=dimension, start=0, length=m_len - 1) \ 8 | .copy_(matrix.narrow(dim=dimension, start=1, length=m_len - 1)) 9 | 10 | 11 | def pad_tensor(tensor, seq_len): 12 | """Pad tensor with last element along 0 dimension.""" 13 | sz = list(tensor.size()) 14 | sz[0] = seq_len - tensor.size()[0] % seq_len 15 | 16 | tail = tensor[-1].clone().expand(sz).to(tensor.device) 17 | tensor = torch.cat((tensor, tail)) 18 | return tensor 19 | 20 | 21 | def calc_attention_combination(attention_weights, matrix): 22 | """Calculate sum of vectors of matrix along dim=1 with coefficients specified by attention_weights. 23 | 24 | :param attention_weights: size - [batch_size, seq_len, 1] 25 | :param matrix: size - [batch_size, seq_len, vector_dim] 26 | :return: matrix of size [batch_size, vector_dim] 27 | """ 28 | return attention_weights.transpose(1, 2).matmul(matrix).squeeze(1) 29 | 30 | 31 | def drop_matrix_rows_3d(matrix, forget_vector): 32 | """ 33 | Zeroing blocks along first dimension according to forget_vector. Forget vector should consist of 0s and 1s. 34 | 35 | :param matrix: size - [N1, N2, N3] 36 | :param forget_vector: size - [N1, 1] 37 | :return: size - [N1, N2, N3] 38 | """ 39 | return matrix.mul(forget_vector.unsqueeze(2)) 40 | 41 | 42 | def select_layered_hidden(layered_hidden, node_depths): 43 | """Selects hidden state for each element in the batch according to layer number in node_depths 44 | 45 | :param layered_hidden: tensor of size [batch_size, layers_num, hidden_size] 46 | :param node_depths: for each batch line contains layer that should be picked. shape: [batch_size] 47 | """ 48 | batch_size = layered_hidden.size()[0] 49 | layers_num = layered_hidden.size()[1] 50 | hidden_size = layered_hidden.size()[2] 51 | depths_one_hot = layered_hidden.new(batch_size, layers_num) 52 | 53 | depths_one_hot.zero_().scatter_(1, node_depths.unsqueeze(1), 1) 54 | mask = depths_one_hot.unsqueeze(2).byte() 55 | mask = mask.to(layered_hidden.device) 56 | 57 | return torch.masked_select(layered_hidden, mask).view(batch_size, 1, hidden_size) 58 | 59 | 60 | def set_layered_hidden(layered_hidden, node_depths, updated): 61 | """Returns new tensor that represents updated hidden state. Only layers that specified in node_depths get updated. 62 | 63 | :param layered_hidden: tensor of size [batch_size, layers_num, hidden_size] 64 | :param node_depths: for each batch line contains layer that should be updated. shape: [batch_size] 65 | :param updated: updated hidden states for particular layer. shape: [batch_size, hidden_size] 66 | """ 67 | batch_size = layered_hidden.size()[0] 68 | layers_num = layered_hidden.size()[1] 69 | hidden_size = layered_hidden.size()[2] 70 | 71 | node_depths_update = node_depths.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, hidden_size) 72 | updated = updated.unsqueeze(1) 73 | node_depths_update.to(layered_hidden.device) 74 | 75 | return layered_hidden.scatter(1, node_depths_update, updated) 76 | 77 | 78 | def create_one_hot(vector, one_hot_size): 79 | """Creates one-hot matrix from 1D vector""" 80 | batch_size = vector.size()[0] 81 | depths_one_hot = vector.new(batch_size, one_hot_size) 82 | return depths_one_hot.zero_().scatter_(1, vector.unsqueeze(1), 1).float() 83 | -------------------------------------------------------------------------------- /zerogercrnn/lib/constants.py: -------------------------------------------------------------------------------- 1 | ENCODING = 'ISO-8859-1' 2 | 3 | EMPTY_TOKEN = '' # token means that for particular terminal there are no corresponding non-terminal 4 | EMPTY_TOKEN_ID = 0 5 | 6 | UNKNOWN_TOKEN = '' # token means that non-terminal token is rare 7 | UNKNOWN_TOKEN_ID = 50000 8 | 9 | EOF_TOKEN = 'EOF' # token indicating end of program 10 | EOF_TOKEN_ID = 96 11 | -------------------------------------------------------------------------------- /zerogercrnn/lib/data.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from zerogercrnn.lib.utils import get_best_device 8 | 9 | 10 | def split_train_validation(data, split_coefficient): 11 | train_examples = int(len(data) * split_coefficient) 12 | return data[:train_examples], data[train_examples:len(data)] 13 | 14 | 15 | def get_shuffled_indexes(length): 16 | temp = np.arange(length) 17 | np.random.shuffle(temp) 18 | return temp 19 | 20 | 21 | def get_random_index(length): 22 | return np.random.randint(length) 23 | 24 | 25 | class DataReader: 26 | """General interface for readers of text files into format for DataGenerator. 27 | Should provide fields for train, validation, eval.""" 28 | 29 | def __init__(self): 30 | self.train_data = [] 31 | self.validation_data = [] 32 | self.eval_data = [] 33 | self.eval_tails = 0 34 | 35 | 36 | class DataGenerator: 37 | """General interface for generators of data for training and validation.""" 38 | 39 | @abstractmethod 40 | def get_train_generator(self): 41 | """Provides data for one epoch of training.""" 42 | pass 43 | 44 | @abstractmethod 45 | def get_validation_generator(self): 46 | """Provides data for one validation cycle.""" 47 | pass 48 | 49 | @abstractmethod 50 | def get_eval_generator(self): 51 | """Provides data for evaluation of trained model.""" 52 | pass 53 | 54 | 55 | class DataChunk: 56 | 57 | @abstractmethod 58 | def prepare_data(self, seq_len): 59 | """Align data with seq_len.""" 60 | pass 61 | 62 | @abstractmethod 63 | def get_by_index(self, index): 64 | pass 65 | 66 | @abstractmethod 67 | def size(self): 68 | pass 69 | 70 | 71 | class DataChunksPool: 72 | """Pool with chunks of data. Able to split data into some parts to produce many epochs from one data pool.""" 73 | 74 | def __init__(self, chunks, splits=1, shuffle=True): 75 | self.chunks = chunks 76 | self.splits = splits 77 | self.shuffle = shuffle 78 | self.epoch_size = len(self.chunks) // self.splits 79 | 80 | self.current = 0 81 | self.right = 0 82 | self._recreate_indexes() 83 | 84 | def start_epoch(self): 85 | if self.current != self.right: 86 | raise Exception( 87 | 'You should finish previous epoch first, cur: {}, right: {}'.format(self.current, self.right) 88 | ) 89 | 90 | if self.current + self.epoch_size > len(self.chunks): # need to start new epoch from begining of data 91 | self.current = 0 92 | self._recreate_indexes() 93 | 94 | self.right = min(self.current + self.epoch_size, len(self.chunks)) 95 | 96 | def get_chunk(self): 97 | """Return next chunks of data in current epoch. Returns None if epoch is finished.""" 98 | if self.current == self.right: 99 | return None 100 | else: 101 | cur = self.current 102 | self.current += 1 103 | 104 | if self.current % 100 == 0: 105 | print('Processed {} programs'.format(self.current)) 106 | 107 | return self.chunks[self.indexes[cur]] 108 | 109 | def is_epoch_finished(self): 110 | return self.current == self.right 111 | 112 | def _recreate_indexes(self): 113 | if self.shuffle: 114 | self.indexes = get_shuffled_indexes(length=len(self.chunks)) 115 | else: 116 | self.indexes = np.arange(start=0, stop=len(self.chunks)) 117 | 118 | 119 | class DataBucket: 120 | """Bucket with DataChunks. Could return index to get data from DataChunk and refills automatically from pool.""" 121 | 122 | def __init__(self, pool: DataChunksPool, seq_len, on_new_chunk=None): 123 | self.pool = pool 124 | self.seq_len = seq_len 125 | self.on_new_chunk = on_new_chunk 126 | 127 | self.chunk = None 128 | self.index = 0 129 | 130 | def get_next_index_with_chunk(self): 131 | """Returns next index to get data from DataChunk.""" 132 | if self.is_empty(): 133 | print('Chunk: {}, Index: {}'.format(self.chunk, self.index)) 134 | raise Exception('No data in bucket') 135 | 136 | if (self.index == 0) and (self.on_new_chunk is not None): 137 | self.on_new_chunk() 138 | 139 | start = self.index 140 | chunk = self.chunk 141 | 142 | self.index += self.seq_len 143 | self.refill_if_necessary() 144 | 145 | return start, chunk 146 | 147 | def is_empty(self): 148 | """Indicates whether this bucket contains at least one more sequence.""" 149 | return (self.chunk is None) or (self.chunk.size() == self.index) 150 | 151 | def refill_if_necessary(self): 152 | if self.is_empty(): 153 | self.chunk = self.pool.get_chunk() 154 | self.index = 0 155 | 156 | 157 | class BucketsBatch: 158 | def __init__(self, pool: DataChunksPool, seq_len, batch_size): 159 | self.pool = pool 160 | self.seq_len = seq_len 161 | self.batch_size = batch_size 162 | self.buckets = [] 163 | 164 | self.forget_vector = torch.FloatTensor(batch_size, 1).to(get_best_device()) 165 | 166 | def forget(x): 167 | self.forget_vector[x] = 0 168 | 169 | def get_forget(x): 170 | return lambda: forget(x) 171 | 172 | for i in range(self.batch_size): 173 | self.buckets.append( 174 | DataBucket( 175 | pool=self.pool, 176 | seq_len=self.seq_len, 177 | on_new_chunk=get_forget(i) 178 | )) 179 | 180 | def get_epoch(self, retriever): 181 | self.pool.start_epoch() 182 | 183 | for b in self.buckets: 184 | b.refill_if_necessary() 185 | 186 | while True: 187 | self.forget_vector.fill_(1) 188 | yield retriever(self.buckets), self.forget_vector 189 | if self.pool.is_epoch_finished(): 190 | should_exit = False 191 | for b in self.buckets: 192 | if b.chunk is None: 193 | should_exit = True 194 | if should_exit: 195 | break 196 | 197 | 198 | class BatchedDataGenerator(DataGenerator): 199 | """Provides batched data for training and evaluation of model.""" 200 | 201 | def __init__(self, data_reader, seq_len, batch_size, shuffle=True): 202 | super(BatchedDataGenerator, self).__init__() 203 | self.data_reader = data_reader 204 | self.seq_len = seq_len 205 | self.batch_size = batch_size 206 | 207 | self.batches = {} 208 | 209 | if data_reader.train_data is not None: 210 | self.train_pool = self._prepare_data_(data_reader.train_data, splits=1, shuffle=shuffle) 211 | self.train_batcher = BucketsBatch(self.train_pool, self.seq_len, self.batch_size) 212 | 213 | if data_reader.validation_data is not None: 214 | self.validation_pool = self._prepare_data_(data_reader.validation_data, splits=1, shuffle=shuffle) 215 | self.validation_batcher = BucketsBatch(self.validation_pool, self.seq_len, self.batch_size) 216 | 217 | if data_reader.eval_data is not None: 218 | self.eval_pool = self._prepare_data_(data_reader.eval_data, splits=1, shuffle=True) 219 | self.eval_batcher = BucketsBatch(self.eval_pool, self.seq_len, self.batch_size) 220 | 221 | @abstractmethod 222 | def _retrieve_batch(self, key, buckets): 223 | """Create batch of data for model using buckets. Buckets are guaranteed to contain data. 224 | Key can be used for caching.""" 225 | pass 226 | 227 | def _get_batched_epoch(self, key, batcher): 228 | return batcher.get_epoch(retriever=lambda buckets: self._retrieve_batch(key, buckets)) 229 | 230 | # override 231 | def get_train_generator(self): 232 | return self._get_batched_epoch('train', self.train_batcher) 233 | 234 | # override 235 | def get_validation_generator(self): 236 | return self._get_batched_epoch('validation', self.validation_batcher) 237 | 238 | # override 239 | def get_eval_generator(self): 240 | return self._get_batched_epoch('eval', self.eval_batcher) 241 | 242 | def _prepare_data_(self, data, splits=5, shuffle=True): 243 | for i in tqdm(range(len(data))): 244 | data[i].prepare_data(self.seq_len) 245 | 246 | return DataChunksPool(chunks=data, splits=splits, shuffle=shuffle) 247 | -------------------------------------------------------------------------------- /zerogercrnn/lib/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ENCODING = 'ISO-8859-1' 4 | 5 | 6 | class Embeddings: 7 | 8 | def __init__(self, embeddings_size, vector_file, squeeze=False): 9 | self.embedding_size = embeddings_size 10 | self.vector_file = vector_file 11 | self.embeddings_tensor = None 12 | 13 | if squeeze: 14 | self._read_embeddings_squeezed(vector_file) 15 | else: 16 | self._read_embeddings(vector_file) 17 | 18 | def index_select(self, index, out=None): 19 | """Make sure that ther is no in dataset. Also embeddings for non-vocabulary words will be zero. 20 | Otherwise embeddings will be equal to zero.""" 21 | 22 | return torch.index_select(self.embeddings_tensor, dim=0, index=index, out=out) 23 | 24 | def _read_embeddings_squeezed(self, vector_file): 25 | embeddings = [] 26 | for l in open(vector_file, mode='r', encoding=ENCODING): 27 | numbers = l.split() 28 | assert len(numbers) == self.embedding_size + 1 29 | embeddings.append(torch.FloatTensor([float(x) for x in numbers[1:]])) 30 | 31 | self.embeddings_tensor = torch.stack(embeddings, dim=0) 32 | 33 | def _read_embeddings(self, vector_file): 34 | embeddings = {} 35 | max_emb_id = 0 36 | for l in open(vector_file, mode='r', encoding=ENCODING): 37 | numbers = l.split(' ') 38 | assert len(numbers) == self.embedding_size + 1 39 | 40 | id = numbers[0] 41 | cur_emb = torch.FloatTensor([float(x) for x in numbers[1:]]) 42 | 43 | if id == '': 44 | self.unk_embedding = cur_emb 45 | else: 46 | assert int(id) not in embeddings.keys() 47 | max_emb_id = max(int(id), max_emb_id) 48 | embeddings[int(id)] = cur_emb 49 | 50 | self.embeddings_tensor = torch.FloatTensor(max_emb_id + 1, self.embedding_size) 51 | for k, v in embeddings.items(): 52 | if k >= 0: 53 | self.embeddings_tensor[k].copy_(v) 54 | else: 55 | print('Key {} skipped during embeddings load.'.format(k)) 56 | -------------------------------------------------------------------------------- /zerogercrnn/lib/file.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import os 4 | from io import open 5 | 6 | import torch 7 | 8 | from zerogercrnn.lib.constants import ENCODING 9 | 10 | 11 | def create_directory_if_not_exists(path): 12 | os.makedirs(path, exist_ok=True) 13 | 14 | 15 | def read_lines(filename, encoding=ENCODING): 16 | """ 17 | Read a file and split into lines 18 | """ 19 | """Read first *limit* lines from file and returns list of them.""" 20 | lines = [] 21 | with open(filename, mode='r', encoding=encoding) as f: 22 | for l in f: 23 | if l[-1] == '\n': 24 | lines.append(l[:-1]) 25 | else: 26 | lines.append(l) 27 | return lines 28 | 29 | 30 | def load_if_saved(model, path): 31 | """Loads state of the model if previously saved.""" 32 | if os.path.isfile(path): 33 | model.load_state_dict(torch.load(path)) 34 | print('Model restored from file.') 35 | else: 36 | raise Exception('Model file not exists File: {}'.format(path)) 37 | 38 | 39 | def load_cuda_on_cpu(model, path): 40 | """Loads CUDA model for testing on non CUDA device.""" 41 | if os.path.isfile(path): 42 | model.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)) 43 | print('Model restored from file.') 44 | else: 45 | raise Exception('Model file not exists. File: {}'.format(path)) 46 | 47 | 48 | def save_model(model, path): 49 | """Saves state of the model by specified path.""" 50 | torch.save(model.state_dict(), path) 51 | -------------------------------------------------------------------------------- /zerogercrnn/lib/health.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class HealthCheck: 5 | """Class that do some check on the model. Usually it prints some info about model at the end of epoch.""" 6 | 7 | @abstractmethod 8 | def do_check(self): 9 | pass 10 | 11 | 12 | class AlphaBetaSumHealthCheck(HealthCheck): 13 | 14 | def __init__(self, module): 15 | super().__init__() 16 | self.module = module 17 | 18 | def do_check(self): 19 | print('Alpha: {}'.format(self.module.mult_alpha)) 20 | print('Beta: {}'.format(self.module.mult_beta)) -------------------------------------------------------------------------------- /zerogercrnn/lib/log.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from itertools import islice 3 | import time 4 | 5 | # hack for tqdm 6 | tqdm.monitor_interval = 0 7 | 8 | 9 | class Logger: 10 | def __init__(self): 11 | self.ct = time.clock() 12 | self.should_log = False 13 | 14 | def reset_time(self): 15 | self.ct = time.clock() 16 | 17 | def log_time_s(self, label): 18 | self.__log_time__(label, 1) 19 | 20 | def log_time_ms(self, label): 21 | self.__log_time__(label, 1000) 22 | 23 | def __log_time__(self, label, multiplier): 24 | if self.should_log: 25 | print("{}: {}".format(label, multiplier * (time.clock() - self.ct))) 26 | self.ct = time.clock() 27 | 28 | 29 | logger = Logger() 30 | 31 | 32 | def tqdm_lim(iter, total=None, lim=None): 33 | if (total is None) and (lim is None): 34 | return tqdm(iter) 35 | 36 | right = 1000000000 37 | if total is not None: 38 | right = min(right, total) 39 | 40 | if lim is not None: 41 | right = min(right, lim) 42 | 43 | return tqdm(islice(iter, 0, right), total=right) 44 | -------------------------------------------------------------------------------- /zerogercrnn/lib/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import abstractmethod 3 | 4 | from zerogercrnn.lib.constants import ENCODING 5 | from zerogercrnn.lib.log import tqdm_lim 6 | 7 | 8 | def write_json(file, raw_json): 9 | """Writes json as is to file.""" 10 | 11 | with open(file, mode='w', encoding=ENCODING) as f: 12 | f.write(json.dumps(raw_json)) 13 | 14 | 15 | def read_lines(file, total=None, lim=None): 16 | """Returns generator of lines from file. 17 | 18 | :param file: path to file 19 | :param total: total number of lines in file 20 | :param lim: limit on number of read lines 21 | """ 22 | with open(file, mode='r', encoding=ENCODING) as f: 23 | for line in tqdm_lim(f, total=total, lim=lim): 24 | yield line 25 | 26 | 27 | def read_jsons(*files, lim=None): 28 | """Reads jsons from passed files. Suppose files to contain json lines separated by newlines. 29 | 30 | :param files: files to read jsons from 31 | :param lim: limit number of read jsons for all files 32 | """ 33 | for file in files: 34 | for line in read_lines(file, lim=lim): 35 | yield json.loads(line) 36 | 37 | 38 | def read_json(file): 39 | """Reads single json from file. 40 | 41 | :param file: file to read jsons from 42 | """ 43 | return list(read_jsons(file))[0] 44 | 45 | 46 | class JsonExtractor: 47 | """Extracts some info from passed json. See specific implementations for more info.""" 48 | 49 | @abstractmethod 50 | def extract(self, raw_json): 51 | pass 52 | 53 | 54 | class JsonListKeyExtractor(JsonExtractor): 55 | """Extracts values by specified key if it present. Suppose json to be a list of jsons.""" 56 | 57 | def __init__(self, key): 58 | self.key = key 59 | 60 | def extract(self, raw_json): 61 | for node in raw_json: 62 | if node == 0: 63 | break 64 | 65 | if self.key in node: 66 | yield node[self.key] 67 | 68 | 69 | def extract_jsons_info(extractor: JsonExtractor, *files, lim=None): 70 | """Read jsons from files and run extractor on them.""" 71 | for raw_json in read_jsons(*files, lim=lim): 72 | yield extractor.extract(raw_json) 73 | 74 | 75 | def test(): 76 | nt_extractor = JsonListKeyExtractor(key='type') 77 | 78 | for info_gen in extract_jsons_info(nt_extractor, 'data/programs_eval.json', lim=10): 79 | for val in info_gen: 80 | print(val) 81 | 82 | 83 | if __name__ == '__main__': 84 | test() 85 | -------------------------------------------------------------------------------- /zerogercrnn/lib/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from abc import abstractmethod 4 | 5 | from tqdm import tqdm 6 | 7 | from zerogercrnn.lib.metrics import Metrics 8 | 9 | # hack for tqdm 10 | tqdm.monitor_interval = 0 11 | 12 | import torch.nn as nn 13 | 14 | from zerogercrnn.lib.visualization.plotter import TensorboardPlotter, \ 15 | TensorboardPlotterCombined 16 | from zerogercrnn.lib.file import save_model 17 | from zerogercrnn.lib.data import DataGenerator 18 | 19 | LOG_EVERY = 1000 20 | 21 | 22 | class NetworkRoutine: 23 | """Base class for running single iteration of RNN. Enable to train or validate networks.""" 24 | 25 | def __init__(self, network): 26 | self.network = network 27 | 28 | @abstractmethod 29 | def run(self, iter_num, iter_data): 30 | """ Run routine and return value for plotting. 31 | 32 | :param iter_num: number of iteration 33 | :param iter_data: data for this iteration 34 | """ 35 | pass 36 | 37 | 38 | def save_current_model(model, dir, name): 39 | if dir is not None: 40 | print('Saving model: {}'.format(name)) 41 | save_model( 42 | model=model, 43 | path=os.path.join(dir, name) 44 | ) 45 | print('Saved!') 46 | 47 | 48 | class TrainEpochRunner: 49 | def __init__( 50 | self, 51 | network: nn.Module, 52 | train_routine: NetworkRoutine, 53 | validation_routine: NetworkRoutine, 54 | metrics: Metrics, 55 | data_generator: DataGenerator, 56 | schedulers=None, 57 | plotter='tensorboard', 58 | save_dir=None, 59 | title=None, 60 | report_train_every=1, 61 | plot_train_every=1, 62 | save_model_every=1 63 | ): 64 | """Create train runner. 65 | 66 | :param network: network to train. 67 | :param train_routine: routine that will run on each train input. 68 | :param validation_routine: routine that will run after each epoch of training on each validation input. 69 | :param metrics: metrics to plot. Should correspond to routine 70 | :param data_generator: generator of data for training and validation 71 | :param schedulers: schedulers for learning rate. If None learning rate will be constant 72 | :param plotter: visualization tool. Either 'matplotlib' or 'visdom'. 73 | :param save_dir: if specified model will be saved in this directory after each epoch with name "model_epoch_X". 74 | :param title: used for visualization 75 | """ 76 | self.network = network 77 | self.train_routine = train_routine 78 | self.validation_routine = validation_routine 79 | self.metrics = metrics 80 | self.data_generator = data_generator 81 | self.schedulers = schedulers 82 | self.save_dir = save_dir 83 | self.report_train_every = report_train_every 84 | self.plot_train_every = plot_train_every 85 | self.save_model_every = save_model_every 86 | 87 | self.epoch = None # current epoch 88 | self.it = None # current iteration 89 | 90 | if self.plot_train_every % self.report_train_every != 0: 91 | raise Exception('report_train_every should divide plot_train_every') 92 | 93 | if plotter == 'tensorboard': 94 | self.plotter = TensorboardPlotter(title=title) 95 | elif plotter == 'tensorboard_combined': 96 | self.plotter = TensorboardPlotterCombined(title=title) 97 | else: 98 | raise Exception('Unknown plotter') 99 | 100 | def run(self, number_of_epochs): 101 | self.epoch = -1 102 | self.it = 0 103 | # self._validate() # first validation for plot. 104 | 105 | try: 106 | while self.epoch < number_of_epochs: 107 | self.epoch += 1 108 | if self.schedulers is not None: 109 | t = 1 110 | # TODO: general way 111 | if self.epoch > 20: 112 | t = 5 113 | for i in range(t): 114 | for scheduler in self.schedulers: 115 | scheduler.step() 116 | 117 | self._run_for_epoch() 118 | self._validate() 119 | 120 | for hc in self.network.health_checks(): 121 | hc.do_check() 122 | 123 | if (self.epoch + 1) % self.save_model_every == 0: 124 | save_current_model(self.network, self.save_dir, name='model_epoch_{}'.format(self.epoch)) 125 | except KeyboardInterrupt: 126 | print('-' * 89) 127 | print('Exiting from training early') 128 | finally: 129 | # plot graphs of validation and train losses 130 | self.plotter.on_finish() 131 | 132 | def _run_for_epoch(self): 133 | self.metrics.train() 134 | self.metrics.drop_state() 135 | 136 | self.network.train() 137 | train_data = self.data_generator.get_train_generator() 138 | # print('Expected number of iterations for epoch: {}'.format(train_generator.size // batch_size)) 139 | 140 | for iter_data in train_data: 141 | if self.it % LOG_EVERY == 0: 142 | print('Training... Epoch: {}, Iters: {}'.format(self.epoch, self.it)) 143 | 144 | metrics_values = self.train_routine.run( 145 | iter_num=self.it, 146 | iter_data=iter_data 147 | ) 148 | 149 | # if self.it % self.report_train_every == 0: 150 | 151 | if self.it % self.plot_train_every == 0: 152 | self.metrics.drop_state() 153 | self.metrics.report(metrics_values) 154 | self.plotter.on_new_point( 155 | label='train', 156 | x=self.it, 157 | y=self.metrics.get_current_value(should_print=False) 158 | ) 159 | 160 | self.it += 1 161 | 162 | def _validate(self): 163 | """Perform validation and calculate loss as an average of the whole validation dataset.""" 164 | validation_data = self.data_generator.get_validation_generator() 165 | 166 | self.metrics.drop_state() 167 | # self.metrics.eval() 168 | self.network.eval() 169 | 170 | with torch.no_grad(): 171 | validation_it = 0 172 | for iter_data in validation_data: 173 | if validation_it % LOG_EVERY == 0: 174 | print('Validating... Epoch: {} Iters: {}'.format(self.epoch, validation_it)) 175 | 176 | metrics_values = self.validation_routine.run( 177 | iter_num=self.it, 178 | iter_data=iter_data 179 | ) 180 | 181 | self.metrics.report(metrics_values) 182 | 183 | validation_it += 1 184 | 185 | self.plotter.on_new_point( 186 | label='validation', 187 | x=self.it, 188 | y=self.metrics.get_current_value(should_print=False) 189 | ) 190 | 191 | print('Validation done. Epoch: {}'.format(self.epoch)) 192 | self.metrics.get_current_value(should_print=True) 193 | -------------------------------------------------------------------------------- /zerogercrnn/lib/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_best_device(): 6 | """Return cuda device if cuda is available.""" 7 | return get_device(torch.cuda.is_available()) 8 | 9 | 10 | def get_device(cuda): 11 | return torch.device("cuda" if cuda else "cpu") 12 | 13 | 14 | def init_recurrent_layers(*layers): 15 | for layer in layers: 16 | for name, param in layer.named_parameters(): 17 | if 'bias' in name: 18 | nn.init.constant_(param, 0.0) 19 | elif 'weight' in name: 20 | nn.init.xavier_normal_(param) 21 | 22 | 23 | def init_layers_uniform(min_value, max_value, layers): 24 | for layer in layers: 25 | for name, param in layer.named_parameters(): 26 | nn.init.uniform_(param, min_value, max_value) 27 | 28 | 29 | def repackage_hidden(h): 30 | """Forgets history of current hidden state.""" 31 | if type(h) == torch.Tensor: 32 | return h.detach().requires_grad_(h.requires_grad) 33 | else: 34 | return tuple(repackage_hidden(v) for v in h) 35 | 36 | 37 | def forget_hidden_partly_lstm_cell(h, forget_vector): 38 | return h[0].mul(forget_vector), h[1].mul(forget_vector) 39 | 40 | 41 | def forget_hidden_partly(h, forget_vector): 42 | if type(h) == torch.Tensor: 43 | return h.mul(forget_vector.unsqueeze(0)) # TODO: check 44 | else: 45 | return tuple(forget_hidden_partly(v, forget_vector) for v in h) 46 | 47 | 48 | def setup_tensor(tensor): 49 | return tensor.to(get_best_device()) 50 | 51 | 52 | def filter_requires_grad(parameters): 53 | return filter(lambda p: p.requires_grad, parameters) 54 | 55 | 56 | def register_forward_hook(module, metrics, picker): 57 | module.register_forward_hook(lambda _, m_input, m_output: metrics.report(picker(m_input, m_output))) 58 | 59 | 60 | def register_output_hook(module, metrics, picker=None): 61 | if picker is None: 62 | picker = lambda m_output: m_output 63 | register_forward_hook(module, metrics, lambda m_input, m_output: picker(m_output)) 64 | 65 | 66 | def register_input_hook(module, metrics, picker=None): 67 | if picker is None: 68 | picker = lambda m_input: m_input[0] 69 | register_forward_hook(module, metrics, lambda m_input, m_output: picker(m_input)) 70 | 71 | 72 | if __name__ == '__main__': 73 | h1 = torch.randn((1, 8, 10)) 74 | zeros = torch.ones(8, 1) 75 | zeros[1][0] = 0 76 | zeros[2][0] = 0 77 | 78 | h1.mul(zeros, out=h1) 79 | 80 | print(h1) 81 | -------------------------------------------------------------------------------- /zerogercrnn/lib/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from tensorboardX import SummaryWriter 5 | 6 | # tensorboard --logdir=/tensorboard/runs 7 | if __name__ == '__main__': 8 | writer = SummaryWriter('/tensorboard/runs/test') 9 | 10 | for step in range(10): 11 | dummy_s1 = torch.rand(1) 12 | writer.add_scalar('data/random', dummy_s1, step) 13 | time.sleep(1) 14 | 15 | # writer.add_scalar( 16 | # tag='data/scalar2', 17 | # scalar_value=10., 18 | # global_step=3 19 | # ) 20 | -------------------------------------------------------------------------------- /zerogercrnn/lib/visualization/embeddings.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.manifold import TSNE 4 | 5 | from zerogercrnn.lib.preprocess import read_jsons 6 | from zerogercrnn.lib.embedding import Embeddings 7 | 8 | 9 | def tsne_plot(emb: Embeddings, vocab): 10 | "Creates and TSNE model and plots it" 11 | labels = [] 12 | tokens = [] 13 | 14 | for cur in range(emb.embeddings_tensor.size()[0]): 15 | if cur == -1: 16 | labels.append('SPACE') 17 | else: 18 | labels.append(vocab[cur]) 19 | tokens.append(emb.embeddings_tensor[cur].numpy()) 20 | 21 | tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23) 22 | new_values = tsne_model.fit_transform(tokens) 23 | 24 | x = [] 25 | y = [] 26 | for value in new_values: 27 | x.append(value[0]) 28 | y.append(value[1]) 29 | 30 | nearest(x, y, vocab, 'TryStatement10') 31 | 32 | plt.figure(figsize=(16, 16)) 33 | for i in range(len(x)): 34 | plt.scatter(x[i], y[i]) 35 | plt.annotate(labels[i], 36 | xy=(x[i], y[i]), 37 | xytext=(5, 2), 38 | textcoords='offset points', 39 | ha='right', 40 | va='bottom') 41 | plt.show() 42 | 43 | 44 | def nearest(x, y, vocab, word): 45 | w_id = -1 46 | for i in range(len(vocab)): 47 | if vocab[i] == word: 48 | w_id = i 49 | if w_id == -1: 50 | raise Exception('No such word in vocabulary: {}'.format(word)) 51 | 52 | px = x[w_id] 53 | py = y[w_id] 54 | p = np.array([px, py]) 55 | 56 | points_with_distance = [] 57 | for i in range(len(x)): 58 | points_with_distance.append((i, np.linalg.norm(np.array([x[i], y[i]]) - p))) 59 | 60 | print('Nearest to {}:'.format(vocab[w_id])) 61 | for c_p in sorted(points_with_distance, key=lambda x: x[1])[:10]: 62 | print(vocab[c_p[0]]) 63 | 64 | 65 | if __name__ == '__main__': 66 | emb = Embeddings(vector_file='/Users/zerogerc/Documents/diploma/GloVe/vectors.txt', embeddings_size=5) 67 | vocab = list(read_jsons('data/ast/non_terminals.json'))[0] 68 | vocab.append('EOF') 69 | tsne_plot(emb, vocab) 70 | -------------------------------------------------------------------------------- /zerogercrnn/lib/visualization/html_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import webbrowser 4 | from urllib.request import pathname2url 5 | 6 | from zerogercrnn.lib.constants import ENCODING 7 | 8 | VIS_PACKAGE = os.path.dirname(os.path.abspath(__file__)) 9 | POPUP_CSS = os.path.join(VIS_PACKAGE, 'html_tools/popup.css') 10 | 11 | htmlCodes = { 12 | "'": ''', 13 | '"': '"', 14 | '>': '>', 15 | '<': '<', 16 | '&': '&', 17 | '\n': '
', 18 | '\t': ' ', 19 | ' ': ' ' 20 | } 21 | 22 | 23 | def char_to_html(c): 24 | if str(c) in htmlCodes: 25 | return htmlCodes[str(c)] 26 | else: 27 | return c 28 | 29 | 30 | def string_to_html(s): 31 | return ''.join([char_to_html(c) for c in s]) 32 | 33 | 34 | def show_html_page(page, save_file=None): 35 | """Show string *page* as an html in the browser. If save_file specified will save the page there.""" 36 | html_path = save_file or os.path.join(tempfile.gettempdir(), 'diff.html') 37 | f = open(html_path, encoding=ENCODING, mode='w') 38 | f.write(page) 39 | f.close() 40 | webbrowser.open(url='file:{}'.format(pathname2url(html_path))) 41 | 42 | 43 | class HtmlBuilder: 44 | HEAD = """ 45 | 46 | 47 | 48 | """.format(POPUP_CSS) 49 | 50 | BODY = """ 51 | 52 |
53 | {} 54 |
55 | 56 | """ 57 | 58 | def __init__(self): 59 | self.message = "" 60 | 61 | def add_popup(self, anchor, popup, background=None): 62 | """Add popup with two texts: one for anchor text and other for popup. 63 | You should not append any html elements to anchor or popup because this texts will be *converted*. 64 | """ 65 | self.message += HtmlBuilder.get_popup_html(anchor, popup, background) 66 | 67 | def build(self): 68 | return HtmlBuilder.get_popup_html_page(self.message) 69 | 70 | @staticmethod 71 | def get_popup_html(anchor, popup, background=None): 72 | background = background or '#FFF' 73 | return """ 74 | 75 | {} 76 | {} 77 | 78 | """.format(background, string_to_html(anchor), string_to_html(popup)) 79 | 80 | @staticmethod 81 | def get_popup_html_page(body): 82 | return HtmlBuilder.HEAD + HtmlBuilder.BODY.format(body) 83 | 84 | 85 | if __name__ == '__main__': 86 | builder = HtmlBuilder() 87 | builder.add_popup( 88 | anchor="first", 89 | popup="First\nSecond", 90 | background="#81C784" 91 | 92 | ) 93 | for i in range(100): 94 | builder.add_popup( 95 | anchor="second", 96 | popup="1\n2\n3", 97 | background="#EF9A9A" 98 | ) 99 | show_html_page( 100 | page=builder.build() 101 | ) 102 | -------------------------------------------------------------------------------- /zerogercrnn/lib/visualization/plotter.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import visdom 6 | from tensorboardX import SummaryWriter 7 | 8 | TENSORBOARD_DIR = 'tensorboard/runs/' 9 | 10 | 11 | class Plotter: 12 | def on_new_point(self, label, x, y): 13 | pass 14 | 15 | def on_finish(self): 16 | pass 17 | 18 | 19 | class MatplotlibPlotter(Plotter): 20 | def __init__(self, title): 21 | super(MatplotlibPlotter, self).__init__() 22 | self.title = title 23 | self.plots = {} 24 | 25 | def on_new_point(self, label, x, y): 26 | if label not in self.plots: 27 | self.plots[label] = PlotData() 28 | 29 | self.plots[label].x.append(x) 30 | self.plots[label].y.append(y) 31 | 32 | def on_finish(self): 33 | for label in self.plots: 34 | plt.plot(self.plots[label].x, self.plots[label].y, label=label) 35 | 36 | plt.title(self.title) 37 | plt.legend() 38 | plt.show() 39 | 40 | 41 | class VisdomPlotter(Plotter): 42 | def __init__(self, title, plots): 43 | super(VisdomPlotter, self).__init__() 44 | self.title = title 45 | self.vis = visdom.Visdom() 46 | self.plots = set(plots) 47 | 48 | self.vis.line( 49 | X=np.zeros((1, len(plots))), 50 | Y=np.zeros((1, len(plots))), 51 | win=self.title, 52 | opts=dict(legend=plots) 53 | ) 54 | 55 | def on_new_point(self, label, x, y): 56 | if label not in self.plots: 57 | raise Exception('Plot should be in plots set!') 58 | 59 | self.vis.line( 60 | X=np.array([x]), 61 | Y=np.array([y]), 62 | win=self.title, 63 | name=label, 64 | update='append' 65 | ) 66 | 67 | 68 | class TensorboardPlotter(Plotter): 69 | def __init__(self, title): 70 | path = os.path.join(os.getcwd(), TENSORBOARD_DIR + title) 71 | self.writer = SummaryWriter(path) 72 | 73 | def on_new_point(self, label, x, y): 74 | self.writer.add_scalar( 75 | tag=label, 76 | scalar_value=y, 77 | global_step=x 78 | ) 79 | 80 | 81 | class TensorboardPlotterCombined(Plotter): 82 | """x is step, y is two values: one for for non terminals, one for terminals.""" 83 | 84 | def __init__(self, title): 85 | path = os.path.join(os.getcwd(), TENSORBOARD_DIR + title) 86 | self.writer = SummaryWriter(path) 87 | 88 | def on_new_point(self, label, x, y): 89 | self.writer.add_scalar( 90 | tag=label + ' non-terminals', 91 | scalar_value=y[0], 92 | global_step=x 93 | ) 94 | self.writer.add_scalar( 95 | tag=label + ' terminals', 96 | scalar_value=y[1], 97 | global_step=x 98 | ) 99 | 100 | 101 | class PlotData: 102 | def __init__(self): 103 | self.x = [] 104 | self.y = [] 105 | 106 | def add(self, x, y): 107 | self.x.append(x) 108 | self.y.append(y) 109 | 110 | 111 | if __name__ == '__main__': 112 | plotter = VisdomPlotter(title='x', plots=['y', 'z']) 113 | -------------------------------------------------------------------------------- /zerogercrnn/lib/visualization/text.py: -------------------------------------------------------------------------------- 1 | from zerogercrnn.lib.visualization.html_helper import char_to_html, string_to_html, show_html_page 2 | 3 | 4 | def get_diff(text, actual): 5 | """Return two html-colored strings. Green if text[i] == actual[i], red otherwise""" 6 | 7 | assert (len(text) == len(actual)) 8 | 9 | out_text = '' 10 | out_actual = '' 11 | 12 | green = '{}' 13 | red = '{}' 14 | 15 | for i in range(len(text)): 16 | if text[i] == actual[i]: 17 | out_text += green.format(char_to_html(text[i])) 18 | out_actual += green.format(char_to_html(actual[i])) 19 | else: 20 | out_text += red.format(char_to_html(text[i])) 21 | out_actual += red.format(char_to_html(actual[i])) 22 | return out_text, out_actual 23 | 24 | 25 | def show_diff(text, actual, file=None): 26 | """ 27 | Shows difference between two strings in html. 28 | 29 | :param text: text got from some algorithm 30 | :param actual: actual text to compare with 31 | :param file: if specified html will be stored there 32 | """ 33 | assert (len(text) == len(actual)) 34 | diff_text, diff_actual = get_diff(text, actual) 35 | message = """ 36 | 37 | 38 | 48 | 49 | 50 | 51 | {} 52 | 53 | 54 | """.format(diff_actual) 55 | 56 | show_html_page( 57 | page=message, 58 | save_file=file 59 | ) 60 | 61 | 62 | def show_token_diff(predicted, actual, file=None): 63 | assert len(predicted) == len(actual) 64 | predicted_html = string_to_html(' '.join(predicted)) 65 | actual_html = string_to_html(' '.join(actual)) 66 | 67 | message = """ 68 | 69 | 70 | 83 | 84 | 85 | 86 |
87 |
88 |

Text

89 |
{}
90 |
91 |
92 |

Actual

93 |
{}
94 |
95 |
96 | 97 | 98 | """.format(predicted_html, actual_html) 99 | 100 | show_html_page( 101 | page=message, 102 | save_file=file 103 | ) 104 | 105 | 106 | if __name__ == '__main__': 107 | show_diff( 108 | " Here we are\te\n !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[" 109 | "\\]^_`abcdefghijklmnopqrstuvwxyz{|}~¥©ÂÃ", 110 | " Hear We are\te\n !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[" 111 | "\\]^_`abcdefghijklmnopqrstuvwxyz{|}~¥©ÂÃ" 112 | ) 113 | -------------------------------------------------------------------------------- /zerogercrnn/test/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zerogerc/rnn-autocomplete/39dc8dd7c431cb8ac9e15016388ec823771388e4/zerogercrnn/test/lib/__init__.py -------------------------------------------------------------------------------- /zerogercrnn/test/lib/calculation_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from zerogercrnn.lib.calculation import shift_left, pad_tensor, calc_attention_combination, drop_matrix_rows_3d, \ 5 | select_layered_hidden, set_layered_hidden, create_one_hot 6 | from zerogercrnn.testutils.utils import assert_tensors_equal 7 | 8 | 9 | def test_move_right_should_move_dim0(): 10 | matrix = torch.LongTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) 11 | dimension = 0 12 | shift_left(matrix, dimension) 13 | assert_tensors_equal(matrix, torch.LongTensor([[2, 2, 2], [3, 3, 3], [3, 3, 3]])) 14 | 15 | 16 | def test_move_right_should_move_dim1(): 17 | matrix = torch.LongTensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) 18 | dimension = 1 19 | shift_left(matrix, dimension) 20 | assert_tensors_equal(matrix, torch.LongTensor([[2, 3, 3], [2, 3, 3], [2, 3, 3]])) 21 | 22 | 23 | def test_pad_tensor_long(): 24 | tensor = torch.tensor([1, 2, 3], dtype=torch.long) 25 | assert_tensors_equal(pad_tensor(tensor, seq_len=5), torch.tensor([1, 2, 3, 3, 3], dtype=torch.long)) 26 | 27 | 28 | def test_pad_tensor_float(): 29 | tensor = torch.tensor([0., 0.5, 1.2], dtype=torch.float32) 30 | assert_tensors_equal( 31 | pad_tensor(tensor, seq_len=6), 32 | torch.tensor([0., 0.5, 1.2, 1.2, 1.2, 1.2], dtype=torch.float32) 33 | ) 34 | 35 | 36 | def test_pad_tensor_2d(): 37 | tensor = torch.tensor([[3, 2, 1], [0, 10, 5]], dtype=torch.long) 38 | assert_tensors_equal( 39 | pad_tensor(tensor, seq_len=4), 40 | torch.tensor( 41 | [[3, 2, 1], [0, 10, 5], [0, 10, 5], [0, 10, 5]], 42 | dtype=torch.float32) 43 | ) 44 | 45 | 46 | def test_calc_attention_combination_should_work(): 47 | matrix = torch.FloatTensor([ 48 | [ 49 | [1, 10], 50 | [1, 1], 51 | [1, 8] 52 | ], 53 | [ 54 | [1, 1], 55 | [4, 4], 56 | [9, 6] 57 | ] 58 | ]) 59 | 60 | attention_weights = torch.FloatTensor([ 61 | [ 62 | [1. / 2], 63 | [1.], 64 | [1. / 2] 65 | ], 66 | [ 67 | [1.], 68 | [1. / 2], 69 | [1. / 3] 70 | ] 71 | ]) 72 | 73 | expected = torch.FloatTensor([ 74 | [2., 10.], 75 | [6., 5.] 76 | ]) 77 | 78 | attentioned = calc_attention_combination(attention_weights, matrix) 79 | assert_tensors_equal(attentioned, expected) 80 | 81 | 82 | def test_drop_matrix_rows_3d(): 83 | matrix = torch.FloatTensor([ 84 | [ 85 | [1, 1, 1], 86 | [2, 2, 2], 87 | [3, 3, 3] 88 | ], 89 | [ 90 | [4, 4, 4], 91 | [5, 5, 5], 92 | [6, 6, 6] 93 | ] 94 | ]) 95 | 96 | forget_vector = torch.FloatTensor([ 97 | [0], 98 | [1] 99 | ]) 100 | 101 | expected = torch.FloatTensor([ 102 | [ 103 | [0, 0, 0], 104 | [0, 0, 0], 105 | [0, 0, 0] 106 | ], 107 | [ 108 | [4, 4, 4], 109 | [5, 5, 5], 110 | [6, 6, 6] 111 | ] 112 | ]) 113 | 114 | assert_tensors_equal(drop_matrix_rows_3d(matrix, forget_vector), expected) 115 | 116 | 117 | def test_select_layered_hidden(): 118 | batch_size = 5 119 | layers = 50 120 | hidden_size = 10 121 | 122 | node_depths = torch.LongTensor([0, 2, layers - 1, 2, 5]) 123 | layered_hidden = torch.randn((batch_size, layers, hidden_size)) 124 | 125 | selected = select_layered_hidden(layered_hidden, node_depths) 126 | 127 | for i in range(node_depths.size()[0]): 128 | assert torch.nonzero(selected[i][0] == layered_hidden[i][node_depths[i]]).size()[0] == hidden_size 129 | 130 | 131 | def test_set_layered_hidden(): 132 | batch_size = 6 133 | layers = 50 134 | hidden_size = 10 135 | 136 | layered_hidden = torch.randn((batch_size, layers, hidden_size)) 137 | node_depths = torch.LongTensor([0, 1, layers - 1, 2, 5, 1]) 138 | updated = torch.randn((batch_size, hidden_size)) 139 | old_hidden = layered_hidden.clone() 140 | 141 | layered_hidden = set_layered_hidden(layered_hidden, node_depths, updated) 142 | 143 | assert torch.nonzero(old_hidden - layered_hidden).size()[0] == batch_size * hidden_size 144 | for i in range(node_depths.size()[0]): 145 | assert torch.nonzero(layered_hidden[i][node_depths[i]] == updated[i]).size()[0] == hidden_size 146 | 147 | 148 | def test_create_one_hot(): 149 | tensor = torch.from_numpy(np.array([1, 2, 3, 4, 0])) 150 | size = 5 151 | expected = torch.from_numpy(np.array([ 152 | [0, 1, 0, 0, 0], 153 | [0, 0, 1, 0, 0], 154 | [0, 0, 0, 1, 0], 155 | [0, 0, 0, 0, 1], 156 | [1, 0, 0, 0, 0] 157 | ] 158 | )) 159 | 160 | one_hot = create_one_hot(tensor, size) 161 | assert_tensors_equal(one_hot, expected) 162 | -------------------------------------------------------------------------------- /zerogercrnn/test/lib/data_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from zerogercrnn.lib.data import split_train_validation, get_shuffled_indexes, get_random_index, DataChunk, \ 5 | DataChunksPool, DataBucket, BucketsBatch, DataReader, BatchedDataGenerator 6 | 7 | 8 | def test_split_train_validation(): 9 | data = np.arange(0, 10) 10 | train, val = split_train_validation(data, 0.8) 11 | assert np.all(train == np.arange(0, 8)) 12 | assert np.all(val == np.array([8, 9])) 13 | 14 | 15 | def test_get_shuffled_indexes(): 16 | data = get_shuffled_indexes(100) 17 | assert len(data) == 100 18 | 19 | 20 | def test_get_random_index(): 21 | np.random.seed(1) 22 | for i in range(100): 23 | assert get_random_index(100) < 100 24 | 25 | 26 | def test_data_chunks_pool_no_shuffle(): 27 | data_size = 10 28 | splits = data_size // 2 29 | pool = create_test_data_pool(data_size, splits=splits, split_coefficient=0.5, shuffle=False) 30 | 31 | # emit all splits of data 32 | for i in range(splits): 33 | pool.start_epoch() 34 | assert pool.get_chunk().id == 2 * i 35 | assert pool.get_chunk().id == 2 * i + 1 36 | assert pool.is_epoch_finished() 37 | 38 | # do not crash on finish data and start from begining 39 | pool.start_epoch() 40 | assert pool.get_chunk().id == 0 41 | assert pool.get_chunk().id == 1 42 | 43 | 44 | def test_data_chunks_pool_shuffle(): 45 | data_size = 10 46 | splits = data_size // 2 47 | pool = create_test_data_pool(data_size, splits=splits, split_coefficient=0.5, shuffle=True) 48 | 49 | ids = set() 50 | first_chunk = None 51 | # emit all splits of data 52 | for i in range(splits): 53 | pool.start_epoch() 54 | if first_chunk is None: 55 | first_chunk = pool.get_chunk().id 56 | ids.add(first_chunk) 57 | else: 58 | ids.add(pool.get_chunk().id) 59 | ids.add(pool.get_chunk().id) 60 | assert pool.is_epoch_finished() 61 | 62 | assert len(ids) == data_size 63 | pool.start_epoch() 64 | 65 | 66 | def test_data_chunks_pool_return_none_on_finished_epoch(): 67 | data_size = 10 68 | pool = create_test_data_pool(data_size, splits=data_size // 2, split_coefficient=0.5, shuffle=True) 69 | 70 | assert pool.get_chunk() is None 71 | 72 | 73 | def test_data_chunks_pool_exceptions_on_not_finished_epoch(): 74 | data_size = 10 75 | pool = create_test_data_pool(data_size, splits=data_size // 2, split_coefficient=0.5, shuffle=True) 76 | 77 | with pytest.raises(Exception): 78 | pool.start_epoch() 79 | pool.start_epoch() 80 | 81 | 82 | def test_data_bucket_emit_all_data_and_then_raise(): 83 | data_size = 100 84 | seq_len = 50 85 | pool = create_test_data_pool(data_size, splits=10, split_coefficient=0.5, shuffle=False) 86 | pool.start_epoch() 87 | 88 | bucket = DataBucket(pool, seq_len) 89 | bucket.refill_if_necessary() 90 | for i in range(20): 91 | index, chunk = bucket.get_next_index_with_chunk() 92 | assert chunk.id == i // 2 93 | assert index == 50 * (i % 2) 94 | 95 | with pytest.raises(Exception) as excinfo: 96 | bucket.get_next_index_with_chunk() 97 | assert 'No data in bucket' in str(excinfo.value) 98 | 99 | 100 | def test_buckets_batch(): 101 | data_size = 60 102 | seq_len = 50 103 | batch_size = 3 104 | pool = create_test_data_pool(data_size, splits=10, split_coefficient=0.5, shuffle=False) 105 | pool.chunks[0].data = np.arange(200) 106 | 107 | cur_chunk = [0] 108 | chunk_numbers = [0, 1, 2, 0, 1, 2, 0, 3, 4, 0, 3, 4] 109 | 110 | def retriever(buckets): 111 | assert len(buckets) == batch_size 112 | for i in range(batch_size): 113 | index, chunk = buckets[i].get_next_index_with_chunk() 114 | assert chunk.id == chunk_numbers[cur_chunk[0]] 115 | cur_chunk[0] += 1 116 | 117 | batch = BucketsBatch(pool, seq_len=seq_len, batch_size=batch_size) 118 | for i, (data, forget_vector) in enumerate(batch.get_epoch(retriever)): 119 | if i == 0: 120 | assert np.all(forget_vector == 0) 121 | elif i == 1: 122 | assert np.all(forget_vector == 1) 123 | elif i == 2: 124 | assert np.all(forget_vector.view(-1).cpu().numpy().astype(int) == [1, 0, 0]) 125 | elif i == 3: 126 | assert np.all(forget_vector == 1) 127 | 128 | 129 | def test_batched_generator(): 130 | data_size = 36 131 | seq_len = 50 132 | batch_size = 3 133 | reader = create_test_data_reader(data_size, data_size // 6, split_coefficient=5/6) 134 | 135 | def get_retriever(to_check_key, start=0): 136 | cur_chunk = [0] 137 | chunk_numbers = np.array([0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5]) + start 138 | 139 | def retriever(key, buckets): 140 | assert key == to_check_key 141 | for i in range(batch_size): 142 | index, chunk = buckets[i].get_next_index_with_chunk() 143 | assert chunk.id == chunk_numbers[cur_chunk[0]] 144 | cur_chunk[0] += 1 145 | 146 | return retriever 147 | 148 | generator = TestBatchedGenerator(reader, seq_len, batch_size, get_retriever('train')) 149 | for data in generator.get_train_generator(): 150 | assert data[1].size()[0] == batch_size 151 | 152 | generator = TestBatchedGenerator(reader, seq_len, batch_size, get_retriever('validation', start=30)) 153 | for data in generator.get_validation_generator(): 154 | assert data[1].size()[0] == batch_size 155 | 156 | # generator = TestBatchedGenerator(reader, seq_len, batch_size, get_retriever('eval')) 157 | # for data in generator.get_eval_generator(): 158 | # assert data[1].size()[0] == batch_size 159 | 160 | 161 | # region Utils 162 | 163 | class TestBatchedGenerator(BatchedDataGenerator): 164 | 165 | def __init__(self, data_reader, seq_len, batch_size, retriever): 166 | super().__init__(data_reader, seq_len, batch_size, shuffle=False) 167 | self.retriever = retriever 168 | 169 | def _retrieve_batch(self, key, buckets): 170 | self.retriever(key, buckets) 171 | 172 | 173 | def create_test_data_pool(data_size, splits, split_coefficient=0.5, shuffle=False): 174 | np.random.seed(1) 175 | reader = create_test_data_reader(2 * data_size, data_size, split_coefficient=split_coefficient) 176 | return DataChunksPool(reader.train_data, splits=splits, shuffle=shuffle) 177 | 178 | 179 | def create_test_data_reader(train_length, test_length, split_coefficient=0.5): 180 | train_data = None 181 | test_data = None 182 | if train_length is not None: 183 | train_data = [create_test_data_chunk(100, i) for i in range(train_length)] 184 | if test_length is not None: 185 | test_data = [create_test_data_chunk(100, i) for i in range(test_length)] 186 | 187 | reader = DataReader() 188 | reader.train_data, reader.validation_data = split_train_validation(train_data, split_coefficient) 189 | reader.eval_data = test_data 190 | return reader 191 | 192 | 193 | class TestDataChunk(DataChunk): 194 | def __init__(self, data, id): 195 | self.id = id 196 | self.data = data 197 | self.seq_len = None 198 | 199 | def prepare_data(self, seq_len): 200 | self.data = self.data[:len(self.data) - (len(self.data) % seq_len)] 201 | 202 | def get_by_index(self, index): 203 | assert self.seq_len is not None 204 | assert len(self.data) % self.seq_len == 0 205 | assert (index < self.size()) 206 | return self.data[index] 207 | 208 | def size(self): 209 | return len(self.data) 210 | 211 | 212 | def create_test_data_chunk(length, _id): 213 | return TestDataChunk(np.arange(length), _id) 214 | 215 | # endregion 216 | -------------------------------------------------------------------------------- /zerogercrnn/test/lib/metrics_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from zerogercrnn.lib.metrics import TensorVisualizerMetrics, TensorVisualizer2DMetrics, TensorVisualizer3DMetrics 5 | from zerogercrnn.testutils.utils import assert_tensors_equal 6 | 7 | 8 | def test_tensor_visualizer_metrics(): 9 | metrics = TensorVisualizerMetrics(file=None) 10 | 11 | t1 = torch.from_numpy(np.array([[1, 2, 3], [4, 5, 6]])) 12 | t2 = torch.from_numpy(np.array([[3, 4, 5], [6, 7, 8]])) 13 | 14 | metrics.drop_state() 15 | metrics.report(t1) 16 | metrics.report(t2) 17 | 18 | expected = (t1 + t2) / 2 19 | 20 | assert_tensors_equal(metrics.get_current_value(), expected) 21 | 22 | 23 | def test_tensor_visualizer_metrics_random(): 24 | torch.manual_seed(123) 25 | metrics = TensorVisualizerMetrics(file=None) 26 | 27 | t1 = torch.randn((2, 3)) 28 | t2 = torch.randn((2, 3)) 29 | 30 | metrics.drop_state() 31 | metrics.report(t1) 32 | metrics.report(t2) 33 | 34 | expected = (t1 + t2) / 2 35 | 36 | assert_tensors_equal(metrics.get_current_value(), expected, eps=1e-6) 37 | 38 | 39 | def test_tensor_visualizer2d_metrics_dim0(): 40 | metrics = TensorVisualizer2DMetrics(dim=0, file=None) 41 | 42 | t1 = torch.from_numpy(np.array([[1, 2, 3], [4, 5, 6]])) 43 | t2 = torch.from_numpy(np.array([[3, 4, 5], [6, 7, 8]])) 44 | 45 | metrics.drop_state() 46 | metrics.report(t1) 47 | metrics.report(t2) 48 | 49 | expected = torch.sum(t1 + t2, dim=0).float() / 4 50 | 51 | assert_tensors_equal(metrics.get_current_value(), expected) 52 | 53 | 54 | def test_tensor_visualizer2d_metrics_dim0_random(): 55 | torch.manual_seed(100) 56 | metrics = TensorVisualizer2DMetrics(dim=0, file=None) 57 | 58 | t1 = torch.randn((2, 3)) 59 | t2 = torch.randn((2, 3)) 60 | 61 | metrics.drop_state() 62 | metrics.report(t1) 63 | metrics.report(t2) 64 | 65 | expected = torch.sum(t1 + t2, dim=0).float() / 4 66 | 67 | assert_tensors_equal(metrics.get_current_value(), expected, eps=1e-6) 68 | 69 | 70 | def test_tensor_visualizer2d_metrics_dim1(): 71 | metrics = TensorVisualizer2DMetrics(dim=1, file=None) 72 | 73 | t1 = torch.from_numpy(np.array([[1, 2, 3], [4, 5, 6]])) 74 | t2 = torch.from_numpy(np.array([[3, 4, 5], [6, 7, 8]])) 75 | 76 | metrics.drop_state() 77 | metrics.report(t1) 78 | metrics.report(t2) 79 | 80 | expected = torch.sum(t1 + t2, dim=1).float() / 6 81 | 82 | assert_tensors_equal(metrics.get_current_value(), expected) 83 | 84 | 85 | def test_tensor_visualizer2d_metrics_dim1_random(): 86 | torch.manual_seed(145) 87 | metrics = TensorVisualizer2DMetrics(dim=1, file=None) 88 | 89 | t1 = torch.randn((2, 3)) 90 | t2 = torch.randn((2, 3)) 91 | 92 | metrics.drop_state() 93 | metrics.report(t1) 94 | metrics.report(t2) 95 | 96 | expected = torch.sum(t1 + t2, dim=1).float() / 6 97 | 98 | assert_tensors_equal(metrics.get_current_value(), expected, eps=1e-6) 99 | 100 | 101 | def test_tensor_visualizer3d_metrics(): 102 | torch.manual_seed(2) 103 | metrics = TensorVisualizer3DMetrics(file=None) 104 | 105 | t1 = torch.from_numpy(np.array([ 106 | [ 107 | [1, 1, 1, 1, 1], 108 | [2, 2, 2, 2, 2], 109 | [3, 3, 3, 3, 3], 110 | [4, 4, 4, 4, 4], 111 | ], 112 | [ 113 | [1, 1, 1, 1, 8], 114 | [2, 10, 2, 4, 2], 115 | [3, 3, 3, 3, 3], 116 | [9, 4, 4, 4, 4], 117 | ], 118 | [ 119 | 120 | [1, 1, -10, 1, 9], 121 | [2, -1, 2, 5, 2], 122 | [0, 3, 7, 3, 3], 123 | [4, 2, 4, 4, 4], 124 | ] 125 | ])) 126 | t2 = torch.from_numpy(np.array([ 127 | [ 128 | [10, 1, 10, 1, 1], 129 | [2, 22, 28, 29, 2], 130 | [3, 34, 33, 3, 31], 131 | [4, 4, 4, 44, 4], 132 | ], 133 | [ 134 | [1, 11, 1, 1, 8], 135 | [2, 10, 2, 4, 2], 136 | [3, 3, 33, 3, 3], 137 | [9, 4, 4, 47, 40], 138 | ], 139 | [ 140 | 141 | [1000, 1, -10, 1, 91], 142 | [2, -111, 2, 5, 2], 143 | [0, 32, 7, 3, 3], 144 | [4, 24, 4, 4, 4], 145 | ] 146 | ])) 147 | 148 | metrics.drop_state() 149 | metrics.report(t1) 150 | metrics.report(t2) 151 | 152 | expected = (t1 + t2).sum(0).sum(0) / (12 * 2) 153 | 154 | assert_tensors_equal(metrics.get_current_value(), expected, eps=1e-6) 155 | 156 | 157 | def test_tensor_visualizer3d_metrics_random(): 158 | torch.manual_seed(2) 159 | metrics = TensorVisualizer3DMetrics(file=None) 160 | 161 | t1 = torch.randn((3, 4, 5)) 162 | t2 = torch.randn((3, 4, 5)) 163 | 164 | metrics.drop_state() 165 | metrics.report(t1) 166 | metrics.report(t2) 167 | 168 | expected = (t1 + t2).sum(0).sum(0) / (12 * 2) 169 | 170 | assert_tensors_equal(metrics.get_current_value(), expected, eps=1e-6) 171 | -------------------------------------------------------------------------------- /zerogercrnn/testutils/utils.py: -------------------------------------------------------------------------------- 1 | def assert_numbers_almost_equal(x, y, eps=1e-9): 2 | assert abs(x - y) < eps 3 | 4 | 5 | def assert_tensors_equal(t1, t2, eps=1e-9): 6 | assert t1.size() == t2.size() 7 | 8 | t1 = t1.view(-1) 9 | t2 = t2.view(-1) 10 | 11 | for i in range(t1.size()[0]): 12 | assert_numbers_almost_equal(t1[i].item(), t2[i].item(), eps=eps) 13 | --------------------------------------------------------------------------------