├── src ├── __init__.py ├── messenger.py └── utils.py ├── requirements.txt ├── .gitignore ├── configs └── config_code2seq.yml ├── README.md ├── LICENSE └── notebooks ├── preparation.ipynb └── code2seq.ipynb /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | tensorboard==2.0.0 3 | gensim==3.6.0 4 | slackweb==1.0.5 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #etc.. 2 | logs/ 3 | data/ 4 | runs/ 5 | slack/ 6 | src/__pycache__ 7 | 8 | # Jupyter Notebook 9 | .ipynb_checkpoints/ 10 | notebooks/.ipynb_checkpoints/ 11 | src/.ipynb_checkpoints 12 | 13 | -------------------------------------------------------------------------------- /src/messenger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import slackweb 3 | 4 | class Info(object): 5 | def __init__(self, info_prefix='', slack_url = None): 6 | 7 | self.info_prefix = info_prefix 8 | self.slack = None 9 | if slack_url is not None: 10 | self.slack = slackweb.Slack(url = slack_url) 11 | self.slack.notify(text = "="*80) 12 | 13 | def print_msg(self, msg): 14 | text = self.info_prefix + ' ' + msg 15 | 16 | print(text) 17 | logging.info(text) 18 | if self.slack is not None: 19 | self.slack.notify(text = text) 20 | 21 | 22 | -------------------------------------------------------------------------------- /configs/config_code2seq.yml: -------------------------------------------------------------------------------- 1 | data: 2 | home: ../data 3 | dict: /java-small/java-small.dict.c2s 4 | train: /java-small/train 5 | valid: /java-small/val 6 | test: /java-small/test 7 | 8 | 9 | training: 10 | batch_size: 256 11 | num_epochs: 50 12 | lr: 0.001 13 | teacher_forcing_rate: 0.4 14 | nesterov: True 15 | weight_decay: 0.01 16 | momentum: 0.95 17 | decay_ratio: 0.95 18 | save_name: /model.pth 19 | warm_up: 1 20 | patience: 2 21 | 22 | 23 | model: 24 | token_size: 128 25 | hidden_size: 64 26 | num_layers: 1 27 | bidirectional: True 28 | rnn_dropout: 0.5 29 | embeddings_dropout: 0.3 30 | num_k : 200 31 | 32 | etc: 33 | info_prefix: code2seq 34 | slack_url_path: ../slack/slack_url.yml 35 | 36 | comment: code2seq -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # code2seq [WIP] 2 | 3 | A PyTorch re-implementation code for "[code2seq: Generating Sequences from Structured Representations of Code](https://arxiv.org/abs/1808.01400)" 4 | 5 | * Paper(Arxiv) : https://arxiv.org/abs/1808.01400 6 | * Official Github : https://github.com/tech-srl/code2seq 7 | 8 | ## Requirements 9 | Please see requirements.txt 10 | 11 | ## Usage 12 | * notebooks/preparation.ipynb 13 | for downloading dataset, making some directories etc. 14 | 15 | * notebooks/code2seq.ipynb 16 | for training and evaluating the model. 17 | 18 | ## Memo 19 | * Beam search is not implemented. 20 | * GCP AI Platform Notebooks is used to train model. 21 | * AI Platform Notebooks requires google_compute_engine api so please install this before installing other packages if you use AI Platform Notebooks. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /notebooks/preparation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Preparation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%cd ../" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "!mkdir data runs logs" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Packages" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# if you use GCP, install google_compute_engine\n", 42 | "!pip install google_compute_engine\n", 43 | "!pip install -r requirements.txt" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "### Dataset" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "!wget https://s3.amazonaws.com/code2seq/datasets/java-small-preprocessed.tar.gz -P data/" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "!tar -xvzf data/java-small-preprocessed.tar.gz -C data/" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "/home/jupyter/code2seq/data/java-small\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "%cd data/java-small/\n", 86 | "\n", 87 | "#for dev\n", 88 | "!head -20000 java-small.train.c2s > java-small.train_dev.c2s" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "!mkdir train train_dev val test" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 7, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "!split -d -a 6 -l 1 --additional-suffix=.txt java-small.test.c2s test/\n", 107 | "!split -d -a 6 -l 1 --additional-suffix=.txt java-small.val.c2s val/\n", 108 | "!split -d -a 6 -l 1 --additional-suffix=.txt java-small.train.c2s train/\n", 109 | "!split -d -a 6 -l 1 --additional-suffix=.txt java-small.train_dev.c2s train_dev/" 110 | ] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "Python 3", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.7.4" 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 4 134 | } 135 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nltk import bleu_score 3 | 4 | PAD = 0 5 | BOS = 1 6 | EOS = 2 7 | UNK = 3 8 | 9 | class Vocab(object): 10 | def __init__(self, word2id={}): 11 | 12 | self.word2id = dict(word2id) 13 | self.id2word = {v: k for k, v in self.word2id.items()} 14 | 15 | def build_vocab(self, sentences, min_count=1): 16 | word_counter = {} 17 | for word in sentences: 18 | word_counter[word] = word_counter.get(word, 0) + 1 19 | 20 | for word, count in sorted(word_counter.items(), key=lambda x: -x[1]): 21 | if count < min_count: 22 | break 23 | _id = len(self.word2id) 24 | self.word2id.setdefault(word, _id) 25 | self.id2word[_id] = word 26 | 27 | def sentence_to_ids(vocab, sentence): 28 | ids = [vocab.word2id.get(word, UNK) for word in sentence] 29 | ids += [EOS] 30 | return ids 31 | 32 | def ids_to_sentence(vocab, ids): 33 | return [vocab.id2word[_id] for _id in ids] 34 | 35 | def trim_eos(ids): 36 | if EOS in ids: 37 | return ids[:ids.index(EOS)] 38 | else: 39 | return ids 40 | 41 | def calculate_results_set(refs, preds): 42 | #calc precision, recall and F1 43 | #same as https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L239 44 | 45 | filterd_refs = [ref[:ref.index(EOS)] for ref in refs] 46 | filterd_preds = [pred[:pred.index(EOS)] if EOS in pred else pred for pred in preds] 47 | 48 | filterd_refs = [list(set(ref)) for ref in filterd_refs] 49 | filterd_preds = [list(set(pred)) for pred in filterd_preds] 50 | 51 | true_positive, false_positive, false_negative = 0, 0, 0 52 | 53 | for filterd_pred, filterd_ref in zip(filterd_preds, filterd_refs): 54 | 55 | for fp in filterd_pred: 56 | if fp in filterd_ref: 57 | true_positive += 1 58 | else: 59 | false_positive += 1 60 | 61 | for fr in filterd_ref: 62 | if not fr in filterd_pred: 63 | false_negative += 1 64 | 65 | # https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L282 66 | if true_positive + false_positive > 0: 67 | precision = true_positive / (true_positive + false_positive) 68 | else: 69 | precision = 0 70 | 71 | if true_positive + false_negative > 0: 72 | recall = true_positive / (true_positive + false_negative) 73 | else: 74 | recall = 0 75 | 76 | if precision + recall > 0: 77 | f1 = 2 * precision * recall / (precision + recall) 78 | else: 79 | f1 = 0 80 | 81 | return precision, recall, f1 82 | 83 | def calculate_results(refs, preds): 84 | #calc precision, recall and F1 85 | #same as https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L239 86 | 87 | filterd_refs = [ref[:ref.index(EOS)] for ref in refs] 88 | filterd_preds = [pred[:pred.index(EOS)] if EOS in pred else pred for pred in preds] 89 | 90 | true_positive, false_positive, false_negative = 0, 0, 0 91 | 92 | for filterd_pred, filterd_ref in zip(filterd_preds, filterd_refs): 93 | 94 | if filterd_pred == filterd_ref: 95 | true_positive += len(filterd_pred) 96 | continue 97 | 98 | for fp in filterd_pred: 99 | if fp in filterd_ref: 100 | true_positive += 1 101 | else: 102 | false_positive += 1 103 | 104 | for fr in filterd_ref: 105 | if not fr in filterd_pred: 106 | false_negative += 1 107 | 108 | # https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L282 109 | if true_positive + false_positive > 0: 110 | precision = true_positive / (true_positive + false_positive) 111 | else: 112 | precision = 0 113 | 114 | if true_positive + false_negative > 0: 115 | recall = true_positive / (true_positive + false_negative) 116 | else: 117 | recall = 0 118 | 119 | if precision + recall > 0: 120 | f1 = 2 * precision * recall / (precision + recall) 121 | else: 122 | f1 = 0 123 | 124 | return precision, recall, f1 125 | 126 | class EarlyStopping(object): 127 | def __init__(self, filename = None, patience=3, warm_up=0, verbose=False): 128 | 129 | self.patience = patience 130 | self.verbose = verbose 131 | self.counter = 0 132 | self.best_score = None 133 | self.early_stop = False 134 | self.warm_up = warm_up 135 | self.filename = filename 136 | 137 | def __call__(self, score, model, epoch): 138 | 139 | if self.best_score is None: 140 | self.best_score = score 141 | self.save_checkpoint(score, model) 142 | 143 | elif (score <= self.best_score) and (epoch > self.warm_up) : 144 | self.counter += 1 145 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 146 | if self.counter >= self.patience: 147 | self.early_stop = True 148 | else: 149 | if (epoch <= self.warm_up): 150 | print('Warming up until epoch', self.warm_up) 151 | 152 | else: 153 | if self.verbose: 154 | print(f'Score improved. ({self.best_score:.6f} --> {score:.6f}).') 155 | 156 | self.best_score = score 157 | self.save_checkpoint(score, model) 158 | self.counter = 0 159 | 160 | def save_checkpoint(self, score, model): 161 | 162 | if self.filename is not None: 163 | torch.save(model.state_dict(), self.filename) 164 | 165 | if self.verbose: 166 | print('Model saved...') 167 | 168 | def pad_seq(seq, max_length): 169 | # pad tail of sequence to extend sequence length up to max_length 170 | res = seq + [PAD for i in range(max_length - len(seq))] 171 | return res 172 | 173 | def calc_bleu(refs, hyps): 174 | _refs = [[ref[:ref.index(EOS)]] for ref in refs] 175 | _hyps = [hyp[:hyp.index(EOS)] if EOS in hyp else hyp for hyp in hyps] 176 | return 100 * bleu_score.corpus_bleu(_refs, _hyps) -------------------------------------------------------------------------------- /notebooks/code2seq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "5OLd1AL5gIQw" 8 | }, 9 | "source": [ 10 | "# A PyTorch re-implementation code for \"code2seq: Generating Sequences from Structured Representations of Code\"\n", 11 | "\n", 12 | "* Paper(Arxiv) : https://arxiv.org/abs/1808.01400 \n", 13 | "* Official Github : https://github.com/tech-srl/code2seq" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%load_ext autoreload\n", 23 | "%autoreload 2" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "colab": {}, 31 | "colab_type": "code", 32 | "id": "cnth1kMQwrNl" 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "import sys\n", 37 | "sys.path.append('../')\n", 38 | "\n", 39 | "import os\n", 40 | "import time\n", 41 | "import yaml\n", 42 | "import random\n", 43 | "import numpy as np\n", 44 | "import warnings\n", 45 | "import logging\n", 46 | "import pickle\n", 47 | "from datetime import datetime\n", 48 | "from tqdm import tqdm_notebook as tqdm\n", 49 | "\n", 50 | "from sklearn.model_selection import train_test_split\n", 51 | "from sklearn.utils import shuffle\n", 52 | "\n", 53 | "import torch\n", 54 | "from torch import einsum\n", 55 | "import torch.nn as nn\n", 56 | "import torch.nn.functional as F\n", 57 | "import torch.optim as optim\n", 58 | "from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence\n", 59 | "\n", 60 | "from src import utils, messenger" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Parameters" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "config_file = '../configs/config_code2seq.yml'" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "config = yaml.load(open(config_file), Loader=yaml.FullLoader)\n", 86 | "\n", 87 | "# Data source\n", 88 | "DATA_HOME = config['data']['home']\n", 89 | "DICT_FILE = DATA_HOME + config['data']['dict']\n", 90 | "TRAIN_DIR = DATA_HOME + config['data']['train']\n", 91 | "VALID_DIR = DATA_HOME + config['data']['valid']\n", 92 | "TEST_DIR = DATA_HOME + config['data']['test']\n", 93 | "\n", 94 | "# Training parameter\n", 95 | "batch_size = config['training']['batch_size']\n", 96 | "num_epochs = config['training']['num_epochs']\n", 97 | "lr = config['training']['lr']\n", 98 | "teacher_forcing_rate = config['training']['teacher_forcing_rate']\n", 99 | "nesterov = config['training']['nesterov']\n", 100 | "weight_decay = config['training']['weight_decay']\n", 101 | "momentum = config['training']['momentum']\n", 102 | "decay_ratio = config['training']['decay_ratio']\n", 103 | "save_name = config['training']['save_name']\n", 104 | "warm_up = config['training']['warm_up']\n", 105 | "patience = config['training']['patience']\n", 106 | "\n", 107 | "\n", 108 | "\n", 109 | "# Model parameter\n", 110 | "token_size = config['model']['token_size']\n", 111 | "hidden_size = config['model']['hidden_size']\n", 112 | "num_layers = config['model']['num_layers']\n", 113 | "bidirectional = config['model']['bidirectional']\n", 114 | "rnn_dropout = config['model']['rnn_dropout']\n", 115 | "embeddings_dropout = config['model']['embeddings_dropout']\n", 116 | "num_k = config['model']['num_k']\n", 117 | "\n", 118 | "# etc\n", 119 | "slack_url_path = config['etc']['slack_url_path']\n", 120 | "info_prefix = config['etc']['info_prefix']" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "slack_url = None\n", 130 | "if os.path.exists(slack_url_path):\n", 131 | " slack_url = yaml.load(open(slack_url_path), Loader=yaml.FullLoader)['slack_url']" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "warnings.filterwarnings('ignore')\n", 141 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 142 | "\n", 143 | "torch.manual_seed(1)\n", 144 | "random_state = 42\n", 145 | "\n", 146 | "run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')\n", 147 | "log_file = '../logs/' + run_id + '.log'\n", 148 | "exp_dir = '../runs/' + run_id\n", 149 | "os.mkdir(exp_dir)\n", 150 | "\n", 151 | "logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)\n", 152 | "msgr = messenger.Info(info_prefix, slack_url)\n", 153 | "\n", 154 | "msgr.print_msg('run_id : {}'.format(run_id))\n", 155 | "msgr.print_msg('log_file : {}'.format(log_file))\n", 156 | "msgr.print_msg('exp_dir : {}'.format(exp_dir))\n", 157 | "msgr.print_msg('device : {}'.format(device))\n", 158 | "msgr.print_msg(str(config))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "colab": {}, 166 | "colab_type": "code", 167 | "id": "JM84CUHawrNv" 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "PAD_TOKEN = '' \n", 172 | "BOS_TOKEN = '' \n", 173 | "EOS_TOKEN = ''\n", 174 | "UNK_TOKEN = ''\n", 175 | "PAD = 0\n", 176 | "BOS = 1\n", 177 | "EOS = 2\n", 178 | "UNK = 3" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "colab": {}, 186 | "colab_type": "code", 187 | "id": "mjRMDX0gwrNy" 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "# load vocab dict\n", 192 | "with open(DICT_FILE, 'rb') as file:\n", 193 | " subtoken_to_count = pickle.load(file)\n", 194 | " node_to_count = pickle.load(file) \n", 195 | " target_to_count = pickle.load(file)\n", 196 | " max_contexts = pickle.load(file)\n", 197 | " num_training_examples = pickle.load(file)\n", 198 | " msgr.print_msg('Dictionaries loaded.')" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": { 205 | "colab": {}, 206 | "colab_type": "code", 207 | "id": "bRqsluoYwrOC" 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "# making vocab dicts for terminal subtoken, nonterminal node and target.\n", 212 | "\n", 213 | "word2id = {\n", 214 | " PAD_TOKEN: PAD,\n", 215 | " BOS_TOKEN: BOS,\n", 216 | " EOS_TOKEN: EOS,\n", 217 | " UNK_TOKEN: UNK,\n", 218 | " }\n", 219 | "\n", 220 | "vocab_subtoken = utils.Vocab(word2id=word2id)\n", 221 | "vocab_nodes = utils.Vocab(word2id=word2id)\n", 222 | "vocab_target = utils.Vocab(word2id=word2id)\n", 223 | "\n", 224 | "vocab_subtoken.build_vocab(list(subtoken_to_count.keys()), min_count=0)\n", 225 | "vocab_nodes.build_vocab(list(node_to_count.keys()), min_count=0)\n", 226 | "vocab_target.build_vocab(list(target_to_count.keys()), min_count=0)\n", 227 | "\n", 228 | "vocab_size_subtoken = len(vocab_subtoken.id2word)\n", 229 | "vocab_size_nodes = len(vocab_nodes.id2word)\n", 230 | "vocab_size_target = len(vocab_target.id2word)\n", 231 | "\n", 232 | "\n", 233 | "msgr.print_msg('vocab_size_subtoken:' + str(vocab_size_subtoken))\n", 234 | "msgr.print_msg('vocab_size_nodes:' + str(vocab_size_nodes))\n", 235 | "msgr.print_msg('vocab_size_target:' + str(vocab_size_target))\n", 236 | "\n", 237 | "num_length_train = num_training_examples\n", 238 | "msgr.print_msg('num_examples : ' + str(num_length_train))" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "class DataLoader(object):\n", 248 | "\n", 249 | " def __init__(self, data_path, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, shuffle=True, batch_time = False):\n", 250 | " \n", 251 | " \"\"\"\n", 252 | " data_path : path for data \n", 253 | " num_examples : total lines of data file\n", 254 | " batch_size : batch size\n", 255 | " num_k : max ast pathes included to one examples\n", 256 | " vocab_subtoken : dict of subtoken and its id\n", 257 | " vocab_nodes : dict of node simbol and its id\n", 258 | " vocab_target : dict of target simbol and its id\n", 259 | " \"\"\"\n", 260 | " \n", 261 | " self.data_path = data_path\n", 262 | " self.batch_size = batch_size\n", 263 | " \n", 264 | " self.num_examples = self.file_count(data_path)\n", 265 | " self.num_k = num_k\n", 266 | " \n", 267 | " self.vocab_subtoken = vocab_subtoken\n", 268 | " self.vocab_nodes = vocab_nodes\n", 269 | " self.vocab_target = vocab_target\n", 270 | " \n", 271 | " self.index = 0\n", 272 | " self.pointer = np.array(range(self.num_examples))\n", 273 | " self.shuffle = shuffle\n", 274 | " \n", 275 | " self.batch_time = batch_time\n", 276 | " \n", 277 | " self.reset()\n", 278 | "\n", 279 | " \n", 280 | " def __iter__(self):\n", 281 | " return self\n", 282 | " \n", 283 | " def __next__(self):\n", 284 | " \n", 285 | " if self.batch_time:\n", 286 | " t1 = time.time()\n", 287 | " \n", 288 | " if self.index >= self.num_examples:\n", 289 | " self.reset()\n", 290 | " raise StopIteration()\n", 291 | " \n", 292 | " ids = self.pointer[self.index: self.index + self.batch_size]\n", 293 | " seqs_S, seqs_N, seqs_E, seqs_Y = self.read_batch(ids)\n", 294 | " \n", 295 | " # length_k : (batch_size, k)\n", 296 | " lengths_k = [len(ex) for ex in seqs_N]\n", 297 | " \n", 298 | " # flattening (batch_size, k, l) to (batch_size * k, l)\n", 299 | " # this is useful to make torch.tensor\n", 300 | " seqs_S = [symbol for k in seqs_S for symbol in k]\n", 301 | " seqs_N = [symbol for k in seqs_N for symbol in k] \n", 302 | " seqs_E = [symbol for k in seqs_E for symbol in k] \n", 303 | " \n", 304 | " # Padding\n", 305 | " lengths_S = [len(s) for s in seqs_S]\n", 306 | " lengths_N = [len(s) for s in seqs_N]\n", 307 | " lengths_E = [len(s) for s in seqs_E]\n", 308 | " lengths_Y = [len(s) for s in seqs_Y]\n", 309 | " \n", 310 | " max_length_S = max(lengths_S)\n", 311 | " max_length_N = max(lengths_N)\n", 312 | " max_length_E = max(lengths_E)\n", 313 | " max_length_Y = max(lengths_Y)\n", 314 | "\n", 315 | " padded_S = [utils.pad_seq(s, max_length_S) for s in seqs_S]\n", 316 | " padded_N = [utils.pad_seq(s, max_length_N) for s in seqs_N]\n", 317 | " padded_E = [utils.pad_seq(s, max_length_E) for s in seqs_E]\n", 318 | " padded_Y = [utils.pad_seq(s, max_length_Y) for s in seqs_Y]\n", 319 | " \n", 320 | " # index for split (batch_size * k, l) into (batch_size, k, l)\n", 321 | " index_N = range(len(lengths_N))\n", 322 | " \n", 323 | " # sort for rnn\n", 324 | " seq_pairs = sorted(zip(lengths_N, index_N, padded_N, padded_S, padded_E), key=lambda p: p[0], reverse=True)\n", 325 | " lengths_N, index_N, padded_N, padded_S, padded_E = zip(*seq_pairs)\n", 326 | " \n", 327 | " batch_S = torch.tensor(padded_S, dtype=torch.long, device=device)\n", 328 | " batch_E = torch.tensor(padded_E, dtype=torch.long, device=device)\n", 329 | " \n", 330 | " # transpose for rnn\n", 331 | " batch_N = torch.tensor(padded_N, dtype=torch.long, device=device).transpose(0, 1)\n", 332 | " batch_Y = torch.tensor(padded_Y, dtype=torch.long, device=device).transpose(0, 1)\n", 333 | " \n", 334 | " # update index\n", 335 | " self.index += self.batch_size\n", 336 | " \n", 337 | " if self.batch_time:\n", 338 | " t2 = time.time()\n", 339 | " elapsed_time = t2-t1\n", 340 | " print(f\"batching time:{elapsed_time}\")\n", 341 | "\n", 342 | " return batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N\n", 343 | " \n", 344 | " \n", 345 | " def reset(self):\n", 346 | " if self.shuffle:\n", 347 | " self.pointer = shuffle(self.pointer)\n", 348 | " self.index = 0 \n", 349 | " \n", 350 | " def file_count(self, path):\n", 351 | " lst = [name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]\n", 352 | " return len(lst)\n", 353 | " \n", 354 | " def read_batch(self, ids):\n", 355 | " \n", 356 | " seqs_S = []\n", 357 | " seqs_E = []\n", 358 | " seqs_N = []\n", 359 | " seqs_Y = []\n", 360 | " \n", 361 | " for i in ids:\n", 362 | " path = self.data_path + '/{:0>6d}.txt'.format(i)\n", 363 | " with open(path, 'r') as f:\n", 364 | " seq_S = []\n", 365 | " seq_N = []\n", 366 | " seq_E = []\n", 367 | "\n", 368 | " target, *syntax_path = f.readline().split(' ')\n", 369 | " target = target.split('|')\n", 370 | " target = utils.sentence_to_ids(self.vocab_target, target)\n", 371 | "\n", 372 | " # remove '' and '\\n' in sequence, java-small dataset contains many '' in a line.\n", 373 | " syntax_path = [s for s in syntax_path if s != '' and s != '\\n']\n", 374 | "\n", 375 | " # if the amount of ast path exceed the k,\n", 376 | " # uniformly sample ast pathes, as described in the paper.\n", 377 | " if len(syntax_path) > self.num_k:\n", 378 | " sampled_path_index = random.sample(range(len(syntax_path)) , self.num_k)\n", 379 | " else :\n", 380 | " sampled_path_index = range(len(syntax_path))\n", 381 | "\n", 382 | " for j in sampled_path_index:\n", 383 | " terminal1, ast_path, terminal2 = syntax_path[j].split(',')\n", 384 | "\n", 385 | " terminal1 = utils.sentence_to_ids(self.vocab_subtoken, terminal1.split('|'))\n", 386 | " ast_path = utils.sentence_to_ids(self.vocab_nodes, ast_path.split('|'))\n", 387 | " terminal2 = utils.sentence_to_ids(self.vocab_subtoken, terminal2.split('|')) \n", 388 | "\n", 389 | " seq_S.append(terminal1)\n", 390 | " seq_E.append(terminal2)\n", 391 | " seq_N.append(ast_path)\n", 392 | "\n", 393 | " seqs_S.append(seq_S)\n", 394 | " seqs_E.append(seq_E)\n", 395 | " seqs_N.append(seq_N)\n", 396 | " seqs_Y.append(target)\n", 397 | "\n", 398 | " return seqs_S, seqs_N, seqs_E, seqs_Y\n", 399 | " " 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": { 406 | "colab": {}, 407 | "colab_type": "code", 408 | "id": "-7lLgR9WwrPS" 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "class Encoder(nn.Module):\n", 413 | " def __init__(self, input_size_subtoken, input_size_node, token_size, hidden_size, bidirectional = True, num_layers = 2, rnn_dropout = 0.5, embeddings_dropout = 0.25):\n", 414 | " \n", 415 | " \"\"\"\n", 416 | " input_size_subtoken : # of unique subtoken\n", 417 | " input_size_node : # of unique node symbol\n", 418 | " token_size : embedded token size\n", 419 | " hidden_size : size of initial state of decoder\n", 420 | " rnn_dropout = 0.5 : rnn drop out ratio\n", 421 | " embeddings_dropout = 0.25 : dropout ratio for context vector\n", 422 | " \"\"\"\n", 423 | " \n", 424 | " super(Encoder, self).__init__()\n", 425 | " self.hidden_size = hidden_size\n", 426 | " self.token_size = token_size\n", 427 | "\n", 428 | " self.embedding_subtoken = nn.Embedding(input_size_subtoken, token_size, padding_idx=PAD)\n", 429 | " self.embedding_node = nn.Embedding(input_size_node, token_size, padding_idx=PAD)\n", 430 | " \n", 431 | " self.lstm = nn.LSTM(token_size, token_size, num_layers = num_layers, bidirectional=bidirectional, dropout=rnn_dropout)\n", 432 | " self.out = nn.Linear(token_size * 4, hidden_size)\n", 433 | " \n", 434 | " self.dropout = nn.Dropout(embeddings_dropout)\n", 435 | " self.num_directions = 2 if bidirectional else 1\n", 436 | " self.num_layers = num_layers\n", 437 | "\n", 438 | " def forward(self, batch_S, batch_N, batch_E, lengths_k, index_N, hidden=None):\n", 439 | " \n", 440 | " \"\"\"\n", 441 | " batch_S : (B * k, l) start terminals' subtoken of each ast path\n", 442 | " batch_N : (l, B*k) nonterminals' nodes of each ast path\n", 443 | " batch_E : (B * k, l) end terminals' subtoken of each ast path\n", 444 | " \n", 445 | " lengths_k : length of k in each example\n", 446 | " index_N : index for unsorting,\n", 447 | " \"\"\"\n", 448 | " \n", 449 | " bk_size = batch_N.shape[1]\n", 450 | " output_bag = []\n", 451 | " hidden_batch = []\n", 452 | " \n", 453 | " # (B * k, l, d)\n", 454 | " encode_S = self.embedding_subtoken(batch_S)\n", 455 | " encode_E = self.embedding_subtoken(batch_E)\n", 456 | " \n", 457 | " # encode_S (B * k, d) token_representation of each ast path\n", 458 | " encode_S = encode_S.sum(1)\n", 459 | " encode_E = encode_E.sum(1)\n", 460 | " \n", 461 | " \n", 462 | " \"\"\"\n", 463 | " LSTM Outputs: output, (h_n, c_n)\n", 464 | " output (seq_len, batch, num_directions * hidden_size)\n", 465 | " h_n (num_layers * num_directions, batch, hidden_size) : tensor containing the hidden state for t = seq_len.\n", 466 | " c_n (num_layers * num_directions, batch, hidden_size)\n", 467 | " \"\"\"\n", 468 | " \n", 469 | " # emb_N :(l, B*k, d)\n", 470 | " emb_N = self.embedding_node(batch_N)\n", 471 | " packed = pack_padded_sequence(emb_N, lengths_N)\n", 472 | " output, (hidden, cell) = self.lstm(packed, hidden)\n", 473 | " #output, _ = pad_packed_sequence(output)\n", 474 | " \n", 475 | " # hidden (num_layers * num_directions, batch, hidden_size)\n", 476 | " # only last layer, (num_directions, batch, hidden_size)\n", 477 | " hidden = hidden[-self.num_directions:, :, :]\n", 478 | " \n", 479 | " # -> (Bk, num_directions, hidden_size)\n", 480 | " hidden = hidden.transpose(0, 1)\n", 481 | " \n", 482 | " # -> (Bk, 1, hidden_size * num_directions)\n", 483 | " hidden = hidden.contiguous().view(bk_size, 1, -1)\n", 484 | " \n", 485 | " # encode_N (Bk, hidden_size * num_directions)\n", 486 | " encode_N = hidden.squeeze(1)\n", 487 | " \n", 488 | " # encode_SNE : (B*k, hidden_size * num_directions + 2)\n", 489 | " encode_SNE = torch.cat([encode_N, encode_S, encode_E], dim=1)\n", 490 | " \n", 491 | " # encode_SNE : (B*k, d)\n", 492 | " encode_SNE = self.out(encode_SNE)\n", 493 | " \n", 494 | " # unsort as example\n", 495 | " #index = torch.tensor(index_N, dtype=torch.long, device=device)\n", 496 | " #encode_SNE = torch.index_select(encode_SNE, dim=0, index=index)\n", 497 | " index = np.argsort(index_N)\n", 498 | " encode_SNE = encode_SNE[[index]]\n", 499 | " \n", 500 | " # as is in https://github.com/tech-srl/code2seq/blob/ec0ae309efba815a6ee8af88301479888b20daa9/model.py#L511\n", 501 | " encode_SNE = self.dropout(encode_SNE)\n", 502 | " \n", 503 | " # output_bag : [ B, (k, d) ]\n", 504 | " output_bag = torch.split(encode_SNE, lengths_k, dim=0)\n", 505 | " \n", 506 | " # hidden_0 : (1, B, d)\n", 507 | " # for decoder initial state\n", 508 | " hidden_0 = [ob.mean(0).unsqueeze(dim=0) for ob in output_bag]\n", 509 | " hidden_0 = torch.cat(hidden_0, dim=0).unsqueeze(dim=0)\n", 510 | " \n", 511 | " return output_bag, hidden_0\n" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "metadata": { 518 | "colab": {}, 519 | "colab_type": "code", 520 | "id": "flU2AatIwrPl" 521 | }, 522 | "outputs": [], 523 | "source": [ 524 | "class Decoder(nn.Module):\n", 525 | " def __init__(self, hidden_size, output_size, rnn_dropout):\n", 526 | " \"\"\"\n", 527 | " hidden_size : decoder unit size, \n", 528 | " output_size : decoder output size, \n", 529 | " rnn_dropout : dropout ratio for rnn\n", 530 | " \"\"\"\n", 531 | " \n", 532 | " super(Decoder, self).__init__()\n", 533 | " self.hidden_size = hidden_size\n", 534 | " self.output_size = output_size\n", 535 | "\n", 536 | " self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD)\n", 537 | " self.gru = nn.GRU(hidden_size, hidden_size, dropout=rnn_dropout)\n", 538 | " self.out = nn.Linear(hidden_size * 2, output_size)\n", 539 | "\n", 540 | " def forward(self, seqs, hidden, attn):\n", 541 | " emb = self.embedding(seqs)\n", 542 | " _, hidden = self.gru(emb, hidden)\n", 543 | " \n", 544 | " output = torch.cat((hidden, attn), 2)\n", 545 | " output = self.out(output)\n", 546 | " \n", 547 | " return output, hidden" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": { 553 | "colab_type": "text", 554 | "id": "XCPLR4MkQ_U1" 555 | }, 556 | "source": [ 557 | "#### Attention\n", 558 | "\n", 559 | "$$\n", 560 | " \\alpha ^ t = softmax(h_t W_a z).\n", 561 | "$$\n", 562 | "\n", 563 | "$$\n", 564 | " c_t = {\\sum^n_{i} \\alpha_i^tz_i} .\n", 565 | "$$\n", 566 | "\n", 567 | "$$\n", 568 | " y_t = \\mathrm{softmax}(W_{s} \\mathrm{tanh} (W_{c} [c_t:h_t] ) )\n", 569 | "$$\n" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "metadata": { 576 | "colab": {}, 577 | "colab_type": "code", 578 | "id": "zsuAuSteQ_U3" 579 | }, 580 | "outputs": [], 581 | "source": [ 582 | "class EncoderDecoder_with_Attention(nn.Module):\n", 583 | " \n", 584 | " \"\"\"Conbine Encoder and Decoder\"\"\"\n", 585 | " \n", 586 | " def __init__(self, input_size_subtoken, input_size_node, token_size, output_size, hidden_size, bidirectional = True, num_layers = 2, rnn_dropout = 0.5, embeddings_dropout = 0.25):\n", 587 | "\n", 588 | " super(EncoderDecoder_with_Attention, self).__init__()\n", 589 | " self.encoder = Encoder(input_size_subtoken, input_size_node, token_size, hidden_size, bidirectional = bidirectional, num_layers = num_layers, rnn_dropout = rnn_dropout, embeddings_dropout = embeddings_dropout)\n", 590 | " self.decoder = Decoder(hidden_size, output_size, rnn_dropout)\n", 591 | " \n", 592 | " self.W_a = torch.rand((hidden_size, hidden_size), dtype=torch.float,device=device , requires_grad=True)\n", 593 | " \n", 594 | " nn.init.xavier_uniform_(self.W_a)\n", 595 | " \n", 596 | " \n", 597 | " def forward(self, batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S, max_length_N,max_length_E,max_length_Y, lengths_k, index_N, terget_max_length, batch_Y=None, use_teacher_forcing=False):\n", 598 | "\n", 599 | " # Encoder\n", 600 | " encoder_output_bag, encoder_hidden = \\\n", 601 | " self.encoder(batch_S, batch_N, batch_E, lengths_k, index_N)\n", 602 | " \n", 603 | " _batch_size = len(encoder_output_bag)\n", 604 | " decoder_hidden = encoder_hidden\n", 605 | " \n", 606 | " # make initial input for decoder\n", 607 | " decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device)\n", 608 | " decoder_input = decoder_input.unsqueeze(0) # (1, batch_size)\n", 609 | " \n", 610 | " # output holder\n", 611 | " decoder_outputs = torch.zeros(terget_max_length, _batch_size, self.decoder.output_size, device=device)\n", 612 | " \n", 613 | " #print('=' * 20)\n", 614 | " for t in range(terget_max_length):\n", 615 | " \n", 616 | " # ct\n", 617 | " ct = self.attention(encoder_output_bag, decoder_hidden, lengths_k)\n", 618 | " \n", 619 | " decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, ct)\n", 620 | " \n", 621 | " #print(decoder_output.max(-1)[1])\n", 622 | " \n", 623 | " decoder_outputs[t] = decoder_output\n", 624 | " \n", 625 | " # Teacher Forcing\n", 626 | " if use_teacher_forcing and batch_Y is not None:\n", 627 | " decoder_input = batch_Y[t].unsqueeze(0)\n", 628 | " else: \n", 629 | " decoder_input = decoder_output.max(-1)[1]\n", 630 | " \n", 631 | " return decoder_outputs\n", 632 | " \n", 633 | " def attention(self, encoder_output_bag, hidden, lengths_k):\n", 634 | " \n", 635 | " \"\"\"\n", 636 | " encoder_output_bag : (batch, k, hidden_size) bag of embedded ast path\n", 637 | " hidden : (1 , batch, hidden_size):\n", 638 | " lengths_k : (batch, 1) length of k in each example\n", 639 | " \"\"\"\n", 640 | " \n", 641 | " # e_out : (batch * k, hidden_size)\n", 642 | " e_out = torch.cat(encoder_output_bag, dim=0)\n", 643 | " \n", 644 | " # e_out : (batch * k(i), hidden_size(j))\n", 645 | " # self.W_a : [hidden_size(j), hidden_size(k)]\n", 646 | " # ha -> : [batch * k(i), hidden_size(k)]\n", 647 | " ha = einsum('ij,jk->ik', e_out, self.W_a)\n", 648 | " \n", 649 | " # ha -> : [batch, (k, hidden_size)]\n", 650 | " ha = torch.split(ha, lengths_k, dim=0)\n", 651 | " \n", 652 | " # dh = [batch, (1, hidden_size)]\n", 653 | " hd = hidden.transpose(0,1)\n", 654 | " hd = torch.unbind(hd, dim = 0)\n", 655 | " \n", 656 | " # _ha : (k(i), hidden_size(j))\n", 657 | " # _hd : (1(k), hidden_size(j))\n", 658 | " # at : [batch, ( k(i) ) ]\n", 659 | " at = [F.softmax(torch.einsum('ij,kj->i', _ha, _hd), dim=0) for _ha, _hd in zip(ha, hd)]\n", 660 | " \n", 661 | " # a : ( k(i) )\n", 662 | " # e : ( k(i), hidden_size(j))\n", 663 | " # ct : [batch, (hidden_size(j)) ] -> [batch, (1, hidden_size) ]\n", 664 | " ct = [torch.einsum('i,ij->j', a, e).unsqueeze(0) for a, e in zip(at, encoder_output_bag)]\n", 665 | " \n", 666 | " # ct [batch, hidden_size(k)]\n", 667 | " # -> (1, batch, hidden_size)\n", 668 | " ct = torch.cat(ct, dim=0).unsqueeze(0)\n", 669 | " \n", 670 | " \n", 671 | " return ct" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "metadata": { 678 | "colab": {}, 679 | "colab_type": "code", 680 | "id": "YqnxqyXqwrP0" 681 | }, 682 | "outputs": [], 683 | "source": [ 684 | "mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD)\n", 685 | "def masked_cross_entropy(logits, target):\n", 686 | " return mce(logits.view(-1, logits.size(-1)), target.view(-1))" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "metadata": { 693 | "colab": {}, 694 | "colab_type": "code", 695 | "id": "7kp7px_O7-1J" 696 | }, 697 | "outputs": [], 698 | "source": [ 699 | "batch_time = False\n", 700 | "train_dataloader = DataLoader(TRAIN_DIR, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, batch_time=batch_time, shuffle=True)\n", 701 | "valid_dataloader = DataLoader(VALID_DIR, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, shuffle=False)\n", 702 | "\n", 703 | "model_args = {\n", 704 | " 'input_size_subtoken' : vocab_size_subtoken,\n", 705 | " 'input_size_node' : vocab_size_nodes,\n", 706 | " 'output_size' : vocab_size_target,\n", 707 | " 'hidden_size' : hidden_size, \n", 708 | " 'token_size' : token_size,\n", 709 | " 'bidirectional' : bidirectional,\n", 710 | " 'num_layers' : num_layers,\n", 711 | " 'rnn_dropout' : rnn_dropout, \n", 712 | " 'embeddings_dropout' : embeddings_dropout\n", 713 | "}\n", 714 | "\n", 715 | "model = EncoderDecoder_with_Attention(**model_args).to(device)\n", 716 | "\n", 717 | "#optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov = nesterov)\n", 718 | "optimizer = optim.Adam(model.parameters(), lr=lr)\n", 719 | "scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: decay_ratio ** epoch)\n", 720 | "\n", 721 | "fname = exp_dir + save_name\n", 722 | "early_stopping = utils.EarlyStopping(fname, patience, warm_up, verbose=True)" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": null, 728 | "metadata": { 729 | "colab": {}, 730 | "colab_type": "code", 731 | "id": "IujU0wrrwrQE" 732 | }, 733 | "outputs": [], 734 | "source": [ 735 | "def compute_loss(batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, model, optimizer=None, is_train=True):\n", 736 | " model.train(is_train)\n", 737 | " \n", 738 | " use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)\n", 739 | " \n", 740 | " target_max_length = batch_Y.size(0)\n", 741 | " pred_Y = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)\n", 742 | " \n", 743 | " loss = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())\n", 744 | " \n", 745 | " if is_train:\n", 746 | " optimizer.zero_grad()\n", 747 | " loss.backward()\n", 748 | " optimizer.step()\n", 749 | " \n", 750 | " batch_Y = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()\n", 751 | " pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()\n", 752 | " \n", 753 | " \n", 754 | " return loss.item(), batch_Y, pred" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": null, 760 | "metadata": { 761 | "colab": {}, 762 | "colab_type": "code", 763 | "id": "nTCOrAinwrQQ" 764 | }, 765 | "outputs": [], 766 | "source": [ 767 | "#\n", 768 | "# Training Loop\n", 769 | "# \n", 770 | "progress_bar = False # progress bar is visible in progress_bar = False\n", 771 | "\n", 772 | "\n", 773 | "for epoch in range(1, num_epochs+1):\n", 774 | " train_loss = 0.\n", 775 | " train_refs = []\n", 776 | " train_hyps = []\n", 777 | " valid_loss = 0.\n", 778 | " valid_refs = []\n", 779 | " valid_hyps = []\n", 780 | " \n", 781 | " # train\n", 782 | " for batch in tqdm(train_dataloader, total=train_dataloader.num_examples // train_dataloader.batch_size + 1, desc='TRAIN'):\n", 783 | " batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S, max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch\n", 784 | " \n", 785 | " loss, gold, pred = compute_loss(\n", 786 | " batch_S, batch_N, batch_E, batch_Y, \n", 787 | " lengths_S, lengths_N, lengths_E, lengths_Y, \n", 788 | " max_length_S,max_length_N,max_length_E,max_length_Y, \n", 789 | " lengths_k, index_N, model, optimizer,\n", 790 | " is_train=True\n", 791 | " )\n", 792 | " \n", 793 | " train_loss += loss\n", 794 | " train_refs += gold\n", 795 | " train_hyps += pred\n", 796 | " \n", 797 | " # valid\n", 798 | " for batch in tqdm(valid_dataloader, total=valid_dataloader.num_examples // valid_dataloader.batch_size + 1, desc='VALID'):\n", 799 | "\n", 800 | " batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch\n", 801 | "\n", 802 | " loss, gold, pred = compute_loss(\n", 803 | " batch_S, batch_N, batch_E, batch_Y, \n", 804 | " lengths_S, lengths_N, lengths_E, lengths_Y, \n", 805 | " max_length_S,max_length_N,max_length_E,max_length_Y, \n", 806 | " lengths_k, index_N, model, optimizer,\n", 807 | " is_train=False\n", 808 | " )\n", 809 | " \n", 810 | " valid_loss += loss\n", 811 | " valid_refs += gold\n", 812 | " valid_hyps += pred\n", 813 | " \n", 814 | "\n", 815 | " train_loss = np.sum(train_loss) / train_dataloader.num_examples\n", 816 | " valid_loss = np.sum(valid_loss) / valid_dataloader.num_examples\n", 817 | " \n", 818 | " # F1 etc\n", 819 | " train_precision, train_recall, train_f1 = utils.calculate_results_set(train_refs, train_hyps)\n", 820 | " valid_precision, valid_recall, valid_f1 = utils.calculate_results_set(valid_refs, valid_hyps)\n", 821 | "\n", 822 | " \n", 823 | " early_stopping(valid_f1, model, epoch)\n", 824 | " if early_stopping.early_stop:\n", 825 | " msgr.print_msg(\"Early stopping\")\n", 826 | " break\n", 827 | " \n", 828 | " msgr.print_msg('Epoch {}: train_loss: {:5.2f} train_f1: {:2.4f} valid_loss: {:5.2f} valid_f1: {:2.4f}'.format(\n", 829 | " epoch, train_loss, train_f1, valid_loss, valid_f1))\n", 830 | " \n", 831 | " print('-'*80)\n", 832 | " \n", 833 | " scheduler.step()" 834 | ] 835 | }, 836 | { 837 | "cell_type": "markdown", 838 | "metadata": { 839 | "colab_type": "text", 840 | "id": "QsQay14VEwYJ" 841 | }, 842 | "source": [ 843 | "## Evaluation" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": null, 849 | "metadata": { 850 | "colab": {}, 851 | "colab_type": "code", 852 | "id": "swYlOn13Q_VU" 853 | }, 854 | "outputs": [], 855 | "source": [ 856 | "model = EncoderDecoder_with_Attention(**model_args).to(device)\n", 857 | "\n", 858 | "fname = exp_dir + save_name\n", 859 | "ckpt = torch.load(fname)\n", 860 | "model.load_state_dict(ckpt)\n", 861 | "\n", 862 | "model.eval()" 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "execution_count": null, 868 | "metadata": { 869 | "colab": {}, 870 | "colab_type": "code", 871 | "id": "2GKrKzSDQ_VW" 872 | }, 873 | "outputs": [], 874 | "source": [ 875 | "test_dataloader = DataLoader(TEST_DIR, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, batch_time=batch_time, shuffle=True)" 876 | ] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": null, 881 | "metadata": { 882 | "colab": {}, 883 | "colab_type": "code", 884 | "id": "lY6ty-rNG76n" 885 | }, 886 | "outputs": [], 887 | "source": [ 888 | "refs_list = []\n", 889 | "hyp_list = []\n", 890 | "\n", 891 | "for batch in tqdm(test_dataloader,\n", 892 | " total=test_dataloader.num_examples // test_dataloader.batch_size + 1,\n", 893 | " desc='TEST'):\n", 894 | " \n", 895 | " batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = batch\n", 896 | " target_max_length = batch_Y.size(0)\n", 897 | " use_teacher_forcing = False\n", 898 | " \n", 899 | " pred_Y = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)\n", 900 | " \n", 901 | " refs = batch_Y.transpose(0, 1).contiguous().data.cpu().tolist()[0]\n", 902 | " pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()[0]\n", 903 | " \n", 904 | " refs_list.append(refs)\n", 905 | " hyp_list.append(pred)\n" 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": null, 911 | "metadata": { 912 | "colab": {}, 913 | "colab_type": "code", 914 | "id": "VVDf6TOAPE8z" 915 | }, 916 | "outputs": [], 917 | "source": [ 918 | "msgr.print_msg('Tested model : ' + fname)\n", 919 | "\n", 920 | "test_precision, test_recall, test_f1 = utils.calculate_results(refs_list, hyp_list)\n", 921 | "msgr.print_msg('Test : precision {:1.5f}, recall {:1.5f}, f1 {:1.5f}'.format(test_precision, test_recall, test_f1))\n", 922 | "\n", 923 | "test_precision, test_recall, test_f1 = utils.calculate_results_set(refs_list, hyp_list)\n", 924 | "msgr.print_msg('Test(set) : precision {:1.5f}, recall {:1.5f}, f1 {:1.5f}'.format(test_precision, test_recall, test_f1))" 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "execution_count": null, 930 | "metadata": {}, 931 | "outputs": [], 932 | "source": [ 933 | "batch_time = False\n", 934 | "test_dataloader = DataLoader(TEST_DIR, 1, num_k, vocab_subtoken, vocab_nodes, vocab_target, batch_time=batch_time, shuffle=True)" 935 | ] 936 | }, 937 | { 938 | "cell_type": "code", 939 | "execution_count": null, 940 | "metadata": {}, 941 | "outputs": [], 942 | "source": [ 943 | "model.eval()\n", 944 | "\n", 945 | "batch_S, batch_N, batch_E, batch_Y, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N = next(test_dataloader)\n", 946 | "\n", 947 | "sentence_Y = ' '.join(utils.ids_to_sentence(vocab_target, batch_Y.data.cpu().numpy()[:-1, 0]))\n", 948 | "msgr.print_msg('tgt: {}'.format(sentence_Y))\n", 949 | "\n", 950 | "target_max_length = batch_Y.size(0)\n", 951 | "use_teacher_forcing = False\n", 952 | "output = model(batch_S, batch_N, batch_E, lengths_S, lengths_N, lengths_E, lengths_Y, max_length_S,max_length_N,max_length_E,max_length_Y, lengths_k, index_N, target_max_length, batch_Y, use_teacher_forcing)\n", 953 | "\n", 954 | "output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()\n", 955 | "output_sentence = ' '.join(utils.ids_to_sentence(vocab_target, utils.trim_eos(output)))\n", 956 | "msgr.print_msg('out: {}'.format(output_sentence))" 957 | ] 958 | }, 959 | { 960 | "cell_type": "code", 961 | "execution_count": null, 962 | "metadata": {}, 963 | "outputs": [], 964 | "source": [] 965 | } 966 | ], 967 | "metadata": { 968 | "accelerator": "GPU", 969 | "colab": { 970 | "collapsed_sections": [], 971 | "include_colab_link": true, 972 | "name": "code2seq.ipynb", 973 | "private_outputs": true, 974 | "provenance": [], 975 | "version": "0.3.2" 976 | }, 977 | "kernelspec": { 978 | "display_name": "Python 3", 979 | "language": "python", 980 | "name": "python3" 981 | }, 982 | "language_info": { 983 | "codemirror_mode": { 984 | "name": "ipython", 985 | "version": 3 986 | }, 987 | "file_extension": ".py", 988 | "mimetype": "text/x-python", 989 | "name": "python", 990 | "nbconvert_exporter": "python", 991 | "pygments_lexer": "ipython3", 992 | "version": "3.7.4" 993 | } 994 | }, 995 | "nbformat": 4, 996 | "nbformat_minor": 4 997 | } 998 | --------------------------------------------------------------------------------