├── .circleci └── config.yml ├── .gitignore ├── .ipython-startup.ipy ├── .pylintrc ├── Pipfile ├── Pipfile.lock ├── README.md ├── bin └── suplearn-clone ├── config.yml.example ├── ipython_config.py ├── requirements.txt ├── samples └── encoder-nn.py ├── scripts ├── create_submissions.py └── create_test_data.py ├── setup.py ├── sql └── create_tables.sql ├── suplearn_clone_detection ├── __init__.py ├── ast.py ├── ast_loader.py ├── ast_transformer.py ├── callbacks.py ├── cli.py ├── commands.py ├── config.py ├── database.py ├── dataset │ ├── __init__.py │ ├── generator.py │ ├── sequences.py │ └── util.py ├── detector.py ├── entities.py ├── evaluator.py ├── file_processor.py ├── layers.py ├── model.py ├── predictor.py ├── results_printer.py ├── settings.py ├── token_based │ ├── __init__.py │ ├── commands.py │ ├── skipgram_generator.py │ ├── util.py │ ├── vocab_item.py │ └── vocabulary_generator.py ├── trainer.py ├── util.py ├── vectorizer.py └── vocabulary.py └── tests ├── __init__.py ├── ast_loader_test.py ├── ast_test.py ├── ast_transformer_test.py ├── base.py ├── config_test.py ├── fixtures ├── asts.json ├── asts.txt ├── config.yml ├── submissions.json ├── vocab-100.tsv └── vocab-noid.tsv ├── util_test.py └── vocabulary_test.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: circleci/python:3.6 6 | environment: 7 | PIPENV_VENV_IN_PROJECT: "1" 8 | working_directory: ~/suplearn-clone-detection 9 | 10 | steps: 11 | - checkout 12 | - restore_cache: 13 | keys: 14 | - v1-dependencies-{{ checksum "Pipfile.lock" }} 15 | - run: 16 | name: install dependencies 17 | command: | 18 | pipenv install 19 | pipenv run pip install tensorflow 20 | - save_cache: 21 | paths: 22 | - ./.venv 23 | key: v1-dependencies-{{ checksum "Pipfile.lock" }} 24 | - run: 25 | name: run tests 26 | command: pipenv run python setup.py test 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # End of https://www.gitignore.io/api/python 108 | 109 | /config.yml 110 | 111 | tmp/ 112 | 113 | /.vscode/ 114 | .noseids 115 | .merlin 116 | -------------------------------------------------------------------------------- /.ipython-startup.ipy: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from keras.models import load_model 4 | 5 | from sqlalchemy import create_engine 6 | 7 | from suplearn_clone_detection import ast_transformer, layers, entities 8 | from suplearn_clone_detection.database import Session 9 | from suplearn_clone_detection.config import Config 10 | from suplearn_clone_detection.ast_loader import ASTLoader 11 | 12 | config = Config.from_file("./config.yml") 13 | 14 | if config.generator.db_path: 15 | engine = create_engine(config.generator.db_path, echo=True) 16 | Session.configure(bind=engine) 17 | 18 | ast_transformers = ast_transformer.create_all(config.model.languages) 19 | 20 | sess = tf.InteractiveSession() 21 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code 6 | extension-pkg-whitelist= 7 | 8 | # Add files or directories to the blacklist. They should be base names, not 9 | # paths. 10 | ignore=CVS 11 | 12 | # Add files or directories matching the regex patterns to the blacklist. The 13 | # regex matches against base names, not paths. 14 | ignore-patterns= 15 | 16 | # Python code to execute, usually for sys.path manipulation such as 17 | # pygtk.require(). 18 | #init-hook= 19 | 20 | # Use multiple processes to speed up Pylint. 21 | jobs=1 22 | 23 | # List of plugins (as comma separated values of python modules names) to load, 24 | # usually to register additional checkers. 25 | load-plugins= 26 | 27 | # Pickle collected data for later comparisons. 28 | persistent=yes 29 | 30 | # Specify a configuration file. 31 | #rcfile= 32 | 33 | # Allow loading of arbitrary C extensions. Extensions are imported into the 34 | # active Python interpreter and may run arbitrary code. 35 | unsafe-load-any-extension=no 36 | 37 | 38 | [MESSAGES CONTROL] 39 | 40 | # Only show warnings with the listed confidence levels. Leave empty to show 41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 42 | confidence= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,long-suffix,old-ne-operator,old-octal-literal,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,locally-enabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,eq-without-hash,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call,missing-docstring,no-else-return,invalid-name,too-few-public-methods 54 | 55 | # Enable the message, report, category or checker with the given id(s). You can 56 | # either give multiple identifier separated by comma (,) or put this option 57 | # multiple time (only on the command line, not in the configuration file where 58 | # it should appear only once). See also the "--disable" option for examples. 59 | enable= 60 | 61 | 62 | [REPORTS] 63 | 64 | # Python expression which should return a note less than 10 (10 is the highest 65 | # note). You have access to the variables errors warning, statement which 66 | # respectively contain the number of errors / warnings messages and the total 67 | # number of statements analyzed. This is used by the global evaluation report 68 | # (RP0004). 69 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 70 | 71 | # Template used to display messages. This is a python new-style format string 72 | # used to format the message information. See doc for all details 73 | #msg-template= 74 | 75 | # Set the output format. Available formats are text, parseable, colorized, json 76 | # and msvs (visual studio).You can also give a reporter class, eg 77 | # mypackage.mymodule.MyReporterClass. 78 | output-format=text 79 | 80 | # Tells whether to display a full report or only the messages 81 | reports=no 82 | 83 | # Activate the evaluation score. 84 | score=yes 85 | 86 | 87 | [REFACTORING] 88 | 89 | # Maximum number of nested blocks for function / method body 90 | max-nested-blocks=5 91 | 92 | 93 | [TYPECHECK] 94 | 95 | # List of decorators that produce context managers, such as 96 | # contextlib.contextmanager. Add to this list to register other decorators that 97 | # produce valid context managers. 98 | contextmanager-decorators=contextlib.contextmanager 99 | 100 | # List of members which are set dynamically and missed by pylint inference 101 | # system, and so shouldn't trigger E1101 when accessed. Python regular 102 | # expressions are accepted. 103 | generated-members= 104 | 105 | # Tells whether missing members accessed in mixin class should be ignored. A 106 | # mixin class is detected if its name ends with "mixin" (case insensitive). 107 | ignore-mixin-members=yes 108 | 109 | # This flag controls whether pylint should warn about no-member and similar 110 | # checks whenever an opaque object is returned when inferring. The inference 111 | # can return multiple potential results while evaluating a Python object, but 112 | # some branches might not be evaluated, which results in partial inference. In 113 | # that case, it might be useful to still emit no-member and other checks for 114 | # the rest of the inferred objects. 115 | ignore-on-opaque-inference=yes 116 | 117 | # List of class names for which member attributes should not be checked (useful 118 | # for classes with dynamically set attributes). This supports the use of 119 | # qualified names. 120 | ignored-classes=optparse.Values,thread._local,_thread._local,scoped_session 121 | 122 | # List of module names for which member attributes should not be checked 123 | # (useful for modules/projects where namespaces are manipulated during runtime 124 | # and thus existing member attributes cannot be deduced by static analysis. It 125 | # supports qualified module names, as well as Unix pattern matching. 126 | ignored-modules= 127 | 128 | # Show a hint with possible names when a member name was not found. The aspect 129 | # of finding the hint is based on edit distance. 130 | missing-member-hint=yes 131 | 132 | # The minimum edit distance a name should have in order to be considered a 133 | # similar match for a missing member name. 134 | missing-member-hint-distance=1 135 | 136 | # The total number of similar names that should be taken in consideration when 137 | # showing a hint for a missing member. 138 | missing-member-max-choices=1 139 | 140 | 141 | [SIMILARITIES] 142 | 143 | # Ignore comments when computing similarities. 144 | ignore-comments=yes 145 | 146 | # Ignore docstrings when computing similarities. 147 | ignore-docstrings=yes 148 | 149 | # Ignore imports when computing similarities. 150 | ignore-imports=no 151 | 152 | # Minimum lines number of a similarity. 153 | min-similarity-lines=4 154 | 155 | 156 | [FORMAT] 157 | 158 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 159 | expected-line-ending-format= 160 | 161 | # Regexp for a line that is allowed to be longer than the limit. 162 | ignore-long-lines=^\s*(# )??$ 163 | 164 | # Number of spaces of indent required inside a hanging or continued line. 165 | indent-after-paren=4 166 | 167 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 168 | # tab). 169 | indent-string=' ' 170 | 171 | # Maximum number of characters on a single line. 172 | max-line-length=100 173 | 174 | # Maximum number of lines in a module 175 | max-module-lines=1000 176 | 177 | # List of optional constructs for which whitespace checking is disabled. `dict- 178 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 179 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 180 | # `empty-line` allows space-only lines. 181 | no-space-check=trailing-comma,dict-separator 182 | 183 | # Allow the body of a class to be on the same line as the declaration if body 184 | # contains single statement. 185 | single-line-class-stmt=no 186 | 187 | # Allow the body of an if to be on the same line as the test if there is no 188 | # else. 189 | single-line-if-stmt=no 190 | 191 | 192 | [LOGGING] 193 | 194 | # Logging modules to check that the string format arguments are in logging 195 | # function parameter format 196 | logging-modules=logging 197 | 198 | 199 | [VARIABLES] 200 | 201 | # List of additional names supposed to be defined in builtins. Remember that 202 | # you should avoid to define new builtins when possible. 203 | additional-builtins= 204 | 205 | # Tells whether unused global variables should be treated as a violation. 206 | allow-global-unused-variables=yes 207 | 208 | # List of strings which can identify a callback function by name. A callback 209 | # name must start or end with one of those strings. 210 | callbacks=cb_,_cb 211 | 212 | # A regular expression matching the name of dummy variables (i.e. expectedly 213 | # not used). 214 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 215 | 216 | # Argument names that match this expression will be ignored. Default to name 217 | # with leading underscore 218 | ignored-argument-names=_.*|^ignored_|^unused_ 219 | 220 | # Tells whether we should check for unused import in __init__ files. 221 | init-import=no 222 | 223 | # List of qualified module names which can have objects that can redefine 224 | # builtins. 225 | redefining-builtins-modules=six.moves,future.builtins 226 | 227 | 228 | [BASIC] 229 | 230 | # Naming hint for argument names 231 | argument-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 232 | 233 | # Regular expression matching correct argument names 234 | argument-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 235 | 236 | # Naming hint for attribute names 237 | attr-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 238 | 239 | # Regular expression matching correct attribute names 240 | attr-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 241 | 242 | # Bad variable names which should always be refused, separated by a comma 243 | bad-names=foo,bar,baz,toto,tutu,tata 244 | 245 | # Naming hint for class attribute names 246 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 247 | 248 | # Regular expression matching correct class attribute names 249 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 250 | 251 | # Naming hint for class names 252 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 253 | 254 | # Regular expression matching correct class names 255 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 256 | 257 | # Naming hint for constant names 258 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 259 | 260 | # Regular expression matching correct constant names 261 | const-rgx=(([a-zA-Z_][a-zA-Z0-9_]*)|(__.*__))$ 262 | 263 | # Minimum line length for functions/classes that require docstrings, shorter 264 | # ones are exempt. 265 | docstring-min-length=-1 266 | 267 | # Naming hint for function names 268 | function-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 269 | 270 | # Regular expression matching correct function names 271 | function-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 272 | 273 | # Good variable names which should always be accepted, separated by a comma 274 | good-names=i,j,k,ex,Run,_,f,e,x 275 | 276 | # Include a hint for the correct naming format with invalid-name 277 | include-naming-hint=no 278 | 279 | # Naming hint for inline iteration names 280 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 281 | 282 | # Regular expression matching correct inline iteration names 283 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 284 | 285 | # Naming hint for method names 286 | method-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 287 | 288 | # Regular expression matching correct method names 289 | method-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 290 | 291 | # Naming hint for module names 292 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 293 | 294 | # Regular expression matching correct module names 295 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 296 | 297 | # Colon-delimited sets of names that determine each other's naming style when 298 | # the name regexes allow several styles. 299 | name-group= 300 | 301 | # Regular expression which should only match function or class names that do 302 | # not require a docstring. 303 | no-docstring-rgx=^_ 304 | 305 | # List of decorators that produce properties, such as abc.abstractproperty. Add 306 | # to this list to register other decorators that produce valid properties. 307 | property-classes=abc.abstractproperty 308 | 309 | # Naming hint for variable names 310 | variable-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 311 | 312 | # Regular expression matching correct variable names 313 | variable-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 314 | 315 | 316 | [MISCELLANEOUS] 317 | 318 | # List of note tags to take in consideration, separated by a comma. 319 | notes=FIXME,XXX,TODO 320 | 321 | 322 | [SPELLING] 323 | 324 | # Spelling dictionary name. Available dictionaries: none. To make it working 325 | # install python-enchant package. 326 | spelling-dict= 327 | 328 | # List of comma separated words that should not be checked. 329 | spelling-ignore-words= 330 | 331 | # A path to a file that contains private dictionary; one word per line. 332 | spelling-private-dict-file= 333 | 334 | # Tells whether to store unknown words to indicated private dictionary in 335 | # --spelling-private-dict-file option instead of raising a message. 336 | spelling-store-unknown-words=no 337 | 338 | 339 | [IMPORTS] 340 | 341 | # Allow wildcard imports from modules that define __all__. 342 | allow-wildcard-with-all=no 343 | 344 | # Analyse import fallback blocks. This can be used to support both Python 2 and 345 | # 3 compatible code, which means that the block might have code that exists 346 | # only in one or another interpreter, leading to false positives when analysed. 347 | analyse-fallback-blocks=no 348 | 349 | # Deprecated modules which should not be used, separated by a comma 350 | deprecated-modules=optparse,tkinter.tix 351 | 352 | # Create a graph of external dependencies in the given file (report RP0402 must 353 | # not be disabled) 354 | ext-import-graph= 355 | 356 | # Create a graph of every (i.e. internal and external) dependencies in the 357 | # given file (report RP0402 must not be disabled) 358 | import-graph= 359 | 360 | # Create a graph of internal dependencies in the given file (report RP0402 must 361 | # not be disabled) 362 | int-import-graph= 363 | 364 | # Force import order to recognize a module as part of the standard 365 | # compatibility libraries. 366 | known-standard-library= 367 | 368 | # Force import order to recognize a module as part of a third party library. 369 | known-third-party=enchant 370 | 371 | 372 | [DESIGN] 373 | 374 | # Maximum number of arguments for function / method 375 | max-args=5 376 | 377 | # Maximum number of attributes for a class (see R0902). 378 | max-attributes=7 379 | 380 | # Maximum number of boolean expressions in a if statement 381 | max-bool-expr=5 382 | 383 | # Maximum number of branch for function / method body 384 | max-branches=12 385 | 386 | # Maximum number of locals for function / method body 387 | max-locals=15 388 | 389 | # Maximum number of parents for a class (see R0901). 390 | max-parents=7 391 | 392 | # Maximum number of public methods for a class (see R0904). 393 | max-public-methods=20 394 | 395 | # Maximum number of return / yield for function / method body 396 | max-returns=6 397 | 398 | # Maximum number of statements in function / method body 399 | max-statements=50 400 | 401 | # Minimum number of public methods for a class (see R0903). 402 | min-public-methods=2 403 | 404 | 405 | [CLASSES] 406 | 407 | # List of method names used to declare (i.e. assign) instance attributes. 408 | defining-attr-methods=__init__,__new__,setUp 409 | 410 | # List of member names, which should be excluded from the protected access 411 | # warning. 412 | exclude-protected=_asdict,_fields,_replace,_source,_make 413 | 414 | # List of valid names for the first argument in a class method. 415 | valid-classmethod-first-arg=cls 416 | 417 | # List of valid names for the first argument in a metaclass class method. 418 | valid-metaclass-classmethod-first-arg=mcs 419 | 420 | 421 | [EXCEPTIONS] 422 | 423 | # Exceptions that will emit a warning when being caught. Defaults to 424 | # "Exception" 425 | overgeneral-exceptions=Exception 426 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.python.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | "e1839a8" = {path = ".", editable = true} 8 | 9 | [dev-packages] 10 | ipython = "*" 11 | pylint = "*" 12 | nose = "*" 13 | rope = "*" 14 | 15 | [requires] 16 | python_version = "3.6" 17 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "2033709b5464cdd9c03a5861936b0bd8ab229d0afe442617e5b8c62427e62a2d" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.6" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.python.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "cycler": { 20 | "hashes": [ 21 | "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", 22 | "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8" 23 | ], 24 | "version": "==0.10.0" 25 | }, 26 | "e1839a8": { 27 | "editable": true, 28 | "path": "." 29 | }, 30 | "h5py": { 31 | "hashes": [ 32 | "sha256:07ddea6bb649a257fc57ccae359a36d691b2ef8b9617971ae7d6f74ef6f67cad", 33 | "sha256:180a688311e826ff6ae6d3bda9b5c292b90b28787525ddfcb10a29d5ddcae2cc", 34 | "sha256:1be9cd57e74b24f836d0d2c34ae376ff2df704f40aa8815aa9113b5a860d467f", 35 | "sha256:1fad9aa32835230de77b31edd6980b7c202de7bb7d8384d1bcb47b5dd32c8c7c", 36 | "sha256:2258fca3533a3276fd86e9196326786f408a95748ac707c010fff265edf60342", 37 | "sha256:2b91c9117f2e7a2ef924bec41ac77e57567bec6731773373bf78eb4387b39a2a", 38 | "sha256:2ccb4f405059314829ebad1859d2c68e133a9d13ca7c3cc7a298a76a438fd09c", 39 | "sha256:2d137a1b2f529e58886b5865f6dec51cd96ea0671dd84cebc6dba5cd8c7d0a75", 40 | "sha256:40dd37cbf24ca3b935a8d6eb8960ec5d0381219f82317bdc40aa9e08b3fcc143", 41 | "sha256:478efa37b84a56061af5fcd286678331e873e216f6c5987cd31f9666edc2f157", 42 | "sha256:4a6e6cd8668fa453864f4f9e243460dcc2d41e79d14516b84f4ba74ebcc5b222", 43 | "sha256:52204972a02032d6a427addd37a24a22a2b97d4bce0850c84a6995db9c91926c", 44 | "sha256:537a60879485e5ce484ab4350c7bd8b3da4b531f9f82ef0a18780beabde98c90", 45 | "sha256:562045c57a2e47aca9c716ac8cd64448d4897c0f5fe456ab5a34b17c8b3907cb", 46 | "sha256:66609c48f8841357ced4291b7c9009518bb6e6fec449d91eb46aa417b6f5f4cf", 47 | "sha256:9d9fb861e10735c5c710fe18f34c69e470cf161a4ba38717b7dde21de2d33760", 48 | "sha256:9e0537058efea7547d976f9c00067f7193727bb41ce6b4733c52de35beaa46f5", 49 | "sha256:a314e5e98037ece52ad0b88b4e0d788ca554935268f3e9d293ca9bcd18611b42", 50 | "sha256:b7e1c42367513108c3615cf1a24a9d366fd93eb9d2d92085bafb3011b785e8a9", 51 | "sha256:bb990d8663dbeee22ce44135ffd65ab38bd23d6a689722a653cfbf2d18d46688", 52 | "sha256:c050791989cd9979fe57a770d4e323b2e67ef95800e89e7dc6ad3652b8ccd86f", 53 | "sha256:e1bfcfa2c425dc0f637d4edd858b94e400bbb5746dba324ace124d55fc21d3df", 54 | "sha256:e78f09a44fc9256b84c9df98edf7b6ead3b3da2e12bf2d1e00384960a6a78a1a" 55 | ], 56 | "version": "==2.7.1" 57 | }, 58 | "keras": { 59 | "hashes": [ 60 | "sha256:a7abff3e87b6706764fbbd0521ba8f54c48f0720fd0495860cd3d2fd67899d48", 61 | "sha256:c14af1081242c25617ade7eb62121d58d01f16e1e744bae9fc4f1f95a417716e" 62 | ], 63 | "version": "==2.1.6" 64 | }, 65 | "kiwisolver": { 66 | "hashes": [ 67 | "sha256:0ee4ed8b3ae8f5f712b0aa9ebd2858b5b232f1b9a96b0943dceb34df2a223bc3", 68 | "sha256:0f7f532f3c94e99545a29f4c3f05637f4d2713e7fd91b4dd8abfc18340b86cd5", 69 | "sha256:1a078f5dd7e99317098f0e0d490257fd0349d79363e8c923d5bb76428f318421", 70 | "sha256:1aa0b55a0eb1bd3fa82e704f44fb8f16e26702af1a073cc5030eea399e617b56", 71 | "sha256:2874060b91e131ceeff00574b7c2140749c9355817a4ed498e82a4ffa308ecbc", 72 | "sha256:379d97783ba8d2934d52221c833407f20ca287b36d949b4bba6c75274bcf6363", 73 | "sha256:3b791ddf2aefc56382aadc26ea5b352e86a2921e4e85c31c1f770f527eb06ce4", 74 | "sha256:4329008a167fac233e398e8a600d1b91539dc33c5a3eadee84c0d4b04d4494fa", 75 | "sha256:45813e0873bbb679334a161b28cb9606d9665e70561fd6caa8863e279b5e464b", 76 | "sha256:53a5b27e6b5717bdc0125338a822605084054c80f382051fb945d2c0e6899a20", 77 | "sha256:66f82819ff47fa67a11540da96966fb9245504b7f496034f534b81cacf333861", 78 | "sha256:79e5fe3ccd5144ae80777e12973027bd2f4f5e3ae8eb286cabe787bed9780138", 79 | "sha256:8b6a7b596ce1d2a6d93c3562f1178ebd3b7bb445b3b0dd33b09f9255e312a965", 80 | "sha256:9576cb63897fbfa69df60f994082c3f4b8e6adb49cccb60efb2a80a208e6f996", 81 | "sha256:95a25d9f3449046ecbe9065be8f8380c03c56081bc5d41fe0fb964aaa30b2195", 82 | "sha256:aaec1cfd94f4f3e9a25e144d5b0ed1eb8a9596ec36d7318a504d813412563a85", 83 | "sha256:acb673eecbae089ea3be3dcf75bfe45fc8d4dcdc951e27d8691887963cf421c7", 84 | "sha256:b15bc8d2c2848a4a7c04f76c9b3dc3561e95d4dabc6b4f24bfabe5fd81a0b14f", 85 | "sha256:b1c240d565e977d80c0083404c01e4d59c5772c977fae2c483f100567f50847b", 86 | "sha256:ce3be5d520b4d2c3e5eeb4cd2ef62b9b9ab8ac6b6fedbaa0e39cdb6f50644278", 87 | "sha256:e0f910f84b35c36a3513b96d816e6442ae138862257ae18a0019d2fc67b041dc", 88 | "sha256:ea36e19ac0a483eea239320aef0bd40702404ff8c7e42179a2d9d36c5afcb55c", 89 | "sha256:f923406e6b32c86309261b8195e24e18b6a8801df0cfc7814ac44017bfcb3939" 90 | ], 91 | "version": "==1.0.1" 92 | }, 93 | "matplotlib": { 94 | "hashes": [ 95 | "sha256:07055eb872fa109bd88f599bdb52065704b2e22d475b67675f345d75d32038a0", 96 | "sha256:0f2f253d6d51f5ed52a819921f8a0a8e054ce0daefcfbc2557e1c433f14dc77d", 97 | "sha256:1ef9fd285334bd6b0495b6de9d56a39dc95081577f27bafabcf28e0d318bed31", 98 | "sha256:3fb2db66ef98246bafc04b4ef4e9b0e73c6369f38a29716844e939d197df816a", 99 | "sha256:3fd90b407d1ab0dae686a4200030ce305526ff20b85a443dc490d194114b2dfa", 100 | "sha256:45dac8589ef1721d7f2ab0f48f986694494dfcc5d13a3e43a5cb6c816276094e", 101 | "sha256:4bb10087e09629ba3f9b25b6c734fd3f99542f93d71c5b9c023f28cb377b43a9", 102 | "sha256:4dc7ef528aad21f22be85e95725234c5178c0f938e2228ca76640e5e84d8cde8", 103 | "sha256:4f6a516d5ef39128bb16af7457e73dde25c30625c4916d8fbd1cc7c14c55e691", 104 | "sha256:70f0e407fbe9e97f16597223269c849597047421af5eb8b60dbaca0382037e78", 105 | "sha256:7b3d03c876684618e2a2be6abeb8d3a033c3a1bb38a786f751199753ef6227e6", 106 | "sha256:8944d311ce37bee1ba0e41a9b58dcf330ffe0cf29d7654c3d07c572215da68ac", 107 | "sha256:8ff08eaa25c66383fe3b6c7eb288da3c22dcedc4b110a0b592b35f68d0e093b2", 108 | "sha256:9d12378d6a236aa38326e27f3a29427b63edce4ce325745785aec1a7535b1f85", 109 | "sha256:abfd3d9390eb4f2d82cbcaa3a5c2834c581329b64eccb7a071ed9d5df27424f7", 110 | "sha256:bc4d7481f0e8ec94cb1afc4a59905d6274b3b4c389aba7a2539e071766671735", 111 | "sha256:dc0ba2080fd0cfdd07b3458ee4324d35806733feb2b080838d7094731d3f73d9", 112 | "sha256:f26fba7fc68994ab2805d77e0695417f9377a00d36ba4248b5d0f1e5adb08d24" 113 | ], 114 | "version": "==2.2.2" 115 | }, 116 | "numpy": { 117 | "hashes": [ 118 | "sha256:0074d42e2cc333800bd09996223d40ec52e3b1ec0a5cab05dacc09b662c4c1ae", 119 | "sha256:034717bfef517858abc79324820a702dc6cd063effb9baab86533e8a78670689", 120 | "sha256:0db6301324d0568089663ef2701ad90ebac0e975742c97460e89366692bd0563", 121 | "sha256:1864d005b2eb7598063e35c320787d87730d864f40d6410f768fe4ea20672016", 122 | "sha256:46ce8323ca9384814c7645298b8b627b7d04ce97d6948ef02da357b2389d6972", 123 | "sha256:510863d606c932b41d2209e4de6157ab3fdf52001d3e4ad351103176d33c4b8b", 124 | "sha256:560e23a12e7599be8e8b67621396c5bc687fd54b48b890adbc71bc5a67333f86", 125 | "sha256:57dc6c22d59054542600fce6fae2d1189b9c50bafc1aab32e55f7efcc84a6c46", 126 | "sha256:760550fdf9d8ec7da9c4402a4afe6e25c0f184ae132011676298a6b636660b45", 127 | "sha256:8670067685051b49d1f2f66e396488064299fefca199c7c80b6ba0c639fedc98", 128 | "sha256:9016692c7d390f9d378fc88b7a799dc9caa7eb938163dda5276d3f3d6f75debf", 129 | "sha256:98ff275f1b5907490d26b30b6ff111ecf2de0254f0ab08833d8fe61aa2068a00", 130 | "sha256:9ccf4d5c9139b1e985db915039baa0610a7e4a45090580065f8d8cb801b7422f", 131 | "sha256:a8dbab311d4259de5eeaa5b4e83f5f8545e4808f9144e84c0f424a6ee55a7b98", 132 | "sha256:aaef1bea636b6e552bbc5dae0ada87d4f6046359daaa97a05a013b0169620f27", 133 | "sha256:b8987e30d9a0eb6635df9705a75cf8c4a2835590244baecf210163343bc65176", 134 | "sha256:c3fe23df6fe0898e788581753da453f877350058c5982e85a8972feeecb15309", 135 | "sha256:c5eb7254cfc4bd7a4330ad7e1f65b98343836865338c57b0e25c661e41d5cfd9", 136 | "sha256:c80fcf9b38c7f4df666150069b04abbd2fe42ae640703a6e1f128cda83b552b7", 137 | "sha256:e33baf50f2f6b7153ddb973601a11df852697fba4c08b34a5e0f39f66f8120e1", 138 | "sha256:e8578a62a8eaf552b95d62f630bb5dd071243ba1302bbff3e55ac48588508736", 139 | "sha256:f22b3206f1c561dd9110b93d144c6aaa4a9a354e3b07ad36030df3ea92c5bb5b", 140 | "sha256:f39afab5769b3aaa786634b94b4a23ef3c150bdda044e8a32a3fc16ddafe803b" 141 | ], 142 | "version": "==1.14.3" 143 | }, 144 | "pyparsing": { 145 | "hashes": [ 146 | "sha256:0832bcf47acd283788593e7a0f542407bd9550a55a8a8435214a1960e04bcb04", 147 | "sha256:281683241b25fe9b80ec9d66017485f6deff1af5cde372469134b56ca8447a07", 148 | "sha256:8f1e18d3fd36c6795bb7e02a39fd05c611ffc2596c1e0d995d34d67630426c18", 149 | "sha256:9e8143a3e15c13713506886badd96ca4b579a87fbdf49e550dbfc057d6cb218e", 150 | "sha256:b8b3117ed9bdf45e14dcc89345ce638ec7e0e29b2b579fa1ecf32ce45ebac8a5", 151 | "sha256:e4d45427c6e20a59bf4f88c639dcc03ce30d193112047f94012102f235853a58", 152 | "sha256:fee43f17a9c4087e7ed1605bd6df994c6173c1e977d7ade7b651292fab2bd010" 153 | ], 154 | "version": "==2.2.0" 155 | }, 156 | "python-dateutil": { 157 | "hashes": [ 158 | "sha256:1adb80e7a782c12e52ef9a8182bebeb73f1d7e24e374397af06fb4956c8dc5c0", 159 | "sha256:e27001de32f627c22380a688bcc43ce83504a7bc5da472209b4c70f02829f0b8" 160 | ], 161 | "version": "==2.7.3" 162 | }, 163 | "pytz": { 164 | "hashes": [ 165 | "sha256:65ae0c8101309c45772196b21b74c46b2e5d11b6275c45d251b150d5da334555", 166 | "sha256:c06425302f2cf668f1bba7a0a03f3c1d34d4ebeef2c72003da308b3947c7f749" 167 | ], 168 | "version": "==2018.4" 169 | }, 170 | "pyyaml": { 171 | "hashes": [ 172 | "sha256:0c507b7f74b3d2dd4d1322ec8a94794927305ab4cebbe89cc47fe5e81541e6e8", 173 | "sha256:16b20e970597e051997d90dc2cddc713a2876c47e3d92d59ee198700c5427736", 174 | "sha256:3262c96a1ca437e7e4763e2843746588a965426550f3797a79fca9c6199c431f", 175 | "sha256:326420cbb492172dec84b0f65c80942de6cedb5233c413dd824483989c000608", 176 | "sha256:4474f8ea030b5127225b8894d626bb66c01cda098d47a2b0d3429b6700af9fd8", 177 | "sha256:592766c6303207a20efc445587778322d7f73b161bd994f227adaa341ba212ab", 178 | "sha256:5ac82e411044fb129bae5cfbeb3ba626acb2af31a8d17d175004b70862a741a7", 179 | "sha256:5f84523c076ad14ff5e6c037fe1c89a7f73a3e04cf0377cb4d017014976433f3", 180 | "sha256:827dc04b8fa7d07c44de11fabbc888e627fa8293b695e0f99cb544fdfa1bf0d1", 181 | "sha256:b4c423ab23291d3945ac61346feeb9a0dc4184999ede5e7c43e1ffb975130ae6", 182 | "sha256:bc6bced57f826ca7cb5125a10b23fd0f2fff3b7c4701d64c439a300ce665fff8", 183 | "sha256:c01b880ec30b5a6e6aa67b09a2fe3fb30473008c85cd6a67359a1b15ed6d83a4", 184 | "sha256:ca233c64c6e40eaa6c66ef97058cdc80e8d0157a443655baa1b2966e812807ca", 185 | "sha256:e863072cdf4c72eebf179342c94e6989c67185842d9997960b3e69290b2fa269" 186 | ], 187 | "version": "==3.12" 188 | }, 189 | "scikit-learn": { 190 | "hashes": [ 191 | "sha256:13136c6e4f6b808569f7f59299d439b2cd718f85d72ea14b5b6077d44ebc7d17", 192 | "sha256:370919e3148253fd6552496c33a1e3d78290a336fc8d1b9349d9e9770fae6ec0", 193 | "sha256:3775cca4ce3f94508bb7c8a6b113044b78c16b0a30a5c169ddeb6b9fe57a8a72", 194 | "sha256:42f3c5bd893ed73bf47ccccf04dfb98fae743f397d688bb58c2238c0e6ec15d2", 195 | "sha256:56cfa19c31edf62e6414da0a337efee37a4af488b135640e67238786b9be6ab3", 196 | "sha256:5c9ff456d67ef9094e5ea272fff2be05d399a47fc30c6c8ed653b94bdf787bd1", 197 | "sha256:5ca0ad32ee04abe0d4ba02c8d89d501b4e5e0304bdf4d45c2e9875a735b323a0", 198 | "sha256:5db9e68a384ce80a17fc449d4d5d9b45025fe17cf468429599bf404eccb51049", 199 | "sha256:6e0899953611d0c47c0d49c5950082ab016b38811fced91cd2dcc889dd94f50a", 200 | "sha256:72c194c5092e921d6107a8de8a5adae58c35bbc54e030ba624b6f02fd823bb21", 201 | "sha256:871669cdb5b3481650fe3adff46eb97c455e30ecdc307eaf382ef90d4e2570ab", 202 | "sha256:873245b03361710f47c5410a050dc56ee8ae97b9f8dcc6e3a81521ca2b64ad10", 203 | "sha256:8b17fc29554c5c98d88142f895516a5bec2b6b61daa815e1193a64c868ad53d2", 204 | "sha256:95b155ef6bf829ddfba6026f100ba8e4218b7171ecab97b2163bc9e8d206848f", 205 | "sha256:a21cf8217e31a9e8e32c559246e05e6909981816152406945ae2e3e244dfcc1f", 206 | "sha256:a58746d4f389ea7df1d908dba8b52f709835f91c342f459a3ade5424330c69d1", 207 | "sha256:b2a10e2f9b73de10d8486f7a23549093436062b69139158802910a0f154aa53b", 208 | "sha256:ba3fd442ae1a46830789b3578867daaf2c8409dcca6bf192e30e85beeabbfc2f", 209 | "sha256:ce78bf4d10bd7e28807c36c6d2ab25a9934aaf80906ad987622a5e45627d91a2", 210 | "sha256:d384e6f9a055b7a43492f9d27779adb717eb5dcf78b0603b01d0f070a608d241", 211 | "sha256:d4da369614e55540c7e830ccdd17ab4fe5412ff8e803a4906d3ece393e2e3a63", 212 | "sha256:ddc1eb10138ae93c136cc4b5945d3977f302b5d693592a4731b2805a7d7f2a74", 213 | "sha256:e54a3dd1fe1f8124de90b93c48d120e6da2ea8df29b6895325df01ddc1bd8e26", 214 | "sha256:ee8c3b1898c728b6e5b5659c233f547700a1fea13ce876b6fe7d3434c70cc0e0", 215 | "sha256:f528c4b2bba652cf116f5cccf36f4db95a7f9cbfcd1ee549c4e8d0f8628783b5", 216 | "sha256:f9abae483f4d52acd6f660addb1b67e35dc5748655250af479de2ea6aefc6df0", 217 | "sha256:fdc39e89bd3466befb76dfc0c258d4ccad159df974954a87de3be5759172a067" 218 | ], 219 | "version": "==0.19.1" 220 | }, 221 | "scipy": { 222 | "hashes": [ 223 | "sha256:0611ee97296265af4a21164a5323f8c1b4e8e15c582d3dfa7610825900136bb7", 224 | "sha256:08237eda23fd8e4e54838258b124f1cd141379a5f281b0a234ca99b38918c07a", 225 | "sha256:0e645dbfc03f279e1946cf07c9c754c2a1859cb4a41c5f70b25f6b3a586b6dbd", 226 | "sha256:0e9bb7efe5f051ea7212555b290e784b82f21ffd0f655405ac4f87e288b730b3", 227 | "sha256:108c16640849e5827e7d51023efb3bd79244098c3f21e4897a1007720cb7ce37", 228 | "sha256:340ef70f5b0f4e2b4b43c8c8061165911bc6b2ad16f8de85d9774545e2c47463", 229 | "sha256:3ad73dfc6f82e494195144bd3a129c7241e761179b7cb5c07b9a0ede99c686f3", 230 | "sha256:3b243c77a822cd034dad53058d7c2abf80062aa6f4a32e9799c95d6391558631", 231 | "sha256:404a00314e85eca9d46b80929571b938e97a143b4f2ddc2b2b3c91a4c4ead9c5", 232 | "sha256:423b3ff76957d29d1cce1bc0d62ebaf9a3fdfaf62344e3fdec14619bb7b5ad3a", 233 | "sha256:698c6409da58686f2df3d6f815491fd5b4c2de6817a45379517c92366eea208f", 234 | "sha256:729f8f8363d32cebcb946de278324ab43d28096f36593be6281ca1ee86ce6559", 235 | "sha256:8190770146a4c8ed5d330d5b5ad1c76251c63349d25c96b3094875b930c44692", 236 | "sha256:878352408424dffaa695ffedf2f9f92844e116686923ed9aa8626fc30d32cfd1", 237 | "sha256:8f841bbc21d3dad2111a94c490fb0a591b8612ffea86b8e5571746ae76a3deac", 238 | "sha256:c22b27371b3866c92796e5d7907e914f0e58a36d3222c5d436ddd3f0e354227a", 239 | "sha256:d0cdd5658b49a722783b8b4f61a6f1f9c75042d0e29a30ccb6cacc9b25f6d9e2", 240 | "sha256:d8491d4784aceb1f100ddb8e31239c54e4afab8d607928a9f7ef2469ec35ae01", 241 | "sha256:dfc5080c38dde3f43d8fbb9c0539a7839683475226cf83e4b24363b227dfe552", 242 | "sha256:e24e22c8d98d3c704bb3410bce9b69e122a8de487ad3dbfe9985d154e5c03a40", 243 | "sha256:e7a01e53163818d56eabddcafdc2090e9daba178aad05516b20c6591c4811020", 244 | "sha256:ee677635393414930541a096fc8e61634304bb0153e4e02b75685b11eba14cae", 245 | "sha256:f0521af1b722265d824d6ad055acfe9bd3341765735c44b5a4d0069e189a0f40" 246 | ], 247 | "version": "==1.1.0" 248 | }, 249 | "six": { 250 | "hashes": [ 251 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", 252 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" 253 | ], 254 | "version": "==1.11.0" 255 | }, 256 | "sqlalchemy": { 257 | "hashes": [ 258 | "sha256:d6cda03b0187d6ed796ff70e87c9a7dce2c2c9650a7bc3c022cd331416853c31" 259 | ], 260 | "version": "==1.2.7" 261 | }, 262 | "tqdm": { 263 | "hashes": [ 264 | "sha256:9fc19da10d7c962613cbcb9cdced41230deb31d9e20332da84c96917ff534281", 265 | "sha256:ce205451a27b6050faed0bb2bcbea96c6a550f8c27cd2b5441d72e948113ad18" 266 | ], 267 | "version": "==4.23.3" 268 | } 269 | }, 270 | "develop": { 271 | "astroid": { 272 | "hashes": [ 273 | "sha256:35cfae47aac19c7b407b7095410e895e836f2285ccf1220336afba744cc4c5f2", 274 | "sha256:38186e481b65877fd8b1f9acc33e922109e983eb7b6e487bd4c71002134ad331" 275 | ], 276 | "version": "==1.6.3" 277 | }, 278 | "backcall": { 279 | "hashes": [ 280 | "sha256:38ecd85be2c1e78f77fd91700c76e14667dc21e2713b63876c0eb901196e01e4", 281 | "sha256:bbbf4b1e5cd2bdb08f915895b51081c041bac22394fdfcfdfbe9f14b77c08bf2" 282 | ], 283 | "version": "==0.1.0" 284 | }, 285 | "decorator": { 286 | "hashes": [ 287 | "sha256:2c51dff8ef3c447388fe5e4453d24a2bf128d3a4c32af3fabef1f01c6851ab82", 288 | "sha256:c39efa13fbdeb4506c476c9b3babf6a718da943dab7811c206005a4a956c080c" 289 | ], 290 | "version": "==4.3.0" 291 | }, 292 | "ipython": { 293 | "hashes": [ 294 | "sha256:a0c96853549b246991046f32d19db7140f5b1a644cc31f0dc1edc86713b7676f", 295 | "sha256:eca537aa61592aca2fef4adea12af8e42f5c335004dfa80c78caf80e8b525e5c" 296 | ], 297 | "index": "pypi", 298 | "version": "==6.4.0" 299 | }, 300 | "ipython-genutils": { 301 | "hashes": [ 302 | "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", 303 | "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8" 304 | ], 305 | "version": "==0.2.0" 306 | }, 307 | "isort": { 308 | "hashes": [ 309 | "sha256:1153601da39a25b14ddc54955dbbacbb6b2d19135386699e2ad58517953b34af", 310 | "sha256:b9c40e9750f3d77e6e4d441d8b0266cf555e7cdabdcff33c4fd06366ca761ef8", 311 | "sha256:ec9ef8f4a9bc6f71eec99e1806bfa2de401650d996c59330782b89a5555c1497" 312 | ], 313 | "version": "==4.3.4" 314 | }, 315 | "jedi": { 316 | "hashes": [ 317 | "sha256:1972f694c6bc66a2fac8718299e2ab73011d653a6d8059790c3476d2353b99ad", 318 | "sha256:5861f6dc0c16e024cbb0044999f9cf8013b292c05f287df06d3d991a87a4eb89" 319 | ], 320 | "version": "==0.12.0" 321 | }, 322 | "lazy-object-proxy": { 323 | "hashes": [ 324 | "sha256:0ce34342b419bd8f018e6666bfef729aec3edf62345a53b537a4dcc115746a33", 325 | "sha256:1b668120716eb7ee21d8a38815e5eb3bb8211117d9a90b0f8e21722c0758cc39", 326 | "sha256:209615b0fe4624d79e50220ce3310ca1a9445fd8e6d3572a896e7f9146bbf019", 327 | "sha256:27bf62cb2b1a2068d443ff7097ee33393f8483b570b475db8ebf7e1cba64f088", 328 | "sha256:27ea6fd1c02dcc78172a82fc37fcc0992a94e4cecf53cb6d73f11749825bd98b", 329 | "sha256:2c1b21b44ac9beb0fc848d3993924147ba45c4ebc24be19825e57aabbe74a99e", 330 | "sha256:2df72ab12046a3496a92476020a1a0abf78b2a7db9ff4dc2036b8dd980203ae6", 331 | "sha256:320ffd3de9699d3892048baee45ebfbbf9388a7d65d832d7e580243ade426d2b", 332 | "sha256:50e3b9a464d5d08cc5227413db0d1c4707b6172e4d4d915c1c70e4de0bbff1f5", 333 | "sha256:5276db7ff62bb7b52f77f1f51ed58850e315154249aceb42e7f4c611f0f847ff", 334 | "sha256:61a6cf00dcb1a7f0c773ed4acc509cb636af2d6337a08f362413c76b2b47a8dd", 335 | "sha256:6ae6c4cb59f199d8827c5a07546b2ab7e85d262acaccaacd49b62f53f7c456f7", 336 | "sha256:7661d401d60d8bf15bb5da39e4dd72f5d764c5aff5a86ef52a042506e3e970ff", 337 | "sha256:7bd527f36a605c914efca5d3d014170b2cb184723e423d26b1fb2fd9108e264d", 338 | "sha256:7cb54db3535c8686ea12e9535eb087d32421184eacc6939ef15ef50f83a5e7e2", 339 | "sha256:7f3a2d740291f7f2c111d86a1c4851b70fb000a6c8883a59660d95ad57b9df35", 340 | "sha256:81304b7d8e9c824d058087dcb89144842c8e0dea6d281c031f59f0acf66963d4", 341 | "sha256:933947e8b4fbe617a51528b09851685138b49d511af0b6c0da2539115d6d4514", 342 | "sha256:94223d7f060301b3a8c09c9b3bc3294b56b2188e7d8179c762a1cda72c979252", 343 | "sha256:ab3ca49afcb47058393b0122428358d2fbe0408cf99f1b58b295cfeb4ed39109", 344 | "sha256:bd6292f565ca46dee4e737ebcc20742e3b5be2b01556dafe169f6c65d088875f", 345 | "sha256:cb924aa3e4a3fb644d0c463cad5bc2572649a6a3f68a7f8e4fbe44aaa6d77e4c", 346 | "sha256:d0fc7a286feac9077ec52a927fc9fe8fe2fabab95426722be4c953c9a8bede92", 347 | "sha256:ddc34786490a6e4ec0a855d401034cbd1242ef186c20d79d2166d6a4bd449577", 348 | "sha256:e34b155e36fa9da7e1b7c738ed7767fc9491a62ec6af70fe9da4a057759edc2d", 349 | "sha256:e5b9e8f6bda48460b7b143c3821b21b452cb3a835e6bbd5dd33aa0c8d3f5137d", 350 | "sha256:e81ebf6c5ee9684be8f2c87563880f93eedd56dd2b6146d8a725b50b7e5adb0f", 351 | "sha256:eb91be369f945f10d3a49f5f9be8b3d0b93a4c2be8f8a5b83b0571b8123e0a7a", 352 | "sha256:f460d1ceb0e4a5dcb2a652db0904224f367c9b3c1470d5a7683c0480e582468b" 353 | ], 354 | "version": "==1.3.1" 355 | }, 356 | "mccabe": { 357 | "hashes": [ 358 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 359 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 360 | ], 361 | "version": "==0.6.1" 362 | }, 363 | "nose": { 364 | "hashes": [ 365 | "sha256:9ff7c6cc443f8c51994b34a667bbcf45afd6d945be7477b52e97516fd17c53ac", 366 | "sha256:dadcddc0aefbf99eea214e0f1232b94f2fa9bd98fa8353711dacb112bfcbbb2a", 367 | "sha256:f1bffef9cbc82628f6e7d7b40d7e255aefaa1adb6a1b1d26c69a8b79e6208a98" 368 | ], 369 | "index": "pypi", 370 | "version": "==1.3.7" 371 | }, 372 | "parso": { 373 | "hashes": [ 374 | "sha256:62bd6bf7f04ab5c817704ff513ef175328676471bdef3629d4bdd46626f75551", 375 | "sha256:a75a304d7090d2c67bd298091c14ef9d3d560e3c53de1c239617889f61d1d307" 376 | ], 377 | "version": "==0.2.0" 378 | }, 379 | "pexpect": { 380 | "hashes": [ 381 | "sha256:9783f4644a3ef8528a6f20374eeb434431a650c797ca6d8df0d81e30fffdfa24", 382 | "sha256:9f8eb3277716a01faafaba553d629d3d60a1a624c7cf45daa600d2148c30020c" 383 | ], 384 | "markers": "sys_platform != 'win32'", 385 | "version": "==4.5.0" 386 | }, 387 | "pickleshare": { 388 | "hashes": [ 389 | "sha256:84a9257227dfdd6fe1b4be1319096c20eb85ff1e82c7932f36efccfe1b09737b", 390 | "sha256:c9a2541f25aeabc070f12f452e1f2a8eae2abd51e1cd19e8430402bdf4c1d8b5" 391 | ], 392 | "version": "==0.7.4" 393 | }, 394 | "prompt-toolkit": { 395 | "hashes": [ 396 | "sha256:1df952620eccb399c53ebb359cc7d9a8d3a9538cb34c5a1344bdbeb29fbcc381", 397 | "sha256:3f473ae040ddaa52b52f97f6b4a493cfa9f5920c255a12dc56a7d34397a398a4", 398 | "sha256:858588f1983ca497f1cf4ffde01d978a3ea02b01c8a26a8bbc5cd2e66d816917" 399 | ], 400 | "version": "==1.0.15" 401 | }, 402 | "ptyprocess": { 403 | "hashes": [ 404 | "sha256:e64193f0047ad603b71f202332ab5527c5e52aa7c8b609704fc28c0dc20c4365", 405 | "sha256:e8c43b5eee76b2083a9badde89fd1bbce6c8942d1045146e100b7b5e014f4f1a" 406 | ], 407 | "version": "==0.5.2" 408 | }, 409 | "pygments": { 410 | "hashes": [ 411 | "sha256:78f3f434bcc5d6ee09020f92ba487f95ba50f1e3ef83ae96b9d5ffa1bab25c5d", 412 | "sha256:dbae1046def0efb574852fab9e90209b23f556367b5a320c0bcb871c77c3e8cc" 413 | ], 414 | "version": "==2.2.0" 415 | }, 416 | "pylint": { 417 | "hashes": [ 418 | "sha256:0b7e6b5d9f1d4e0b554b5d948f14ed7969e8cdf9a0120853e6e5af60813b18ab", 419 | "sha256:34738a82ab33cbd3bb6cd4cef823dbcabdd2b6b48a4e3a3054a2bbbf0c712be9" 420 | ], 421 | "index": "pypi", 422 | "version": "==1.8.4" 423 | }, 424 | "rope": { 425 | "hashes": [ 426 | "sha256:a09edfd2034fd50099a67822f9bd851fbd0f4e98d3b87519f6267b60e50d80d1" 427 | ], 428 | "index": "pypi", 429 | "version": "==0.10.7" 430 | }, 431 | "simplegeneric": { 432 | "hashes": [ 433 | "sha256:dc972e06094b9af5b855b3df4a646395e43d1c9d0d39ed345b7393560d0b9173" 434 | ], 435 | "version": "==0.8.1" 436 | }, 437 | "six": { 438 | "hashes": [ 439 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", 440 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" 441 | ], 442 | "version": "==1.11.0" 443 | }, 444 | "traitlets": { 445 | "hashes": [ 446 | "sha256:9c4bd2d267b7153df9152698efb1050a5d84982d3384a37b2c1f7723ba3e7835", 447 | "sha256:c6cb5e6f57c5a9bdaa40fa71ce7b4af30298fbab9ece9815b5d995ab6217c7d9" 448 | ], 449 | "version": "==4.3.2" 450 | }, 451 | "wcwidth": { 452 | "hashes": [ 453 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 454 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 455 | ], 456 | "version": "==0.1.7" 457 | }, 458 | "wrapt": { 459 | "hashes": [ 460 | "sha256:d4d560d479f2c21e1b5443bbd15fe7ec4b37fe7e53d335d3b9b0a7b1226fe3c6" 461 | ], 462 | "version": "==1.10.11" 463 | } 464 | } 465 | } 466 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # suplearn-clone-detection 2 | 3 | [![CircleCI](https://circleci.com/gh/danhper/suplearn-clone-detection.svg?style=svg&circle-token=738ac3f3e6453f2beef09c2bf1a2e72d2a959ee0)](https://circleci.com/gh/tuvistavie/suplearn-clone-detection) 4 | 5 | ## Setup 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | python setup.py develop 10 | ``` 11 | 12 | Note that Tensorflow needs to be installed separately using [these steps][tensorflow-install] 13 | 14 | ## Configuration 15 | 16 | First, copy [config.yml.example](./config.yml.example) to `config.yml` 17 | 18 | ``` 19 | cp config.yml.example config.yml 20 | ``` 21 | 22 | Then, modify the content of `config.yml`. The configuration file is 23 | self-documented and the most important parameters we used can be found in the paper. 24 | 25 | ## Dataset 26 | 27 | We train our model using a dataset with data extracted from the competitive programming website AtCoder: https://atcoder.jp. 28 | The dataset can be downloaded as an SQLite3 database : [java-python-clones.db.gz][cross-language-clones-db]. 29 | You will most likely need to decompress the database before using it. 30 | We also provide the raw data as a tarball but it should generally not be needed: [java-python-clones.tar.gz][cross-language-clones-tar]. 31 | The database contains both the text representation and the AST representation 32 | of the source code. All the data is in the `submissions` table. We describe 33 | the different rows of the table below. 34 | 35 | Name | Type | Description 36 | -----|------|------------ 37 | id | INTEGER | Primary key for the submission 38 | url | VARCHAR(255) | URL of the problem on AtCoder 39 | contest_type | VARCHAR(64) | Contest type on AtCoder (beginner or regular) 40 | contest_id | INTEGER | Contest ID on AtCoder 41 | problem_id | INTEGER | Problem ID on AtCoder 42 | problem_title | VARCHAR(255) | Problem title on AtCoder (usually in Japanese) 43 | filename | VARCHAR(255) | Original path of the file 44 | language | VARCHAR(64) | Full name of the language used 45 | language_code | VARCHAR(64) | Short name of the language used 46 | source_length | INTEGER | Source length in bytes 47 | exec_time | INTEGER | Execution time in ms 48 | tokens_count | INTEGER | Number of tokens in the source 49 | source | TEXT | Source code of the submission 50 | ast | TEXT | JSON encoded AST representation of the source code 51 | 52 | The database also contains a `samples` table which should be populated 53 | using the `suplearn-clone generate-dataset` command. 54 | 55 | ## Usage 56 | 57 | The model should already be configured in `config.yml` to use the following steps. 58 | 59 | ### Generating training samples 60 | 61 | Before training the model, the clones pair for training/cross-validation/test must first be generated using the following command. 62 | 63 | ``` 64 | suplearn-clone generate-dataset -c /path/to/config.yml 65 | ``` 66 | 67 | ### Training the model 68 | 69 | Once the data is generated, the model can be trained 70 | by simply using the following command 71 | 72 | ``` 73 | suplearn-clone train -c /path/to/config.yml 74 | ``` 75 | 76 | ### Testing the model 77 | 78 | The model can be evaulated on test data by using the following command: 79 | 80 | ``` 81 | suplearn-clone evaulate -c /path/to/config.yml -m /path/to/model.h5 --data-type= -o results.json 82 | ``` 83 | 84 | Note that `config.yml` should be the same file as the one used for training. 85 | 86 | ## Using pre-trained embeddings 87 | 88 | Pre-trained embeddings can be used by using the `model.languages.n.embeddings` 89 | setting in the configuration file. 90 | This repository does not provide any functionality to train emebddings. 91 | Please check the [bigcode-tools][bigcode-tools] repository for the instructions 92 | on how to train embeddings. 93 | 94 | ## Citing the project 95 | 96 | If you are using this for academic work, we would be thankful if you could cite the following paper. 97 | 98 | ``` 99 | @inproceedings{Perez:2019:CCD:3341883.3341965, 100 | author = {Perez, Daniel and Chiba, Shigeru}, 101 | title = {Cross-language Clone Detection by Learning over Abstract Syntax Trees}, 102 | booktitle = {Proceedings of the 16th International Conference on Mining Software Repositories}, 103 | series = {MSR '19}, 104 | year = {2019}, 105 | location = {Montreal, Quebec, Canada}, 106 | pages = {518--528}, 107 | numpages = {11}, 108 | url = {https://doi.org/10.1109/MSR.2019.00078}, 109 | doi = {10.1109/MSR.2019.00078}, 110 | acmid = {3341965}, 111 | publisher = {IEEE Press}, 112 | address = {Piscataway, NJ, USA}, 113 | keywords = {clone detection, machine learning, source code representation}, 114 | } 115 | ``` 116 | 117 | 118 | [tensorflow-install]: https://www.tensorflow.org/install 119 | [cross-language-clones-db]: https://static.perez.sh/research/2019/cross-language-clone-detection/datasets/java-python-clones.db.gz 120 | [cross-language-clones-tar]: https://static.perez.sh/research/2019/cross-language-clone-detection/datasets/java-python-clones.tar.gz 121 | [bigcode-tools]: https://github.com/danhper/bigcode-tools 122 | -------------------------------------------------------------------------------- /bin/suplearn-clone: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from suplearn_clone_detection import cli 4 | 5 | 6 | cli.run() 7 | -------------------------------------------------------------------------------- /config.yml.example: -------------------------------------------------------------------------------- 1 | references: 2 | # defaults settings per language 3 | language_defaults: &language_defaults 4 | # offset when indexing the vocabulary 5 | # 1 offset is to ignore the padding index 6 | vocabulary_offset: 1 7 | 8 | # class to use to linearize the AST 9 | # DFSTransformer seems to perform better. BFSTransformer 10 | transformer_class_name: DFSTransformer 11 | 12 | # size of the token/node embeddings 13 | embeddings_dimension: 100 14 | 15 | # output dimensions of each stacked LSTM 16 | output_dimensions: [100, 50] 17 | 18 | # whether to use bidirectional LSTM or not 19 | bidirectional_encoding: true 20 | 21 | # the output dimension of the hash layer(s) 22 | hash_dims: [20] 23 | 24 | 25 | model: 26 | languages: 27 | # name of the language 28 | - name: java 29 | # load the defaults set above 30 | # settings can be overriden below if necessary 31 | <<: *language_defaults 32 | 33 | # path to the vocabulary for the language 34 | vocabulary: $HOME/workspaces/research/results/java/data/no-id.tsv 35 | 36 | # path to the pre-trained embeddings for the vocabulary 37 | embeddings: $HOME/workspaces/research/results/python/embeddings/noid-ch1-anc2-nosib-50d-lr001.npy 38 | # maximum length (in AST nodes number) of an input 39 | max_length: 500 40 | - name: python 41 | <<: *language_defaults 42 | vocabulary: $HOME/workspaces/research/results/python/vocabulary/no-id.tsv 43 | # if not embeddings file is passed, embeddings are randomly initialized 44 | # embeddings: $HOME/workspaces/research/results/python/embeddings/noid-ch1-anc2-nosib-50d-lr001.npy 45 | max_length: 400 46 | 47 | # type of loss to use 48 | # this parameter is passed directly to keras 49 | loss: binary_crossentropy 50 | 51 | # how to merge the output of the two LSTMs 52 | # see paper for details 53 | merge_mode: bidistance 54 | 55 | # dimensions of the output of the merged LSTM outputs 56 | merge_output_dim: 64 57 | 58 | # number and dimensions of dense layers after the LSTM 59 | dense_layers: [64, 32] 60 | 61 | # optimizer parameters 62 | optimizer: 63 | # this is passed to keras directly 64 | type: rmsprop 65 | 66 | generator: 67 | # The three following are only needed to generate the SQLite3 DB 68 | # pass to the submissions metadata 69 | submissions_path: $HOME/workspaces/research/dataset/atcoder/submissions.json 70 | # file format of the AST 71 | file_format: multi_file 72 | # path to the files containing the ASTS 73 | asts_path: $HOME/workspaces/research/dataset/atcoder/asts/asts.json 74 | 75 | # path to the SQLite3 DB containing the data 76 | db_path: sqlite:///$HOME/workspaces/research/dataset/atcoder/atcoder.db 77 | 78 | # the split ratio for training/cross-validation/test 79 | split_ratio: [0.8, 0.1, 0.1] 80 | 81 | # the number of samples to generate for each problem 82 | samples_per_problem: 10 83 | 84 | # the maximum distance ratio between the length of the positive 85 | # and the negative sample 86 | # if too high, the network tends to overfit by learning the length 87 | negative_sample_distance: 0.2 88 | 89 | # the weights for positive and negative samples 90 | sample_weights: {0: 1.0, 1: 3.0} 91 | 92 | trainer: 93 | # size of a batch 94 | batch_size: 128 95 | # number of epochs 96 | epochs: 10 97 | 98 | # base directory to output results 99 | output_dir: ./tmp 100 | 101 | # whether to output data for tensorboard or not 102 | tensorboard_logs: True 103 | -------------------------------------------------------------------------------- /ipython_config.py: -------------------------------------------------------------------------------- 1 | # Configuration file for ipython. 2 | 3 | #------------------------------------------------------------------------------ 4 | # InteractiveShellApp(Configurable) configuration 5 | #------------------------------------------------------------------------------ 6 | 7 | ## A Mixin for applications that start InteractiveShell instances. 8 | # 9 | # Provides configurables for loading extensions and executing files as part of 10 | # configuring a Shell environment. 11 | # 12 | # The following methods should be called by the :meth:`initialize` method of the 13 | # subclass: 14 | # 15 | # - :meth:`init_path` 16 | # - :meth:`init_shell` (to be implemented by the subclass) 17 | # - :meth:`init_gui_pylab` 18 | # - :meth:`init_extensions` 19 | # - :meth:`init_code` 20 | 21 | ## Execute the given command string. 22 | #c.InteractiveShellApp.code_to_run = '' 23 | 24 | ## Run the file referenced by the PYTHONSTARTUP environment variable at IPython 25 | # startup. 26 | #c.InteractiveShellApp.exec_PYTHONSTARTUP = True 27 | 28 | ## List of files to run at IPython startup. 29 | #c.InteractiveShellApp.exec_files = [] 30 | 31 | ## lines of code to run at IPython startup. 32 | c.InteractiveShellApp.exec_lines = ['%autoreload 2'] 33 | 34 | ## A list of dotted module names of IPython extensions to load. 35 | c.InteractiveShellApp.extensions = ['autoreload'] 36 | 37 | ## A file to be run 38 | # c.InteractiveShellApp.file_to_run = '' 39 | c.InteractiveShellApp.exec_files = ['.ipython-startup.ipy'] 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /samples/encoder-nn.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from os import path 3 | 4 | import h5py 5 | import numpy as np 6 | import tensorflow as tf 7 | import keras.backend as K 8 | from keras.models import load_model 9 | 10 | from suplearn_clone_detection import layers, util 11 | from suplearn_clone_detection.predictor import Predictor 12 | 13 | #%% 14 | PROJECT_DIR = path.expanduser("~/Documents/organizations/tuvistavie/suplearn-clone-detection") 15 | ROOT_DIR = path.join(PROJECT_DIR, "tmp/20180425-2227") 16 | CONFIG_PATH = path.join(ROOT_DIR, "config.yml") 17 | MODEL_PATH = path.join(ROOT_DIR, "model.h5") 18 | EXAMPLES_PATH = path.join(PROJECT_DIR, "tmp/sample-data/java-java-dev.csv") 19 | 20 | #%% 21 | model = load_model(MODEL_PATH, custom_objects=layers.custom_objects) # type: layers.ModelWrapper 22 | 23 | left_encoder = model.inner_models[0] 24 | right_encoder = model.inner_models[1] 25 | distance_model = model.inner_models[2] 26 | print("MODEL LOADED") 27 | 28 | #%% 29 | 30 | left_encoder.summary() 31 | right_encoder.summary() 32 | distance_model.summary() 33 | 34 | #%% 35 | 36 | with open(EXAMPLES_PATH) as f: 37 | files = [] 38 | for line in f: 39 | pair = line.strip().split("," if "," in line else " ")[:2] 40 | files.append(tuple(path.join(filename) for filename in pair)) 41 | 42 | print("\n".join([" ".join(v) for v in files[:10]])) 43 | 44 | #%% 45 | predictor = Predictor.from_config(CONFIG_PATH, MODEL_PATH, {}) 46 | 47 | #%% 48 | to_predict = files[8:20] 49 | to_predict, input_data, _assumed_false = predictor._generate_vectors(to_predict) 50 | result = model.predict(input_data) 51 | result 52 | 53 | #%% 54 | 55 | sess = K.get_session() 56 | 57 | left = left_encoder(tf.constant(input_data[0])) 58 | right = right_encoder(tf.constant(input_data[1])) 59 | sess.run(distance_model([left, right])) 60 | 61 | #%% 62 | 63 | def read_inputs(filename, files): 64 | with h5py.File(path.join(ROOT_DIR, filename)) as f: 65 | return np.array([f[v][:] for v in files]) 66 | 67 | left_inputs = read_inputs("java-dev-files-left.h5", [v[0] for v in to_predict]) 68 | right_inputs = read_inputs("java-dev-files-right.h5", [v[1] for v in to_predict]) 69 | 70 | print(left_inputs[0], sess.run(left)[0]) 71 | print(right_inputs[0], sess.run(right)[0]) 72 | 73 | #%% 74 | 75 | def find_nearest_neighbors(input_array, files, batch_size=512): 76 | left_input = tf.reshape( 77 | tf.tile(tf.constant(input_array), [batch_size]), 78 | (batch_size, -1)) 79 | 80 | results = [] 81 | def run_batch(batch): 82 | filenames = [v[0] for v in batch] 83 | right_input = tf.constant(np.array([v[1][:] for v in batch])) 84 | res = distance_model([left_input[:len(batch)], right_input]) 85 | distances = sess.run(res) 86 | results.extend(zip(filenames, distances)) 87 | 88 | batch = [] 89 | def visitor(name, value): 90 | if not isinstance(value, h5py.Dataset): 91 | return 92 | batch.append((name, value)) 93 | if len(batch) >= batch_size: 94 | run_batch(batch) 95 | batch.clear() 96 | 97 | files.visititems(visitor) 98 | if batch: 99 | run_batch(batch) 100 | 101 | results = np.array(results, dtype=[("filename", "S256"), ("distance", float)]) 102 | return np.sort(results, order="distance")[::-1] 103 | 104 | with h5py.File(path.join(ROOT_DIR, "java-dev-files-right.h5")) as f: 105 | result = find_nearest_neighbors(left_inputs[0], f) 106 | 107 | print("results for {0}".format(to_predict[0][0])) 108 | print(result[:10]) 109 | 110 | 111 | # sess.run(tf.reshape(tf.tile(tf.constant(left_inputs[0]), [5]), (5, -1))) 112 | -------------------------------------------------------------------------------- /scripts/create_submissions.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import path 3 | from typing import List 4 | 5 | from suplearn_clone_detection.config import Config 6 | from suplearn_clone_detection.ast_loader import ASTLoader 7 | from suplearn_clone_detection import entities, database 8 | from suplearn_clone_detection.database import Session 9 | 10 | 11 | SQL_PATH = path.realpath(path.join(path.dirname(path.dirname(__file__)), "sql")) 12 | KNOWN_LANGUAGES = ["java", "python"] 13 | 14 | 15 | class SubmissionCreator: 16 | def __init__(self, config: Config, known_languages: List[str] = None): 17 | if known_languages is None: 18 | known_languages = KNOWN_LANGUAGES 19 | self.config = config 20 | self.submissions_dir = path.dirname(self.config.generator.submissions_path) 21 | self.known_languages = known_languages 22 | self.ast_loader = ASTLoader( 23 | config.generator.asts_path, 24 | config.generator.filenames_path, 25 | config.generator.file_format 26 | ) 27 | 28 | def normalize_language(self, language): 29 | for known_lang in self.known_languages: 30 | if language.startswith(known_lang): 31 | return known_lang 32 | return None 33 | 34 | def get_source(self, submission_obj): 35 | filepath = path.join(self.submissions_dir, submission_obj["file"]) 36 | with open(filepath) as f: 37 | return f.read() 38 | 39 | def make_submission(self, submission_obj): 40 | ast = self.ast_loader.get_ast(submission_obj["file"]) 41 | return entities.Submission( 42 | id=submission_obj["id"], 43 | url=submission_obj["submission_url"], 44 | contest_type=submission_obj["contest_type"], 45 | contest_id=submission_obj["contest_id"], 46 | problem_id=submission_obj["problem_id"], 47 | problem_title=submission_obj["problem_title"], 48 | filename=path.basename(submission_obj["file"]), 49 | language=submission_obj["language"], 50 | language_code=self.normalize_language(submission_obj["language"]), 51 | source_length=submission_obj["source_length"], 52 | exec_time=submission_obj["exec_time"], 53 | tokens_count=len(ast), 54 | source=self.get_source(submission_obj), 55 | ast=json.dumps(ast), 56 | ) 57 | 58 | def load_submissions(self): 59 | submissions = [] 60 | with open(self.config.generator.submissions_path, "r") as f: 61 | for submission_obj in json.load(f): 62 | if self.ast_loader.has_file(submission_obj["file"]): 63 | submissions.append(self.make_submission(submission_obj)) 64 | return submissions 65 | 66 | @staticmethod 67 | def create_db(): 68 | sqlite_conn = Session.connection().connection.connection 69 | with open(path.join(SQL_PATH, "create_tables.sql")) as f: 70 | sqlite_conn.executescript(f.read()) 71 | Session.commit() 72 | 73 | def create_submission(self): 74 | submissions = self.load_submissions() 75 | Session.bulk_save_objects(submissions) 76 | Session.commit() 77 | 78 | 79 | def main(): 80 | config = Config.from_file("./config.yml") 81 | database.bind_db(config.generator.db_path) 82 | creator = SubmissionCreator(config) 83 | creator.create_db() 84 | creator.create_submission() 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /scripts/create_test_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from os import path 4 | import json 5 | import random 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(prog="create-test-data") 9 | parser.add_argument("-o", "--output", default="to_check_cross_lang.txt") 10 | parser.add_argument("--max-tokens", type=int, default=250) 11 | parser.add_argument("-n", "--projects-count", default=50, type=int) 12 | parser.add_argument("--min-files", default=4, type=int) 13 | parser.add_argument("--max-files", default=20, type=int) 14 | parser.add_argument("--contest-type", choices=["r", "b"]) 15 | args = parser.parse_args() 16 | 17 | 18 | with open("./asts/asts.jsonl") as f: 19 | all_asts = [json.loads(v) for v in f if v] 20 | 21 | with open("./asts/asts.txt") as f: 22 | all_files = [v.strip() for v in f if v] 23 | 24 | short_files = [name 25 | for name, ast in zip(all_files, all_asts) 26 | if len(ast) < args.max_tokens] 27 | 28 | grouped = {} 29 | 30 | for file in short_files: 31 | if args.contest_type and not file.startswith("src/{0}".format(args.contest_type)): 32 | continue 33 | grouped.setdefault(path.dirname(file), []) 34 | grouped[path.dirname(file)].append(file) 35 | 36 | to_check = set() 37 | for key in random.sample(list(grouped), args.projects_count): 38 | group = grouped[key] 39 | sample_count = min(random.randint(args.min_files, args.max_files), len(group)) 40 | for file in random.sample(group, sample_count): 41 | to_check.add(file) 42 | 43 | with open(args.output, "w") as f: 44 | for name in to_check: 45 | print(name, file=f) 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name="suplearn_clone_detection", 6 | version="0.1.0", 7 | packages=find_packages(), 8 | install_requires=[ 9 | "numpy", 10 | "keras", 11 | "pyyaml", 12 | "h5py", 13 | "scikit-learn", 14 | "tqdm", 15 | "matplotlib", 16 | "sqlalchemy", 17 | ] 18 | ) 19 | -------------------------------------------------------------------------------- /sql/create_tables.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS submissions ( 2 | id INTEGER PRIMARY KEY, 3 | url VARCHAR(255), 4 | contest_type VARCHAR(64) NOT NULL, 5 | contest_id INTEGER NOT NULL, 6 | problem_id INTEGER NOT NULL, 7 | problem_title VARCHAR(255), 8 | filename VARCHAR(255) NOT NULL, 9 | language VARCHAR(64), 10 | language_code VARCHAR(64) NOT NULL, 11 | source_length INTEGER, 12 | exec_time INTEGER, 13 | tokens_count INTEGER, 14 | source TEXT NOT NULL, 15 | ast TEXT NOT NULL 16 | ); 17 | 18 | CREATE INDEX IF NOT EXISTS contest_idx ON submissions (contest_id, contest_type); 19 | CREATE INDEX IF NOT EXISTS problem_idx ON submissions (contest_id, contest_type, problem_id); 20 | CREATE INDEX IF NOT EXISTS language_idx ON submissions (language_code); 21 | CREATE INDEX IF NOT EXISTS token_count_idx ON submissions (tokens_count); 22 | CREATE INDEX IF NOT EXISTS token_count_idx ON submissions (tokens_count, language_code); 23 | 24 | CREATE VIEW IF NOT EXISTS submissions_stats AS 25 | SELECT id, contest_id, contest_type, problem_id, problem_title, filename, 26 | language, language_code, source_length, exec_time, tokens_count, url 27 | FROM submissions; 28 | 29 | CREATE TABLE IF NOT EXISTS samples ( 30 | id INTEGER PRIMARY KEY, 31 | 32 | anchor_id INTEGER NOT NULL, 33 | positive_id INTEGER, 34 | negative_id INTEGER, 35 | 36 | dataset_name STRING NOT NULL, 37 | config_checksum STRING NOT NULL, 38 | 39 | FOREIGN KEY (anchor_id) REFERENCES submissions (id), 40 | FOREIGN KEY (positive_id) REFERENCES submissions (id), 41 | FOREIGN KEY (negative_id) REFERENCES submissions (id) 42 | ); 43 | 44 | CREATE INDEX IF NOT EXISTS dataset_name_idx ON samples (dataset_name); 45 | CREATE INDEX IF NOT EXISTS anchor_idx ON samples (anchor_id); 46 | CREATE INDEX IF NOT EXISTS positive_idx ON samples (positive_id); 47 | CREATE INDEX IF NOT EXISTS negative_idx ON samples (negative_id); 48 | CREATE INDEX IF NOT EXISTS config_checksum_idx ON samples (config_checksum); 49 | -------------------------------------------------------------------------------- /suplearn_clone_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danhper/suplearn-clone-detection/81c485c1d7ef98c417427a6db5cae7ae2b2541ce/suplearn_clone_detection/__init__.py -------------------------------------------------------------------------------- /suplearn_clone_detection/ast.py: -------------------------------------------------------------------------------- 1 | class Node: 2 | def __init__(self, node_id, node_type, value=None): 3 | self.node_id = node_id 4 | self.type = node_type 5 | self.value = value 6 | self.children = [] 7 | 8 | def get(self, key): 9 | return getattr(self, key) 10 | 11 | def __getitem__(self, key): 12 | return self.get(key) 13 | 14 | def dfs(self, reverse=False): 15 | children = reversed(self.children) if reverse else self.children 16 | yield self 17 | for child in children: 18 | yield from child.dfs(reverse) 19 | 20 | def bfs(self): 21 | queue = [self] 22 | while queue: 23 | node = queue.pop() 24 | yield node 25 | for child in node.children: 26 | queue.insert(0, child) 27 | 28 | def from_list(list_ast): 29 | def create_node(index): 30 | node_info = list_ast[index] 31 | node = Node(node_info.get("id"), node_info["type"], node_info.get("value")) 32 | for child_index in node_info.get("children", []): 33 | node.children.append(create_node(child_index)) 34 | return node 35 | return create_node(0) 36 | -------------------------------------------------------------------------------- /suplearn_clone_detection/ast_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | import gzip 4 | import json 5 | from os import path 6 | 7 | 8 | def tautology(*_args, **_kwargs): 9 | return True 10 | 11 | 12 | class ASTLoader: 13 | def __init__(self, asts_path, filenames_path=None, file_format="multi_file"): 14 | if filenames_path is None: 15 | filenames_path = path.splitext(asts_path)[0] + ".txt" 16 | if file_format == "multi_file" and not path.exists(filenames_path): 17 | logging.warning("%s does not exist, falling back to single file format", filenames_path) 18 | file_format = "single_file" 19 | if file_format == "multi_file": 20 | if filenames_path is None: 21 | filenames_path = path.splitext(asts_path)[0] + ".txt" 22 | self._load_asts(asts_path) 23 | self._load_names(filenames_path) 24 | elif file_format == "single_file": 25 | self._load_single_file_format(asts_path) 26 | else: 27 | raise ValueError("unknown format {0}".format(file_format)) 28 | 29 | def _load_single_file_format(self, filepath): 30 | with self._open(filepath) as f: 31 | entries = [json.loads(row) for row in f] 32 | self.names = {entry["filename"]: index for (index, entry) in enumerate(entries)} 33 | self.asts = [entry["tokens"] for entry in entries] 34 | 35 | def _load_names(self, names_path): 36 | with self._open(names_path) as f: 37 | self.names = {filename.strip(): index for (index, filename) in enumerate(f)} 38 | 39 | def _load_asts(self, asts_path): 40 | with self._open(asts_path) as f: 41 | self.asts = [json.loads(ast) for ast in f] 42 | 43 | def get_ast(self, filename): 44 | return self.asts[self.names[filename]] 45 | 46 | def random_ast(self, predicate=tautology): 47 | keys = list(self.names.keys()) 48 | while True: 49 | name = random.choice(keys) 50 | ast = self.get_ast(name) 51 | if predicate(name, ast): 52 | return name, ast 53 | 54 | def has_file(self, filename): 55 | return filename in self.names 56 | 57 | @staticmethod 58 | def _open(filename): 59 | if filename.endswith(".gz"): 60 | return gzip.open(filename) 61 | else: 62 | return open(filename) 63 | -------------------------------------------------------------------------------- /suplearn_clone_detection/ast_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Type, List 2 | import sys 3 | 4 | import numpy as np 5 | 6 | from suplearn_clone_detection import ast 7 | from suplearn_clone_detection.config import LanguageConfig 8 | from suplearn_clone_detection.vocabulary import Vocabulary 9 | 10 | thismodule = sys.modules[__name__] 11 | 12 | 13 | class ASTTransformer: 14 | def __init__(self, lang, vocabulary, vocabulary_offset=0, input_length=None): 15 | self.language = lang 16 | self.vocabulary = vocabulary 17 | self.vocabulary_offset = np.int32(vocabulary_offset) 18 | self.input_length = input_length 19 | self.total_input_length = input_length 20 | 21 | def transform_ast(self, list_ast): 22 | raise NotImplementedError() 23 | 24 | def nodes_to_indexes(self, nodes): 25 | indexes = [self.node_index(node) for node in nodes] 26 | if not self.input_length: 27 | return indexes 28 | if len(indexes) > self.input_length: 29 | return False 30 | return self.pad(indexes) 31 | 32 | @property 33 | def split_input(self): 34 | return False 35 | 36 | def node_index(self, node): 37 | return self.vocabulary.index(node) + self.vocabulary_offset 38 | 39 | def pad(self, indexes, pad_value=np.int32(0)): 40 | return indexes + [pad_value] * (self.input_length - len(indexes)) 41 | 42 | 43 | class DFSTransformer(ASTTransformer): 44 | def transform_ast(self, list_ast): 45 | return self.nodes_to_indexes(list_ast) 46 | 47 | 48 | class BFSTransformer(ASTTransformer): 49 | def transform_ast(self, list_ast): 50 | ast_root = ast.from_list(list_ast) 51 | return self.nodes_to_indexes(ast_root.bfs()) 52 | 53 | 54 | class MultiTransformer(ASTTransformer): 55 | def __init__(self, lang, vocabulary, vocabulary_offset=0, input_length=None): 56 | super(MultiTransformer, self).__init__(lang, vocabulary, vocabulary_offset, input_length) 57 | if self.total_input_length: 58 | self.total_input_length *= 2 59 | 60 | @property 61 | def split_input(self): 62 | return True 63 | 64 | 65 | class DBFSTransformer(MultiTransformer): 66 | def transform_ast(self, list_ast): 67 | ast_root = ast.from_list(list_ast) 68 | return self.nodes_to_indexes(ast_root.dfs()) + \ 69 | self.nodes_to_indexes(ast_root.bfs()) 70 | 71 | 72 | class BiDFSTransformer(MultiTransformer): 73 | def transform_ast(self, list_ast): 74 | ast_root = ast.from_list(list_ast) 75 | return self.nodes_to_indexes(ast_root.dfs()) + \ 76 | self.nodes_to_indexes(ast_root.dfs(reverse=True)) 77 | 78 | 79 | def get_class(language_config: LanguageConfig) -> Type[ASTTransformer]: 80 | return getattr(thismodule, language_config.transformer_class_name) 81 | 82 | 83 | def create_all(languages: List[LanguageConfig]) -> List[ASTTransformer]: 84 | return [create(lang) for lang in languages] 85 | 86 | 87 | def create(language_config: LanguageConfig) -> ASTTransformer: 88 | vocab = Vocabulary.from_file(language_config.vocabulary) 89 | language_config.vocabulary_size = len(vocab) 90 | transformer_class = get_class(language_config) 91 | return transformer_class(language_config.name, 92 | vocab, 93 | vocabulary_offset=language_config.vocabulary_offset, 94 | input_length=language_config.input_length) 95 | -------------------------------------------------------------------------------- /suplearn_clone_detection/callbacks.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import Callback 2 | 3 | from suplearn_clone_detection import evaluator 4 | 5 | 6 | class ModelResultsTracker: 7 | def __init__(self, data, model, comparator=None): 8 | self.data = data 9 | self.model = model 10 | self._results_cache = {} 11 | self.best_epoch = -1 12 | self.best_results = None 13 | self.comparator = comparator 14 | self.evaluator = evaluator.Evaluator(self.model) 15 | if self.comparator is None: 16 | self.comparator = self.default_comparator 17 | 18 | def compute_results(self, epoch): 19 | if epoch in self._results_cache: 20 | return self._results_cache[epoch] 21 | results = self.evaluator.evaluate(self.data) 22 | if self.best_results is None or self.comparator(results, self.best_results): 23 | self.best_epoch = epoch 24 | self.best_results = results 25 | self._results_cache[epoch] = results 26 | return results 27 | 28 | def is_best_epoch(self, epoch): 29 | if not epoch in self._results_cache: 30 | self.compute_results(epoch) 31 | return self.best_epoch == epoch 32 | 33 | @staticmethod 34 | def default_comparator(current_results, best_results): 35 | return current_results["f1"] > best_results["f1"] 36 | 37 | 38 | class ModelEvaluator(Callback): 39 | def __init__(self, results_tracker, filepath=None, quiet=False, save_best_only=True): 40 | super(ModelEvaluator, self).__init__() 41 | self.filepath = filepath 42 | self.quiet = quiet 43 | self.save_best_only = save_best_only 44 | self.results_tracker = results_tracker 45 | self.best_results = None 46 | 47 | def on_epoch_end(self, epoch, logs=None): 48 | results = self.results_tracker.compute_results(epoch) 49 | if not self.quiet: 50 | print("\nDev set results") 51 | evaluator.output_results(results) 52 | 53 | if not self.filepath: 54 | return 55 | 56 | if not self.save_best_only or self.results_tracker.is_best_epoch(epoch): 57 | filepath = self.filepath.format(epoch=epoch) 58 | with open(filepath, "w") as f: 59 | evaluator.output_results(results, file=f) 60 | 61 | 62 | class ModelCheckpoint(Callback): 63 | def __init__(self, results_tracker, filepath, save_best_only=False): 64 | super(ModelCheckpoint, self).__init__() 65 | self.filepath = filepath 66 | self.save_best_only = save_best_only 67 | self.results_tracker = results_tracker 68 | self.best_results = None 69 | 70 | def on_epoch_end(self, epoch, logs=None): 71 | if not self.save_best_only or self.results_tracker.is_best_epoch(epoch): 72 | filepath = self.filepath.format(epoch=epoch) 73 | self.model.save(filepath, overwrite=True) 74 | -------------------------------------------------------------------------------- /suplearn_clone_detection/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import logging 4 | 5 | import tensorflow as tf 6 | import keras.backend as K 7 | 8 | from suplearn_clone_detection import commands, settings 9 | from suplearn_clone_detection.token_based import commands as token_commands 10 | 11 | 12 | def create_token_parser(base_subparsers): 13 | token_parser = base_subparsers.add_parser("tokens", help="Tokens-based data creation commands") 14 | subparsers = token_parser.add_subparsers(dest="subcommand") 15 | 16 | create_vocab_parser = subparsers.add_parser( 17 | "create-vocabulary", help="Create vocabulary from tokens file") 18 | create_vocab_parser.add_argument("input", help="Input file containing tokens") 19 | create_vocab_parser.add_argument( 20 | "-o", "--output", help="Vocabulary output file", required=True) 21 | create_vocab_parser.add_argument( 22 | "--size", help="Maxiumum size of the vocabulary", type=int, default=10000) 23 | create_vocab_parser.add_argument( 24 | "--strip-values", 25 | help="Maxiumum size of the vocabulary", 26 | default=True, 27 | action="store_false", 28 | dest="include_values") 29 | 30 | create_skipgram_data_parser = subparsers.add_parser( 31 | "skipgram-data", help="Create data to train skipgram model") 32 | create_skipgram_data_parser.add_argument("input", help="Input file containing tokens") 33 | create_skipgram_data_parser.add_argument( 34 | "-o", "--output", help="Vocabulary output file", required=True) 35 | create_skipgram_data_parser.add_argument( 36 | "-v", "--vocabulary", help="Path to the vocabulary file", required=True) 37 | create_skipgram_data_parser.add_argument( 38 | "-w", "--window-size", help="Window size to generate context", type=int, default=2) 39 | 40 | 41 | def make_file_processor_parser(parser, output_required=False): 42 | parser.add_argument("file", help="file containing list of files to predict") 43 | parser.add_argument( 44 | "-b", "--files-base-dir", help="base directory for files in ") 45 | parser.add_argument( 46 | "-d", "--base-dir", help="base directory for model, config and output") 47 | parser.add_argument( 48 | "-c", "--config", help="config file for the model to evaluate", default="config.yml") 49 | parser.add_argument( 50 | "-m", "--model", help="path to the model to evaluate", default="model.h5") 51 | parser.add_argument( 52 | "--files-cache", help="file containing cached vectors for files") 53 | parser.add_argument( 54 | "--asts-path", help="file containing the JSON representation of the ASTs") 55 | parser.add_argument( 56 | "--filenames-path", help="file containing the filename path of the ASTs") 57 | parser.add_argument( 58 | "--batch-size", help="size of a batch", type=int) 59 | parser.add_argument( 60 | "-o", "--output", 61 | help="file where to save the output", required=output_required) 62 | 63 | 64 | def create_parser(): 65 | parser = argparse.ArgumentParser() 66 | 67 | parser.add_argument( 68 | "-q", "--quiet", help="reduce output", default=False, action="store_true") 69 | parser.add_argument( 70 | "--debug", help="enables debug", default=False, action="store_true") 71 | 72 | subparsers = parser.add_subparsers(dest="command") 73 | 74 | train_parser = subparsers.add_parser("train", help="Train the model") 75 | train_parser.add_argument( 76 | "-c", "--config", help="config file to train model", default="config.yml") 77 | 78 | 79 | generate_dataset_parser = subparsers.add_parser( 80 | "generate-dataset", help="Generate dataset for training/evaluating model") 81 | generate_dataset_parser.add_argument( 82 | "-c", "--config", help="config file path", default="config.yml") 83 | 84 | detect_clones_parser = subparsers.add_parser( 85 | "detect-clones", help="Detect clones in the given dataset") 86 | detect_clones_parser.add_argument("dataset", help="h5py dataset") 87 | detect_clones_parser.add_argument("-m", "--model", required=True, help="path to the model to compute distance") 88 | detect_clones_parser.add_argument("-o", "--output", required=True, help="output path") 89 | 90 | evaluate_parser = subparsers.add_parser("evaluate", help="Evaluate the model") 91 | evaluate_parser.add_argument( 92 | "-d", "--base-dir", help="base directory for model, config and output") 93 | evaluate_parser.add_argument( 94 | "--data-path", help="path of the data to use for evaulation (csv file)") 95 | evaluate_parser.add_argument( 96 | "--data-type", choices=["dev", "test"], default="dev", 97 | help="the type of data on which to evaluate the model") 98 | evaluate_parser.add_argument( 99 | "-c", "--config", help="config file for the model to evaluate", default="config.yml") 100 | evaluate_parser.add_argument( 101 | "-m", "--model", help="path to the model to evaluate", default="model.h5") 102 | evaluate_parser.add_argument( 103 | "-o", "--output", help="file where to save the output") 104 | evaluate_parser.add_argument( 105 | "-f", "--overwrite", help="overwrite the results output if file exists", 106 | default=False, action="store_true") 107 | 108 | 109 | evaluate_predictions_parser = subparsers.add_parser("evaluate-predictions", 110 | help="Evaluate predictions") 111 | evaluate_predictions_parser.add_argument("predictions", help="file containing predictions") 112 | evaluate_predictions_parser.add_argument("-o", "--output", help="output file") 113 | 114 | predict_parser = subparsers.add_parser("predict", help="Predict files") 115 | make_file_processor_parser(predict_parser) 116 | predict_parser.add_argument( 117 | "--max-size-diff", help="max size diff as a ratio between clones", type=float) 118 | 119 | vectorize_parser = subparsers.add_parser("vectorize", help="Vectorize files") 120 | make_file_processor_parser(vectorize_parser, output_required=True) 121 | vectorize_parser.add_argument("--encoder-index", type=int, 122 | help="The index of the encoder to use (only useful when both languages are the same)") 123 | 124 | show_results_parser = subparsers.add_parser("show-results", help="Show formatted results") 125 | show_results_parser.add_argument("filepath", help="file containing the results") 126 | show_results_parser.add_argument("metric", help="the metric to show") 127 | show_results_parser.add_argument("-o", "--output", help="output file to save the result") 128 | 129 | create_token_parser(subparsers) 130 | 131 | 132 | return parser 133 | 134 | 135 | app_parser = create_parser() 136 | 137 | 138 | def run_token_command(args): 139 | if args.subcommand == "create-vocabulary": 140 | token_commands.create_vocabulary( 141 | args.input, args.size, args.include_values, args.output) 142 | elif args.subcommand == "skipgram-data": 143 | token_commands.generate_skipgram_data( 144 | args.input, args.vocabulary, args.window_size, args.output) 145 | else: 146 | app_parser.error("no subcommand provided") 147 | 148 | 149 | def run_command(args): 150 | if args.command == "train": 151 | commands.train(args.config, args.quiet) 152 | elif args.command == "evaluate": 153 | commands.evaluate(vars(args)) 154 | elif args.command == "evaluate-predictions": 155 | commands.evaluate_predictions(vars(args)) 156 | elif args.command == "predict": 157 | commands.predict(vars(args)) 158 | elif args.command == "detect-clones": 159 | commands.detect_clones(vars(args)) 160 | elif args.command == "generate-dataset": 161 | commands.generate_dataset(args.config) 162 | elif args.command == "show-results": 163 | commands.show_results(args.filepath, args.metric, args.output) 164 | elif args.command == "tokens": 165 | run_token_command(args) 166 | elif args.command == "vectorize": 167 | commands.vectorize(vars(args)) 168 | else: 169 | app_parser.error("no command provided") 170 | 171 | 172 | def run(): 173 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=settings.TF_GPU_MAX_MEMORY_USAGE) 174 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 175 | K.set_session(sess) 176 | 177 | args = app_parser.parse_args() 178 | 179 | log_level = logging.INFO if args.quiet else logging.DEBUG 180 | logging.basicConfig(level=log_level, 181 | format="%(asctime)-15s %(levelname)s %(message)s") 182 | 183 | if args.debug: 184 | run_command(args) 185 | return 186 | 187 | try: 188 | run_command(args) 189 | except Exception as e: # pylint: disable=broad-except 190 | logging.error("failed: %s", e) 191 | sys.exit(1) 192 | -------------------------------------------------------------------------------- /suplearn_clone_detection/commands.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from os import path 3 | import logging 4 | 5 | import h5py 6 | from keras.models import load_model 7 | 8 | from suplearn_clone_detection import database, dataset, layers, evaluator 9 | from suplearn_clone_detection.config import Config 10 | from suplearn_clone_detection.dataset.generator import DatasetGenerator 11 | from suplearn_clone_detection.predictor import Predictor 12 | from suplearn_clone_detection.vectorizer import Vectorizer 13 | from suplearn_clone_detection.results_printer import ResultsPrinter 14 | from suplearn_clone_detection.trainer import Trainer 15 | from suplearn_clone_detection.detector import Detector 16 | 17 | 18 | def train(config_path: str, quiet: bool = False): 19 | trainer = Trainer(config_path, quiet) 20 | logging.debug("initializing trainer...") 21 | trainer.initialize() 22 | if not quiet: 23 | trainer.model.summary() 24 | trainer.train() 25 | 26 | data = dataset.get(trainer.config, "dev") 27 | ev = evaluator.Evaluator(trainer.model) 28 | results_file = path.join(trainer.output_dir, "results-dev.yml") 29 | results = ev.evaluate(data, output=results_file) 30 | if not quiet: 31 | evaluator.output_results(results) 32 | return results 33 | 34 | 35 | def evaluate(options: Dict[str, str]): 36 | options = process_options(options) 37 | config = load_and_process_config(options["config"]) 38 | 39 | if options.get("output", "") is None: 40 | val = "results-{0}.yml".format(options["data_type"]) 41 | options["output"] = path.join(options.get("base_dir", ""), val) 42 | 43 | ev = evaluator.Evaluator(options["model"]) 44 | data = dataset.get(config, options["data_type"]) 45 | overwrite = options.get("overwrite", False) 46 | results = ev.evaluate(data, output=options["output"], overwrite=overwrite) 47 | if not options.get("quiet", False): 48 | evaluator.output_results(results) 49 | return results 50 | 51 | 52 | def evaluate_predictions(options: Dict[str, str]): 53 | evaluator.evaluate_predictions(options["predictions"], options["output"]) 54 | 55 | 56 | def predict(options: Dict[str, str]): 57 | options = process_options(options) 58 | config = load_and_process_config(options["config"]) 59 | with open(options["file"], "r") as f: 60 | files_base_dir = options.get("files_base_dir") or "" 61 | files = [] 62 | for line in f: 63 | pair = line.strip().split("," if "," in line else " ")[:2] 64 | files.append(tuple(path.join(files_base_dir, filename) for filename in pair)) 65 | 66 | predictor = Predictor.from_config(config, options["model"], options) 67 | predictor.predict(files) 68 | 69 | if not options.get("quiet", False): 70 | print(predictor.formatted_predictions) 71 | 72 | if options.get("output"): 73 | with open(options["output"], "w") as f: 74 | f.write(predictor.formatted_predictions) 75 | 76 | 77 | def vectorize(options: Dict[str, str]): 78 | options = process_options(options) 79 | config = load_and_process_config(options["config"]) 80 | vectorizer = Vectorizer.from_config(config, options["model"], options) 81 | with open(options["file"]) as f: 82 | filenames = f.read().splitlines() 83 | vectorizer.process(filenames, options["output"]) 84 | 85 | 86 | def process_options(options: Dict[str, str]): 87 | options = options.copy() 88 | 89 | if not options.get("base_dir"): 90 | return options 91 | for key in ["config", "model"]: 92 | if options.get(key): 93 | options[key] = path.join(options["base_dir"], options[key]) 94 | 95 | for key in ["config", "model"]: 96 | if not path.isfile(options[key]): 97 | raise ValueError("cannot open {0}".format(options[key])) 98 | 99 | return options 100 | 101 | 102 | def detect_clones(options: dict): 103 | model = load_model(options["model"], custom_objects=layers.custom_objects) 104 | with h5py.File(options["dataset"]) as data: 105 | detector = Detector(model, data) 106 | predictions = detector.detect_clones() 107 | with open(options["output"], "w") as f: 108 | detector.output_prediction_results(predictions, f) 109 | 110 | 111 | def show_results(filepath: str, metric: str, output: str): 112 | printer = ResultsPrinter(filepath) 113 | printer.show(metric, output) 114 | 115 | 116 | def generate_dataset(config_path: str): 117 | config = load_and_process_config(config_path) 118 | dataset_generator = DatasetGenerator(config) 119 | dataset_generator.create_samples() 120 | 121 | 122 | def load_and_process_config(config_path: str) -> Config: 123 | config = Config.from_file(config_path) 124 | if config.generator.db_path: 125 | database.bind_db(config.generator.db_path) 126 | return config 127 | -------------------------------------------------------------------------------- /suplearn_clone_detection/config.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import json 3 | import hashlib 4 | 5 | import yaml 6 | 7 | 8 | TRANSFORMER_MAPPING = { 9 | "FlatVectorIndexASTTransformer": "DFSTransformer", 10 | } 11 | 12 | 13 | class LanguageConfig: 14 | def __init__(self, config): 15 | self.name = config["name"] 16 | self.vocabulary = path.expandvars(config["vocabulary"]) 17 | self._vocabulary_size = None 18 | self.embeddings = path.expandvars(config.get("embeddings", "")) 19 | self.vocabulary_offset = config.get("vocabulary_offset", 0) 20 | self.input_length = config.get("input_length") 21 | self.max_length = config.get("max_length") 22 | self.embeddings_dimension = config["embeddings_dimension"] 23 | if "output_dimension" in config: 24 | self.output_dimensions = [config["output_dimension"]] 25 | else: 26 | self.output_dimensions = config["output_dimensions"] 27 | self.transformer_class_name = config.get("transformer_class_name", 28 | "DFSTransformer") 29 | if self.transformer_class_name in TRANSFORMER_MAPPING: 30 | self.transformer_class_name = TRANSFORMER_MAPPING[self.transformer_class_name] 31 | self.bidirectional_encoding = config.get("bidirectional_encoding", False) 32 | self.hash_dims = config.get("hash_dims", []) 33 | 34 | @property 35 | def vocabulary_size(self): 36 | if self._vocabulary_size is None: 37 | raise ValueError("vocabulary size needs to be set explicitly") 38 | return self._vocabulary_size 39 | 40 | @vocabulary_size.setter 41 | def vocabulary_size(self, value): 42 | self._vocabulary_size = value 43 | 44 | 45 | class ModelConfig: 46 | KNOWN_MERGE_MODES = [ 47 | "simple", 48 | "bidistance", 49 | "euclidean_distance", 50 | "euclidean_similarity", 51 | "cosine_similarity" 52 | ] 53 | 54 | def __init__(self, config): 55 | self.languages = [LanguageConfig(lang) for lang in config["languages"]] 56 | self.dense_layers = config.get("dense_layers", [64, 64]) 57 | self.optimizer = config.get("optimizer", {"type": "sgd"}) 58 | self.merge_mode = config.get("merge_mode", "simple") 59 | self.merge_output_dim = config.get("merge_output_dim", 64) 60 | self.use_output_nn = config.get("use_output_nn", True) 61 | if not self.merge_mode in self.KNOWN_MERGE_MODES: 62 | raise ValueError("unknown merge mode: {0}".format(self.merge_mode)) 63 | default_loss = "binary_crossentropy" if self.use_output_nn else "mse" 64 | self.loss = config.get("loss", default_loss) 65 | self.metrics = config.get("metrics", ["accuracy"]) 66 | self.normalization_value = 100 67 | 68 | 69 | class GeneratorConfig: 70 | def __init__(self, config): 71 | self.submissions_path = path.expandvars(config["submissions_path"]) 72 | self.asts_path = path.expandvars(config["asts_path"]) 73 | self.filenames_path = None 74 | self.db_path = path.expandvars(config.get("db_path", "")) 75 | if "filenames_path " in config: 76 | self.filenames_path = path.expandvars(config["filenames_path"]) 77 | self.file_format = config.get("file_format", "multi_file") 78 | self.use_all_combinations = config.get("use_all_combinations", False) 79 | self.shuffle = config.get("shuffle", True) 80 | self.shuffle_before_epoch = config.get("shuffle_before_epoch", True) 81 | self.split_ratio = config.get("split_ratio", [0.8, 0.1, 0.1]) 82 | self.negative_samples = config.get("negative_samples", 1) 83 | self.negative_sample_candidates = config.get("negative_sample_candidates", 8) 84 | self.samples_per_problem = config.get("samples_per_problem", 1) 85 | self.class_weights = config.get("class_weights") 86 | self.negative_sample_distance = config.get("negative_sample_distance", 0.2) 87 | 88 | 89 | class TrainerConfig: 90 | def __init__(self, config): 91 | self.epochs = config["epochs"] 92 | self.batch_size = config.get("batch_size", 128) 93 | self.output_dir = path.expandvars(config.get("output_dir", "")) 94 | self.tensorboard_logs = config.get("tensorboard_logs", True) 95 | 96 | 97 | class Config: 98 | def __init__(self, config): 99 | self.model = ModelConfig(config["model"]) 100 | self.generator = GeneratorConfig(config["generator"]) 101 | self.trainer = TrainerConfig(config["trainer"]) 102 | 103 | @property 104 | def data_generation_config(self): 105 | return dict( 106 | languages=[v.name for v in self.model.languages], 107 | negative_sample_distance=self.generator.negative_sample_distance, 108 | samples_per_problem=self.generator.samples_per_problem, 109 | split_ratio=self.generator.split_ratio, 110 | ) 111 | 112 | def data_generation_checksum(self): 113 | h = hashlib.md5() 114 | h.update(json.dumps(self.data_generation_config, sort_keys=True).encode("utf-8")) 115 | return h.hexdigest() 116 | 117 | @classmethod 118 | def from_file(cls, filepath) -> "Config": 119 | with open(filepath) as f: 120 | return cls(yaml.load(f)) 121 | -------------------------------------------------------------------------------- /suplearn_clone_detection/database.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | from sqlalchemy import create_engine 4 | from sqlalchemy.orm import scoped_session, sessionmaker 5 | from sqlalchemy.orm.session import Session as DBSession 6 | from sqlalchemy.ext.declarative import declarative_base 7 | 8 | 9 | Session: DBSession = scoped_session(sessionmaker(autocommit=False, autoflush=False)) 10 | Base = declarative_base() 11 | Base.query = Session.query_property() 12 | 13 | 14 | def bind_db(db_url: str): 15 | engine = create_engine(db_url) 16 | Session.configure(bind=engine) 17 | 18 | 19 | @contextmanager 20 | def get_session(commit=False): 21 | sess = Session() 22 | try: 23 | yield sess 24 | if commit: 25 | sess.commit() 26 | except Exception: # pylint: disable=broad-except 27 | sess.rollback() 28 | finally: 29 | sess.close() 30 | -------------------------------------------------------------------------------- /suplearn_clone_detection/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from keras.utils import Sequence 2 | 3 | from suplearn_clone_detection.config import Config 4 | from suplearn_clone_detection.dataset.sequences import DevSequence, TestSequence 5 | 6 | 7 | def get(config: Config, data_type: str) -> Sequence: 8 | if data_type == "dev": 9 | return DevSequence(config) 10 | elif data_type == "test": 11 | return TestSequence(config) 12 | raise ValueError("cannot get {0} sequence".format(data_type)) 13 | -------------------------------------------------------------------------------- /suplearn_clone_detection/dataset/generator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import random 3 | import logging 4 | 5 | from suplearn_clone_detection import entities 6 | from suplearn_clone_detection.dataset import util 7 | from suplearn_clone_detection.config import Config 8 | from suplearn_clone_detection.database import Session 9 | 10 | 11 | class DatasetGenerator: 12 | def __init__(self, config: Config): 13 | self.config = config 14 | self.config_checksum = self.config.data_generation_checksum() 15 | 16 | def load_submissions(self, lang: str) -> Dict[str, List[entities.Submission]]: 17 | training_ratio, dev_ratio, _test_ratio = self.config.generator.split_ratio 18 | submissions = entities.Submission.query \ 19 | .filter(entities.Submission.language_code == lang) \ 20 | .all() 21 | random.shuffle(submissions) 22 | 23 | training_count = int(len(submissions) * training_ratio) 24 | dev_count = int(len(submissions) * dev_ratio) 25 | return dict( 26 | training=submissions[:training_count], 27 | dev=submissions[training_count:training_count+dev_count], 28 | test=submissions[training_count+dev_count:], 29 | ) 30 | 31 | def _create_sample(self, dataset_name: str, submission: entities.Submission, 32 | sorted_set: List[entities.Submission], 33 | positive_samples: List[entities.Submission]): 34 | if not positive_samples: 35 | return None, None 36 | positive_sample_idx = random.randrange(len(positive_samples)) 37 | positive_sample = positive_samples[positive_sample_idx] 38 | negative_samples = util.select_negative_candidates( 39 | sorted_set, positive_sample, 40 | self.config.generator.negative_sample_distance) 41 | if not negative_samples: 42 | return None, None 43 | negative_sample = random.choice(negative_samples) 44 | return entities.Sample( 45 | anchor_id=submission.id, 46 | anchor=submission, 47 | positive=positive_sample, 48 | positive_id=positive_sample.id, 49 | negative=negative_sample, 50 | negative_id=negative_sample.id, 51 | dataset_name=dataset_name, 52 | config_checksum=self.config_checksum, 53 | ), positive_sample_idx 54 | 55 | def create_set_samples(self, dataset_name: str, 56 | lang1_dataset: List[entities.Submission], 57 | lang2_dataset: List[entities.Submission]): 58 | samples = [] 59 | lang2_grouped_dataset = util.group_submissions(lang2_dataset) 60 | lang2_sorted_dataset = util.sort_dataset(lang2_dataset) 61 | for i, submission in enumerate(lang1_dataset): 62 | positive_submissions = lang2_grouped_dataset.get(submission.group_key) 63 | if not positive_submissions: 64 | continue 65 | positive_submissions = positive_submissions.copy() 66 | for _ in range(self.config.generator.samples_per_problem): 67 | sample, positive_idx = self._create_sample(dataset_name, submission, 68 | lang2_sorted_dataset, 69 | positive_submissions) 70 | if not sample: 71 | break 72 | del positive_submissions[positive_idx] 73 | samples.append(sample) 74 | if i % 1000 == 0 and sample: 75 | logging.info("%s-%s pairs progress - %s - %s/%s", 76 | submission.language_code, sample.negative.language_code, 77 | dataset_name, i, len(lang1_dataset)) 78 | return samples 79 | 80 | def create_lang_samples( 81 | self, lang1_datasets: Dict[str, List[entities.Submission]], 82 | lang2_datasets: Dict[str, List[entities.Submission]]) -> entities.Sample: 83 | samples = [] 84 | for dataset_name, dataset in lang1_datasets.items(): 85 | samples.extend(self.create_set_samples(dataset_name, dataset, lang2_datasets[dataset_name])) 86 | return samples 87 | 88 | def check_existing_samples(self): 89 | q = entities.Sample.query.filter_by(config_checksum=self.config_checksum) 90 | if q.first(): 91 | raise ValueError("samples already exists for checksum '{0}'. run " 92 | "\"DELETE FROM samples WHERE config_checksum='{0}'\" " 93 | "if you want to remove them".format(self.config_checksum)) 94 | 95 | def create_samples(self): 96 | self.check_existing_samples() 97 | languages = [v.name for v in self.config.model.languages] 98 | lang1_samples = self.load_submissions(languages[0]) 99 | if languages[0] != languages[1]: 100 | lang2_samples = self.load_submissions(languages[1]) 101 | else: 102 | lang2_samples = lang1_samples 103 | samples = self.create_lang_samples(lang1_samples, lang2_samples) 104 | if languages[0] != languages[1]: 105 | samples.extend(self.create_lang_samples(lang2_samples, lang1_samples)) 106 | Session.bulk_save_objects(samples) 107 | Session.commit() 108 | return len(samples) 109 | -------------------------------------------------------------------------------- /suplearn_clone_detection/dataset/sequences.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Iterable 2 | import math 3 | import json 4 | import random 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from keras.engine import Model 11 | from keras.utils import Sequence 12 | from keras.preprocessing.sequence import pad_sequences 13 | 14 | from sqlalchemy.orm import joinedload, aliased 15 | 16 | from suplearn_clone_detection.database import get_session 17 | from suplearn_clone_detection import entities, ast_transformer, util 18 | from suplearn_clone_detection.util import memoize 19 | from suplearn_clone_detection.config import Config 20 | 21 | 22 | Positive = aliased(entities.Submission) 23 | Anchor = aliased(entities.Submission) 24 | 25 | 26 | class SuplearnSequence(Sequence): 27 | def __init__(self, config: Config) -> None: 28 | self.config = config 29 | self.config_checksum = self.config.data_generation_checksum() 30 | transformers = ast_transformer.create_all(config.model.languages) 31 | self.ast_transformers = {tr.language: tr for tr in transformers} 32 | self.shuffle = False 33 | 34 | @property 35 | def dataset_name(self): 36 | raise NotImplementedError() 37 | 38 | def __getitem__(self, index): 39 | samples = self.get_samples(index) 40 | lang1_positive, lang2_positive = self.get_positive_pairs(samples) 41 | lang1_negative, lang2_negative = self.get_negative_pairs(samples) 42 | lang1_input = pad_sequences(lang1_positive + lang1_negative) 43 | lang2_input = pad_sequences(lang2_positive + lang2_negative) 44 | y = np.hstack([np.ones(len(lang1_positive), dtype=np.int32), 45 | np.zeros(len(lang1_negative), dtype=np.int32)]) 46 | if self.shuffle: 47 | shuffled_index = np.random.permutation(len(y)) 48 | X = [lang1_input[shuffled_index], lang2_input[shuffled_index]] 49 | y = y[shuffled_index] 50 | else: 51 | X = [lang1_input, lang2_input] 52 | return X, y 53 | 54 | def get_labels(self): 55 | y = np.array([]) 56 | for index in range(len(self)): 57 | _, batch_y = self[index] 58 | y = np.hstack([y, batch_y]) 59 | return y 60 | 61 | def get_positive_pairs(self, samples: List[entities.Sample]) -> Tuple[List[int], List[int]]: 62 | return self._get_pairs(samples, "positive") 63 | 64 | def get_negative_pairs(self, samples: List[entities.Sample]) -> Tuple[List[int], List[int]]: 65 | return self._get_pairs(samples, "negative") 66 | 67 | def _get_pairs(self, samples: List[entities.Sample], 68 | second_elem_key: str) -> Tuple[List[int], List[int]]: 69 | lang1 = [self.get_ast(sample.anchor) for sample in samples] 70 | lang2 = [self.get_ast(getattr(sample, second_elem_key)) for sample in samples] 71 | return lang1, lang2 72 | 73 | @memoize(lambda self, submission: (self, submission.id)) 74 | def get_ast(self, submission: entities.Submission): 75 | transformer = self.ast_transformers[submission.language_code] 76 | return transformer.transform_ast(json.loads(submission.ast)) 77 | 78 | @property 79 | def batch_size(self): 80 | return self.config.trainer.batch_size 81 | 82 | @contextmanager 83 | def db_query(self): 84 | conditions = dict(dataset_name=self.dataset_name, 85 | config_checksum=self.config_checksum) 86 | with get_session() as sess: 87 | yield sess.query(entities.Sample).filter_by(**conditions) 88 | 89 | def __len__(self): 90 | return math.ceil(self.count_samples() * 2 / self.batch_size) 91 | 92 | @memoize 93 | def get_samples(self, index): 94 | if index >= len(self): 95 | raise IndexError("sequence index out of range") 96 | samples_per_batch = self.batch_size // 2 97 | offset = samples_per_batch * index 98 | options = [joinedload(entities.Sample.anchor), 99 | joinedload(entities.Sample.positive), 100 | joinedload(entities.Sample.negative)] 101 | with self.db_query() as query: 102 | return query.options(*options).offset(offset).limit(samples_per_batch).all() 103 | 104 | @memoize 105 | def count_samples(self): 106 | with self.db_query() as query: 107 | return query.count() 108 | 109 | 110 | class TrainingSequence(SuplearnSequence): 111 | def __init__(self, model: Model, graph: tf.Graph, config: Config) -> None: 112 | super(TrainingSequence, self).__init__(config) 113 | self.model = model 114 | self.graph = graph 115 | self.shuffle = True 116 | 117 | @contextmanager 118 | def db_query(self): 119 | lang1_config, lang2_config = self.config.model.languages 120 | with super(TrainingSequence, self).db_query() as query: 121 | if lang1_config.max_length: 122 | query = query.join(Anchor, entities.Sample.anchor) \ 123 | .filter(Anchor.tokens_count <= lang1_config.max_length) 124 | if lang2_config.max_length: 125 | query = query.join(Positive, entities.Sample.positive) \ 126 | .filter(Positive.tokens_count <= lang2_config.max_length) 127 | yield query 128 | 129 | def get_negative_pairs(self, samples: List[entities.Sample]) \ 130 | -> Tuple[List[List[int]], List[List[int]]]: 131 | count_per_anchor = self.config.generator.negative_sample_candidates 132 | anchors = [sample.anchor for sample in samples] 133 | anchor_asts = [self.get_ast(anchor) for anchor in anchors] 134 | candidates = random.sample(self.candidates_pool(samples[0].negative.language_code), 135 | count_per_anchor * len(samples)) 136 | candidate_asts = [self.get_ast(submission) for submission in candidates] 137 | 138 | # input lengths: len(samples) * count_per_anchor 139 | lang1_input = pad_sequences([ast for ast in anchor_asts for _ in range(count_per_anchor)]) 140 | lang2_input = pad_sequences([ast for ast in candidate_asts]) 141 | 142 | with self.graph.as_default(): 143 | predictions = self.model.predict([lang1_input, lang2_input], 144 | batch_size=len(lang1_input)) 145 | negative_asts = self._collect_negative_asts(anchors, candidates, 146 | candidate_asts, predictions) 147 | 148 | return anchor_asts, negative_asts 149 | 150 | def _collect_negative_asts(self, anchors: List[entities.Submission], 151 | candidates: List[entities.Submission], 152 | candidate_asts: List[List[int]], 153 | predictions: Iterable[int]) -> List[List[int]]: 154 | count_per_anchor = self.config.generator.negative_sample_candidates 155 | negative_asts = [] 156 | for i, sample_predictions in enumerate(util.in_batch(predictions, count_per_anchor)): 157 | negative_ast = None 158 | max_prediction = -1 159 | base_index = i * count_per_anchor 160 | for j, prediction in enumerate(sample_predictions): 161 | if prediction > max_prediction and \ 162 | candidates[base_index + j].group_key != anchors[i].group_key: 163 | negative_ast = candidate_asts[base_index + j] 164 | max_prediction = prediction 165 | 166 | if not negative_ast: 167 | logging.warning("could not find a valid negative sample") 168 | negative_ast = candidate_asts[base_index] 169 | 170 | negative_asts.append(negative_ast) 171 | return negative_asts 172 | 173 | @memoize 174 | def candidates_pool(self, language_code: str): 175 | with get_session() as sess: 176 | query = sess.query(entities.Submission) \ 177 | .filter_by(language_code=language_code) 178 | max_length = self.config.model.languages[1].max_length 179 | if max_length: 180 | query = query.filter(entities.Submission.tokens_count <= max_length) 181 | return query.all() 182 | 183 | 184 | @property 185 | def dataset_name(self): 186 | return "training" 187 | 188 | 189 | class DevSequence(SuplearnSequence): 190 | @property 191 | def dataset_name(self): 192 | return "dev" 193 | 194 | 195 | class TestSequence(SuplearnSequence): 196 | @property 197 | def dataset_name(self): 198 | return "test" 199 | -------------------------------------------------------------------------------- /suplearn_clone_detection/dataset/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from typing import List, Optional, Dict, Tuple 4 | from suplearn_clone_detection import entities 5 | 6 | 7 | def find_submission_index(submissions: List[entities.Submission], 8 | tokens_count: int, rightmost=False) -> int: 9 | left = 0 10 | right = len(submissions) 11 | while left < right: 12 | middle = (left + right) // 2 13 | submission = submissions[middle] 14 | if submission.tokens_count < tokens_count or \ 15 | (submission.tokens_count == tokens_count and rightmost): 16 | left = middle + 1 17 | else: 18 | right = middle 19 | return left 20 | 21 | 22 | def select_negative_candidates( 23 | sorted_submissions: List[entities.Submission], 24 | positive_sample: entities.Submission, 25 | distance: int, 26 | min_candidates: Optional[int] = None) -> List[entities.Submission]: 27 | tokens_count = positive_sample.tokens_count 28 | tokens_diff = int(distance * tokens_count) 29 | left_index = find_submission_index(sorted_submissions, tokens_count - tokens_diff) 30 | right_index = find_submission_index(sorted_submissions, 31 | tokens_count + tokens_diff, rightmost=True) 32 | candidates = sorted_submissions[left_index:right_index] 33 | 34 | if not min_candidates: 35 | return candidates 36 | 37 | while len(candidates) < min_candidates: 38 | new_candidate = random.choice(sorted_submissions) 39 | if new_candidate.group_key != positive_sample.group_key and \ 40 | new_candidate not in candidates: 41 | candidates.append(new_candidate) 42 | return candidates 43 | 44 | 45 | def select_negative_sample( 46 | sorted_submissions: List[entities.Submission], 47 | positive_sample: entities.Submission, 48 | distance: int) -> entities.Submission: 49 | candidates = select_negative_candidates(sorted_submissions, positive_sample, distance) 50 | while candidates: 51 | idx = random.randrange(len(candidates)) 52 | negative_sample = candidates[idx] 53 | if positive_sample.group_key != negative_sample.group_key: 54 | return negative_sample 55 | del candidates[idx] 56 | 57 | 58 | def group_submissions(submissions: List[entities.Submission]) \ 59 | -> Dict[Tuple[str, int, int], entities.Submission]: 60 | result = {} 61 | for submission in submissions: 62 | result.setdefault(submission.group_key, []) 63 | result[submission.group_key].append(submission) 64 | return result 65 | 66 | 67 | def sort_dataset(submissions: List[entities.Submission]) \ 68 | -> Tuple[List[entities.Submission], Dict[int, int]]: 69 | sorted_submissions = sorted(submissions, key=lambda x: x.tokens_count) 70 | return sorted_submissions 71 | -------------------------------------------------------------------------------- /suplearn_clone_detection/detector.py: -------------------------------------------------------------------------------- 1 | import math 2 | from keras.models import Model 3 | import h5py 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | from suplearn_clone_detection import util 9 | 10 | class Detector: 11 | def __init__(self, model: Model, dataset: h5py.Dataset): 12 | self.model = model 13 | self.dataset = dataset 14 | self.left_lang = self._get_lang(self.model.layers[0]) 15 | self.right_lang = self._get_lang(self.model.layers[1]) 16 | 17 | def _get_lang(self, layer): 18 | return layer.name.split("_")[1] 19 | 20 | # TODO: parallelize calls to model.predict 21 | def detect_clones(self, batch_size=1024): 22 | total_batches = self.batches_count(batch_size) 23 | for batch in tqdm(self.batch_iterator(batch_size), total=total_batches): 24 | left, right = self.get_inputs(batch) 25 | batch_predictions = self.model.predict([left, right]).reshape(len(batch)) 26 | yield from zip(batch, batch_predictions) 27 | 28 | @staticmethod 29 | def output_prediction_results(predictions, f): 30 | for (left, right), output in predictions: 31 | print("{0},{1},{2}".format(left, right, output), file=f) 32 | 33 | def get_inputs(self, batch): 34 | left, right = zip(*[(self.dataset[l].value, self.dataset[r].value) for l, r in batch]) 35 | return np.array(left), np.array(right) 36 | 37 | def batch_iterator(self, batch_size): 38 | pairs = util.hdf5_key_pairs(self.dataset, self.left_lang, self.right_lang) 39 | yield from util.in_batch(pairs, batch_size) 40 | 41 | def batches_count(self, batch_size): 42 | keys_count = len(util.hdf5_keys(self.dataset)) 43 | pairs_count = (keys_count * (keys_count - 1)) // 2 44 | return math.ceil(pairs_count / batch_size) 45 | -------------------------------------------------------------------------------- /suplearn_clone_detection/entities.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, ForeignKey 2 | from sqlalchemy.orm import relationship 3 | 4 | from suplearn_clone_detection.database import Base 5 | 6 | 7 | class Submission(Base): 8 | __tablename__ = "submissions" 9 | 10 | id: int = Column(Integer, primary_key=True) 11 | url: str = Column(String) 12 | contest_type: str = Column(String) 13 | contest_id: int = Column(Integer) 14 | problem_id: str = Column(String) 15 | problem_title: str = Column(String) 16 | filename: str = Column(String) 17 | language: str = Column(String) 18 | language_code: str = Column(String) 19 | source_length: int = Column(Integer) 20 | exec_time: int = Column(Integer) 21 | tokens_count: int = Column(Integer) 22 | source: str = Column(String) 23 | ast: str = Column(String) 24 | 25 | @property 26 | def path(self): 27 | return "{0}/{1}/{2}/{3}".format( 28 | self.contest_type, 29 | self.contest_id, 30 | self.problem_id, 31 | self.filename) 32 | 33 | @property 34 | def group_key(self): 35 | return (self.contest_type, self.contest_id, self.problem_id) 36 | 37 | def __repr__(self): 38 | return "Submission(path=\"{0}\")".format(self.path) 39 | 40 | 41 | class Sample(Base): 42 | __tablename__ = "samples" 43 | 44 | id: int = Column(Integer, primary_key=True) 45 | dataset_name: str = Column(String) 46 | config_checksum: str = Column(String) 47 | 48 | anchor_id: int = Column(Integer, ForeignKey("submissions.id")) 49 | positive_id: int = Column(Integer, ForeignKey("submissions.id")) 50 | negative_id: int = Column(Integer, ForeignKey("submissions.id")) 51 | 52 | anchor: Submission = relationship("Submission", 53 | foreign_keys=[anchor_id], 54 | cascade="expunge") 55 | positive: Submission = relationship("Submission", 56 | foreign_keys=[positive_id], 57 | cascade="expunge") 58 | negative: Submission = relationship("Submission", 59 | foreign_keys=[negative_id], 60 | cascade="expunge") 61 | 62 | def __repr__(self): 63 | return "Sample(anchor={0}, positive={1}, negative={2}, " \ 64 | "dataset_name=\"{3}\")".format(self.anchor, self.positive, 65 | self.negative, self.dataset_name) 66 | -------------------------------------------------------------------------------- /suplearn_clone_detection/evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Dict, Union 3 | from os import path 4 | import logging 5 | 6 | import yaml 7 | 8 | from keras.models import load_model, Model 9 | import numpy as np 10 | from sklearn import metrics 11 | 12 | from suplearn_clone_detection.layers import custom_objects 13 | from suplearn_clone_detection.dataset.sequences import SuplearnSequence 14 | 15 | 16 | 17 | def get_metrics(labels, predictions): 18 | precisions, recalls, _ = metrics.precision_recall_curve(labels, predictions) 19 | return { 20 | "samples_count": len(labels), 21 | "positive_samples_count": len([labels for t in labels if t == 1]), 22 | "accuracy": float(metrics.accuracy_score(labels, predictions)), 23 | "precision": float(metrics.precision_score(labels, predictions)), 24 | "recall": float(metrics.recall_score(labels, predictions)), 25 | "avg_precision": float(metrics.average_precision_score(labels, predictions)), 26 | "f1": float(metrics.f1_score(labels, predictions)), 27 | "pr_curve": dict(precision=precisions.tolist(), recall=recalls.tolist()) 28 | } 29 | 30 | 31 | def output_results(results: Dict[str, Dict[str, float]], file=sys.stdout): 32 | print(yaml.dump(results, default_flow_style=False), file=file, end="") 33 | 34 | 35 | def try_output_results(results, output: str = None, overwrite: bool = False): 36 | if output: 37 | if path.exists(output) and not overwrite: 38 | logging.warning("%s exists, skipping", output) 39 | else: 40 | with open(output, "w") as f: 41 | output_results(results, file=f) 42 | 43 | 44 | def evaluate_predictions(predictions_file: str, output: str = None): 45 | with open(predictions_file) as f: 46 | lines = [v.strip().split(",") for v in f if v] 47 | predictions = np.round([float(line[2]) for line in lines]) 48 | labels = np.array([int(path.dirname(f1) == path.dirname(f2)) for f1, f2, _ in lines]) 49 | results = get_metrics(labels, predictions) 50 | if output: 51 | try_output_results(results, output) 52 | else: 53 | output_results(results) 54 | 55 | 56 | class Evaluator: 57 | def __init__(self, model: Union[Model, str]): 58 | if isinstance(model, str): 59 | model = load_model(model, custom_objects=custom_objects) 60 | self.model = model 61 | 62 | def evaluate(self, data: SuplearnSequence, output: str = None, 63 | overwrite: bool = False) -> dict: 64 | labels = data.get_labels() 65 | logging.info("running predictions with %s samples", len(labels)) 66 | prediction_probs = self.model.predict_generator(data) 67 | predictions = np.round(prediction_probs) 68 | results = get_metrics(labels, predictions) 69 | try_output_results(results, output, overwrite) 70 | return results 71 | -------------------------------------------------------------------------------- /suplearn_clone_detection/file_processor.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import logging 4 | import subprocess 5 | from os import path 6 | 7 | from keras.models import load_model 8 | 9 | from suplearn_clone_detection.config import Config 10 | from suplearn_clone_detection.ast_loader import ASTLoader 11 | from suplearn_clone_detection.layers import custom_objects, ModelWrapper 12 | from suplearn_clone_detection import ast_transformer 13 | 14 | 15 | class FileProcessor: 16 | def __init__(self, config: Config, model: ModelWrapper, options: dict = None): 17 | if options is None: 18 | options = {} 19 | self.options = options 20 | self.config = config 21 | self.loader = self._make_ast_loader(config, options) 22 | self._files_cache = {} 23 | self.language_names = [lang.name for lang in self.config.model.languages] 24 | transformers = ast_transformer.create_all(self.config.model.languages) 25 | self.transformers = {t.language: t for t in transformers} 26 | self.model = model 27 | if self.options.get("files_cache"): 28 | with open(self.options["files_cache"], "rb") as f: 29 | self._files_cache = pickle.load(f) 30 | 31 | @staticmethod 32 | def _make_ast_loader(config: Config, options: dict): 33 | args = dict( 34 | asts_path=options.get("asts_path") or config.generator.asts_path, 35 | filenames_path=options.get("filenames_path") or config.generator.filenames_path 36 | ) 37 | if options.get("file_format"): 38 | args["file_format"] = options["file_format"] 39 | return ASTLoader(**args) 40 | 41 | def get_file_vector(self, filename, language): 42 | if filename not in self._files_cache: 43 | transformer = self.transformers[language] 44 | file_ast = self.get_file_ast(filename) 45 | self._files_cache[filename] = (file_ast, transformer.transform_ast(file_ast)) 46 | return self._files_cache[filename] 47 | 48 | def get_file_ast(self, filename): 49 | if self.loader.has_file(filename): 50 | return self.loader.get_ast(filename) 51 | _, ext = path.splitext(filename) 52 | executable = "bigcode-astgen-{0}".format(ext[1:]) 53 | logging.warning("%s AST not found, generating with %s", filename, executable) 54 | res = subprocess.run([executable, filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 55 | if res.returncode != 0: 56 | raise ValueError("got exit code {0}: {1}".format(res.returncode, res.stderr)) 57 | return json.loads(res.stdout) 58 | 59 | @classmethod 60 | def from_config(cls, config: Config, model_path: str, options: dict): 61 | model = load_model(model_path, custom_objects=custom_objects) 62 | return cls(config, model, options) 63 | -------------------------------------------------------------------------------- /suplearn_clone_detection/layers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import copy 3 | from os import path 4 | 5 | from keras import activations, initializers, regularizers, constraints 6 | from keras.engine import Layer, InputSpec, Model 7 | from keras.layers.merge import _Merge, Dot 8 | from keras.layers.wrappers import Wrapper 9 | from keras.utils.generic_utils import has_arg 10 | import keras.backend as K 11 | 12 | 13 | 14 | class ModelWrapper(Model): 15 | @property 16 | def inner_models(self) -> List[Model]: 17 | return [v for v in self.layers if isinstance(v, Model)] 18 | 19 | def save(self, filepath, overwrite=True, include_optimizer=True): 20 | kwargs = dict(overwrite=overwrite, include_optimizer=include_optimizer) 21 | filename, ext = path.splitext(filepath) 22 | super(ModelWrapper, self).save(filepath, **kwargs) 23 | for model in self.inner_models: 24 | fullpath = f"{filename}-{model.name}{ext}" 25 | model.save(fullpath, **kwargs) 26 | 27 | def summary(self, line_length=None, positions=None, print_fn=print): 28 | kwargs = dict(line_length=line_length, positions=positions, 29 | print_fn=print_fn) 30 | for model in self.inner_models: 31 | print_fn(model.name) 32 | model.summary(**kwargs) 33 | print("Main model:") 34 | super(ModelWrapper, self).summary(**kwargs) 35 | 36 | 37 | class SplitInput(Wrapper): 38 | def __init__(self, layer, weights=None, **kwargs): 39 | super(SplitInput, self).__init__(layer, **kwargs) 40 | self.forward_layer = copy.copy(layer) 41 | self.backward_layer = layer.__class__.from_config(layer.get_config()) 42 | self.forward_layer.name = 'forward_' + self.forward_layer.name 43 | self.backward_layer.name = 'backward_' + self.backward_layer.name 44 | if weights: 45 | nw = len(weights) 46 | self.forward_layer.initial_weights = weights[:nw // 2] 47 | self.backward_layer.initial_weights = weights[nw // 2:] 48 | self.stateful = layer.stateful 49 | self.return_sequences = layer.return_sequences 50 | self.supports_masking = True 51 | 52 | def get_weights(self): 53 | return self.forward_layer.get_weights() + self.backward_layer.get_weights() 54 | 55 | def set_weights(self, weights): 56 | nw = len(weights) 57 | self.forward_layer.set_weights(weights[:nw // 2]) 58 | self.backward_layer.set_weights(weights[nw // 2:]) 59 | 60 | def compute_output_shape(self, input_shape): 61 | shape = list(input_shape) 62 | shape[1] //= 2 63 | output_shape = list(self.forward_layer.compute_output_shape(tuple(shape))) 64 | output_shape[1] *= 2 65 | return tuple(output_shape) 66 | 67 | def call(self, inputs, training=None, mask=None): 68 | kwargs = {} 69 | if has_arg(self.layer.call, 'training'): 70 | kwargs['training'] = training 71 | if has_arg(self.layer.call, 'mask'): 72 | kwargs['mask'] = mask 73 | 74 | ni = inputs.shape[1] 75 | y = self.forward_layer.call(inputs[:, :ni // 2], **kwargs) 76 | y_rev = self.backward_layer.call(inputs[:, ni // 2:], **kwargs) 77 | output = K.concatenate([y, y_rev]) 78 | 79 | # Properly set learning phase 80 | if self.layer.dropout + self.layer.recurrent_dropout > 0: 81 | output._uses_learning_phase = True 82 | return output 83 | 84 | def reset_states(self): 85 | self.forward_layer.reset_states() 86 | self.backward_layer.reset_states() 87 | 88 | def build(self, input_shape=None): 89 | with K.name_scope(self.forward_layer.name): 90 | self.forward_layer.build(input_shape) 91 | with K.name_scope(self.backward_layer.name): 92 | self.backward_layer.build(input_shape) 93 | self.built = True 94 | 95 | def compute_mask(self, inputs, mask=None): 96 | if self.return_sequences: 97 | return mask 98 | else: 99 | return None 100 | 101 | @property 102 | def trainable_weights(self): 103 | if hasattr(self.forward_layer, 'trainable_weights'): 104 | return (self.forward_layer.trainable_weights + 105 | self.backward_layer.trainable_weights) 106 | return [] 107 | 108 | @property 109 | def non_trainable_weights(self): 110 | if hasattr(self.forward_layer, 'non_trainable_weights'): 111 | return (self.forward_layer.non_trainable_weights + 112 | self.backward_layer.non_trainable_weights) 113 | return [] 114 | 115 | @property 116 | def updates(self): 117 | if hasattr(self.forward_layer, 'updates'): 118 | return self.forward_layer.updates + self.backward_layer.updates 119 | return [] 120 | 121 | @property 122 | def losses(self): 123 | if hasattr(self.forward_layer, 'losses'): 124 | return self.forward_layer.losses + self.backward_layer.losses 125 | return [] 126 | 127 | @property 128 | def constraints(self): 129 | constr = {} 130 | if hasattr(self.forward_layer, 'constraints'): 131 | constr.update(self.forward_layer.constraints) 132 | constr.update(self.backward_layer.constraints) 133 | return constr 134 | 135 | 136 | class _PairMerge(_Merge): 137 | def _check_merge_inputs(self, inputs): 138 | class_name = self.__class__.__name__ 139 | if len(inputs) != 2: 140 | raise ValueError('`{0}` layer should be called ' 141 | 'on exactly 2 inputs'.format(class_name)) 142 | if K.int_shape(inputs[0]) != K.int_shape(inputs[1]): 143 | raise ValueError('`{0}` layer should be called ' 144 | 'on inputs of the same shape'.format(class_name)) 145 | 146 | 147 | class AbsDiff(_PairMerge): 148 | def _merge_function(self, inputs): 149 | self._check_merge_inputs(inputs) 150 | return K.abs(inputs[0] - inputs[1]) 151 | 152 | 153 | class EuclideanDistance(_PairMerge): 154 | def __init__(self, *args, max_value=None, normalize=False, **kwargs): 155 | super(EuclideanDistance, self).__init__(*args, **kwargs) 156 | self.max_value = max_value 157 | self.normalize = normalize 158 | if self.normalize and not self.max_value: 159 | raise ValueError("max_value must be provided to normalize output") 160 | 161 | def _merge_function(self, inputs): 162 | self._check_merge_inputs(inputs) 163 | distance = norm(inputs[0] - inputs[1]) 164 | if self.max_value: 165 | distance = K.clip(distance, 0, self.max_value) 166 | if self.normalize: 167 | distance /= self.max_value 168 | return distance 169 | 170 | def compute_output_shape(self, input_shape): 171 | shape = list(input_shape) 172 | return (shape[0], 1) 173 | 174 | 175 | class EuclideanSimilarity(EuclideanDistance): 176 | def __init__(self, *args, max_value=None, **kwargs): 177 | super(EuclideanSimilarity, self).__init__( 178 | *args, max_value=max_value, normalize=True, **kwargs) 179 | if not max_value: 180 | raise ValueError("max_value must be provided") 181 | 182 | def _merge_function(self, inputs): 183 | distance = super(EuclideanSimilarity, self)._merge_function(inputs) 184 | return 1 - distance 185 | 186 | 187 | class CosineSimilarity(Dot): 188 | def __init__(self, *args, min_value=-1, **kwargs): 189 | # set default axis to 1 190 | if not args and "axes" not in kwargs: 191 | args = (1,) 192 | super(CosineSimilarity, self).__init__(*args, **kwargs) 193 | self.min_value = min_value 194 | 195 | def call(self, inputs): 196 | dot_product = super(CosineSimilarity, self).call(inputs) 197 | magnitude = norm(inputs[0]) * norm(inputs[1]) 198 | result = dot_product / magnitude 199 | return K.clip(result, self.min_value, 1) 200 | 201 | def get_config(self): 202 | config = {"min_value": self.min_value} 203 | base_config = super(CosineSimilarity, self).get_config() 204 | return dict(list(base_config.items()) + list(config.items())) 205 | 206 | 207 | def norm(x): 208 | return K.sqrt(K.sum(K.pow(x, 2))) 209 | 210 | 211 | def abs_diff(inputs, **kwargs): 212 | return AbsDiff(**kwargs)(inputs) 213 | 214 | 215 | def euclidean_distance(inputs, **kwargs): 216 | return EuclideanDistance(**kwargs)(inputs) 217 | 218 | 219 | def euclidean_similarity(inputs, **kwargs): 220 | return EuclideanSimilarity(**kwargs)(inputs) 221 | 222 | 223 | def cosine_similarity(inputs, **kwargs): 224 | return CosineSimilarity(**kwargs)(inputs) 225 | 226 | 227 | class DenseMulti(Layer): 228 | def __init__(self, units, 229 | activation=None, 230 | use_bias=True, 231 | kernel_initializer='glorot_uniform', 232 | bias_initializer='zeros', 233 | kernel_regularizer=None, 234 | bias_regularizer=None, 235 | activity_regularizer=None, 236 | kernel_constraint=None, 237 | bias_constraint=None, 238 | **kwargs): 239 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 240 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 241 | super(DenseMulti, self).__init__(**kwargs) 242 | self.units = units 243 | self.kernels = [] 244 | self.activation = activations.get(activation) 245 | self.use_bias = use_bias 246 | self.kernel_initializer = initializers.get(kernel_initializer) 247 | self.bias_initializer = initializers.get(bias_initializer) 248 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 249 | self.bias_regularizer = regularizers.get(bias_regularizer) 250 | self.activity_regularizer = regularizers.get(activity_regularizer) 251 | self.kernel_constraint = constraints.get(kernel_constraint) 252 | self.bias_constraint = constraints.get(bias_constraint) 253 | self.input_spec = [] 254 | self.supports_masking = False 255 | self.bias = None 256 | 257 | def build(self, input_shape): 258 | if not isinstance(input_shape, list): 259 | raise ValueError('`DenseMulti` layer should be called ' 260 | 'on a list of inputs') 261 | assert len(input_shape) >= 2 262 | 263 | for i, shape in enumerate(input_shape): 264 | assert len(shape) == 2 265 | assert shape[0] == input_shape[0][0] 266 | 267 | input_dim = shape[-1] 268 | 269 | self.kernels.append(self.add_weight(shape=(input_dim, self.units), 270 | initializer=self.kernel_initializer, 271 | name='kernel-{0}'.format(i), 272 | regularizer=self.kernel_regularizer, 273 | constraint=self.kernel_constraint)) 274 | self.input_spec.append(InputSpec(min_ndim=2, axes={-1: input_dim})) 275 | if self.use_bias: 276 | self.bias = self.add_weight(shape=(self.units,), 277 | initializer=self.bias_initializer, 278 | name='bias', 279 | regularizer=self.bias_regularizer, 280 | constraint=self.bias_constraint) 281 | self.built = True 282 | 283 | def call(self, inputs): 284 | assert len(inputs) == len(self.kernels) 285 | 286 | output = K.dot(inputs[0], self.kernels[0]) 287 | for i in range(1, len(inputs)): 288 | output += K.dot(inputs[i], self.kernels[i]) 289 | if self.use_bias: 290 | output = K.bias_add(output, self.bias) 291 | if self.activation is not None: 292 | output = self.activation(output) 293 | return output 294 | 295 | def compute_output_shape(self, input_shape): 296 | if not isinstance(input_shape, list): 297 | raise ValueError('`DenseMulti` layer should be called ' 298 | 'on a list of inputs') 299 | for shape in input_shape: 300 | assert len(shape) == 2 301 | assert shape[0] == input_shape[0][0] 302 | assert input_shape[0][-1] 303 | output_shape = list(input_shape[0]) 304 | output_shape[-1] = self.units 305 | return tuple(output_shape) 306 | 307 | def get_config(self): 308 | config = { 309 | 'units': self.units, 310 | 'activation': activations.serialize(self.activation), 311 | 'use_bias': self.use_bias, 312 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 313 | 'bias_initializer': initializers.serialize(self.bias_initializer), 314 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 315 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 316 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 317 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 318 | 'bias_constraint': constraints.serialize(self.bias_constraint) 319 | } 320 | base_config = super(DenseMulti, self).get_config() 321 | return dict(list(base_config.items()) + list(config.items())) 322 | 323 | 324 | custom_objects = { 325 | "SplitInput": SplitInput, 326 | "AbsDiff": AbsDiff, 327 | "DenseMulti": DenseMulti, 328 | "ModelWrapper": ModelWrapper, 329 | "CosineSimilarity": CosineSimilarity, 330 | "EuclideanSimilarity": EuclideanSimilarity, 331 | } 332 | -------------------------------------------------------------------------------- /suplearn_clone_detection/model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | import numpy as np 3 | 4 | from keras import optimizers 5 | from keras.models import Model, Input 6 | from keras.engine.topology import Layer 7 | from keras.layers import LSTM, Bidirectional, Embedding, concatenate, Dense, multiply 8 | 9 | 10 | from suplearn_clone_detection import ast_transformer 11 | from suplearn_clone_detection.config import LanguageConfig, ModelConfig 12 | from suplearn_clone_detection.layers import SplitInput, abs_diff, DenseMulti, \ 13 | euclidean_similarity, cosine_similarity, ModelWrapper 14 | 15 | 16 | def make_embeddings(lang_config: LanguageConfig, index: int): 17 | embedding_input_size = lang_config.vocabulary_size + lang_config.vocabulary_offset 18 | kwargs = dict( 19 | name="embedding_{0}_{1}".format(lang_config.name, index), 20 | mask_zero=True 21 | ) 22 | if lang_config.embeddings: 23 | weights = np.load(lang_config.embeddings) 24 | padding = np.zeros((lang_config.vocabulary_offset, lang_config.embeddings_dimension)) 25 | kwargs["weights"] = [np.vstack([padding, weights])] 26 | 27 | return Embedding(embedding_input_size, lang_config.embeddings_dimension, **kwargs) 28 | 29 | 30 | def create_lstm(output_dimension: int, 31 | lang_config: LanguageConfig, 32 | transformer: ast_transformer.ASTTransformer, 33 | index: int, 34 | position: int, 35 | return_sequences: bool) -> Layer: 36 | 37 | lstm = LSTM(output_dimension, 38 | return_sequences=return_sequences, 39 | name="lstm_{0}_{1}_{2}".format(lang_config.name, index, position)) 40 | 41 | if transformer.split_input: 42 | lstm = SplitInput( 43 | lstm, name="bidfs_lstm_{0}_{1}_{2}".format(lang_config.name, index, position)) 44 | 45 | if lang_config.bidirectional_encoding: 46 | if transformer.split_input: 47 | raise ValueError("bidirectional_encoding cannot be used with {0}".format( 48 | lang_config.transformer_class_name)) 49 | lstm = Bidirectional( 50 | lstm, name="bilstm_{0}_{1}_{2}".format(lang_config.name, index, position)) 51 | 52 | return lstm 53 | 54 | 55 | def create_encoder(lang_config: LanguageConfig, index: int): 56 | transformer = ast_transformer.create(lang_config) 57 | ast_input = Input(shape=(None,), 58 | dtype="int32", name="input_{0}_{1}".format(lang_config.name, index)) 59 | 60 | x = make_embeddings(lang_config, index)(ast_input) 61 | 62 | for i, n in enumerate(lang_config.output_dimensions[:-1]): 63 | x = create_lstm(n, lang_config, transformer, 64 | index=index, position=i + 1, return_sequences=True)(x) 65 | 66 | x = create_lstm(lang_config.output_dimensions[-1], lang_config, transformer, 67 | index=index, 68 | position=len(lang_config.output_dimensions), 69 | return_sequences=False)(x) 70 | 71 | output_dimension = lang_config.output_dimensions[-1] 72 | if lang_config.bidirectional_encoding: 73 | output_dimension *= 2 74 | 75 | for i, dim in enumerate(lang_config.hash_dims): 76 | is_last = i == len(lang_config.hash_dims) - 1 77 | activation = None if is_last else "relu" 78 | x = Dense(dim, use_bias=not is_last, activation=activation, 79 | name="dense_{0}_{1}_{2}".format(lang_config.name, index, i))(x) 80 | if is_last: 81 | output_dimension = dim 82 | 83 | encoder = Model(inputs=ast_input, outputs=x, 84 | name="encoder_{0}_{1}".format(lang_config.name, index)) 85 | encoder.output_dimension = output_dimension 86 | return ast_input, encoder 87 | 88 | def make_merge_model(model_config: ModelConfig, input_lang1, input_lang2): 89 | if model_config.merge_mode == "simple": 90 | x = concatenate([input_lang1, input_lang2]) 91 | elif model_config.merge_mode == "bidistance": 92 | hx = multiply([input_lang1, input_lang2]) 93 | hp = abs_diff([input_lang1, input_lang2]) 94 | x = DenseMulti(model_config.merge_output_dim)([hx, hp]) 95 | elif model_config.merge_mode == "euclidean_similarity": 96 | x = euclidean_similarity([input_lang1, input_lang2], 97 | max_value=model_config.normalization_value) 98 | elif model_config.merge_mode == "cosine_similarity": 99 | x = cosine_similarity([input_lang1, input_lang2], min_value=0) 100 | else: 101 | raise ValueError("invalid merge mode") 102 | 103 | if model_config.use_output_nn: 104 | for i, layer_size in enumerate(model_config.dense_layers): 105 | name = "distance_dense_{0}".format(i) 106 | x = Dense(layer_size, activation="relu", name=name)(x) 107 | x = Dense(1, activation="sigmoid", name="main_output")(x) 108 | 109 | model = Model(inputs=[input_lang1, input_lang2], outputs=x, 110 | name="merge_model") 111 | 112 | return model 113 | 114 | 115 | def create_merge_input(lang_config: LanguageConfig, 116 | input_dimension: Tuple[Optional[int]], 117 | index: int): 118 | return Input(shape=(input_dimension,), 119 | name="encoded_{0}_{1}".format(lang_config.name, index)) 120 | 121 | 122 | def create_model(model_config: ModelConfig): 123 | lang1_config, lang2_config = model_config.languages 124 | 125 | input_lang1, encoder_lang1 = create_encoder(lang1_config, 1) 126 | input_lang2, encoder_lang2 = create_encoder(lang2_config, 2) 127 | 128 | lang1_merge_input = create_merge_input(lang1_config, encoder_lang1.output_dimension, 1) 129 | lang2_merge_input = create_merge_input(lang2_config, encoder_lang2.output_dimension, 2) 130 | merge_model = make_merge_model(model_config, lang1_merge_input, lang2_merge_input) 131 | 132 | output_lang1 = encoder_lang1(input_lang1) 133 | output_lang2 = encoder_lang2(input_lang2) 134 | 135 | result = merge_model([output_lang1, output_lang2]) 136 | 137 | model = ModelWrapper(inputs=[input_lang1, input_lang2], outputs=result) 138 | optimizer_class = getattr(optimizers, model_config.optimizer["type"]) 139 | optimizer = optimizer_class(**model_config.optimizer.get("options", {})) 140 | model.compile(optimizer=optimizer, loss=model_config.loss, metrics=model_config.metrics) 141 | 142 | return model 143 | -------------------------------------------------------------------------------- /suplearn_clone_detection/predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from suplearn_clone_detection.config import Config 7 | from suplearn_clone_detection.file_processor import FileProcessor 8 | 9 | 10 | class Predictor(FileProcessor): 11 | def __init__(self, config: Config, model: 'keras.models.Model', options: dict = None): 12 | super(Predictor, self).__init__(config, model, options) 13 | self._predictions = [] 14 | 15 | def predict(self, files: List[Tuple[str, str]]) -> List[float]: 16 | batch_size = self.options.get("batch_size") or self.config.trainer.batch_size 17 | predictions = {} 18 | for i in tqdm(range(len(files) // batch_size + 1)): 19 | batch_files = files[i * batch_size:(i + 1) * batch_size] 20 | to_predict, input_data, assumed_false = self._generate_vectors(batch_files) 21 | if not to_predict: 22 | continue 23 | for file_pair in assumed_false: 24 | predictions[file_pair] = 0. 25 | preds = self.model.predict(input_data, batch_size=batch_size) 26 | for file_pair, pred in zip(to_predict, preds): 27 | predictions[file_pair] = float(pred[0]) 28 | return self._save_predictions(files, predictions) 29 | 30 | def _below_size_threshold(self, lang1_ast, lang2_ast): 31 | max_size_diff = self.options.get("max_size_diff") 32 | if not max_size_diff: 33 | return True 34 | size_ratio = len(lang1_ast) / len(lang2_ast) 35 | return abs(1 - size_ratio) <= self.options["max_size_diff"] 36 | 37 | def _generate_vectors(self, files: List[Tuple[str, str]]) \ 38 | -> Tuple[List[Tuple[str, str]], List[np.array], List[Tuple[str, str]]]: 39 | lang1_vectors = [] 40 | lang2_vectors = [] 41 | to_predict = [] 42 | assumed_false = [] 43 | for (lang1_file, lang2_file) in files: 44 | lang1_ast, lang1_vec = self.get_file_vector(lang1_file, self.language_names[0]) 45 | lang2_ast, lang2_vec = self.get_file_vector(lang2_file, self.language_names[1]) 46 | if not lang1_vec or not lang2_vec: 47 | continue 48 | if self._below_size_threshold(lang1_ast, lang2_ast): 49 | lang1_vectors.append(lang1_vec) 50 | lang2_vectors.append(lang2_vec) 51 | to_predict.append((lang1_file, lang2_file)) 52 | else: 53 | assumed_false.append((lang1_file, lang2_file)) 54 | 55 | return to_predict, [np.array(lang1_vectors), np.array(lang2_vectors)], assumed_false 56 | 57 | def _save_predictions(self, files, predictions): 58 | ordered_predictions = [] 59 | for pair in files: 60 | if pair in predictions: 61 | ordered_predictions.append((pair, predictions[pair])) 62 | self._predictions += ordered_predictions 63 | return ordered_predictions 64 | 65 | @property 66 | def formatted_predictions(self): 67 | formatted_predictions = [] 68 | for ((lang1_file, lang2_file), prediction) in self._predictions: 69 | formatted = "{0},{1},{2}".format(lang1_file, lang2_file, prediction) 70 | formatted_predictions.append(formatted) 71 | return "\n".join(formatted_predictions) 72 | -------------------------------------------------------------------------------- /suplearn_clone_detection/results_printer.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | METRIC_NAMES = { 6 | "accuracy": "Accuracy", 7 | "precision": "Precision", 8 | "recall": "Recall", 9 | "avg_precision": "Avg. Precision", 10 | "pr_curve": "Precision-Recall curve", 11 | "f1": "F1 score", 12 | } 13 | 14 | 15 | class ResultsPrinter: 16 | def __init__(self, results_path: str): 17 | with open(results_path) as f: 18 | self.results = yaml.load(f) 19 | 20 | def show(self, metric: str, output: str): 21 | method = getattr(self, "show_{0}".format(metric), None) 22 | if method: 23 | method(output) 24 | elif metric in self.results: 25 | print(self.results[metric]) 26 | else: 27 | raise ValueError("unknown metric {0}".format(metric)) 28 | 29 | def show_pr_curve(self, output: str): 30 | pr_curve = self.results["pr_curve"] 31 | precision, recall = pr_curve["precision"], pr_curve["recall"] 32 | 33 | plt.step(recall, precision, color="b", alpha=0.2, where="post") 34 | plt.fill_between(recall, precision, step="post", alpha=0.2, color="b") 35 | plt.xlabel("Recall") 36 | plt.ylabel("Precision") 37 | plt.ylim([0.0, 1.05]) 38 | plt.xlim([0.0, 1.0]) 39 | plt.title("Precision-Recall curve: AP={0:0.2f}".format(self.results["avg_precision"])) 40 | 41 | if output: 42 | plt.savefig(output) 43 | else: 44 | plt.show() 45 | 46 | def show_summary(self, output: str): 47 | results = self.results.copy() 48 | del results["pr_curve"] 49 | formatted_results = self.format_table(results) 50 | if output: 51 | with open(output, "w") as f: 52 | f.write(formatted_results) 53 | else: 54 | print(formatted_results, end="") 55 | 56 | @staticmethod 57 | def format_table(dict_object): 58 | key_width = max(len(k) for k in dict_object.keys()) + 1 59 | value_width = len("Value") 60 | 61 | separator = "+-{0}-+-{1}-+\n".format("-" * key_width, "-" * value_width) 62 | header = "| {0:>{key_width}} | {1:>{value_width}} |\n".format( 63 | "Metric", "Value", key_width=key_width, value_width=value_width) 64 | value_tpl = "| {0:>{key_width}} | {1:<{value_width}.2f} |\n" 65 | 66 | s = separator + header + separator 67 | for key, value in dict_object.items(): 68 | s += value_tpl.format(METRIC_NAMES[key], value, 69 | key_width=key_width, value_width=value_width) 70 | s += separator 71 | return s 72 | -------------------------------------------------------------------------------- /suplearn_clone_detection/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | TF_GPU_MAX_MEMORY_USAGE = float(os.environ.get("TF_GPU_MAX_MEMORY_USAGE", "0.9")) 5 | -------------------------------------------------------------------------------- /suplearn_clone_detection/token_based/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danhper/suplearn-clone-detection/81c485c1d7ef98c417427a6db5cae7ae2b2541ce/suplearn_clone_detection/token_based/__init__.py -------------------------------------------------------------------------------- /suplearn_clone_detection/token_based/commands.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from suplearn_clone_detection.token_based import vocabulary_generator 3 | from suplearn_clone_detection.token_based.skipgram_generator import SkipgramGenerator 4 | 5 | 6 | def check_path(filepath: str): 7 | if not path.exists(filepath): 8 | raise ValueError("{0} does not exist".format(filepath)) 9 | 10 | 11 | def create_vocabulary(filepath: str, size: int, include_values: bool, output: str): 12 | check_path(filepath) 13 | vocab = vocabulary_generator.generate_vocabulary(filepath, size, include_values) 14 | vocab.save(output) 15 | 16 | 17 | def generate_skipgram_data(filepath: str, vocabulary_path: str, window_size: int, output: str): 18 | check_path(filepath) 19 | check_path(vocabulary_path) 20 | skipgram_generator = SkipgramGenerator(filepath, vocabulary_path) 21 | skipgram_generator.generate_skipgram_data(window_size, output) 22 | -------------------------------------------------------------------------------- /suplearn_clone_detection/token_based/skipgram_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import gzip 4 | from typing import Tuple, List 5 | 6 | from suplearn_clone_detection.vocabulary import Vocabulary 7 | from suplearn_clone_detection.token_based import util 8 | 9 | 10 | class SkipgramGenerator: 11 | def __init__(self, tokens_path: str, vocabulary_path: str) -> None: 12 | self.tokens_path = tokens_path 13 | self.vocabulary = Vocabulary.from_file(vocabulary_path) 14 | 15 | def generate_skipgram_data(self, window_size: int, output: str) -> None: 16 | total_lines = util.get_lines_count(self.tokens_path) 17 | logging.info("generating skipgram data from %s files", total_lines) 18 | with util.open_file(self.tokens_path) as token_files, \ 19 | gzip.open(output, "wb") as output_file: 20 | for i, row in enumerate(token_files): 21 | tokens = json.loads(row) 22 | if isinstance(tokens, dict) and "tokens" in tokens: 23 | tokens = tokens["tokens"] 24 | for target, context in self._generate_context_pairs(tokens, window_size): 25 | output_file.write("{0},{1}".format(target, context).encode("utf-8")) 26 | output_file.write(b"\n") 27 | 28 | if i > 0 and i % 1000 == 0: 29 | logging.info("progress: %s/%s", i, total_lines) 30 | 31 | def _generate_context_pairs(self, tokens: List[dict], window_size: int) \ 32 | -> List[Tuple[int, int]]: 33 | for i, target in enumerate(tokens): 34 | target_index = self.vocabulary.index(target) 35 | for j in range(max(0, i - window_size), min(i + window_size + 1, len(tokens))): 36 | if i != j: 37 | context_index = self.vocabulary.index(tokens[j]) 38 | yield (target_index, context_index) 39 | -------------------------------------------------------------------------------- /suplearn_clone_detection/token_based/util.py: -------------------------------------------------------------------------------- 1 | from typing import TextIO 2 | import subprocess 3 | import gzip 4 | 5 | 6 | def get_lines_count(filepath: str) -> int: 7 | if filepath.endswith(".gz"): 8 | command = "cat {0} | gunzip | wc -l".format(filepath) 9 | else: 10 | command = "wc -l {0}".format(filepath) 11 | stdout, _stderr = subprocess.Popen(command, stdout=subprocess.PIPE, 12 | shell=True).communicate() 13 | return int(stdout.decode("utf-8").split(" ")[0]) 14 | 15 | 16 | def open_file(filepath: str) -> TextIO: 17 | if filepath.endswith(".gz"): 18 | return gzip.open(filepath) 19 | else: 20 | return open(filepath) 21 | -------------------------------------------------------------------------------- /suplearn_clone_detection/token_based/vocab_item.py: -------------------------------------------------------------------------------- 1 | from functools import total_ordering 2 | from typing import Optional 3 | 4 | 5 | @total_ordering 6 | class VocabItem: 7 | def __init__(self, token_type: str, token_value: Optional[str]): 8 | self.type = token_type 9 | self.value = token_value 10 | self.count = 0 11 | 12 | def __hash__(self): 13 | return hash((self.type, self.value)) 14 | 15 | def __eq__(self, other): 16 | return isinstance(other, VocabItem) and \ 17 | ((self.type, self.value, self.count) == (other.type, other.value, other.count)) 18 | 19 | def __lt__(self, other): 20 | if not isinstance(other, VocabItem): 21 | return NotImplemented 22 | if (self.value is None) == (other.value is None): 23 | return self.count < other.count 24 | return other.value is None 25 | 26 | def make_key(self, include_values: bool): 27 | key = (self.type,) 28 | if include_values: 29 | key += (self.value,) 30 | return key 31 | 32 | def make_token(self, index: int): 33 | meta_type = self.type.split(".")[0] 34 | return dict( 35 | id=index, 36 | type=self.type, 37 | metaType=meta_type, 38 | value=self.value, 39 | count=self.count 40 | ) 41 | -------------------------------------------------------------------------------- /suplearn_clone_detection/token_based/vocabulary_generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Optional, Dict 3 | 4 | import json 5 | from suplearn_clone_detection.vocabulary import Vocabulary 6 | from suplearn_clone_detection.token_based import util 7 | from suplearn_clone_detection.token_based.vocab_item import VocabItem 8 | 9 | 10 | def get_or_add_token(counts: Dict[Tuple[str, Optional[str]], VocabItem], 11 | token_type: str, token_value: Optional[str]) -> VocabItem: 12 | key = (token_type, token_value) 13 | if key in counts: 14 | return counts[key] 15 | item = VocabItem(token_type, token_value) 16 | counts[key] = item 17 | return item 18 | 19 | 20 | def generate_vocabulary(filepath: str, size: int, include_values: bool) -> Vocabulary: 21 | counts = {} 22 | total_lines = util.get_lines_count(filepath) 23 | logging.info("generating vocabulary from %s files", total_lines) 24 | with util.open_file(filepath) as f: 25 | for i, row in enumerate(f): 26 | tokens = json.loads(row) 27 | if isinstance(tokens, dict) and "tokens" in tokens: 28 | tokens = tokens["tokens"] 29 | for token in tokens: 30 | if include_values: 31 | get_or_add_token(counts, token["type"], token.get("value")).count += 1 32 | if not include_values or token.get("value") is not None: 33 | get_or_add_token(counts, token["type"], None).count += 1 34 | if i > 0 and i % 1000 == 0: 35 | logging.info("progress: %s/%s", i, total_lines) 36 | vocab_items = sorted(counts.values(), reverse=True)[:size] 37 | entries = {item.make_key(include_values): item.make_token(i) 38 | for i, item in enumerate(vocab_items)} 39 | return Vocabulary(entries=entries, has_values=include_values) 40 | -------------------------------------------------------------------------------- /suplearn_clone_detection/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from datetime import datetime 4 | import logging 5 | 6 | import tensorflow as tf 7 | from keras.callbacks import TensorBoard 8 | 9 | import yaml 10 | 11 | from suplearn_clone_detection.config import Config 12 | from suplearn_clone_detection import database, callbacks, util 13 | from suplearn_clone_detection.dataset.sequences import TrainingSequence, DevSequence 14 | from suplearn_clone_detection.model import create_model 15 | 16 | 17 | class Trainer: 18 | def __init__(self, config_path: str, quiet: bool = False): 19 | with open(config_path) as f: 20 | self.raw_config = f.read() 21 | self.config = Config(yaml.load(self.raw_config)) 22 | database.bind_db(self.config.generator.db_path) 23 | self.model = None 24 | self.training_data = None 25 | self.dev_data = None 26 | self.quiet = quiet 27 | 28 | def initialize(self): 29 | self.model = create_model(self.config.model) 30 | graph = tf.get_default_graph() 31 | self.training_data = TrainingSequence(self.model, graph, self.config) 32 | self.dev_data = DevSequence(self.config) 33 | os.makedirs(self.output_dir) 34 | with open(path.join(self.output_dir, "config.yml"), "w") as f: 35 | f.write(self.raw_config) 36 | for transformer in self.training_data.ast_transformers.values(): 37 | vocab = transformer.vocabulary 38 | vocab.save(self._vocab_path(transformer.language), 39 | offset=transformer.vocabulary_offset) 40 | 41 | def train(self): 42 | logging.info("starting training, outputing to %s", self.output_dir) 43 | 44 | model_path = path.join(self.output_dir, "model-{epoch:02d}.h5") 45 | results_path = path.join(self.output_dir, "results-dev-{epoch:02d}.yml") 46 | 47 | results_tracker = callbacks.ModelResultsTracker(self.dev_data, self.model) 48 | checkpoint_callback = callbacks.ModelCheckpoint( 49 | results_tracker, model_path, save_best_only=True) 50 | evaluate_callback = callbacks.ModelEvaluator( 51 | results_tracker, results_path, quiet=self.quiet, save_best_only=True) 52 | model_callbacks = [checkpoint_callback, evaluate_callback] 53 | 54 | if self.config.trainer.tensorboard_logs: 55 | tensorboard_logs_path = path.join(self.output_dir, "tf-logs") 56 | metadata = {} 57 | for lang_config in self.config.model.languages: 58 | vocab_path = path.relpath(self._vocab_path(lang_config.name), 59 | tensorboard_logs_path) 60 | metadata["embedding_{0}".format(lang_config.name)] = vocab_path 61 | model_callbacks.append(TensorBoard(tensorboard_logs_path)) 62 | # TODO: restore embeddings 63 | # embeddings_freq=1, 64 | # embeddings_metadata=metadata)) 65 | 66 | self.model.fit_generator( 67 | self.training_data, 68 | validation_data=self.dev_data, 69 | epochs=self.config.trainer.epochs, 70 | callbacks=model_callbacks) 71 | 72 | def _vocab_path(self, lang): 73 | return path.join(self.output_dir, "vocab-{0}.tsv".format(lang)) 74 | 75 | @property 76 | @util.memoize 77 | def output_dir(self): 78 | output_dir = datetime.now().strftime("%Y%m%d-%H%M") 79 | return path.join(self.config.trainer.output_dir, output_dir) 80 | -------------------------------------------------------------------------------- /suplearn_clone_detection/util.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import functools 3 | 4 | import h5py 5 | 6 | 7 | def filename_language(filename, available_languages): 8 | _basename, ext = path.splitext(filename) 9 | for known_lang in available_languages: 10 | if known_lang.startswith(ext[1:]): 11 | return known_lang 12 | raise ValueError("no language found for {0}".format(filename)) 13 | 14 | 15 | def in_batch(iterable, batch_size): 16 | batch = [] 17 | for value in iterable: 18 | batch.append(value) 19 | if len(batch) == batch_size: 20 | yield batch 21 | batch = [] 22 | if batch: 23 | yield batch 24 | 25 | 26 | def group_by(iterable, key): 27 | iterator = iter(iterable) 28 | grouped = {} 29 | while True: 30 | try: 31 | value = next(iterator) 32 | value_key = key(value) 33 | grouped.setdefault(value_key, []) 34 | grouped[value_key].append(value) 35 | except StopIteration: 36 | break 37 | return grouped 38 | 39 | 40 | def memoize(make_key): 41 | def wrapped(f): 42 | memoized = {} 43 | 44 | def wrapper(*args, **kwargs): 45 | key = make_key(*args, **kwargs) 46 | if key not in memoized: 47 | memoized[key] = f(*args) 48 | return memoized[key] 49 | 50 | return functools.update_wrapper(wrapper, f) 51 | 52 | if not callable(make_key) or not hasattr(make_key, "__name__"): 53 | raise ValueError("@memoize argument should be a lambda") 54 | 55 | if make_key.__name__ == "": 56 | return wrapped 57 | else: 58 | f = make_key 59 | make_key = lambda *args, **kwargs: (tuple(args), (tuple(kwargs.items()))) 60 | return wrapped(f) 61 | 62 | 63 | def hdf5_keys(dataset: h5py.File): 64 | keys = [] 65 | def visitor(key, value): 66 | if isinstance(value, h5py.Dataset): 67 | keys.append(key) 68 | dataset.visititems(visitor) 69 | return keys 70 | 71 | def hdf5_key_pairs(dataset: h5py.File, lang_left: str, lang_right: str): 72 | keys = hdf5_keys(dataset) 73 | def has_correct_ext(filename, expected): 74 | ext = path.splitext(filename)[1][1:] 75 | return expected.startswith(ext) 76 | for i, left in enumerate(keys): 77 | for right in keys[i + 1:]: 78 | if has_correct_ext(left, lang_left) and has_correct_ext(right, lang_right): 79 | yield (left, right) 80 | -------------------------------------------------------------------------------- /suplearn_clone_detection/vectorizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import keras.backend as K 4 | from keras.preprocessing.sequence import pad_sequences 5 | import tensorflow as tf 6 | import h5py 7 | from tqdm import tqdm 8 | 9 | from suplearn_clone_detection.file_processor import FileProcessor 10 | from suplearn_clone_detection.config import Config 11 | from suplearn_clone_detection.layers import ModelWrapper 12 | from suplearn_clone_detection import util 13 | 14 | 15 | class Vectorizer(FileProcessor): 16 | def __init__(self, config: Config, model: ModelWrapper, options: dict = None): 17 | super(Vectorizer, self).__init__(config, model, options) 18 | self.encoders = {} 19 | for i, lang in enumerate(config.model.languages): 20 | encoder_index = options.get("encoder_index") 21 | if not lang.name in self.encoders and \ 22 | (not encoder_index or encoder_index == i): 23 | self.encoders[lang.name] = model.inner_models[i] 24 | 25 | def vectorize(self, filenames: List[str], language: str, sess: tf.Session): 26 | input_filenames = [] 27 | input_vectors = [] 28 | for filename in filenames: 29 | _ast, input_vector = self.get_file_vector(filename, language) 30 | if input_vector: 31 | input_filenames.append(filename) 32 | input_vectors.append(input_vector) 33 | model_input = tf.constant(pad_sequences(input_vectors)) 34 | vectors = sess.run(self.encoders[language](model_input)) 35 | return zip(input_filenames, vectors) 36 | 37 | def process(self, input_filenames: List[str], output: str): 38 | by_lang = self._group_filenames(input_filenames) 39 | batch_size = self.options.get("batch_size") or self.config.trainer.batch_size 40 | sess = K.get_session() 41 | with h5py.File(output, "w") as f, \ 42 | tqdm(total=len(input_filenames)) as pbar: 43 | for lang, lang_filenames in by_lang.items(): 44 | for filenames in util.in_batch(lang_filenames, batch_size): 45 | for filename, vector in self.vectorize(filenames, lang, sess): 46 | f.create_dataset(filename, data=vector) 47 | pbar.update(len(filenames)) 48 | 49 | def _group_filenames(self, filenames): 50 | key = lambda filename: util.filename_language(filename, self.encoders) 51 | return util.group_by(filenames, key) 52 | -------------------------------------------------------------------------------- /suplearn_clone_detection/vocabulary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | 5 | BASE_HEADERS = ["id", "type", "metaType", "count"] 6 | 7 | 8 | class Vocabulary: 9 | @classmethod 10 | def from_file(cls, filepath, fallback_empty_value=True): 11 | entries, has_values = cls._parse_file(filepath) 12 | vocab = Vocabulary( 13 | entries=entries, 14 | has_values=has_values, 15 | fallback_empty_value=fallback_empty_value) 16 | return vocab 17 | 18 | def __init__(self, entries=None, has_values=True, fallback_empty_value=True): 19 | self.headers = BASE_HEADERS.copy() 20 | if has_values: 21 | self.headers.append("value") 22 | self.has_values = has_values 23 | self.fallback_empty_value = fallback_empty_value 24 | if not entries: 25 | entries = {} 26 | self.entries = entries 27 | 28 | def __eq__(self, other): 29 | if not isinstance(other, Vocabulary): 30 | return False 31 | attrs = ["headers", "entries", "has_values", "fallback_empty_value"] 32 | return all(getattr(self, attr) == getattr(other, attr) for attr in attrs) 33 | 34 | @staticmethod 35 | def _parse_file(filepath): 36 | entries = {} 37 | with open(filepath, "r", newline="") as f: 38 | headers = next(f).strip().split("\t") 39 | assert BASE_HEADERS == headers[:len(BASE_HEADERS)] 40 | has_values = "value" in headers 41 | entries = {} 42 | for row in f: 43 | row = row.strip().split("\t") 44 | entry = dict(id=int(row[0]), type=row[1], metaType=row[2], count=int(row[3])) 45 | if has_values: 46 | entry["value"] = json.loads(row[4]) if len(row) > 4 and row[4] else None 47 | key = (entry["type"],) 48 | if has_values: 49 | key += (entry["value"],) 50 | entry["id"] = int(entry["id"]) 51 | entries[key] = entry 52 | return entries, has_values 53 | 54 | def __len__(self): 55 | return len(self.entries) 56 | 57 | def __getitem__(self, token): 58 | if isinstance(token, tuple): 59 | return self.entries[token] 60 | 61 | if not self.has_values: 62 | return self.entries[(token["type"],)] 63 | 64 | key = (token["type"], token.get("value")) 65 | result = self.entries.get(key) 66 | if result is not None: 67 | return result 68 | elif self.fallback_empty_value: 69 | return self.entries[(token["type"], None)] 70 | else: 71 | raise KeyError(key) 72 | 73 | def index(self, token): 74 | return np.int32(self[token]["id"]) 75 | 76 | def save(self, path, offset=0): 77 | with open(path, "w") as f: 78 | print("\t".join(self.headers), file=f) 79 | for i in range(offset): 80 | row = [str(i), "offset-{0}".format(i), "padding", "0"] 81 | if self.has_values: 82 | row.append('"padding"') 83 | print("\t".join(row), file=f) 84 | for entry in sorted(self.entries.values(), key=lambda x: x["id"]): 85 | row = [str(entry["id"] + offset), entry["type"], 86 | entry.get("metaType", "Other"), str(entry["count"])] 87 | if self.has_values: 88 | row.append(json.dumps(entry["value"]) if entry["value"] else "") 89 | print("\t".join(row), file=f) 90 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danhper/suplearn-clone-detection/81c485c1d7ef98c417427a6db5cae7ae2b2541ce/tests/__init__.py -------------------------------------------------------------------------------- /tests/ast_loader_test.py: -------------------------------------------------------------------------------- 1 | from tests.base import TestCase 2 | 3 | from suplearn_clone_detection.ast_loader import ASTLoader 4 | 5 | class ASTLoaderTest(TestCase): 6 | @classmethod 7 | def setUpClass(cls): 8 | cls.loader = ASTLoader(cls.fixture_path("asts.json")) 9 | 10 | def test_load_names(self): 11 | self.assertEqual(len(self.loader.names), 5) 12 | 13 | def test_load_asts(self): 14 | self.assertEqual(len(self.loader.asts), 5) 15 | 16 | def test_random_ast(self): 17 | name, ast = self.loader.random_ast() 18 | self.assertIsInstance(name, str) 19 | self.assertIsInstance(ast, list) 20 | name, _ast = self.loader.random_ast(lambda name, _: name.endswith(".py")) 21 | self.assertTrue(name.endswith(".py")) 22 | name, _ast = self.loader.random_ast(lambda name, _: name.endswith(".java")) 23 | self.assertTrue(name.endswith(".java")) 24 | -------------------------------------------------------------------------------- /tests/ast_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tests.base import TestCase 4 | 5 | from suplearn_clone_detection import ast 6 | 7 | 8 | class AstTest(TestCase): 9 | ast_nodes = [{"type": "root", "children": [1, 2]}, 10 | {"type": "child1", "children": [3]}, 11 | {"type": "child2"}, 12 | {"type": "grand-child1"}] 13 | 14 | @classmethod 15 | def setUpClass(cls): 16 | with open(cls.fixture_path("asts.json")) as f: 17 | cls.asts = [json.loads(v) for v in f if v] 18 | 19 | def test_from_list_without_value(self): 20 | list_ast = [{"type": "foo"}] 21 | root = ast.from_list(list_ast) 22 | self.assertEqual(root.type, "foo") 23 | self.assertIsNone(root.value) 24 | self.assertEqual(len(root.children), 0) 25 | 26 | def test_from_list_with_value(self): 27 | list_ast = [{"type": "foo", "value": "bar"}] 28 | root = ast.from_list(list_ast) 29 | self.assertEqual(root.type, "foo") 30 | self.assertEqual(root.value, "bar") 31 | self.assertEqual(len(root.children), 0) 32 | 33 | def test_from_list_recursive(self): 34 | list_ast = [{"type": "foo", "value": "bar", "children": [1]}, 35 | {"type": "baz"}] 36 | root = ast.from_list(list_ast) 37 | self.assertEqual(root.type, "foo") 38 | self.assertEqual(root.value, "bar") 39 | self.assertEqual(len(root.children), 1) 40 | child = root.children[0] 41 | self.assertEqual(child.type, "baz") 42 | 43 | def test_from_list_complex(self): 44 | list_ast = self.asts[0] 45 | root = ast.from_list(list_ast) 46 | self.assertEqual(root.type, "CompilationUnit") 47 | 48 | def test_bfs(self): 49 | root = ast.from_list(self.ast_nodes) 50 | bfs_types = [node["type"] for node in root.bfs()] 51 | expected = ["root", "child1", "child2", "grand-child1"] 52 | self.assertEqual(expected, bfs_types) 53 | 54 | def test_dfs(self): 55 | root = ast.from_list(self.ast_nodes) 56 | dfs_types = [node["type"] for node in root.dfs()] 57 | expected = ["root", "child1", "grand-child1", "child2"] 58 | self.assertEqual(expected, dfs_types) 59 | 60 | def test_dfs_reverse(self): 61 | root = ast.from_list(self.ast_nodes) 62 | dfs_types = [node["type"] for node in root.dfs(reverse=True)] 63 | expected = ["root", "child2", "child1", "grand-child1"] 64 | self.assertEqual(expected, dfs_types) 65 | 66 | def _load_list_ast(self): 67 | with open(self.fixture_path("asts.json")) as f: 68 | return json.loads(next(f)) 69 | -------------------------------------------------------------------------------- /tests/ast_transformer_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | from suplearn_clone_detection.vocabulary import Vocabulary 5 | from suplearn_clone_detection.ast_transformer import DFSTransformer, BiDFSTransformer 6 | from tests.base import TestCase 7 | 8 | 9 | class TransformerTestCase(TestCase): 10 | @classmethod 11 | def setUpClass(cls): 12 | cls.vocabulary = Vocabulary.from_file(cls.fixture_path("vocab-noid.tsv")) 13 | with open(cls.fixture_path("asts.json"), "r") as f: 14 | cls.asts = [json.loads(v) for v in f.read().split("\n") if v] 15 | 16 | 17 | class DFSTransformerTest(TransformerTestCase): 18 | def test_split_input(self): 19 | self.assertFalse(DFSTransformer("lang", self.vocabulary).split_input) 20 | 21 | def test_input_length(self): 22 | transformer = DFSTransformer("lang", self.vocabulary, input_length=210) 23 | self.assertEqual(transformer.total_input_length, 210) 24 | 25 | def test_simple_transform(self): 26 | transformer = DFSTransformer("lang", self.vocabulary) 27 | result = transformer.transform_ast(self.asts[0]) 28 | self.assertEqual(len(self.asts[0]), len(result)) 29 | for index in result: 30 | self.assertIsInstance(index, np.int32) 31 | 32 | def test_transform_with_padding(self): 33 | transformer = DFSTransformer("lang", self.vocabulary, input_length=210) 34 | padded_result = transformer.transform_ast(self.asts[0]) # length: 206 35 | self.assertEqual(len(padded_result), 210) 36 | too_long_result = transformer.transform_ast(self.asts[2]) # length: 215 37 | self.assertFalse(too_long_result) 38 | 39 | def test_transform_with_base_index(self): 40 | transformer = DFSTransformer("lang", self.vocabulary) 41 | transformer_with_base_index = DFSTransformer("lang", self.vocabulary, vocabulary_offset=2) 42 | normal_result = transformer.transform_ast(self.asts[0]) 43 | changed_result = transformer_with_base_index.transform_ast(self.asts[0]) 44 | self.assertEqual([v + np.int32(2) for v in normal_result], changed_result) 45 | 46 | 47 | class BiDFSTransformerTest(TransformerTestCase): 48 | def test_split_input(self): 49 | self.assertTrue(BiDFSTransformer("lang", self.vocabulary).split_input) 50 | 51 | def test_input_length(self): 52 | transformer = BiDFSTransformer("lang", self.vocabulary, input_length=210) 53 | self.assertEqual(transformer.total_input_length, 420) 54 | 55 | def test_simple_transform(self): 56 | transformer = BiDFSTransformer("lang", self.vocabulary) 57 | result = transformer.transform_ast(self.asts[0]) 58 | self.assertEqual(len(self.asts[0]) * 2, len(result)) 59 | for index in result: 60 | self.assertIsInstance(index, np.int32) 61 | 62 | def test_transform_with_padding(self): 63 | transformer = BiDFSTransformer( 64 | "lang", self.vocabulary, input_length=210, vocabulary_offset=1) 65 | padded_result = transformer.transform_ast(self.asts[0]) # length: 206 66 | self.assertEqual(len(padded_result), 420) 67 | self.assertEqual(padded_result[208], 0) 68 | self.assertEqual(padded_result[416], 0) 69 | self.assertNotEqual(padded_result[210], 0) 70 | too_long_result = transformer.transform_ast(self.asts[2]) # length: 215 71 | self.assertFalse(too_long_result) 72 | -------------------------------------------------------------------------------- /tests/base.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase as BaseTestCase 2 | 3 | from os import path 4 | 5 | 6 | FIXTURES_PATH = path.join(path.dirname(__file__), "fixtures") 7 | 8 | 9 | class TestCase(BaseTestCase): 10 | @staticmethod 11 | def fixture_path(filename): 12 | return path.join(FIXTURES_PATH, filename) 13 | -------------------------------------------------------------------------------- /tests/config_test.py: -------------------------------------------------------------------------------- 1 | from tests.base import TestCase 2 | 3 | from suplearn_clone_detection.config import Config 4 | 5 | class ConfigTest(TestCase): 6 | @classmethod 7 | def load_config(cls): 8 | return Config.from_file(cls.fixture_path("config.yml")) 9 | 10 | @classmethod 11 | def setUpClass(cls): 12 | cls.config = cls.load_config() 13 | 14 | def test_config(self): 15 | config = self.config.model 16 | self.assertEqual(config.dense_layers, [64, 64]) 17 | 18 | def test_hash(self): 19 | self.assertEqual(self.config.data_generation_checksum(), self.load_config().data_generation_checksum()) 20 | -------------------------------------------------------------------------------- /tests/fixtures/asts.txt: -------------------------------------------------------------------------------- 1 | src/1/1/1454244.java 2 | src/1/1/1260853.py 3 | src/1/1/1400209.java 4 | src/1/0/1186318.py 5 | src/5/0/1384364.java 6 | -------------------------------------------------------------------------------- /tests/fixtures/config.yml: -------------------------------------------------------------------------------- 1 | references: 2 | language_defaults: &language_defaults 3 | embeddings_dimension: 100 4 | output_dimensions: [100] 5 | transformer_class_name: DFSTransformer 6 | vocabulary_offset: 1 7 | 8 | model: 9 | languages: 10 | - name: java 11 | <<: *language_defaults 12 | vocabulary: $HOME/workspaces/research/results/java/data/no-id.tsv 13 | input_length: 4 14 | - name: python 15 | <<: *language_defaults 16 | vocabulary: $HOME/workspaces/research/results/python/vocabulary/no-id.tsv 17 | input_length: 4 18 | 19 | dense_layers: [64, 64] 20 | optimizer: 21 | type: rmsprop 22 | 23 | generator: 24 | submissions_path: $HOME/workspaces/research/dataset/atcoder/submissions.json 25 | asts_path: $HOME/workspaces/research/dataset/atcoder/asts/asts.json 26 | use_all_combinations: false 27 | negative_samples: 1 28 | 29 | trainer: 30 | batch_size: 128 31 | epochs: 10 32 | -------------------------------------------------------------------------------- /tests/fixtures/submissions.json: -------------------------------------------------------------------------------- 1 | [{"contest_type": "r", "contest_id": 1, "problem_id": 1, "language": "java_1.7.0", "created_time": "2017/07/26 06:59:53 +0000", "problem_title": "B - \u30ea\u30e2\u30b3\u30f3", "score": 100, "username": "ssskurosuke", "status": "AC", "source_length": 752, "source_length_unit": "Byte", "exec_time": 89, "exec_time_unit": "ms", "submission_url": "http://arc001.contest.atcoder.jp/submissions/1454244", "id": "1454244", "file": "src/1/1/1454244.java"}, {"contest_type": "r", "contest_id": 1, "problem_id": 1, "language": "python3_3.4.3", "created_time": "2017/05/04 18:30:07 +0000", "problem_title": "B - \u30ea\u30e2\u30b3\u30f3", "score": 100, "username": "kakkey", "status": "AC", "source_length": 231, "source_length_unit": "Byte", "exec_time": 17, "exec_time_unit": "ms", "submission_url": "http://arc001.contest.atcoder.jp/submissions/1260853", "id": "1260853", "file": "src/1/1/1260853.py"}, {"contest_type": "r", "contest_id": 1, "problem_id": 1, "language": "python2_2.7.6", "created_time": "2017/08/31 07:10:06 +0000", "problem_title": "B - \u30ea\u30e2\u30b3\u30f3", "score": 100, "username": "prd_xxx", "status": "AC", "source_length": 98, "source_length_unit": "Byte", "exec_time": 12, "exec_time_unit": "ms", "submission_url": "http://arc001.contest.atcoder.jp/submissions/1554130", "id": "1554130", "file": "src/1/1/1554130.py"}, {"contest_type": "r", "contest_id": 1, "problem_id": 1, "language": "java_1.8.0", "created_time": "2017/07/03 05:03:15 +0000", "problem_title": "B - \u30ea\u30e2\u30b3\u30f3", "score": 100, "username": "mucarthur", "status": "AC", "source_length": 869, "source_length_unit": "Byte", "exec_time": 104, "exec_time_unit": "ms", "submission_url": "http://arc001.contest.atcoder.jp/submissions/1400209", "id": "1400209", "file": "src/1/1/1400209.java"}, {"contest_type": "r", "contest_id": 1, "problem_id": 0, "language": "python3_3.4.3", "created_time": "2017/03/27 08:46:00 +0000", "problem_title": "A - \u30bb\u30f3\u30bf\u30fc\u63a1\u70b9", "score": 100, "username": "maji_ji", "status": "AC", "source_length": 203, "source_length_unit": "Byte", "exec_time": 266, "exec_time_unit": "ms", "submission_url": "http://arc001.contest.atcoder.jp/submissions/1186318", "id": "1186318", "file": "src/1/0/1186318.py"}, {"contest_type": "b", "contest_id": 5, "problem_id": 0, "language": "java_1.8.0", "created_time": "2017/06/27 04:03:28 +0000", "problem_title": "A - \u5927\u597d\u304d\u9ad8\u6a4b\u541b", "score": 100, "username": "FVRChan", "status": "AC", "source_length": 489, "source_length_unit": "Byte", "exec_time": 97, "exec_time_unit": "ms", "submission_url": "http://arc005.contest.atcoder.jp/submissions/1384364", "id": "1384364", "file": "src/5/0/1384364.java"}, {"contest_type": "r", "contest_id": 15, "problem_id": 0, "language": "python3_3.4.3", "created_time": "2017/05/30 01:53:52 +0000", "problem_title": "A - Celsius \u3068 Fahrenheit", "score": 100, "username": "ssk0907", "status": "AC", "source_length": 43, "source_length_unit": "Byte", "exec_time": 18, "exec_time_unit": "ms", "submission_url": "http://arc015.contest.atcoder.jp/submissions/1317675", "id": "1317675", "file": "src/15/0/1317675.py"}] -------------------------------------------------------------------------------- /tests/fixtures/vocab-100.tsv: -------------------------------------------------------------------------------- 1 | id type metaType count value 2 | 0 SimpleName Other 1563985 3 | 1 NameExpr Expr 1215218 4 | 2 MethodCallExpr Expr 706234 5 | 3 ExpressionStmt Stmt 547284 6 | 4 ClassOrInterfaceType Type 528494 7 | 5 BlockStmt Stmt 293102 8 | 6 VariableDeclarator Other 230872 9 | 7 VariableDeclarationExpr Expr 183318 10 | 8 IntegerLiteralExpr Expr 161283 11 | 9 Parameter Other 147858 12 | 10 MethodDeclaration Declaration 142518 13 | 11 StringLiteralExpr Expr 122115 14 | 12 ObjectCreationExpr Expr 115276 15 | 13 FieldAccessExpr Expr 111752 16 | 14 BinaryExpr Expr 108859 17 | 15 PrimitiveType Type 93223 18 | 16 ReturnStmt Stmt 88510 19 | 17 VoidType Type 69288 20 | 18 IfStmt Stmt 63610 21 | 19 NullLiteralExpr Expr 58262 22 | 20 ArrayAccessExpr Expr 53728 23 | 21 IntegerLiteralExpr Expr 49226 "0" 24 | 22 AssignExpr Expr 48675 25 | 23 ArrayType Type 46292 26 | 24 FieldDeclaration Declaration 46220 27 | 25 AssignExpr Expr 44845 "=" 28 | 26 DoubleLiteralExpr Expr 43479 29 | 27 UnaryExpr Expr 42427 30 | 28 Name Other 39293 31 | 29 ArrayInitializerExpr Expr 37296 32 | 30 PrimitiveType Type 36777 "int" 33 | 31 SimpleName Other 36589 "assertEquals" 34 | 32 BinaryExpr Expr 32809 "+" 35 | 33 ThisExpr Expr 31920 36 | 34 CastExpr Expr 31502 37 | 35 SimpleName Other 31343 "String" 38 | 36 SimpleName Other 30718 "i" 39 | 37 IntegerLiteralExpr Expr 28650 "1" 40 | 38 EnclosedExpr Expr 27446 41 | 39 ArrayCreationLevel Other 27124 42 | 40 ArrayCreationExpr Expr 25010 43 | 41 PrimitiveType Type 23158 "double" 44 | 42 UnaryExpr Expr 21299 "-" 45 | 43 BooleanLiteralExpr Expr 20445 46 | 44 TryStmt Stmt 19628 47 | 45 CatchClause Other 18518 48 | 46 ClassOrInterfaceDeclaration Declaration 18386 49 | 47 ForStmt Stmt 17496 50 | 48 ThrowStmt Stmt 15518 51 | 49 ConstructorDeclaration Declaration 15088 52 | 50 CompilationUnit Other 14682 53 | 51 PackageDeclaration Declaration 14668 54 | 52 BinaryExpr Expr 14044 "==" 55 | 53 SimpleName Other 13422 "Assert" 56 | 54 IntegerLiteralExpr Expr 13183 "2" 57 | 55 SimpleName Other 13051 "length" 58 | 56 ClassExpr Expr 12874 59 | 57 SimpleName Other 12349 "assertTrue" 60 | 58 PrimitiveType Type 11562 "boolean" 61 | 59 SimpleName Other 11503 "value" 62 | 60 BooleanLiteralExpr Expr 11040 "false" 63 | 61 BinaryExpr Expr 11037 "<" 64 | 62 PrimitiveType Type 10750 "byte" 65 | 63 SimpleName Other 10593 "Object" 66 | 64 SimpleName Other 9862 "e" 67 | 65 UnaryExpr Expr 9722 "++" 68 | 66 BinaryExpr Expr 9668 "*" 69 | 67 BooleanLiteralExpr Expr 9405 "true" 70 | 68 ExplicitConstructorInvocationStmt Stmt 8928 71 | 69 SimpleName Other 8571 "Exception" 72 | 70 BinaryExpr Expr 8292 "!=" 73 | 71 BinaryExpr Expr 7905 "-" 74 | 72 TypeParameter Other 7858 75 | 73 SimpleName Other 7781 "result" 76 | 74 SimpleName Other 7690 "x" 77 | 75 IntegerLiteralExpr Expr 7452 "3" 78 | 76 SimpleName Other 7410 "add" 79 | 77 Name Other 7330 "commons" 80 | 78 Name Other 7330 "org" 81 | 79 Name Other 7330 "apache" 82 | 80 SimpleName Other 7131 "key" 83 | 81 SimpleName Other 7068 "IOException" 84 | 82 SimpleName Other 6693 "get" 85 | 83 CharLiteralExpr Expr 6579 86 | 84 SimpleName Other 6128 "name" 87 | 85 ForeachStmt Stmt 6124 88 | 86 SimpleName Other 6091 "size" 89 | 87 PrimitiveType Type 6058 "long" 90 | 88 SimpleName Other 6038 "append" 91 | 89 SimpleName Other 5688 "j" 92 | 90 UnaryExpr Expr 5672 "+" 93 | 91 ConditionalExpr Expr 5666 94 | 92 InstanceOfExpr Expr 5666 95 | 93 SimpleName Other 5473 "Integer" 96 | 94 SimpleName Other 5396 "fail" 97 | 95 SimpleName Other 5365 "toString" 98 | 96 SimpleName Other 5080 "T" 99 | 97 SimpleName Other 4992 "equals" 100 | 98 UnaryExpr Expr 4991 "!" 101 | 99 IntegerLiteralExpr Expr 4988 "4" 102 | -------------------------------------------------------------------------------- /tests/fixtures/vocab-noid.tsv: -------------------------------------------------------------------------------- 1 | id type metaType count 2 | 0 SimpleName Other 1510104 3 | 1 NameExpr Expr 584187 4 | 2 MethodCallExpr Expr 334222 5 | 3 ClassOrInterfaceType Type 273741 6 | 4 ExpressionStmt Stmt 259191 7 | 5 Name Other 208614 8 | 6 IntegerLiteralExpr Expr 169151 9 | 7 BlockStmt Stmt 141148 10 | 8 StringLiteralExpr Expr 108152 11 | 9 VariableDeclarator Other 106969 12 | 10 BinaryExpr Expr 103880 13 | 11 PrimitiveType Type 95673 14 | 12 VariableDeclarationExpr Expr 85275 15 | 13 Parameter Other 69993 16 | 14 MethodDeclaration Declaration 69176 17 | 15 DoubleLiteralExpr Expr 53319 18 | 16 ObjectCreationExpr Expr 51755 19 | 17 FieldAccessExpr Expr 51501 20 | 18 AssignExpr Expr 48869 21 | 19 UnaryExpr Expr 47385 22 | 20 ReturnStmt Stmt 44433 23 | 21 ImportDeclaration Declaration 39289 24 | 22 VoidType Type 33183 25 | 23 IfStmt Stmt 31088 26 | 24 ArrayAccessExpr Expr 29595 27 | 25 NullLiteralExpr Expr 26203 28 | 26 ArrayType Type 24183 29 | 27 ArrayInitializerExpr Expr 21413 30 | 28 FieldDeclaration Declaration 21012 31 | 29 BooleanLiteralExpr Expr 19696 32 | 30 CastExpr Expr 16471 33 | 31 ArrayCreationLevel Other 14963 34 | 32 ThisExpr Expr 14532 35 | 33 EnclosedExpr Expr 14207 36 | 34 ArrayCreationExpr Expr 13718 37 | 35 ForStmt Stmt 9386 38 | 36 CatchClause Other 8967 39 | 37 TryStmt Stmt 8748 40 | 38 ClassOrInterfaceDeclaration Declaration 8391 41 | 39 ConstructorDeclaration Declaration 7792 42 | 40 ThrowStmt Stmt 7533 43 | 41 CompilationUnit Other 6677 44 | 42 PackageDeclaration Declaration 6650 45 | 43 CharLiteralExpr Expr 5921 46 | 44 ClassExpr Expr 5853 47 | 45 WildcardType Type 5587 48 | 46 LongLiteralExpr Expr 4894 49 | 47 ExplicitConstructorInvocationStmt Stmt 4780 50 | 48 SwitchEntryStmt Stmt 3549 51 | 49 InstanceOfExpr Expr 3230 52 | 50 ConditionalExpr Expr 2679 53 | 51 TypeParameter Other 2606 54 | 52 ForeachStmt Stmt 2413 55 | 53 SuperExpr Expr 1919 56 | 54 WhileStmt Stmt 1881 57 | 55 BreakStmt Stmt 1874 58 | 56 EnumConstantDeclaration Declaration 1121 59 | 57 SwitchStmt Stmt 443 60 | 58 SynchronizedStmt Stmt 404 61 | 59 ContinueStmt Stmt 362 62 | 60 DoStmt Stmt 141 63 | 61 InitializerDeclaration Declaration 138 64 | 62 EnumDeclaration Declaration 119 65 | 63 AnnotationMemberDeclaration Declaration 100 66 | 64 EmptyStmt Stmt 69 67 | 65 LabeledStmt Stmt 50 68 | 66 AssertStmt Stmt 39 69 | 67 AnnotationDeclaration Declaration 34 70 | 68 LocalClassDeclarationStmt Stmt 27 71 | 69 UnionType Type 4 72 | 70 LambdaExpr Expr 1 73 | -------------------------------------------------------------------------------- /tests/util_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suplearn_clone_detection import util 4 | 5 | 6 | class UtilTest(unittest.TestCase): 7 | def test_in_batch(self): 8 | sample_list = [8, 4, 2, 1, 9, 2, 3, 5] # 8 elements 9 | batches = list(util.in_batch(sample_list, batch_size=2)) 10 | self.assertEqual(len(batches), 4) 11 | self.assertListEqual(batches[0], [8, 4]) 12 | 13 | # partial batch 14 | sample_list.append(9) 15 | batches = list(util.in_batch(sample_list, batch_size=2)) 16 | self.assertEqual(len(batches), 5) 17 | self.assertListEqual(batches[-1], [9]) 18 | 19 | def test_group_by(self): 20 | sample_list = [(0, 1), (0, 2), (1, 2), (1, 8), (1, 4)] 21 | grouped = util.group_by(sample_list, key=lambda x: x[0]) 22 | self.assertListEqual(list(grouped.keys()), [0, 1]) 23 | self.assertListEqual(grouped[0], [(0, v) for v in [1, 2]]) 24 | self.assertListEqual(grouped[1], [(1, v) for v in [2, 8, 4]]) 25 | -------------------------------------------------------------------------------- /tests/vocabulary_test.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from tests.base import TestCase 3 | 4 | from suplearn_clone_detection.vocabulary import Vocabulary 5 | 6 | 7 | 8 | class VocabularyTest(TestCase): 9 | @classmethod 10 | def setUpClass(cls): 11 | cls.vocab_no_values = Vocabulary.from_file(cls.fixture_path("vocab-noid.tsv")) 12 | cls.vocab_with_values = Vocabulary.from_file( 13 | cls.fixture_path("vocab-100.tsv"), fallback_empty_value=False) 14 | 15 | def test_valid_access_no_value(self): 16 | self.assertEqual(self.vocab_no_values.index({"type": "SimpleName"}), 0) 17 | self.assertEqual(self.vocab_no_values.index({"type": "BinaryExpr"}), 10) 18 | self.assertEqual(self.vocab_no_values.index({"type": "DoStmt"}), 60) 19 | 20 | def test_valid_access_with_value(self): 21 | self.assertEqual(self.vocab_with_values.index({"type": "SimpleName"}), 0) 22 | self.assertEqual( 23 | self.vocab_with_values.index({"type": "IntegerLiteralExpr", "value": "0"}), 21) 24 | self.assertEqual( 25 | self.vocab_with_values.index({"type": "BooleanLiteralExpr", "value": "true"}), 67) 26 | 27 | def test_keyerror_no_value(self): 28 | with self.assertRaises(KeyError): 29 | _ = self.vocab_no_values.index({"type": "IDontExist"}) 30 | 31 | def test_keyerror_with_value(self): 32 | with self.assertRaises(KeyError): 33 | _ = self.vocab_with_values.index({"type": "IDontExist"}) 34 | 35 | with self.assertRaises(KeyError): 36 | _ = self.vocab_with_values.index({"type": "SimpleName", "value": "dont-exist"}) 37 | 38 | def test_save(self): 39 | with tempfile.NamedTemporaryFile(prefix="suplearn-cc") as f: 40 | self.vocab_with_values.save(f.name) 41 | reloaded_vocab = Vocabulary.from_file(f.name, fallback_empty_value=False) 42 | self.assertEqual(self.vocab_with_values, reloaded_vocab) 43 | --------------------------------------------------------------------------------