├── .gitignore ├── LICENSE ├── README.md ├── configs ├── attn_lstm_vocab_10k.yml ├── attn_lstm_vocab_1k.yml ├── attn_lstm_vocab_50k.yml ├── label_smoothing_pointer_10k.yml ├── pointer_vocab_10k.yml ├── pointer_vocab_1k.yml ├── pointer_vocab_50k.yml ├── simple_lstm_vocab_10k.yml ├── simple_lstm_vocab_1k.yml └── simple_lstm_vocab_50k.yml ├── data.py ├── eval.ipynb ├── model.py ├── preprocess.py ├── preprocess_utils ├── freq_dict.py ├── get_non_terminal.py ├── get_terminal_dict.py ├── get_terminal_whole.py ├── get_total_length.py └── utils.py ├── run.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *json* 2 | *ipynb_checkpoints* 3 | *__pycache__* 4 | *.log 5 | *.pickle 6 | *.txt 7 | *.pdf 8 | checkpoints 9 | logs 10 | *.out 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Oleg Desheulin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code-Completion 2 | Pytorch version of [code completion with neural attention and pointer networks](https://arxiv.org/pdf/1711.09573.pdf) 3 | 4 | 5 | ## TO DO LIST: 6 | - [ ] refactor preprocessing code 7 | - [ ] add python to AST code 8 | - [ ] config for preprocessing 9 | - [ ] join training for type 10 | 11 | ## Requirments list: 12 | 13 | - python3 >= 3.6 14 | - torch >= 1.2, or tensorboadX for earlier versions 15 | - pyyaml 16 | 17 | 18 | ## Instruction: 19 | 20 | - run `python3 preprocess.py` for preprocessing 21 | - run `CUDA_VISIBLE_DEVICES=id python3 train.py --config=path/to/config.yml` for training with specified config, list of available configscan be found at configs folder 22 | 23 | ## Results for python value prediction (acc@1): 24 | 25 | 26 | model | vocab_size 1k | vocab_size 10k | vocab_size 50k 27 | --- | --- | --- | --- 28 | simple_lstm | 66.33 | 65.7 | 61.68, 1 epoch 29 | attn_lstm | 64.95 | 65.77 | 63.15, 1 epoch 30 | pointer_mixture | 66.62 | [67.05](https://www.dropbox.com/s/r69ksk7idd53s9n/epoch_0007.pth?dl=0) | [65.3, 3 epochs](https://www.dropbox.com/s/s40ruwonbeebpxm/epoch_0002.pth?dl=0) 31 | 32 | 33 | 34 | ## Examples: 35 | Here will be examples of code generation 36 | -------------------------------------------------------------------------------- /configs/attn_lstm_vocab_10k.yml: -------------------------------------------------------------------------------- 1 | name: attn_lstm_vocab_10k 2 | train: 3 | batch_size: 128 4 | LOAD_EPOCH: 4 5 | epochs: 8 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0.05 18 | num_layers: 1 19 | label_smoothing: 0 20 | pointer: False 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_10k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/attn_lstm_vocab_1k.yml: -------------------------------------------------------------------------------- 1 | name: attn_lstm_vocab_1k 2 | train: 3 | batch_size: 256 4 | LOAD_EPOCH: 2 5 | epochs: 5 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0.05 18 | num_layers: 1 19 | label_smoothing: 0 20 | pointer: False 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_1k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/attn_lstm_vocab_50k.yml: -------------------------------------------------------------------------------- 1 | name: attn_lstm_vocab_50k 2 | train: 3 | batch_size: 64 4 | LOAD_EPOCH: 5 | epochs: 1 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0 18 | num_layers: 1 19 | label_smoothing: 0 20 | pointer: False 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_50k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/label_smoothing_pointer_10k.yml: -------------------------------------------------------------------------------- 1 | name: label_smoothing_pointer_10k 2 | train: 3 | batch_size: 256 4 | LOAD_EPOCH: 5 | epochs: 6 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0.1 18 | num_layers: 1 19 | label_smoothing: 0.1 20 | pointer: True 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_10k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/pointer_vocab_10k.yml: -------------------------------------------------------------------------------- 1 | name: pointer_vocab_10k 2 | train: 3 | batch_size: 256 4 | LOAD_EPOCH: 4 5 | epochs: 8 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0 18 | num_layers: 1 19 | label_smoothing: 0 20 | pointer: True 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_10k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/pointer_vocab_1k.yml: -------------------------------------------------------------------------------- 1 | name: pointer_vocab_1k 2 | train: 3 | batch_size: 512 4 | LOAD_EPOCH: 5 | epochs: 5 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0.05 18 | num_layers: 1 19 | label_smoothing: 0 20 | pointer: True 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_1k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/pointer_vocab_50k.yml: -------------------------------------------------------------------------------- 1 | name: pointer_vocab_50k 2 | train: 3 | batch_size: 32 4 | LOAD_EPOCH: 0 5 | epochs: 5 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0 18 | num_layers: 1 19 | label_smoothing: 0 20 | pointer: True 21 | attn: True 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_50k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/simple_lstm_vocab_10k.yml: -------------------------------------------------------------------------------- 1 | name: simple_lstm_vocab_10k 2 | train: 3 | batch_size: 256 4 | LOAD_EPOCH: 5 5 | epochs: 8 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0 18 | num_layers: 1 19 | label_smoothing: 0 20 | attn: False 21 | pointer: False 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_10k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/simple_lstm_vocab_1k.yml: -------------------------------------------------------------------------------- 1 | name: simple_lstm_vocab_1k 2 | train: 3 | batch_size: 1024 4 | LOAD_EPOCH: 5 | epochs: 5 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0 18 | num_layers: 1 19 | label_smoothing: 0 20 | attn: False 21 | pointer: False 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_1k_whole.pickle 26 | -------------------------------------------------------------------------------- /configs/simple_lstm_vocab_50k.yml: -------------------------------------------------------------------------------- 1 | name: simple_lstm_vocab_50k 2 | train: 3 | batch_size: 128 4 | LOAD_EPOCH: 5 | epochs: 1 6 | num_workers: 6 7 | eval_period: 1 8 | checkpoint_period: 1 9 | device: cuda 10 | lr: 0.001 11 | lr_decay: 0.6 12 | clip_value: 5 13 | model: 14 | hidden_size: 800 15 | embedding_sizeT: 512 16 | embedding_sizeN: 300 17 | dropout: 0 18 | num_layers: 1 19 | label_smoothing: 0 20 | attn: False 21 | pointer: False 22 | data: 23 | truncate_size: 50 24 | N_filename: ./pickle_data/PY_non_terminal_small.pickle 25 | T_filename: ./pickle_data/PY_terminal_50k_whole.pickle 26 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from preprocess_utils.utils import * 3 | from tqdm import tqdm 4 | 5 | def fix_parent(p, start_i): 6 | p -= start_i 7 | return 0 if p < 0 else p 8 | 9 | def data_gen(data, split_size): 10 | for sample in data: 11 | accum_n = [] 12 | accum_t = [] 13 | accum_p = [] 14 | start_i = 0 15 | for i, item in enumerate(zip(*sample)): 16 | n, t, p = item 17 | p = fix_parent(p, start_i) 18 | accum_n.append(n) 19 | accum_t.append(t) 20 | accum_p.append(p) 21 | if len(accum_n) == split_size: 22 | yield accum_n, accum_t, accum_p 23 | accum_n = [] 24 | accum_t = [] 25 | accum_p = [] 26 | start_i = i 27 | if len(accum_n) > 0: 28 | yield accum_n, accum_t, accum_p 29 | 30 | class MainDataset(torch.utils.data.Dataset): 31 | def __init__(self, 32 | N_filename = './pickle_data/PY_non_terminal_small.pickle', 33 | T_filename = './pickle_data/PY_terminal_10k_whole.pickle', 34 | is_train=False, 35 | truncate_size=150 36 | ): 37 | super(MainDataset).__init__() 38 | train_dataN, test_dataN, vocab_sizeN, train_dataT, test_dataT, vocab_sizeT, attn_size, train_dataP, test_dataP = input_data( 39 | N_filename, T_filename 40 | ) 41 | self.is_train = is_train 42 | if self.is_train: 43 | self.data = [item for item in data_gen(zip(tqdm(train_dataN), train_dataT, train_dataP), truncate_size)] 44 | else: 45 | self.data = [item for item in data_gen(zip(tqdm(test_dataN), test_dataT, test_dataP), truncate_size)] 46 | self.data = sorted(self.data, key=lambda x: len(x[0])) 47 | self.vocab_sizeN = vocab_sizeN 48 | self.vocab_sizeT = vocab_sizeT 49 | self.attn_size = attn_size 50 | self.eof_N_id = vocab_sizeN - 1 51 | self.eof_T_id = vocab_sizeT - 1 52 | self.unk_id = vocab_sizeT - 2 53 | self.truncate_size = truncate_size 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, idx): 59 | item = self.data[idx] 60 | return item 61 | 62 | def collate_fn(self, samples, device='cpu'): 63 | sent_N = [sample[0] for sample in samples] 64 | sent_T = [sample[1] for sample in samples] 65 | sent_P = [sample[2] for sample in samples] 66 | 67 | s_max_length = max(map(lambda x: len(x), sent_N)) 68 | 69 | sent_N_tensors = [] 70 | sent_T_tensors = [] 71 | sent_P_tensors = [] 72 | 73 | for sn, st, sp in zip(sent_N, sent_T, sent_P): 74 | sn_tensor = torch.ones( 75 | s_max_length 76 | , dtype=torch.long 77 | , device=device 78 | ) * self.eof_N_id 79 | 80 | st_tensor = torch.ones( 81 | s_max_length 82 | , dtype=torch.long 83 | , device=device 84 | ) * self.eof_T_id 85 | 86 | sp_tensor = torch.ones( 87 | s_max_length 88 | , dtype=torch.long 89 | , device=device 90 | ) * 1 91 | 92 | for idx, w in enumerate(sn): 93 | sn_tensor[idx] = w 94 | st_tensor[idx] = st[idx] 95 | sp_tensor[idx] = sp[idx] 96 | sent_N_tensors.append(sn_tensor.unsqueeze(0)) 97 | sent_T_tensors.append(st_tensor.unsqueeze(0)) 98 | sent_P_tensors.append(sp_tensor.unsqueeze(0)) 99 | 100 | sent_N_tensors = torch.cat(sent_N_tensors, dim=0) 101 | sent_T_tensors = torch.cat(sent_T_tensors, dim=0) 102 | sent_P_tensors = torch.cat(sent_P_tensors, dim=0) 103 | 104 | return sent_N_tensors, sent_T_tensors, sent_P_tensors -------------------------------------------------------------------------------- /eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from model import *\n", 10 | "from data import *\n", 11 | "import os\n", 12 | "from tqdm import tqdm\n", 13 | "import yaml\n", 14 | "from utils import DotDict, adjust_learning_rate, accuracy\n", 15 | "import torch\n", 16 | "import traceback" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "losses = {'attn_lstm_vocab_10k': 10165.64429010069,\n", 26 | " 'attn_lstm_vocab_1k': 15595.033411789249,\n", 27 | " 'attn_lstm_vocab_50k': 6836.341175179966,\n", 28 | " 'pointer_vocab_10k': 21672.131244335622,\n", 29 | " 'pointer_vocab_1k': 43398.31753725069,\n", 30 | " 'pointer_vocab_50k': 7057.6959175422135,\n", 31 | " 'simple_lstm_vocab_10k': 20159.74126170967,\n", 32 | " 'simple_lstm_vocab_1k': 109026.57322798869,\n", 33 | " 'simple_lstm_vocab_50k': 13842.316331750844}\n", 34 | "\n", 35 | "\n", 36 | "epoch = {'attn_lstm_vocab_10k': 8,\n", 37 | " 'attn_lstm_vocab_1k': 5,\n", 38 | " 'attn_lstm_vocab_50k': 1,\n", 39 | " 'pointer_vocab_10k': 8,\n", 40 | " 'pointer_vocab_1k': 5,\n", 41 | " 'pointer_vocab_50k': 1,\n", 42 | " 'simple_lstm_vocab_10k': 8,\n", 43 | " 'simple_lstm_vocab_1k': 5,\n", 44 | " 'simple_lstm_vocab_50k': 1}\n", 45 | "\n", 46 | "acces = {'attn_lstm_vocab_10k': 0.6577360621117113,\n", 47 | " 'attn_lstm_vocab_1k': 0.6494665779492861,\n", 48 | " 'attn_lstm_vocab_50k': 0.6315070176565867,\n", 49 | " 'pointer_vocab_10k': 0.6705065943604063,\n", 50 | " 'pointer_vocab_1k': 0.6662379230942852,\n", 51 | " 'pointer_vocab_50k': 0.634914800961635,\n", 52 | " 'simple_lstm_vocab_10k': 0.657081022143478,\n", 53 | " 'simple_lstm_vocab_1k': 0.6633081460420689,\n", 54 | " 'simple_lstm_vocab_50k': 0.6168174382832627}" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 15, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "started label_smoothing_pointer_10k\n", 67 | "started pointer_vocab_10k\n", 68 | "already calculated\n", 69 | "started simple_lstm_vocab_10k\n", 70 | "already calculated\n", 71 | "started attn_lstm_vocab_50k\n", 72 | "already calculated\n", 73 | "started simple_lstm_vocab_1k\n", 74 | "reading data from ./pickle_data/PY_non_terminal_small.pickle\n" 75 | ] 76 | }, 77 | { 78 | "name": "stderr", 79 | "output_type": "stream", 80 | "text": [ 81 | "Traceback (most recent call last):\n", 82 | " File \"\", line 13, in \n", 83 | " last_cpk = sorted(os.listdir(checkpoint_folder), key=lambda x: int(x[6:-4]), reverse=True)[0]\n", 84 | "FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/label_smoothing_pointer_10k'\n" 85 | ] 86 | }, 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "the vocab_sizeN is 330 (not including the eof)\n", 92 | "the number of training data is 100000\n", 93 | "the number of test data is 50000\n", 94 | "\n", 95 | "reading data from ./pickle_data/PY_terminal_1k_whole.pickle\n", 96 | "the vocab_sizeT is 1000 (not including the unk and eof)\n", 97 | "the attn_size is 50\n", 98 | "the number of training data is 100000\n", 99 | "the number of test data is 50000\n", 100 | "Finish reading data and take 13.77\n", 101 | "\n" 102 | ] 103 | }, 104 | { 105 | "name": "stderr", 106 | "output_type": "stream", 107 | "text": [ 108 | "100%|██████████| 50000/50000 [00:33<00:00, 1501.39it/s]\n", 109 | "100%|██████████| 619/619 [04:29<00:00, 2.65it/s]\n" 110 | ] 111 | }, 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "loss: 109026.57322798869 acc: 0.6633081460420689\n", 117 | "started attn_lstm_vocab_10k\n", 118 | "already calculated\n", 119 | "started pointer_vocab_50k\n", 120 | "already calculated\n", 121 | "started pointer_vocab_1k\n", 122 | "already calculated\n", 123 | "started attn_lstm_vocab_1k\n", 124 | "already calculated\n", 125 | "started simple_lstm_vocab_50k\n", 126 | "already calculated\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "for config in os.listdir('configs'):\n", 132 | " if not config.endswith('yml'):\n", 133 | " continue\n", 134 | " config = os.path.join('configs', config)\n", 135 | " with open(config, 'r') as f:\n", 136 | " config = DotDict(yaml.safe_load(f))\n", 137 | " print('started', config.name)\n", 138 | " if config.name in acces:\n", 139 | " print('already calculated')\n", 140 | " continue\n", 141 | " checkpoint_folder = os.path.join('checkpoints', config.name)\n", 142 | " try:\n", 143 | " last_cpk = sorted(os.listdir(checkpoint_folder), key=lambda x: int(x[6:-4]), reverse=True)[0]\n", 144 | " except:\n", 145 | " traceback.print_exc()\n", 146 | " continue\n", 147 | " checkpoint_path = os.path.join(checkpoint_folder, last_cpk)\n", 148 | " \n", 149 | " device = config.train.device\n", 150 | "\n", 151 | " data_val = MainDataset(\n", 152 | " N_filename = config.data.N_filename,\n", 153 | " T_filename = config.data.T_filename,\n", 154 | " is_train=False,\n", 155 | " truncate_size=config.data.truncate_size\n", 156 | " )\n", 157 | "\n", 158 | " test_loader = torch.utils.data.DataLoader(\n", 159 | " data_val,\n", 160 | " batch_size=config.train.batch_size,\n", 161 | " shuffle=False,\n", 162 | " num_workers=config.train.num_workers,\n", 163 | " collate_fn=data_val.collate_fn\n", 164 | " )\n", 165 | " \n", 166 | " ignored_index = data_val.vocab_sizeT - 1\n", 167 | " unk_index = data_val.vocab_sizeT - 2\n", 168 | " \n", 169 | " model = MixtureAttention(\n", 170 | " hidden_size = config.model.hidden_size,\n", 171 | " vocab_sizeT = data_val.vocab_sizeT,\n", 172 | " vocab_sizeN = data_val.vocab_sizeN,\n", 173 | " attn_size = data_val.attn_size,\n", 174 | " embedding_sizeT = config.model.embedding_sizeT,\n", 175 | " embedding_sizeN = config.model.embedding_sizeN,\n", 176 | " num_layers = 1,\n", 177 | " dropout = config.model.dropout,\n", 178 | " label_smoothing = config.model.label_smoothing,\n", 179 | " pointer = config.model.pointer,\n", 180 | " attn = config.model.attn,\n", 181 | " device = device\n", 182 | " )\n", 183 | " cpk = torch.load(checkpoint_path)\n", 184 | " model.load_state_dict(cpk['model'])\n", 185 | " model = model.to(device)\n", 186 | " with torch.no_grad():\n", 187 | " model = model.eval()\n", 188 | " acc = 0.\n", 189 | " loss_eval = 0.\n", 190 | " for i, (n, t, p) in enumerate(tqdm(test_loader)):\n", 191 | " n, t, p = n.to(device), t.to(device), p.to(device)\n", 192 | " loss, ans = model(n, t, p)\n", 193 | " loss_eval += loss.item()\n", 194 | " acc += accuracy(ans.cpu().numpy().flatten(), t.cpu().numpy().flatten(), ignored_index, unk_index)\n", 195 | " acc /= len(test_loader)\n", 196 | " loss_eval /= len(test_loader)\n", 197 | " losses[config.name] = loss_eval\n", 198 | " acces[config.name] = acc\n", 199 | " print('loss:', losses[config.name], 'acc:', acces[config.name])\n", 200 | " torch.cuda.empty_cache()" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 3, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stderr", 210 | "output_type": "stream", 211 | "text": [ 212 | "/data/anaconda/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", 213 | " return f(*args, **kwds)\n", 214 | "/data/anaconda/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", 215 | " return f(*args, **kwds)\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "import pandas as pd" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def make_table(dic, m = 1):\n", 230 | " data = pd.DataFrame(columns=['1k', '10k', '50k'], index=['simple_lstm', 'attn_lstm', 'pointer'])\n", 231 | " for item in dic:\n", 232 | " id_ = int([i for i, j in enumerate(item.split('_')) if j == 'vocab'][0])\n", 233 | " name = '_'.join(item.split('_')[:id_])\n", 234 | " vocab = item.split('_')[-1]\n", 235 | " data.loc[name, vocab] = dic[item]\n", 236 | " return data * m" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 5, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "{'attn_lstm_vocab_10k': 0.6577360621117113,\n", 248 | " 'attn_lstm_vocab_1k': 0.6494665779492861,\n", 249 | " 'attn_lstm_vocab_50k': 0.6315070176565867,\n", 250 | " 'pointer_vocab_10k': 0.6705065943604063,\n", 251 | " 'pointer_vocab_1k': 0.6662379230942852,\n", 252 | " 'pointer_vocab_50k': 0.634914800961635,\n", 253 | " 'simple_lstm_vocab_10k': 0.657081022143478,\n", 254 | " 'simple_lstm_vocab_1k': 0.6633081460420689,\n", 255 | " 'simple_lstm_vocab_50k': 0.6168174382832627}" 256 | ] 257 | }, 258 | "execution_count": 5, 259 | "metadata": {}, 260 | "output_type": "execute_result" 261 | } 262 | ], 263 | "source": [ 264 | "acces" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 6, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "{'attn_lstm_vocab_10k': 10165.64429010069,\n", 276 | " 'attn_lstm_vocab_1k': 15595.033411789249,\n", 277 | " 'attn_lstm_vocab_50k': 6836.341175179966,\n", 278 | " 'pointer_vocab_10k': 21672.131244335622,\n", 279 | " 'pointer_vocab_1k': 43398.31753725069,\n", 280 | " 'pointer_vocab_50k': 7057.6959175422135,\n", 281 | " 'simple_lstm_vocab_10k': 20159.74126170967,\n", 282 | " 'simple_lstm_vocab_1k': 109026.57322798869,\n", 283 | " 'simple_lstm_vocab_50k': 13842.316331750844}" 284 | ] 285 | }, 286 | "execution_count": 6, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "losses" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 9, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "data": { 302 | "text/html": [ 303 | "
\n", 304 | "\n", 317 | "\n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | "
1k10k50k
simple_lstm66.330865.708161.6817
attn_lstm64.946765.773663.1507
pointer66.623867.050763.4915
\n", 347 | "
" 348 | ], 349 | "text/plain": [ 350 | " 1k 10k 50k\n", 351 | "simple_lstm 66.3308 65.7081 61.6817\n", 352 | "attn_lstm 64.9467 65.7736 63.1507\n", 353 | "pointer 66.6238 67.0507 63.4915" 354 | ] 355 | }, 356 | "execution_count": 9, 357 | "metadata": {}, 358 | "output_type": "execute_result" 359 | } 360 | ], 361 | "source": [ 362 | "make_table(acces, m = 100)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 10, 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/html": [ 373 | "
\n", 374 | "\n", 387 | "\n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | "
1k10k50k
simple_lstm10902720159.713842.3
attn_lstm1559510165.66836.34
pointer43398.321672.17057.7
\n", 417 | "
" 418 | ], 419 | "text/plain": [ 420 | " 1k 10k 50k\n", 421 | "simple_lstm 109027 20159.7 13842.3\n", 422 | "attn_lstm 15595 10165.6 6836.34\n", 423 | "pointer 43398.3 21672.1 7057.7" 424 | ] 425 | }, 426 | "execution_count": 10, 427 | "metadata": {}, 428 | "output_type": "execute_result" 429 | } 430 | ], 431 | "source": [ 432 | "make_table(losses)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 11, 438 | "metadata": {}, 439 | "outputs": [ 440 | { 441 | "data": { 442 | "text/html": [ 443 | "
\n", 444 | "\n", 457 | "\n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | "
1k10k50k
simple_lstm581
attn_lstm581
pointer581
\n", 487 | "
" 488 | ], 489 | "text/plain": [ 490 | " 1k 10k 50k\n", 491 | "simple_lstm 5 8 1\n", 492 | "attn_lstm 5 8 1\n", 493 | "pointer 5 8 1" 494 | ] 495 | }, 496 | "execution_count": 11, 497 | "metadata": {}, 498 | "output_type": "execute_result" 499 | } 500 | ], 501 | "source": [ 502 | "make_table(epoch)" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": {}, 508 | "source": [ 509 | "### Eval from an example" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 12, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "from model import *\n", 519 | "from data import *\n", 520 | "import os\n", 521 | "from tqdm import tqdm\n", 522 | "import yaml\n", 523 | "from utils import DotDict, adjust_learning_rate, accuracy\n", 524 | "import torch\n", 525 | "import traceback" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 13, 531 | "metadata": {}, 532 | "outputs": [], 533 | "source": [ 534 | "os.environ['CUDA_VISIBLE_DEVICES'] = ''" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 16, 540 | "metadata": {}, 541 | "outputs": [ 542 | { 543 | "name": "stdout", 544 | "output_type": "stream", 545 | "text": [ 546 | "started pointer_vocab_10k\n", 547 | "reading data from ./pickle_data/PY_non_terminal_small.pickle\n", 548 | "the vocab_sizeN is 330 (not including the eof)\n", 549 | "the number of training data is 100000\n", 550 | "the number of test data is 50000\n", 551 | "\n", 552 | "reading data from ./pickle_data/PY_terminal_10k_whole.pickle\n", 553 | "the vocab_sizeT is 10000 (not including the unk and eof)\n", 554 | "the attn_size is 50\n", 555 | "the number of training data is 100000\n", 556 | "the number of test data is 50000\n", 557 | "Finish reading data and take 13.11\n", 558 | "\n" 559 | ] 560 | }, 561 | { 562 | "name": "stderr", 563 | "output_type": "stream", 564 | "text": [ 565 | "100%|██████████| 50000/50000 [00:33<00:00, 1504.67it/s]\n" 566 | ] 567 | }, 568 | { 569 | "ename": "RuntimeError", 570 | "evalue": "Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.", 571 | "output_type": "error", 572 | "traceback": [ 573 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 574 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 575 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m )\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mcpk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcpk\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'model'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 576 | "\u001b[0;32m/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 385\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 386\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 387\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnew_fd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 577 | "\u001b[0;32m/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0munpickler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnpickler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0munpickler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpersistent_load\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 573\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 574\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 575\u001b[0m \u001b[0mdeserialized_storage_keys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 578 | "\u001b[0;32m/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mpersistent_load\u001b[0;34m(saved_id)\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 535\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_torch_load_uninitialized\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 536\u001b[0;31m \u001b[0mdeserialized_objects\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mroot_key\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrestore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 537\u001b[0m \u001b[0mstorage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeserialized_objects\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mroot_key\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mview_metadata\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 579 | "\u001b[0;32m/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mdefault_restore_location\u001b[0;34m(storage, location)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdefault_restore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_package_registry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 580 | "\u001b[0;32m/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_cuda_deserialize\u001b[0;34m(obj, location)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_cuda_deserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 95\u001b[0;31m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidate_cuda_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 96\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"_torch_load_uninitialized\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mstorage_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 581 | "\u001b[0;32m/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mvalidate_cuda_device\u001b[0;34m(location)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m raise RuntimeError('Attempting to deserialize object on a CUDA '\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;34m'device but torch.cuda.is_available() is False. '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m'If you are running on a CPU-only machine, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 582 | "\u001b[0;31mRuntimeError\u001b[0m: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU." 583 | ] 584 | } 585 | ], 586 | "source": [ 587 | "config = 'configs/pointer_vocab_10k.yml'\n", 588 | "with open(config, 'r') as f:\n", 589 | " config = DotDict(yaml.safe_load(f))\n", 590 | "print('started', config.name)\n", 591 | "checkpoint_folder = os.path.join('checkpoints', config.name)\n", 592 | "last_cpk = sorted(os.listdir(checkpoint_folder), key=lambda x: int(x[6:-4]), reverse=True)[0]\n", 593 | "checkpoint_path = os.path.join(checkpoint_folder, last_cpk)\n", 594 | "\n", 595 | "device = 'cpu'\n", 596 | "\n", 597 | "data_val = MainDataset(\n", 598 | " N_filename = config.data.N_filename,\n", 599 | " T_filename = config.data.T_filename,\n", 600 | " is_train=False,\n", 601 | " truncate_size=config.data.truncate_size\n", 602 | ")\n", 603 | "\n", 604 | "test_loader = torch.utils.data.DataLoader(\n", 605 | " data_val,\n", 606 | " batch_size=config.train.batch_size,\n", 607 | " shuffle=False,\n", 608 | " num_workers=config.train.num_workers,\n", 609 | " collate_fn=data_val.collate_fn\n", 610 | ")\n", 611 | "\n", 612 | "ignored_index = data_val.vocab_sizeT - 1\n", 613 | "unk_index = data_val.vocab_sizeT - 2\n", 614 | "\n", 615 | "model = MixtureAttention(\n", 616 | " hidden_size = config.model.hidden_size,\n", 617 | " vocab_sizeT = data_val.vocab_sizeT,\n", 618 | " vocab_sizeN = data_val.vocab_sizeN,\n", 619 | " attn_size = data_val.attn_size,\n", 620 | " embedding_sizeT = config.model.embedding_sizeT,\n", 621 | " embedding_sizeN = config.model.embedding_sizeN,\n", 622 | " num_layers = 1,\n", 623 | " dropout = config.model.dropout,\n", 624 | " label_smoothing = config.model.label_smoothing,\n", 625 | " pointer = config.model.pointer,\n", 626 | " attn = config.model.attn,\n", 627 | " device = device\n", 628 | ")" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 17, 634 | "metadata": {}, 635 | "outputs": [], 636 | "source": [ 637 | "cpk = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n", 638 | "model.load_state_dict(cpk['model'])\n", 639 | "model = model.to(device)" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 24, 645 | "metadata": {}, 646 | "outputs": [ 647 | { 648 | "name": "stdout", 649 | "output_type": "stream", 650 | "text": [ 651 | "reading data from ./pickle_data/PY_non_terminal_small.pickle\n", 652 | "the vocab_sizeN is 330 (not including the eof)\n", 653 | "the number of training data is 100000\n", 654 | "the number of test data is 50000\n", 655 | "\n", 656 | "reading data from ./pickle_data/PY_terminal_10k_whole.pickle\n", 657 | "the vocab_sizeT is 10000 (not including the unk and eof)\n", 658 | "the attn_size is 50\n", 659 | "the number of training data is 100000\n", 660 | "the number of test data is 50000\n", 661 | "Finish reading data and take 9.66\n", 662 | "\n" 663 | ] 664 | } 665 | ], 666 | "source": [ 667 | "train_dataN, test_dataN, vocab_sizeN, train_dataT, test_dataT, vocab_sizeT, attn_size, train_dataP, test_dataP = input_data(\n", 668 | " config.data.N_filename, config.data.T_filename\n", 669 | ")" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 49, 675 | "metadata": {}, 676 | "outputs": [], 677 | "source": [ 678 | "f = open('pickle_data/terminal_dict_10k_PY.pickle', 'rb')\n", 679 | "t_dict = pickle.load(f)\n", 680 | "f.close()" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 55, 686 | "metadata": {}, 687 | "outputs": [ 688 | { 689 | "data": { 690 | "text/plain": [ 691 | "10000" 692 | ] 693 | }, 694 | "execution_count": 55, 695 | "metadata": {}, 696 | "output_type": "execute_result" 697 | } 698 | ], 699 | "source": [ 700 | "len(t_dict['terminal_dict'])" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 56, 706 | "metadata": {}, 707 | "outputs": [], 708 | "source": [ 709 | "t_reversed = {val: key for key, val in t_dict['terminal_dict'].items()}" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 74, 715 | "metadata": {}, 716 | "outputs": [], 717 | "source": [ 718 | "def decode(arr):\n", 719 | " return [t_reversed[item.item()] if item.item() < 10000 else item.item() - 10000 for item in arr]" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 70, 725 | "metadata": {}, 726 | "outputs": [], 727 | "source": [ 728 | "sample_n = torch.tensor(test_dataN[5:6])\n", 729 | "sample_t = torch.tensor(test_dataT[5:6])\n", 730 | "sample_p = torch.tensor(test_dataP[5:6])" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": 71, 736 | "metadata": {}, 737 | "outputs": [], 738 | "source": [ 739 | "loss, ans = model(sample_n, sample_t, sample_p)" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": 75, 745 | "metadata": {}, 746 | "outputs": [ 747 | { 748 | "data": { 749 | "text/plain": [ 750 | "['',\n", 751 | " 'enum',\n", 752 | " 0,\n", 753 | " 'component',\n", 754 | " 'Component',\n", 755 | " 'object',\n", 756 | " 'field',\n", 757 | " 0,\n", 758 | " '',\n", 759 | " 8,\n", 760 | " '',\n", 761 | " '',\n", 762 | " 0,\n", 763 | " '0',\n", 764 | " '',\n", 765 | " 0,\n", 766 | " '1',\n", 767 | " '',\n", 768 | " 0,\n", 769 | " '2',\n", 770 | " '',\n", 771 | " 0,\n", 772 | " '3',\n", 773 | " '',\n", 774 | " 0,\n", 775 | " '',\n", 776 | " 18,\n", 777 | " '',\n", 778 | " '',\n", 779 | " 18,\n", 780 | " '0',\n", 781 | " '',\n", 782 | " 0,\n", 783 | " '1',\n", 784 | " '',\n", 785 | " 0,\n", 786 | " '2',\n", 787 | " '',\n", 788 | " 0,\n", 789 | " '3',\n", 790 | " '',\n", 791 | " 0,\n", 792 | " '',\n", 793 | " 'Component',\n", 794 | " '',\n", 795 | " '',\n", 796 | " 'enabled',\n", 797 | " '',\n", 798 | " 'field',\n", 799 | " 0,\n", 800 | " 'bool',\n", 801 | " '',\n", 802 | " 0,\n", 803 | " '',\n", 804 | " 'field',\n", 805 | " 0,\n", 806 | " '',\n", 807 | " 'materials',\n", 808 | " '',\n", 809 | " 'field',\n", 810 | " 0,\n", 811 | " '',\n", 812 | " 0,\n", 813 | " '',\n", 814 | " 'field',\n", 815 | " 0,\n", 816 | " '',\n", 817 | " 0,\n", 818 | " '',\n", 819 | " 'field',\n", 820 | " 0,\n", 821 | " 'bool',\n", 822 | " '',\n", 823 | " 0,\n", 824 | " '',\n", 825 | " 'field',\n", 826 | " 0,\n", 827 | " 0,\n", 828 | " '',\n", 829 | " 0,\n", 830 | " '',\n", 831 | " 'field',\n", 832 | " 0,\n", 833 | " 0,\n", 834 | " '',\n", 835 | " 0,\n", 836 | " '',\n", 837 | " 'field',\n", 838 | " 0,\n", 839 | " '',\n", 840 | " 0,\n", 841 | " '',\n", 842 | " 'field',\n", 843 | " 0,\n", 844 | " '',\n", 845 | " 0,\n", 846 | " '',\n", 847 | " 'field',\n", 848 | " 0,\n", 849 | " 'bool',\n", 850 | " '',\n", 851 | " 0,\n", 852 | " '',\n", 853 | " 'field',\n", 854 | " 0,\n", 855 | " '',\n", 856 | " 0,\n", 857 | " '',\n", 858 | " 'field',\n", 859 | " 0,\n", 860 | " '',\n", 861 | " 0,\n", 862 | " '',\n", 863 | " 'field',\n", 864 | " 0,\n", 865 | " '',\n", 866 | " 0,\n", 867 | " '',\n", 868 | " 'field',\n", 869 | " 0,\n", 870 | " '',\n", 871 | " 0,\n", 872 | " '',\n", 873 | " 'field',\n", 874 | " 0,\n", 875 | " 'material',\n", 876 | " '',\n", 877 | " '',\n", 878 | " 'self',\n", 879 | " '',\n", 880 | " '',\n", 881 | " '',\n", 882 | " '',\n", 883 | " '',\n", 884 | " 'self',\n", 885 | " 'materials',\n", 886 | " '',\n", 887 | " '0',\n", 888 | " '',\n", 889 | " 'property',\n", 890 | " '',\n", 891 | " 0,\n", 892 | " '',\n", 893 | " 0,\n", 894 | " '',\n", 895 | " '',\n", 896 | " 0,\n", 897 | " '0',\n", 898 | " '',\n", 899 | " 0,\n", 900 | " '1',\n", 901 | " '',\n", 902 | " 0,\n", 903 | " '2',\n", 904 | " '',\n", 905 | " 0,\n", 906 | " '3',\n", 907 | " '',\n", 908 | " 0,\n", 909 | " '4',\n", 910 | " '',\n", 911 | " 0,\n", 912 | " '',\n", 913 | " 21,\n", 914 | " '',\n", 915 | " '',\n", 916 | " 0,\n", 917 | " '0',\n", 918 | " '',\n", 919 | " 0,\n", 920 | " '1',\n", 921 | " '',\n", 922 | " 0,\n", 923 | " '2',\n", 924 | " '',\n", 925 | " 0,\n", 926 | " '3',\n", 927 | " '',\n", 928 | " 0,\n", 929 | " '',\n", 930 | " 'Component',\n", 931 | " '',\n", 932 | " '',\n", 933 | " '',\n", 934 | " 0,\n", 935 | " '',\n", 936 | " 0,\n", 937 | " '',\n", 938 | " '',\n", 939 | " 0,\n", 940 | " '',\n", 941 | " 'field',\n", 942 | " 0,\n", 943 | " '',\n", 944 | " 0,\n", 945 | " '',\n", 946 | " 'field',\n", 947 | " 0,\n", 948 | " '',\n", 949 | " 0,\n", 950 | " '',\n", 951 | " 'field',\n", 952 | " 0,\n", 953 | " '',\n", 954 | " 0,\n", 955 | " '',\n", 956 | " 'field',\n", 957 | " 0,\n", 958 | " '',\n", 959 | " 0,\n", 960 | " '',\n", 961 | " 'field',\n", 962 | " 0,\n", 963 | " '',\n", 964 | " 0,\n", 965 | " '',\n", 966 | " 'field',\n", 967 | " 0,\n", 968 | " '',\n", 969 | " 0,\n", 970 | " '',\n", 971 | " 36,\n", 972 | " '',\n", 973 | " '',\n", 974 | " 36,\n", 975 | " '',\n", 976 | " 'field',\n", 977 | " 36,\n", 978 | " '',\n", 979 | " 36,\n", 980 | " '',\n", 981 | " 'field',\n", 982 | " 36,\n", 983 | " '',\n", 984 | " 36,\n", 985 | " '',\n", 986 | " 'field',\n", 987 | " 36,\n", 988 | " '',\n", 989 | " 'mesh',\n", 990 | " '',\n", 991 | " 'field',\n", 992 | " 0,\n", 993 | " '',\n", 994 | " 0,\n", 995 | " '',\n", 996 | " 'field',\n", 997 | " 0,\n", 998 | " '',\n", 999 | " 0,\n", 1000 | " '',\n", 1001 | " 'field',\n", 1002 | " 0,\n", 1003 | " '',\n", 1004 | " 0,\n", 1005 | " '',\n", 1006 | " 'field',\n", 1007 | " 0,\n", 1008 | " '',\n", 1009 | " 0,\n", 1010 | " '',\n", 1011 | " 'field',\n", 1012 | " 0,\n", 1013 | " '',\n", 1014 | " 0,\n", 1015 | " '',\n", 1016 | " 'field',\n", 1017 | " 0,\n", 1018 | " 0,\n", 1019 | " '',\n", 1020 | " 0,\n", 1021 | " '',\n", 1022 | " 'field',\n", 1023 | " 0,\n", 1024 | " 0,\n", 1025 | " '',\n", 1026 | " 0,\n", 1027 | " '',\n", 1028 | " 'field',\n", 1029 | " 0,\n", 1030 | " '',\n", 1031 | " 0,\n", 1032 | " '',\n", 1033 | " 'field',\n", 1034 | " 0,\n", 1035 | " '']" 1036 | ] 1037 | }, 1038 | "execution_count": 75, 1039 | "metadata": {}, 1040 | "output_type": "execute_result" 1041 | } 1042 | ], 1043 | "source": [ 1044 | "decode(sample_t[0])" 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "code", 1049 | "execution_count": 76, 1050 | "metadata": {}, 1051 | "outputs": [ 1052 | { 1053 | "data": { 1054 | "text/plain": [ 1055 | "['',\n", 1056 | " '',\n", 1057 | " 'Enum',\n", 1058 | " 0,\n", 1059 | " 0,\n", 1060 | " '',\n", 1061 | " 'Object',\n", 1062 | " '',\n", 1063 | " '',\n", 1064 | " 'object',\n", 1065 | " '',\n", 1066 | " '',\n", 1067 | " 0,\n", 1068 | " '',\n", 1069 | " '',\n", 1070 | " 0,\n", 1071 | " '1',\n", 1072 | " '',\n", 1073 | " 0,\n", 1074 | " '2',\n", 1075 | " '',\n", 1076 | " 0,\n", 1077 | " '3',\n", 1078 | " '',\n", 1079 | " 0,\n", 1080 | " '',\n", 1081 | " 9,\n", 1082 | " '',\n", 1083 | " '',\n", 1084 | " 15,\n", 1085 | " '2',\n", 1086 | " '',\n", 1087 | " 24,\n", 1088 | " '1',\n", 1089 | " '',\n", 1090 | " 0,\n", 1091 | " '2',\n", 1092 | " '',\n", 1093 | " 0,\n", 1094 | " '3',\n", 1095 | " '',\n", 1096 | " 0,\n", 1097 | " '',\n", 1098 | " 9,\n", 1099 | " '',\n", 1100 | " '',\n", 1101 | " 15,\n", 1102 | " '1',\n", 1103 | " '',\n", 1104 | " 'mandatory',\n", 1105 | " '',\n", 1106 | " '',\n", 1107 | " 0,\n", 1108 | " '',\n", 1109 | " 'field',\n", 1110 | " 0,\n", 1111 | " '',\n", 1112 | " 0,\n", 1113 | " '',\n", 1114 | " 'field',\n", 1115 | " 0,\n", 1116 | " '',\n", 1117 | " 0,\n", 1118 | " '',\n", 1119 | " 'field',\n", 1120 | " 0,\n", 1121 | " '',\n", 1122 | " 0,\n", 1123 | " '',\n", 1124 | " 'field',\n", 1125 | " 0,\n", 1126 | " '',\n", 1127 | " '',\n", 1128 | " 0,\n", 1129 | " '',\n", 1130 | " 'field',\n", 1131 | " 0,\n", 1132 | " '',\n", 1133 | " '',\n", 1134 | " 0,\n", 1135 | " '',\n", 1136 | " 'field',\n", 1137 | " 0,\n", 1138 | " '',\n", 1139 | " '',\n", 1140 | " 0,\n", 1141 | " '',\n", 1142 | " 'field',\n", 1143 | " 0,\n", 1144 | " '',\n", 1145 | " 0,\n", 1146 | " '',\n", 1147 | " 'field',\n", 1148 | " 0,\n", 1149 | " '',\n", 1150 | " 0,\n", 1151 | " '',\n", 1152 | " 'field',\n", 1153 | " 0,\n", 1154 | " '',\n", 1155 | " '',\n", 1156 | " 0,\n", 1157 | " '',\n", 1158 | " 'field',\n", 1159 | " 0,\n", 1160 | " '',\n", 1161 | " 0,\n", 1162 | " '',\n", 1163 | " 'field',\n", 1164 | " 0,\n", 1165 | " '',\n", 1166 | " 0,\n", 1167 | " '',\n", 1168 | " 'field',\n", 1169 | " 0,\n", 1170 | " '',\n", 1171 | " 0,\n", 1172 | " '',\n", 1173 | " 'field',\n", 1174 | " 0,\n", 1175 | " '',\n", 1176 | " 0,\n", 1177 | " '',\n", 1178 | " 'field',\n", 1179 | " 0,\n", 1180 | " '',\n", 1181 | " '',\n", 1182 | " '',\n", 1183 | " 'self',\n", 1184 | " '',\n", 1185 | " '',\n", 1186 | " '',\n", 1187 | " '',\n", 1188 | " '',\n", 1189 | " 'self',\n", 1190 | " 0,\n", 1191 | " '',\n", 1192 | " 0,\n", 1193 | " '',\n", 1194 | " 'property',\n", 1195 | " 0,\n", 1196 | " 0,\n", 1197 | " '',\n", 1198 | " 0,\n", 1199 | " '',\n", 1200 | " '',\n", 1201 | " 'code',\n", 1202 | " '',\n", 1203 | " '',\n", 1204 | " 0,\n", 1205 | " '1',\n", 1206 | " '',\n", 1207 | " 0,\n", 1208 | " '2',\n", 1209 | " '',\n", 1210 | " 0,\n", 1211 | " '3',\n", 1212 | " '',\n", 1213 | " 0,\n", 1214 | " '4',\n", 1215 | " '',\n", 1216 | " 0,\n", 1217 | " '',\n", 1218 | " 0,\n", 1219 | " '',\n", 1220 | " '',\n", 1221 | " 0,\n", 1222 | " '1',\n", 1223 | " '',\n", 1224 | " 0,\n", 1225 | " '1',\n", 1226 | " '',\n", 1227 | " 0,\n", 1228 | " '2',\n", 1229 | " '',\n", 1230 | " 0,\n", 1231 | " '3',\n", 1232 | " '',\n", 1233 | " 0,\n", 1234 | " '',\n", 1235 | " '_object',\n", 1236 | " '',\n", 1237 | " '',\n", 1238 | " '',\n", 1239 | " 0,\n", 1240 | " '',\n", 1241 | " 47,\n", 1242 | " '',\n", 1243 | " '',\n", 1244 | " 0,\n", 1245 | " '1',\n", 1246 | " '',\n", 1247 | " 'type',\n", 1248 | " '',\n", 1249 | " 0,\n", 1250 | " '',\n", 1251 | " 'field',\n", 1252 | " 0,\n", 1253 | " '',\n", 1254 | " 0,\n", 1255 | " '',\n", 1256 | " 'field',\n", 1257 | " 0,\n", 1258 | " '',\n", 1259 | " 0,\n", 1260 | " '',\n", 1261 | " 'field',\n", 1262 | " 0,\n", 1263 | " '',\n", 1264 | " 0,\n", 1265 | " '',\n", 1266 | " 'field',\n", 1267 | " 0,\n", 1268 | " '',\n", 1269 | " 0,\n", 1270 | " '',\n", 1271 | " 'field',\n", 1272 | " 0,\n", 1273 | " '',\n", 1274 | " 0,\n", 1275 | " '',\n", 1276 | " '',\n", 1277 | " '',\n", 1278 | " '',\n", 1279 | " 0,\n", 1280 | " '',\n", 1281 | " 'field',\n", 1282 | " 0,\n", 1283 | " '',\n", 1284 | " 0,\n", 1285 | " '',\n", 1286 | " 'field',\n", 1287 | " 0,\n", 1288 | " '',\n", 1289 | " 0,\n", 1290 | " '',\n", 1291 | " 'field',\n", 1292 | " 0,\n", 1293 | " '',\n", 1294 | " 48,\n", 1295 | " '',\n", 1296 | " 'field',\n", 1297 | " 'type',\n", 1298 | " '',\n", 1299 | " 0,\n", 1300 | " '',\n", 1301 | " 'field',\n", 1302 | " 0,\n", 1303 | " '',\n", 1304 | " 0,\n", 1305 | " '',\n", 1306 | " 'field',\n", 1307 | " 0,\n", 1308 | " '',\n", 1309 | " 0,\n", 1310 | " '',\n", 1311 | " 'field',\n", 1312 | " 0,\n", 1313 | " '',\n", 1314 | " 0,\n", 1315 | " '',\n", 1316 | " 'field',\n", 1317 | " 0,\n", 1318 | " '',\n", 1319 | " 0,\n", 1320 | " '',\n", 1321 | " 'field',\n", 1322 | " 0,\n", 1323 | " '',\n", 1324 | " '',\n", 1325 | " 0,\n", 1326 | " '',\n", 1327 | " 'field',\n", 1328 | " 0,\n", 1329 | " '',\n", 1330 | " '',\n", 1331 | " 0,\n", 1332 | " '',\n", 1333 | " 'field',\n", 1334 | " 0,\n", 1335 | " '',\n", 1336 | " 0,\n", 1337 | " '',\n", 1338 | " 'field',\n", 1339 | " 0,\n", 1340 | " '']" 1341 | ] 1342 | }, 1343 | "execution_count": 76, 1344 | "metadata": {}, 1345 | "output_type": "execute_result" 1346 | } 1347 | ], 1348 | "source": [ 1349 | "decode(ans[0])" 1350 | ] 1351 | }, 1352 | { 1353 | "cell_type": "code", 1354 | "execution_count": null, 1355 | "metadata": {}, 1356 | "outputs": [], 1357 | "source": [] 1358 | } 1359 | ], 1360 | "metadata": { 1361 | "kernelspec": { 1362 | "display_name": "Python [default]", 1363 | "language": "python", 1364 | "name": "python3" 1365 | }, 1366 | "language_info": { 1367 | "codemirror_mode": { 1368 | "name": "ipython", 1369 | "version": 3 1370 | }, 1371 | "file_extension": ".py", 1372 | "mimetype": "text/x-python", 1373 | "name": "python", 1374 | "nbconvert_exporter": "python", 1375 | "pygments_lexer": "ipython3", 1376 | "version": "3.5.5" 1377 | } 1378 | }, 1379 | "nbformat": 4, 1380 | "nbformat_minor": 2 1381 | } 1382 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from utils import * 10 | 11 | class DecoderSimple(nn.Module): 12 | def __init__( 13 | self, 14 | hidden_size, 15 | vocab_sizeT, 16 | vocab_sizeN, 17 | embedding_sizeT, 18 | embedding_sizeN, 19 | dropout, 20 | num_layers, 21 | device='cuda' 22 | ): 23 | super(DecoderSimple, self).__init__() 24 | self.num_layers = num_layers 25 | self.hidden_size = hidden_size 26 | self.device = device 27 | self.dropout = dropout 28 | 29 | self.embeddingN = nn.Embedding(vocab_sizeN, embedding_sizeN, vocab_sizeN - 1) 30 | self.embeddingT = nn.Embedding(vocab_sizeT + 3, embedding_sizeT, vocab_sizeT - 1) 31 | 32 | self.lstm = nn.LSTM( 33 | embedding_sizeN + embedding_sizeT, 34 | hidden_size, 35 | num_layers=num_layers, 36 | batch_first=True, 37 | bidirectional=False 38 | ) 39 | self.w_global = nn.Linear(hidden_size * 3, vocab_sizeT + 3) # map into T 40 | 41 | def embedded_dropout(self, embed, words, scale=None): 42 | dropout = self.dropout 43 | if dropout > 0: 44 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 45 | masked_embed_weight = mask * embed.weight 46 | else: 47 | masked_embed_weight = embed.weight 48 | if scale: 49 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 50 | 51 | padding_idx = embed.padding_idx 52 | if padding_idx is None: 53 | padding_idx = -1 54 | 55 | words[words >= embed.weight.size(0)] = padding_idx 56 | 57 | X = F.embedding(words, masked_embed_weight, 58 | padding_idx, embed.max_norm, embed.norm_type, 59 | embed.scale_grad_by_freq, embed.sparse 60 | ) 61 | return X 62 | 63 | def forward( 64 | self, 65 | input, 66 | hc, 67 | enc_out, 68 | mask, 69 | h_parent 70 | ): 71 | n_input, t_input = input 72 | batch_size = n_input.size(0) 73 | 74 | # (enc_out, enc_out_W) [(batch_size, max_length, hidden_size * 2), (batch_size, max_length, hidden_size)] 75 | # mask (batch_size, max_length) 76 | # hidden_prev (batch_size, hidden_size) 77 | 78 | n_input = self.embedded_dropout(self.embeddingN, n_input) 79 | t_input = self.embedded_dropout(self.embeddingT, t_input) 80 | input = torch.cat([n_input, t_input], 1) 81 | 82 | out, (h, c) = self.lstm(input.unsqueeze(1), hc) 83 | 84 | 85 | hidden = h[-1] # use only last layer hidden in attention 86 | out = out.squeeze(1) 87 | 88 | w_t = F.log_softmax(self.w_global(torch.cat([hidden, out, h_parent], dim=1)), dim=1) 89 | return w_t, (h, c) 90 | 91 | 92 | class DecoderAttention(nn.Module): 93 | def __init__( 94 | self, 95 | hidden_size, 96 | vocab_sizeT, 97 | vocab_sizeN, 98 | embedding_sizeT, 99 | embedding_sizeN, 100 | dropout, 101 | num_layers, 102 | attn_size=50, 103 | pointer=True, 104 | device='cuda' 105 | ): 106 | super(DecoderAttention, self).__init__() 107 | self.num_layers = num_layers 108 | self.hidden_size = hidden_size 109 | self.pointer = pointer 110 | self.device = device 111 | self.dropout = dropout 112 | 113 | self.embeddingN = nn.Embedding(vocab_sizeN, embedding_sizeN, vocab_sizeN - 1) 114 | self.embeddingT = nn.Embedding(vocab_sizeT + attn_size + 2, embedding_sizeT, vocab_sizeT - 1) 115 | 116 | self.W_hidden = nn.Linear(hidden_size, hidden_size) 117 | self.W_mem2hidden = nn.Linear(hidden_size, hidden_size) 118 | self.v = nn.Linear(hidden_size, 1) 119 | 120 | self.W_context = nn.Linear( 121 | embedding_sizeN + embedding_sizeT + hidden_size, 122 | hidden_size 123 | ) 124 | self.lstm = nn.LSTM( 125 | embedding_sizeN + embedding_sizeT, 126 | hidden_size, 127 | num_layers=num_layers, 128 | batch_first=True, 129 | bidirectional=False 130 | ) 131 | self.w_global = nn.Linear(hidden_size * 3, vocab_sizeT + 2) # map into T 132 | if self.pointer: 133 | self.w_switcher = nn.Linear(hidden_size * 2, 1) 134 | self.logsigmoid = torch.nn.LogSigmoid() 135 | 136 | def embedded_dropout(self, embed, words, scale=None): 137 | dropout = self.dropout 138 | if dropout > 0: 139 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 140 | masked_embed_weight = mask * embed.weight 141 | else: 142 | masked_embed_weight = embed.weight 143 | if scale: 144 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 145 | 146 | padding_idx = embed.padding_idx 147 | if padding_idx is None: 148 | padding_idx = -1 149 | 150 | words[words >= embed.weight.size(0)] = padding_idx 151 | 152 | X = F.embedding(words, masked_embed_weight, 153 | padding_idx, embed.max_norm, embed.norm_type, 154 | embed.scale_grad_by_freq, embed.sparse 155 | ) 156 | return X 157 | 158 | def forward( 159 | self, 160 | input, 161 | hc, 162 | enc_out, 163 | mask, 164 | h_parent 165 | ): 166 | n_input, t_input = input 167 | batch_size = n_input.size(0) 168 | 169 | # (enc_out, enc_out_W) [(batch_size, max_length, hidden_size * 2), (batch_size, max_length, hidden_size)] 170 | # mask (batch_size, max_length) 171 | # hidden_prev (batch_size, hidden_size) 172 | 173 | n_input = self.embedded_dropout(self.embeddingN, n_input) 174 | t_input = self.embedded_dropout(self.embeddingT, t_input) 175 | input = torch.cat([n_input, t_input], 1) 176 | 177 | out, (h, c) = self.lstm(input.unsqueeze(1), hc) 178 | 179 | 180 | hidden = h[-1] # use only last layer hidden in attention 181 | out = out.squeeze(1) 182 | 183 | scores = self.W_hidden(hidden).unsqueeze(1) # (batch_size, max_length, hidden_size) 184 | if enc_out.shape[1] > 0: 185 | scores_mem = self.W_mem2hidden(enc_out) 186 | scores = scores.repeat(1, scores_mem.shape[1], 1) + scores_mem 187 | scores = torch.tanh(scores) 188 | scores = self.v(scores).squeeze(2) # (batch_size, max_length) 189 | scores = scores.masked_fill(mask, -1e20) # (batch_size, max_length) 190 | attn_weights = F.softmax(scores, dim=1) # (batch_size, max_length) 191 | attn_weights = attn_weights.unsqueeze(1) # (batch_size, 1, max_length) 192 | context = torch.matmul(attn_weights, enc_out).squeeze(1) # (batch_size, hidden_size) 193 | 194 | if self.pointer: 195 | w_t = F.log_softmax(self.w_global(torch.cat([context, out, h_parent], dim=1)), dim=1) 196 | attn_weights = F.log_softmax(scores, dim=1) 197 | w_s = self.w_switcher(torch.cat([context, out], dim=1)) 198 | return torch.cat([self.logsigmoid(w_s) + w_t, self.logsigmoid(-w_s) + attn_weights], dim=1), (h, c) 199 | else: 200 | w_t = F.log_softmax(self.w_global(torch.cat([context, out, h_parent], dim=1)), dim=1) 201 | return w_t, (h, c) 202 | 203 | class MixtureAttention(nn.Module): 204 | def __init__( 205 | self, 206 | hidden_size, 207 | vocab_sizeT, 208 | vocab_sizeN, 209 | embedding_sizeT, 210 | embedding_sizeN, 211 | num_layers, 212 | dropout, 213 | device='cuda', 214 | label_smoothing = 0.1, 215 | attn=True, 216 | pointer=True, 217 | attn_size=50, 218 | SOS_token=0 219 | ): 220 | super(MixtureAttention, self).__init__() 221 | self.device = device 222 | self.hidden_size = hidden_size 223 | self.dropout = dropout 224 | self.eof_N_id = vocab_sizeN - 1 225 | self.eof_T_id = vocab_sizeT - 1 226 | self.unk_id = vocab_sizeT - 2 227 | self.SOS_token = SOS_token 228 | self.attn_size = attn_size 229 | self.vocab_sizeT = vocab_sizeT 230 | self.vocab_sizeN = vocab_sizeN 231 | 232 | self.W_out = nn.Linear(hidden_size * 2, hidden_size) 233 | 234 | if attn: 235 | self.decoder = DecoderAttention( 236 | hidden_size=hidden_size, 237 | vocab_sizeT=vocab_sizeT, 238 | vocab_sizeN=vocab_sizeN, 239 | embedding_sizeT=embedding_sizeT, 240 | embedding_sizeN=embedding_sizeN, 241 | num_layers=num_layers, 242 | attn_size=attn_size, 243 | dropout=dropout, 244 | pointer=pointer, 245 | device=device 246 | ).to(device) 247 | else: 248 | self.decoder = DecoderSimple( 249 | hidden_size=hidden_size, 250 | vocab_sizeT=vocab_sizeT, 251 | vocab_sizeN=vocab_sizeN, 252 | embedding_sizeT=embedding_sizeT, 253 | embedding_sizeN=embedding_sizeN, 254 | num_layers=num_layers, 255 | dropout=dropout, 256 | device=device 257 | ).to(device) 258 | 259 | if label_smoothing > 0: 260 | self.criterion = LabelSmoothingLoss( 261 | label_smoothing, 262 | tgt_vocab_size=vocab_sizeT + attn_size + 3, 263 | ignore_index=self.eof_T_id, 264 | device=self.device 265 | ) # ignore EOF ?! 266 | else: 267 | self.criterion = nn.NLLLoss(reduction='none', ignore_index=self.eof_T_id) 268 | # 269 | 270 | self.pointer = pointer 271 | 272 | 273 | def forward( 274 | self, 275 | n_tensor, 276 | t_tensor, 277 | p_tensor 278 | ): 279 | batch_size = n_tensor.size(0) 280 | max_length = n_tensor.size(1) 281 | 282 | full_mask = (n_tensor == self.eof_N_id) 283 | 284 | input = ( 285 | torch.ones( 286 | batch_size, 287 | dtype=torch.long, 288 | device=self.device 289 | ) * self.SOS_token, 290 | torch.ones( 291 | batch_size, 292 | dtype=torch.long, 293 | device=self.device 294 | ) * self.SOS_token 295 | ) 296 | hs = torch.zeros( 297 | batch_size, 298 | max_length, 299 | self.hidden_size, 300 | requires_grad=False 301 | ).to(self.device) 302 | hc = None 303 | 304 | parent = torch.zeros( 305 | batch_size, 306 | dtype=torch.long, 307 | device=self.device 308 | ) 309 | 310 | loss = torch.tensor(0.0, device=self.device) 311 | 312 | token_losses = torch.zeros( 313 | batch_size, 314 | max_length 315 | ).to(self.device) 316 | 317 | ans = [] 318 | 319 | for iter in range(max_length): 320 | memory = hs[:, max(iter - self.attn_size, 0) : iter] 321 | output, hc = self.decoder( 322 | input, 323 | hc, 324 | memory.clone().detach(), 325 | full_mask[:, max(iter - self.attn_size, 0) : iter], 326 | hs[torch.arange(batch_size),parent].squeeze(1).clone().detach() 327 | ) 328 | hs[:, iter] = hc[0][-1] # store last layer hidden state only 329 | topv, topi = output.topk(1) 330 | input = (n_tensor[:, iter].clone(), t_tensor[:, iter].clone()) 331 | parent = p_tensor[:, iter] 332 | 333 | # print(output.shape[1]) 334 | 335 | ans.append(topi.detach()) 336 | # cond = (t_tensor[:, iter] < self.vocab_sizeT + self.attn_size).long() 337 | # masked_target = cond * t_tensor[:, iter] + (1 - cond) * self.eof_T_id 338 | target = t_tensor[:, iter] 339 | target[target >= output.shape[1]] = self.unk_id 340 | token_losses[:, iter] = self.criterion(output, t_tensor[:, iter].clone().detach()) 341 | 342 | loss = token_losses.sum() #/ batch_size 343 | return loss, torch.cat(ans, dim=1) 344 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import argparse 3 | 4 | if __name__ == '__main__': 5 | os.system('mkdir -p pickle_data') 6 | os.system('python preprocess_utils/freq_dict.py') 7 | os.system('python preprocess_utils/get_non_terminal.py') 8 | os.system('python preprocess_utils/get_terminal_dict.py') 9 | os.system('python preprocess_utils/get_terminal_whole.py') 10 | -------------------------------------------------------------------------------- /preprocess_utils/freq_dict.py: -------------------------------------------------------------------------------- 1 | #freq_dict: each terminal's frequency; terminal_num: a set about all the terminals. 2 | 3 | import numpy as np 4 | from six.moves import cPickle as pickle 5 | import json 6 | from collections import Counter 7 | import time 8 | 9 | #attention line 28: for python dataset, not exclude the last one 10 | train_filename = './json_data/python100k_train.json' 11 | test_filename = './json_data/python50k_eval.json' 12 | target_filename = './pickle_data/freq_dict_PY.pickle' 13 | 14 | EMPTY_TOKEN = '' 15 | 16 | freq_dict = Counter() 17 | terminal_num = set() 18 | terminal_num.add(EMPTY_TOKEN) 19 | 20 | def process(filename): 21 | with open(filename, encoding='latin-1') as lines: 22 | print ('Start procesing %s !!!'%(filename)) 23 | line_index = 0 24 | for line in lines: 25 | line_index += 1 26 | if line_index % 1000 == 0: 27 | print ('Processing line:', line_index) 28 | data = json.loads(line) 29 | if len(data) < 3e4: 30 | for i, dic in enumerate(data): #JS data[:-1] or PY data 31 | if 'value' in dic.keys(): 32 | terminal_num.add(dic['value']) 33 | freq_dict[dic['value']] += 1 34 | else: 35 | freq_dict[EMPTY_TOKEN] += 1 36 | 37 | def save(filename): 38 | with open(filename, 'wb') as f: 39 | save = {'freq_dict': freq_dict,'terminal_num': terminal_num} 40 | pickle.dump(save, f, protocol=2) 41 | 42 | 43 | if __name__ == '__main__': 44 | start_time = time.time() 45 | process(train_filename) 46 | process(test_filename) 47 | save(target_filename) 48 | print(freq_dict['EmptY'], freq_dict['Empty'], freq_dict['empty'], freq_dict['EMPTY'], freq_dict[EMPTY_TOKEN]) 49 | print('Finishing generating freq_dict and takes %.2f'%(time.time() - start_time)) 50 | -------------------------------------------------------------------------------- /preprocess_utils/get_non_terminal.py: -------------------------------------------------------------------------------- 1 | # rewrite on 2018/1/8 by xxx, add parent 2 | 3 | import numpy as np 4 | from six.moves import cPickle as pickle 5 | import json 6 | import time 7 | from collections import Counter, defaultdict 8 | 9 | #attention line 42: for python dataset, not exclude the last one 10 | train_filename = './json_data/python100k_train.json' 11 | test_filename = './json_data/python50k_eval.json' 12 | target_filename = './pickle_data/PY_non_terminal_small.pickle' 13 | 14 | # global variables 15 | typeDict = dict() #map N's name into its original ID(before expanding into 4*base_ID) 16 | numID = set() #the set to include all sparse ID 17 | no_empty_set = set() 18 | typeList = list() #the set to include all Types 19 | numType = 0 20 | dicID = dict() #map sparse id to dense id (remove empty id inside 4*base_ID) 21 | 22 | def process(filename): 23 | with open(filename, encoding='latin-1') as lines: 24 | print ('Start procesing %s !!!'%(filename)) 25 | line_index = 0 26 | corpus_N = list() 27 | corpus_parent = list() 28 | 29 | for line in lines: 30 | line_index += 1 31 | if line_index % 1000 == 0: 32 | print ('Processing line: ', line_index) 33 | data = json.loads(line) 34 | line_N = list() 35 | has_sibling = Counter() 36 | parent_counter = defaultdict(lambda: 1) #default parent is previous 1 37 | parent_list = list() 38 | 39 | if len(data) >= 3e4: 40 | continue 41 | 42 | for i, dic in enumerate(data): #JS data[:-1] or PY data 43 | typeName = dic['type'] 44 | if typeName in typeList: 45 | base_ID = typeDict[typeName] 46 | else: 47 | typeList.append(typeName) 48 | global numType 49 | typeDict[typeName] = numType 50 | base_ID = numType 51 | numType = numType + 1 52 | 53 | #expand the ID into the range of 4*base_ID, according to whether it has sibling or children. Sibling information is got by the ancestor's children information 54 | if 'children' in dic.keys(): 55 | if has_sibling[i]: 56 | ID = base_ID * 4 + 3 57 | else: 58 | ID = base_ID * 4 + 2 59 | 60 | childs = dic['children'] 61 | for j in childs: 62 | parent_counter[j] = j-i 63 | 64 | if len(childs) > 1: 65 | for j in childs: 66 | has_sibling[j] = 1 67 | else: 68 | if has_sibling[i]: 69 | ID = base_ID * 4 + 1 70 | else: 71 | ID = base_ID * 4 72 | #recording the N which has non-empty T 73 | if 'value' in dic.keys(): 74 | no_empty_set.add(ID) 75 | 76 | line_N.append(ID) 77 | parent_list.append(parent_counter[i]) 78 | numID.add(ID) 79 | 80 | corpus_N.append(line_N) 81 | corpus_parent.append(parent_list) 82 | return corpus_N, corpus_parent 83 | 84 | 85 | 86 | def map_dense_id(data): 87 | result = list() 88 | for line_id in data: 89 | line_new_id = list() 90 | for i in line_id: 91 | if i in dicID.keys(): 92 | line_new_id.append(dicID[i]) 93 | else: 94 | dicID[i] = len(dicID) 95 | line_new_id.append(dicID[i]) 96 | result.append(line_new_id) 97 | return result 98 | 99 | 100 | def save(filename, typeDict, numType, dicID, vocab_size, trainData, testData, trainParent, testParent, empty_set_dense): 101 | with open(filename, 'wb') as f: 102 | save = { 103 | # 'typeDict': typeDict, 104 | # 'numType': numType, 105 | # 'dicID': dicID, 106 | 'vocab_size': vocab_size, 107 | 'trainData': trainData, 108 | 'testData': testData, 109 | 'trainParent': trainParent, 110 | 'testParent': testParent, 111 | # 'typeOnlyHasEmptyValue': empty_set_dense, 112 | } 113 | pickle.dump(save, f, protocol=2) 114 | 115 | if __name__ == '__main__': 116 | start_time = time.time() 117 | trainData, trainParent = process(train_filename) 118 | testData, testParent = process(test_filename) 119 | trainData = map_dense_id(trainData) 120 | testData = map_dense_id(testData) 121 | vocab_size = len(numID) 122 | assert len(dicID) == vocab_size 123 | 124 | #for print the N which can only has empty T 125 | assert no_empty_set.issubset(numID) 126 | empty_set = numID.difference(no_empty_set) 127 | empty_set_dense = set() 128 | # print(dicID) 129 | for i in empty_set: 130 | empty_set_dense.add(dicID[i]) 131 | print('The N set that can only has empty terminals: ',len(empty_set_dense), empty_set_dense) 132 | print('The vocaburary:', vocab_size, numID) 133 | 134 | 135 | save(target_filename, typeDict, numType, dicID, vocab_size, trainData, testData, trainParent, testParent,empty_set_dense) 136 | print('Finishing generating terminals and takes %.2fs'%(time.time() - start_time)) 137 | -------------------------------------------------------------------------------- /preprocess_utils/get_terminal_dict.py: -------------------------------------------------------------------------------- 1 | #sort the freq_dict and get the terminal_dict for top terminals (include EmptY) 2 | 3 | import time 4 | from six.moves import cPickle as pickle 5 | import json 6 | from collections import Counter 7 | import operator 8 | 9 | # vocab_size = 10000 10 | # vocab_size = 1000 11 | vocab_size = 50000 12 | total_length = 92758587 # JS: 160143814, PY 92758587 13 | freq_dict_filename = './pickle_data/freq_dict_PY.pickle' 14 | # target_filename = './pickle_data/terminal_dict_10k_PY.pickle' 15 | # target_filename = './pickle_data/terminal_dict_1k_PY.pickle' 16 | target_filename = './pickle_data/terminal_dict_50k_PY.pickle' 17 | 18 | def restore_freq_dict(filename): 19 | with open(filename, 'rb') as f: 20 | save = pickle.load(f) 21 | freq_dict = save['freq_dict'] 22 | terminal_num = save['terminal_num'] 23 | return freq_dict, terminal_num 24 | 25 | def get_terminal_dict(vocab_size, freq_dict, verbose=False): 26 | terminal_dict = dict() 27 | sorted_freq_dict = sorted(freq_dict.items(), key=operator.itemgetter(1), reverse=True) 28 | if verbose == True: 29 | for i in range(100): 30 | print ('the %d frequent terminal: %s, its frequency: %.5f'%(i, sorted_freq_dict[i][0], float(sorted_freq_dict[i][1])/total_length)) 31 | new_freq_dict = sorted_freq_dict[:vocab_size] 32 | for i, (terminal, frequent) in enumerate(new_freq_dict): 33 | terminal_dict[terminal] = i 34 | return terminal_dict, sorted_freq_dict 35 | 36 | def save(filename, terminal_dict, terminal_num, sorted_freq_dict): 37 | with open(filename, 'wb') as f: 38 | save = {'terminal_dict': terminal_dict,'terminal_num': terminal_num, 'vocab_size': vocab_size, 'sorted_freq_dict': sorted_freq_dict,} 39 | pickle.dump(save, f, protocol=2) 40 | 41 | if __name__ == '__main__': 42 | start_time = time.time() 43 | freq_dict, terminal_num = restore_freq_dict(freq_dict_filename) 44 | print(freq_dict[''], freq_dict['empty']) 45 | terminal_dict, sorted_freq_dict = get_terminal_dict(vocab_size, freq_dict, True) 46 | save(target_filename, terminal_dict, terminal_num, sorted_freq_dict) 47 | print('Finishing generating terminal_dict and takes %.2f'%(time.time() - start_time)) 48 | -------------------------------------------------------------------------------- /preprocess_utils/get_terminal_whole.py: -------------------------------------------------------------------------------- 1 | #According to the terminal_dict you choose (i.e. 5k, 10k, 50k), parse the json file and turn them into ids that are stored in pickle file 2 | #Output just one vector for terminal, the upper part is the word id while the lower part is the location 3 | # 0108 revise the Empty into EmptY, normal to NormaL 4 | # Here attn_size matters 5 | 6 | import numpy as np 7 | from six.moves import cPickle as pickle 8 | import json 9 | from collections import deque 10 | import time 11 | 12 | #attention line 48: for python dataset, not exclude the last one 13 | # terminal_dict_filename = './pickle_data/terminal_dict_10k_PY.pickle' 14 | # terminal_dict_filename = './pickle_data/terminal_dict_1k_PY.pickle' 15 | terminal_dict_filename = './pickle_data/terminal_dict_50k_PY.pickle' 16 | train_filename = './json_data/python100k_train.json' 17 | test_filename = './json_data/python50k_eval.json' 18 | # target_filename = './pickle_data/PY_terminal_10k_whole.pickle' 19 | # target_filename = './pickle_data/PY_terminal_1k_whole.pickle' 20 | target_filename = './pickle_data/PY_terminal_50k_whole.pickle' 21 | 22 | 23 | def restore_terminal_dict(filename): 24 | with open(filename, 'rb') as f: 25 | save = pickle.load(f) 26 | terminal_dict = save['terminal_dict'] 27 | terminal_num = save['terminal_num'] 28 | vocab_size = save['vocab_size'] 29 | return terminal_dict, terminal_num, vocab_size #vocab_size is 50k, and also the unk_id 30 | 31 | def process(filename, terminal_dict, unk_id, attn_size, verbose=False, is_train=False): 32 | with open(filename, encoding='latin-1') as lines: 33 | print ('Start procesing %s !!!'%(filename)) 34 | terminal_corpus = list() 35 | attn_que = deque(maxlen=attn_size) 36 | attn_success_total = 0 37 | attn_fail_total = 0 38 | length_total = 0 39 | line_index = 0 40 | for line in lines: 41 | line_index += 1 42 | # if is_train and line_index == 11: 43 | # continue 44 | if line_index % 1000 == 0: 45 | print ('Processing line:', line_index) 46 | data = json.loads(line) 47 | if len(data) < 3e4: 48 | terminal_line = list() 49 | attn_que.clear() # have a new queue for each file 50 | attn_success_cnt = 0 51 | attn_fail_cnt = 0 52 | for i, dic in enumerate(data): ##JS data[:-1] or PY data 53 | if 'value' in dic.keys(): 54 | dic_value = dic['value'] 55 | if dic_value in terminal_dict.keys(): #take long time!!! 56 | terminal_line.append(terminal_dict[dic_value]) 57 | attn_que.append('NormaL') 58 | else: 59 | if dic_value in attn_que: 60 | location_index = [len(attn_que)-ind for ind,x in enumerate(attn_que) if x==dic_value][-1] 61 | location_id = unk_id + 1 + (location_index) 62 | # print('\nattn_success!! its value is ', dic_value) 63 | # print('The current file index: ', line_index, ', the location index', location_index,', the location_id: ', location_id, ',\n the attn_que', attn_que) 64 | terminal_line.append(location_id) 65 | attn_success_cnt += 1 66 | else: 67 | attn_fail_cnt += 1 68 | terminal_line.append(unk_id) 69 | attn_que.append(dic_value) 70 | else: 71 | terminal_line.append(terminal_dict['']) 72 | attn_que.append('') 73 | terminal_corpus.append(terminal_line) 74 | attn_success_total += attn_success_cnt 75 | attn_fail_total += attn_fail_cnt 76 | attn_total = attn_success_total + attn_fail_total 77 | length_total += len(data) 78 | # print ('Process line', line_index, 'attn_success_cnt', attn_success_cnt, 'attn_fail_cnt', attn_fail_cnt,'data length', len(data)) 79 | if verbose and line_index % 1000 == 0: 80 | print('\nUntil line %d: attn_success_total: %d, attn_fail_total: %d, success/attn_total: %.4f, length_total: %d, attn_success percentage: %.4f, total unk percentage: %.4f\n'% 81 | (line_index, attn_success_total, attn_fail_total, float(attn_success_total)/attn_total, length_total, 82 | float(attn_success_total)/length_total, float(attn_total)/length_total)) 83 | with open('output.txt', 'a') as fout: 84 | fout.write('Statistics: attn_success_total: %d, attn_fail_total: %d, success/fail: %.4f, length_total: %d, attn_success percentage: %.4f, total unk percentage: %.4f\n'% 85 | (attn_success_total, attn_fail_total, float(attn_success_total)/attn_fail_total, length_total, 86 | float(attn_success_total)/length_total, float(attn_success_total + attn_fail_total)/length_total)) 87 | 88 | return terminal_corpus 89 | 90 | def save(filename, terminal_dict, terminal_num, vocab_size, attn_size, trainData, testData): 91 | with open(filename, 'wb') as f: 92 | save = {'terminal_dict': terminal_dict, 93 | 'terminal_num': terminal_num, 94 | 'vocab_size': vocab_size, 95 | 'attn_size': attn_size, 96 | 'trainData': trainData, 97 | 'testData': testData, 98 | } 99 | pickle.dump(save, f, protocol=2) 100 | 101 | if __name__ == '__main__': 102 | start_time = time.time() 103 | attn_size = 50 104 | terminal_dict, terminal_num, vocab_size = restore_terminal_dict(terminal_dict_filename) 105 | trainData = process(train_filename, terminal_dict, vocab_size, attn_size=attn_size, verbose=True, is_train=True) 106 | testData = process(test_filename, terminal_dict, vocab_size, attn_size=attn_size, verbose=True, is_train=False) 107 | save(target_filename, terminal_dict, terminal_num, vocab_size, attn_size, trainData, testData) 108 | print('Finishing generating terminals and takes %.2f'%(time.time() - start_time)) 109 | -------------------------------------------------------------------------------- /preprocess_utils/get_total_length.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | train_filename = './json_data/python100k_train.json' 5 | test_filename = './json_data/python50k_eval.json' 6 | 7 | def process(filename): 8 | with open(filename, encoding='latin-1') as lines: 9 | print ('Start procesing %s !!!'%(filename)) 10 | length = 0 11 | line_index = 0 12 | for line in lines: 13 | line_index += 1 14 | if line_index % 1000 == 0: 15 | print ('Processing line:', line_index) 16 | data = json.loads(line) 17 | if len(data) < 3e4: 18 | length += len(data[:-1]) # total number of AST nodes 19 | return length 20 | 21 | if __name__ == '__main__': 22 | start_time = time.time() 23 | train_len = process(train_filename) 24 | test_len = process(test_filename) 25 | print('total_length is ', train_len + test_len) 26 | print('Finishing counting the length and takes %.2f'%(time.time() - start_time)) 27 | -------------------------------------------------------------------------------- /preprocess_utils/utils.py: -------------------------------------------------------------------------------- 1 | #Utilities for preprocess the data 2 | 3 | import numpy as np 4 | from six.moves import cPickle as pickle 5 | import json 6 | from collections import deque 7 | import time 8 | 9 | 10 | def read_N_pickle(filename): 11 | with open(filename, 'rb') as f: 12 | print ("Reading data from ", filename) 13 | save = pickle.load(f) 14 | train_data = save['trainData'] 15 | test_data = save['testData'] 16 | vocab_size = save['vocab_size'] 17 | print ('the vocab_size is %d' %vocab_size) 18 | print ('the number of training data is %d' %(len(train_data))) 19 | print ('the number of test data is %d' %(len(test_data))) 20 | print ('Finish reading data!!') 21 | return train_data, test_data, vocab_size 22 | 23 | def read_T_pickle(filename): 24 | with open(filename, 'rb') as f: 25 | print ("Reading data from ", filename) 26 | save = pickle.load(f) 27 | train_data = save['trainData'] 28 | test_data = save['testData'] 29 | vocab_size = save['vocab_size'] 30 | attn_size = save['attn_size'] 31 | print ('the vocab_size is %d' %vocab_size) 32 | print ('the attn_size is %d' %attn_size) 33 | print ('the number of training data is %d' %(len(train_data))) 34 | print ('the number of test data is %d' %(len(test_data))) 35 | print ('Finish reading data!!') 36 | return train_data, test_data, vocab_size, attn_size 37 | 38 | def input_data(N_filename, T_filename): 39 | start_time = time.time() 40 | with open(N_filename, 'rb') as f: 41 | print ("reading data from ", N_filename) 42 | save = pickle.load(f) 43 | train_dataN = save['trainData'] 44 | test_dataN = save['testData'] 45 | train_dataP = save['trainParent'] 46 | test_dataP = save['testParent'] 47 | vocab_sizeN = save['vocab_size'] 48 | print ('the vocab_sizeN is %d (not including the eof)' %vocab_sizeN) 49 | print ('the number of training data is %d' %(len(train_dataN))) 50 | print ('the number of test data is %d\n' %(len(test_dataN))) 51 | 52 | with open(T_filename, 'rb') as f: 53 | print ("reading data from ", T_filename) 54 | save = pickle.load(f) 55 | train_dataT = save['trainData'] 56 | test_dataT = save['testData'] 57 | vocab_sizeT = save['vocab_size'] 58 | attn_size = save['attn_size'] 59 | print ('the vocab_sizeT is %d (not including the unk and eof)' %vocab_sizeT) 60 | print ('the attn_size is %d' %attn_size) 61 | print ('the number of training data is %d' %(len(train_dataT))) 62 | print ('the number of test data is %d' %(len(test_dataT))) 63 | print ('Finish reading data and take %.2f\n'%(time.time()-start_time)) 64 | 65 | return train_dataN, test_dataN, vocab_sizeN, train_dataT, test_dataT, vocab_sizeT, attn_size, train_dataP, test_dataP 66 | 67 | 68 | def save(filename, terminal_dict, terminal_num, vocab_size, sorted_freq_dict): 69 | with open(filename, 'wb') as f: 70 | save = {'terminal_dict': terminal_dict,'terminal_num': terminal_num, 'vocab_size': vocab_size, 'sorted_freq_dict': sorted_freq_dict,} 71 | pickle.dump(save, f) 72 | 73 | def change_protocol_for_N(filename): 74 | 75 | f = open(filename, 'rb') 76 | save = pickle.load(f) 77 | typeDict = save['typeDict'] 78 | numType = save['numType'] 79 | dicID = save['dicID'] 80 | vocab_size = save['vocab_size'] 81 | trainData = save['trainData'] 82 | testData = save['testData'] 83 | typeOnlyHasEmptyValue = save['typeOnlyHasEmptyValue'] 84 | f.close() 85 | 86 | f = open(filename, 'wb') 87 | save = { 88 | 'typeDict': typeDict, 89 | 'numType': numType, 90 | 'dicID': dicID, 91 | 'vocab_size': vocab_size, 92 | 'trainData': trainData, 93 | 'testData': testData, 94 | 'typeOnlyHasEmptyValue': typeOnlyHasEmptyValue, 95 | } 96 | pickle.dump(save, f, protocol=2) 97 | f.close() 98 | 99 | 100 | def change_protocol_for_T(filename): 101 | f = open(filename, 'rb') 102 | save = pickle.load(f) 103 | terminal_dict = save['terminal_dict'] 104 | terminal_num = save['terminal_num'] 105 | vocab_size = save['vocab_size'] 106 | attn_size = save['attn_size'] 107 | trainData = save['trainData'] 108 | testData = save['testData'] 109 | f.close() 110 | 111 | f = open(target_filename, 'wb') 112 | save = {'terminal_dict': terminal_dict, 113 | 'terminal_num': terminal_num, 114 | 'vocab_size': vocab_size, 115 | 'attn_size': attn_size, 116 | 'trainData': trainData, 117 | 'testData': testData, 118 | } 119 | pickle.dump(save, f, protocol=2) 120 | f.close() 121 | 122 | if __name__ == '__main__': 123 | 124 | # train_filename = '../json_data/small_programs_training.json' 125 | # test_filename = '../json_data/small_programs_eval.json' 126 | # N_pickle_filename = '../pickle_data/JS_non_terminal.pickle' 127 | # T_pickle_filename = '../pickle_data/JS_terminal_1k.pickle' 128 | filename = '../pickle_data/PY_non_terminal.pickle' 129 | read_N_pickle(filename) 130 | # filename = '../pickle_data/JS_terminal_1k_whole.pickle' 131 | # change_protocol_for_T(filename, target_filename) 132 | 133 | 134 | # N_train_data, N_test_data, N_vocab_size = read_N_pickle(N_pickle_filename) 135 | # T_train_data, T_test_data, T_vocab_size, attn_size = read_T_pickle(T_pickle_filename) 136 | # print(len(N_train_data), len(T_train_data)) 137 | 138 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | rm -r logs/pointer_vocab_50k 2 | # python3 train.py --config=configs/pointer_vocab_10k.yml 3 | python train.py --config=configs/pointer_vocab_50k.yml 4 | python train.py --config=configs/attn_lstm_vocab_1k.yml 5 | python train.py --config=configs/attn_lstm_vocab_50k.yml 6 | #python3 train.py --config=configs/simple_lstm_vocab_1k.yml 7 | python train.py --config=configs/simple_lstm_vocab_50k.yml -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from data import * 3 | import os 4 | from tqdm import tqdm 5 | import yaml 6 | from utils import DotDict, adjust_learning_rate, accuracy 7 | import torch 8 | try: 9 | from torch.utils.tensorboard import SummaryWriter 10 | except: 11 | from tensorboardX import SummaryWriter 12 | import argparse 13 | 14 | def train(config): 15 | writer = SummaryWriter('logs/' + config.name) 16 | 17 | device = config.train.device 18 | 19 | data_train = MainDataset( 20 | N_filename = config.data.N_filename, 21 | T_filename = config.data.T_filename, 22 | is_train=True, 23 | truncate_size=config.data.truncate_size 24 | ) 25 | 26 | data_val = MainDataset( 27 | N_filename = config.data.N_filename, 28 | T_filename = config.data.T_filename, 29 | is_train=False, 30 | truncate_size=config.data.truncate_size 31 | ) 32 | 33 | train_loader = torch.utils.data.DataLoader( 34 | data_train, 35 | batch_size=config.train.batch_size, 36 | shuffle=False, 37 | num_workers=config.train.num_workers, 38 | collate_fn=data_train.collate_fn 39 | ) 40 | 41 | test_loader = torch.utils.data.DataLoader( 42 | data_val, 43 | batch_size=config.train.batch_size, 44 | shuffle=False, 45 | num_workers=config.train.num_workers, 46 | collate_fn=data_val.collate_fn 47 | ) 48 | 49 | ignored_index = data_train.vocab_sizeT - 1 50 | unk_index = data_train.vocab_sizeT - 2 51 | 52 | model = MixtureAttention( 53 | hidden_size = config.model.hidden_size, 54 | vocab_sizeT = data_train.vocab_sizeT, 55 | vocab_sizeN = data_train.vocab_sizeN, 56 | attn_size = data_train.attn_size, 57 | embedding_sizeT = config.model.embedding_sizeT, 58 | embedding_sizeN = config.model.embedding_sizeN, 59 | num_layers = config.model.num_layers, 60 | dropout = config.model.dropout, 61 | label_smoothing = config.model.label_smoothing, 62 | pointer = config.model.pointer, 63 | attn = config.model.attn, 64 | device = device 65 | ) 66 | 67 | start_epoch = 0 68 | if config.train.LOAD_EPOCH is not None: 69 | cpk = torch.load('checkpoints/%s/epoch_%04d.pth' % (config.name, config.train.LOAD_EPOCH)) 70 | model.load_state_dict(cpk['model']) 71 | model = model.to(device) 72 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr) 73 | optimizer.load_state_dict(cpk['optimizer']) 74 | start_epoch = cpk['epoch'] + 1 75 | print('loaded', start_epoch, '!') 76 | else: 77 | model = model.to(device) 78 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr) 79 | 80 | for epoch in range(start_epoch, config.train.epochs): 81 | lr = config.train.lr * config.train.lr_decay ** max(epoch - 1, 0) 82 | adjust_learning_rate(optimizer, lr) 83 | print("epoch: %04d" % epoch) 84 | loss_avg, acc_avg = 0, 0 85 | total = len(train_loader) 86 | 87 | model = model.train() 88 | for i, (n, t, p) in enumerate(tqdm(train_loader)): 89 | n, t, p = n.to(device), t.to(device), p.to(device) 90 | optimizer.zero_grad() 91 | 92 | loss, ans = model(n, t, p) 93 | loss_avg += loss.item() 94 | acc_item = accuracy(ans.cpu().numpy().flatten(), t.cpu().numpy().flatten(), ignored_index, unk_index) 95 | acc_avg += acc_item 96 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_value) 97 | loss.backward() 98 | 99 | if (i + 1) % 100 == 0: 100 | print('\ntemp_loss: %f, temp_acc: %f' % (loss.item(), acc_item), flush=True) 101 | writer.add_scalar('train/loss', loss.item(), epoch * total + i) 102 | writer.add_scalar('train/acc', acc_item, epoch * total + i) 103 | 104 | # if (i + 1) % 1000 == 0: 105 | # break 106 | 107 | optimizer.step() 108 | 109 | print("\navg_loss: %f, avg_acc: %f" % (loss_avg/total, acc_avg/total)) 110 | 111 | if (epoch + 1) % config.train.eval_period == 0: 112 | with torch.no_grad(): 113 | model = model.eval() 114 | acc = 0. 115 | loss_eval = 0. 116 | for i, (n, t, p) in enumerate(tqdm(test_loader)): 117 | n, t, p = n.to(device), t.to(device), p.to(device) 118 | loss, ans = model(n, t, p) 119 | loss_eval += loss.item() 120 | acc += accuracy(ans.cpu().numpy().flatten(), t.cpu().numpy().flatten(), ignored_index, unk_index) 121 | acc /= len(test_loader) 122 | loss_eval /= len(test_loader) 123 | print('\navg acc:', acc, 'avg loss:', loss_eval) 124 | writer.add_scalar('val/loss', loss_eval, epoch) 125 | writer.add_scalar('val/acc', acc, epoch) 126 | if (epoch + 1) % config.train.checkpoint_period == 0: 127 | os.system('mkdir -p checkpoints/' + config.name) 128 | torch.save({ 129 | 'model': model.state_dict(), 130 | 'optimizer': optimizer.state_dict(), 131 | 'epoch': epoch 132 | }, 'checkpoints/%s/epoch_%04d.pth' % (config.name, epoch)) 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser(description='Training model.') 136 | parser.add_argument('--config', default='configs/pointer_vocab_10k.yml', 137 | help='path to config file') 138 | args = parser.parse_args() 139 | with open(args.config, 'r') as f: 140 | config = DotDict(yaml.safe_load(f)) 141 | train(config) 142 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from sklearn.metrics import accuracy_score 6 | 7 | 8 | class LabelSmoothingLoss(nn.Module): 9 | """ 10 | With label smoothing, 11 | KL-divergence between q_{smoothed ground truth prob.}(w) 12 | and p_{prob. computed by model}(w) is minimized. 13 | """ 14 | def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=1, device='cuda'): 15 | assert 0.0 < label_smoothing <= 1.0 16 | self.ignore_index = ignore_index 17 | super(LabelSmoothingLoss, self).__init__() 18 | self.smoothing_value = label_smoothing / (tgt_vocab_size - 2) 19 | self.device = device 20 | self.confidence = 1.0 - label_smoothing 21 | 22 | def forward(self, output, target): 23 | """ 24 | output (FloatTensor): batch_size x n_classes 25 | target (LongTensor): batch_size 26 | """ 27 | one_hot = torch.full((output.shape[1],), self.smoothing_value).to(self.device) 28 | one_hot[self.ignore_index] = 0 29 | model_prob = one_hot.repeat(target.size(0), 1) 30 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 31 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 32 | 33 | return F.kl_div(output, model_prob, reduction='sum') 34 | 35 | def accuracy(out, target, ignored_index, unk_index): 36 | out_ = np.array(out[target != ignored_index]) 37 | target_ = np.array(target[target != ignored_index]) 38 | out_[out_ == unk_index] = -1 39 | return accuracy_score(out_, target_) 40 | 41 | class DotDict(dict): 42 | """A dictionary that supports dot notation 43 | as well as dictionary access notation 44 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 45 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 46 | get attributes: d.val2 or d['val2'] 47 | """ 48 | __getattr__ = dict.__getitem__ 49 | __setattr__ = dict.__setitem__ 50 | __delattr__ = dict.__delitem__ 51 | 52 | def __init__(self, dct): 53 | for key, value in dct.items(): 54 | if hasattr(value, 'keys'): 55 | value = DotDict(value) 56 | self[key] = value 57 | 58 | def adjust_learning_rate(optimizer, lr): 59 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 60 | for param_group in optimizer.param_groups: 61 | param_group['lr'] = lr 62 | --------------------------------------------------------------------------------