├── .gitignore ├── LICENSE ├── README.md ├── docly ├── __init__.py ├── cli │ ├── __init__.py │ ├── args.py │ ├── docly_clean.py │ ├── docly_gen.py │ ├── docly_restore.py │ └── setup_env.py ├── config │ └── __init__.py ├── ioutils │ ├── __init__.py │ ├── apply_diff.py │ ├── console_printer.py │ ├── convert_ipynb.py │ └── table_printer.py ├── logic │ ├── __init__.py │ ├── example.py │ ├── input_features.py │ ├── logic_main.py │ ├── model.py │ └── model_new.py ├── parser │ ├── __init__.py │ └── parser.py └── tokenizers │ └── __init__.py ├── logo └── docly.png ├── requirements.txt ├── setup.py └── test_files ├── api.py ├── flask_files ├── aa.py └── cli.py ├── inner_dir └── _internal_utils.py ├── notebooks ├── 6_lstm.ipynb ├── clipboards.py └── dream.ipynb ├── random.txt ├── simple_funcs └── simple_funcs.py └── test_config_file.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # vscode 141 | 142 | .vscode/ 143 | 144 | # Python-version 145 | 146 | .python-version 147 | 148 | # VScode workspace 149 | 150 | *.code-workspace -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Codist AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # docly 2 | 3 | ## Please note------- 4 | 5 | **Docly is not maintained anymore. Codist is in the process of shutting down. Sorry for the inconvenience** 6 | 7 | ![Docly - Automatic source code commenting](https://github.com/autosoft-dev/docly/blob/master/logo/docly.png) 8 | 9 | [![parser: tree-hugger](https://img.shields.io/badge/parser-tree--hugger-lightgrey)](https://github.com/autosoft-dev/tree-hugger/) 10 | 11 | Automatically generate docstrings for your python functions 12 | 13 | ## Main documentation (FAQ and all) 14 | 15 | http://beta.thedocly.io/ 16 | 17 | 18 | ## Installing 19 | 20 | Requires python 3.6+ 21 | 22 | _NOTE THAT, if you are getting an error (in a fresh virtualenv in ubuntu it is observed) like `error: invalid command 'bdist_wheel'` then please install wheel by doing `pip install wheel`_ 23 | 24 | _ALSO NOTE that if you are getting an error to build tree-sitter because you do not have gcc installed then you can install it using `sudo apt-get install gcc python3-dev` for other distros please check [here](https://stackoverflow.com/questions/21530577/fatal-error-python-h-no-such-file-or-directory)_ 25 | 26 | First install setuptools-rust by 27 | 28 | ``` 29 | pip install setuptools-rust 30 | ``` 31 | 32 | Then 33 | 34 | ``` 35 | pip install docly 36 | ``` 37 | 38 | ## Using 39 | 40 | To generate comments - 41 | 42 | ``` 43 | docly-gen /path/to/file_or_folder_with_python_files 44 | ``` 45 | _Please note that if you do not have the necessary engine (models) downloaded before running the command (which is going to be case the first time you run this command) then it will download and set them up. Which may take a bit of time_ 46 | 47 | 48 | It will produce something like this (Shown on a single file but you can run it on a directory full of files also) 49 | 50 | ``` 51 | The diff has been generated, do you want to see the suggestions for missing Docstrings? [Y/n] 52 | Y 53 | +-----------------+------------------------------+---------------------------------------+ 54 | | File Name | Function Name | Docstring | 55 | +-----------------+------------------------------+---------------------------------------+ 56 | | simple_funcs.py | add | Add two numbers . | 57 | | simple_funcs.py | check_if_even | Checks if number is even . | 58 | | simple_funcs.py | check_even_numbers_in_a_list | Return list of numbers in base_list . | 59 | | simple_funcs.py | open_file | Open a file . | 60 | +-----------------+------------------------------+---------------------------------------+ 61 | Do you want to apply the suggestions? [Y/n] 62 | Y 63 | Applying diff 64 | Diff applied. Good bye! 65 | ``` 66 | 67 | Instead if you just want the above report and not to apply the chages then do this - 68 | 69 | ``` 70 | docly-gen --no_generate_diff --print_report /path/to/file_or_folder_with_python_files 71 | ``` 72 | 73 | If you want to revert the changes we applied then use 74 | 75 | ``` 76 | docly-restore 77 | ``` 78 | 79 | This will bring back ALL the files that we had touched to the exact state before we applied the changes 80 | -------------------------------------------------------------------------------- /docly/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.0" -------------------------------------------------------------------------------- /docly/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autosoft-dev/docly/0bd6216b8a9735e9fa76bffd4ffea6cec6cc4a01/docly/cli/__init__.py -------------------------------------------------------------------------------- /docly/cli/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from argparse import ArgumentDefaultsHelpFormatter 3 | from pathlib import Path 4 | 5 | 6 | def setup_cmdline_args_for_docly_gen(parser: ArgumentParser): 7 | parser.add_argument("--no_generate_diff", action="store_false", 8 | help="Do not generate the diff. Only prints a report on the console") 9 | parser.add_argument("--no_generate_args_list", action="store_false", 10 | help="Do not generate argument list in the docstring") 11 | parser.add_argument("--no_print_report", action="store_false", 12 | help="Do not prompt to show the report once the diff is generated") 13 | parser.add_argument("--run_on_notebooks", action="store_true", 14 | help="If you want docly to run on notebook (.ipynb) files (Requires jupytext and defaults false)") 15 | parser.add_argument("--docstring_style", type=str, default="google", 16 | help="What style of docstring you want [google, numpy, sphinx]. Defaults to `google` style.") 17 | parser.add_argument("--config_file", type=str, default="docly-config.ini", 18 | help="Configuration file for docly") 19 | parser.add_argument("--use-old-model", action="store_true", 20 | help="Do you want to run the older model? Saves a second download, but bad quality") 21 | parser.add_argument("--force", action="store_true", 22 | help="Applies the changes without an interactive propmt") 23 | # This is manmdatory 24 | parser.add_argument("files", type=str, nargs="+", 25 | help="List the files/dirs you want to run it on") 26 | 27 | 28 | def setup_cmdline_args_for_docly_restore(parser: ArgumentParser): 29 | parser.add_argument("--force", action="store_true", 30 | help="Disables interactive restoring") -------------------------------------------------------------------------------- /docly/cli/docly_clean.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from docly.ioutils.console_printer import print_on_console 4 | 5 | 6 | def main(): 7 | print_on_console("Cleaning old model", color="green") 8 | if Path((Path.home() / ".docly" / "model" / "pytorch_model.bin")).exists(): 9 | Path((Path.home() / ".docly" / "model" / "pytorch_model.bin")).unlink() 10 | print_on_console("Cleaning done", color="green", emoji="heavy_check_mark") 11 | -------------------------------------------------------------------------------- /docly/cli/docly_gen.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 3 | import time 4 | import sys 5 | import shutil 6 | 7 | from pyfiglet import Figlet 8 | import transformers 9 | from halo import Halo 10 | 11 | from .args import setup_cmdline_args_for_docly_gen 12 | from .setup_env import inspect_and_download_latest_model, inspect_and_download_latest_tslibs 13 | from docly.config import DoclyConfig 14 | from docly.ioutils import (is_dir, 15 | check_out_path, 16 | process_file, 17 | is_python_file, 18 | is_ipynb_notebook, 19 | query_yes_no, 20 | look_for_update 21 | ) 22 | from docly.ioutils.console_printer import print_on_console 23 | from docly.ioutils.apply_diff import apply_diff 24 | from docly.ioutils.table_printer import print_results_as_table 25 | from docly.ioutils.convert_ipynb import convert_ipynb_to_python 26 | from docly.logic.logic_main import load_model, predict_docstring 27 | 28 | transformers.logger.setLevel(transformers.logging.CRITICAL) 29 | 30 | parser = ArgumentParser() 31 | ROOT = (Path.home() / ".docly") 32 | MODEL_DOWNLOAD_ROOT = "https://docly-model.s3.amazonaws.com/pytorch_model.bin" 33 | NEW_MODEL_DOWNLOAD_ROOT = "https://code-summary.s3.amazonaws.com/pytorch_model_new.bin" 34 | TSLIBS_DOWNLOAD_ROOT = "https://func2docstr-py.s3.amazonaws.com/" 35 | 36 | f = Figlet(font='slant') 37 | 38 | 39 | def _print_welcome(): 40 | print(f.renderText('Docly')) 41 | 42 | 43 | # def _predict_docstrings_from_file(f_path: Path, ts_lib_path: str, model: object, tokenizer: object): 44 | # func_names = [] 45 | # ct = [] 46 | # for code_tokens, raw_code, start_index, function_name in process_file(f_path, ts_lib_path): 47 | # func_names.append(func_names) 48 | # ct.append(code_tokens) 49 | # docstrs = predict_docstring(model, tokenizer, ct) 50 | # print(docstrs) 51 | 52 | 53 | def _apply_diff(docstr_loc, no_generate_args_list, ipynb_files, docstring_style): 54 | print_on_console("Applying diff", color="green") 55 | apply_diff(docstr_loc, no_generate_args_list, ipynb_files, docstring_style) 56 | print_on_console("Diff applied. Good bye!", color="green", emoji="thumbsup") 57 | 58 | 59 | def _remove_converted_python_files(ipynb_files): 60 | for py_file_loc, _ in ipynb_files.items(): 61 | Path(py_file_loc).unlink() 62 | 63 | 64 | def _if_jupytext_is_installed(): 65 | try: 66 | import jupytext 67 | return True 68 | except ModuleNotFoundError: 69 | return False 70 | 71 | 72 | def _deal_with_result(args, table_rows, docstr_loc, ipynb_files): 73 | if not args.no_generate_diff and table_rows: 74 | print_results_as_table(table_rows) 75 | elif args.force and docstr_loc: 76 | _apply_diff(docstr_loc, args.no_generate_args_list, ipynb_files, args.docstring_style) 77 | _remove_converted_python_files(ipynb_files) 78 | elif args.no_generate_diff and docstr_loc: 79 | if args.no_print_report: 80 | choice = query_yes_no("The diff has been generated, do you want to see the suggestions for missing Docstrings?") 81 | if choice: 82 | print_results_as_table(table_rows) 83 | choice = query_yes_no("Do you want to apply the suggestions?") 84 | else: 85 | choice = query_yes_no("Do you want to apply the suggestions?") 86 | else: 87 | choice = query_yes_no("Do you want to apply the suggestions?") 88 | 89 | if choice: 90 | _apply_diff(docstr_loc, args.no_generate_args_list, ipynb_files, args.docstring_style) 91 | else: 92 | _remove_converted_python_files(ipynb_files) 93 | print_on_console("Nothing changed. Good bye!", color="green", emoji="thumbsup") 94 | else: 95 | print_on_console("\n\nNothing to be done. Good bye!", color="green", emoji="thumbsup") 96 | 97 | 98 | @Halo(text='Processing files', spinner='dots') 99 | def _process(args, model, tokenizer, ts_lib_path, config: DoclyConfig): 100 | """ 101 | Terribly written code. Refactor ASAP 102 | """ 103 | table_rows = [] 104 | docstr_loc = {} # Very badly named variable. Need to change 105 | ipynb_files = {} 106 | 107 | for file in args.files: 108 | f_path = Path(file) 109 | if is_dir(file): 110 | for f in check_out_path(f_path): 111 | if not is_dir(f) and is_python_file(f): 112 | #### 113 | # Very bad implementation. Change ASAP 114 | #### 115 | if not config.is_dir_skipped(str(f).split("/")[:-1]): 116 | for code_tokens, params, start_index, function_name, ds in process_file(f, ts_lib_path, args.use_old_model): 117 | if ds == "": 118 | docstr = predict_docstring(model, tokenizer, code_tokens, args.use_old_model) 119 | if docstr_loc.get(str(f)) is None: 120 | docstr_loc[str(f)] = {start_index[0]: 121 | (start_index[1], 122 | docstr[0], 123 | params 124 | ) 125 | } 126 | else: 127 | docstr_loc[str(f)][start_index[0]] = (start_index[1], 128 | docstr[0], 129 | params) 130 | table_rows.append([f.name, function_name, docstr[0]]) 131 | elif not is_dir(f) and is_ipynb_notebook(f) and args.run_on_notebooks: 132 | if not config.is_dir_skipped(str(f).split("/")[:-1]): 133 | py_file = convert_ipynb_to_python(f) 134 | if py_file: 135 | for code_tokens, params, start_index, function_name, ds in process_file(py_file, ts_lib_path, args.use_old_model): 136 | if ds == "": 137 | docstr = predict_docstring(model, tokenizer, code_tokens, args.use_old_model) 138 | if docstr_loc.get(str(py_file)) is None: 139 | docstr_loc[str(py_file)] = {start_index[0]: 140 | (start_index[1], 141 | docstr[0], 142 | params 143 | ) 144 | } 145 | else: 146 | docstr_loc[str(py_file)][start_index[0]] = (start_index[1], 147 | docstr[0], 148 | params) 149 | table_rows.append([f.name, function_name, docstr[0]]) 150 | ipynb_files[str(py_file.absolute())] = f 151 | else: 152 | if is_python_file(f_path): 153 | if not config.is_dir_skipped(str(f_path.absolute()).split("/")[:-1]): 154 | for code_tokens, params, start_index, function_name, ds in process_file(f_path, ts_lib_path, args.use_old_model): 155 | if ds == "": 156 | docstr = predict_docstring(model, tokenizer, code_tokens, args.use_old_model) 157 | if docstr_loc.get(str(f_path.absolute())) is None: 158 | docstr_loc[str(f_path.absolute())] = {start_index[0]: 159 | (start_index[1], 160 | docstr[0], 161 | params 162 | ) 163 | } 164 | else: 165 | docstr_loc[str(f_path.absolute())][start_index[0]] = (start_index[1], 166 | docstr[0], 167 | params) 168 | table_rows.append([f_path.name, function_name, docstr[0]]) 169 | elif is_ipynb_notebook(f_path) and args.run_on_notebooks: 170 | if not config.is_dir_skipped(str(f_path.absolute()).split("/")[:-1]): 171 | py_file = convert_ipynb_to_python(f_path) 172 | if py_file: 173 | for code_tokens, params, start_index, function_name, ds in process_file(py_file, ts_lib_path, args.use_old_model): 174 | if ds == "": 175 | docstr = predict_docstring(model, tokenizer, code_tokens, args.use_old_model) 176 | if docstr_loc.get(str(py_file)) is None: 177 | docstr_loc[str(py_file)] = {start_index[0]: 178 | (start_index[1], 179 | docstr[0], 180 | params 181 | ) 182 | } 183 | else: 184 | docstr_loc[str(py_file)][start_index[0]] = (start_index[1], 185 | docstr[0], 186 | params) 187 | table_rows.append([f_path.name, function_name, docstr[0]]) 188 | ipynb_files[str(py_file.absolute())] = f_path.absolute() 189 | return table_rows, docstr_loc, ipynb_files 190 | 191 | 192 | 193 | def main(): 194 | # if look_for_update(): 195 | # print_on_console("There is an update available. Please run `pip install --upgrade docly`", color="green", emoji="rotating_light") 196 | _print_welcome() 197 | 198 | setup_cmdline_args_for_docly_gen(parser) 199 | args = parser.parse_args() 200 | 201 | if args.run_on_notebooks and not _if_jupytext_is_installed(): 202 | print_on_console("You have mentioned `run_on_notebooks` but the needed dependecy is not present. Please run `pip install 'docly[jupyter]'` for that. This switch will be ignored", color="green") 203 | args.run_on_notebooks = False 204 | 205 | # if args.run_on_notebooks: 206 | # print_on_console("You have mentioned the `run_on_notebooks` switch. It is experimental", color="red", emoji="rotating_light") 207 | # choice = query_yes_no("Do you want to continue?") 208 | # if not choice: 209 | # args.run_on_notebooks = False 210 | 211 | config = DoclyConfig(args.config_file) 212 | 213 | try: 214 | mdp = NEW_MODEL_DOWNLOAD_ROOT if not args.use_old_model else MODEL_DOWNLOAD_ROOT 215 | inspect_and_download_latest_model(ROOT, mdp, args.use_old_model) 216 | except KeyboardInterrupt: 217 | print_on_console("You stopped the download. Docly won't work", color="red", emoji="X") 218 | shutil.rmtree(str(ROOT / "model")) 219 | sys.exit(1) 220 | 221 | try: 222 | ready, tslib_file = inspect_and_download_latest_tslibs(ROOT, TSLIBS_DOWNLOAD_ROOT) 223 | if not ready: 224 | print_on_console("===== OS version not supported =====", color="red", emoji="X") 225 | return 226 | except KeyboardInterrupt: 227 | print_on_console("You stopped the download. Docly won't work", color="red", emoji="X") 228 | sys.exit(1) 229 | 230 | print_on_console("Loading Engine. Please wait", color="green") 231 | model_name = "pytorch_model_new.bin" if not args.use_old_model else "pytorch_model.bin" 232 | model, tokenizer = load_model(str(ROOT / "model"/ model_name), args.use_old_model) 233 | print_on_console("Engine Loaded.", color="green", emoji="heavy_check_mark") 234 | 235 | ts_lib_path = str(ROOT / "tslibs" / tslib_file) 236 | 237 | table_rows, docstr_loc, ipynb_files = _process(args, model, tokenizer, ts_lib_path, config) 238 | 239 | _deal_with_result(args, table_rows, docstr_loc, ipynb_files) 240 | -------------------------------------------------------------------------------- /docly/cli/docly_restore.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | import os 4 | import shutil 5 | 6 | from .args import setup_cmdline_args_for_docly_restore 7 | from docly.ioutils.apply_diff import CACHE_DIR 8 | from docly.ioutils import check_out_path, is_dir, is_python_file, query_yes_no, is_ipynb_notebook 9 | from docly.ioutils.console_printer import print_on_console 10 | 11 | parser = ArgumentParser() 12 | 13 | 14 | def get_all_files_from_cache(): 15 | return os.listdir(str(CACHE_DIR)) 16 | 17 | 18 | def main(): 19 | setup_cmdline_args_for_docly_restore(parser) 20 | args = parser.parse_args() 21 | all_cached_files = get_all_files_from_cache() 22 | if not args.force: 23 | choice = query_yes_no("It will restore all MODIFIED files to the state of last run of `docly-gen`. Are you sure?") 24 | else: 25 | # Forceful application of restore command 26 | choice = True 27 | if choice: 28 | print_on_console("Restoring files", color="green") 29 | try: 30 | for file in check_out_path(Path().cwd().absolute()): 31 | if not is_dir(file) and (is_python_file(file) or is_ipynb_notebook(file)): 32 | full_path = str(Path(file).absolute()) 33 | comparison_key = full_path[1:].replace("/", "#") 34 | if comparison_key in all_cached_files: 35 | source_file = str(CACHE_DIR / comparison_key) 36 | final_file = str(Path(file).absolute()) 37 | shutil.move(source_file, final_file) 38 | print_on_console("Restoring done", color="green", emoji="thumbsup") 39 | except KeyboardInterrupt: 40 | print_on_console("Restoration not finished", color="red", emoji="X") -------------------------------------------------------------------------------- /docly/cli/setup_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import platform 4 | 5 | from docly.ioutils import download_from_url 6 | from docly.ioutils.console_printer import print_on_console 7 | 8 | SUPPORTED_PLATFORMS = ["linux", "darwin"] 9 | 10 | 11 | def inspect_and_download_latest_model(model_root: Path, download_url: str, is_old=False) -> bool: 12 | model_file = "pytorch_model_new.bin" if not is_old else "pytorch_model.bin" 13 | 14 | if not model_root.exists(): 15 | os.makedirs(str(model_root)) 16 | os.mkdir(str(model_root / "model")) 17 | elif model_root.exists() and not (model_root / "model").exists(): 18 | os.mkdir(str(model_root / "model")) 19 | 20 | if Path(model_root/ "model" / model_file).exists() and Path(model_root/ "model" / model_file).is_file(): 21 | return True 22 | 23 | print_on_console("There is no model. Downloading (maybe because you chose to use the newer model)", color="green") 24 | download_from_url(download_url, str(Path(model_root/ "model" / model_file))) 25 | print_on_console("Download complete", color="green", emoji="heavy_check_mark") 26 | return True 27 | 28 | 29 | def inspect_and_download_latest_tslibs(tslibs_root: Path, download_url: str) -> bool: 30 | os_name = platform.system().lower() 31 | if os_name not in SUPPORTED_PLATFORMS: 32 | return (False, None) 33 | 34 | file_name = "python_ts_darwin64.so" if os_name == "darwin" else "python_ts_nix64.so" 35 | download_url = f"{download_url}{file_name}" 36 | 37 | if (tslibs_root/ "tslibs" / file_name).exists() and (tslibs_root/ "tslibs" / file_name).is_file(): 38 | return (True, file_name) 39 | 40 | if not (tslibs_root / "tslibs").exists(): 41 | os.mkdir(str(tslibs_root / "tslibs")) 42 | 43 | print_on_console("There is no tree-sitter lib. Downloading", color="green") 44 | download_from_url(download_url, str(Path(tslibs_root/ "tslibs" / file_name))) 45 | print_on_console("Download complete", color="green", emoji="heavy_check_mark") 46 | return (True, file_name) 47 | -------------------------------------------------------------------------------- /docly/config/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import configparser 3 | 4 | 5 | class BadConfigError(Exception): 6 | pass 7 | 8 | 9 | class DoclyConfig(object): 10 | 11 | def __init__(self, config_file): 12 | if Path(config_file).exists() and Path(config_file).is_file(): 13 | self.config = configparser.ConfigParser() 14 | self.config.read(config_file) 15 | if "skipDirs" not in self.config: 16 | raise BadConfigError("\n\n=====>\nYou have mentioned a config file but it is badly formatted") 17 | else: 18 | self.config = None 19 | 20 | def _parent_contains_child(self, parent_path: str, child_path: str): 21 | child_path_rev = list(reversed(child_path.split("/"))) 22 | parent_path_list = list(reversed(parent_path)) 23 | parent_path_to_compare = parent_path_list[:len(child_path_rev)] 24 | return child_path_rev == parent_path_to_compare 25 | 26 | def is_dir_skipped(self, file_path_until_last_parent: str): 27 | if not self.config: 28 | return False 29 | for p in self.config["skipDirs"].keys(): 30 | if self._parent_contains_child(file_path_until_last_parent, p): 31 | return True 32 | return False -------------------------------------------------------------------------------- /docly/ioutils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import requests 4 | import shutil 5 | import sys 6 | from distutils.version import LooseVersion 7 | import time 8 | 9 | from tqdm import tqdm 10 | 11 | from docly.parser import parser as py_parser 12 | from docly.tokenizers import tokenize_code_string 13 | from docly import __version__ 14 | 15 | # from c2nl.objects import Code 16 | UPDATE_CHECK_URL = "http://3.80.2.138:8584/vercheck/check-version/" 17 | # UPDATE_CHECK_URL = "http://127.0.0.1:5000/vercheck/check-version/" 18 | 19 | interaction_cache = lambda : Path(Path.home() / ".docly" / "interaction_cache") 20 | 21 | CACHE_DIR = (Path().home() / ".docly" / "file_cache") 22 | 23 | cache_exists = lambda : CACHE_DIR.exists() 24 | make_cache_dir = lambda : os.mkdir(str(CACHE_DIR)) 25 | 26 | 27 | def _compare_installed_version_with_latest(v1, v2): 28 | try: 29 | current_version = LooseVersion(v1) 30 | latest_version = LooseVersion(v2) 31 | assert current_version == latest_version 32 | return True 33 | except AssertionError: 34 | return False 35 | 36 | 37 | def look_for_update(): 38 | with requests.sessions.Session() as s: 39 | try: 40 | r = s.get(UPDATE_CHECK_URL, timeout=2) 41 | r.raise_for_status() 42 | if not _compare_installed_version_with_latest(__version__, r.text): 43 | i_c = interaction_cache() 44 | return True 45 | return False 46 | except Exception: 47 | i_c = interaction_cache() 48 | if not i_c.exists(): 49 | os.mkdir(i_c) 50 | if not (i_c / "icache.txt").exists(): 51 | with open((i_c / "icache.txt"), "w") as f: 52 | f.write(str(int(time.time())) + "\n") 53 | else: 54 | with open((i_c / "icache.txt"), "a") as f: 55 | f.write(str(int(time.time())) + "\n") 56 | return False 57 | 58 | 59 | def is_dir(base_path): 60 | if isinstance(base_path, Path): 61 | return base_path.is_dir() 62 | elif isinstance(base_path, str): 63 | return Path(base_path).is_dir() 64 | else: 65 | return False 66 | 67 | 68 | def is_python_file(file_path): 69 | if isinstance(file_path, Path): 70 | return file_path.suffix == ".py" 71 | elif isinstance(file_path, str): 72 | return Path(file_path).suffix == ".py" 73 | else: 74 | return False 75 | 76 | 77 | def is_ipynb_notebook(file_path): 78 | if isinstance(file_path, Path): 79 | return file_path.suffix == ".ipynb" 80 | elif isinstance(file_path, str): 81 | return Path(file_path).suffix == ".ipynb" 82 | else: 83 | return False 84 | 85 | 86 | def download_from_url(url, dst): 87 | """ 88 | @param: url to download file 89 | @param: dst place to put the file 90 | """ 91 | file_size = int(requests.head(url).headers["Content-Length"]) 92 | if os.path.exists(dst): 93 | first_byte = os.path.getsize(dst) 94 | else: 95 | first_byte = 0 96 | if first_byte >= file_size: 97 | return file_size 98 | header = {"Range": "bytes=%s-%s" % (first_byte, file_size)} 99 | pbar = tqdm( 100 | total=file_size, initial=first_byte, 101 | unit='B', unit_scale=True, desc=dst.split('/')[-1]) 102 | req = requests.get(url, headers=header, stream=True) 103 | with(open(dst, 'ab')) as f: 104 | for chunk in req.iter_content(chunk_size=1024): 105 | if chunk: 106 | f.write(chunk) 107 | pbar.update(1024) 108 | pbar.close() 109 | return file_size 110 | 111 | 112 | def check_out_path(target_path: Path): 113 | """" 114 | This function recursively yields all contents of a pathlib.Path object 115 | """ 116 | yield target_path 117 | for file in target_path.iterdir(): 118 | if file.is_dir(): 119 | yield from check_out_path(file) 120 | else: 121 | yield file.absolute() 122 | 123 | 124 | def process_file(file_path: Path, ts_lib_path: str, use_old=False): 125 | result, parser_obj = py_parser.parse(file_path, ts_lib_path) 126 | func_and_params = parser_obj.get_all_function_names_with_params() 127 | if result: 128 | for func_name, data in py_parser.get_func_body_and_docstr(parser_obj): 129 | # print(py_toeknizer.tokenize_code_string(func_body)) 130 | # code.tokens = tokenizer.tokenize(func_body).data 131 | # code.text = func_body 132 | (func_body, docstr), start, end = data 133 | ret_start = (start[0]+1, start[1]) 134 | params = func_and_params[func_name] 135 | 136 | code_str = [tokenize_code_string(func_body)] if use_old else func_body 137 | 138 | yield code_str, params, ret_start, func_name, docstr.strip() 139 | 140 | 141 | def query_yes_no(question, default="yes"): 142 | """Ask a yes/no question and return their answer. 143 | 144 | "question" is a string that is presented to the user. 145 | "default" is the presumed answer if the user just hits . 146 | It must be "yes", "no", or None (meaning 147 | an answer is required of the user). 148 | 149 | The "answer" return value is True for "yes" or False for "no". 150 | """ 151 | valid = {"yes": True, "y": True, "ye": True, 152 | "no": False, "n": False} 153 | if default is None: 154 | prompt = " [y/n] " 155 | elif default == "yes": 156 | prompt = " [Y/n] " 157 | elif default == "no": 158 | prompt = " [y/N] " 159 | else: 160 | raise ValueError("invalid default answer: '{}}'".format(default)) 161 | 162 | while True: 163 | print(question + prompt) 164 | choice = input().lower() 165 | if default is not None and choice == '': 166 | return valid[default] 167 | elif choice in valid: 168 | return valid[choice] 169 | else: 170 | print("Please respond with 'yes' or 'no' " 171 | "(or 'y' or 'n').\n") -------------------------------------------------------------------------------- /docly/ioutils/apply_diff.py: -------------------------------------------------------------------------------- 1 | ######################### 2 | 3 | # This entire file should be re-written 4 | # using ast or tree-hugger tree traversal 5 | # The way it is now, can cause problems 6 | # in some corner cases. 7 | 8 | ####################### 9 | 10 | import os 11 | import shutil 12 | from pathlib import Path 13 | from typing import Dict 14 | # from tabnanny import check 15 | 16 | from docly.ioutils import CACHE_DIR, cache_exists, make_cache_dir 17 | from docly.ioutils.convert_ipynb import convert_python_to_ipynb 18 | 19 | 20 | def _generate_main_docstring(docstring, spaces, style, will_follow_param: bool): 21 | if not will_follow_param: 22 | if style == "google": 23 | return f'{spaces}"""\n{spaces}{docstring}\n\n{spaces}(Generated by docly)\n{spaces}"""\n' 24 | elif style == "sphinx": 25 | return f'{spaces}"""{docstring}\n\n{spaces}(Generated by docly)\n{spaces}"""\n' 26 | elif style == "numpy": 27 | return f"{spaces}'''\n{spaces}{docstring}\n\n{spaces}(Generated by docly)\n{spaces}'''\n" 28 | else: 29 | if style == "google": 30 | return f'{spaces}"""\n{spaces}{docstring}' 31 | elif style == "sphinx": 32 | return f'{spaces}"""{docstring}' 33 | elif style == "numpy": 34 | return f"{spaces}'''\n{spaces}{docstring}" 35 | 36 | 37 | def _generate_param_list(params, style, spaces): 38 | line_to_write = "" 39 | if style == "google": 40 | line_to_write = line_to_write + f"\n\n{spaces}Args:\n" 41 | for (param_name, param_type, default_val) in params: 42 | param_desc = f"{param_name}" 43 | if param_type: 44 | param_desc = param_desc + f" ({param_type})" 45 | if default_val: 46 | param_desc = param_desc + f" : Defaults to {default_val}" 47 | if param_desc.find(":") == -1: 48 | param_desc = param_desc + " :" 49 | if len(spaces) == 4: 50 | line_to_write = line_to_write + f"{spaces}{spaces}{param_desc}\n" 51 | if len(spaces) > 4: 52 | line_to_write = line_to_write + f"{spaces} {param_desc}\n" 53 | elif style == "numpy": 54 | line_to_write = line_to_write + f"\n\n{spaces}Parameters\n{spaces}----------" 55 | for (param_name, param_type, default_val) in params: 56 | param_desc = f"\n{spaces}{param_name} :" 57 | if param_type: 58 | param_desc = param_desc + f" {param_type}" 59 | line_to_write = line_to_write + f"{param_desc}" 60 | elif style == "sphinx": 61 | line_to_write = "\n" 62 | for (param_name, param_type, default_val) in params: 63 | param_desc = f"\n{spaces}:param {param_name}:" 64 | if default_val: 65 | param_desc = param_desc + f"\n{spaces} defaults to {default_val}" 66 | if param_type: 67 | param_desc = param_desc + f"\n{spaces}:type {param_name}: {param_type}" 68 | line_to_write = line_to_write + f"{param_desc}" 69 | return line_to_write 70 | 71 | 72 | def _get_line_to_write(docstrs, line_num, should_write_args_list: bool, docstring_style: str): 73 | docstr_line = docstrs.get(line_num+1)[1] 74 | num_spaces = int(docstrs.get(line_num+1)[0]) 75 | params = docstrs.get(line_num+1)[2] 76 | 77 | line_to_write = None 78 | 79 | spaces = " ".join([''] * (num_spaces + 1)) 80 | if not params: 81 | line_to_write = _generate_main_docstring(docstr_line, spaces, docstring_style, False) 82 | else: 83 | line_to_write = _generate_main_docstring(docstr_line, spaces, docstring_style, True) 84 | if should_write_args_list: 85 | line_to_write = line_to_write + _generate_param_list(params, docstring_style, spaces) 86 | if docstring_style == "google" or docstring_style == "sphinx": 87 | line_to_write = line_to_write + f'\n\n{spaces}(Generated by docly)\n{spaces}"""\n' 88 | elif docstring_style == "numpy": 89 | line_to_write = line_to_write + f"\n\n{spaces}(Generated by docly)\n{spaces}'''\n" 90 | return line_to_write 91 | 92 | 93 | def apply_diff(docstr_loc: Dict[str, Dict[int, tuple]], should_write_args_list: bool, ipynb_files: Dict, docstring_style: str): 94 | try: 95 | for file_loc, docstrs in docstr_loc.items(): 96 | # l = check(file_loc) 97 | temp_file_name = f"{str(Path(file_loc).stem)}.pytemp" 98 | final_file_name = f"{str(Path(file_loc).stem)}.py" 99 | temp_file = (Path(file_loc).parent / temp_file_name) 100 | final_file = (Path(file_loc).parent / final_file_name) 101 | write_handle = open(temp_file, "w") 102 | 103 | with open(file_loc) as f: 104 | for line_num, line in enumerate(f): 105 | if docstrs.get(line_num+1): 106 | line_to_write = _get_line_to_write(docstrs, line_num, should_write_args_list, docstring_style) 107 | write_handle.write(line_to_write) 108 | write_handle.write(line) 109 | else: 110 | write_handle.write(line) 111 | 112 | write_handle.close() 113 | if len(ipynb_files) == 0: 114 | cache_file_name = file_loc[1:].replace("/", "#") 115 | 116 | if not cache_exists(): 117 | make_cache_dir() 118 | 119 | if (CACHE_DIR / cache_file_name).exists(): 120 | (CACHE_DIR / cache_file_name).unlink() 121 | 122 | shutil.move(file_loc, str(CACHE_DIR / cache_file_name)) 123 | shutil.move(str(temp_file), str(final_file)) 124 | elif len(ipynb_files) > 0 and file_loc not in ipynb_files.keys(): 125 | cache_file_name = file_loc[1:].replace("/", "#") 126 | 127 | if not cache_exists(): 128 | make_cache_dir() 129 | 130 | if (CACHE_DIR / cache_file_name).exists(): 131 | (CACHE_DIR / cache_file_name).unlink() 132 | 133 | shutil.move(file_loc, str(CACHE_DIR / cache_file_name)) 134 | shutil.move(str(temp_file), str(final_file)) 135 | else: 136 | if Path(final_file).exists(): 137 | Path(final_file).unlink() 138 | 139 | shutil.move(str(temp_file), str(final_file)) 140 | convert_python_to_ipynb(Path(final_file)) 141 | 142 | Path(final_file).unlink() 143 | except KeyboardInterrupt: 144 | temp_file.unlink() 145 | -------------------------------------------------------------------------------- /docly/ioutils/console_printer.py: -------------------------------------------------------------------------------- 1 | from rich import print as rprint 2 | 3 | 4 | def print_on_console(text, color="green", emoji=None): 5 | if not emoji: 6 | rprint(f"[{color}]{text}[/{color}]") 7 | else: 8 | rprint(f"[{color}]{text}[/{color}]", f":{emoji}:") -------------------------------------------------------------------------------- /docly/ioutils/convert_ipynb.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import shutil 3 | import logging 4 | 5 | from invoke import run 6 | from docly.ioutils import CACHE_DIR, cache_exists, make_cache_dir 7 | 8 | try: 9 | import jupytext 10 | JUPYTEXT_AVAILABLE = True 11 | except Exception: 12 | JUPYTEXT_AVAILABLE = False 13 | 14 | 15 | IPYNB_TO_PY_CMD = "jupytext --set-formats ipynb,py:percent" 16 | PY_TO_IPYNB_CMD = "jupytext --update --to ipynb" 17 | 18 | 19 | def convert_ipynb_to_python(notebook_path: Path): 20 | if JUPYTEXT_AVAILABLE: 21 | try: 22 | jupytext.read(notebook_path.absolute()) 23 | except Exception: 24 | return None 25 | 26 | actual_file_path = str(notebook_path.absolute()) 27 | cache_file_name = actual_file_path[1:].replace("/", "#") 28 | 29 | if not cache_exists(): 30 | make_cache_dir() 31 | 32 | if (CACHE_DIR / cache_file_name).exists(): 33 | (CACHE_DIR / cache_file_name).unlink() 34 | 35 | shutil.copy(actual_file_path, str(CACHE_DIR / cache_file_name)) 36 | 37 | result = run(f"{IPYNB_TO_PY_CMD} {str(notebook_path.absolute())}", hide=True, warn=True) 38 | 39 | if not result.ok: 40 | logging.error("Could not run the conversion command. Are you using an old version of Jupyter notebook? Otherwise, Maybe use `pip install 'docly[jupyter]'") 41 | return None 42 | else: 43 | return (notebook_path.absolute().parent / (notebook_path.stem + '.py')).absolute() 44 | 45 | 46 | def convert_python_to_ipynb(python_file_path: Path): 47 | result = run(f"{PY_TO_IPYNB_CMD} {str(python_file_path.absolute())}", hide=True, warn=True) 48 | if not result.ok: 49 | logging.error("Could not run the conversion command. Maybe use `pip install 'docly[jupyter]'") 50 | else: 51 | return (python_file_path.absolute().parent / (python_file_path.stem + '.ipynb')).absolute() -------------------------------------------------------------------------------- /docly/ioutils/table_printer.py: -------------------------------------------------------------------------------- 1 | from rich.table import Table 2 | from rich.console import Console 3 | 4 | 5 | def print_results_as_table(rows): 6 | table = Table(title="Functions and Docstrings") 7 | 8 | table.add_column("File Name", justify="left") 9 | table.add_column("Function Name", justify="left") 10 | table.add_column("Docstring", justify="left") 11 | 12 | for row in rows: 13 | table.add_row(*row) 14 | 15 | console = Console() 16 | console.print(table) 17 | -------------------------------------------------------------------------------- /docly/logic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autosoft-dev/docly/0bd6216b8a9735e9fa76bffd4ffea6cec6cc4a01/docly/logic/__init__.py -------------------------------------------------------------------------------- /docly/logic/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class Example(object): 5 | """A single training/test example.""" 6 | def __init__(self, 7 | idx, 8 | source, 9 | target, 10 | ): 11 | self.idx = idx 12 | self.source = source 13 | self.target = target 14 | 15 | 16 | def make_example(code_tokens): 17 | idx = random.randint(1, 100000) 18 | examples = [] 19 | for ct in code_tokens: 20 | ex = Example(idx=idx, 21 | source=" ".join(ct), 22 | target="") 23 | examples.append(ex) 24 | return examples 25 | 26 | 27 | class NewExample(object): 28 | 29 | def __init__(self, 30 | source, 31 | target 32 | ): 33 | self.source = source 34 | self.target = target 35 | 36 | 37 | def make_new_example(code): 38 | return [NewExample(source=code, target=None)] 39 | -------------------------------------------------------------------------------- /docly/logic/input_features.py: -------------------------------------------------------------------------------- 1 | class InputFeatures(object): 2 | """A single training/test features for a example.""" 3 | def __init__(self, 4 | example_id, 5 | source_ids, 6 | target_ids, 7 | source_mask, 8 | target_mask, 9 | 10 | ): 11 | self.example_id = example_id 12 | self.source_ids = source_ids 13 | self.target_ids = target_ids 14 | self.source_mask = source_mask 15 | self.target_mask = target_mask 16 | 17 | 18 | def convert_examples_to_features(examples, 19 | tokenizer, 20 | max_source_length=256, 21 | max_target_length=128): 22 | features = [] 23 | for example_index, example in enumerate(examples): 24 | source_tokens = tokenizer.tokenize(example.source)[:max_source_length-2] 25 | source_tokens = [tokenizer.cls_token] + source_tokens + [tokenizer.sep_token] 26 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 27 | source_mask = [1] * (len(source_tokens)) 28 | padding_length = max_source_length - len(source_ids) 29 | source_ids += [tokenizer.pad_token_id] * padding_length 30 | source_mask += [0] * padding_length 31 | 32 | target_tokens = tokenizer.tokenize("None") 33 | target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token] 34 | target_ids = tokenizer.convert_tokens_to_ids(target_tokens) 35 | target_mask = [1] * len(target_ids) 36 | padding_length = max_target_length - len(target_ids) 37 | target_ids += [tokenizer.pad_token_id] * padding_length 38 | target_mask += [0] * padding_length 39 | 40 | features.append( 41 | InputFeatures( 42 | example_index, 43 | source_ids, 44 | target_ids, 45 | source_mask, 46 | target_mask, 47 | ) 48 | ) 49 | return features 50 | -------------------------------------------------------------------------------- /docly/logic/logic_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from io import open 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .example import make_example, make_new_example 10 | from .input_features import convert_examples_to_features 11 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, TensorDataset 12 | 13 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 14 | RobertaConfig, RobertaModel, RobertaTokenizer) 15 | 16 | 17 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} 18 | 19 | model_name_or_path = "microsoft/codebert-base" 20 | beam_size = 10 21 | max_target_length = 128 22 | max_source_length = 256 23 | seed = 42 24 | 25 | 26 | def load_model(model_path, is_old=False): 27 | if is_old: 28 | from .model import Seq2Seq 29 | else: 30 | from .model_new import Seq2Seq 31 | 32 | config_class, model_class, tokenizer_class = MODEL_CLASSES['roberta'] 33 | 34 | config = config_class.from_pretrained(model_name_or_path) 35 | if is_old: 36 | tokenizer = tokenizer_class.from_pretrained(model_name_or_path) 37 | else: 38 | tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=False) 39 | 40 | encoder = model_class.from_pretrained(model_name_or_path, config=config) 41 | decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, 42 | nhead=config.num_attention_heads) 43 | decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 44 | model = Seq2Seq(encoder=encoder, 45 | decoder=decoder, 46 | config=config, 47 | beam_size=beam_size, 48 | max_length=max_target_length, 49 | sos_id=tokenizer.cls_token_id, 50 | eos_id=tokenizer.sep_token_id 51 | ) 52 | if is_old: 53 | if not torch.cuda.is_available(): 54 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) 55 | else: 56 | model.load_state_dict(torch.load(model_path)) 57 | else: 58 | if not torch.cuda.is_available(): 59 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) 60 | else: 61 | model.load_state_dict(torch.load(model_path), strict=False) 62 | if not torch.cuda.is_available(): 63 | model.to("cpu") 64 | model.eval() 65 | 66 | return model, tokenizer 67 | 68 | 69 | def predict_docstring(model, tokenizer, code_tokens, is_old): 70 | examples = make_example(code_tokens) if is_old else make_new_example(code_tokens) 71 | 72 | features = convert_examples_to_features(examples, tokenizer) 73 | if is_old: 74 | all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) 75 | all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long) 76 | else: 77 | all_source_ids = torch.tensor([f.source_ids[: max_source_length] for f in features], dtype=torch.long) 78 | all_source_mask = torch.tensor([f.source_mask[: max_source_length] for f in features], dtype=torch.long) 79 | 80 | eval_data = TensorDataset(all_source_ids, all_source_mask) 81 | 82 | eval_sampler = SequentialSampler(eval_data) 83 | batch_size = len(code_tokens) if is_old else len(eval_data) 84 | 85 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size) 86 | 87 | p=[] 88 | for batch in eval_dataloader: 89 | if not torch.cuda.is_available(): 90 | batch = tuple(t.to('cpu') for t in batch) 91 | else: 92 | batch = tuple(t for t in batch) 93 | source_ids, source_mask = batch 94 | 95 | with torch.no_grad(): 96 | preds = model(source_ids=source_ids, source_mask=source_mask) 97 | 98 | for pred in preds: 99 | t=pred[0].cpu().numpy() 100 | t=list(t) 101 | if 0 in t: 102 | t=t[:t.index(0)] 103 | text = tokenizer.decode(t,clean_up_tokenization_spaces=False) 104 | p.append(text) 105 | 106 | px = p[0].split() 107 | if px[-1] == ".": 108 | px[-2] = px[-2].strip() + "." 109 | px.pop() 110 | 111 | return [" ".join(px)] 112 | -------------------------------------------------------------------------------- /docly/logic/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch 7 | from torch.autograd import Variable 8 | import copy 9 | 10 | 11 | class Seq2Seq(nn.Module): 12 | """ 13 | Build Seqence-to-Sequence. 14 | 15 | Parameters: 16 | 17 | * `encoder`- encoder of seq2seq model. e.g. roberta 18 | * `decoder`- decoder of seq2seq model. e.g. transformer 19 | * `config`- configuration of encoder model. 20 | * `beam_size`- beam size for beam search. 21 | * `max_length`- max length of target for beam search. 22 | * `sos_id`- start of symbol ids in target for beam search. 23 | * `eos_id`- end of symbol ids in target for beam search. 24 | """ 25 | def __init__(self, encoder,decoder,config,beam_size=None,max_length=None,sos_id=None,eos_id=None): 26 | super(Seq2Seq, self).__init__() 27 | self.encoder = encoder 28 | self.decoder=decoder 29 | self.config=config 30 | self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) 31 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 32 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 33 | self.lsm = nn.LogSoftmax(dim=-1) 34 | self.tie_weights() 35 | 36 | self.beam_size=beam_size 37 | self.max_length=max_length 38 | self.sos_id=sos_id 39 | self.eos_id=eos_id 40 | 41 | def _tie_or_clone_weights(self, first_module, second_module): 42 | """ Tie or clone module weights depending of weither we are using TorchScript or not 43 | """ 44 | if self.config.torchscript: 45 | first_module.weight = nn.Parameter(second_module.weight.clone()) 46 | else: 47 | first_module.weight = second_module.weight 48 | 49 | def tie_weights(self): 50 | """ Make sure we are sharing the input and output embeddings. 51 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 52 | """ 53 | self._tie_or_clone_weights(self.lm_head, 54 | self.encoder.embeddings.word_embeddings) 55 | 56 | def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): 57 | outputs = self.encoder(source_ids, attention_mask=source_mask) 58 | encoder_output = outputs[0].permute([1,0,2]).contiguous() 59 | if target_ids is not None: 60 | attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) 61 | tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous() 62 | out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,memory_key_padding_mask=(1-source_mask).bool()) 63 | hidden_states = torch.tanh(self.dense(out)).permute([1,0,2]).contiguous() 64 | lm_logits = self.lm_head(hidden_states) 65 | # Shift so that tokens < n predict n 66 | active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 67 | shift_logits = lm_logits[..., :-1, :].contiguous() 68 | shift_labels = target_ids[..., 1:].contiguous() 69 | # Flatten the tokens 70 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 71 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], 72 | shift_labels.view(-1)[active_loss]) 73 | 74 | outputs = loss,loss*active_loss.sum(),active_loss.sum() 75 | return outputs 76 | else: 77 | #Predict 78 | preds=[] 79 | zero=torch.LongTensor(1).fill_(0) 80 | for i in range(source_ids.shape[0]): 81 | context=encoder_output[:,i:i+1] 82 | context_mask=source_mask[i:i+1,:] 83 | beam = Beam(self.beam_size,self.sos_id,self.eos_id) 84 | input_ids=beam.getCurrentState() 85 | context=context.repeat(1, self.beam_size,1) 86 | context_mask=context_mask.repeat(self.beam_size,1) 87 | for _ in range(self.max_length): 88 | if beam.done(): 89 | break 90 | attn_mask=-1e4 *(1-self.bias[:input_ids.shape[1],:input_ids.shape[1]]) 91 | tgt_embeddings = self.encoder.embeddings(input_ids).permute([1,0,2]).contiguous() 92 | out = self.decoder(tgt_embeddings,context,tgt_mask=attn_mask,memory_key_padding_mask=(1-context_mask).bool()) 93 | out = torch.tanh(self.dense(out)) 94 | hidden_states=out.permute([1,0,2]).contiguous()[:,-1,:] 95 | out = self.lsm(self.lm_head(hidden_states)).data 96 | beam.advance(out) 97 | input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) 98 | input_ids=torch.cat((input_ids,beam.getCurrentState()),-1) 99 | hyp= beam.getHyp(beam.getFinal()) 100 | pred=beam.buildTargetTokens(hyp)[:self.beam_size] 101 | pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] 102 | preds.append(torch.cat(pred,0).unsqueeze(0)) 103 | 104 | preds=torch.cat(preds,0) 105 | return preds 106 | 107 | 108 | class Beam(object): 109 | def __init__(self, size,sos,eos): 110 | self.size = size 111 | self.tt = torch 112 | # The score for each translation on the beam. 113 | self.scores = self.tt.FloatTensor(size).zero_() 114 | # The backpointers at each time-step. 115 | self.prevKs = [] 116 | # The outputs at each time-step. 117 | self.nextYs = [self.tt.LongTensor(size) 118 | .fill_(0)] 119 | self.nextYs[0][0] = sos 120 | # Has EOS topped the beam yet. 121 | self._eos = eos 122 | self.eosTop = False 123 | # Time and k pair for finished. 124 | self.finished = [] 125 | 126 | def getCurrentState(self): 127 | "Get the outputs for the current timestep." 128 | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) 129 | return batch 130 | 131 | def getCurrentOrigin(self): 132 | "Get the backpointers for the current timestep." 133 | return self.prevKs[-1] 134 | 135 | def advance(self, wordLk): 136 | """ 137 | Given prob over words for every last beam `wordLk` and attention 138 | `attnOut`: Compute and update the beam search. 139 | 140 | Parameters: 141 | 142 | * `wordLk`- probs of advancing from the last step (K x words) 143 | * `attnOut`- attention at the last step 144 | 145 | Returns: True if beam search is complete. 146 | """ 147 | numWords = wordLk.size(1) 148 | 149 | # Sum the previous scores. 150 | if len(self.prevKs) > 0: 151 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 152 | 153 | # Don't let EOS have children. 154 | for i in range(self.nextYs[-1].size(0)): 155 | if self.nextYs[-1][i] == self._eos: 156 | beamLk[i] = -1e20 157 | else: 158 | beamLk = wordLk[0] 159 | flatBeamLk = beamLk.view(-1) 160 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 161 | 162 | self.scores = bestScores 163 | 164 | # bestScoresId is flattened beam x word array, so calculate which 165 | # word and beam each score came from 166 | prevK = bestScoresId // numWords 167 | self.prevKs.append(prevK) 168 | self.nextYs.append((bestScoresId - prevK * numWords)) 169 | 170 | 171 | for i in range(self.nextYs[-1].size(0)): 172 | if self.nextYs[-1][i] == self._eos: 173 | s = self.scores[i] 174 | self.finished.append((s, len(self.nextYs) - 1, i)) 175 | 176 | # End condition is when top-of-beam is EOS and no global score. 177 | if self.nextYs[-1][0] == self._eos: 178 | self.eosTop = True 179 | 180 | def done(self): 181 | return self.eosTop and len(self.finished) >=self.size 182 | 183 | def getFinal(self): 184 | if len(self.finished) == 0: 185 | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) 186 | self.finished.sort(key=lambda a: -a[0]) 187 | if len(self.finished) != self.size: 188 | unfinished=[] 189 | for i in range(self.nextYs[-1].size(0)): 190 | if self.nextYs[-1][i] != self._eos: 191 | s = self.scores[i] 192 | unfinished.append((s, len(self.nextYs) - 1, i)) 193 | unfinished.sort(key=lambda a: -a[0]) 194 | self.finished+=unfinished[:self.size-len(self.finished)] 195 | return self.finished[:self.size] 196 | 197 | def getHyp(self, beam_res): 198 | """ 199 | Walk back to construct the full hypothesis. 200 | """ 201 | hyps=[] 202 | for _,timestep, k in beam_res: 203 | hyp = [] 204 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 205 | hyp.append(self.nextYs[j+1][k]) 206 | k = self.prevKs[j][k] 207 | hyps.append(hyp[::-1]) 208 | return hyps 209 | 210 | def buildTargetTokens(self, preds): 211 | sentence=[] 212 | for pred in preds: 213 | tokens = [] 214 | for tok in pred: 215 | if tok==self._eos: 216 | break 217 | tokens.append(tok) 218 | sentence.append(tokens) 219 | return sentence 220 | -------------------------------------------------------------------------------- /docly/logic/model_new.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Seq2Seq(nn.Module): 9 | """ 10 | Build Seqence-to-Sequence. 11 | 12 | Parameters: 13 | 14 | * `encoder`- encoder of seq2seq model. e.g. roberta 15 | * `decoder`- decoder of seq2seq model. e.g. transformer 16 | * `config`- configuration of encoder model. 17 | * `beam_size`- beam size for beam search. 18 | * `max_length`- max length of target for beam search. 19 | * `sos_id`- start of symbol ids in target for beam search. 20 | * `eos_id`- end of symbol ids in target for beam search. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | encoder, 26 | decoder, 27 | config, 28 | beam_size=None, 29 | max_length=None, 30 | sos_id=None, 31 | eos_id=None, 32 | ): 33 | super(Seq2Seq, self).__init__() 34 | self.encoder = encoder 35 | self.decoder = decoder 36 | self.config = config 37 | self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) 38 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 39 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 40 | self.lsm = nn.LogSoftmax(dim=-1) 41 | self.tie_weights() 42 | 43 | self.beam_size = beam_size 44 | self.max_length = max_length 45 | self.sos_id = sos_id 46 | self.eos_id = eos_id 47 | 48 | def _tie_or_clone_weights(self, first_module, second_module): 49 | """Tie or clone module weights depending of weither we are using TorchScript or not""" 50 | if self.config.torchscript: 51 | first_module.weight = nn.Parameter(second_module.weight.clone()) 52 | else: 53 | first_module.weight = second_module.weight 54 | 55 | def tie_weights(self): 56 | """Make sure we are sharing the input and output embeddings. 57 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 58 | """ 59 | self._tie_or_clone_weights( 60 | self.lm_head, self.encoder.embeddings.word_embeddings 61 | ) 62 | 63 | def forward( 64 | self, 65 | source_ids=None, 66 | source_mask=None, 67 | target_ids=None, 68 | target_mask=None, 69 | args=None, 70 | ): 71 | outputs = self.encoder(source_ids, attention_mask=source_mask) 72 | encoder_output = outputs[0].permute([1, 0, 2]).contiguous() 73 | if target_ids is not None: 74 | attn_mask = -1e4 * ( 75 | 1 - self.bias[: target_ids.shape[1], : target_ids.shape[1]] 76 | ) 77 | tgt_embeddings = ( 78 | self.encoder.embeddings(target_ids).permute([1, 0, 2]).contiguous() 79 | ) 80 | out = self.decoder( 81 | tgt_embeddings, 82 | encoder_output, 83 | tgt_mask=attn_mask, 84 | memory_key_padding_mask=(1 - source_mask).bool(), 85 | ) 86 | hidden_states = torch.tanh(self.dense(out)).permute([1, 0, 2]).contiguous() 87 | lm_logits = self.lm_head(hidden_states) 88 | # Shift so that tokens < n predict n 89 | active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 90 | shift_logits = lm_logits[..., :-1, :].contiguous() 91 | shift_labels = target_ids[..., 1:].contiguous() 92 | # Flatten the tokens 93 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 94 | loss = loss_fct( 95 | shift_logits.view(-1, shift_logits.size(-1))[active_loss], 96 | shift_labels.view(-1)[active_loss], 97 | ) 98 | 99 | outputs = loss, loss * active_loss.sum(), active_loss.sum() 100 | return outputs 101 | else: 102 | # Predict 103 | preds = [] 104 | if source_ids.device.type == "cuda": 105 | zero = torch.cuda.LongTensor(1).fill_(0) 106 | elif source_ids.device.type == "cpu": 107 | zero = torch.LongTensor(1).fill_(0) 108 | for i in range(source_ids.shape[0]): 109 | context = encoder_output[:, i : i + 1] 110 | context_mask = source_mask[i : i + 1, :] 111 | beam = Beam( 112 | self.beam_size, 113 | self.sos_id, 114 | self.eos_id, 115 | device=source_ids.device.type, 116 | ) 117 | input_ids = beam.getCurrentState() 118 | context = context.repeat(1, self.beam_size, 1) 119 | context_mask = context_mask.repeat(self.beam_size, 1) 120 | for _ in range(self.max_length): 121 | if beam.done(): 122 | break 123 | attn_mask = -1e4 * ( 124 | 1 - self.bias[: input_ids.shape[1], : input_ids.shape[1]] 125 | ) 126 | tgt_embeddings = ( 127 | self.encoder.embeddings(input_ids) 128 | .permute([1, 0, 2]) 129 | .contiguous() 130 | ) 131 | out = self.decoder( 132 | tgt_embeddings, 133 | context, 134 | tgt_mask=attn_mask, 135 | memory_key_padding_mask=(1 - context_mask).bool(), 136 | ) 137 | out = torch.tanh(self.dense(out)) 138 | hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :] 139 | out = self.lsm(self.lm_head(hidden_states)).data 140 | beam.advance(out) 141 | input_ids.data.copy_( 142 | input_ids.data.index_select(0, beam.getCurrentOrigin()) 143 | ) 144 | input_ids = torch.cat((input_ids, beam.getCurrentState()), -1) 145 | hyp = beam.getHyp(beam.getFinal()) 146 | pred = beam.buildTargetTokens(hyp)[: self.beam_size] 147 | pred = [ 148 | torch.cat( 149 | [x.view(-1) for x in p] + [zero] * (self.max_length - len(p)) 150 | ).view(1, -1) 151 | for p in pred 152 | ] 153 | preds.append(torch.cat(pred, 0).unsqueeze(0)) 154 | 155 | preds = torch.cat(preds, 0) 156 | return preds 157 | 158 | 159 | class Beam(object): 160 | def __init__(self, size, sos, eos, device): 161 | self.size = size 162 | if device == "cuda": 163 | self.tt = torch.cuda 164 | elif device == "cpu": 165 | self.tt = torch 166 | # The score for each translation on the beam. 167 | self.scores = self.tt.FloatTensor(size).zero_() 168 | # The backpointers at each time-step. 169 | self.prevKs = [] 170 | # The outputs at each time-step. 171 | self.nextYs = [self.tt.LongTensor(size).fill_(0)] 172 | self.nextYs[0][0] = sos 173 | # Has EOS topped the beam yet. 174 | self._eos = eos 175 | self.eosTop = False 176 | # Time and k pair for finished. 177 | self.finished = [] 178 | 179 | def getCurrentState(self): 180 | "Get the outputs for the current timestep." 181 | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) 182 | return batch 183 | 184 | def getCurrentOrigin(self): 185 | "Get the backpointers for the current timestep." 186 | return self.prevKs[-1] 187 | 188 | def advance(self, wordLk): 189 | """ 190 | Given prob over words for every last beam `wordLk` and attention 191 | `attnOut`: Compute and update the beam search. 192 | 193 | Parameters: 194 | 195 | * `wordLk`- probs of advancing from the last step (K x words) 196 | * `attnOut`- attention at the last step 197 | 198 | Returns: True if beam search is complete. 199 | """ 200 | numWords = wordLk.size(1) 201 | 202 | # Sum the previous scores. 203 | if len(self.prevKs) > 0: 204 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 205 | 206 | # Don't let EOS have children. 207 | for i in range(self.nextYs[-1].size(0)): 208 | if self.nextYs[-1][i] == self._eos: 209 | beamLk[i] = -1e20 210 | else: 211 | beamLk = wordLk[0] 212 | flatBeamLk = beamLk.view(-1) 213 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 214 | 215 | self.scores = bestScores 216 | 217 | # bestScoresId is flattened beam x word array, so calculate which 218 | # word and beam each score came from 219 | prevK = bestScoresId // numWords 220 | self.prevKs.append(prevK) 221 | self.nextYs.append((bestScoresId - prevK * numWords)) 222 | 223 | for i in range(self.nextYs[-1].size(0)): 224 | if self.nextYs[-1][i] == self._eos: 225 | s = self.scores[i] 226 | self.finished.append((s, len(self.nextYs) - 1, i)) 227 | 228 | # End condition is when top-of-beam is EOS and no global score. 229 | if self.nextYs[-1][0] == self._eos: 230 | self.eosTop = True 231 | 232 | def done(self): 233 | return self.eosTop and len(self.finished) >= self.size 234 | 235 | def getFinal(self): 236 | if len(self.finished) == 0: 237 | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) 238 | self.finished.sort(key=lambda a: -a[0]) 239 | if len(self.finished) != self.size: 240 | unfinished = [] 241 | for i in range(self.nextYs[-1].size(0)): 242 | if self.nextYs[-1][i] != self._eos: 243 | s = self.scores[i] 244 | unfinished.append((s, len(self.nextYs) - 1, i)) 245 | unfinished.sort(key=lambda a: -a[0]) 246 | self.finished += unfinished[: self.size - len(self.finished)] 247 | return self.finished[: self.size] 248 | 249 | def getHyp(self, beam_res): 250 | """ 251 | Walk back to construct the full hypothesis. 252 | """ 253 | hyps = [] 254 | for _, timestep, k in beam_res: 255 | hyp = [] 256 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 257 | hyp.append(self.nextYs[j + 1][k]) 258 | k = self.prevKs[j][k] 259 | hyps.append(hyp[::-1]) 260 | return hyps 261 | 262 | def buildTargetTokens(self, preds): 263 | sentence = [] 264 | for pred in preds: 265 | tokens = [] 266 | for tok in pred: 267 | if tok == self._eos: 268 | break 269 | tokens.append(tok) 270 | sentence.append(tokens) 271 | return sentence 272 | -------------------------------------------------------------------------------- /docly/parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autosoft-dev/docly/0bd6216b8a9735e9fa76bffd4ffea6cec6cc4a01/docly/parser/__init__.py -------------------------------------------------------------------------------- /docly/parser/parser.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from tree_hugger.core import PythonParser 4 | 5 | 6 | def set_up_parser(): 7 | return 8 | 9 | 10 | def parse(file_path: str, tslib_path): 11 | python_parser = PythonParser(library_loc=tslib_path) 12 | return python_parser.parse_file(file_path), python_parser 13 | 14 | 15 | def get_func_body_and_docstr(python_parser: PythonParser): 16 | res = python_parser.get_all_function_bodies(strip_docstr=True, get_index=True) 17 | for key, value in res.items(): 18 | yield key, value -------------------------------------------------------------------------------- /docly/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tokenize import tokenize 2 | from io import BytesIO 3 | 4 | 5 | def tokenize_code_string(text): 6 | code_tokens = [] 7 | for tok in tokenize(BytesIO(text.encode('utf-8')).readline): 8 | if tok.string.strip() != "" and tok.string.strip() != "utf-8": 9 | code_tokens.append(tok.string.strip().lower()) 10 | return code_tokens -------------------------------------------------------------------------------- /logo/docly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autosoft-dev/docly/0bd6216b8a9735e9fa76bffd4ffea6cec6cc4a01/logo/docly.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.4 2 | appnope==0.1.0 3 | args==0.1.0 4 | astroid==2.4.2 5 | attrs==20.2.0 6 | autopep8==1.5.4 7 | azure-common==1.1.25 8 | azure-nspkg==3.0.2 9 | azure-storage==0.36.0 10 | backcall==0.2.0 11 | black==20.8b1 12 | boto3==1.15.10 13 | botocore==1.18.10 14 | cached-property==1.5.2 15 | certifi==2020.6.20 16 | cffi==1.14.3 17 | chardet==3.0.4 18 | click==7.1.2 19 | clint==0.5.1 20 | cryptography==3.1.1 21 | dataclasses==0.7 22 | decorator==4.4.2 23 | docopt==0.6.2 24 | dpu-utils==0.2.18 25 | filelock==3.0.12 26 | flake8==3.8.3 27 | future==0.18.2 28 | idna==2.10 29 | importlib-metadata==1.7.0 30 | iniconfig==1.0.1 31 | ipython==7.16.1 32 | ipython-genutils==0.2.0 33 | isort==5.5.3 34 | jedi==0.17.2 35 | jmespath==0.10.0 36 | joblib==0.16.0 37 | lazy-object-proxy==1.4.3 38 | mccabe==0.6.1 39 | more-itertools==8.5.0 40 | mypy-extensions==0.4.3 41 | nltk==3.5 42 | numpy==1.19.2 43 | packaging==20.4 44 | parso==0.7.1 45 | pathspec==0.8.0 46 | pexpect==4.8.0 47 | pickleshare==0.7.5 48 | pluggy==0.13.1 49 | prettytable==0.7.2 50 | prompt-toolkit==3.0.7 51 | ptyprocess==0.6.0 52 | py==1.10.0 53 | pycodestyle==2.6.0 54 | pycparser==2.20 55 | pyfiglet==0.8.post1 56 | pyflakes==2.2.0 57 | pygit2==1.3.0 58 | Pygments==2.7.1 59 | pylint==2.6.0 60 | pyparsing==2.4.7 61 | pytest==6.0.2 62 | python-dateutil==2.8.1 63 | PyYAML==5.4 64 | regex==2020.7.14 65 | requests==2.24.0 66 | s3transfer==0.3.3 67 | sacremoses==0.0.43 68 | sentencepiece==0.1.91 69 | SetSimilaritySearch==0.1.7 70 | six==1.15.0 71 | tokenizers==0.5.0 72 | toml==0.10.1 73 | torch==1.6.0 74 | tqdm==4.49.0 75 | traitlets==4.3.3 76 | transformers==2.5.0 77 | tree-hugger==0.8.3 78 | tree-sitter==0.1.1 79 | typed-ast==1.4.1 80 | typing-extensions==3.7.4.3 81 | urllib3==1.25.10 82 | wcwidth==0.2.5 83 | wrapt==1.12.1 84 | zipp==3.1.0 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import find_packages, setup 3 | 4 | from docly import __version__ 5 | 6 | # The directory containing this file 7 | HERE = pathlib.Path(__file__).parent 8 | 9 | # The text of the README file 10 | README = (HERE / "README.md").read_text() 11 | 12 | 13 | setup( 14 | name='docly', 15 | description="Generate docstrings for your python functions. Automatically!", 16 | long_description=README, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/autosoft-dev/docly", 19 | author="CodistAI", 20 | author_email="shubhadeep@codist-ai.com", 21 | include_package_data=True, 22 | license="MIT", 23 | classifiers=[ 24 | "License :: OSI Approved :: MIT License", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.6", 27 | ], 28 | version=__version__, 29 | packages=find_packages(exclude=("tests",)), 30 | install_requires=["tqdm", "requests", "torch", 31 | "pyfiglet", "rich", "dpu-utils", "numpy", 32 | "nltk", "transformers==2.5.0", "tree_hugger", 33 | "halo", "invoke", 34 | ], 35 | extras_require={ 36 | "jupyter": ["jupytext"] 37 | }, 38 | entry_points = { 39 | 'console_scripts': ['docly-gen=docly.cli.docly_gen:main', 40 | 'docly-restore=docly.cli.docly_restore:main', 41 | 'docly-clean=docly.cli.docly_clean:main'], 42 | }, 43 | ) 44 | -------------------------------------------------------------------------------- /test_files/api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | requests.api 5 | ~~~~~~~~~~~~ 6 | 7 | This module implements the Requests API. 8 | 9 | :copyright: (c) 2012 by Kenneth Reitz. 10 | :license: Apache2, see LICENSE for more details. 11 | """ 12 | 13 | from . import sessions 14 | 15 | 16 | def request(method, url, **kwargs): 17 | 18 | # By using the 'with' statement we are sure the session is closed, thus we 19 | # avoid leaving sockets open which can trigger a ResourceWarning in some 20 | # cases, and look like a memory leak in others. 21 | with sessions.Session() as session: 22 | return session.request(method=method, url=url, **kwargs) 23 | 24 | 25 | def get(url, params=None, **kwargs): 26 | 27 | kwargs.setdefault('allow_redirects', True) 28 | return request('get', url, params=params, **kwargs) 29 | 30 | 31 | def options(url, **kwargs): 32 | 33 | kwargs.setdefault('allow_redirects', True) 34 | return request('options', url, **kwargs) 35 | 36 | 37 | def head(url, **kwargs): 38 | 39 | kwargs.setdefault('allow_redirects', False) 40 | return request('head', url, **kwargs) 41 | 42 | 43 | def post(url, data=None, json=None, **kwargs): 44 | 45 | return request('post', url, data=data, json=json, **kwargs) 46 | 47 | 48 | def put(url, data=None, **kwargs): 49 | 50 | return request('put', url, data=data, **kwargs) 51 | 52 | 53 | def patch(url, data=None, **kwargs): 54 | 55 | return request('patch', url, data=data, **kwargs) 56 | 57 | 58 | def delete(url, **kwargs): 59 | 60 | return request('delete', url, **kwargs) 61 | -------------------------------------------------------------------------------- /test_files/flask_files/aa.py: -------------------------------------------------------------------------------- 1 | def new_function_with_one_argument(arg): 2 | pass 3 | -------------------------------------------------------------------------------- /test_files/flask_files/cli.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | import os 4 | import platform 5 | import re 6 | import sys 7 | import traceback 8 | import warnings 9 | from functools import update_wrapper 10 | from operator import attrgetter 11 | from threading import Lock 12 | from threading import Thread 13 | 14 | import click 15 | from werkzeug.utils import import_string 16 | 17 | from .globals import current_app 18 | from .helpers import get_debug_flag 19 | from .helpers import get_env 20 | from .helpers import get_load_dotenv 21 | 22 | try: 23 | import dotenv 24 | except ImportError: 25 | dotenv = None 26 | 27 | try: 28 | import ssl 29 | except ImportError: 30 | ssl = None 31 | 32 | 33 | class NoAppException(click.UsageError): 34 | """Raised if an application cannot be found or loaded.""" 35 | 36 | 37 | def find_best_app(script_info, module): 38 | from . import Flask 39 | 40 | # Search for the most common names first. 41 | for attr_name in ("app", "application"): 42 | app = getattr(module, attr_name, None) 43 | 44 | if isinstance(app, Flask): 45 | return app 46 | 47 | # Otherwise find the only object that is a Flask instance. 48 | matches = [v for v in module.__dict__.values() if isinstance(v, Flask)] 49 | 50 | if len(matches) == 1: 51 | return matches[0] 52 | elif len(matches) > 1: 53 | raise NoAppException( 54 | "Detected multiple Flask applications in module" 55 | f" {module.__name__!r}. Use 'FLASK_APP={module.__name__}:name'" 56 | f" to specify the correct one." 57 | ) 58 | 59 | # Search for app factory functions. 60 | for attr_name in {"create_app", "make_app"}: 61 | app_factory = getattr(module, attr_name, None) 62 | 63 | if inspect.isfunction(app_factory): 64 | try: 65 | app = call_factory(script_info, app_factory) 66 | 67 | if isinstance(app, Flask): 68 | return app 69 | except TypeError: 70 | if not _called_with_wrong_args(app_factory): 71 | raise 72 | raise NoAppException( 73 | f"Detected factory {attr_name!r} in module {module.__name__!r}," 74 | " but could not call it without arguments. Use" 75 | f" \"FLASK_APP='{module.__name__}:{attr_name}(args)'\"" 76 | " to specify arguments." 77 | ) 78 | 79 | raise NoAppException( 80 | "Failed to find Flask application or factory in module" 81 | f" {module.__name__!r}. Use 'FLASK_APP={module.__name__}:name'" 82 | " to specify one." 83 | ) 84 | 85 | 86 | def call_factory(script_info, app_factory, args=None, kwargs=None): 87 | sig = inspect.signature(app_factory) 88 | args = [] if args is None else args 89 | kwargs = {} if kwargs is None else kwargs 90 | 91 | if "script_info" in sig.parameters: 92 | warnings.warn( 93 | "The 'script_info' argument is deprecated and will not be" 94 | " passed to the app factory function in 2.1.", 95 | DeprecationWarning, 96 | ) 97 | kwargs["script_info"] = script_info 98 | 99 | if ( 100 | not args 101 | and len(sig.parameters) == 1 102 | and next(iter(sig.parameters.values())).default is inspect.Parameter.empty 103 | ): 104 | warnings.warn( 105 | "Script info is deprecated and will not be passed as the" 106 | " single argument to the app factory function in 2.1.", 107 | DeprecationWarning, 108 | ) 109 | args.append(script_info) 110 | 111 | return app_factory(*args, **kwargs) 112 | 113 | 114 | def _called_with_wrong_args(f): 115 | 116 | tb = sys.exc_info()[2] 117 | 118 | try: 119 | while tb is not None: 120 | if tb.tb_frame.f_code is f.__code__: 121 | # In the function, it was called successfully. 122 | return False 123 | 124 | tb = tb.tb_next 125 | 126 | # Didn't reach the function. 127 | return True 128 | finally: 129 | # Delete tb to break a circular reference. 130 | # https://docs.python.org/2/library/sys.html#sys.exc_info 131 | del tb 132 | 133 | 134 | def find_app_by_string(script_info, module, app_name): 135 | from . import Flask 136 | 137 | # Parse app_name as a single expression to determine if it's a valid 138 | # attribute name or function call. 139 | try: 140 | expr = ast.parse(app_name.strip(), mode="eval").body 141 | except SyntaxError: 142 | raise NoAppException( 143 | f"Failed to parse {app_name!r} as an attribute name or function call." 144 | ) 145 | 146 | if isinstance(expr, ast.Name): 147 | name = expr.id 148 | args = kwargs = None 149 | elif isinstance(expr, ast.Call): 150 | # Ensure the function name is an attribute name only. 151 | if not isinstance(expr.func, ast.Name): 152 | raise NoAppException( 153 | f"Function reference must be a simple name: {app_name!r}." 154 | ) 155 | 156 | name = expr.func.id 157 | 158 | # Parse the positional and keyword arguments as literals. 159 | try: 160 | args = [ast.literal_eval(arg) for arg in expr.args] 161 | kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expr.keywords} 162 | except ValueError: 163 | # literal_eval gives cryptic error messages, show a generic 164 | # message with the full expression instead. 165 | raise NoAppException( 166 | f"Failed to parse arguments as literal values: {app_name!r}." 167 | ) 168 | else: 169 | raise NoAppException( 170 | f"Failed to parse {app_name!r} as an attribute name or function call." 171 | ) 172 | 173 | try: 174 | attr = getattr(module, name) 175 | except AttributeError: 176 | raise NoAppException( 177 | f"Failed to find attribute {name!r} in {module.__name__!r}." 178 | ) 179 | 180 | # If the attribute is a function, call it with any args and kwargs 181 | # to get the real application. 182 | if inspect.isfunction(attr): 183 | try: 184 | app = call_factory(script_info, attr, args, kwargs) 185 | except TypeError: 186 | if not _called_with_wrong_args(attr): 187 | raise 188 | 189 | raise NoAppException( 190 | f"The factory {app_name!r} in module" 191 | f" {module.__name__!r} could not be called with the" 192 | " specified arguments." 193 | ) 194 | else: 195 | app = attr 196 | 197 | if isinstance(app, Flask): 198 | return app 199 | 200 | raise NoAppException( 201 | "A valid Flask application was not obtained from" 202 | f" '{module.__name__}:{app_name}'." 203 | ) 204 | 205 | 206 | def prepare_import(path): 207 | path = os.path.realpath(path) 208 | 209 | fname, ext = os.path.splitext(path) 210 | if ext == ".py": 211 | path = fname 212 | 213 | if os.path.basename(path) == "__init__": 214 | path = os.path.dirname(path) 215 | 216 | module_name = [] 217 | 218 | # move up until outside package structure (no __init__.py) 219 | while True: 220 | path, name = os.path.split(path) 221 | module_name.append(name) 222 | 223 | if not os.path.exists(os.path.join(path, "__init__.py")): 224 | break 225 | 226 | if sys.path[0] != path: 227 | sys.path.insert(0, path) 228 | 229 | return ".".join(module_name[::-1]) 230 | 231 | 232 | def locate_app(script_info, module_name, app_name, raise_if_not_found=True): 233 | __traceback_hide__ = True # noqa: F841 234 | 235 | try: 236 | __import__(module_name) 237 | except ImportError: 238 | # Reraise the ImportError if it occurred within the imported module. 239 | # Determine this by checking whether the trace has a depth > 1. 240 | if sys.exc_info()[2].tb_next: 241 | raise NoAppException( 242 | f"While importing {module_name!r}, an ImportError was" 243 | f" raised:\n\n{traceback.format_exc()}" 244 | ) 245 | elif raise_if_not_found: 246 | raise NoAppException(f"Could not import {module_name!r}.") 247 | else: 248 | return 249 | 250 | module = sys.modules[module_name] 251 | 252 | if app_name is None: 253 | return find_best_app(script_info, module) 254 | else: 255 | return find_app_by_string(script_info, module, app_name) 256 | 257 | 258 | def get_version(ctx, param, value): 259 | if not value or ctx.resilient_parsing: 260 | return 261 | 262 | import werkzeug 263 | from . import __version__ 264 | 265 | click.echo( 266 | f"Python {platform.python_version()}\n" 267 | f"Flask {__version__}\n" 268 | f"Werkzeug {werkzeug.__version__}", 269 | color=ctx.color, 270 | ) 271 | ctx.exit() 272 | 273 | 274 | version_option = click.Option( 275 | ["--version"], 276 | help="Show the flask version", 277 | expose_value=False, 278 | callback=get_version, 279 | is_flag=True, 280 | is_eager=True, 281 | ) 282 | 283 | 284 | class DispatchingApp: 285 | """Special application that dispatches to a Flask application which 286 | is imported by name in a background thread. If an error happens 287 | it is recorded and shown as part of the WSGI handling which in case 288 | of the Werkzeug debugger means that it shows up in the browser. 289 | """ 290 | 291 | def __init__(self, loader, use_eager_loading=None): 292 | self.loader = loader 293 | self._app = None 294 | self._lock = Lock() 295 | self._bg_loading_exc_info = None 296 | 297 | if use_eager_loading is None: 298 | use_eager_loading = os.environ.get("WERKZEUG_RUN_MAIN") != "true" 299 | 300 | if use_eager_loading: 301 | self._load_unlocked() 302 | else: 303 | self._load_in_background() 304 | 305 | def _load_in_background(self): 306 | def _load_app(): 307 | __traceback_hide__ = True # noqa: F841 308 | with self._lock: 309 | try: 310 | self._load_unlocked() 311 | except Exception: 312 | self._bg_loading_exc_info = sys.exc_info() 313 | 314 | t = Thread(target=_load_app, args=()) 315 | t.start() 316 | 317 | def _flush_bg_loading_exception(self): 318 | __traceback_hide__ = True # noqa: F841 319 | exc_info = self._bg_loading_exc_info 320 | if exc_info is not None: 321 | self._bg_loading_exc_info = None 322 | raise exc_info 323 | 324 | def _load_unlocked(self): 325 | __traceback_hide__ = True # noqa: F841 326 | self._app = rv = self.loader() 327 | self._bg_loading_exc_info = None 328 | return rv 329 | 330 | def __call__(self, environ, start_response): 331 | __traceback_hide__ = True # noqa: F841 332 | if self._app is not None: 333 | return self._app(environ, start_response) 334 | self._flush_bg_loading_exception() 335 | with self._lock: 336 | if self._app is not None: 337 | rv = self._app 338 | else: 339 | rv = self._load_unlocked() 340 | return rv(environ, start_response) 341 | 342 | 343 | class ScriptInfo: 344 | """Helper object to deal with Flask applications. This is usually not 345 | necessary to interface with as it's used internally in the dispatching 346 | to click. In future versions of Flask this object will most likely play 347 | a bigger role. Typically it's created automatically by the 348 | :class:`FlaskGroup` but you can also manually create it and pass it 349 | onwards as click object. 350 | """ 351 | 352 | def __init__(self, app_import_path=None, create_app=None, set_debug_flag=True): 353 | #: Optionally the import path for the Flask application. 354 | self.app_import_path = app_import_path or os.environ.get("FLASK_APP") 355 | #: Optionally a function that is passed the script info to create 356 | #: the instance of the application. 357 | self.create_app = create_app 358 | #: A dictionary with arbitrary data that can be associated with 359 | #: this script info. 360 | self.data = {} 361 | self.set_debug_flag = set_debug_flag 362 | self._loaded_app = None 363 | 364 | def load_app(self): 365 | """Loads the Flask app (if not yet loaded) and returns it. Calling 366 | this multiple times will just result in the already loaded app to 367 | be returned. 368 | """ 369 | __traceback_hide__ = True # noqa: F841 370 | 371 | if self._loaded_app is not None: 372 | return self._loaded_app 373 | 374 | if self.create_app is not None: 375 | app = call_factory(self, self.create_app) 376 | else: 377 | if self.app_import_path: 378 | path, name = ( 379 | re.split(r":(?![\\/])", self.app_import_path, 1) + [None] 380 | )[:2] 381 | import_name = prepare_import(path) 382 | app = locate_app(self, import_name, name) 383 | else: 384 | for path in ("wsgi.py", "app.py"): 385 | import_name = prepare_import(path) 386 | app = locate_app(self, import_name, None, raise_if_not_found=False) 387 | 388 | if app: 389 | break 390 | 391 | if not app: 392 | raise NoAppException( 393 | "Could not locate a Flask application. You did not provide " 394 | 'the "FLASK_APP" environment variable, and a "wsgi.py" or ' 395 | '"app.py" module was not found in the current directory.' 396 | ) 397 | 398 | if self.set_debug_flag: 399 | # Update the app's debug flag through the descriptor so that 400 | # other values repopulate as well. 401 | app.debug = get_debug_flag() 402 | 403 | self._loaded_app = app 404 | return app 405 | 406 | 407 | pass_script_info = click.make_pass_decorator(ScriptInfo, ensure=True) 408 | 409 | """Wraps a callback so that it's guaranteed to be executed with the 410 | script's application context. If callbacks are registered directly 411 | to the ``app.cli`` object then they are wrapped with this function 412 | by default unless it's disabled. 413 | """ 414 | def with_appcontext(f): 415 | 416 | @click.pass_context 417 | def decorator(__ctx, *args, **kwargs): 418 | with __ctx.ensure_object(ScriptInfo).load_app().app_context(): 419 | return __ctx.invoke(f, *args, **kwargs) 420 | 421 | return update_wrapper(decorator, f) 422 | 423 | 424 | class AppGroup(click.Group): 425 | """This works similar to a regular click :class:`~click.Group` but it 426 | changes the behavior of the :meth:`command` decorator so that it 427 | automatically wraps the functions in :func:`with_appcontext`. 428 | 429 | Not to be confused with :class:`FlaskGroup`. 430 | """ 431 | 432 | def command(self, *args, **kwargs): 433 | """This works exactly like the method of the same name on a regular 434 | :class:`click.Group` but it wraps callbacks in :func:`with_appcontext` 435 | unless it's disabled by passing ``with_appcontext=False``. 436 | """ 437 | wrap_for_ctx = kwargs.pop("with_appcontext", True) 438 | 439 | def decorator(f): 440 | if wrap_for_ctx: 441 | f = with_appcontext(f) 442 | return click.Group.command(self, *args, **kwargs)(f) 443 | 444 | return decorator 445 | 446 | def group(self, *args, **kwargs): 447 | """This works exactly like the method of the same name on a regular 448 | :class:`click.Group` but it defaults the group class to 449 | :class:`AppGroup`. 450 | """ 451 | kwargs.setdefault("cls", AppGroup) 452 | return click.Group.group(self, *args, **kwargs) 453 | 454 | 455 | class FlaskGroup(AppGroup): 456 | """Special subclass of the :class:`AppGroup` group that supports 457 | loading more commands from the configured Flask app. Normally a 458 | developer does not have to interface with this class but there are 459 | some very advanced use cases for which it makes sense to create an 460 | instance of this. see :ref:`custom-scripts`. 461 | 462 | :param add_default_commands: if this is True then the default run and 463 | shell commands will be added. 464 | :param add_version_option: adds the ``--version`` option. 465 | :param create_app: an optional callback that is passed the script info and 466 | returns the loaded app. 467 | :param load_dotenv: Load the nearest :file:`.env` and :file:`.flaskenv` 468 | files to set environment variables. Will also change the working 469 | directory to the directory containing the first file found. 470 | :param set_debug_flag: Set the app's debug flag based on the active 471 | environment 472 | 473 | .. versionchanged:: 1.0 474 | If installed, python-dotenv will be used to load environment variables 475 | from :file:`.env` and :file:`.flaskenv` files. 476 | """ 477 | 478 | def __init__( 479 | self, 480 | add_default_commands=True, 481 | create_app=None, 482 | add_version_option=True, 483 | load_dotenv=True, 484 | set_debug_flag=True, 485 | **extra, 486 | ): 487 | params = list(extra.pop("params", None) or ()) 488 | 489 | if add_version_option: 490 | params.append(version_option) 491 | 492 | AppGroup.__init__(self, params=params, **extra) 493 | self.create_app = create_app 494 | self.load_dotenv = load_dotenv 495 | self.set_debug_flag = set_debug_flag 496 | 497 | if add_default_commands: 498 | self.add_command(run_command) 499 | self.add_command(shell_command) 500 | self.add_command(routes_command) 501 | 502 | self._loaded_plugin_commands = False 503 | 504 | def _load_plugin_commands(self): 505 | if self._loaded_plugin_commands: 506 | return 507 | try: 508 | import pkg_resources 509 | except ImportError: 510 | self._loaded_plugin_commands = True 511 | return 512 | 513 | for ep in pkg_resources.iter_entry_points("flask.commands"): 514 | self.add_command(ep.load(), ep.name) 515 | self._loaded_plugin_commands = True 516 | 517 | def get_command(self, ctx, name): 518 | self._load_plugin_commands() 519 | # Look up built-in and plugin commands, which should be 520 | # available even if the app fails to load. 521 | rv = super().get_command(ctx, name) 522 | 523 | if rv is not None: 524 | return rv 525 | 526 | info = ctx.ensure_object(ScriptInfo) 527 | 528 | # Look up commands provided by the app, showing an error and 529 | # continuing if the app couldn't be loaded. 530 | try: 531 | return info.load_app().cli.get_command(ctx, name) 532 | except NoAppException as e: 533 | click.secho(f"Error: {e.format_message()}\n", err=True, fg="red") 534 | 535 | def list_commands(self, ctx): 536 | self._load_plugin_commands() 537 | # Start with the built-in and plugin commands. 538 | rv = set(super().list_commands(ctx)) 539 | info = ctx.ensure_object(ScriptInfo) 540 | 541 | # Add commands provided by the app, showing an error and 542 | # continuing if the app couldn't be loaded. 543 | try: 544 | rv.update(info.load_app().cli.list_commands(ctx)) 545 | except NoAppException as e: 546 | # When an app couldn't be loaded, show the error message 547 | # without the traceback. 548 | click.secho(f"Error: {e.format_message()}\n", err=True, fg="red") 549 | except Exception: 550 | # When any other errors occurred during loading, show the 551 | # full traceback. 552 | click.secho(f"{traceback.format_exc()}\n", err=True, fg="red") 553 | 554 | return sorted(rv) 555 | 556 | def main(self, *args, **kwargs): 557 | # Set a global flag that indicates that we were invoked from the 558 | # command line interface. This is detected by Flask.run to make the 559 | # call into a no-op. This is necessary to avoid ugly errors when the 560 | # script that is loaded here also attempts to start a server. 561 | os.environ["FLASK_RUN_FROM_CLI"] = "true" 562 | 563 | if get_load_dotenv(self.load_dotenv): 564 | load_dotenv() 565 | 566 | obj = kwargs.get("obj") 567 | 568 | if obj is None: 569 | obj = ScriptInfo( 570 | create_app=self.create_app, set_debug_flag=self.set_debug_flag 571 | ) 572 | 573 | kwargs["obj"] = obj 574 | kwargs.setdefault("auto_envvar_prefix", "FLASK") 575 | return super().main(*args, **kwargs) 576 | 577 | 578 | def _path_is_ancestor(path, other): 579 | """Take ``other`` and remove the length of ``path`` from it. Then join it 580 | to ``path``. If it is the original value, ``path`` is an ancestor of 581 | ``other``.""" 582 | return os.path.join(path, other[len(path) :].lstrip(os.sep)) == other 583 | 584 | 585 | def load_dotenv(path=None): 586 | """Load "dotenv" files in order of precedence to set environment variables. 587 | 588 | If an env var is already set it is not overwritten, so earlier files in the 589 | list are preferred over later files. 590 | 591 | This is a no-op if `python-dotenv`_ is not installed. 592 | 593 | .. _python-dotenv: https://github.com/theskumar/python-dotenv#readme 594 | 595 | :param path: Load the file at this location instead of searching. 596 | :return: ``True`` if a file was loaded. 597 | 598 | .. versionchanged:: 1.1.0 599 | Returns ``False`` when python-dotenv is not installed, or when 600 | the given path isn't a file. 601 | 602 | .. versionadded:: 1.0 603 | """ 604 | if dotenv is None: 605 | if path or os.path.isfile(".env") or os.path.isfile(".flaskenv"): 606 | click.secho( 607 | " * Tip: There are .env or .flaskenv files present." 608 | ' Do "pip install python-dotenv" to use them.', 609 | fg="yellow", 610 | err=True, 611 | ) 612 | 613 | return False 614 | 615 | # if the given path specifies the actual file then return True, 616 | # else False 617 | if path is not None: 618 | if os.path.isfile(path): 619 | return dotenv.load_dotenv(path) 620 | 621 | return False 622 | 623 | new_dir = None 624 | 625 | for name in (".env", ".flaskenv"): 626 | path = dotenv.find_dotenv(name, usecwd=True) 627 | 628 | if not path: 629 | continue 630 | 631 | if new_dir is None: 632 | new_dir = os.path.dirname(path) 633 | 634 | dotenv.load_dotenv(path) 635 | 636 | return new_dir is not None # at least one file was located and loaded 637 | 638 | 639 | def show_server_banner(env, debug, app_import_path, eager_loading): 640 | """Show extra startup messages the first time the server is run, 641 | ignoring the reloader. 642 | """ 643 | if os.environ.get("WERKZEUG_RUN_MAIN") == "true": 644 | return 645 | 646 | if app_import_path is not None: 647 | message = f" * Serving Flask app {app_import_path!r}" 648 | 649 | if not eager_loading: 650 | message += " (lazy loading)" 651 | 652 | click.echo(message) 653 | 654 | click.echo(f" * Environment: {env}") 655 | 656 | if env == "production": 657 | click.secho( 658 | " WARNING: This is a development server. Do not use it in" 659 | " a production deployment.", 660 | fg="red", 661 | ) 662 | click.secho(" Use a production WSGI server instead.", dim=True) 663 | 664 | if debug is not None: 665 | click.echo(f" * Debug mode: {'on' if debug else 'off'}") 666 | 667 | 668 | class CertParamType(click.ParamType): 669 | """Click option type for the ``--cert`` option. Allows either an 670 | existing file, the string ``'adhoc'``, or an import for a 671 | :class:`~ssl.SSLContext` object. 672 | """ 673 | 674 | name = "path" 675 | 676 | def __init__(self): 677 | self.path_type = click.Path(exists=True, dir_okay=False, resolve_path=True) 678 | 679 | def convert(self, value, param, ctx): 680 | if ssl is None: 681 | raise click.BadParameter( 682 | 'Using "--cert" requires Python to be compiled with SSL support.', 683 | ctx, 684 | param, 685 | ) 686 | 687 | try: 688 | return self.path_type(value, param, ctx) 689 | except click.BadParameter: 690 | value = click.STRING(value, param, ctx).lower() 691 | 692 | if value == "adhoc": 693 | try: 694 | import cryptography # noqa: F401 695 | except ImportError: 696 | raise click.BadParameter( 697 | "Using ad-hoc certificates requires the cryptography library.", 698 | ctx, 699 | param, 700 | ) 701 | 702 | return value 703 | 704 | obj = import_string(value, silent=True) 705 | 706 | if isinstance(obj, ssl.SSLContext): 707 | return obj 708 | 709 | raise 710 | 711 | 712 | def _validate_key(ctx, param, value): 713 | """The ``--key`` option must be specified when ``--cert`` is a file. 714 | Modifies the ``cert`` param to be a ``(cert, key)`` pair if needed. 715 | """ 716 | cert = ctx.params.get("cert") 717 | is_adhoc = cert == "adhoc" 718 | is_context = ssl and isinstance(cert, ssl.SSLContext) 719 | 720 | if value is not None: 721 | if is_adhoc: 722 | raise click.BadParameter( 723 | 'When "--cert" is "adhoc", "--key" is not used.', ctx, param 724 | ) 725 | 726 | if is_context: 727 | raise click.BadParameter( 728 | 'When "--cert" is an SSLContext object, "--key is not used.', ctx, param 729 | ) 730 | 731 | if not cert: 732 | raise click.BadParameter('"--cert" must also be specified.', ctx, param) 733 | 734 | ctx.params["cert"] = cert, value 735 | 736 | else: 737 | if cert and not (is_adhoc or is_context): 738 | raise click.BadParameter('Required when using "--cert".', ctx, param) 739 | 740 | return value 741 | 742 | 743 | class SeparatedPathType(click.Path): 744 | """Click option type that accepts a list of values separated by the 745 | OS's path separator (``:``, ``;`` on Windows). Each value is 746 | validated as a :class:`click.Path` type. 747 | """ 748 | 749 | def convert(self, value, param, ctx): 750 | items = self.split_envvar_value(value) 751 | super_convert = super().convert 752 | return [super_convert(item, param, ctx) for item in items] 753 | 754 | 755 | @click.command("run", short_help="Run a development server.") 756 | @click.option("--host", "-h", default="127.0.0.1", help="The interface to bind to.") 757 | @click.option("--port", "-p", default=5000, help="The port to bind to.") 758 | @click.option( 759 | "--cert", type=CertParamType(), help="Specify a certificate file to use HTTPS." 760 | ) 761 | @click.option( 762 | "--key", 763 | type=click.Path(exists=True, dir_okay=False, resolve_path=True), 764 | callback=_validate_key, 765 | expose_value=False, 766 | help="The key file to use when specifying a certificate.", 767 | ) 768 | @click.option( 769 | "--reload/--no-reload", 770 | default=None, 771 | help="Enable or disable the reloader. By default the reloader " 772 | "is active if debug is enabled.", 773 | ) 774 | @click.option( 775 | "--debugger/--no-debugger", 776 | default=None, 777 | help="Enable or disable the debugger. By default the debugger " 778 | "is active if debug is enabled.", 779 | ) 780 | @click.option( 781 | "--eager-loading/--lazy-loading", 782 | default=None, 783 | help="Enable or disable eager loading. By default eager " 784 | "loading is enabled if the reloader is disabled.", 785 | ) 786 | @click.option( 787 | "--with-threads/--without-threads", 788 | default=True, 789 | help="Enable or disable multithreading.", 790 | ) 791 | @click.option( 792 | "--extra-files", 793 | default=None, 794 | type=SeparatedPathType(), 795 | help=( 796 | "Extra files that trigger a reload on change. Multiple paths" 797 | f" are separated by {os.path.pathsep!r}." 798 | ), 799 | ) 800 | @pass_script_info 801 | def run_command( 802 | info, host, port, reload, debugger, eager_loading, with_threads, cert, extra_files 803 | ): 804 | """Run a local development server. 805 | 806 | This server is for development purposes only. It does not provide 807 | the stability, security, or performance of production WSGI servers. 808 | 809 | The reloader and debugger are enabled by default if 810 | FLASK_ENV=development or FLASK_DEBUG=1. 811 | """ 812 | debug = get_debug_flag() 813 | 814 | if reload is None: 815 | reload = debug 816 | 817 | if debugger is None: 818 | debugger = debug 819 | 820 | show_server_banner(get_env(), debug, info.app_import_path, eager_loading) 821 | app = DispatchingApp(info.load_app, use_eager_loading=eager_loading) 822 | 823 | from werkzeug.serving import run_simple 824 | 825 | run_simple( 826 | host, 827 | port, 828 | app, 829 | use_reloader=reload, 830 | use_debugger=debugger, 831 | threaded=with_threads, 832 | ssl_context=cert, 833 | extra_files=extra_files, 834 | ) 835 | 836 | 837 | @click.command("shell", short_help="Run a shell in the app context.") 838 | @with_appcontext 839 | def shell_command(): 840 | """Run an interactive Python shell in the context of a given 841 | Flask application. The application will populate the default 842 | namespace of this shell according to its configuration. 843 | 844 | This is useful for executing small snippets of management code 845 | without having to manually configure the application. 846 | """ 847 | import code 848 | from .globals import _app_ctx_stack 849 | 850 | app = _app_ctx_stack.top.app 851 | banner = ( 852 | f"Python {sys.version} on {sys.platform}\n" 853 | f"App: {app.import_name} [{app.env}]\n" 854 | f"Instance: {app.instance_path}" 855 | ) 856 | ctx = {} 857 | 858 | # Support the regular Python interpreter startup script if someone 859 | # is using it. 860 | startup = os.environ.get("PYTHONSTARTUP") 861 | if startup and os.path.isfile(startup): 862 | with open(startup) as f: 863 | eval(compile(f.read(), startup, "exec"), ctx) 864 | 865 | ctx.update(app.make_shell_context()) 866 | 867 | code.interact(banner=banner, local=ctx) 868 | 869 | 870 | @click.command("routes", short_help="Show the routes for the app.") 871 | @click.option( 872 | "--sort", 873 | "-s", 874 | type=click.Choice(("endpoint", "methods", "rule", "match")), 875 | default="endpoint", 876 | help=( 877 | 'Method to sort routes by. "match" is the order that Flask will match ' 878 | "routes when dispatching a request." 879 | ), 880 | ) 881 | @click.option("--all-methods", is_flag=True, help="Show HEAD and OPTIONS methods.") 882 | @with_appcontext 883 | def routes_command(sort, all_methods): 884 | """Show all registered routes with endpoints and methods.""" 885 | 886 | rules = list(current_app.url_map.iter_rules()) 887 | if not rules: 888 | click.echo("No routes were registered.") 889 | return 890 | 891 | ignored_methods = set(() if all_methods else ("HEAD", "OPTIONS")) 892 | 893 | if sort in ("endpoint", "rule"): 894 | rules = sorted(rules, key=attrgetter(sort)) 895 | elif sort == "methods": 896 | rules = sorted(rules, key=lambda rule: sorted(rule.methods)) 897 | 898 | rule_methods = [", ".join(sorted(rule.methods - ignored_methods)) for rule in rules] 899 | 900 | headers = ("Endpoint", "Methods", "Rule") 901 | widths = ( 902 | max(len(rule.endpoint) for rule in rules), 903 | max(len(methods) for methods in rule_methods), 904 | max(len(rule.rule) for rule in rules), 905 | ) 906 | widths = [max(len(h), w) for h, w in zip(headers, widths)] 907 | row = "{{0:<{0}}} {{1:<{1}}} {{2:<{2}}}".format(*widths) 908 | 909 | click.echo(row.format(*headers).strip()) 910 | click.echo(row.format(*("-" * width for width in widths))) 911 | 912 | for rule, methods in zip(rules, rule_methods): 913 | click.echo(row.format(rule.endpoint, methods, rule.rule).rstrip()) 914 | 915 | 916 | cli = FlaskGroup( 917 | help="""\ 918 | A general utility script for Flask applications. 919 | 920 | Provides commands from Flask, extensions, and the application. Loads the 921 | application defined in the FLASK_APP environment variable, or from a wsgi.py 922 | file. Setting the FLASK_ENV environment variable to 'development' will enable 923 | debug mode. 924 | 925 | \b 926 | {prefix}{cmd} FLASK_APP=hello.py 927 | {prefix}{cmd} FLASK_ENV=development 928 | {prefix}flask run 929 | """.format( 930 | cmd="export" if os.name == "posix" else "set", 931 | prefix="$ " if os.name == "posix" else "> ", 932 | ) 933 | ) 934 | 935 | 936 | def main(as_module=False): 937 | # TODO omit sys.argv once https://github.com/pallets/click/issues/536 is fixed 938 | cli.main(args=sys.argv[1:], prog_name="python -m flask" if as_module else None) 939 | 940 | 941 | if __name__ == "__main__": 942 | main(as_module=True) 943 | -------------------------------------------------------------------------------- /test_files/inner_dir/_internal_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | requests._internal_utils 5 | ~~~~~~~~~~~~~~ 6 | 7 | Provides utility functions that are consumed internally by Requests 8 | which depend on extremely few external helpers (such as compat) 9 | """ 10 | 11 | from .compat import is_py2, builtin_str, str 12 | 13 | """Given a string object, regardless of type, returns a representation of 14 | that string in the native string type, encoding and decoding where 15 | necessary. This assumes ASCII unless told otherwise. 16 | """ 17 | def to_native_string(string, encoding='ascii'): 18 | if isinstance(string, builtin_str): 19 | out = string 20 | else: 21 | if is_py2: 22 | out = string.encode(encoding) 23 | else: 24 | out = string.decode(encoding) 25 | 26 | return out 27 | 28 | """Determine if unicode string only contains ASCII characters. 29 | 30 | :param str u_string: unicode string to check. Must be unicode 31 | and not Python 2 `str`. 32 | :rtype: bool 33 | """ 34 | def unicode_is_ascii(u_string): 35 | 36 | assert isinstance(u_string, str) 37 | try: 38 | u_string.encode('ascii') 39 | return True 40 | except UnicodeEncodeError: 41 | return False 42 | -------------------------------------------------------------------------------- /test_files/notebooks/6_lstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "D7tqLMoKF6uq" 8 | }, 9 | "source": [ 10 | "Deep Learning with TensorFlow\n", 11 | "=============\n", 12 | "\n", 13 | "Credits: Forked from [TensorFlow](https://github.com/tensorflow/tensorflow) by Google\n", 14 | "\n", 15 | "Setup\n", 16 | "------------\n", 17 | "\n", 18 | "Refer to the [setup instructions](https://github.com/donnemartin/data-science-ipython-notebooks/tree/feature/deep-learning/deep-learning/tensor-flow-exercises/README.md).\n", 19 | "\n", 20 | "Exercise 6\n", 21 | "------------\n", 22 | "\n", 23 | "After training a skip-gram model in `5_word2vec.ipynb`, the goal of this exercise is to train a LSTM character model over [Text8](http://mattmahoney.net/dc/textdata) data." 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "cellView": "both", 31 | "colab": { 32 | "autoexec": { 33 | "startup": false, 34 | "wait_interval": 0 35 | } 36 | }, 37 | "colab_type": "code", 38 | "collapsed": true, 39 | "id": "MvEblsgEXxrd" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "# These are all the modules we'll be using later. Make sure you can import them\n", 44 | "# before proceeding further.\n", 45 | "import os\n", 46 | "import numpy as np\n", 47 | "import random\n", 48 | "import string\n", 49 | "import tensorflow as tf\n", 50 | "import urllib\n", 51 | "import zipfile" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "cellView": "both", 59 | "colab": { 60 | "autoexec": { 61 | "startup": false, 62 | "wait_interval": 0 63 | }, 64 | "output_extras": [ 65 | { 66 | "item_id": 1 67 | } 68 | ] 69 | }, 70 | "colab_type": "code", 71 | "collapsed": false, 72 | "executionInfo": { 73 | "elapsed": 5993, 74 | "status": "ok", 75 | "timestamp": 1445965582896, 76 | "user": { 77 | "color": "#1FA15D", 78 | "displayName": "Vincent Vanhoucke", 79 | "isAnonymous": false, 80 | "isMe": true, 81 | "permissionId": "05076109866853157986", 82 | "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", 83 | "sessionId": "6f6f07b359200c46", 84 | "userId": "102167687554210253930" 85 | }, 86 | "user_tz": 420 87 | }, 88 | "id": "RJ-o3UBUFtCw", 89 | "outputId": "d530534e-0791-4a94-ca6d-1c8f1b908a9e" 90 | }, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Found and verified text8.zip\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "url = 'http://mattmahoney.net/dc/'\n", 102 | "\n", 103 | "def maybe_download(filename, expected_bytes):\n", 104 | " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", 105 | " if not os.path.exists(filename):\n", 106 | " filename, _ = urllib.urlretrieve(url + filename, filename)\n", 107 | " statinfo = os.stat(filename)\n", 108 | " if statinfo.st_size == expected_bytes:\n", 109 | " print 'Found and verified', filename\n", 110 | " else:\n", 111 | " print statinfo.st_size\n", 112 | " raise Exception(\n", 113 | " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n", 114 | " return filename\n", 115 | "\n", 116 | "filename = maybe_download('text8.zip', 31344016)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "cellView": "both", 124 | "colab": { 125 | "autoexec": { 126 | "startup": false, 127 | "wait_interval": 0 128 | }, 129 | "output_extras": [ 130 | { 131 | "item_id": 1 132 | } 133 | ] 134 | }, 135 | "colab_type": "code", 136 | "collapsed": false, 137 | "executionInfo": { 138 | "elapsed": 5982, 139 | "status": "ok", 140 | "timestamp": 1445965582916, 141 | "user": { 142 | "color": "#1FA15D", 143 | "displayName": "Vincent Vanhoucke", 144 | "isAnonymous": false, 145 | "isMe": true, 146 | "permissionId": "05076109866853157986", 147 | "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", 148 | "sessionId": "6f6f07b359200c46", 149 | "userId": "102167687554210253930" 150 | }, 151 | "user_tz": 420 152 | }, 153 | "id": "Mvf09fjugFU_", 154 | "outputId": "8f75db58-3862-404b-a0c3-799380597390" 155 | }, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "Data size 100000000\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "def read_data(filename):\n", 167 | " f = zipfile.ZipFile(filename)\n", 168 | " for name in f.namelist():\n", 169 | " return f.read(name)\n", 170 | " f.close()\n", 171 | " \n", 172 | "text = read_data(filename)\n", 173 | "print \"Data size\", len(text)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "colab_type": "text", 180 | "id": "ga2CYACE-ghb" 181 | }, 182 | "source": [ 183 | "Create a small validation set." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "cellView": "both", 191 | "colab": { 192 | "autoexec": { 193 | "startup": false, 194 | "wait_interval": 0 195 | }, 196 | "output_extras": [ 197 | { 198 | "item_id": 1 199 | } 200 | ] 201 | }, 202 | "colab_type": "code", 203 | "collapsed": false, 204 | "executionInfo": { 205 | "elapsed": 6184, 206 | "status": "ok", 207 | "timestamp": 1445965583138, 208 | "user": { 209 | "color": "#1FA15D", 210 | "displayName": "Vincent Vanhoucke", 211 | "isAnonymous": false, 212 | "isMe": true, 213 | "permissionId": "05076109866853157986", 214 | "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", 215 | "sessionId": "6f6f07b359200c46", 216 | "userId": "102167687554210253930" 217 | }, 218 | "user_tz": 420 219 | }, 220 | "id": "w-oBpfFG-j43", 221 | "outputId": "bdb96002-d021-4379-f6de-a977924f0d02" 222 | }, 223 | "outputs": [ 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "99999000 ons anarchists advocate social relations based upon voluntary as\n", 229 | "1000 anarchism originated as a term of abuse first used against earl\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "valid_size = 1000\n", 235 | "valid_text = text[:valid_size]\n", 236 | "train_text = text[valid_size:]\n", 237 | "train_size = len(train_text)\n", 238 | "print train_size, train_text[:64]\n", 239 | "print valid_size, valid_text[:64]" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": { 245 | "colab_type": "text", 246 | "id": "Zdw6i4F8glpp" 247 | }, 248 | "source": [ 249 | "Utility functions to map characters to vocabulary IDs and back." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "cellView": "both", 257 | "colab": { 258 | "autoexec": { 259 | "startup": false, 260 | "wait_interval": 0 261 | }, 262 | "output_extras": [ 263 | { 264 | "item_id": 1 265 | } 266 | ] 267 | }, 268 | "colab_type": "code", 269 | "collapsed": false, 270 | "executionInfo": { 271 | "elapsed": 6276, 272 | "status": "ok", 273 | "timestamp": 1445965583249, 274 | "user": { 275 | "color": "#1FA15D", 276 | "displayName": "Vincent Vanhoucke", 277 | "isAnonymous": false, 278 | "isMe": true, 279 | "permissionId": "05076109866853157986", 280 | "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", 281 | "sessionId": "6f6f07b359200c46", 282 | "userId": "102167687554210253930" 283 | }, 284 | "user_tz": 420 285 | }, 286 | "id": "gAL1EECXeZsD", 287 | "outputId": "88fc9032-feb9-45ff-a9a0-a26759cc1f2e" 288 | }, 289 | "outputs": [ 290 | { 291 | "name": "stdout", 292 | "output_type": "stream", 293 | "text": [ 294 | "1 26 0 Unexpected character: ï\n", 295 | "0\n", 296 | "a z \n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "vocabulary_size = len(string.ascii_lowercase) + 1 # [a-z] + ' '\n", 302 | "first_letter = ord(string.ascii_lowercase[0])\n", 303 | "\n", 304 | "def char2id(char):\n", 305 | " if char in string.ascii_lowercase:\n", 306 | " return ord(char) - first_letter + 1\n", 307 | " elif char == ' ':\n", 308 | " return 0\n", 309 | " else:\n", 310 | " print 'Unexpected character:', char\n", 311 | " return 0\n", 312 | " \n", 313 | "def id2char(dictid):\n", 314 | " if dictid > 0:\n", 315 | " return chr(dictid + first_letter - 1)\n", 316 | " else:\n", 317 | " return ' '\n", 318 | "\n", 319 | "print char2id('a'), char2id('z'), char2id(' '), char2id('ï')\n", 320 | "print id2char(1), id2char(26), id2char(0)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": { 326 | "colab_type": "text", 327 | "id": "lFwoyygOmWsL" 328 | }, 329 | "source": [ 330 | "Function to generate a training batch for the LSTM model." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": { 337 | "cellView": "both", 338 | "colab": { 339 | "autoexec": { 340 | "startup": false, 341 | "wait_interval": 0 342 | }, 343 | "output_extras": [ 344 | { 345 | "item_id": 1 346 | } 347 | ] 348 | }, 349 | "colab_type": "code", 350 | "collapsed": false, 351 | "executionInfo": { 352 | "elapsed": 6473, 353 | "status": "ok", 354 | "timestamp": 1445965583467, 355 | "user": { 356 | "color": "#1FA15D", 357 | "displayName": "Vincent Vanhoucke", 358 | "isAnonymous": false, 359 | "isMe": true, 360 | "permissionId": "05076109866853157986", 361 | "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", 362 | "sessionId": "6f6f07b359200c46", 363 | "userId": "102167687554210253930" 364 | }, 365 | "user_tz": 420 366 | }, 367 | "id": "d9wMtjy5hCj9", 368 | "outputId": "3dd79c80-454a-4be0-8b71-4a4a357b3367" 369 | }, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "['ons anarchi', 'when milita', 'lleria arch', ' abbeys and', 'married urr', 'hel and ric', 'y and litur', 'ay opened f', 'tion from t', 'migration t', 'new york ot', 'he boeing s', 'e listed wi', 'eber has pr', 'o be made t', 'yer who rec', 'ore signifi', 'a fierce cr', ' two six ei', 'aristotle s', 'ity can be ', ' and intrac', 'tion of the', 'dy to pass ', 'f certain d', 'at it will ', 'e convince ', 'ent told hi', 'ampaign and', 'rver side s', 'ious texts ', 'o capitaliz', 'a duplicate', 'gh ann es d', 'ine january', 'ross zero t', 'cal theorie', 'ast instanc', ' dimensiona', 'most holy m', 't s support', 'u is still ', 'e oscillati', 'o eight sub', 'of italy la', 's the tower', 'klahoma pre', 'erprise lin', 'ws becomes ', 'et in a naz', 'the fabian ', 'etchy to re', ' sharman ne', 'ised empero', 'ting in pol', 'd neo latin', 'th risky ri', 'encyclopedi', 'fense the a', 'duating fro', 'treet grid ', 'ations more', 'appeal of d', 'si have mad']\n", 376 | "['ists advoca', 'ary governm', 'hes nationa', 'd monasteri', 'raca prince', 'chard baer ', 'rgical lang', 'for passeng', 'the nationa', 'took place ', 'ther well k', 'seven six s', 'ith a gloss', 'robably bee', 'to recogniz', 'ceived the ', 'icant than ', 'ritic of th', 'ight in sig', 's uncaused ', ' lost as in', 'cellular ic', 'e size of t', ' him a stic', 'drugs confu', ' take to co', ' the priest', 'im to name ', 'd barred at', 'standard fo', ' such as es', 'ze on the g', 'e of the or', 'd hiver one', 'y eight mar', 'the lead ch', 'es classica', 'ce the non ', 'al analysis', 'mormons bel', 't or at lea', ' disagreed ', 'ing system ', 'btypes base', 'anguages th', 'r commissio', 'ess one nin', 'nux suse li', ' the first ', 'zi concentr', ' society ne', 'elatively s', 'etworks sha', 'or hirohito', 'litical ini', 'n most of t', 'iskerdoo ri', 'ic overview', 'air compone', 'om acnm acc', ' centerline', 'e than any ', 'devotional ', 'de such dev']\n", 377 | "[' a']\n", 378 | "['an']\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "batch_size=64\n", 384 | "num_unrollings=10\n", 385 | "\n", 386 | "class BatchGenerator(object):\n", 387 | " def __init__(self, text, batch_size, num_unrollings):\n", 388 | " self._text = text\n", 389 | " self._text_size = len(text)\n", 390 | " self._batch_size = batch_size\n", 391 | " self._num_unrollings = num_unrollings\n", 392 | " segment = self._text_size / batch_size\n", 393 | " self._cursor = [ offset * segment for offset in xrange(batch_size)]\n", 394 | " self._last_batch = self._next_batch()\n", 395 | " \n", 396 | " def _next_batch(self):\n", 397 | " \"\"\"Generate a single batch from the current cursor position in the data.\"\"\"\n", 398 | " batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)\n", 399 | " for b in xrange(self._batch_size):\n", 400 | " batch[b, char2id(self._text[self._cursor[b]])] = 1.0\n", 401 | " self._cursor[b] = (self._cursor[b] + 1) % self._text_size\n", 402 | " return batch\n", 403 | " \n", 404 | " def next(self):\n", 405 | " \"\"\"Generate the next array of batches from the data. The array consists of\n", 406 | " the last batch of the previous array, followed by num_unrollings new ones.\n", 407 | " \"\"\"\n", 408 | " batches = [self._last_batch]\n", 409 | " for step in xrange(self._num_unrollings):\n", 410 | " batches.append(self._next_batch())\n", 411 | " self._last_batch = batches[-1]\n", 412 | " return batches\n", 413 | "\n", 414 | "def characters(probabilities):\n", 415 | " \"\"\"Turn a 1-hot encoding or a probability distribution over the possible\n", 416 | " characters back into its (mostl likely) character representation.\"\"\"\n", 417 | " return [id2char(c) for c in np.argmax(probabilities, 1)]\n", 418 | "\n", 419 | "def batches2string(batches):\n", 420 | " \"\"\"Convert a sequence of batches back into their (most likely) string\n", 421 | " representation.\"\"\"\n", 422 | " s = [''] * batches[0].shape[0]\n", 423 | " for b in batches:\n", 424 | " s = [''.join(x) for x in zip(s, characters(b))]\n", 425 | " return s\n", 426 | "\n", 427 | "train_batches = BatchGenerator(train_text, batch_size, num_unrollings)\n", 428 | "valid_batches = BatchGenerator(valid_text, 1, 1)\n", 429 | "\n", 430 | "print batches2string(train_batches.next())\n", 431 | "print batches2string(train_batches.next())\n", 432 | "print batches2string(valid_batches.next())\n", 433 | "print batches2string(valid_batches.next())" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "metadata": { 440 | "cellView": "both", 441 | "colab": { 442 | "autoexec": { 443 | "startup": false, 444 | "wait_interval": 0 445 | } 446 | }, 447 | "colab_type": "code", 448 | "collapsed": true, 449 | "id": "KyVd8FxT5QBc" 450 | }, 451 | "outputs": [], 452 | "source": [ 453 | "def logprob(predictions, labels):\n", 454 | " \"\"\"Log-probability of the true labels in a predicted batch.\"\"\"\n", 455 | " predictions[predictions < 1e-10] = 1e-10\n", 456 | " return np.sum(np.multiply(labels, -np.log(predictions))) / labels.shape[0]\n", 457 | "\n", 458 | "def sample_distribution(distribution):\n", 459 | " \"\"\"Sample one element from a distribution assumed to be an array of normalized\n", 460 | " probabilities.\n", 461 | " \"\"\"\n", 462 | " r = random.uniform(0, 1)\n", 463 | " s = 0\n", 464 | " for i in xrange(len(distribution)):\n", 465 | " s += distribution[i]\n", 466 | " if s >= r:\n", 467 | " return i\n", 468 | " return len(distribution) - 1\n", 469 | "\n", 470 | "def sample(prediction):\n", 471 | " \"\"\"Turn a (column) prediction into 1-hot encoded samples.\"\"\"\n", 472 | " p = np.zeros(shape=[1, vocabulary_size], dtype=np.float)\n", 473 | " p[0, sample_distribution(prediction[0])] = 1.0\n", 474 | " return p\n", 475 | "\n", 476 | "def random_distribution():\n", 477 | " \"\"\"Generate a random column of probabilities.\"\"\"\n", 478 | " b = np.random.uniform(0.0, 1.0, size=[1, vocabulary_size])\n", 479 | " return b/np.sum(b, 1)[:,None]" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": { 485 | "colab_type": "text", 486 | "id": "K8f67YXaDr4C" 487 | }, 488 | "source": [ 489 | "Simple LSTM Model." 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "metadata": { 496 | "cellView": "both", 497 | "colab": { 498 | "autoexec": { 499 | "startup": false, 500 | "wait_interval": 0 501 | } 502 | }, 503 | "colab_type": "code", 504 | "collapsed": true, 505 | "id": "Q5rxZK6RDuGe" 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "num_nodes = 64\n", 510 | "\n", 511 | "graph = tf.Graph()\n", 512 | "with graph.as_default():\n", 513 | " \n", 514 | " # Parameters:\n", 515 | " # Input gate: input, previous output, and bias.\n", 516 | " ix = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", 517 | " im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", 518 | " ib = tf.Variable(tf.zeros([1, num_nodes]))\n", 519 | " # Forget gate: input, previous output, and bias.\n", 520 | " fx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", 521 | " fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", 522 | " fb = tf.Variable(tf.zeros([1, num_nodes]))\n", 523 | " # Memory cell: input, state and bias. \n", 524 | " cx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", 525 | " cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", 526 | " cb = tf.Variable(tf.zeros([1, num_nodes]))\n", 527 | " # Output gate: input, previous output, and bias.\n", 528 | " ox = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", 529 | " om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", 530 | " ob = tf.Variable(tf.zeros([1, num_nodes]))\n", 531 | " # Variables saving state across unrollings.\n", 532 | " saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n", 533 | " saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n", 534 | " # Classifier weights and biases.\n", 535 | " w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))\n", 536 | " b = tf.Variable(tf.zeros([vocabulary_size]))\n", 537 | " \n", 538 | " # Definition of the cell computation.\n", 539 | " def lstm_cell(i, o, state):\n", 540 | " \"\"\"Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf\n", 541 | " Note that in this formulation, we omit the various connections between the\n", 542 | " previous state and the gates.\"\"\"\n", 543 | " input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)\n", 544 | " forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)\n", 545 | " update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb\n", 546 | " state = forget_gate * state + input_gate * tf.tanh(update)\n", 547 | " output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)\n", 548 | " return output_gate * tf.tanh(state), state\n", 549 | "\n", 550 | " # Input data.\n", 551 | " train_data = list()\n", 552 | " for _ in xrange(num_unrollings + 1):\n", 553 | " train_data.append(\n", 554 | " tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))\n", 555 | " train_inputs = train_data[:num_unrollings]\n", 556 | " train_labels = train_data[1:] # labels are inputs shifted by one time step.\n", 557 | "\n", 558 | " # Unrolled LSTM loop.\n", 559 | " outputs = list()\n", 560 | " output = saved_output\n", 561 | " state = saved_state\n", 562 | " for i in train_inputs:\n", 563 | " output, state = lstm_cell(i, output, state)\n", 564 | " outputs.append(output)\n", 565 | "\n", 566 | " # State saving across unrollings.\n", 567 | " with tf.control_dependencies([saved_output.assign(output),\n", 568 | " saved_state.assign(state)]):\n", 569 | " # Classifier.\n", 570 | " logits = tf.nn.xw_plus_b(tf.concat(0, outputs), w, b)\n", 571 | " loss = tf.reduce_mean(\n", 572 | " tf.nn.softmax_cross_entropy_with_logits(\n", 573 | " logits, tf.concat(0, train_labels)))\n", 574 | "\n", 575 | " # Optimizer.\n", 576 | " global_step = tf.Variable(0)\n", 577 | " learning_rate = tf.train.exponential_decay(\n", 578 | " 10.0, global_step, 5000, 0.1, staircase=True)\n", 579 | " optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n", 580 | " gradients, v = zip(*optimizer.compute_gradients(loss))\n", 581 | " gradients, _ = tf.clip_by_global_norm(gradients, 1.25)\n", 582 | " optimizer = optimizer.apply_gradients(\n", 583 | " zip(gradients, v), global_step=global_step)\n", 584 | "\n", 585 | " # Predictions.\n", 586 | " train_prediction = tf.nn.softmax(logits)\n", 587 | " \n", 588 | " # Sampling and validation eval: batch 1, no unrolling.\n", 589 | " sample_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size])\n", 590 | " saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))\n", 591 | " saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))\n", 592 | " reset_sample_state = tf.group(\n", 593 | " saved_sample_output.assign(tf.zeros([1, num_nodes])),\n", 594 | " saved_sample_state.assign(tf.zeros([1, num_nodes])))\n", 595 | " sample_output, sample_state = lstm_cell(\n", 596 | " sample_input, saved_sample_output, saved_sample_state)\n", 597 | " with tf.control_dependencies([saved_sample_output.assign(sample_output),\n", 598 | " saved_sample_state.assign(sample_state)]):\n", 599 | " sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": { 606 | "cellView": "both", 607 | "colab": { 608 | "autoexec": { 609 | "startup": false, 610 | "wait_interval": 0 611 | }, 612 | "output_extras": [ 613 | { 614 | "item_id": 41 615 | }, 616 | { 617 | "item_id": 80 618 | }, 619 | { 620 | "item_id": 126 621 | }, 622 | { 623 | "item_id": 144 624 | } 625 | ] 626 | }, 627 | "colab_type": "code", 628 | "collapsed": false, 629 | "executionInfo": { 630 | "elapsed": 199909, 631 | "status": "ok", 632 | "timestamp": 1445965877333, 633 | "user": { 634 | "color": "#1FA15D", 635 | "displayName": "Vincent Vanhoucke", 636 | "isAnonymous": false, 637 | "isMe": true, 638 | "permissionId": "05076109866853157986", 639 | "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", 640 | "sessionId": "6f6f07b359200c46", 641 | "userId": "102167687554210253930" 642 | }, 643 | "user_tz": 420 644 | }, 645 | "id": "RD9zQCZTEaEm", 646 | "outputId": "5e868466-2532-4545-ce35-b403cf5d9de6" 647 | }, 648 | "outputs": [ 649 | { 650 | "name": "stdout", 651 | "output_type": "stream", 652 | "text": [ 653 | "Initialized\n", 654 | "Average loss at step 0 : 3.29904174805 learning rate: 10.0\n", 655 | "Minibatch perplexity: 27.09\n", 656 | "================================================================================\n", 657 | "srk dwmrnuldtbbgg tapootidtu xsciu sgokeguw hi ieicjq lq piaxhazvc s fht wjcvdlh\n", 658 | "lhrvallvbeqqquc dxd y siqvnle bzlyw nr rwhkalezo siie o deb e lpdg storq u nx o\n", 659 | "meieu nantiouie gdys qiuotblci loc hbiznauiccb cqzed acw l tsm adqxplku gn oaxet\n", 660 | "unvaouc oxchywdsjntdh zpklaejvxitsokeerloemee htphisb th eaeqseibumh aeeyj j orw\n", 661 | "ogmnictpycb whtup otnilnesxaedtekiosqet liwqarysmt arj flioiibtqekycbrrgoysj\n", 662 | "================================================================================\n", 663 | "Validation set perplexity: 19.99\n", 664 | "Average loss at step 100 : 2.59553678274 learning rate: 10.0\n", 665 | "Minibatch perplexity: 9.57\n", 666 | "Validation set perplexity: 10.60\n", 667 | "Average loss at step 200 : 2.24747137785 learning rate: 10.0\n", 668 | "Minibatch perplexity: 7.68\n", 669 | "Validation set perplexity: 8.84\n", 670 | "Average loss at step 300 : 2.09438110709 learning rate: 10.0\n", 671 | "Minibatch perplexity: 7.41\n", 672 | "Validation set perplexity: 8.13\n", 673 | "Average loss at step 400 : 1.99440989017 learning rate: 10.0\n", 674 | "Minibatch perplexity: 6.46\n", 675 | "Validation set perplexity: 7.58\n", 676 | "Average loss at step 500 : 1.9320810616 learning rate: 10.0\n", 677 | "Minibatch perplexity: 6.30\n", 678 | "Validation set perplexity: 6.88\n", 679 | "Average loss at step 600 : 1.90935629249 learning rate: 10.0\n", 680 | "Minibatch perplexity: 7.21\n", 681 | "Validation set perplexity: 6.91\n", 682 | "Average loss at step 700 : 1.85583009005 learning rate: 10.0\n", 683 | "Minibatch perplexity: 6.13\n", 684 | "Validation set perplexity: 6.60\n", 685 | "Average loss at step 800 : 1.82152368546 learning rate: 10.0\n", 686 | "Minibatch perplexity: 6.01\n", 687 | "Validation set perplexity: 6.37\n", 688 | "Average loss at step 900 : 1.83169809818 learning rate: 10.0\n", 689 | "Minibatch perplexity: 7.20\n", 690 | "Validation set perplexity: 6.23\n", 691 | "Average loss at step 1000 : 1.82217029214 learning rate: 10.0\n", 692 | "Minibatch perplexity: 6.73\n", 693 | "================================================================================\n", 694 | "le action b of the tert sy ofter selvorang previgned stischdy yocal chary the co\n", 695 | "le relganis networks partucy cetinning wilnchan sics rumeding a fulch laks oftes\n", 696 | "hian andoris ret the ecause bistory l pidect one eight five lack du that the ses\n", 697 | "aiv dromery buskocy becomer worils resism disele retery exterrationn of hide in \n", 698 | "mer miter y sught esfectur of the upission vain is werms is vul ugher compted by\n", 699 | "================================================================================\n", 700 | "Validation set perplexity: 6.07\n", 701 | "Average loss at step 1100 : 1.77301145077 learning rate: 10.0\n", 702 | "Minibatch perplexity: 6.03\n", 703 | "Validation set perplexity: 5.89\n", 704 | "Average loss at step 1200 : 1.75306463003 learning rate: 10.0\n", 705 | "Minibatch perplexity: 6.50\n", 706 | "Validation set perplexity: 5.61\n", 707 | "Average loss at step 1300 : 1.72937195778 learning rate: 10.0\n", 708 | "Minibatch perplexity: 5.00\n", 709 | "Validation set perplexity: 5.60\n", 710 | "Average loss at step 1400 : 1.74773373723 learning rate: 10.0\n", 711 | "Minibatch perplexity: 6.48\n", 712 | "Validation set perplexity: 5.66\n", 713 | "Average loss at step 1500 : 1.7368799901 learning rate: 10.0\n", 714 | "Minibatch perplexity: 5.22\n", 715 | "Validation set perplexity: 5.44\n", 716 | "Average loss at step 1600 : 1.74528762937 learning rate: 10.0\n", 717 | "Minibatch perplexity: 5.85\n", 718 | "Validation set perplexity: 5.33\n", 719 | "Average loss at step 1700 : 1.70881183743 learning rate: 10.0\n", 720 | "Minibatch perplexity: 5.33\n", 721 | "Validation set perplexity: 5.56\n", 722 | "Average loss at step 1800 : 1.67776108027 learning rate: 10.0\n", 723 | "Minibatch perplexity: 5.33\n", 724 | "Validation set perplexity: 5.29\n", 725 | "Average loss at step 1900 : 1.64935536742 learning rate: 10.0\n", 726 | "Minibatch perplexity: 5.29\n", 727 | "Validation set perplexity: 5.15\n", 728 | "Average loss at step 2000 : 1.69528644681 learning rate: 10.0\n", 729 | "Minibatch perplexity: 5.13\n", 730 | "================================================================================\n", 731 | "vers soqually have one five landwing to docial page kagan lower with ther batern\n", 732 | "ctor son alfortmandd tethre k skin the known purated to prooust caraying the fit\n", 733 | "je in beverb is the sournction bainedy wesce tu sture artualle lines digra forme\n", 734 | "m rousively haldio ourso ond anvary was for the seven solies hild buil s to te\n", 735 | "zall for is it is one nine eight eight one neval to the kime typer oene where he\n", 736 | "================================================================================\n", 737 | "Validation set perplexity: 5.25\n", 738 | "Average loss at step 2100 : 1.68808053017 learning rate: 10.0\n", 739 | "Minibatch perplexity: 5.17\n", 740 | "Validation set perplexity: 5.01\n", 741 | "Average loss at step 2200 : 1.68322490931 learning rate: 10.0\n", 742 | "Minibatch perplexity: 5.09\n", 743 | "Validation set perplexity: 5.15\n", 744 | "Average loss at step 2300 : 1.64465074301 learning rate: 10.0\n", 745 | "Minibatch perplexity: 5.51\n", 746 | "Validation set perplexity: 5.00\n", 747 | "Average loss at step 2400 : 1.66408578038 learning rate: 10.0\n", 748 | "Minibatch perplexity: 5.86\n", 749 | "Validation set perplexity: 4.80\n", 750 | "Average loss at step 2500 : 1.68515402555 learning rate: 10.0\n", 751 | "Minibatch perplexity: 5.75\n", 752 | "Validation set perplexity: 4.82\n", 753 | "Average loss at step 2600 : 1.65405208349 learning rate: 10.0\n", 754 | "Minibatch perplexity: 5.38\n", 755 | "Validation set perplexity: 4.85\n", 756 | "Average loss at step 2700 : 1.65706222177 learning rate: 10.0\n", 757 | "Minibatch perplexity: 5.46\n", 758 | "Validation set perplexity: 4.78\n", 759 | "Average loss at step 2800 : 1.65204829812 learning rate: 10.0\n", 760 | "Minibatch perplexity: 5.06\n", 761 | "Validation set perplexity: 4.64\n", 762 | "Average loss at step 2900 : 1.65107253551 learning rate: 10.0\n", 763 | "Minibatch perplexity: 5.00\n", 764 | "Validation set perplexity: 4.61\n", 765 | "Average loss at step 3000 : 1.6495274055 learning rate: 10.0\n", 766 | "Minibatch perplexity: 4.53\n", 767 | "================================================================================\n", 768 | "ject covered in belo one six six to finsh that all di rozial sime it a the lapse\n", 769 | "ble which the pullic bocades record r to sile dric two one four nine seven six f\n", 770 | " originally ame the playa ishaps the stotchational in a p dstambly name which as\n", 771 | "ore volum to bay riwer foreal in nuily operety can and auscham frooripm however \n", 772 | "kan traogey was lacous revision the mott coupofiteditey the trando insended frop\n", 773 | "================================================================================\n", 774 | "Validation set perplexity: 4.76\n", 775 | "Average loss at step 3100 : 1.63705502152 learning rate: 10.0\n", 776 | "Minibatch perplexity: 5.50\n", 777 | "Validation set perplexity: 4.76\n", 778 | "Average loss at step 3200 : 1.64740695596 learning rate: 10.0\n", 779 | "Minibatch perplexity: 4.84\n", 780 | "Validation set perplexity: 4.67\n", 781 | "Average loss at step 3300 : 1.64711504817 learning rate: 10.0\n", 782 | "Minibatch perplexity: 5.39\n", 783 | "Validation set perplexity: 4.57\n", 784 | "Average loss at step 3400 : 1.67113256454 learning rate: 10.0\n", 785 | "Minibatch perplexity: 5.56\n", 786 | "Validation set perplexity: 4.71\n", 787 | "Average loss at step 3500 : 1.65637169957 learning rate: 10.0\n", 788 | "Minibatch perplexity: 5.03\n", 789 | "Validation set perplexity: 4.80\n", 790 | "Average loss at step 3600 : 1.66601825476 learning rate: 10.0\n", 791 | "Minibatch perplexity: 4.63\n", 792 | "Validation set perplexity: 4.52\n", 793 | "Average loss at step 3700 : 1.65021387935 learning rate: 10.0\n", 794 | "Minibatch perplexity: 5.50\n", 795 | "Validation set perplexity: 4.56\n", 796 | "Average loss at step 3800 : 1.64481814981 learning rate: 10.0\n", 797 | "Minibatch perplexity: 4.60\n", 798 | "Validation set perplexity: 4.54\n", 799 | "Average loss at step 3900 : 1.642069453 learning rate: 10.0\n", 800 | "Minibatch perplexity: 4.91\n", 801 | "Validation set perplexity: 4.54\n", 802 | "Average loss at step 4000 : 1.65179730773 learning rate: 10.0\n", 803 | "Minibatch perplexity: 4.77\n", 804 | "================================================================================\n", 805 | "k s rasbonish roctes the nignese at heacle was sito of beho anarchys and with ro\n", 806 | "jusar two sue wletaus of chistical in causations d ow trancic bruthing ha laters\n", 807 | "de and speacy pulted yoftret worksy zeatlating to eight d had to ie bue seven si\n", 808 | "s fiction of the feelly constive suq flanch earlied curauking bjoventation agent\n", 809 | "quen s playing it calana our seopity also atbellisionaly comexing the revideve i\n", 810 | "================================================================================\n", 811 | "Validation set perplexity: 4.58\n", 812 | "Average loss at step 4100 : 1.63794238806 learning rate: 10.0\n", 813 | "Minibatch perplexity: 5.47\n", 814 | "Validation set perplexity: 4.79\n", 815 | "Average loss at step 4200 : 1.63822438836 learning rate: 10.0\n", 816 | "Minibatch perplexity: 5.30\n", 817 | "Validation set perplexity: 4.54\n", 818 | "Average loss at step 4300 : 1.61844664574 learning rate: 10.0\n", 819 | "Minibatch perplexity: 4.69\n", 820 | "Validation set perplexity: 4.54\n", 821 | "Average loss at step 4400 : 1.61255454302 learning rate: 10.0\n", 822 | "Minibatch perplexity: 4.67\n", 823 | "Validation set perplexity: 4.54\n", 824 | "Average loss at step 4500 : 1.61543365479 learning rate: 10.0\n", 825 | "Minibatch perplexity: 4.83\n", 826 | "Validation set perplexity: 4.69\n", 827 | "Average loss at step 4600 : 1.61607327104 learning rate: 10.0\n", 828 | "Minibatch perplexity: 5.18\n", 829 | "Validation set perplexity: 4.64\n", 830 | "Average loss at step 4700 : 1.62757282495 learning rate: 10.0\n", 831 | "Minibatch perplexity: 4.24\n", 832 | "Validation set perplexity: 4.66\n", 833 | "Average loss at step 4800 : 1.63222063541 learning rate: 10.0\n", 834 | "Minibatch perplexity: 5.30\n", 835 | "Validation set perplexity: 4.53\n", 836 | "Average loss at step 4900 : 1.63678096652 learning rate: 10.0\n", 837 | "Minibatch perplexity: 5.43\n", 838 | "Validation set perplexity: 4.64\n", 839 | "Average loss at step 5000 : 1.610340662 learning rate: 1.0\n", 840 | "Minibatch perplexity: 5.10\n", 841 | "================================================================================\n", 842 | "in b one onarbs revieds the kimiluge that fondhtic fnoto cre one nine zero zero \n", 843 | " of is it of marking panzia t had wap ironicaghni relly deah the omber b h menba\n", 844 | "ong messified it his the likdings ara subpore the a fames distaled self this int\n", 845 | "y advante authors the end languarle meit common tacing bevolitione and eight one\n", 846 | "zes that materly difild inllaring the fusts not panition assertian causecist bas\n", 847 | "================================================================================\n", 848 | "Validation set perplexity: 4.69\n", 849 | "Average loss at step 5100 : 1.60593637228 learning rate: 1.0\n", 850 | "Minibatch perplexity: 4.69\n", 851 | "Validation set perplexity: 4.47\n", 852 | "Average loss at step 5200 : 1.58993269444 learning rate: 1.0\n", 853 | "Minibatch perplexity: 4.65\n", 854 | "Validation set perplexity: 4.39\n", 855 | "Average loss at step 5300 : 1.57930587292 learning rate: 1.0\n", 856 | "Minibatch perplexity: 5.11\n", 857 | "Validation set perplexity: 4.39\n", 858 | "Average loss at step 5400 : 1.58022856832 learning rate: 1.0\n", 859 | "Minibatch perplexity: 5.19\n", 860 | "Validation set perplexity: 4.37\n", 861 | "Average loss at step 5500 : 1.56654450059 learning rate: 1.0\n", 862 | "Minibatch perplexity: 4.69\n", 863 | "Validation set perplexity: 4.33\n", 864 | "Average loss at step 5600 : 1.58013380885 learning rate: 1.0\n", 865 | "Minibatch perplexity: 5.13\n", 866 | "Validation set perplexity: 4.35\n", 867 | "Average loss at step 5700 : 1.56974959254 learning rate: 1.0\n", 868 | "Minibatch perplexity: 5.00\n", 869 | "Validation set perplexity: 4.34\n", 870 | "Average loss at step 5800 : 1.5839582932 learning rate: 1.0\n", 871 | "Minibatch perplexity: 4.88\n", 872 | "Validation set perplexity: 4.31\n", 873 | "Average loss at step 5900 : 1.57129439116 learning rate: 1.0\n", 874 | "Minibatch perplexity: 4.66\n", 875 | "Validation set perplexity: 4.32\n", 876 | "Average loss at step 6000 : 1.55144061089 learning rate: 1.0\n", 877 | "Minibatch perplexity: 4.55\n", 878 | "================================================================================\n", 879 | "utic clositical poopy stribe addi nixe one nine one zero zero eight zero b ha ex\n", 880 | "zerns b one internequiption of the secordy way anti proble akoping have fictiona\n", 881 | "phare united from has poporarly cities book ins sweden emperor a sass in origina\n", 882 | "quulk destrebinist and zeilazar and on low and by in science over country weilti\n", 883 | "x are holivia work missincis ons in the gages to starsle histon one icelanctrotu\n", 884 | "================================================================================\n", 885 | "Validation set perplexity: 4.30\n", 886 | "Average loss at step 6100 : 1.56450940847 learning rate: 1.0\n", 887 | "Minibatch perplexity: 4.77\n", 888 | "Validation set perplexity: 4.27\n", 889 | "Average loss at step 6200 : 1.53433164835 learning rate: 1.0\n", 890 | "Minibatch perplexity: 4.77\n", 891 | "Validation set perplexity: 4.27\n", 892 | "Average loss at step 6300 : 1.54773445129 learning rate: 1.0\n", 893 | "Minibatch perplexity: 4.76\n", 894 | "Validation set perplexity: 4.25\n", 895 | "Average loss at step 6400 : 1.54021131516 learning rate: 1.0\n", 896 | "Minibatch perplexity: 4.56\n", 897 | "Validation set perplexity: 4.24\n", 898 | "Average loss at step 6500 : 1.56153374553 learning rate: 1.0\n", 899 | "Minibatch perplexity: 5.43\n", 900 | "Validation set perplexity: 4.27\n", 901 | "Average loss at step 6600 : 1.59556478739 learning rate: 1.0\n", 902 | "Minibatch perplexity: 4.92\n", 903 | "Validation set perplexity: 4.28\n", 904 | "Average loss at step 6700 : 1.58076951623 learning rate: 1.0\n", 905 | "Minibatch perplexity: 4.77\n", 906 | "Validation set perplexity: 4.30\n", 907 | "Average loss at step 6800 : 1.6070714438 learning rate: 1.0\n", 908 | "Minibatch perplexity: 4.98\n", 909 | "Validation set perplexity: 4.28\n", 910 | "Average loss at step 6900 : 1.58413293839 learning rate: 1.0\n", 911 | "Minibatch perplexity: 4.61\n", 912 | "Validation set perplexity: 4.29\n", 913 | "Average loss at step 7000 : 1.57905534983 learning rate: 1.0\n", 914 | "Minibatch perplexity: 5.08\n", 915 | "================================================================================\n", 916 | "jague are officiencinels ored by film voon higherise haik one nine on the iffirc\n", 917 | "oshe provision that manned treatists on smalle bodariturmeristing the girto in s\n", 918 | "kis would softwenn mustapultmine truativersakys bersyim by s of confound esc bub\n", 919 | "ry of the using one four six blain ira mannom marencies g with fextificallise re\n", 920 | " one son vit even an conderouss to person romer i a lebapter at obiding are iuse\n", 921 | "================================================================================\n", 922 | "Validation set perplexity: 4.25\n" 923 | ] 924 | } 925 | ], 926 | "source": [ 927 | "num_steps = 7001\n", 928 | "summary_frequency = 100\n", 929 | "\n", 930 | "with tf.Session(graph=graph) as session:\n", 931 | " tf.global_variables_initializer().run()\n", 932 | " print 'Initialized'\n", 933 | " mean_loss = 0\n", 934 | " for step in xrange(num_steps):\n", 935 | " batches = train_batches.next()\n", 936 | " feed_dict = dict()\n", 937 | " for i in xrange(num_unrollings + 1):\n", 938 | " feed_dict[train_data[i]] = batches[i]\n", 939 | " _, l, predictions, lr = session.run(\n", 940 | " [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)\n", 941 | " mean_loss += l\n", 942 | " if step % summary_frequency == 0:\n", 943 | " if step > 0:\n", 944 | " mean_loss = mean_loss / summary_frequency\n", 945 | " # The mean loss is an estimate of the loss over the last few batches.\n", 946 | " print 'Average loss at step', step, ':', mean_loss, 'learning rate:', lr\n", 947 | " mean_loss = 0\n", 948 | " labels = np.concatenate(list(batches)[1:])\n", 949 | " print 'Minibatch perplexity: %.2f' % float(\n", 950 | " np.exp(logprob(predictions, labels)))\n", 951 | " if step % (summary_frequency * 10) == 0:\n", 952 | " # Generate some samples.\n", 953 | " print '=' * 80\n", 954 | " for _ in xrange(5):\n", 955 | " feed = sample(random_distribution())\n", 956 | " sentence = characters(feed)[0]\n", 957 | " reset_sample_state.run()\n", 958 | " for _ in xrange(79):\n", 959 | " prediction = sample_prediction.eval({sample_input: feed})\n", 960 | " feed = sample(prediction)\n", 961 | " sentence += characters(feed)[0]\n", 962 | " print sentence\n", 963 | " print '=' * 80\n", 964 | " # Measure validation set perplexity.\n", 965 | " reset_sample_state.run()\n", 966 | " valid_logprob = 0\n", 967 | " for _ in xrange(valid_size):\n", 968 | " b = valid_batches.next()\n", 969 | " predictions = sample_prediction.eval({sample_input: b[0]})\n", 970 | " valid_logprob = valid_logprob + logprob(predictions, b[1])\n", 971 | " print 'Validation set perplexity: %.2f' % float(np.exp(\n", 972 | " valid_logprob / valid_size))" 973 | ] 974 | }, 975 | { 976 | "cell_type": "markdown", 977 | "metadata": { 978 | "colab_type": "text", 979 | "id": "pl4vtmFfa5nn" 980 | }, 981 | "source": [ 982 | "---\n", 983 | "Problem 1\n", 984 | "---------\n", 985 | "\n", 986 | "You might have noticed that the definition of the LSTM cell involves 4 matrix multiplications with the input, and 4 matrix multiplications with the output. Simplify the expression by using a single matrix multiply for each, and variables that are 4 times larger.\n", 987 | "\n", 988 | "---" 989 | ] 990 | }, 991 | { 992 | "cell_type": "markdown", 993 | "metadata": { 994 | "colab_type": "text", 995 | "id": "4eErTCTybtph" 996 | }, 997 | "source": [ 998 | "---\n", 999 | "Problem 2\n", 1000 | "---------\n", 1001 | "\n", 1002 | "We want to train a LSTM over bigrams, that is pairs of consecutive characters like 'ab' instead of single characters like 'a'. Since the number of possible bigrams is large, feeding them directly to the LSTM using 1-hot encodings will lead to a very sparse representation that is very wasteful computationally.\n", 1003 | "\n", 1004 | "a- Introduce an embedding lookup on the inputs, and feed the embeddings to the LSTM cell instead of the inputs themselves.\n", 1005 | "\n", 1006 | "b- Write a bigram-based LSTM, modeled on the character LSTM above.\n", 1007 | "\n", 1008 | "c- Introduce Dropout. For best practices on how to use Dropout in LSTMs, refer to this [article](http://arxiv.org/abs/1409.2329).\n", 1009 | "\n", 1010 | "---" 1011 | ] 1012 | }, 1013 | { 1014 | "cell_type": "markdown", 1015 | "metadata": { 1016 | "colab_type": "text", 1017 | "id": "Y5tapX3kpcqZ" 1018 | }, 1019 | "source": [ 1020 | "---\n", 1021 | "Problem 3\n", 1022 | "---------\n", 1023 | "\n", 1024 | "(difficult!)\n", 1025 | "\n", 1026 | "Write a sequence-to-sequence LSTM which mirrors all the words in a sentence. For example, if your input is:\n", 1027 | "\n", 1028 | " the quick brown fox\n", 1029 | " \n", 1030 | "the model should attempt to output:\n", 1031 | "\n", 1032 | " eht kciuq nworb xof\n", 1033 | " \n", 1034 | "Reference: http://arxiv.org/abs/1409.3215\n", 1035 | "\n", 1036 | "---" 1037 | ] 1038 | } 1039 | ], 1040 | "metadata": { 1041 | "colabVersion": "0.3.2", 1042 | "colab_default_view": {}, 1043 | "colab_views": {}, 1044 | "kernelspec": { 1045 | "display_name": "Python 2", 1046 | "language": "python", 1047 | "name": "python2" 1048 | }, 1049 | "language_info": { 1050 | "codemirror_mode": { 1051 | "name": "ipython", 1052 | "version": 2 1053 | }, 1054 | "file_extension": ".py", 1055 | "mimetype": "text/x-python", 1056 | "name": "python", 1057 | "nbconvert_exporter": "python", 1058 | "pygments_lexer": "ipython2", 1059 | "version": "2.7.12" 1060 | } 1061 | }, 1062 | "nbformat": 4, 1063 | "nbformat_minor": 0 1064 | } 1065 | -------------------------------------------------------------------------------- /test_files/notebooks/clipboards.py: -------------------------------------------------------------------------------- 1 | """ io on the clipboard """ 2 | from io import StringIO 3 | import warnings 4 | 5 | from pandas.core.dtypes.generic import ABCDataFrame 6 | 7 | from pandas import get_option, option_context 8 | 9 | 10 | def read_clipboard(sep=r"\s+", **kwargs): # pragma: no cover 11 | 12 | encoding = kwargs.pop("encoding", "utf-8") 13 | 14 | # only utf-8 is valid for passed value because that's what clipboard 15 | # supports 16 | if encoding is not None and encoding.lower().replace("-", "") != "utf8": 17 | raise NotImplementedError("reading from clipboard only supports utf-8 encoding") 18 | 19 | from pandas.io.clipboard import clipboard_get 20 | from pandas.io.parsers import read_csv 21 | 22 | text = clipboard_get() 23 | 24 | # Try to decode (if needed, as "text" might already be a string here). 25 | try: 26 | text = text.decode(kwargs.get("encoding") or get_option("display.encoding")) 27 | except AttributeError: 28 | pass 29 | 30 | # Excel copies into clipboard with \t separation 31 | # inspect no more then the 10 first lines, if they 32 | # all contain an equal number (>0) of tabs, infer 33 | # that this came from excel and set 'sep' accordingly 34 | lines = text[:10000].split("\n")[:-1][:10] 35 | 36 | # Need to remove leading white space, since read_csv 37 | # accepts: 38 | # a b 39 | # 0 1 2 40 | # 1 3 4 41 | 42 | counts = {x.lstrip().count("\t") for x in lines} 43 | if len(lines) > 1 and len(counts) == 1 and counts.pop() != 0: 44 | sep = "\t" 45 | 46 | # Edge case where sep is specified to be None, return to default 47 | if sep is None and kwargs.get("delim_whitespace") is None: 48 | sep = r"\s+" 49 | 50 | # Regex separator currently only works with python engine. 51 | # Default to python if separator is multi-character (regex) 52 | if len(sep) > 1 and kwargs.get("engine") is None: 53 | kwargs["engine"] = "python" 54 | elif len(sep) > 1 and kwargs.get("engine") == "c": 55 | warnings.warn( 56 | "read_clipboard with regex separator does not work properly with c engine" 57 | ) 58 | 59 | return read_csv(StringIO(text), sep=sep, **kwargs) 60 | 61 | 62 | def to_clipboard(obj, excel=True, sep=None, **kwargs): # pragma: no cover 63 | 64 | encoding = kwargs.pop("encoding", "utf-8") 65 | 66 | # testing if an invalid encoding is passed to clipboard 67 | if encoding is not None and encoding.lower().replace("-", "") != "utf8": 68 | raise ValueError("clipboard only supports utf-8 encoding") 69 | 70 | from pandas.io.clipboard import clipboard_set 71 | 72 | if excel is None: 73 | excel = True 74 | 75 | if excel: 76 | try: 77 | if sep is None: 78 | sep = "\t" 79 | buf = StringIO() 80 | 81 | # clipboard_set (pyperclip) expects unicode 82 | obj.to_csv(buf, sep=sep, encoding="utf-8", **kwargs) 83 | text = buf.getvalue() 84 | 85 | clipboard_set(text) 86 | return 87 | except TypeError: 88 | warnings.warn( 89 | "to_clipboard in excel mode requires a single character separator." 90 | ) 91 | elif sep is not None: 92 | warnings.warn("to_clipboard with excel=False ignores the sep argument") 93 | 94 | if isinstance(obj, ABCDataFrame): 95 | # str(df) has various unhelpful defaults, like truncation 96 | with option_context("display.max_colwidth", None): 97 | objstr = obj.to_string(**kwargs) 98 | else: 99 | objstr = str(obj) 100 | clipboard_set(objstr) 101 | -------------------------------------------------------------------------------- /test_files/random.txt: -------------------------------------------------------------------------------- 1 | Just a random 2 | 3 | Test file 4 | 5 | To Test -------------------------------------------------------------------------------- /test_files/simple_funcs/simple_funcs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, TextIO 3 | 4 | a_string = """ 5 | 6 | This is a multi line string 7 | 8 | """ 9 | 10 | VAR = 50 # A Variable 11 | 12 | ##################### 13 | 14 | #### COMMENT ####### 15 | 16 | #################### 17 | 18 | # def add(a, b: int=10): 19 | def check_if_odd(num): 20 | """ 21 | Checks if number is odd 22 | 23 | Args: 24 | num : 25 | 26 | 27 | (Generated by docly) 28 | """ 29 | return True if num % 2 != 0 else False 30 | # return a * b 31 | 32 | 33 | def check_even_numbers_in_a_list (base_list) -> list: 34 | """ 35 | Checks that all numbers in a list are equal. 36 | 37 | Args: 38 | base_list : 39 | 40 | 41 | (Generated by docly) 42 | """ 43 | return [a for a in base_list if a % 2 == 0] 44 | 45 | 46 | def open_file(file_path) -> TextIO: 47 | """ 48 | open a file 49 | 50 | Args: 51 | file_path : 52 | 53 | 54 | (Generated by docly) 55 | """ 56 | return open(file_path, "r") 57 | 58 | 59 | def add_tensors (t, t1) -> Any: 60 | """ 61 | Add two tensors. 62 | 63 | Args: 64 | t : 65 | t1 : 66 | 67 | 68 | (Generated by docly) 69 | """ 70 | return t + t1 71 | 72 | 73 | def print_hello_greetings() -> None: 74 | """ 75 | Emulate greeting. 76 | 77 | (Generated by docly) 78 | """ 79 | print("Hello") 80 | 81 | 82 | def echo_name(name) -> str: 83 | """ 84 | Echo the given name. 85 | 86 | Args: 87 | name : 88 | 89 | 90 | (Generated by docly) 91 | """ 92 | return f"Hello {name}" 93 | 94 | 95 | if __name__ == "__main__": 96 | print(add(12, 12)) 97 | 98 | 99 | def area_rectangle(length, width) -> Any: 100 | """ 101 | Return the area of a rectangle. 102 | 103 | Args: 104 | length : 105 | width : 106 | 107 | 108 | (Generated by docly) 109 | """ 110 | if length < 0 or width < 0: 111 | raise ValueError("only accepts non-negative values") 112 | return length * width 113 | -------------------------------------------------------------------------------- /test_files/test_config_file.ini: -------------------------------------------------------------------------------- 1 | [skipDirs] 2 | test_files/flask_files = * --------------------------------------------------------------------------------