├── .gitignore ├── .labml.yaml ├── LICENSE ├── evaluate.py ├── extract_code.py ├── logs └── simple_lstm │ └── 2a86d636936d11eab8740dffb016e7b1 │ ├── artifacts.yaml │ ├── checkpoints │ └── 72237 │ │ ├── base.pth │ │ └── info.json │ ├── indicators.yaml │ ├── run.yaml │ ├── source.diff │ ├── sqlite.db │ └── tensorboard │ └── events.out.tfevents.1589190774.varuna-small.31006.5.v2 ├── model.py ├── parser ├── __init__.py ├── load.py └── tokenizer.py ├── python-autocomplete.png ├── readme.md ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .ipynb_checkpoints/ 3 | .idea/ 4 | __pycache__/ 5 | .DS_Store 6 | .*.swp 7 | logs/ 8 | -------------------------------------------------------------------------------- /.labml.yaml: -------------------------------------------------------------------------------- 1 | check_repo_dirty: False -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Varuna Jayasiri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tokenize 4 | from io import BytesIO 5 | from typing import NamedTuple, List, Tuple 6 | 7 | import torch 8 | import torch.nn 9 | from labml import experiment, monit, logger 10 | from labml.logger import Text, Style 11 | 12 | import parser.load 13 | import parser.tokenizer 14 | from model import SimpleLstmModel 15 | from parser import tokenizer 16 | 17 | # Experiment configuration to load checkpoints 18 | experiment.create(name="simple_lstm", 19 | comment="Simple LSTM") 20 | 21 | # device to evaluate on 22 | device = torch.device("cuda:0") 23 | 24 | # Beam search 25 | BEAM_SIZE = 8 26 | 27 | 28 | class Suggestions(NamedTuple): 29 | codes: List[List[int]] 30 | matched: List[int] 31 | scores: List[float] 32 | 33 | 34 | class ScoredItem(NamedTuple): 35 | score: float 36 | idx: Tuple 37 | 38 | 39 | class Predictor: 40 | """ 41 | Predicts the next few characters 42 | """ 43 | 44 | NEW_LINE_TOKENS = {tokenize.NEWLINE, tokenize.NL} 45 | INDENT_TOKENS = {tokenize.INDENT, tokenize.DEDENT} 46 | 47 | def __init__(self, model, lstm_layers, lstm_size): 48 | self.__model = model 49 | 50 | # Initial state 51 | self._h0 = torch.zeros((lstm_layers, 1, lstm_size), device=device) 52 | self._c0 = torch.zeros((lstm_layers, 1, lstm_size), device=device) 53 | 54 | # Last line of source code read 55 | self._last_line = "" 56 | 57 | self._tokens: List[tokenize.TokenInfo] = [] 58 | 59 | # Last token, because we need to input that to the model for inference 60 | self._last_token = 0 61 | 62 | # Last bit of the input string 63 | self._untokenized = "" 64 | 65 | # For timing 66 | self.time_add = 0 67 | self.time_predict = 0 68 | self.time_check = 0 69 | 70 | def __clear_tokens(self, lines: int): 71 | """ 72 | Clears old lines from tokens 73 | """ 74 | for i, t in enumerate(self._tokens): 75 | if t.type in self.NEW_LINE_TOKENS: 76 | lines -= 1 77 | if lines == 0: 78 | self._tokens = self._tokens[i + 1:] 79 | return 80 | 81 | raise RuntimeError() 82 | 83 | def __clear_untokenized(self, tokens): 84 | """ 85 | Remove tokens not properly tokenized; 86 | i.e. the last token, unless it's a new line 87 | """ 88 | 89 | limit = 0 90 | for i in reversed(range(len(tokens))): 91 | if tokens[i].type in self.NEW_LINE_TOKENS: 92 | limit = i + 1 93 | break 94 | else: 95 | limit = i 96 | break 97 | 98 | return tokens[:limit] 99 | 100 | @staticmethod 101 | def __get_tokens(it): 102 | tokens: List[tokenize.TokenInfo] = [] 103 | 104 | try: 105 | for t in it: 106 | if t.type in tokenizer.SKIP_TOKENS: 107 | continue 108 | if t.type == tokenize.NEWLINE and t.string == '': 109 | continue 110 | if t.type == tokenize.DEDENT: 111 | continue 112 | if t.type == tokenize.ERRORTOKEN: 113 | continue 114 | tokens.append(t) 115 | except tokenize.TokenError as e: 116 | if not e.args[0].startswith('EOF in'): 117 | print(e) 118 | except IndentationError as e: 119 | print(e) 120 | 121 | return tokens 122 | 123 | def add(self, content): 124 | """ 125 | Add a string of code, this shouldn't have multiple lines 126 | """ 127 | start_time = time.time() 128 | self._last_line += content 129 | 130 | # Remove old lines 131 | lines = self._last_line.split("\n") 132 | if len(lines) > 1: 133 | assert len(lines) <= 3 134 | if lines[-1] == '': 135 | if len(lines) > 2: 136 | self.__clear_tokens(len(lines) - 2) 137 | lines = lines[-2:] 138 | else: 139 | self.__clear_tokens(len(lines) - 1) 140 | lines = lines[-1:] 141 | 142 | line = '\n'.join(lines) 143 | 144 | self._last_line = line 145 | 146 | # Parse the last line 147 | tokens_it = tokenize.tokenize(BytesIO(self._last_line.encode('utf-8')).readline) 148 | tokens = self.__get_tokens(tokens_it) 149 | 150 | # Remove last token 151 | tokens = self.__clear_untokenized(tokens) 152 | 153 | # Check if previous tokens is a prefix 154 | assert len(tokens) >= len(self._tokens) 155 | 156 | for t1, t2 in zip(self._tokens, tokens): 157 | assert t1.type == t2.type 158 | assert t1.string == t2.string 159 | 160 | # Get the untokenized string 161 | if len(tokens) > 0: 162 | assert tokens[-1].end[0] == 1 163 | self._untokenized = line[tokens[-1].end[1]:] 164 | else: 165 | self._untokenized = line 166 | 167 | # Update previous tokens and the model state 168 | if len(tokens) > len(self._tokens): 169 | self.__update_state(tokens[len(self._tokens):]) 170 | self._tokens = tokens 171 | 172 | self.time_add += time.time() - start_time 173 | 174 | def get_predictions(self, codes_batch: List[List[int]]): 175 | # Sequence length and batch size 176 | seq_len = len(codes_batch[0]) 177 | batch_size = len(codes_batch) 178 | 179 | for codes in codes_batch: 180 | assert seq_len == len(codes) 181 | 182 | # Input to the model 183 | x = torch.tensor(codes_batch, device=device) 184 | x = x.transpose(0, 1) 185 | 186 | # Expand state 187 | h0 = self._h0.expand(-1, batch_size, -1).contiguous() 188 | c0 = self._c0.expand(-1, batch_size, -1).contiguous() 189 | 190 | # Get predictions 191 | prediction, _, _ = self.__model(x, h0, c0) 192 | 193 | assert prediction.shape == (seq_len, len(codes_batch), tokenizer.VOCAB_SIZE) 194 | 195 | # Final prediction 196 | prediction = prediction[-1, :, :] 197 | 198 | return prediction.detach().cpu().numpy() 199 | 200 | def get_suggestion(self) -> str: 201 | # Start of with the last token 202 | suggestions = [Suggestions([[self._last_token]], 203 | [0], 204 | [1.])] 205 | 206 | # Do a beam search, up to the untokenized string length and 10 more 207 | for step in range(10 + len(self._untokenized)): 208 | sugg = suggestions[step] 209 | batch_size = len(sugg.codes) 210 | 211 | # Break if empty 212 | if batch_size == 0: 213 | break 214 | 215 | # Get predictions 216 | start_time = time.time() 217 | predictions = self.get_predictions(sugg.codes) 218 | self.time_predict += time.time() - start_time 219 | 220 | start_time = time.time() 221 | # Get all choices 222 | choices = [] 223 | for idx in range(batch_size): 224 | for code in range(tokenizer.VOCAB_SIZE): 225 | score = sugg.scores[idx] * predictions[idx, code] 226 | choices.append(ScoredItem( 227 | score * math.sqrt(sugg.matched[idx] + tokenizer.LENGTHS[code]), 228 | (idx, code))) 229 | # Sort them 230 | choices.sort(key=lambda x: x.score, reverse=True) 231 | 232 | # Collect the ones that match untokenized string 233 | codes = [] 234 | matches = [] 235 | scores = [] 236 | len_untokenized = len(self._untokenized) 237 | 238 | for choice in choices: 239 | prev_idx = choice.idx[0] 240 | code = choice.idx[1] 241 | 242 | token = tokenizer.DESERIALIZE[code] 243 | if token.type in tokenizer.LINE_BREAK: 244 | continue 245 | 246 | # Previously mached length 247 | matched = sugg.matched[prev_idx] 248 | 249 | if matched >= len_untokenized: 250 | # Increment the length if already matched 251 | matched += tokenizer.LENGTHS[code] 252 | else: 253 | # Otherwise check if the new token string matches 254 | unmatched = tokenizer.DECODE[code][sugg.codes[prev_idx][-1]] 255 | to_match = self._untokenized[matched:] 256 | 257 | if len(unmatched) < len(to_match): 258 | if not to_match.startswith(unmatched): 259 | continue 260 | else: 261 | matched += len(unmatched) 262 | else: 263 | if not unmatched.startswith(to_match): 264 | continue 265 | else: 266 | matched += len(unmatched) 267 | 268 | # Collect new item 269 | codes.append(sugg.codes[prev_idx] + [code]) 270 | matches.append(matched) 271 | score = sugg.scores[prev_idx] * predictions[prev_idx, code] 272 | scores.append(score) 273 | 274 | # Stop at `BEAM_SIZE` 275 | if len(scores) == BEAM_SIZE: 276 | break 277 | 278 | suggestions.append(Suggestions(codes, matches, scores)) 279 | 280 | self.time_check += time.time() - start_time 281 | 282 | # Collect suggestions of all lengths 283 | choices = [] 284 | for s_idx, sugg in enumerate(suggestions): 285 | batch_size = len(sugg.codes) 286 | for idx in range(batch_size): 287 | length = sugg.matched[idx] - len(self._untokenized) 288 | if length <= 2: 289 | continue 290 | choice = sugg.scores[idx] * math.sqrt(length - 1) 291 | choices.append(ScoredItem(choice, (s_idx, idx))) 292 | choices.sort(key=lambda x: x.score, reverse=True) 293 | 294 | # Return the best option 295 | for choice in choices: 296 | codes = suggestions[choice.idx[0]].codes[choice.idx[1]] 297 | res = "" 298 | prev = self._last_token 299 | for code in codes[1:]: 300 | res += tokenizer.DECODE[code][prev] 301 | prev = code 302 | 303 | res = res[len(self._untokenized):] 304 | 305 | # Skip if blank 306 | if res.strip() == "": 307 | continue 308 | 309 | return res 310 | 311 | # Return blank if there are no options 312 | return '' 313 | 314 | def __update_state(self, tokens): 315 | """ 316 | Update model state 317 | """ 318 | data = parser.tokenizer.parse(tokens) 319 | data = parser.tokenizer.encode(data) 320 | x = [self._last_token] + data[:-1] 321 | self._last_token = data[-1] 322 | 323 | x = torch.tensor([x], device=device) 324 | x = x.transpose(0, 1) 325 | _, _, (hn, cn) = self.__model(x, self._h0, self._c0) 326 | self._h0 = hn.detach() 327 | self._c0 = cn.detach() 328 | 329 | 330 | class Evaluator: 331 | def __init__(self, model, file: parser.load.EncodedFile, 332 | lstm_layers, lstm_size, 333 | skip_spaces=False): 334 | self.__content = self.get_content(file.codes) 335 | self.__skip_spaces = skip_spaces 336 | self.__predictor = Predictor(model, lstm_layers, lstm_size) 337 | 338 | @staticmethod 339 | def get_content(codes: List[int]): 340 | tokens = parser.tokenizer.decode(codes) 341 | content = parser.tokenizer.to_string(tokens) 342 | return content.split('\n') 343 | 344 | def eval(self): 345 | keys_saved = 0 346 | 347 | for line, content in enumerate(self.__content): 348 | # Keep reference to rest of the line 349 | rest_of_line = content 350 | 351 | # Build the line for logging with colors 352 | # The line number 353 | logs = [(f"{line: 4d}: ", Text.meta)] 354 | 355 | # Type the line character by character 356 | while rest_of_line != '': 357 | suggestion = self.__predictor.get_suggestion() 358 | 359 | # If suggestion matches 360 | if suggestion != '' and rest_of_line.startswith(suggestion): 361 | # Log 362 | logs.append((suggestion[0], [Style.underline, Text.danger])) 363 | logs.append((suggestion[1:], Style.underline)) 364 | 365 | keys_saved += len(suggestion) - 1 366 | 367 | # Skip the prediction text 368 | rest_of_line = rest_of_line[len(suggestion):] 369 | 370 | # Add text to the predictor 371 | self.__predictor.add(suggestion) 372 | 373 | # If the suggestion doesn't match 374 | else: 375 | # Add the next character 376 | self.__predictor.add(rest_of_line[0]) 377 | logs.append((rest_of_line[0], Text.subtle)) 378 | rest_of_line = rest_of_line[1:] 379 | 380 | # Add a new line 381 | self.__predictor.add("\n") 382 | 383 | # Log the line 384 | logger.log(logs) 385 | 386 | # Log time taken for the file 387 | logger.inspect(add=self.__predictor.time_add, 388 | check=self.__predictor.time_check, 389 | predict=self.__predictor.time_predict) 390 | 391 | total_keys = sum([len(c) for c in self.__content]) 392 | logger.inspect(keys_saved=keys_saved, 393 | percentage_saved=100 * keys_saved / total_keys, 394 | total_keys=total_keys, 395 | total_lines=len(self.__content)) 396 | 397 | 398 | def main(): 399 | lstm_size = 1024 400 | lstm_layers = 3 401 | 402 | with monit.section("Loading data"): 403 | files = parser.load.load_files() 404 | train_files, valid_files = parser.load.split_train_valid(files, is_shuffle=False) 405 | 406 | with monit.section("Create model"): 407 | model = SimpleLstmModel(encoding_size=tokenizer.VOCAB_SIZE, 408 | embedding_size=tokenizer.VOCAB_SIZE, 409 | lstm_size=lstm_size, 410 | lstm_layers=lstm_layers) 411 | model.to(device) 412 | 413 | experiment.add_pytorch_models({'base': model}) 414 | 415 | experiment.load("2a86d636936d11eab8740dffb016e7b1", 72237) 416 | 417 | # For debugging with a specific piece of source code 418 | # predictor = Predictor(model, lstm_layers, lstm_size) 419 | # for s in ['""" """\n', "from __future__"]: 420 | # predictor.add(s) 421 | # s = predictor.get_suggestion() 422 | 423 | # Evaluate all the files in validation set 424 | for file in valid_files: 425 | logger.log(str(file.path), Text.heading) 426 | evaluator = Evaluator(model, file, 427 | lstm_layers, lstm_size, 428 | skip_spaces=True) 429 | evaluator.eval() 430 | 431 | 432 | if __name__ == '__main__': 433 | main() 434 | -------------------------------------------------------------------------------- /extract_code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Parse all files and write to a single file 5 | """ 6 | import os 7 | from pathlib import Path 8 | from typing import List, NamedTuple 9 | 10 | from labml import logger, monit 11 | 12 | from parser import tokenizer 13 | from parser.tokenizer import encode, parse_string 14 | 15 | COMMENT = '#' 16 | MULTI_COMMENT = '"""' 17 | 18 | 19 | class _PythonFile(NamedTuple): 20 | relative_path: str 21 | project: str 22 | path: Path 23 | 24 | 25 | class _GetPythonFiles: 26 | """ 27 | Get list of python files and their paths inside `data/source` folder 28 | """ 29 | 30 | def __init__(self): 31 | self.source_path = Path(os.getcwd()) / 'data' / 'source' 32 | self.files: List[_PythonFile] = [] 33 | self.get_python_files(self.source_path) 34 | 35 | logger.inspect([f.path for f in self.files]) 36 | 37 | def add_file(self, path: Path): 38 | """ 39 | Add a file to the list of tiles 40 | """ 41 | project = path.relative_to(self.source_path).parents 42 | project = project[len(project) - 2] 43 | relative_path = path.relative_to(self.source_path / project) 44 | 45 | self.files.append(_PythonFile(relative_path=str(relative_path), 46 | project=str(project), 47 | path=path)) 48 | 49 | def get_python_files(self, path: Path): 50 | """ 51 | Recursively collect files 52 | """ 53 | for p in path.iterdir(): 54 | if p.is_dir(): 55 | self.get_python_files(p) 56 | else: 57 | if p.suffix == '.py': 58 | self.add_file(p) 59 | 60 | 61 | def _fix_indentation(parsed: List[tokenizer.ParsedToken]) -> List[tokenizer.ParsedToken]: 62 | """ 63 | Change indentation tokens. Remove `DEDENT` tokens and 64 | add `INDENT` tokens to each line. 65 | This is easier for prediction. 66 | """ 67 | res: List[tokenizer.ParsedToken] = [] 68 | indentation = 0 69 | indented = False 70 | for t in parsed: 71 | if t.type == tokenizer.TokenType.indent: 72 | indentation += 1 73 | elif t.type == tokenizer.TokenType.dedent: 74 | indentation -= 1 75 | elif t.type in [tokenizer.TokenType.new_line, 76 | tokenizer.TokenType.eof]: 77 | indented = False 78 | res.append(t) 79 | else: 80 | if not indented: 81 | for _ in range(indentation): 82 | res.append(tokenizer.ParsedToken(tokenizer.TokenType.indent, 0)) 83 | indented = True 84 | 85 | res.append(t) 86 | 87 | return res 88 | 89 | 90 | def _remove_comments(parsed: List[tokenizer.ParsedToken]) -> List[tokenizer.ParsedToken]: 91 | """ 92 | Remove comment tokens 93 | """ 94 | res = [] 95 | for p in parsed: 96 | if p.type == tokenizer.TokenType.comment: 97 | continue 98 | else: 99 | res.append(p) 100 | 101 | return res 102 | 103 | 104 | def _remove_empty_lines(parsed: List[tokenizer.ParsedToken]) -> List[tokenizer.ParsedToken]: 105 | """ 106 | Remove empty lines 107 | """ 108 | 109 | tokens = [tokenizer.TokenType.new_line, tokenizer.TokenType.new_line] 110 | res = [] 111 | for p in parsed: 112 | for i in range(1): 113 | tokens[i] = tokens[i + 1] 114 | tokens[-1] = p.type 115 | all_new_line = True 116 | for t in tokens: 117 | if t != tokenizer.TokenType.new_line: 118 | all_new_line = False 119 | 120 | if all_new_line: 121 | continue 122 | else: 123 | res.append(p) 124 | 125 | return res 126 | 127 | 128 | def _read_file(path: Path) -> List[int]: 129 | """ 130 | Read and encode a file 131 | """ 132 | with open(str(path)) as f: 133 | content = f.read() 134 | 135 | parsed = parse_string(content) 136 | parsed = _remove_comments(parsed) 137 | parsed = _remove_empty_lines(parsed) 138 | parsed = _fix_indentation(parsed) 139 | serialized = encode(parsed) 140 | 141 | # deserialized = tokenizer.deserialize(serialized) 142 | # for i in range(len(serialized)): 143 | # assert deserialized[i] == parsed[i] 144 | # 145 | # res = to_text(deserialized) 146 | # print(res) 147 | 148 | return serialized 149 | 150 | 151 | def main(): 152 | source_files = _GetPythonFiles().files 153 | 154 | logger.inspect(source_files) 155 | 156 | with open(str(Path(os.getcwd()) / 'data' / 'all.py'), 'w') as f: 157 | for i, source in monit.enum("Parse", source_files): 158 | serialized = _read_file(source.path) 159 | # return 160 | serialized = [str(t) for t in serialized] 161 | f.write(f"{str(source.path)}\n") 162 | f.write(" ".join(serialized) + "\n") 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/artifacts.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/checkpoints/72237/base.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vpj/python_autocomplete/cd61eb37c9ff9fef835173e4fba460d12e8f8d2e/logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/checkpoints/72237/base.pth -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/checkpoints/72237/info.json: -------------------------------------------------------------------------------- 1 | {"base": "base.pth"} -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/indicators.yaml: -------------------------------------------------------------------------------- 1 | train.loss: 2 | class_name: Queue 3 | is_print: true 4 | name: train.loss 5 | queue_size: 500 6 | valid.loss: 7 | class_name: Queue 8 | is_print: true 9 | name: valid.loss 10 | queue_size: 500 11 | -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/run.yaml: -------------------------------------------------------------------------------- 1 | comment: Simple LSTM 2 | commit: 301ee43c2edb9523ca7b66f258c7908c8e5be3a6 3 | commit_message: "\U0001F41B validation set wasn't passed" 4 | is_dirty: true 5 | load_run: null 6 | notes: '' 7 | python_file: /home/varuna/ml/python_autocomplete/train.py 8 | start_step: 0 9 | tags: 10 | - lstm 11 | - simple 12 | trial_date: '2020-05-11' 13 | trial_time: '15:22:46' 14 | uuid: 2a86d636936d11eab8740dffb016e7b1 15 | -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/source.diff: -------------------------------------------------------------------------------- 1 | diff --git a/.gitignore b/.gitignore 2 | index 37299d1..b2dc025 100644 3 | --- a/.gitignore 4 | +++ b/.gitignore 5 | @@ -4,4 +4,4 @@ data/ 6 | __pycache__/ 7 | .DS_Store 8 | .*.swp 9 | -lab 10 | +logs 11 | diff --git a/.lab.yaml b/.lab.yaml 12 | index e69de29..1e5d0a2 100644 13 | --- a/.lab.yaml 14 | +++ b/.lab.yaml 15 | @@ -0,0 +1 @@ 16 | +check_repo_dirty: False 17 | \ No newline at end of file 18 | diff --git a/evaluate.py b/evaluate.py 19 | index 559cf62..1e66630 100644 20 | --- a/evaluate.py 21 | +++ b/evaluate.py 22 | @@ -6,25 +6,20 @@ from typing import NamedTuple, List, Tuple 23 | 24 | import torch 25 | import torch.nn 26 | +from lab import experiment, monit, logger 27 | +from lab.logger import Text, Style 28 | 29 | import parser.load 30 | import parser.tokenizer 31 | -from lab import colors 32 | -from lab.experiment.pytorch import Experiment 33 | from model import SimpleLstmModel 34 | from parser import tokenizer 35 | 36 | # Experiment configuration to load checkpoints 37 | -EXPERIMENT = Experiment(name="simple_lstm", 38 | - python_file=__file__, 39 | - comment="Simple LSTM", 40 | - check_repo_dirty=False, 41 | - is_log_python_file=False) 42 | - 43 | -logger = EXPERIMENT.logger 44 | +experiment.create(name="simple_lstm", 45 | + comment="Simple LSTM") 46 | 47 | # device to evaluate on 48 | -device = torch.device("cuda:1") 49 | +device = torch.device("cuda:0") 50 | 51 | # Beam search 52 | BEAM_SIZE = 8 53 | @@ -349,16 +344,13 @@ class Evaluator: 54 | def eval(self): 55 | keys_saved = 0 56 | 57 | - logger.info(total_keys=sum([len(c) for c in self.__content]), 58 | - total_lines=len(self.__content)) 59 | - 60 | for line, content in enumerate(self.__content): 61 | # Keep reference to rest of the line 62 | rest_of_line = content 63 | 64 | # Build the line for logging with colors 65 | # The line number 66 | - logs = [(f"{line: 4d}: ", colors.BrightColor.cyan)] 67 | + logs = [(f"{line: 4d}: ", Text.meta)] 68 | 69 | # Type the line character by character 70 | while rest_of_line != '': 71 | @@ -367,8 +359,8 @@ class Evaluator: 72 | # If suggestion matches 73 | if suggestion != '' and rest_of_line.startswith(suggestion): 74 | # Log 75 | - logs.append((suggestion[0], colors.BrightColor.green)) 76 | - logs.append((suggestion[1:], colors.BrightBackground.black)) 77 | + logs.append((suggestion[0], Text.danger)) 78 | + logs.append((suggestion[1:], Style.underline)) 79 | 80 | keys_saved += len(suggestion) - 1 81 | 82 | @@ -382,40 +374,45 @@ class Evaluator: 83 | else: 84 | # Add the next character 85 | self.__predictor.add(rest_of_line[0]) 86 | - logs.append((rest_of_line[0], None)) 87 | + logs.append((rest_of_line[0], Text.subtle)) 88 | rest_of_line = rest_of_line[1:] 89 | 90 | # Add a new line 91 | self.__predictor.add("\n") 92 | 93 | # Log the line 94 | - logger.log_color(logs) 95 | + logger.log(logs) 96 | 97 | # Log time taken for the file 98 | - logger.info(add=self.__predictor.time_add, 99 | - check=self.__predictor.time_check, 100 | - predict=self.__predictor.time_predict) 101 | - return keys_saved 102 | + logger.inspect(add=self.__predictor.time_add, 103 | + check=self.__predictor.time_check, 104 | + predict=self.__predictor.time_predict) 105 | + 106 | + total_keys = sum([len(c) for c in self.__content]) 107 | + logger.inspect(keys_saved=keys_saved, 108 | + percentage_saved=100 * keys_saved / total_keys, 109 | + total_keys=total_keys, 110 | + total_lines=len(self.__content)) 111 | 112 | 113 | def main(): 114 | lstm_size = 1024 115 | lstm_layers = 3 116 | 117 | - with logger.section("Loading data"): 118 | + with monit.section("Loading data"): 119 | files = parser.load.load_files() 120 | train_files, valid_files = parser.load.split_train_valid(files, is_shuffle=False) 121 | 122 | - with logger.section("Create model"): 123 | + with monit.section("Create model"): 124 | model = SimpleLstmModel(encoding_size=tokenizer.VOCAB_SIZE, 125 | embedding_size=tokenizer.VOCAB_SIZE, 126 | lstm_size=lstm_size, 127 | lstm_layers=lstm_layers) 128 | model.to(device) 129 | 130 | - EXPERIMENT.add_models({'base': model}) 131 | + experiment.add_pytorch_models({'base': model}) 132 | 133 | - EXPERIMENT.start_replay() 134 | + experiment.load("b3000660936b11ea9aa1c9a10ca3c0a4") 135 | 136 | # For debugging with a specific piece of source code 137 | # predictor = Predictor(model, lstm_layers, lstm_size) 138 | @@ -425,13 +422,11 @@ def main(): 139 | 140 | # Evaluate all the files in validation set 141 | for file in valid_files: 142 | - logger.log(str(file.path), color=colors.BrightColor.orange) 143 | + logger.log(str(file.path), Text.heading) 144 | evaluator = Evaluator(model, file, 145 | lstm_layers, lstm_size, 146 | skip_spaces=True) 147 | - keys_saved = evaluator.eval() 148 | - 149 | - logger.info(keys_saved=keys_saved) 150 | + evaluator.eval() 151 | 152 | 153 | if __name__ == '__main__': 154 | diff --git a/extract_code.py b/extract_code.py 155 | old mode 100644 156 | new mode 100755 157 | index 6eee33e..70ed0ab 158 | --- a/extract_code.py 159 | +++ b/extract_code.py 160 | @@ -7,15 +7,14 @@ import os 161 | from pathlib import Path 162 | from typing import List, NamedTuple 163 | 164 | -from lab.logger import Logger 165 | +from lab import logger, monit 166 | + 167 | from parser import tokenizer 168 | from parser.tokenizer import encode, parse_string 169 | 170 | COMMENT = '#' 171 | MULTI_COMMENT = '"""' 172 | 173 | -_logger = Logger() 174 | - 175 | 176 | class _PythonFile(NamedTuple): 177 | relative_path: str 178 | @@ -27,12 +26,13 @@ class _GetPythonFiles: 179 | """ 180 | Get list of python files and their paths inside `data/source` folder 181 | """ 182 | + 183 | def __init__(self): 184 | self.source_path = Path(os.getcwd()) / 'data' / 'source' 185 | self.files: List[_PythonFile] = [] 186 | self.get_python_files(self.source_path) 187 | 188 | - _logger.info([f.path for f in self.files]) 189 | + logger.inspect([f.path for f in self.files]) 190 | 191 | def add_file(self, path: Path): 192 | """ 193 | @@ -151,17 +151,15 @@ def _read_file(path: Path) -> List[int]: 194 | def main(): 195 | source_files = _GetPythonFiles().files 196 | 197 | - _logger.info(source_files) 198 | + logger.inspect(source_files) 199 | 200 | with open(str(Path(os.getcwd()) / 'data' / 'all.py'), 'w') as f: 201 | - with _logger.section("Parse", total_steps=len(source_files)): 202 | - for i, source in enumerate(source_files): 203 | - serialized = _read_file(source.path) 204 | - # return 205 | - serialized = [str(t) for t in serialized] 206 | - f.write(f"{str(source.path)}\n") 207 | - f.write(" ".join(serialized) + "\n") 208 | - _logger.progress(i + 1) 209 | + for i, source in monit.enum("Parse", source_files): 210 | + serialized = _read_file(source.path) 211 | + # return 212 | + serialized = [str(t) for t in serialized] 213 | + f.write(f"{str(source.path)}\n") 214 | + f.write(" ".join(serialized) + "\n") 215 | 216 | 217 | if __name__ == '__main__': 218 | diff --git a/logs/simple_lstm/checkpoints/536410/base_embedding.weight.npy b/logs/simple_lstm/checkpoints/536410/base_embedding.weight.npy 219 | deleted file mode 100644 220 | index 31d17b4..0000000 221 | Binary files a/logs/simple_lstm/checkpoints/536410/base_embedding.weight.npy and /dev/null differ 222 | diff --git a/logs/simple_lstm/checkpoints/536410/base_fc.bias.npy b/logs/simple_lstm/checkpoints/536410/base_fc.bias.npy 223 | deleted file mode 100644 224 | index 775c839..0000000 225 | Binary files a/logs/simple_lstm/checkpoints/536410/base_fc.bias.npy and /dev/null differ 226 | diff --git a/logs/simple_lstm/checkpoints/536410/base_fc.weight.npy b/logs/simple_lstm/checkpoints/536410/base_fc.weight.npy 227 | deleted file mode 100644 228 | index 8bfe8de..0000000 229 | Binary files a/logs/simple_lstm/checkpoints/536410/base_fc.weight.npy and /dev/null differ 230 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l0.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l0.npy 231 | deleted file mode 100644 232 | index a36c62c..0000000 233 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l0.npy and /dev/null differ 234 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l1.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l1.npy 235 | deleted file mode 100644 236 | index bd257e9..0000000 237 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l1.npy and /dev/null differ 238 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l2.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l2.npy 239 | deleted file mode 100644 240 | index f6fda59..0000000 241 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_hh_l2.npy and /dev/null differ 242 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l0.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l0.npy 243 | deleted file mode 100644 244 | index 60674f9..0000000 245 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l0.npy and /dev/null differ 246 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l1.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l1.npy 247 | deleted file mode 100644 248 | index 618cc33..0000000 249 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l1.npy and /dev/null differ 250 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l2.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l2.npy 251 | deleted file mode 100644 252 | index ba3eeb7..0000000 253 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.bias_ih_l2.npy and /dev/null differ 254 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l0.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l0.npy 255 | deleted file mode 100644 256 | index ac0f9ac..0000000 257 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l0.npy and /dev/null differ 258 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l1.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l1.npy 259 | deleted file mode 100644 260 | index 773440f..0000000 261 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l1.npy and /dev/null differ 262 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l2.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l2.npy 263 | deleted file mode 100644 264 | index ea331c0..0000000 265 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_hh_l2.npy and /dev/null differ 266 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l0.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l0.npy 267 | deleted file mode 100644 268 | index 467f734..0000000 269 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l0.npy and /dev/null differ 270 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l1.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l1.npy 271 | deleted file mode 100644 272 | index 631e007..0000000 273 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l1.npy and /dev/null differ 274 | diff --git a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l2.npy b/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l2.npy 275 | deleted file mode 100644 276 | index a9730e1..0000000 277 | Binary files a/logs/simple_lstm/checkpoints/536410/base_lstm.weight_ih_l2.npy and /dev/null differ 278 | diff --git a/logs/simple_lstm/checkpoints/536410/info.json b/logs/simple_lstm/checkpoints/536410/info.json 279 | deleted file mode 100644 280 | index 3e172d3..0000000 281 | --- a/logs/simple_lstm/checkpoints/536410/info.json 282 | +++ /dev/null 283 | @@ -1 +0,0 @@ 284 | -{"base": {"embedding.weight": "base_embedding.weight.npy", "lstm.weight_ih_l0": "base_lstm.weight_ih_l0.npy", "lstm.weight_hh_l0": "base_lstm.weight_hh_l0.npy", "lstm.bias_ih_l0": "base_lstm.bias_ih_l0.npy", "lstm.bias_hh_l0": "base_lstm.bias_hh_l0.npy", "lstm.weight_ih_l1": "base_lstm.weight_ih_l1.npy", "lstm.weight_hh_l1": "base_lstm.weight_hh_l1.npy", "lstm.bias_ih_l1": "base_lstm.bias_ih_l1.npy", "lstm.bias_hh_l1": "base_lstm.bias_hh_l1.npy", "lstm.weight_ih_l2": "base_lstm.weight_ih_l2.npy", "lstm.weight_hh_l2": "base_lstm.weight_hh_l2.npy", "lstm.bias_ih_l2": "base_lstm.bias_ih_l2.npy", "lstm.bias_hh_l2": "base_lstm.bias_hh_l2.npy", "fc.weight": "base_fc.weight.npy", "fc.bias": "base_fc.bias.npy"}} 285 | \ No newline at end of file 286 | diff --git a/model.py b/model.py 287 | index c736cd5..5e02eb9 100644 288 | --- a/model.py 289 | +++ b/model.py 290 | @@ -1,7 +1,7 @@ 291 | -import torch.nn 292 | +from torch import nn 293 | 294 | 295 | -class SimpleLstmModel(torch.nn.Module): 296 | +class SimpleLstmModel(nn.Module): 297 | def __init__(self, *, 298 | encoding_size, 299 | embedding_size, 300 | @@ -9,12 +9,12 @@ class SimpleLstmModel(torch.nn.Module): 301 | lstm_layers): 302 | super().__init__() 303 | 304 | - self.embedding = torch.nn.Embedding(encoding_size, embedding_size) 305 | - self.lstm = torch.nn.LSTM(input_size=embedding_size, 306 | - hidden_size=lstm_size, 307 | - num_layers=lstm_layers) 308 | - self.fc = torch.nn.Linear(lstm_size, encoding_size) 309 | - self.softmax = torch.nn.Softmax(dim=-1) 310 | + self.embedding = nn.Embedding(encoding_size, embedding_size) 311 | + self.lstm = nn.LSTM(input_size=embedding_size, 312 | + hidden_size=lstm_size, 313 | + num_layers=lstm_layers) 314 | + self.fc = nn.Linear(lstm_size, encoding_size) 315 | + self.softmax = nn.Softmax(dim=-1) 316 | 317 | def forward(self, x, h0, c0): 318 | # shape of x is [seq, batch, feat] 319 | diff --git a/parser/load.py b/parser/load.py 320 | index 70b1d92..43892ff 100644 321 | --- a/parser/load.py 322 | +++ b/parser/load.py 323 | @@ -2,12 +2,10 @@ from pathlib import Path 324 | from typing import NamedTuple, List 325 | 326 | import numpy as np 327 | +from lab import logger 328 | 329 | -from lab.logger import Logger 330 | from parser import tokenizer 331 | 332 | -logger = Logger() 333 | - 334 | 335 | class EncodedFile(NamedTuple): 336 | path: str 337 | @@ -34,7 +32,8 @@ def load_files() -> List[EncodedFile]: 338 | return files 339 | 340 | 341 | -def split_train_valid(files: List[EncodedFile], is_shuffle=True) -> (List[EncodedFile], List[EncodedFile]): 342 | +def split_train_valid(files: List[EncodedFile], 343 | + is_shuffle=True) -> (List[EncodedFile], List[EncodedFile]): 344 | """ 345 | Split training and validation sets 346 | """ 347 | @@ -55,15 +54,15 @@ def split_train_valid(files: List[EncodedFile], is_shuffle=True) -> (List[Encode 348 | if train_size < total_size * 0.60: 349 | raise RuntimeError("Validation set too large") 350 | 351 | - logger.info(train_size=train_size, 352 | - valid_size=valid_size, 353 | - vocab=tokenizer.VOCAB_SIZE) 354 | + logger.inspect(train_size=train_size, 355 | + valid_size=valid_size, 356 | + vocab=tokenizer.VOCAB_SIZE) 357 | return files, valid 358 | 359 | 360 | def main(): 361 | - files, code_to_str = load_files() 362 | - logger.info(code_to_str) 363 | + files = load_files() 364 | + logger.inspect(files) 365 | 366 | 367 | if __name__ == "__main__": 368 | diff --git a/parser/tokenizer.py b/parser/tokenizer.py 369 | index 806845f..1fde875 100644 370 | --- a/parser/tokenizer.py 371 | +++ b/parser/tokenizer.py 372 | @@ -1,6 +1,8 @@ 373 | import tokenize 374 | from io import BytesIO 375 | -from typing import Optional, List, NamedTuple 376 | +from typing import Optional, List, NamedTuple, Union 377 | + 378 | +import numpy as np 379 | 380 | 381 | class TokenType: 382 | @@ -33,6 +35,7 @@ class _TokenParser: 383 | """ 384 | Parse tokens 385 | """ 386 | + 387 | def __init__(self, token_type, tokenize_type, match_type, values, 388 | replacement=None): 389 | self.offset = 0 390 | @@ -55,32 +58,36 @@ class _TokenParser: 391 | """ 392 | Parse token 393 | """ 394 | - if type(self.tokenize_type) == list: 395 | - if token.type not in self.tokenize_type: 396 | - return None 397 | - else: 398 | - if token.type != self.tokenize_type: 399 | + try: 400 | + if type(self.tokenize_type) == list: 401 | + if token.type not in self.tokenize_type: 402 | + return None 403 | + else: 404 | + if token.type != self.tokenize_type: 405 | + return None 406 | + 407 | + # Perhaps use subclasses? 408 | + if self.match_type == _MatchType.exact: 409 | + if token.string not in self.match_set: 410 | + return None 411 | + return [ParsedToken(self.token_type, self.match_set[token.string])] 412 | + elif self.match_type == _MatchType.each: 413 | + res = [] 414 | + for ch in token.string: 415 | + res.append(ParsedToken(self.token_type, self.match_set[ch])) 416 | + return res 417 | + elif self.match_type == _MatchType.starts: 418 | + for i, pref in enumerate(self.values): 419 | + if token.string.startswith(pref): 420 | + return [ParsedToken(self.token_type, i)] 421 | return None 422 | - 423 | - # Perhaps use subclasses? 424 | - if self.match_type == _MatchType.exact: 425 | - if token.string not in self.match_set: 426 | - return None 427 | - return [ParsedToken(self.token_type, self.match_set[token.string])] 428 | - elif self.match_type == _MatchType.each: 429 | - res = [] 430 | - for ch in token.string: 431 | - res.append(ParsedToken(self.token_type, self.match_set[ch])) 432 | - return res 433 | - elif self.match_type == _MatchType.starts: 434 | - for i, pref in enumerate(self.values): 435 | - if token.string.startswith(pref): 436 | - return [ParsedToken(self.token_type, i)] 437 | - return None 438 | - elif self.match_type == _MatchType.none: 439 | - return [ParsedToken(self.token_type, 0)] 440 | - else: 441 | - raise RuntimeError(self.match_type) 442 | + elif self.match_type == _MatchType.none: 443 | + return [ParsedToken(self.token_type, 0)] 444 | + else: 445 | + raise RuntimeError(self.match_type) 446 | + except Exception as e: 447 | + print(token) 448 | + raise e 449 | 450 | def calc_serialize_range(self): 451 | for p in _PARSERS: 452 | @@ -103,7 +110,7 @@ _CHARS += [chr(i + ord('a')) for i in range(26)] 453 | _CHARS += [chr(i + ord('A')) for i in range(26)] 454 | _CHARS += [chr(i + ord('0')) for i in range(10)] 455 | 456 | -_NUMS = ['.', '_', 'e', 'x', '-'] 457 | +_NUMS = ['.', '_', 'x', 'X', 'o', 'O', '-', '+', 'j', 'J'] 458 | _NUMS += [chr(i + ord('a')) for i in range(6)] 459 | _NUMS += [chr(i + ord('A')) for i in range(6)] 460 | _NUMS += [chr(i + ord('0')) for i in range(10)] 461 | @@ -120,7 +127,7 @@ _PARSERS = [ 462 | '&', '|', '^', '~', '<<', '>>', 463 | '&=', '|=', '^=', '~=', '<<=', '>>=', 464 | '.', ',', '(', ')', ':', '[', ']', '{', '}', 465 | - '@', '...', ';']), 466 | + '@', '...', ';', '->']), 467 | _TokenParser(TokenType.keyword, tokenize.NAME, _MatchType.exact, 468 | ['and', 'as', 'assert', 'break', 'class', 469 | 'continue', 'def', 'del', 'elif', 'else', 470 | @@ -138,10 +145,19 @@ _PARSERS = [ 471 | _TokenParser(TokenType.comment, tokenize.COMMENT, _MatchType.none, '#') 472 | ] 473 | 474 | + 475 | +def get_vocab_size(token_type: int): 476 | + return len(_PARSERS[token_type]) 477 | + 478 | + 479 | +def get_vocab_offset(token_type: int): 480 | + return _PARSERS[token_type].offset 481 | + 482 | + 483 | VOCAB_SIZE = 0 484 | -DECODE = [] 485 | -LENGTHS = [] 486 | -DESERIALIZE = [] 487 | +DECODE: List[List[str]] = [] 488 | +LENGTHS: List[int] = [] 489 | +DESERIALIZE: List[ParsedToken] = [] 490 | 491 | SKIP_TOKENS = {tokenize.ENCODING, tokenize.ENDMARKER} 492 | EMPTY_TOKENS = {TokenType.eof, TokenType.new_line, TokenType.indent, TokenType.dedent} 493 | @@ -238,7 +254,7 @@ def encode(tokens: List[ParsedToken]) -> List[int]: 494 | return [_encode_token(t) for t in tokens] 495 | 496 | 497 | -def decode(codes: List[int]) -> List[ParsedToken]: 498 | +def decode(codes: Union[np.ndarray, List[int]]) -> List[ParsedToken]: 499 | """ 500 | Decode codes to tokens 501 | """ 502 | diff --git a/readme.md b/readme.md 503 | index 8efedba..f408715 100644 504 | --- a/readme.md 505 | +++ b/readme.md 506 | @@ -1,15 +1,32 @@ 507 | -[This](https://github.com/vpj/python_autocomplete) a toy project we started to see how well a simple LSTM model can autocomplete python code. 508 | - 509 | -It gives quite decent results by saving above 30% key strokes in most files, and close to 50% in some. We calculated key strokes saved by making a single (best) prediction and selecting it with a single key. 510 | - 511 | -We do a beam search to find predictions, upto ~10 characters ahead. So far it's too inefficient, if you are wondering about editor integration. 512 | - 513 | -We train and predict on after cleaning comments, strings and blank lines in python code. 514 | -The model is trained after tokenizing python code. It seems more efficient than character level prediction with byte-pair encoding. 515 | - 516 | -A saved model is included in this repo. It is trained on [tensorflow/models](https://github.com/tensorflow/models). 517 | - 518 | -Here's a sample evaluation on a source file from validation set. Green characters are when a autocompletion started; i.e. user presses TAB to select the completion. The green character and and the following characters highlighted in gray are autocompleted. As you can see, it starts and ends completions arbitarily. That is a suggestion could be 'tensorfl' and not the complete identifier 'tensorflow' which can be a little annoying in a real usage scenario. We can limit them to finish on end of tokens to fix that. Also you can notice that it completes across operators as well. Increasing the length of the beam search will let it complete longer pieces of code. 519 | +[This](https://github.com/vpj/python_autocomplete) a toy project we started 520 | +to see how well a simple LSTM model can autocomplete python code. 521 | + 522 | +It gives quite decent results by saving above 30% key strokes in most files, 523 | +and close to 50% in some. 524 | +We calculated key strokes saved by making a single (best) 525 | +prediction and selecting it with a single key. 526 | + 527 | +We do a beam search to find predictions, upto ~10 characters ahead. 528 | +So far it's too inefficient, if you are wondering about editor integration. 529 | + 530 | +We train and predict on after cleaning comments, strings 531 | +and blank lines in python code. 532 | +The model is trained after tokenizing python code. 533 | +It seems more efficient than character level prediction with byte-pair encoding. 534 | + 535 | +A saved model is included in this repo. 536 | +It is trained on [tensorflow/models](https://github.com/tensorflow/models). 537 | + 538 | +Here's a sample evaluation on a source file from validation set. 539 | +Green characters are when a auto-completion started; 540 | +i.e. user presses TAB to select the completion. 541 | +The green character and and the following characters highlighted in gray 542 | +are auto-completed. As you can see, it starts and ends completions arbitrarily. 543 | +That is a suggestion could be 'tensorfl' and not the complete identifier 544 | +'tensorflow' which can be a little annoying in a real usage scenario. 545 | +We can limit them to finish on end of tokens to fix that. 546 | +Also you can notice that it completes across operators as well. 547 | +Increasing the length of the beam search will let it complete longer pieces of code. 548 | 549 |

