├── .gitignore ├── LICENSE ├── README.md ├── download.sh ├── models.py ├── preprocess.py ├── run_conv.py ├── run_copy.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | #data 107 | /data 108 | 109 | #saves 110 | /.save 111 | /save 112 | /saves 113 | /.vector_cache 114 | /logs 115 | /checkpoints 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ben Trevett 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Convolutional Attention Network for Extreme Summarization of Source Code 2 | 3 | Implementation of [A Convolutional Attention Network for Extreme Summarization of Source Code](https://arxiv.org/abs/1602.03001) in PyTorch using TorchText 4 | 5 | Using Python 3.6, PyTorch 0.4 and TorchText 0.2.3. 6 | 7 | **Note**: only the *Convolutional Attention Model* currently works, the *Copy Convolutional Attention Model* is in progress. 8 | 9 | To use: 10 | 11 | 1. `download.sh` to grab the dataset 12 | 1. `python preprocess.py` to preprocess the dataset into json format 13 | 1. `python run_conv.py` to run the Convolutional Attention Model with default parameters 14 | 15 | Use `python run_conv.py -h` to see all the parameters that can be changed, e.g. to run the model on a different Java project within the dataset, use: `python run_conv.py --project {project name}`. -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | wget http://groups.inf.ed.ac.uk/cup/codeattention/dataset.zip 2 | mkdir data 3 | mv dataset.zip data 4 | cd data 5 | unzip -o dataset -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import random 6 | 7 | class AttentionFeatures(nn.Module): 8 | """ 9 | Page 3 of the paper 10 | attention_features (code tokens c, context h_{t-1}) 11 | C <- lookupandpad(c, E) 12 | L1 <- ReLU(Conv1d(C, K_{l1})) 13 | L2 <- Conv1d(L1, K_{l2}) * h_{t-1} 14 | Lfeat <- L2/||L2||_2 15 | return Lfeat 16 | """ 17 | def __init__(self, embedding_dim, k1, w1, k2, w2, w3, dropout): 18 | super().__init__() 19 | 20 | self.w1 = w1 21 | self.k1 = k1 22 | 23 | self.w2 = w2 24 | self.k2 = k2 25 | 26 | #self.w3 = w3 #use this to calculate padding 27 | 28 | self.conv1 = nn.Conv1d(embedding_dim, k1, w1) 29 | self.conv2 = nn.Conv1d(k1, k2, w2) 30 | self.do = nn.Dropout(dropout) 31 | self.activation = nn.PReLU() 32 | def forward(self, C, h_t): 33 | 34 | #C = embedded body tokens 35 | #h_t = previous hidden state used to predict name token 36 | 37 | #C = [bodies len, batch size, emb dim] 38 | #h_t = [1, batch size, k2] 39 | 40 | C = C.permute(1, 2, 0) #input to conv needs n_channels as dim 1 41 | 42 | #C = [batch size, emb dim, bodies len] 43 | 44 | h_t = h_t.permute(1, 2, 0) #from [1, batch size, k2] to [batch size, k2, 1] 45 | 46 | #h_t = [batch size, k2, 1] 47 | 48 | L_1 = self.do(self.activation(self.conv1(C))) 49 | 50 | #L_1 = [batch size, k1, bodies len - w1 + 1] 51 | 52 | L_2 = self.do(self.conv2(L_1)) * h_t 53 | 54 | #L_2 = [batch size, k2, bodies len - w1 - w2 + 2] 55 | 56 | L_feat = F.normalize(L_2, p=2, dim=1) 57 | 58 | #L_feat = [batch size, k2, bodies len - w1 - w2 + 2] 59 | 60 | return L_feat 61 | 62 | class AttentionWeights(nn.Module): 63 | """ 64 | Page 3 of the paper 65 | attention_features (attention features Lfeat, kernel K) 66 | return Softmax(Conv1d(Lfeat, K)) 67 | """ 68 | def __init__(self, k2, w3, dropout): 69 | super().__init__() 70 | 71 | self.conv1 = nn.Conv1d(k2, 1, w3) 72 | self.do = nn.Dropout(dropout) 73 | 74 | def forward(self, L_feat, log=False): 75 | 76 | #L_feat = [batch size, k2, bodies len - w1 - w2 + 2] 77 | 78 | x = self.do(self.conv1(L_feat)) 79 | 80 | #x = [batch size, 1, bodies len - w1 - w2 - w3 + 3] 81 | 82 | x = x.squeeze(1) 83 | 84 | #x = [batch size, bodies len - w1 - w2 - w3 + 3] 85 | 86 | if log: 87 | x = F.log_softmax(x, dim=1) 88 | else: 89 | x = F.softmax(x, dim=1) 90 | 91 | #x = [batch size, bodies len - w1 - w2 - w3 + 3] 92 | 93 | return x 94 | 95 | class ConvAttentionNetwork(nn.Module): 96 | def __init__(self, vocab_size, embedding_dim, k1, k2, w1, w2, w3, dropout, pad_idx): 97 | super().__init__() 98 | 99 | self.vocab_size = vocab_size 100 | self.embedding_dim = embedding_dim 101 | self.k1 = k1 102 | self.k2 = k2 103 | self.w1 = w1 104 | self.w2 = w2 105 | self.w3 = w3 106 | self.dropout = dropout 107 | self.pad_idx = pad_idx 108 | 109 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) 110 | self.do = nn.Dropout(dropout) 111 | self.gru = nn.GRU(embedding_dim, k2) 112 | self.attn_feat = AttentionFeatures(embedding_dim, k1, w1, k2, w2, w3, dropout) 113 | self.attn_weights = AttentionWeights(k2, w3, dropout) 114 | self.bias = nn.Parameter(torch.ones(vocab_size)) 115 | 116 | n_padding = w1 + w2 + w3 - 3 117 | self.padding = torch.zeros(n_padding, 1).fill_(pad_idx).long() 118 | 119 | def forward(self, bodies, names, tf=None): 120 | 121 | if tf is None: 122 | tf = self.dropout 123 | 124 | #bodies = [bodies len, batch size] 125 | #names = [names len, batch size] 126 | 127 | #stores the probabilities generated for each token 128 | outputs = torch.zeros(names.shape[0], names.shape[1], self.vocab_size).to(names.device) 129 | 130 | #outputs = [name len, batch size, vocab dim] 131 | 132 | #need to pad the function body so after it has been fed through 133 | #the convolutional layers it is the same size as the original function body 134 | bodies_padded = torch.cat((bodies, self.padding.expand(-1, bodies.shape[1]).to(bodies.device))) 135 | 136 | #bodies_padded = [bodies len + w1 + w2 + w3 - 3, batch_size] 137 | 138 | #from now on when we refer to bodies len, we mean the padded version 139 | 140 | #convert function body tokens into their embeddings 141 | emb_b = self.embedding(bodies_padded) 142 | 143 | #emb_b = [bodies len, batch size, emb dim] 144 | 145 | #first input to the gru is the first token of the function name 146 | #which is a start of sentence token 147 | output = names[0] 148 | 149 | #generate predicted function name tokens one at a time 150 | for i in range(1, names.shape[0]): 151 | 152 | #initial hidden state is start of sentence token passed through gru 153 | #subsequent hidden states from either the previous token predicted by the model 154 | #or the ground truth token the model should have predicted 155 | _, h_t = self.gru(self.embedding(output).unsqueeze(0)) 156 | 157 | #h_t = [1, batch size, k2] 158 | 159 | #computes `k2` features for each token which are scaled by h_t 160 | L_feat = self.attn_feat(emb_b, h_t) 161 | 162 | #L_feat = [batch size, k2, bodies len - w1 - w2 + 2] 163 | 164 | #computes the attention values for each token in the function body 165 | #the second dimension is now equal to the original unpadded `bodies len` size 166 | alpha = self.attn_weights(L_feat) 167 | 168 | #alpha = [batch size, bodies len - w1 - w2 - w3 + 3] 169 | 170 | #emb_b also contains the padding tokens so we slice these off 171 | emb_b_slice = emb_b.permute(1, 0, 2)[:, :bodies.shape[0], :] 172 | 173 | #emb_b = [batch_size, bodies len, emb dim] 174 | 175 | #apply the attention to the embedded function body tokens 176 | n_hat = torch.sum(alpha.unsqueeze(2) * emb_b_slice, dim=1) 177 | 178 | #n_hat = [batch size, emb dim] 179 | 180 | #E is the embedding layer weights 181 | E = self.embedding.weight.unsqueeze(0).expand(bodies.shape[1],-1,-1) 182 | 183 | #E = [batch size, vocab size, emb dim] 184 | 185 | #matrix multiply E and n_hat and apply a bias 186 | #n is the probability distribution over the vocabulary for the predicted next token 187 | n = torch.bmm(E, n_hat.unsqueeze(2)).squeeze(2) + self.bias.unsqueeze(0).expand(bodies.shape[1], -1) 188 | 189 | #n = [batch size, vocab size] 190 | 191 | #store prediction probability distribution in large tensor that holds 192 | #predictions for each token in the function name 193 | outputs[i] = n 194 | 195 | #with probability of `tf`, use the model's prediction of the next token 196 | #as the next token to feed into the model (to become the next h_t) 197 | #with probability 1-`tf`, use the actual ground truth next token as 198 | #the next token to feed into the model 199 | #teacher forcing ratio is equal to dropout during training and 0 during inference 200 | if random.random() < tf: 201 | 202 | #model's predicted token highest value in the probability distribution 203 | top1 = n.max(1)[1] 204 | output = top1 205 | 206 | else: 207 | output = names[i] 208 | 209 | #outputs = [name len, batch size, vocab dim] 210 | 211 | return outputs 212 | 213 | class CopyAttentionNetwork(nn.Module): 214 | def __init__(self, vocab_size, embedding_dim, k1, k2, w1, w2, w3, dropout, pad_idx): 215 | super().__init__() 216 | 217 | self.vocab_size = vocab_size 218 | self.embedding_dim = embedding_dim 219 | self.k1 = k1 220 | self.k2 = k2 221 | self.w1 = w1 222 | self.w2 = w2 223 | self.w3 = w3 224 | self.dropout = dropout 225 | self.pad_idx = pad_idx 226 | 227 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) 228 | self.do = nn.Dropout(dropout) 229 | self.gru = nn.GRU(embedding_dim, k2) 230 | self.attn_feat = AttentionFeatures(embedding_dim, k1, w1, k2, w2, w3, dropout) 231 | self.attn_weights_alpha = AttentionWeights(k2, w3, dropout) 232 | self.attn_weights_kappa = AttentionWeights(k2, w3, dropout) 233 | self.conv1 = nn.Conv1d(k2, 1, w3) 234 | self.bias = nn.Parameter(torch.ones(vocab_size)) 235 | 236 | n_padding = w1 + w2 + w3 - 3 237 | self.padding = torch.zeros(n_padding, 1).fill_(pad_idx).long() 238 | 239 | def forward(self, bodies, names, tf=None): 240 | 241 | if tf is None: 242 | tf = self.dropout 243 | 244 | #bodies = [bodies len, batch size] 245 | #names = [names len, batch size] 246 | 247 | #stores the probabilities generated for each token 248 | outputs = torch.zeros(names.shape[0], names.shape[1], self.vocab_size).to(names.device) 249 | 250 | #outputs = [names len, batch size, vocab dim] 251 | 252 | #stores the copy attention generated for each token 253 | kappas = torch.zeros(names.shape[0], names.shape[1], bodies.shape[0]).to(names.device) 254 | 255 | #kappas = [name len, batch size, bodies len] 256 | 257 | #stores the prob of doing a copy for each token 258 | lambdas = torch.zeros(names.shape[0], names.shape[1]).to(names.device) 259 | 260 | #lambdas = [name len, batch size] 261 | 262 | #need to pad the function body so after it has been fed through 263 | #the convolutional layers it is the same size as the original function body 264 | bodies_padded = torch.cat((bodies, self.padding.expand(-1, bodies.shape[1]).to(bodies.device))) 265 | 266 | #bodies_padded = [bodies len + w1 + w2 + w3 - 3, batch_size] 267 | 268 | #from now on when we refer to bodies len, we mean the padded version 269 | 270 | #convert function body tokens into their embeddings 271 | emb_b = self.embedding(bodies_padded) 272 | 273 | #emb_b = [bodies len, batch size, emb dim] 274 | 275 | #first input to the gru is the first token of the function name 276 | #which is a start of sentence token 277 | output = names[0] 278 | 279 | #generate predicted function name tokens one at a time 280 | for i in range(1, names.shape[0]): 281 | 282 | #initial hidden state is start of sentence token passed through gru 283 | #subsequent hidden states from either the previous token predicted by the model 284 | #or the ground truth token the model should have predicted 285 | _, h_t = self.gru(self.embedding(output).unsqueeze(0)) 286 | 287 | #h_t = [1, batch size, k2] 288 | 289 | #computes `k2` features for each token which are scaled by h_t 290 | L_feat = self.attn_feat(emb_b, h_t) 291 | 292 | #L_feat = [batch size, k2, bodies len - w1 - w2 + 2] 293 | 294 | #alpha is the attention values for each token in the function body 295 | #kappa is the probability that each token in the function body is copied 296 | #the second dimension is now equal to the original unpadded `bodies len` size 297 | alpha = self.attn_weights_alpha(L_feat) 298 | kappa = self.attn_weights_kappa(L_feat, log=True) 299 | 300 | #alpha = [batch size, bodies len - w1 - w2 - w3 + 3] 301 | #kappa = [batch size, bodies len - w1 - w2 - w3 + 3] 302 | 303 | #calculate the weight of predicting by copying from body vs. predicting by guessing from vocab 304 | _lambda = F.max_pool1d(torch.sigmoid(self.do(self.conv1(L_feat))), alpha.shape[1]).squeeze(2) 305 | 306 | lambdas[i] = _lambda.permute(1, 0) 307 | 308 | #emb_b also contains the padding tokens so we slice these off 309 | emb_b_slice = emb_b.permute(1, 0, 2)[:, :bodies.shape[0], :] 310 | 311 | #emb_b = [batch_size, bodies len, emb dim] 312 | 313 | #apply the attention to the embedded function body tokens 314 | n_hat = torch.sum(alpha.unsqueeze(2) * emb_b_slice, dim=1) 315 | 316 | #n_hat = [batch size, emb dim] 317 | 318 | #E is the embedding layer weights 319 | E = self.embedding.weight.unsqueeze(0).expand(bodies.shape[1],-1,-1) 320 | 321 | #E = [batch size, vocab size, emb dim] 322 | 323 | #matrix multiply E and n_hat and apply a bias 324 | #n is the probability distribution over the vocabulary for the predicted next token 325 | n = torch.bmm(E, n_hat.unsqueeze(2)).squeeze(2) + self.bias.unsqueeze(0).expand(bodies.shape[1], -1) 326 | 327 | #n = [batch size, vocab size] 328 | 329 | #store prediction probability distribution in large tensor that holds 330 | #predictions for each token in the function name 331 | outputs[i] = F.log_softmax(n,dim=1) 332 | 333 | #store copy probability distribution 334 | kappas[i] = kappa 335 | 336 | #with probability of `tf`, use the model's prediction of the next token 337 | #as the next token to feed into the model (to become the next h_t) 338 | #with probability 1-`tf`, use the actual ground truth next token as 339 | #the next token to feed into the model 340 | #teacher forcing ratio is equal to dropout during training and 0 during inference 341 | if random.random() < tf: 342 | 343 | #model's predicted token highest value in the probability distribution 344 | top1 = n.max(1)[1] 345 | output = top1 346 | 347 | else: 348 | output = names[i] 349 | 350 | #outputs = [name len, batch size, vocab dim] 351 | 352 | return outputs, kappas, lambdas 353 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | DATA_DIR = 'data/json/' 5 | 6 | START_TOKEN = '' 7 | END_TOKEN = '' 8 | 9 | PROJECTS = ['cassandra', 'elasticsearch', 'gradle', 'hadoop-common', 'hibernate-orm', 'intellij-community', 'libgdx', 'liferay-portal', 'presto', 'spring-framework', 'wildfly'] 10 | 11 | def files_to_data(DIR, FILE): 12 | """ 13 | DIR is the base directory containing all the json files 14 | FILES is the filename of the json file you want to get the data from 15 | """ 16 | data = [] 17 | with open(os.path.join(DIR, FILE), 'r') as r: 18 | project = json.load(r) 19 | for method in project: 20 | assert type(method['filename']) is str 21 | assert type(method['name']) is list 22 | assert type(method['tokens']) == list 23 | method_name = [START_TOKEN] + method['name'] + [END_TOKEN] #add start and end of sequence token to method name 24 | method_body = [x.lower() for x in method['tokens'] if (x != '' and x != '')] #lowercase and remove tags 25 | 26 | #when the method name appears in the body it is represented by a %self% token 27 | #we replace the %self% token with the actual method name 28 | while '%self%' in method_body: 29 | self_idx = method_body.index('%self%') 30 | method_body = method_body[:self_idx] + method['name'] + method_body[self_idx+1:] 31 | 32 | data.append({'name': method_name, 'body': method_body}) 33 | 34 | return data 35 | 36 | for project in PROJECTS: 37 | 38 | print(f'Project: {project}') 39 | 40 | train_file = f'{project}_train_methodnaming.json' 41 | test_file = f'{project}_test_methodnaming.json' 42 | 43 | train_data = files_to_data(DATA_DIR, train_file) 44 | test_data = files_to_data(DATA_DIR, test_file) 45 | 46 | print(f'Training examples: {len(train_data)}') 47 | print(f'Testing examples: {len(test_data)}') 48 | 49 | with open(f'data/{project}_train.json', 'w') as w: 50 | for example in train_data: 51 | json.dump(example, w) 52 | w.write('\n') 53 | 54 | with open(f'data/{project}_test.json', 'w') as w: 55 | for example in test_data: 56 | json.dump(example, w) 57 | w.write('\n') 58 | -------------------------------------------------------------------------------- /run_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | 5 | from torchtext.data import Field 6 | from torchtext.data import TabularDataset 7 | from torchtext.data import BucketIterator 8 | 9 | import os 10 | import argparse 11 | import random 12 | 13 | import models 14 | import utils 15 | 16 | parser = argparse.ArgumentParser(description='Implemention of \'A Convolutional Attention Network for Extreme Summarization of Source Code\'', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | 18 | parser.add_argument('--project', default='cassandra', type=str, help='Which project to run on') 19 | parser.add_argument('--data_dir', default='data', type=str, help='Where to find the training data') 20 | parser.add_argument('--checkpoints_dir', default='checkpoints', type=str, help='Where to save the model checkpoints') 21 | parser.add_argument('--no_cuda', action='store_true', help='Use this flag to stop using the GPU') 22 | parser.add_argument('--min_freq', default=2, help='Minimum times a token must appear in the dataset to not be unk\'d') 23 | parser.add_argument('--batch_size', default=64, type=int) 24 | parser.add_argument('--emb_dim', default=128, type=int) 25 | parser.add_argument('--k1', default=8, type=int) 26 | parser.add_argument('--k2', default=8, type=int) 27 | parser.add_argument('--w1', default=24, type=int) 28 | parser.add_argument('--w2', default=29, type=int) 29 | parser.add_argument('--w3', default=10, type=int) 30 | parser.add_argument('--dropout', default=0.25, type=float) 31 | parser.add_argument('--clip', default=1.0, type=float) 32 | parser.add_argument('--epochs', default=100, type=int) 33 | parser.add_argument('--seed', default=1234, type=int) 34 | parser.add_argument('--load', action='store_true', help='Use this to load model parameters, parameters should be saved as: {checkpoints_dir}/{project name}-conv-model.pt') 35 | 36 | args = parser.parse_args() 37 | 38 | assert os.path.exists(f'{args.data_dir}/{args.project}_train.json') 39 | assert os.path.exists(f'{args.data_dir}/{args.project}_test.json') 40 | 41 | if not os.path.exists(f'{args.checkpoints_dir}'): 42 | os.mkdir(f'{args.checkpoints_dir}') 43 | 44 | #make deterministic 45 | torch.backends.cudnn.deterministic = True 46 | torch.manual_seed(args.seed) 47 | torch.cuda.manual_seed_all(args.seed) 48 | random.seed(args.seed) 49 | 50 | #get available device 51 | device = torch.device('cuda' if (torch.cuda.is_available() and not args.no_cuda) else 'cpu') 52 | 53 | #set up fields 54 | BODY = Field() 55 | NAME = Field() 56 | fields = {'name': ('name', NAME), 'body': ('body', BODY)} 57 | 58 | #get data from json 59 | train, test = TabularDataset.splits( 60 | path = 'data', 61 | train = f'{args.project}_train.json', 62 | test = f'{args.project}_test.json', 63 | format = 'json', 64 | fields = fields 65 | ) 66 | 67 | #build the vocabulary 68 | BODY.build_vocab(train.body, train.name, min_freq=args.min_freq) 69 | NAME.build_vocab(train.body, train.name, min_freq=args.min_freq) 70 | 71 | # make iterator for splits 72 | train_iter, test_iter = BucketIterator.splits( 73 | (train, test), 74 | batch_size=args.batch_size, 75 | sort_key=lambda x: len(x.name), 76 | repeat=False, 77 | device=-1 if device == 'cpu' else None) 78 | 79 | #calculate these for the model 80 | vocab_size = len(BODY.vocab) 81 | pad_idx = BODY.vocab.stoi[''] 82 | unk_idx = BODY.vocab.stoi[''] 83 | 84 | #initialize model 85 | model = models.ConvAttentionNetwork(vocab_size, args.emb_dim, args.k1, args.k2, args.w1, args.w2, args.w3, args.dropout, pad_idx) 86 | 87 | #place on GPU if available 88 | model = model.to(device) 89 | 90 | if args.load: 91 | model.load_state_dict(torch.load(f'{args.checkpoints_dir}/{args.project}-conv-model.pt')) 92 | 93 | #initialize optimizer and loss function 94 | criterion = nn.CrossEntropyLoss(ignore_index = pad_idx) 95 | optimizer = optim.RMSprop(model.parameters(), lr=1e-3, momentum=0.9) 96 | 97 | criterion = criterion.to(device) 98 | 99 | def train(model, iterator, optimizer, criterion, clip): 100 | 101 | #turn on dropout/bn 102 | model.train() 103 | 104 | epoch_loss = 0 105 | n_examples = 0 106 | precision = 0 107 | recall = 0 108 | f1 = 0 109 | 110 | for _, batch in enumerate(iterator): 111 | 112 | bodies = batch.body 113 | names = batch.name 114 | 115 | optimizer.zero_grad() 116 | 117 | output = model(bodies, names) 118 | 119 | #take highest probability token as prediction 120 | preds = output.max(2)[1] 121 | 122 | examples = names.shape[1] 123 | n_examples += examples 124 | 125 | #calculate precision, recall and f1 126 | #this is probably very inefficient 127 | for ex in range(examples): 128 | actual = [n.item() for n in names[:,ex][1:]] 129 | predicted = [p.item() for p in preds[:,ex][1:]] 130 | _precision, _recall, _f1 = utils.token_precision_recall(predicted, actual, unk_idx, pad_idx) 131 | precision += _precision 132 | recall += _recall 133 | f1 += _f1 134 | 135 | #calculate loss 136 | loss = criterion(output[1:].view(-1, output.shape[2]), names[1:].view(-1)) 137 | 138 | #calculate gradients wrt loss 139 | loss.backward() 140 | 141 | #clip gradients 142 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 143 | 144 | #update parameters 145 | optimizer.step() 146 | 147 | epoch_loss += loss.item() 148 | 149 | return epoch_loss / len(iterator), precision/n_examples, recall/n_examples, f1/n_examples 150 | 151 | def evaluate(model, iterator, criterion): 152 | 153 | #turn off bn/dropout 154 | model.eval() 155 | 156 | epoch_loss = 0 157 | n_examples = 0 158 | precision = 0 159 | recall = 0 160 | f1 = 0 161 | 162 | #ensures no gradients are calculated, speeds up calculations 163 | with torch.no_grad(): 164 | 165 | for _, batch in enumerate(iterator): 166 | 167 | bodies = batch.body.to(device) 168 | names = batch.name.to(device) 169 | 170 | output = model(bodies, names, 0) #set teacher forcing to zero 171 | 172 | preds = output.max(2)[1] 173 | 174 | examples = names.shape[1] 175 | n_examples += examples 176 | 177 | for ex in range(examples): 178 | actual = [n.item() for n in names[:,ex][1:]] 179 | predicted = [p.item() for p in preds[:,ex][1:]] 180 | _precision, _recall, _f1 = utils.token_precision_recall(predicted, actual, unk_idx, pad_idx) 181 | precision += _precision 182 | recall += _recall 183 | f1 += _f1 184 | 185 | loss = criterion(output[1:].view(-1, output.shape[2]), names[1:].view(-1)) 186 | 187 | epoch_loss += loss.item() 188 | 189 | return epoch_loss / len(iterator), precision/n_examples, recall/n_examples, f1/n_examples 190 | 191 | best_test_loss = float('inf') 192 | 193 | if not os.path.isdir(f'{args.checkpoints_dir}'): 194 | os.makedirs(f'{args.checkpoints_dir}') 195 | 196 | for epoch in range(args.epochs): 197 | 198 | train_loss, train_precision, train_recall, train_f1 = train(model, train_iter, optimizer, criterion, args.clip) 199 | test_loss, test_precision, test_recall, test_f1 = evaluate(model, test_iter, criterion) 200 | 201 | if test_loss < best_test_loss: 202 | best_test_loss = test_loss 203 | torch.save(model.state_dict(), f'{args.checkpoints_dir}/{args.project}-conv-model.pt') 204 | 205 | print(f'| Epoch: {epoch+1:03} | Train Loss: {train_loss:.3f} | Train F1: {train_f1:.3f} | Test Loss: {test_loss:.3f} | Test F1: {test_f1:.3f}') 206 | -------------------------------------------------------------------------------- /run_copy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | 5 | from torchtext.data import Field, LabelField 6 | from torchtext.data import TabularDataset 7 | from torchtext.data import BucketIterator 8 | 9 | import os 10 | import argparse 11 | import random 12 | 13 | import models 14 | import utils 15 | 16 | parser = argparse.ArgumentParser(description='Implemention of \'A Convolutional Attention Network for Extreme Summarization of Source Code\'', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | 18 | parser.add_argument('--project', default='cassandra', type=str, help='Which project to run on') 19 | parser.add_argument('--data_dir', default='data', type=str, help='Where to find the training data') 20 | parser.add_argument('--checkpoints_dir', default='checkpoints', type=str, help='Where to save the model checkpoints') 21 | parser.add_argument('--no_cuda', action='store_true', help='Use this flag to stop using the GPU') 22 | parser.add_argument('--min_freq', default=2, help='Minimum times a token must appear in the dataset to not be unk\'d') 23 | parser.add_argument('--batch_size', default=64, type=int) 24 | parser.add_argument('--emb_dim', default=128, type=int) 25 | parser.add_argument('--k1', default=32, type=int) 26 | parser.add_argument('--k2', default=16, type=int) 27 | parser.add_argument('--w1', default=18, type=int) 28 | parser.add_argument('--w2', default=19, type=int) 29 | parser.add_argument('--w3', default=2, type=int) 30 | parser.add_argument('--dropout', default=0.4, type=float) 31 | parser.add_argument('--clip', default=0.75, type=float) 32 | parser.add_argument('--epochs', default=100, type=int) 33 | parser.add_argument('--seed', default=1234, type=int) 34 | 35 | args = parser.parse_args() 36 | 37 | assert os.path.exists(f'{args.data_dir}/{args.project}_train.json') 38 | assert os.path.exists(f'{args.data_dir}/{args.project}_test.json') 39 | 40 | if not os.path.exists(f'{args.checkpoints_dir}'): 41 | os.mkdir(f'{args.checkpoints_dir}') 42 | 43 | #make deterministic 44 | torch.backends.cudnn.deterministic = True 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed_all(args.seed) 47 | random.seed(args.seed) 48 | 49 | #get available device 50 | device = torch.device('cuda' if (torch.cuda.is_available() and not args.no_cuda) else 'cpu') 51 | 52 | #set up fields 53 | BODY = Field() 54 | NAME = Field() 55 | 56 | fields = {'name': ('name', NAME), 'body': ('body', BODY)} 57 | 58 | #get data from json 59 | train, test = TabularDataset.splits( 60 | path = 'data', 61 | train = f'{args.project}_train.json', 62 | test = f'{args.project}_test.json', 63 | format = 'json', 64 | fields = fields 65 | ) 66 | 67 | #build the vocabulary 68 | BODY.build_vocab(train.body, train.name, min_freq=args.min_freq) 69 | NAME.build_vocab(train.body, train.name, min_freq=args.min_freq) 70 | 71 | # make iterator for splits 72 | train_iter, test_iter = BucketIterator.splits( 73 | (train, test), 74 | batch_size=args.batch_size, 75 | sort_key=lambda x: len(x.name), 76 | repeat=False, 77 | device=-1 if device == 'cpu' else None) 78 | 79 | #calculate these for the model 80 | vocab_size = len(BODY.vocab) 81 | pad_idx = BODY.vocab.stoi[''] 82 | unk_idx = BODY.vocab.stoi[''] 83 | 84 | #initialize model 85 | model = models.CopyAttentionNetwork(vocab_size, args.emb_dim, args.k1, args.k2, args.w1, args.w2, args.w3, args.dropout, pad_idx) 86 | 87 | #place on GPU if available 88 | model = model.to(device) 89 | 90 | #initialize optimizer 91 | optimizer = optim.RMSprop(model.parameters(), momentum=0.9, lr=1e-3) 92 | 93 | def train(model, iterator, optimizer, clip): 94 | 95 | #turn on dropout/bn 96 | model.train() 97 | 98 | epoch_loss = 0 99 | n_examples = 0 100 | precision = 0 101 | recall = 0 102 | f1 = 0 103 | 104 | for i, batch in enumerate(iterator): 105 | 106 | bodies = batch.body 107 | names = batch.name 108 | 109 | optimizer.zero_grad() 110 | 111 | I = torch.zeros(names.shape[0], names.shape[1], bodies.shape[0]).to(device) 112 | 113 | _ones = torch.ones(bodies.shape[0]).to(device) 114 | _zeros = torch.zeros(bodies.shape[0]).to(device) 115 | 116 | #create the I tensor 117 | #the length of the method body where elements are: 118 | # 1 in the position where the current token you are trying to predict are in the body 119 | # 0 otherwise 120 | for j, name in enumerate(names): 121 | for k, token in enumerate(name): 122 | I[j,k,:] = torch.where(bodies[:,k] == token, _ones, _zeros) 123 | 124 | #output is predictions 125 | #kappas are copy-attention over the body 126 | #lambdas are probability of copy over generate from vocab 127 | output, kappas, lambdas = model(bodies, names) 128 | 129 | examples = names.shape[1] 130 | n_examples += examples 131 | 132 | copy_preds = kappas.max(2)[1] 133 | vocab_preds = output.max(2)[1] 134 | 135 | for ex in range(examples): 136 | predicted = [] 137 | actual = [n.item() for n in names[:,ex][1:]] 138 | for n, l in enumerate(lambdas[:,ex][1:], start=1): 139 | if l.item() >= 0.5: #do copy 140 | copied_token_position = copy_preds[n,ex] 141 | predicted.append(bodies[copied_token_position, ex].item()) 142 | else: 143 | predicted.append(vocab_preds[n,ex].item()) 144 | _precision, _recall, _f1 = utils.token_precision_recall(predicted, actual, unk_idx) 145 | precision += _precision 146 | recall += _recall 147 | f1 += _f1 148 | 149 | #reshape parameters 150 | output = output[1:].view(-1, output.shape[2]) 151 | kappas = kappas[1:].view(-1, kappas.shape[2]) 152 | lambdas = lambdas[1:].view(-1) 153 | I = I[1:].view(-1, I.shape[2]) 154 | names = names[1:].view(-1, 1) 155 | 156 | #probability of using copy and model predictions from vocab 157 | use_copy = torch.log(lambdas + 10e-8) + torch.sum(I * kappas, dim=1) 158 | use_model = torch.log(1 - lambdas + 10e-8) + torch.gather(output, 1, names).squeeze(1) 159 | 160 | #calculate loss 161 | loss = torch.mean(utils.logsumexp(use_copy, use_model)) 162 | 163 | #calculate gradients 164 | loss.backward() 165 | 166 | #clip to prevent exploding gradients 167 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 168 | 169 | #update parameters 170 | optimizer.step() 171 | 172 | epoch_loss += loss.item() 173 | 174 | return epoch_loss / len(iterator), precision/n_examples, recall/n_examples, f1/n_examples 175 | 176 | def evaluate(model, iterator): 177 | 178 | #turn off bn/dropout 179 | model.eval() 180 | 181 | epoch_loss = 0 182 | n_examples = 0 183 | precision = 0 184 | recall = 0 185 | f1 = 0 186 | 187 | with torch.no_grad(): 188 | 189 | for i, batch in enumerate(iterator): 190 | 191 | bodies = batch.body 192 | names = batch.name 193 | 194 | I = torch.zeros(names.shape[0], names.shape[1], bodies.shape[0]).to(device) 195 | 196 | _ones = torch.ones(bodies.shape[0]).to(device) 197 | _zeros = torch.zeros(bodies.shape[0]).to(device) 198 | 199 | for j, name in enumerate(names): 200 | for k, token in enumerate(name): 201 | I[j,k,:] = torch.where(bodies[:,k] == token, _ones, _zeros) 202 | 203 | output, kappas, lambdas = model(bodies, names, 0) #set teacher forcing to zero 204 | 205 | examples = names.shape[1] 206 | n_examples += examples 207 | 208 | copy_preds = kappas.max(2)[1] 209 | vocab_preds = output.max(2)[1] 210 | 211 | for ex in range(examples): 212 | predicted = [] 213 | actual = [n.item() for n in names[:,ex][1:]] 214 | for n, l in enumerate(lambdas[:,ex][1:], start=1): 215 | if l.item() >= 0.5: #do copy 216 | copied_token_position = copy_preds[n,ex] 217 | predicted.append(bodies[copied_token_position, ex].item()) 218 | else: 219 | predicted.append(vocab_preds[n,ex].item()) 220 | _precision, _recall, _f1 = utils.token_precision_recall(predicted, actual, unk_idx) 221 | precision += _precision 222 | recall += _recall 223 | f1 += _f1 224 | 225 | output = output[1:].view(-1, output.shape[2]) 226 | kappas = kappas[1:].view(-1, kappas.shape[2]) 227 | lambdas = lambdas[1:].view(-1) 228 | I = I[1:].view(-1, I.shape[2]) 229 | names = names[1:].view(-1,1) 230 | 231 | use_copy = torch.log(lambdas + 10e-8) + torch.sum(I * kappas, dim=1) 232 | use_model = torch.log(1 - lambdas + 10e-8) + torch.gather(output, 1, names).squeeze(1) 233 | 234 | loss = torch.mean(utils.logsumexp(use_copy, use_model)) 235 | 236 | epoch_loss += loss.item() 237 | 238 | return epoch_loss / len(iterator), precision/n_examples, recall/n_examples, f1/n_examples 239 | 240 | best_test_loss = float('inf') 241 | 242 | if not os.path.isdir(f'{args.checkpoints_dir}'): 243 | os.makedirs(f'{args.checkpoints_dir}') 244 | 245 | for epoch in range(args.epochs): 246 | 247 | train_loss, train_precision, train_recall, train_f1 = train(model, train_iter, optimizer, args.clip) 248 | test_loss, test_precision, test_recall, test_f1 = evaluate(model, test_iter) 249 | 250 | if test_loss < best_test_loss: 251 | best_test_loss = test_loss 252 | torch.save(model.state_dict(), f'{args.checkpoints_dir}/{args.project}-copy-model.pt') 253 | 254 | print(f'| Epoch: {epoch+1:03} | Train Loss: {train_loss:.3f} | Train F1: {train_f1:.3f} | Test Loss: {test_loss:.3f} | Test F1: {test_f1:.3f}') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def token_precision_recall(y_pred, y_true, unk_idx, pad_idx): 4 | """ 5 | Get the precision/recall for the given token. 6 | :param predicted_parts: a list of predicted parts 7 | :param gold_set_parts: a list of the golden parts 8 | :return: precision, recall, f1 as floats 9 | """ 10 | 11 | ground_truth = y_true[:] 12 | 13 | tp = 0 14 | for subtoken in set(y_pred): 15 | if subtoken == unk_idx or subtoken == pad_idx: 16 | continue 17 | if subtoken in ground_truth: 18 | ground_truth.remove(subtoken) 19 | tp += 1 20 | 21 | assert tp <= len(y_pred), (tp, len(y_pred)) 22 | 23 | if len(y_pred) > 0: 24 | precision = float(tp) / len(y_pred) 25 | else: 26 | precision = 0 27 | 28 | assert tp <= len(y_true), (y_true) 29 | 30 | if len(y_true) > 0: 31 | recall = float(tp) / len(y_true) 32 | else: 33 | recall = 0 34 | 35 | if precision + recall > 0: 36 | f1 = 2 * precision * recall / (precision + recall) 37 | else: 38 | f1 = 0. 39 | 40 | return precision, recall, f1 41 | 42 | def logsumexp(x, y): 43 | max = torch.where(x > y, x, y) 44 | min = torch.where(x > y, y, x) 45 | return torch.log1p(torch.exp(min - max)) + max --------------------------------------------------------------------------------