550 | 551 | diff --git a/requirements.txt b/requirements.txt 552 | index e69de29..f039373 100644 553 | --- a/requirements.txt 554 | +++ b/requirements.txt 555 | @@ -0,0 +1,3 @@ 556 | +machine_learning_lab 557 | +torch 558 | +numpy 559 | diff --git a/train.py b/train.py 560 | index 590c4a7..5beb0cf 100644 561 | --- a/train.py 562 | +++ b/train.py 563 | @@ -4,24 +4,19 @@ from typing import List 564 | import numpy as np 565 | import torch 566 | import torch.nn 567 | +from lab import experiment, monit, tracker, loop, logger 568 | +from lab.utils.delayed_keyboard_interrupt import DelayedKeyboardInterrupt 569 | 570 | import parser.load 571 | -from lab.experiment.pytorch import Experiment 572 | from model import SimpleLstmModel 573 | from parser import tokenizer 574 | 575 | -# Configure the experiment 576 | - 577 | -EXPERIMENT = Experiment(name="simple_lstm", 578 | - python_file=__file__, 579 | - comment="Simple LSTM", 580 | - check_repo_dirty=False, 581 | - is_log_python_file=False) 582 | - 583 | -logger = EXPERIMENT.logger 584 | +# Setup the experiment 585 | +experiment.create(name="simple_lstm", 586 | + comment="Simple LSTM") 587 | 588 | # device to train on 589 | -device = torch.device("cuda:1") 590 | +device = torch.device("cuda:0") 591 | 592 | 593 | def list_to_batches(x, batch_size, batches, seq_len): 594 | @@ -118,9 +113,9 @@ class Trainer: 595 | loss.backward() 596 | self.optimizer.step() 597 | 598 | - logger.store("train_loss", loss.cpu().data.item()) 599 | + tracker.add("train.loss", loss.cpu().data.item()) 600 | else: 601 | - logger.store("valid_loss", loss.cpu().data.item()) 602 | + tracker.add("valid.loss", loss.cpu().data.item()) 603 | 604 | 605 | def main_train(): 606 | @@ -129,13 +124,13 @@ def main_train(): 607 | batch_size = 32 608 | seq_len = 32 609 | 610 | - with logger.section("Loading data"): 611 | + with monit.section("Loading data"): 612 | # Load all python files 613 | files = parser.load.load_files() 614 | # Split training and validation data 615 | train_files, valid_files = parser.load.split_train_valid(files, is_shuffle=False) 616 | 617 | - with logger.section("Create model"): 618 | + with monit.section("Create model"): 619 | # Create model 620 | model = SimpleLstmModel(encoding_size=tokenizer.VOCAB_SIZE, 621 | embedding_size=tokenizer.VOCAB_SIZE, 622 | @@ -153,14 +148,14 @@ def main_train(): 623 | c0 = torch.zeros((lstm_layers, batch_size, lstm_size), device=device) 624 | 625 | # Setup logger indicators 626 | - logger.add_indicator("train_loss", queue_limit=500, is_histogram=True) 627 | - logger.add_indicator("valid_loss", queue_limit=500, is_histogram=True) 628 | + tracker.set_queue("train.loss", queue_size=500, is_print=True) 629 | + tracker.set_queue("valid.loss", queue_size=500, is_print=True) 630 | 631 | # Specify the model in [lab](https://github.com/vpj/lab) for saving and loading 632 | - EXPERIMENT.add_models({'base': model}) 633 | + experiment.add_pytorch_models({'base': model}) 634 | 635 | # Start training scratch (step '0') 636 | - EXPERIMENT.start_train(0) 637 | + experiment.start() 638 | 639 | # Number of batches per epoch 640 | batches = math.ceil(sum([len(f[1]) + 1 for f in train_files]) / (batch_size * seq_len)) 641 | @@ -169,7 +164,7 @@ def main_train(): 642 | steps_per_epoch = 200 643 | 644 | # Train for 100 epochs 645 | - for epoch in logger.loop(range(100)): 646 | + for epoch in loop.loop(range(100)): 647 | # Create trainer 648 | trainer = Trainer(files=train_files, 649 | model=model, 650 | @@ -199,46 +194,44 @@ def main_train(): 651 | 652 | # Loop through steps 653 | for i in range(1, steps_per_epoch): 654 | - # Set global step 655 | - global_step = epoch * batches + min(batches, (batches * i) // steps_per_epoch) 656 | - logger.set_global_step(global_step) 657 | - 658 | - # Last batch to train and validate 659 | - train_batch_limit = trainer.x.shape[0] * min(1., (i + 1) / steps_per_epoch) 660 | - valid_batch_limit = validator.x.shape[0] * min(1., (i + 1) / steps_per_epoch) 661 | - 662 | try: 663 | - with logger.delayed_keyboard_interrupt(): 664 | + with DelayedKeyboardInterrupt(): 665 | + # Set global step 666 | + global_step = epoch * batches + min(batches, (batches * i) // steps_per_epoch) 667 | + loop.set_global_step(global_step) 668 | 669 | - with logger.section("train", total_steps=trainer.x.shape[0], is_partial=True): 670 | + # Last batch to train and validate 671 | + train_batch_limit = trainer.x.shape[0] * min(1., (i + 1) / steps_per_epoch) 672 | + valid_batch_limit = validator.x.shape[0] * min(1., (i + 1) / steps_per_epoch) 673 | + 674 | + with monit.section("train", total_steps=trainer.x.shape[0], is_partial=True): 675 | model.train() 676 | # Train 677 | while train_batch < train_batch_limit: 678 | trainer.run(train_batch) 679 | - logger.progress(train_batch + 1) 680 | + monit.progress(train_batch + 1) 681 | train_batch += 1 682 | 683 | - with logger.section("valid", total_steps=validator.x.shape[0], is_partial=True): 684 | + with monit.section("valid", total_steps=validator.x.shape[0], is_partial=True): 685 | model.eval() 686 | # Validate 687 | while valid_batch < valid_batch_limit: 688 | validator.run(valid_batch) 689 | - logger.progress(valid_batch + 1) 690 | + monit.progress(valid_batch + 1) 691 | valid_batch += 1 692 | 693 | # Output results 694 | - logger.write() 695 | + tracker.save() 696 | 697 | # 10 lines of logs per epoch 698 | if (i + 1) % (steps_per_epoch // 10) == 0: 699 | - logger.new_line() 700 | - 701 | + logger.log() 702 | except KeyboardInterrupt: 703 | - logger.save_progress() 704 | - logger.save_checkpoint() 705 | - logger.new_line() 706 | + experiment.save_checkpoint() 707 | return 708 | 709 | + experiment.save_checkpoint() 710 | + 711 | 712 | if __name__ == '__main__': 713 | main_train() -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/sqlite.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vpj/python_autocomplete/cd61eb37c9ff9fef835173e4fba460d12e8f8d2e/logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/sqlite.db -------------------------------------------------------------------------------- /logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/tensorboard/events.out.tfevents.1589190774.varuna-small.31006.5.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vpj/python_autocomplete/cd61eb37c9ff9fef835173e4fba460d12e8f8d2e/logs/simple_lstm/2a86d636936d11eab8740dffb016e7b1/tensorboard/events.out.tfevents.1589190774.varuna-small.31006.5.v2 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SimpleLstmModel(nn.Module): 5 | def __init__(self, *, 6 | encoding_size, 7 | embedding_size, 8 | lstm_size, 9 | lstm_layers): 10 | super().__init__() 11 | 12 | self.embedding = nn.Embedding(encoding_size, embedding_size) 13 | self.lstm = nn.LSTM(input_size=embedding_size, 14 | hidden_size=lstm_size, 15 | num_layers=lstm_layers) 16 | self.fc = nn.Linear(lstm_size, encoding_size) 17 | self.softmax = nn.Softmax(dim=-1) 18 | 19 | def forward(self, x, h0, c0): 20 | # shape of x is [seq, batch, feat] 21 | x = self.embedding(x) 22 | out, (hn, cn) = self.lstm(x, (h0, c0)) 23 | logits = self.fc(out) 24 | 25 | return self.softmax(logits), logits, (hn, cn) 26 | -------------------------------------------------------------------------------- /parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vpj/python_autocomplete/cd61eb37c9ff9fef835173e4fba460d12e8f8d2e/parser/__init__.py -------------------------------------------------------------------------------- /parser/load.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, List 2 | 3 | import numpy as np 4 | from labml import logger, lab 5 | 6 | from parser import tokenizer 7 | 8 | 9 | class EncodedFile(NamedTuple): 10 | path: str 11 | codes: List[int] 12 | 13 | 14 | def load_files() -> List[EncodedFile]: 15 | """ 16 | Load encoded files 17 | """ 18 | with open(str(lab.get_data_path() / 'all.py')) as f: 19 | lines = f.readlines() 20 | 21 | files = [] 22 | for i in range(0, len(lines), 2): 23 | path = lines[i][:-1] 24 | content = lines[i + 1][:-1] 25 | if content == '': 26 | content = [] 27 | else: 28 | content = [int(t) for t in content.split(' ')] 29 | files.append(EncodedFile(path, content)) 30 | 31 | return files 32 | 33 | 34 | def split_train_valid(files: List[EncodedFile], 35 | is_shuffle=True) -> (List[EncodedFile], List[EncodedFile]): 36 | """ 37 | Split training and validation sets 38 | """ 39 | if is_shuffle: 40 | np.random.shuffle(files) 41 | 42 | total_size = sum([len(f.codes) for f in files]) 43 | valid = [] 44 | valid_size = 0 45 | while len(files) > 0: 46 | if valid_size > total_size * 0.15: 47 | break 48 | valid.append(files[0]) 49 | valid_size += len(files[0].codes) 50 | files.pop(0) 51 | 52 | train_size = sum(len(f.codes) for f in files) 53 | if train_size < total_size * 0.60: 54 | raise RuntimeError("Validation set too large") 55 | 56 | logger.inspect(train_size=train_size, 57 | valid_size=valid_size, 58 | vocab=tokenizer.VOCAB_SIZE) 59 | return files, valid 60 | 61 | 62 | def main(): 63 | files = load_files() 64 | logger.inspect(files) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /parser/tokenizer.py: -------------------------------------------------------------------------------- 1 | import tokenize 2 | from io import BytesIO 3 | from typing import Optional, List, NamedTuple, Union 4 | 5 | import numpy as np 6 | 7 | 8 | class TokenType: 9 | eof = 0 10 | new_line = 1 11 | indent = 2 12 | dedent = 3 13 | op = 4 14 | keyword = 5 15 | name = 6 16 | number = 7 17 | string = 8 18 | string_other = 9 19 | comment = 10 20 | 21 | 22 | class _MatchType: 23 | exact = 0 24 | each = 1 25 | starts = 2 26 | none = 3 27 | 28 | 29 | class ParsedToken(NamedTuple): 30 | type: int 31 | value: int 32 | 33 | 34 | class _TokenParser: 35 | """ 36 | Parse tokens 37 | """ 38 | 39 | def __init__(self, token_type, tokenize_type, match_type, values, 40 | replacement=None): 41 | self.offset = 0 42 | self.token_type = token_type 43 | self.tokenize_type = tokenize_type 44 | self.match_type = match_type 45 | self.values = values 46 | self.replacement = replacement 47 | 48 | if match_type == _MatchType.exact or match_type == _MatchType.each: 49 | self.match_set = {v: i for i, v in enumerate(values)} 50 | 51 | def __len__(self): 52 | if type(self.values) != list: 53 | return 1 54 | else: 55 | return len(self.values) 56 | 57 | def parse(self, token: tokenize.TokenInfo) -> Optional[List[ParsedToken]]: 58 | """ 59 | Parse token 60 | """ 61 | try: 62 | if type(self.tokenize_type) == list: 63 | if token.type not in self.tokenize_type: 64 | return None 65 | else: 66 | if token.type != self.tokenize_type: 67 | return None 68 | 69 | # Perhaps use subclasses? 70 | if self.match_type == _MatchType.exact: 71 | if token.string not in self.match_set: 72 | return None 73 | return [ParsedToken(self.token_type, self.match_set[token.string])] 74 | elif self.match_type == _MatchType.each: 75 | res = [] 76 | for ch in token.string: 77 | res.append(ParsedToken(self.token_type, self.match_set[ch])) 78 | return res 79 | elif self.match_type == _MatchType.starts: 80 | for i, pref in enumerate(self.values): 81 | if token.string.startswith(pref): 82 | return [ParsedToken(self.token_type, i)] 83 | return None 84 | elif self.match_type == _MatchType.none: 85 | return [ParsedToken(self.token_type, 0)] 86 | else: 87 | raise RuntimeError(self.match_type) 88 | except Exception as e: 89 | print(token) 90 | raise e 91 | 92 | def calc_serialize_range(self): 93 | for p in _PARSERS: 94 | if p == self: 95 | break 96 | self.offset += len(p) 97 | 98 | def get_str(self, value): 99 | if self.replacement is not None: 100 | return self.replacement[value] 101 | 102 | if type(self.values) == str: 103 | return self.values 104 | else: 105 | return self.values[value] 106 | 107 | 108 | _CHARS = ['_'] 109 | _CHARS += [chr(i + ord('a')) for i in range(26)] 110 | _CHARS += [chr(i + ord('A')) for i in range(26)] 111 | _CHARS += [chr(i + ord('0')) for i in range(10)] 112 | 113 | _NUMS = ['.', '_', 'x', 'X', 'o', 'O', '-', '+', 'j', 'J'] 114 | _NUMS += [chr(i + ord('a')) for i in range(6)] 115 | _NUMS += [chr(i + ord('A')) for i in range(6)] 116 | _NUMS += [chr(i + ord('0')) for i in range(10)] 117 | 118 | _PARSERS = [ 119 | _TokenParser(TokenType.eof, None, _MatchType.none, '[eof]'), 120 | _TokenParser(TokenType.new_line, [tokenize.NL, tokenize.NEWLINE], _MatchType.none, '\n'), 121 | _TokenParser(TokenType.indent, tokenize.INDENT, _MatchType.none, ' '), 122 | _TokenParser(TokenType.dedent, tokenize.DEDENT, _MatchType.none, ''), 123 | _TokenParser(TokenType.op, tokenize.OP, _MatchType.exact, 124 | ['+', '-', '*', '/', '%', '**', '//', 125 | '==', '!=', '<>', '>', '<', '>=', '<=', 126 | '=', '+=', '-=', '*=', '/=', '%=', '**=', '//=', 127 | '&', '|', '^', '~', '<<', '>>', 128 | '&=', '|=', '^=', '~=', '<<=', '>>=', 129 | '.', ',', '(', ')', ':', '[', ']', '{', '}', 130 | '@', '...', ';', '->']), 131 | _TokenParser(TokenType.keyword, tokenize.NAME, _MatchType.exact, 132 | ['and', 'as', 'assert', 'break', 'class', 133 | 'continue', 'def', 'del', 'elif', 'else', 134 | 'except', 'False', 'finally', 'for', 'from', 135 | 'global', 'if', 'import', 'in', 'is', 'lambda', 136 | 'None', 'nonlocal', 'not', 'or', 'pass', 'raise', 137 | 'return', 'True', 'try', 'while', 'with', 'yield']), 138 | _TokenParser(TokenType.name, tokenize.NAME, _MatchType.each, _CHARS), 139 | _TokenParser(TokenType.number, tokenize.NUMBER, _MatchType.each, _NUMS), 140 | _TokenParser(TokenType.string, tokenize.STRING, _MatchType.starts, 141 | ['"""', "'''", '"', "'", 'f"'], 142 | ['""" """', "''' '''", '""', "''", 'f""']), 143 | _TokenParser(TokenType.string_other, tokenize.STRING, _MatchType.none, ['"'], ['""']), 144 | # regex etc 145 | _TokenParser(TokenType.comment, tokenize.COMMENT, _MatchType.none, '#') 146 | ] 147 | 148 | 149 | def get_vocab_size(token_type: int): 150 | return len(_PARSERS[token_type]) 151 | 152 | 153 | def get_vocab_offset(token_type: int): 154 | return _PARSERS[token_type].offset 155 | 156 | 157 | VOCAB_SIZE = 0 158 | DECODE: List[List[str]] = [] 159 | LENGTHS: List[int] = [] 160 | DESERIALIZE: List[ParsedToken] = [] 161 | 162 | SKIP_TOKENS = {tokenize.ENCODING, tokenize.ENDMARKER} 163 | EMPTY_TOKENS = {TokenType.eof, TokenType.new_line, TokenType.indent, TokenType.dedent} 164 | LINE_BREAK = {TokenType.eof, TokenType.new_line} 165 | 166 | 167 | def _parse_token(token: tokenize.TokenInfo) -> List[ParsedToken]: 168 | if token.type in SKIP_TOKENS: 169 | return [] 170 | 171 | for p in _PARSERS: 172 | res = p.parse(token) 173 | if res is not None: 174 | return res 175 | 176 | raise RuntimeError(token) 177 | 178 | 179 | def _encode_token(token: ParsedToken): 180 | return _PARSERS[token.type].offset + token.value 181 | 182 | 183 | def _decode_code(code: int) -> ParsedToken: 184 | for p in _PARSERS: 185 | if code < p.offset + len(p): 186 | return ParsedToken(p.token_type, code - p.offset) 187 | 188 | 189 | def _token_to_string(token: ParsedToken, prev: Optional[ParsedToken]): 190 | is_spaced = False 191 | if prev is not None: 192 | if prev.type == TokenType.keyword: 193 | if token.type == TokenType.name: 194 | is_spaced = True 195 | if token.type == TokenType.number: 196 | is_spaced = True 197 | if token.type == TokenType.keyword: 198 | is_spaced = True 199 | elif token.type == TokenType.keyword: 200 | if prev.type == TokenType.name: 201 | is_spaced = True 202 | if prev.type == TokenType.number: 203 | is_spaced = True 204 | if prev.type == TokenType.keyword: 205 | is_spaced = True 206 | 207 | string = _PARSERS[token.type].get_str(token.value) 208 | 209 | if is_spaced: 210 | return " " + string 211 | else: 212 | return string 213 | 214 | 215 | def _init(): 216 | """ 217 | Pre-calculate for efficiency 218 | """ 219 | global VOCAB_SIZE, _PARSERS, DESERIALIZE, LENGTHS 220 | 221 | for p in _PARSERS: 222 | p.calc_serialize_range() 223 | VOCAB_SIZE += len(p) 224 | 225 | for c1 in range(VOCAB_SIZE): 226 | t1 = _decode_code(c1) 227 | DESERIALIZE.append(t1) 228 | LENGTHS.append(len(_token_to_string(t1, None))) 229 | dec = [] 230 | for c2 in range(VOCAB_SIZE): 231 | t2 = _decode_code(c2) 232 | dec.append(_token_to_string(t1, t2)) 233 | DECODE.append(dec) 234 | 235 | 236 | _init() 237 | 238 | 239 | def parse(tokens: List[tokenize.TokenInfo]) -> List[ParsedToken]: 240 | """ 241 | Parse tokens 242 | """ 243 | parsed = [] 244 | for t in tokens: 245 | parsed += _parse_token(t) 246 | 247 | return parsed 248 | 249 | 250 | def encode(tokens: List[ParsedToken]) -> List[int]: 251 | """ 252 | Encode tokens to codes 253 | """ 254 | return [_encode_token(t) for t in tokens] 255 | 256 | 257 | def decode(codes: Union[np.ndarray, List[int]]) -> List[ParsedToken]: 258 | """ 259 | Decode codes to tokens 260 | """ 261 | return [DESERIALIZE[c] for c in codes] 262 | 263 | 264 | def parse_string(content: str) -> List[ParsedToken]: 265 | """ 266 | Encode source code 267 | """ 268 | g = tokenize.tokenize(BytesIO(content.encode('utf-8')).readline) 269 | 270 | return parse(g) 271 | 272 | 273 | def to_string(tokens: List[ParsedToken]) -> str: 274 | """ 275 | Convert tokens to source code 276 | """ 277 | res = "" 278 | prev = None 279 | for t in tokens: 280 | res += _token_to_string(t, prev) 281 | prev = t 282 | 283 | return res 284 | -------------------------------------------------------------------------------- /python-autocomplete.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vpj/python_autocomplete/cd61eb37c9ff9fef835173e4fba460d12e8f8d2e/python-autocomplete.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ### ⭐️ We rewrote a simpler version of this at [lab-ml/source_code_modelling](https://github.com/lab-ml/source_code_modelling) and we intend to maintain it for a while 2 | 3 | [This](https://github.com/vpj/python_autocomplete) a toy project we started 4 | to see how well a simple LSTM model can autocomplete python code. 5 | 6 | It gives quite decent results by saving above 30% key strokes in most files, 7 | and close to 50% in some. 8 | We calculated key strokes saved by making a single (best) 9 | prediction and selecting it with a single key. 10 | 11 | We do a beam search to find predictions, upto ~10 characters ahead. 12 | So far it's too inefficient, if you are wondering about editor integration. 13 | 14 | We train and predict on after cleaning comments, strings 15 | and blank lines in python code. 16 | The model is trained after tokenizing python code. 17 | It seems more efficient than character level prediction with byte-pair encoding. 18 | 19 | A saved model is included in this repo. 20 | It is trained on [tensorflow/models](https://github.com/tensorflow/models). 21 | 22 | Here's a sample evaluation on a source file from validation set. 23 | Red characters are when a auto-completion started; 24 | i.e. user presses TAB to select the completion. 25 | The green character and and the following characters highlighted in gray 26 | are auto-completed. As you can see, it starts and ends completions arbitrarily. 27 | That is a suggestion could be 'tensorfl' and not the complete identifier 28 | 'tensorflow' which can be a little annoying in a real usage scenario. 29 | We can limit them to finish on end of tokens to fix that. 30 | Also you can notice that it completes across operators as well. 31 | Increasing the length of the beam search will let it complete longer pieces of code. 32 | 33 |

34 | 35 |

36 | 37 | ## Try it yourself 38 | 39 | 1. Clone this repo 40 | 41 | 2. Install requirements from `requirements.txt` 42 | 43 | 3. Copy data to `./data/source` 44 | 45 | 4. Run `extract_code.py` to collect all python files, encode and merge them into `all.py` 46 | 47 | 5. Run `evaluate.py` to evaluate the model. I have included a checkpoint in the repo. 48 | 49 | 6. Run `train.py` to train the model 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | labml 2 | torch 3 | numpy 4 | tensorflow 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn 7 | from labml import experiment, monit, tracker, logger 8 | from labml.utils.delayed_keyboard_interrupt import DelayedKeyboardInterrupt 9 | 10 | import parser.load 11 | from model import SimpleLstmModel 12 | from parser import tokenizer 13 | 14 | # Setup the experiment 15 | experiment.create(name="simple_lstm", 16 | comment="Simple LSTM") 17 | 18 | # device to train on 19 | device = torch.device("cuda:0") 20 | 21 | 22 | def list_to_batches(x, batch_size, batches, seq_len): 23 | """ 24 | Prepare flat data into batches to be ready for the model to consume 25 | """ 26 | x = np.reshape(x, (batch_size, batches, seq_len)) 27 | x = np.transpose(x, (1, 2, 0)) 28 | 29 | return x 30 | 31 | 32 | def get_batches(files: List[parser.load.EncodedFile], eof: int, batch_size=32, seq_len=32): 33 | """ 34 | Covert raw encoded files into trainin/validation batches 35 | """ 36 | 37 | # Shuffle the order of files 38 | np.random.shuffle(files) 39 | 40 | # Concatenate all the files whilst adding `eof` marker at the beginnings 41 | data = [] 42 | for f in files: 43 | data.append(eof) 44 | data += f.codes 45 | data = np.array(data) 46 | 47 | # Start from a random offset 48 | offset = np.random.randint(seq_len * batch_size) 49 | data = data[offset:] 50 | 51 | # Number of batches 52 | batches = (len(data) - 1) // batch_size // seq_len 53 | 54 | # Extract input 55 | x = data[:(batch_size * seq_len * batches)] 56 | # Extract output, i.e. the next char 57 | y = data[1:(batch_size * seq_len * batches) + 1] 58 | 59 | # Covert the flat data into batches 60 | x = list_to_batches(x, batch_size, batches, seq_len) 61 | y = list_to_batches(y, batch_size, batches, seq_len) 62 | 63 | return x, y 64 | 65 | 66 | class Trainer: 67 | """ 68 | This will maintain states, data and train/validate the model 69 | """ 70 | 71 | def __init__(self, *, files: List[parser.load.EncodedFile], 72 | model, loss_func, optimizer, 73 | eof: int, 74 | batch_size: int, seq_len: int, 75 | is_train: bool, 76 | h0, c0): 77 | # Get batches 78 | x, y = get_batches(files, eof, 79 | batch_size=batch_size, 80 | seq_len=seq_len) 81 | # Covert data to PyTorch tensors 82 | self.x = torch.tensor(x, device=device) 83 | self.y = torch.tensor(y, device=device) 84 | 85 | # Initial state 86 | self.hn = h0 87 | self.cn = c0 88 | 89 | self.model = model 90 | self.loss_func = loss_func 91 | self.optimizer = optimizer 92 | self.p = None 93 | self.is_train = is_train 94 | 95 | def run(self, i): 96 | # Get model output 97 | self.p, logits, (self.hn, self.cn) = self.model(self.x[i], self.hn, self.cn) 98 | 99 | # Flatten outputs 100 | logits = logits.view(-1, self.p.shape[-1]) 101 | yi = self.y[i].reshape(-1) 102 | 103 | # Calculate loss 104 | loss = self.loss_func(logits, yi) 105 | 106 | # Store the states 107 | self.hn = self.hn.detach() 108 | self.cn = self.cn.detach() 109 | 110 | if self.is_train: 111 | # Take a training step 112 | self.optimizer.zero_grad() 113 | loss.backward() 114 | self.optimizer.step() 115 | 116 | tracker.add("train.loss", loss.cpu().data.item()) 117 | else: 118 | tracker.add("valid.loss", loss.cpu().data.item()) 119 | 120 | 121 | def main_train(): 122 | lstm_size = 1024 123 | lstm_layers = 3 124 | batch_size = 32 125 | seq_len = 32 126 | 127 | with monit.section("Loading data"): 128 | # Load all python files 129 | files = parser.load.load_files() 130 | # Split training and validation data 131 | train_files, valid_files = parser.load.split_train_valid(files, is_shuffle=False) 132 | 133 | with monit.section("Create model"): 134 | # Create model 135 | model = SimpleLstmModel(encoding_size=tokenizer.VOCAB_SIZE, 136 | embedding_size=tokenizer.VOCAB_SIZE, 137 | lstm_size=lstm_size, 138 | lstm_layers=lstm_layers) 139 | # Move model to `device` 140 | model.to(device) 141 | 142 | # Create loss function and optimizer 143 | loss_func = torch.nn.CrossEntropyLoss() 144 | optimizer = torch.optim.Adam(model.parameters()) 145 | 146 | # Initial state is 0 147 | h0 = torch.zeros((lstm_layers, batch_size, lstm_size), device=device) 148 | c0 = torch.zeros((lstm_layers, batch_size, lstm_size), device=device) 149 | 150 | # Setup logger indicators 151 | tracker.set_queue("train.loss", queue_size=500, is_print=True) 152 | tracker.set_queue("valid.loss", queue_size=500, is_print=True) 153 | 154 | # Specify the model in [lab](https://github.com/vpj/lab) for saving and loading 155 | experiment.add_pytorch_models({'base': model}) 156 | 157 | # Start training scratch (step '0') 158 | experiment.start() 159 | 160 | # Number of batches per epoch 161 | batches = math.ceil(sum([len(f[1]) + 1 for f in train_files]) / (batch_size * seq_len)) 162 | 163 | # Number of steps per epoch. We train and validate on each step. 164 | steps_per_epoch = 200 165 | 166 | # Train for 100 epochs 167 | for epoch in monit.loop(range(100)): 168 | # Create trainer 169 | trainer = Trainer(files=train_files, 170 | model=model, 171 | loss_func=loss_func, 172 | optimizer=optimizer, 173 | batch_size=batch_size, 174 | seq_len=seq_len, 175 | is_train=True, 176 | h0=h0, 177 | c0=c0, 178 | eof=0) 179 | # Create validator 180 | validator = Trainer(files=valid_files, 181 | model=model, 182 | loss_func=loss_func, 183 | optimizer=optimizer, 184 | is_train=False, 185 | seq_len=seq_len, 186 | batch_size=batch_size, 187 | h0=h0, 188 | c0=c0, 189 | eof=0) 190 | 191 | # Next batch to train and validation 192 | train_batch = 0 193 | valid_batch = 0 194 | 195 | # Loop through steps 196 | for i in range(1, steps_per_epoch): 197 | try: 198 | with DelayedKeyboardInterrupt(): 199 | # Set global step 200 | global_step = epoch * batches + min(batches, (batches * i) // steps_per_epoch) 201 | tracker.set_global_step(global_step) 202 | 203 | # Last batch to train and validate 204 | train_batch_limit = trainer.x.shape[0] * min(1., (i + 1) / steps_per_epoch) 205 | valid_batch_limit = validator.x.shape[0] * min(1., (i + 1) / steps_per_epoch) 206 | 207 | with monit.section("train", total_steps=trainer.x.shape[0], is_partial=True): 208 | model.train() 209 | # Train 210 | while train_batch < train_batch_limit: 211 | trainer.run(train_batch) 212 | monit.progress(train_batch + 1) 213 | train_batch += 1 214 | 215 | with monit.section("valid", total_steps=validator.x.shape[0], is_partial=True): 216 | model.eval() 217 | # Validate 218 | while valid_batch < valid_batch_limit: 219 | validator.run(valid_batch) 220 | monit.progress(valid_batch + 1) 221 | valid_batch += 1 222 | 223 | # Output results 224 | tracker.save() 225 | 226 | # 10 lines of logs per epoch 227 | if (i + 1) % (steps_per_epoch // 10) == 0: 228 | logger.log() 229 | except KeyboardInterrupt: 230 | experiment.save_checkpoint() 231 | return 232 | 233 | experiment.save_checkpoint() 234 | 235 | 236 | if __name__ == '__main__': 237 | main_train() 238 | --------------------------------------------------------------------------------