├── .gitignore ├── .idea ├── codeprep.iml ├── dictionaries │ └── hlib.xml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── .reuse └── dep5 ├── .travis.yml ├── LICENSES ├── Apache-2.0.txt └── CC0-1.0.txt ├── MANIFEST.in ├── README.md ├── codeprep ├── VERSION ├── __init__.py ├── __main__.py ├── api │ ├── __init__.py │ ├── common.py │ ├── corpus.py │ └── text.py ├── bpepkg │ ├── __init__.py │ ├── bpe_config.py │ ├── bpe_encode.py │ ├── bpe_learn.py │ ├── cache.py │ ├── merge.py │ └── wild_bpe.py ├── cli │ ├── __init__.py │ ├── impl.py │ ├── spec.py │ └── vocab.py ├── config.py ├── data │ ├── allamanis_dataset_metadata │ │ └── small-train-projects.txt │ └── bpe │ │ ├── 10k │ │ └── merges.txt │ │ ├── case │ │ ├── 0 │ │ │ └── merges.txt │ │ ├── 10k │ │ │ └── merges.txt │ │ ├── 1k │ │ │ └── merges.txt │ │ ├── 2k │ │ │ └── merges.txt │ │ └── 5k │ │ │ └── merges.txt │ │ └── nocase │ │ └── 0 │ │ └── merges.txt ├── dirutils.py ├── fileutils.py ├── logging.yaml ├── noneng.py ├── parse │ ├── __init__.py │ ├── core.py │ ├── matchers.py │ └── subtokens.py ├── pipeline │ ├── __init__.py │ ├── bpelearner.py │ ├── bperegistry.py │ ├── dataset.py │ ├── parse_projects.py │ ├── stages.py │ ├── to_repr.py │ ├── vocab.py │ └── vocabloader.py ├── prepconfig.py ├── preprocess │ ├── __init__.py │ ├── core.py │ ├── metadata.py │ ├── placeholders.py │ └── reprconfig.py ├── stemming.py ├── subtokens.py ├── tokens │ ├── __init__.py │ ├── containers.py │ ├── noneng.py │ ├── numeric.py │ ├── rootclasses.py │ ├── whitespace.py │ └── word.py └── util.py ├── reports └── bpe │ └── wild-bpe │ ├── v0.1_0.05mb_1bit.png │ ├── v0.1_0.05mb_1bit.png.license │ ├── v0.1_0.05mb_2bit.png │ ├── v0.1_0.05mb_2bit.png.license │ ├── v0.1_0.05mb_3bit.png │ ├── v0.1_0.05mb_3bit.png.license │ ├── v0.1_0.5mb_1bit.png │ ├── v0.1_0.5mb_1bit.png.license │ ├── v0.1_0.5mb_2bit.png │ ├── v0.1_0.5mb_2bit.png.license │ ├── v0.1_0.5mb_3bit.png │ ├── v0.1_0.5mb_3bit.png.license │ ├── v0.1_5mb_1bit.png │ ├── v0.1_5mb_1bit.png.license │ ├── v0.1_5mb_2bit.png │ ├── v0.1_5mb_2bit.png.license │ ├── v0.1_5mb_3bit.png │ └── v0.1_5mb_3bit.png.license ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── test-data └── test-corpus │ ├── jquery.min.js │ └── yahtzee │ ├── .reuse │ └── dep5 │ ├── LICENSES │ └── MIT.txt │ └── src │ ├── main │ └── java │ │ └── hlibbabii │ │ └── yahtzee │ │ ├── DiceValues.java │ │ ├── GameDemo.java │ │ ├── Player.java │ │ ├── combination │ │ ├── Chance.java │ │ ├── Combination.java │ │ ├── FullHouse.java │ │ ├── LargeStraight.java │ │ ├── NOfAKind.java │ │ ├── Numbers.java │ │ ├── SmallStraight.java │ │ ├── TwoPairs.java │ │ └── Yahtzee.java │ │ ├── gameplay │ │ ├── Decision.java │ │ ├── Game.java │ │ ├── GameStats.java │ │ ├── MoveResult.java │ │ ├── PlayerException.java │ │ └── PlayerStats.java │ │ ├── model │ │ └── DiceLayout.java │ │ ├── player │ │ └── DummyPlayer.java │ │ └── util │ │ └── RandomService.java │ └── test │ └── java │ └── hlibbabii │ └── yahtzee │ ├── combination │ ├── ChanceTest.java │ ├── FullHouseTest.java │ ├── LargeStraightTest.java │ ├── NOfAKindTest.java │ ├── SmallStraightTest.java │ └── TwoPairsTest.java │ └── model │ └── DiceLayoutTest.java ├── tests ├── __init__.py ├── api │ ├── __init__.py │ └── test_corpus.py ├── bpepkg │ ├── __init__.py │ ├── test_merge.py │ └── wild_bpe_performance.py ├── cli │ ├── __init__.py │ └── test_spec.py ├── infrastructure │ ├── __init__.py │ ├── test_bpelearner.py │ ├── test_bperegistry.py │ └── test_dataset.py ├── parse │ ├── __init__.py │ ├── test_core.py │ └── test_subtokens.py ├── test_corpus_b2b.py ├── test_subword_separation.py └── test_to_repr.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Pycharm 107 | **/.idea/workspace.xml 108 | **/.idea/tasks.xml 109 | 110 | !test-data -------------------------------------------------------------------------------- /.idea/codeprep.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /.idea/dictionaries/hlib.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | databunch 5 | fastai 6 | mult 7 | numericalizer 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 24 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.reuse/dep5: -------------------------------------------------------------------------------- 1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ 2 | Upstream-Name: codeprep 3 | Upstream-Contact: Hlib Babii 4 | Source: https://github.com/giganticode/codeprep 5 | 6 | # Sample paragraph, commented out: 7 | # 8 | Files: codeprep/data/**/*.txt codeprep/logging.yaml 9 | Copyright: 2020 Hlib Babii 10 | License: Apache-2.0 11 | 12 | Files: .idea/**/*.* .idea/*.* .gitignore .travis.yml MANIFEST.in requirements.txt requirements-dev.txt tox.ini codeprep/VERSION 13 | Copyright: 2020 Hlib Babii 14 | License: CC0-1.0 15 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial # required for Python >= 3.7 2 | env: 3 | global: 4 | - CC_TEST_REPORTER_ID=$CODE_CLIMATE_REPORTER_ID 5 | language: python 6 | if: 'type = pull_request OR branch = master' 7 | python: 8 | - "3.6" 9 | - "3.7" 10 | jobs: 11 | include: 12 | - os: osx 13 | osx_image: xcode11.2 # Python 3.7.4 running on macOS 10.14.4 14 | language: shell 15 | python: "3.7" 16 | before_script: echo "Not sending code coverage to code climate on OSx" 17 | after_script: echo "Not sending code coverage to code climate on OSx" 18 | - os: windows 19 | language: sh 20 | python: "3.7" 21 | before_install: 22 | - choco install python3 --params "/InstallDir:C:\\Python" 23 | - export PATH="/c/Python:/c/Python/Scripts:$PATH" 24 | - python -m pip install --upgrade pip wheel 25 | before_script: echo "Not sending code coverage to code climate on Windows" 26 | after_script: echo "Not sending code coverage to code climate on Windows" 27 | install: 28 | - pip3 install -r requirements.txt 29 | - pip3 install -r requirements-dev.txt 30 | before_script: 31 | - curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter 32 | - chmod +x ./cc-test-reporter 33 | - ./cc-test-reporter before-build 34 | script: 35 | - cd $TRAVIS_BUILD_DIR 36 | - echo "Current directory is $(pwd)" 37 | - coverage run --concurrency=multiprocessing -m pytest --doctest-modules 38 | - coverage combine 39 | - coverage report -m --include=./\*\* --omit=tests/\*\*,.venv/\*\* --fail-under=70 40 | - coverage xml -i --include=./\*\* --omit=tests/\*\*,.venv/\*\* 41 | after_script: 42 | - if [[ "$TRAVIS_PULL_REQUEST" == "false" && "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then echo "Reporting coverage to Code Climate" && ./cc-test-reporter after-build --exit-code $TRAVIS_TEST_RESULT; fi 43 | -------------------------------------------------------------------------------- /LICENSES/CC0-1.0.txt: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES 4 | NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE 5 | AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION 6 | ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE USE 7 | OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER, AND DISCLAIMS 8 | LIABILITY FOR DAMAGES RESULTING FROM THE USE OF THIS DOCUMENT OR THE INFORMATION 9 | OR WORKS PROVIDED HEREUNDER. 10 | 11 | Statement of Purpose 12 | 13 | The laws of most jurisdictions throughout the world automatically confer exclusive 14 | Copyright and Related Rights (defined below) upon the creator and subsequent 15 | owner(s) (each and all, an "owner") of an original work of authorship and/or 16 | a database (each, a "Work"). 17 | 18 | Certain owners wish to permanently relinquish those rights to a Work for the 19 | purpose of contributing to a commons of creative, cultural and scientific 20 | works ("Commons") that the public can reliably and without fear of later claims 21 | of infringement build upon, modify, incorporate in other works, reuse and 22 | redistribute as freely as possible in any form whatsoever and for any purposes, 23 | including without limitation commercial purposes. These owners may contribute 24 | to the Commons to promote the ideal of a free culture and the further production 25 | of creative, cultural and scientific works, or to gain reputation or greater 26 | distribution for their Work in part through the use and efforts of others. 27 | 28 | For these and/or other purposes and motivations, and without any expectation 29 | of additional consideration or compensation, the person associating CC0 with 30 | a Work (the "Affirmer"), to the extent that he or she is an owner of Copyright 31 | and Related Rights in the Work, voluntarily elects to apply CC0 to the Work 32 | and publicly distribute the Work under its terms, with knowledge of his or 33 | her Copyright and Related Rights in the Work and the meaning and intended 34 | legal effect of CC0 on those rights. 35 | 36 | 1. Copyright and Related Rights. A Work made available under CC0 may be protected 37 | by copyright and related or neighboring rights ("Copyright and Related Rights"). 38 | Copyright and Related Rights include, but are not limited to, the following: 39 | 40 | i. the right to reproduce, adapt, distribute, perform, display, communicate, 41 | and translate a Work; 42 | 43 | ii. moral rights retained by the original author(s) and/or performer(s); 44 | 45 | iii. publicity and privacy rights pertaining to a person's image or likeness 46 | depicted in a Work; 47 | 48 | iv. rights protecting against unfair competition in regards to a Work, subject 49 | to the limitations in paragraph 4(a), below; 50 | 51 | v. rights protecting the extraction, dissemination, use and reuse of data 52 | in a Work; 53 | 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal protection 56 | of databases, and under any national implementation thereof, including any 57 | amended or successor version of such directive); and 58 | 59 | vii. other similar, equivalent or corresponding rights throughout the world 60 | based on applicable law or treaty, and any national implementations thereof. 61 | 62 | 2. Waiver. To the greatest extent permitted by, but not in contravention of, 63 | applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and 64 | unconditionally waives, abandons, and surrenders all of Affirmer's Copyright 65 | and Related Rights and associated claims and causes of action, whether now 66 | known or unknown (including existing as well as future claims and causes of 67 | action), in the Work (i) in all territories worldwide, (ii) for the maximum 68 | duration provided by applicable law or treaty (including future time extensions), 69 | (iii) in any current or future medium and for any number of copies, and (iv) 70 | for any purpose whatsoever, including without limitation commercial, advertising 71 | or promotional purposes (the "Waiver"). Affirmer makes the Waiver for the 72 | benefit of each member of the public at large and to the detriment of Affirmer's 73 | heirs and successors, fully intending that such Waiver shall not be subject 74 | to revocation, rescission, cancellation, termination, or any other legal or 75 | equitable action to disrupt the quiet enjoyment of the Work by the public 76 | as contemplated by Affirmer's express Statement of Purpose. 77 | 78 | 3. Public License Fallback. Should any part of the Waiver for any reason be 79 | judged legally invalid or ineffective under applicable law, then the Waiver 80 | shall be preserved to the maximum extent permitted taking into account Affirmer's 81 | express Statement of Purpose. In addition, to the extent the Waiver is so 82 | judged Affirmer hereby grants to each affected person a royalty-free, non 83 | transferable, non sublicensable, non exclusive, irrevocable and unconditional 84 | license to exercise Affirmer's Copyright and Related Rights in the Work (i) 85 | in all territories worldwide, (ii) for the maximum duration provided by applicable 86 | law or treaty (including future time extensions), (iii) in any current or 87 | future medium and for any number of copies, and (iv) for any purpose whatsoever, 88 | including without limitation commercial, advertising or promotional purposes 89 | (the "License"). The License shall be deemed effective as of the date CC0 90 | was applied by Affirmer to the Work. Should any part of the License for any 91 | reason be judged legally invalid or ineffective under applicable law, such 92 | partial invalidity or ineffectiveness shall not invalidate the remainder of 93 | the License, and in such case Affirmer hereby affirms that he or she will 94 | not (i) exercise any of his or her remaining Copyright and Related Rights 95 | in the Work or (ii) assert any associated claims and causes of action with 96 | respect to the Work, in either case contrary to Affirmer's express Statement 97 | of Purpose. 98 | 99 | 4. Limitations and Disclaimers. 100 | 101 | a. No trademark or patent rights held by Affirmer are waived, abandoned, surrendered, 102 | licensed or otherwise affected by this document. 103 | 104 | b. Affirmer offers the Work as-is and makes no representations or warranties 105 | of any kind concerning the Work, express, implied, statutory or otherwise, 106 | including without limitation warranties of title, merchantability, fitness 107 | for a particular purpose, non infringement, or the absence of latent or other 108 | defects, accuracy, or the present or absence of errors, whether or not discoverable, 109 | all to the greatest extent permissible under applicable law. 110 | 111 | c. Affirmer disclaims responsibility for clearing rights of other persons 112 | that may apply to the Work or any use thereof, including without limitation 113 | any person's Copyright and Related Rights in the Work. Further, Affirmer disclaims 114 | responsibility for obtaining any necessary consents, permissions or other 115 | rights required for any use of the Work. 116 | 117 | d. Affirmer understands and acknowledges that Creative Commons is not a party 118 | to this document and has no duty or obligation with respect to this CC0 or 119 | use of the Work. 120 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft codeprep/data 2 | include codeprep/VERSION 3 | include codeprep/*.yaml 4 | include *.md 5 | include *.txt 6 | include LICENSE 7 | -------------------------------------------------------------------------------- /codeprep/VERSION: -------------------------------------------------------------------------------- 1 | 1.0.5 2 | -------------------------------------------------------------------------------- /codeprep/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | import logging.config 7 | import os 8 | import yaml 9 | 10 | from codeprep.config import root_package_dir, version 11 | 12 | 13 | def load_logging_config(): 14 | path = os.path.join(root_package_dir, 'logging.yaml') 15 | if os.path.exists(path): 16 | with open(path, 'rt') as f: 17 | logging_config = yaml.safe_load(f.read()) 18 | logging.config.dictConfig(logging_config) 19 | else: 20 | logging.basicConfig(level=logging.DEBUG) 21 | 22 | 23 | load_logging_config() 24 | 25 | logging.getLogger('matplotlib').setLevel(logging.INFO) 26 | logging.getLogger('Ronin').setLevel(logging.INFO) 27 | 28 | __version__ = version -------------------------------------------------------------------------------- /codeprep/__main__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import sys 6 | 7 | from codeprep.cli.spec import parse_and_run 8 | 9 | 10 | def main(): 11 | parse_and_run(sys.argv[1:]) 12 | 13 | 14 | if __name__ == '__main__': 15 | main() -------------------------------------------------------------------------------- /codeprep/api/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /codeprep/api/common.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import sys 6 | from typing import Optional 7 | 8 | from codeprep.prepconfig import PrepConfig, PrepParam, get_possible_str_values 9 | 10 | 11 | def create_split_value(split_type: str, bpe_codes_id: Optional[str] = None, full_strings: bool = False, 12 | split_numbers: bool = False, ronin: bool = False, stem: bool = False): 13 | if split_type == 'nosplit': 14 | return 'F' if full_strings else '0' 15 | elif split_type == 'chars': 16 | return '8' 17 | elif split_type == 'basic': 18 | if stem: 19 | return 's' 20 | elif ronin: 21 | return '3' 22 | elif split_numbers: 23 | return '2' 24 | else: 25 | return '1' 26 | elif split_type == 'bpe': 27 | if bpe_codes_id == '1k': 28 | return '5' 29 | elif bpe_codes_id == '5k': 30 | return '4' 31 | elif bpe_codes_id == '10k': 32 | return '6' 33 | else: 34 | return '9' 35 | else: 36 | raise AssertionError(f"Invalid split option: {split_type}") 37 | 38 | 39 | def create_str_value(no_str: bool, max_str_len: int) -> str: 40 | if no_str: 41 | return '0' 42 | if 0 <= max_str_len < 2: 43 | return '2' 44 | if 2 <= max_str_len < len(get_possible_str_values()): 45 | return get_possible_str_values()[max_str_len] 46 | else: 47 | return '1' 48 | 49 | 50 | def create_prep_config(spl_type: str, bpe_codes_id: Optional[str] = None, no_spaces: bool = False, 51 | no_unicode: bool = False, no_case: bool = False, no_com: bool = False, no_str: bool = False, 52 | full_strings: bool = False, max_str_length: int = sys.maxsize, split_numbers: bool = False, 53 | ronin: bool = False, stem: bool = False): 54 | return PrepConfig({ 55 | PrepParam.EN_ONLY: 'U' if no_unicode else 'u', 56 | PrepParam.COM: '0' if no_com else 'c', 57 | PrepParam.STR: create_str_value(no_str, max_str_length), 58 | PrepParam.SPLIT: create_split_value(spl_type, bpe_codes_id=bpe_codes_id, full_strings=full_strings, 59 | split_numbers=split_numbers, ronin=ronin, stem=stem), 60 | PrepParam.TABS_NEWLINES: '0' if no_spaces else 's', 61 | PrepParam.CASE: 'l' if no_case else 'u' 62 | }) -------------------------------------------------------------------------------- /codeprep/bpepkg/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /codeprep/bpepkg/bpe_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from enum import Enum 6 | from typing import Dict 7 | 8 | from codeprep.prepconfig import PrepConfig, PrepParam 9 | 10 | 11 | class BpeConfigNotSupported(Exception): 12 | pass 13 | 14 | 15 | class BpeParam(str, Enum): 16 | CASE: str = 'case' 17 | WORD_END: str = 'wordend' 18 | BASE: str = 'base' 19 | UNICODE: str = 'unicode' 20 | 21 | 22 | class BpeConfig(object): 23 | possible_param_values = { 24 | BpeParam.CASE: ['yes'], 25 | BpeParam.WORD_END: [True, False], 26 | BpeParam.BASE: ["all", "code", "java"], 27 | BpeParam.UNICODE: ['yes', 'no', 'bytes'], 28 | } 29 | 30 | @staticmethod 31 | def _check_param_number(n_passed_params: int): 32 | n_expected_params = len([i for i in BpeParam]) 33 | if n_passed_params != n_expected_params: 34 | raise ValueError(f'Expected {n_expected_params} params, got {n_passed_params}') 35 | 36 | @staticmethod 37 | def _check_invariants(params: Dict[BpeParam, str]): 38 | BpeConfig._check_param_number(len(params)) 39 | for pp in BpeParam: 40 | if params[pp] not in BpeConfig.possible_param_values[pp]: 41 | raise ValueError(f'Invalid value {params[pp]} for prep param {pp}, ' 42 | f'possible values are: {BpeConfig.possible_param_values[pp]}') 43 | 44 | def __init__(self, params: Dict[BpeParam, str]): 45 | BpeConfig._check_invariants(params) 46 | 47 | self.params = params 48 | 49 | def get_param_value(self, param: BpeParam) -> str: 50 | return self.params[param] 51 | 52 | def to_prep_config(self): 53 | return PrepConfig({ 54 | PrepParam.EN_ONLY: 'U' if self.get_param_value(BpeParam.UNICODE) == 'no' else 'u', 55 | PrepParam.COM: '0', 56 | PrepParam.STR: 'E', 57 | PrepParam.SPLIT: 'F', 58 | PrepParam.TABS_NEWLINES: 's', 59 | PrepParam.CASE: 'u' 60 | }) 61 | 62 | UNICODE_NO = 'nounicode' 63 | UNICODE_BYTES = 'bytes' 64 | CASE_NO = 'nocase' 65 | CASE_PREFIX = 'prefix' 66 | WORD_END = 'we' 67 | 68 | @staticmethod 69 | def from_suffix(suffix: str): 70 | if suffix.find(BpeConfig.CASE_NO) != -1: 71 | case = 'no' 72 | elif suffix.find(BpeConfig.CASE_PREFIX) != -1: 73 | case = 'prefix' 74 | else: 75 | case = 'yes' 76 | 77 | if suffix.find(BpeConfig.UNICODE_NO) != -1: 78 | unicode = 'no' 79 | elif suffix.find(BpeConfig.UNICODE_BYTES) != -1: 80 | unicode = 'bytes' 81 | else: 82 | unicode = 'yes' 83 | 84 | 85 | return BpeConfig({ 86 | BpeParam.CASE: case, 87 | BpeParam.WORD_END: suffix.find(BpeConfig.WORD_END) != -1, 88 | BpeParam.BASE: 'code', 89 | BpeParam.UNICODE: unicode, 90 | }) 91 | 92 | def to_suffix(self): 93 | """ 94 | >>> bpe_config = BpeConfig({ 95 | ... BpeParam.CASE: 'yes', 96 | ... BpeParam.WORD_END: False, 97 | ... BpeParam.BASE: 'all', 98 | ... BpeParam.UNICODE: 'yes' 99 | ... }) 100 | >>> bpe_config.to_suffix() 101 | '' 102 | 103 | >>> bpe_config = BpeConfig({ 104 | ... BpeParam.CASE: 'yes', 105 | ... BpeParam.WORD_END: True, 106 | ... BpeParam.BASE: 'all', 107 | ... BpeParam.UNICODE: 'no' 108 | ... }) 109 | >>> bpe_config.to_suffix() 110 | 'we_nounicode' 111 | 112 | >>> bpe_config = BpeConfig({ 113 | ... BpeParam.CASE: 'yes', 114 | ... BpeParam.WORD_END: False, 115 | ... BpeParam.BASE: 'all', 116 | ... BpeParam.UNICODE: 'bytes' 117 | ... }) 118 | >>> bpe_config.to_suffix() 119 | 'bytes' 120 | 121 | """ 122 | suffix_parts = [] 123 | 124 | if self.get_param_value(BpeParam.CASE) == 'no': 125 | suffix_parts.append(BpeConfig.CASE_NO) 126 | elif self.get_param_value(BpeParam.CASE) == 'prefix': 127 | suffix_parts.append(BpeConfig.CASE_PREFIX) 128 | 129 | if self.get_param_value(BpeParam.WORD_END): 130 | suffix_parts.append(BpeConfig.WORD_END) 131 | 132 | if self.get_param_value(BpeParam.UNICODE) == 'no': 133 | suffix_parts.append(BpeConfig.UNICODE_NO) 134 | elif self.get_param_value(BpeParam.UNICODE) == 'bytes': 135 | suffix_parts.append(BpeConfig.UNICODE_BYTES) 136 | 137 | return "_".join(suffix_parts) 138 | 139 | def __eq__(self, other): 140 | return self.params == other.params 141 | 142 | def __str__(self) -> str: 143 | parts = [str(self.params[k]) for k in BpeParam] 144 | return "_".join(parts) 145 | 146 | def __repr__(self): 147 | return str(self.params) -------------------------------------------------------------------------------- /codeprep/bpepkg/bpe_encode.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | import os 7 | import sys 8 | 9 | import argparse 10 | from typing import List, Dict 11 | 12 | from tqdm import tqdm 13 | 14 | from codeprep.bpepkg.merge import MergeList, read_merges 15 | from codeprep.config import DEFAULT_BPE_DIR 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class BpeData(object): 21 | def __init__(self, merges_cache=None, merges: MergeList=None): 22 | self.merges_cache = merges_cache 23 | self.merges = merges 24 | 25 | 26 | ESCAPE_CHAR = '@' 27 | 28 | ESCAPABLE_CHAR_LIST = [] + [ESCAPE_CHAR] 29 | 30 | 31 | def escape(word: str, merged: bool=False) -> str: 32 | word = word.replace(ESCAPE_CHAR, 2 * ESCAPE_CHAR) 33 | if merged: 34 | return f"{word}{ESCAPE_CHAR}" 35 | else: 36 | return f"{word} {ESCAPE_CHAR}" 37 | 38 | 39 | def unescape(parts: List[str]): 40 | if parts[-1][-1] != ESCAPE_CHAR: 41 | raise ValueError(f"There should be {ESCAPE_CHAR} at the end, however this is what was passed: {parts}") 42 | 43 | parts[-1] = parts[-1][:-1] 44 | return list(map(lambda p: p.replace(ESCAPE_CHAR + '@', ESCAPE_CHAR), parts)) 45 | 46 | 47 | def to_char_list(word: str): 48 | i = 0 49 | res = [] 50 | while i < len(word): 51 | if word[i] != ESCAPE_CHAR or i+1 == len(word): 52 | res.append(word[i]) 53 | i += 1 54 | elif word[i+1] in ESCAPABLE_CHAR_LIST: 55 | res.append(word[i:i+2]) 56 | i += 2 57 | else: 58 | raise ValueError(f"Illegal escape sequence: {word[i:i+2]}") 59 | return res 60 | 61 | 62 | def encode(words: Dict[str, int], merges: MergeList) -> Dict[str, int]: 63 | letters_list = {" ".join(to_char_list(k)): v for k, v in words.items()} 64 | 65 | new_letters_list = {} 66 | for letters, freq in letters_list.items(): 67 | subwords = letters.split(" ") 68 | 69 | show_bpe_progress_bar = False 70 | if len(subwords) > 5000: 71 | logger.warning(f'Encountered a string of length {len(subwords)}. It will take a while to bpe-encode it.') 72 | show_bpe_progress_bar = True 73 | 74 | if show_bpe_progress_bar: 75 | bpe_progress = tqdm(total=len(merges)) 76 | last_value = 0 77 | while True: 78 | merge_indices = [] 79 | merge_candidate_priority = sys.maxsize 80 | for i in range(len(subwords) - 1): 81 | merge_candidate = (subwords[i], subwords[i + 1]) 82 | if merge_candidate in merges: 83 | current_merge_candidate_priority = merges.get_priority(merge_candidate) 84 | if current_merge_candidate_priority < merge_candidate_priority: 85 | merge_candidate_priority = current_merge_candidate_priority 86 | merge_indices = [i] 87 | elif current_merge_candidate_priority == merge_candidate_priority: 88 | if not merge_indices or merge_indices[-1] != i - 1: 89 | merge_indices.append(i) 90 | 91 | if not merge_indices: 92 | break 93 | 94 | subwords_after_this_merge_round = [] 95 | start_idx = 0 96 | for merge_index in merge_indices: 97 | for i in range(start_idx, merge_index): 98 | subwords_after_this_merge_round.append(subwords[i]) 99 | subwords_after_this_merge_round.append(subwords[merge_index] + subwords[merge_index + 1]) 100 | start_idx = merge_index + 2 101 | for i in range(start_idx, len(subwords)): 102 | subwords_after_this_merge_round.append(subwords[i]) 103 | subwords = subwords_after_this_merge_round 104 | if show_bpe_progress_bar: 105 | bpe_progress.update(merge_candidate_priority - last_value) 106 | last_value = merge_candidate_priority 107 | if show_bpe_progress_bar: 108 | bpe_progress.update(len(merges) - last_value) 109 | bpe_progress.close() 110 | 111 | new_letters_list[" ".join(subwords)] = freq 112 | return new_letters_list 113 | 114 | 115 | def encode_word(word: str, merges: MergeList) -> List[str]: 116 | """ 117 | >>> merge_file = os.path.join(DEFAULT_BPE_DIR, '10k', 'merges.txt') 118 | >>> merges = read_merges(merge_file, 10000) 119 | 120 | >>> encode_word('this@@is_all_one_String@', merges) 121 | ['this', '@@', 'is_', 'all', '_', 'one', '_', 'String@'] 122 | 123 | >>> encode_word('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa@', merges) 124 | ['aaaaaaaa', 'aaaaaaaa', 'aaaaaaaa', 'aaaaaaaa', 'aaaa', 'a', 'a@'] 125 | 126 | >>> encode_word('erererererererererererer@', merges) 127 | ['er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er@'] 128 | 129 | >>> encode_word('@', merges) 130 | ['@'] 131 | 132 | >>> encode_word('', merges) 133 | [''] 134 | 135 | >>> encode_word('split@', merges) 136 | ['split@'] 137 | 138 | >>> encode_word('aaa', merges) 139 | ['aa', 'a'] 140 | 141 | >>> encode_word('this\xa0is@@a@@@@bit@@@@larger\xa0stringwith\xa0some@@unicode@@possibly\xf7@', merges) 142 | ['this', '\\xa0', 'is', '@@', 'a', '@@', '@@', 'bit', '@@', '@@', 'l', 'arg', 'er', '\\xa0', 'string', \ 143 | 'with', '\\xa0', 's', 'ome', '@@', 'unic', 'ode', '@@', 'pos', 'si', 'b', 'ly', '÷', '@'] 144 | """ 145 | enc_word, _ = encode({word: 0}, merges).popitem() 146 | subwords = enc_word.split(" ") 147 | return subwords 148 | 149 | 150 | def get_bpe_subwords(word: str, bpe_data: BpeData) -> List[str]: 151 | merges = bpe_data.merges 152 | cache = bpe_data.merges_cache 153 | word = escape(word, merged=True) 154 | if word in cache: 155 | result = cache[word] 156 | else: 157 | result = encode_word(word, merges) 158 | 159 | return unescape(result) 160 | 161 | 162 | __all__ = [encode, encode_word] 163 | 164 | 165 | if __name__ == '__main__': 166 | arg_parser = argparse.ArgumentParser() 167 | arg_parser.add_argument('merges-file', action='store', help='path to file with merges') 168 | arg_parser.add_argument('word', action='store', help='word to encode', default='if') 169 | 170 | args = arg_parser.parse_args() 171 | 172 | merges = read_merges(args.merges_file) 173 | 174 | subwords = encode_word(args.word, merges) 175 | print(subwords) -------------------------------------------------------------------------------- /codeprep/bpepkg/bpe_learn.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import collections 6 | import logging 7 | 8 | import regex 9 | from tqdm import tqdm 10 | from typing import Dict, List, Tuple, Set 11 | 12 | from codeprep.bpepkg.merge import Merge, MergeList 13 | from codeprep.util import PriorityCounter 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # ======== BPE algo itself 18 | 19 | 20 | def get_stats(split_base_vocab: Dict[str, int]) -> PriorityCounter: 21 | pairs = collections.defaultdict(int) 22 | for word, freq in split_base_vocab.items(): 23 | symbols = word.split(' ') 24 | for i in range(len(symbols) - 1): 25 | pairs [symbols[i], symbols[i + 1]] += freq 26 | return PriorityCounter(pairs) 27 | 28 | 29 | def merge_vocab(pair: Tuple[str, str], input_vocab: Dict[str, int]) -> Tuple[Dict[str, int], List]: 30 | """ 31 | >>> pair = ('w', 'o') 32 | >>> input_vocab = {'b i r d @': 3, 'w o r d @': 7, 'w o g @': 13} 33 | >>> new_vocab, new_pairs = merge_vocab(pair, input_vocab) 34 | >>> new_vocab 35 | {'b i r d @': 3, 'wo r d @': 7, 'wo g @': 13} 36 | >>> new_pairs 37 | [(('wo', 'r'), 7), (('o', 'r'), -7), (('wo', 'g'), 13), (('o', 'g'), -13)] 38 | """ 39 | output_vocab = {} 40 | concat_pair_with_space = ' '.join(pair) 41 | concat_pair_with_space_escaped = regex.escape(concat_pair_with_space) 42 | concat_pair = ''.join(pair) 43 | reg = regex.compile('(^|[^ ]+ )(' + concat_pair_with_space_escaped + ')( [^ ]+|$)') 44 | added_pairs = [] 45 | for word in input_vocab: 46 | word_occurences = input_vocab[word] 47 | match = reg.search(word) 48 | while match: 49 | # word changed 50 | if match.group(1) != '': 51 | subtoken_before = match.group(1)[:-1] 52 | added_pairs.append(((subtoken_before, concat_pair), word_occurences)) 53 | if pair != (subtoken_before, pair[0]): 54 | added_pairs.append(((subtoken_before, pair[0]), -word_occurences)) 55 | if match.group(3) != '': 56 | subtoken_after = match.group(3)[1:] 57 | added_pairs.append(((concat_pair, subtoken_after), word_occurences)) 58 | if pair != (pair[1], subtoken_after): 59 | added_pairs.append(((pair[1], subtoken_after), -word_occurences)) 60 | start, end = match.span(2) 61 | replacement = concat_pair 62 | word = word[:start] + replacement + word[end:] 63 | match = reg.search(word) 64 | output_vocab[word] = word_occurences 65 | return output_vocab, added_pairs 66 | 67 | 68 | def do_merges(vocab: Dict[str, int], n_merges: int) -> Tuple[Dict[str, int], MergeList]: 69 | """ 70 | Do `n_merges` bpe merges starting from vocabulary splittings `vocab` which were formed after applying `already_done_merges` merges 71 | 72 | :param vocab: base vocab splittings formed after applying `already_done_merges` in a format 73 | {"fix me@": 3242, "a b c@": 400} 74 | :param n_merges: number of bpe merges to be applied 75 | :param already_done_merges: merges which has already been applied in a format ["e @, f i", "fi x", "m e@"] 76 | 77 | :return: a tuple where the first elements is the resulting vocab splittings, 78 | the second one are all the merges done to reach those vocab splittings 79 | 80 | >>> input_vocab = { 81 | ... "b i r d @": 3, 82 | ... "w o r d @": 7, 83 | ... "w o g @": 13 84 | ... } 85 | 86 | >>> vocab, merges = do_merges(input_vocab, 10) 87 | >>> vocab 88 | {'bird@': 3, 'word@': 7, 'wog@': 13} 89 | >>> merges 90 | [('w', 'o'): (20, 0), ('g', '@'): (13, 1), ('wo', 'g@'): (13, 2), ('r', 'd'): (10, 3), ('rd', '@'): (10, 4), \ 91 | ('wo', 'rd@'): (7, 5), ('b', 'i'): (3, 6), ('bi', 'rd@'): (3, 7)] 92 | 93 | 94 | >>> input_vocab = {"a a a a a @": 3} 95 | 96 | >>> do_merges(input_vocab, 10) 97 | ({'aaaaa@': 3}, [('a', 'a'): (12, 0), ('a', '@'): (3, 1), ('aa', 'aa'): (3, 2), ('aaaa', 'a@'): (3, 3)]) 98 | 99 | >>> input_vocab = {"l a l a l a @": 3} 100 | >>> do_merges(input_vocab, 10) 101 | ({'lalala@': 3}, [('l', 'a'): (9, 0), ('la', 'la'): (6, 1), ('la', '@'): (3, 2), ('lala', 'la@'): (3, 3)]) 102 | 103 | """ 104 | merges = MergeList() 105 | pairs = get_stats(vocab) 106 | for i in tqdm(range(n_merges), total=n_merges): 107 | try: 108 | best, occurences = pairs.pop_pair() 109 | merges.append(Merge(best, freq=occurences, priority=i)) 110 | except KeyError: 111 | break 112 | vocab, added_pairs = merge_vocab(best, vocab) 113 | for p in added_pairs: 114 | pairs.add(*p) 115 | return vocab, merges 116 | 117 | # ======== Create auxiliary data structures. 118 | 119 | 120 | def create_bpe_cache(vocab: Dict[str, int]) -> Dict[str, List[str]]: 121 | merges_cache = {} 122 | for entry, _ in vocab.items(): 123 | subword_list = entry.split(' ') 124 | key = ''.join(subword_list) 125 | merges_cache[key] = subword_list 126 | return merges_cache 127 | 128 | 129 | def create_resulting_vocab(split_base_vocab: Dict[str, int]) -> Dict[str, int]: 130 | resulting_vocab = collections.defaultdict(int) 131 | for entry, frequency in split_base_vocab.items(): 132 | for subword in entry.split(" "): 133 | resulting_vocab[subword] += frequency 134 | return resulting_vocab 135 | 136 | # ============ 137 | 138 | 139 | def separate_vocabs(all_vocab: Dict[str, int], tokens_to_exclude: Set[str]) -> Tuple[Dict[str, int], Dict[str, int]]: 140 | main_vocab = {} 141 | other_vocab = {} 142 | for k, v in all_vocab.items(): 143 | if k not in tokens_to_exclude: 144 | main_vocab[k] = v 145 | else: 146 | other_vocab[k] = v 147 | return main_vocab, other_vocab -------------------------------------------------------------------------------- /codeprep/bpepkg/cache.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | """ 6 | >>> import tempfile 7 | >>> f = tempfile.NamedTemporaryFile(delete=False) 8 | >>> cache = {'ab': ['a', 'b'], '\\t\\xa0': ['\\t', '\\xa0']} 9 | >>> dump_bpe_cache(cache, f.name) 10 | >>> cache == read_bpe_cache(f.name) 11 | True 12 | """ 13 | from typing import List, Dict 14 | 15 | from codeprep.util import to_literal_str, to_non_literal_str 16 | 17 | KEY_VALUE_DELIM = '\t' 18 | VALUE_PARTS_DELIM = ' ' 19 | 20 | 21 | def read_bpe_cache(file: str) -> Dict[str, List[str]]: 22 | words = {} 23 | with open(file, 'r') as f: 24 | for line in f: 25 | line = line.rstrip('\n') 26 | splits = line.split(KEY_VALUE_DELIM) 27 | second_column = to_non_literal_str(splits[1]).split(VALUE_PARTS_DELIM) 28 | words[to_non_literal_str(splits[0])] = second_column 29 | return words 30 | 31 | 32 | def dump_bpe_cache(dct: Dict[str, List[str]], file: str) -> None: 33 | with open(file, 'w') as f: 34 | for word, subwords in dct.items(): 35 | a = to_literal_str(" ".join(subwords)) 36 | f.write(f'{to_literal_str(str(word))}{KEY_VALUE_DELIM}{a}\n') -------------------------------------------------------------------------------- /codeprep/bpepkg/merge.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import copy 6 | 7 | from typing import List, Tuple, Union, Optional, Iterator, Dict 8 | 9 | from codeprep.util import is_python_3_6_and_higher, to_literal_str, to_non_literal_str 10 | 11 | 12 | # TODO this class should be frozen 13 | class Merge(object): 14 | def __init__(self, pair: Tuple[str, str], freq: int = None, priority: int = None): 15 | self.pair = pair 16 | self.freq = freq 17 | self.priority = priority 18 | 19 | @classmethod 20 | def parse_file_entry(cls, line: str, priority: int) -> "Merge": 21 | try: 22 | spl = to_non_literal_str(line).split(" ") 23 | if len(spl) == 2: 24 | return cls((spl[0], spl[1]), priority=priority) 25 | else: 26 | return cls((spl[0], spl[1]), freq=int(spl[2]), priority=priority) 27 | except (IndexError, TypeError) as err: 28 | raise ValueError(f"Invalid merge entry format: {line}", err) 29 | 30 | def __str__(self): 31 | return self.__repr__() 32 | 33 | def __repr__(self): 34 | return f'{self.pair}: ({self.freq}, {self.priority})' 35 | 36 | def __eq__(self, other): 37 | return self.__class__ == other.__class__ and self.pair == other.pair and self.priority == other.priority \ 38 | and self.freq == other.freq 39 | 40 | def __hash__(self): 41 | return hash((self.pair, self.priority, self.freq)) 42 | 43 | 44 | class MergeList(object): 45 | """ 46 | >>> merges = MergeList() 47 | >>> merges = merges.append(Merge(('a', 'b'), 34, 0)).append(Merge(('b', 'c'), 44, 1)) 48 | >>> [m for m in merges] 49 | [('a', 'b'): (34, 0), ('b', 'c'): (44, 1)] 50 | >>> len(merges) 51 | 2 52 | >>> merges[0] 53 | ('a', 'b'): (34, 0) 54 | >>> merges[1] 55 | ('b', 'c'): (44, 1) 56 | >>> merges[-1] 57 | ('b', 'c'): (44, 1) 58 | >>> merges[0:-1] 59 | [('a', 'b'): (34, 0)] 60 | >>> type(merges[0:-1]) 61 | 62 | 63 | >>> merges[2] 64 | Traceback (most recent call last): 65 | ... 66 | IndexError: list index out of range 67 | 68 | >>> ('a', 'b') in merges 69 | True 70 | >>> ('a', 'x') in merges 71 | False 72 | 73 | >>> merge1 = Merge(('a', 'b'), 34, 0) 74 | >>> merge2 = Merge(('a', 'b'), 34, 0) 75 | >>> dct = {merge1: 3} 76 | >>> dct[merge2] 77 | 3 78 | 79 | >>> merges + MergeList().append(Merge(('d', 'e'), 84, 0)) 80 | [('a', 'b'): (34, 0), ('b', 'c'): (44, 1), ('d', 'e'): (84, 2)] 81 | >>> merges + [(('d', 'e'), 84, 1)] 82 | Traceback (most recent call last): 83 | ... 84 | TypeError: Cannot add to a MergeList 85 | 86 | >>> merges + merges 87 | Traceback (most recent call last): 88 | ... 89 | ValueError: It's only possible to add merges in priority order. The priority of the next merge should be 2 but is 3 90 | 91 | >>> merges.append(Merge(('x', 'y'), 34, 0)) 92 | Traceback (most recent call last): 93 | ... 94 | ValueError: It's only possible to add merges in priority order. The priority of the next merge should be 2 but is 0 95 | 96 | >>> merges = merges.append(Merge(('x', 'y'), 34)) 97 | >>> merges 98 | [('a', 'b'): (34, 0), ('b', 'c'): (44, 1), ('x', 'y'): (34, 2)] 99 | >>> merges.get_priority(('x', 'y')) 100 | 2 101 | """ 102 | def __init__(self): 103 | self.merges: Dict[Tuple[str, str], Merge] = {} 104 | 105 | def __contains__(self, item): 106 | return item in self.merges 107 | 108 | def __len__(self): 109 | return len(self.merges) 110 | 111 | def __iter__(self) -> Iterator[Merge]: 112 | return iter(self._get_sorted_merges()) 113 | 114 | def _get_sorted_merges(self) -> List[Merge]: 115 | if not is_python_3_6_and_higher(): 116 | # we cannot rely on dict order for python versions lower than 3.6 117 | raise NotImplementedError() 118 | 119 | return list(self.merges.values()) 120 | 121 | def __add__(self, other: 'MergeList'): 122 | if self.__class__ != other.__class__: 123 | raise TypeError(f"Cannot add {other.__class__} to a MergeList") 124 | 125 | new_merge_list = copy.deepcopy(self) 126 | other_copy = copy.deepcopy(other) 127 | first_list_len = len(new_merge_list) 128 | for merge in other_copy: 129 | merge.priority += first_list_len 130 | new_merge_list.append(merge) 131 | 132 | return new_merge_list 133 | 134 | def append(self, merge: Merge) -> 'MergeList': 135 | # along with the pair we save its priority and the number of its occurrences 136 | if merge.priority is None: 137 | merge.priority = len(self.merges) 138 | elif merge.priority != len(self.merges): 139 | raise ValueError(f"It's only possible to add merges in priority order. " 140 | f"The priority of the next merge should be {len(self.merges)} but is {merge.priority}") 141 | 142 | self.merges[merge.pair] = merge 143 | return self 144 | 145 | def get_priority(self, pair: Tuple[str, str]) -> int: 146 | return self.merges[pair].priority 147 | 148 | def __getitem__(self, item) -> Union[List[Merge], Merge]: 149 | lst = self._get_sorted_merges() 150 | return lst[item] 151 | 152 | def __repr__(self): 153 | return repr(self[:]) 154 | 155 | def __eq__(self, other): 156 | return self.__class__ == other.__class__ and self[:] == other[:] 157 | 158 | 159 | def read_merges(file: str, n_merges: Optional[int] = None) -> MergeList: 160 | merges = MergeList() 161 | with open(file, 'r') as f: 162 | for idx, line in enumerate(f): 163 | if n_merges and idx >= n_merges: 164 | break 165 | line = line.rstrip('\n') 166 | merges.append(Merge.parse_file_entry(line, idx)) 167 | return merges 168 | 169 | 170 | def dump_merges(merges: MergeList, file: str): 171 | with open(file, 'w') as f: 172 | for merge in merges: 173 | f.write(f"{to_literal_str(' '.join(merge.pair))} {merge.freq}\n") -------------------------------------------------------------------------------- /codeprep/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /codeprep/cli/impl.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | import os 7 | from typing import Dict, Optional, Any 8 | 9 | import sys 10 | 11 | import codeprep 12 | import codeprep.api.corpus 13 | import codeprep.api.text 14 | from codeprep.api.common import create_split_value, create_str_value 15 | from codeprep.bpepkg.bpe_config import BpeParam, BpeConfig 16 | from codeprep.pipeline import bpelearner 17 | from codeprep.pipeline.bperegistry import InvalidBpeCodesIdError, USER_PREDEFINED_BPE_CODES 18 | from codeprep.pipeline.dataset import Dataset, normalize_extension_string 19 | from codeprep.prepconfig import PrepConfig, PrepParam 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def set_log_level(args: Dict[str, str]) -> None: 25 | if args['--verbose']: 26 | logging.root.setLevel(logging.DEBUG) 27 | else: 28 | logging.root.setLevel(logging.ERROR) 29 | 30 | 31 | def get_option(args: Dict, option: str) -> Optional[Any]: 32 | return args[option] if option in args else None 33 | 34 | 35 | def is_option_true(args: Dict, option: str) -> bool: 36 | return bool(get_option(args, option)) 37 | 38 | 39 | def handle_learnbpe(args): 40 | set_log_level(args) 41 | path = os.path.abspath(args['--path']) 42 | bpe_config = create_bpe_config_from_args(args) 43 | n_merges = int(args['']) 44 | if args['--legacy']: 45 | parsed_extensions = normalize_extension_string(args['--ext']) 46 | if parsed_extensions and parsed_extensions != ['java']: 47 | print("Only --ext 'java' is supported when --legacy is specified") 48 | return 49 | else: 50 | extensions = 'java' 51 | else: 52 | extensions = args['--ext'] 53 | bpe_codes_id = args['--id'] 54 | dataset = Dataset.create(path, bpe_config.to_prep_config(), extensions, None, bpe_config) 55 | 56 | if not dataset.bpe_codes_id: 57 | dataset.assign_bpe_codes_id(bpe_config, predefined_bpe_codes_id=bpe_codes_id) 58 | elif bpe_codes_id: 59 | logger.warning(f"Ignoring passed bpe codes id: {bpe_codes_id}. " 60 | f"This dataset has already been assigned id: {dataset.bpe_codes_id}") 61 | 62 | bpelearner.run(dataset, n_merges, bpe_config) 63 | 64 | 65 | def handle_splitting(args: Dict) -> None: 66 | set_log_level(args) 67 | try: 68 | prep_config = create_prep_config_from_args(args) 69 | bpe_codes_id = get_option(args, '') or get_predefined_bpe_codes_id(args) 70 | if args['']: 71 | prep_text = codeprep.api.text.preprocess(args[''], prep_config, bpe_codes_id, 72 | extension=args['--ext']) 73 | print(prep_text) 74 | else: 75 | codeprep.api.corpus.preprocess_corpus(args['--path'], prep_config, bpe_codes_id, 76 | extensions=args['--ext'], 77 | output_path=args['--output-path'], 78 | calc_vocab=bool(args['--calc-vocab'])) 79 | except InvalidBpeCodesIdError as err: 80 | logger.error(err) 81 | return 82 | 83 | 84 | def create_bpe_config_from_args(run_options: Dict[str, str]) -> BpeConfig: 85 | if run_options['--no-unicode']: 86 | unicode = 'no' 87 | elif run_options['--bytes']: 88 | unicode = 'bytes' 89 | else: 90 | unicode = 'yes' 91 | return BpeConfig({ 92 | BpeParam.CASE: 'yes', 93 | BpeParam.WORD_END: run_options["--word-end"], 94 | BpeParam.BASE: 'java' if run_options['--legacy'] else 'code', 95 | BpeParam.UNICODE: unicode 96 | }) 97 | 98 | 99 | def create_prep_config_from_args(arguments: Dict) -> PrepConfig: 100 | max_str_length = get_option(arguments, '--max-str-length') 101 | max_str_length = int(max_str_length) if max_str_length is not None else sys.maxsize 102 | return PrepConfig({ 103 | PrepParam.EN_ONLY: 'U' if is_option_true(arguments, '--no-unicode') else 'u', 104 | PrepParam.COM: '0' if is_option_true(arguments, '--no-com') else 'c', 105 | PrepParam.STR: create_str_value(is_option_true(arguments, '--no-str'), max_str_length), 106 | PrepParam.SPLIT: create_split_value_from_args(arguments), 107 | PrepParam.TABS_NEWLINES: '0' if is_option_true(arguments, '--no-spaces') else 's', 108 | PrepParam.CASE: 'l' if is_option_true(arguments, '--no-case') else 'u', 109 | }) 110 | 111 | 112 | def get_predefined_bpe_codes_id(arguments: Dict) -> Optional[str]: 113 | for predefined_id in USER_PREDEFINED_BPE_CODES: 114 | if is_option_true(arguments, predefined_id): 115 | return predefined_id 116 | 117 | return '0' if is_option_true(arguments, 'chars') else None 118 | 119 | 120 | def create_split_value_from_args(arguments: Dict) -> str: 121 | if is_option_true(arguments, 'nosplit'): 122 | return create_split_value('nosplit', full_strings=is_option_true(arguments, '--full-strings')) 123 | elif is_option_true(arguments, 'chars'): 124 | return create_split_value('chars') 125 | elif is_option_true(arguments, 'basic'): 126 | return create_split_value('basic', 127 | split_numbers=is_option_true(arguments, '--split-numbers'), 128 | ronin=is_option_true(arguments, '--ronin'), 129 | stem=is_option_true(arguments, '--stem')) 130 | elif is_option_true(arguments, 'bpe'): 131 | return create_split_value('bpe', bpe_codes_id=get_predefined_bpe_codes_id(arguments)) 132 | else: 133 | raise AssertionError(f"Invalid split option: {arguments}") -------------------------------------------------------------------------------- /codeprep/cli/vocab.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import argparse 6 | import logging 7 | 8 | from codeprep.dirutils import walk 9 | from codeprep.pipeline.vocab import calc_vocab 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | if __name__ == '__main__': 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('path_to_dataset', action='store', help=f'path to dataset') 17 | parser.add_argument('output_dir', action='store', help=f'output dir') 18 | parser.add_argument('extension', action='store', help=f'extension') 19 | 20 | args = parser.parse_known_args() 21 | args = args[0] 22 | 23 | calc_vocab(args.path_to_dataset, walk(args.path_to_dataset.encode(), extension=args.extension.encode()), args.output_dir) -------------------------------------------------------------------------------- /codeprep/config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | 7 | import appdirs 8 | 9 | TRAIN_DIR = 'train' 10 | TEST_DIR = 'test' 11 | VALID_DIR = 'valid' 12 | 13 | CASE_DIR='case' 14 | NO_CASE_DIR='nocase' 15 | 16 | REPR_DIR = 'repr' 17 | PARSED_DIR = 'parsed' 18 | BPE_DIR = 'bpe' 19 | VOCAB_DIR = 'vocab' 20 | 21 | current_script_location = os.path.realpath(__file__) 22 | root_package_dir = os.path.dirname(current_script_location) 23 | data_dir = os.path.join(root_package_dir, 'data') 24 | test_data_dir = os.path.join(os.path.dirname(root_package_dir), 'test-data') 25 | 26 | app_name='codeprep' 27 | 28 | with open(os.path.join(root_package_dir, 'VERSION')) as version_file: 29 | version = version_file.read().strip() 30 | 31 | USER_CONFIG_DIR = appdirs.user_config_dir(app_name, appauthor=False, version=version) 32 | USER_CACHE_DIR = appdirs.user_cache_dir(app_name, appauthor=False, version=version) 33 | 34 | DEFAULT_FILE_LIST_DIR = os.path.join(USER_CACHE_DIR, 'file_lists') 35 | DEFAULT_PARSED_DATASETS_DIR = os.path.join(USER_CACHE_DIR, 'parsed_datasets') 36 | DEFAULT_PREP_DATASETS_DIR = os.path.join(USER_CACHE_DIR, 'prep_datasets') 37 | DEFAULT_BPE_DIR = os.path.join(data_dir, BPE_DIR) 38 | USER_BPE_DIR = os.path.join(USER_CONFIG_DIR, BPE_DIR) 39 | USER_VOCAB_DIR = os.path.join(USER_CONFIG_DIR, VOCAB_DIR) 40 | DEFAULT_BPE_CACHE_DIR = os.path.join(USER_CACHE_DIR, BPE_DIR) 41 | DEFAULT_CORPUS_SIZES_DIR = os.path.join(USER_CACHE_DIR, 'corpus_sizes') 42 | 43 | REWRITE_PARSED_FILE=False 44 | REWRITE_PREPROCESSED_FILE=False 45 | 46 | CHUNKSIZE=24 47 | LIMIT_FILES_ON_LAST_MODIFICATION_CHECK=1000 48 | LIMIT_FILES_SCANNING=50000 -------------------------------------------------------------------------------- /codeprep/data/allamanis_dataset_metadata/small-train-projects.txt: -------------------------------------------------------------------------------- 1 | AccelService 2 | agentcontest 3 | Ambience 4 | AndroidApp 5 | Android-CalendarView 6 | android-nice-guidance 7 | android_packages_apps_ClassicnerdWallpapers 8 | ANNIS 9 | Arboretum-Kiosk 10 | Atlantis 11 | BacksideUpdater 12 | bee-encode 13 | Birthday-Ring-Kata 14 | bookmarklet-and-chrome-addon 15 | BuilderGen 16 | Californium 17 | cassandra-tutorial 18 | CherryToolsMP 19 | Client 20 | codjo-gui-toolkit 21 | CommuniCase-spmode-mail-support-patch 22 | cookbook-domain 23 | CreativeStick 24 | cuke4duke 25 | DAVEtools 26 | device_allwinner_novo7a 27 | DockSoundRedirector 28 | Dspace-googleanalytics 29 | eclim 30 | elasticsearch-redis-river 31 | erjang 32 | expr 33 | festomat 34 | fluent-logger-java 35 | ftbt 36 | gatein-wsrp 37 | Gilded-Rose-Kata 38 | gora 39 | groovy-eclipse 40 | GWTCleaner 41 | hbasene 42 | hibernateuniversity-devoxx 43 | Hudson-Gerrit-Plugin 44 | imageassert 45 | InterviewStreet 46 | izpack 47 | java-api 48 | javacodes 49 | jboss-aejb 50 | jdeb 51 | JHexView 52 | jogl-utils 53 | Json-lib 54 | karma 55 | labeled-test-groups-publisher-plugin 56 | Library 57 | LockSupport 58 | lzmajio 59 | markdownj 60 | maven-thrift-plugin 61 | metatypes 62 | minerhat 63 | modern-ee-app20 64 | mousefeed 65 | MWM-for-Android 66 | Nearby 67 | NFEGuardian-client 68 | oae-maven-plugin 69 | open311-android 70 | opposite_of_a_bloom_filter 71 | otertool 72 | payments4j 73 | phoneload 74 | platform_packages_apps_alarmclock 75 | play2-scheduled-job-demo 76 | port-allocator-maven-plugin 77 | programmer-rush-beta 78 | Pubtran-London 79 | rabbitosgi 80 | redak 81 | RestFixture 82 | Roll 83 | Saints-Robotics-Programming 84 | SCORMCloud_JavaDemoApp 85 | sensei 86 | sia2012 87 | SimpleWarps 88 | smartgrid 89 | Solr-Ranking-Plugin 90 | SpoutWallet 91 | Spring-MVC-sample-using-iBatis 92 | sql-processor 93 | storm-counts 94 | swagger-spring 95 | Tartiquizz 96 | TenVersion_Project 97 | ThreadPaint 98 | Toast-Plugin-for-Smart-Phone 99 | translator-aquacrop 100 | Twister 101 | UPC_POO_EDU 102 | Video-Display-Portlet 103 | vraptor-jasperreport 104 | websms 105 | Wool-Trees 106 | xml-utilities 107 | zest -------------------------------------------------------------------------------- /codeprep/data/bpe/case/0/merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/codeprep/data/bpe/case/0/merges.txt -------------------------------------------------------------------------------- /codeprep/data/bpe/nocase/0/merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/codeprep/data/bpe/nocase/0/merges.txt -------------------------------------------------------------------------------- /codeprep/dirutils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | import os 7 | from datetime import datetime 8 | 9 | from typing import Optional, List, Generator 10 | 11 | from codeprep.config import LIMIT_FILES_ON_LAST_MODIFICATION_CHECK 12 | from codeprep.fileutils import has_one_of_extensions 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def walk(path:bytes, extension: Optional[bytes] = None) -> Generator[bytes, None, None]: 18 | if os.path.isfile(path) and (not extension or path.endswith(extension)): 19 | yield path 20 | else: 21 | for root, dirs, files in os.walk(path): 22 | for file in files: 23 | if not extension or file.endswith(extension): 24 | yield os.path.join(root, file) 25 | 26 | 27 | def walk_and_save(path: str, dir_list_path: str, file_list_path: str, return_dirs_instead_of_regular_files: bool, 28 | extensions: Optional[List[str]]) -> Generator[bytes, None, None]: 29 | with open(dir_list_path, 'w') as d, open(file_list_path, 'w') as f: 30 | path_bin = path.encode() 31 | extensions_bin = list(map(lambda e: e.encode(), extensions)) if extensions else None 32 | # we want to list and store all the files a sequences of bytes to avoid problems with different encodings for filenames 33 | if os.path.isfile(path_bin): 34 | res = os.path.basename(path_bin) 35 | f.write(f'{res}\n') 36 | if not return_dirs_instead_of_regular_files: 37 | yield res 38 | else: 39 | for root, dirs, files in os.walk(path_bin, followlinks=True): 40 | # we pass bytes to os.walk -> the output are bytes as well 41 | for dir in dirs: 42 | bin_name = os.path.join(os.path.relpath(root, path_bin), dir) 43 | d.write(f'{bin_name}\n') 44 | if return_dirs_instead_of_regular_files: 45 | yield bin_name 46 | for file in files: 47 | bin_name = os.path.join(os.path.relpath(root, path_bin), file) 48 | if not extensions or has_one_of_extensions(bin_name, extensions_bin): 49 | if not os.path.islink(os.path.join(root, file)): 50 | f.write(f'{bin_name}\n') 51 | if not return_dirs_instead_of_regular_files: 52 | yield bin_name 53 | 54 | 55 | def get_dir_last_modification(path: str, limit: int = LIMIT_FILES_ON_LAST_MODIFICATION_CHECK) -> datetime: 56 | 57 | def walk_path(path): 58 | counter = 0 59 | if os.path.isfile(path) or len(os.listdir(path)) == 0: 60 | yield os.path.getmtime(path) 61 | else: 62 | for root, dirs, files in os.walk(path): 63 | for dir in dirs: 64 | if counter >= limit: 65 | return 66 | counter += 1 67 | yield os.path.getmtime(os.path.join(root, dir)) 68 | for file in files: 69 | if counter >= limit: 70 | return 71 | full_path = os.path.join(root, file) 72 | if not os.path.islink(full_path): 73 | counter += 1 74 | yield os.path.getmtime(full_path) 75 | 76 | mtime = max(walk_path(path)) 77 | return datetime.fromtimestamp(mtime) 78 | 79 | 80 | def get_timestamp(path: str) -> str: 81 | last_modif_time = get_dir_last_modification(path) 82 | return last_modif_time.strftime("%y-%m-%dT%H-%M-%S") -------------------------------------------------------------------------------- /codeprep/fileutils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | 7 | from typing import List, Tuple 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def has_one_of_extensions(name: bytes, extensions: List[bytes]) -> bool: 13 | """ 14 | >>> has_one_of_extensions(b'/home/abc.java', [b'java', b'c']) 15 | True 16 | 17 | >>> has_one_of_extensions(b'/home/abc.py', [b'java', b'c']) 18 | False 19 | 20 | >>> has_one_of_extensions(b'/home/abc.dtc', [b'java', b'c']) 21 | False 22 | 23 | >>> has_one_of_extensions(b'/home/abc.f.java.prep', [b'java.prep', b'c']) 24 | True 25 | 26 | >>> has_one_of_extensions(b'/home/abc.f.java.prep', [b'a.prep', b'c']) 27 | False 28 | 29 | """ 30 | for ext in extensions: 31 | if name.endswith(b'.' + ext): 32 | return True 33 | return False 34 | 35 | 36 | def read_file_contents(file_path: bytes) -> Tuple[List[str], bytes]: 37 | try: 38 | return read_file_with_encoding(file_path, 'utf-8') 39 | except UnicodeDecodeError: 40 | try: 41 | return read_file_with_encoding(file_path, 'ISO-8859-1') 42 | except UnicodeDecodeError: 43 | logger.error(f"Unicode decode error in file: {file_path}") 44 | 45 | 46 | def read_file_with_encoding(file_path: bytes, encoding: str) -> Tuple[List[str], bytes]: 47 | with open(file_path, 'r', encoding=encoding) as f: 48 | return [line.rstrip('\n') for line in f], file_path -------------------------------------------------------------------------------- /codeprep/logging.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | disable_existing_loggers: False 3 | formatters: 4 | with_process_name: 5 | format: "%(asctime)s [%(name)s] %(levelname)s: %(message)s" 6 | 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | level: DEBUG 11 | formatter: with_process_name 12 | stream: ext://sys.stdout 13 | loggers: 14 | codeprep.infrastructure: 15 | level: DEBUG 16 | propagate: True 17 | codeprep.api: 18 | level: DEBUG 19 | propagate: True 20 | codeprep.cli: 21 | level: DEBUG 22 | propagate: True 23 | codeprep.vocab: 24 | level: DEBUG 25 | propagate: True 26 | codeprep.parse_projects: 27 | level: DEBUG 28 | propagate: True 29 | codeprep.to_repr: 30 | level: DEBUG 31 | propagate: True 32 | root: 33 | level: DEBUG 34 | handlers: [console] -------------------------------------------------------------------------------- /codeprep/noneng.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def is_non_eng(word): 11 | return not __isascii(word) 12 | 13 | 14 | def __isascii(str): 15 | try: 16 | str.encode('ascii') 17 | return True 18 | except UnicodeEncodeError: 19 | return False 20 | 21 | 22 | def replace_non_ascii_seqs(word:str, placeholder: str) -> str: 23 | """ 24 | >>> replace_non_ascii_seqs("","\xf7") 25 | '' 26 | 27 | >>> replace_non_ascii_seqs("Ü", "\xf7") 28 | '\xf7' 29 | 30 | >>> replace_non_ascii_seqs("Üüø", "\xf7") 31 | '\xf7' 32 | 33 | >>> replace_non_ascii_seqs("abcd", "\xf7") 34 | 'abcd' 35 | 36 | >>> replace_non_ascii_seqs("aæbñńcdú", "\xf7") 37 | 'a\xf7b\xf7cd\xf7' 38 | 39 | >>> replace_non_ascii_seqs("any_string", "\xf7\xa0") 40 | Traceback (most recent call last): 41 | ... 42 | ValueError: Placeholder should be a single character, but is ÷\xa0 43 | """ 44 | if len(placeholder) != 1: 45 | raise ValueError(f"Placeholder should be a single character, but is {placeholder}") 46 | 47 | new_word = "" 48 | ongoing_non_ascii_seq = False 49 | for ch in word: 50 | if ord(ch) < 128: 51 | if ongoing_non_ascii_seq: 52 | new_word += placeholder 53 | ongoing_non_ascii_seq = False 54 | new_word += ch 55 | else: 56 | ongoing_non_ascii_seq = True 57 | if ongoing_non_ascii_seq: 58 | new_word += placeholder 59 | 60 | return new_word -------------------------------------------------------------------------------- /codeprep/parse/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /codeprep/parse/core.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | from typing import List 7 | 8 | from pygments import lex 9 | from pygments.lexers import get_lexer_by_name, guess_lexer 10 | from pygments.util import ClassNotFound 11 | 12 | from codeprep.parse import matchers 13 | from codeprep.parse.matchers import DefaultMatcher 14 | from codeprep.tokens.rootclasses import ParsedToken 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | matchers = [ 19 | matchers.NewLineMatcher(), 20 | matchers.TabMatcher(), 21 | matchers.WhitespaceMatcher(), 22 | matchers.OperatorMatcher(), 23 | matchers.NumberMatchers(), 24 | matchers.WordMatcher(), 25 | matchers.GenericLiteralMatcher(), 26 | matchers.KeywordMatcher(), 27 | matchers.StringMatcher(), 28 | matchers.OneLineCommentMatcher(), 29 | matchers.MultiLineLineCommentMatcher(), 30 | matchers.GenericTokenMatcher() 31 | ] 32 | 33 | 34 | def _convert(token, value: str) -> List[ParsedToken]: 35 | for matcher in matchers: 36 | if matcher.match(token, value): 37 | return matcher.transform(value) 38 | 39 | if DefaultMatcher().match(token, value): 40 | return DefaultMatcher().transform(value) 41 | 42 | assert False 43 | 44 | 45 | def convert_text(text: str, extension: str) -> List[ParsedToken]: 46 | extension = extension or 'java' 47 | if extension: 48 | try: 49 | lexer = get_lexer_by_name(extension) 50 | except ClassNotFound as err: 51 | logger.warning(err) 52 | lexer = guess_lexer(text) 53 | else: 54 | lexer = guess_lexer(text) 55 | for token, value in lex(text, lexer): 56 | model_tokens = _convert(token, value) 57 | for mr in model_tokens: 58 | yield mr -------------------------------------------------------------------------------- /codeprep/parse/matchers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List 6 | 7 | from pygments.token import Token 8 | 9 | from codeprep.parse.subtokens import split_into_words, split_string 10 | from codeprep.tokens.containers import StringLiteral, OneLineComment, MultilineComment 11 | from codeprep.tokens.numeric import Number, Zero, One 12 | from codeprep.tokens.rootclasses import ParsedToken 13 | from codeprep.tokens.whitespace import NewLine, Tab 14 | from codeprep.tokens.word import KeyWord, Operator, Semicolon, OpeningCurlyBracket, ClosingCurlyBracket, OpeningBracket, \ 15 | ClosingBracket 16 | 17 | 18 | class DefaultMatcher(object): 19 | def match(self, token, value: str) -> bool: 20 | return True 21 | 22 | def transform(self, value: str) -> List[ParsedToken]: 23 | return split_into_words(value) 24 | 25 | 26 | class GenericTokenMatcher(object): 27 | def match(self, token, value: str) -> bool: 28 | return token in Token.Generic 29 | 30 | def transform(self, value: str) -> List[ParsedToken]: 31 | return split_into_words(value) 32 | 33 | 34 | class StringMatcher(object): 35 | def match(self, token, value: str) -> bool: 36 | return token in Token.Literal.String 37 | 38 | def transform(self, value: str) -> List[StringLiteral]: 39 | return [StringLiteral(split_string(value), len(value))] 40 | 41 | 42 | class OneLineCommentMatcher(object): 43 | def match(self, token, value: str) -> bool: 44 | return token is Token.Comment.Single 45 | 46 | def transform(self, value: str) -> List[OneLineComment]: 47 | return [OneLineComment(split_into_words(value))] 48 | 49 | 50 | class MultiLineLineCommentMatcher(object): 51 | def match(self, token, value: str) -> bool: 52 | return token in Token.Comment and not token is Token.Comment.Single 53 | 54 | def transform(self, value: str) -> List[MultilineComment]: 55 | return [MultilineComment(split_into_words(value))] 56 | 57 | 58 | class WordMatcher(object): 59 | def match(self, token, value: str) -> bool: 60 | return token in Token.Name 61 | 62 | def transform(self, value: str) -> List[ParsedToken]: 63 | return split_into_words(value) 64 | 65 | 66 | class GenericLiteralMatcher(object): 67 | def match(self, token, value: str) -> bool: 68 | return token is Token.Literal or token is Token.Literal.Date 69 | 70 | def transform(self, value: str) -> List[ParsedToken]: 71 | return split_into_words(value) 72 | 73 | 74 | class KeywordMatcher(object): 75 | def match(self, token, value: str) -> bool: 76 | return token in Token.Keyword 77 | 78 | def transform(self, value: str) -> List[KeyWord]: 79 | return [KeyWord(value)] 80 | 81 | 82 | class NewLineMatcher(object): 83 | def match(self, token, value: str) -> bool: 84 | return value == '\n' 85 | 86 | def transform(self, value: str) -> List[NewLine]: 87 | return [NewLine()] 88 | 89 | 90 | class WhitespaceMatcher(object): 91 | def match(self, token, value: str) -> bool: 92 | return value.strip() == '' 93 | 94 | def transform(self, value: str) -> List[Tab]: 95 | return [Tab()] * (len(value) // 4) 96 | 97 | 98 | class TabMatcher(object): 99 | def match(self, token, value: str) -> bool: 100 | return value == '\t' 101 | 102 | def transform(self, value: str) -> List[Tab]: 103 | return [Tab()] 104 | 105 | 106 | class NumberMatchers(object): 107 | def match(self, token, value: str) -> bool: 108 | return token in Token.Literal.Number 109 | 110 | def transform(self, value: str) -> List[Number]: 111 | if value == '0': 112 | return [Zero()] 113 | elif value == '1': 114 | return [One()] 115 | else: 116 | return [Number(value)] 117 | 118 | 119 | class OperatorMatcher(object): 120 | def match(self, token, value: str): 121 | return token is Token.Operator or token in Token.Punctuation 122 | 123 | def transform(self, value: str) -> List[Operator]: 124 | if value == ';': 125 | return [Semicolon()] 126 | elif value == '{': 127 | return [OpeningCurlyBracket()] 128 | elif value == '}': 129 | return [ClosingCurlyBracket()] 130 | elif value == '(': 131 | return [OpeningBracket()] 132 | elif value == ')': 133 | return [ClosingBracket()] 134 | else: 135 | return [Operator(value)] 136 | 137 | 138 | class WordOperatorMatcher(object): 139 | def match(self, token, value: str): 140 | return token is Token.Operator.Word 141 | 142 | def transform(self, value: str) -> List[ParsedToken]: 143 | return split_into_words(value) -------------------------------------------------------------------------------- /codeprep/parse/subtokens.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List 6 | 7 | import regex 8 | 9 | from codeprep.noneng import is_non_eng 10 | from codeprep.tokens.containers import SplitContainer 11 | from codeprep.tokens.noneng import NonEng 12 | from codeprep.tokens.numeric import Number 13 | from codeprep.tokens.rootclasses import ParsedToken 14 | from codeprep.tokens.whitespace import NewLine, Tab, SpaceInString 15 | from codeprep.tokens.word import Underscore, Word, NonCodeChar 16 | 17 | 18 | def split_identifier(token: str) -> SplitContainer: 19 | parts = [m[0] for m in 20 | regex.finditer('(_|[0-9]+|[[:upper:]]?[[:lower:]]+|[[:upper:]]+(?![[:lower:]])|[^ ])', token)] 21 | 22 | processable_tokens = [Word.from_(p) if p != '_' else Underscore() for p in parts] 23 | split_container = SplitContainer(processable_tokens) 24 | return NonEng(split_container) if is_non_eng(token) else split_container 25 | 26 | 27 | # Using the same regexps SLP team uses to parse numbers in java code 28 | # https://github.com/SLP-team/SLP-Core/blob/master/src/main/java/slp/core/lexing/code/JavaLexer.java 29 | 30 | HEX_REGEX = "0x([0-9a-fA-F]+_)*[0-9a-fA-F]+[lL]?" 31 | BIN_REGEX = "0b([01]+_)*[01]+[lL]?" 32 | IR_REGEX = "([0-9]+_)*[0-9]+[lLfFdD]?" 33 | DBL_REGEXA = "[0-9]+\\.[0-9]+([eE][-+]?[0-9]+)?[fFdD]?" 34 | DBL_REGEXB = "[0-9]+\\.([eE][-+]?[0-9]+)?[fFdD]?" 35 | DBL_REGEXC = "\\.[0-9]+([eE][-+]?[0-9]+)?[fFdD]?" 36 | DBL_REGEXD = "[0-9]+[eE][-+]?[0-9]+[fFdD]?" 37 | 38 | NUMBER_PATTERN = f'({HEX_REGEX}|{BIN_REGEX}|{IR_REGEX}|{DBL_REGEXA}|{DBL_REGEXB}|{DBL_REGEXC}|{DBL_REGEXD})' 39 | 40 | 41 | def is_number(word: str) -> bool: 42 | """ 43 | >>> is_number("0") 44 | True 45 | 46 | >>> is_number("8") 47 | True 48 | 49 | >>> is_number("-5") 50 | False 51 | 52 | >>> is_number("23450012") 53 | True 54 | 55 | >>> is_number("283463L") 56 | True 57 | 58 | >>> is_number("342424242l") 59 | True 60 | 61 | >>> is_number("0.") 62 | True 63 | 64 | >>> is_number(".0") 65 | True 66 | 67 | >>> is_number(".0d") 68 | True 69 | 70 | >>> is_number("353535.") 71 | True 72 | 73 | >>> is_number("353535.D") 74 | True 75 | 76 | >>> is_number(".353535F") 77 | True 78 | 79 | >>> is_number(".353535f") 80 | True 81 | 82 | >>> is_number("0.2e+3D") 83 | True 84 | 85 | >>> is_number("23424.E-30F") 86 | True 87 | 88 | >>> is_number(".002e-0f") 89 | True 90 | 91 | >>> is_number("0b10101") 92 | True 93 | 94 | >>> is_number("0b0011L") # java -- not python 95 | True 96 | 97 | >>> is_number("0b0") 98 | True 99 | 100 | >>> is_number("0x8AbCc006EfBd") 101 | True 102 | 103 | >>> is_number("0xG12") 104 | False 105 | 106 | >>> is_number("0x56DL") 107 | True 108 | 109 | >>> is_number("0x56Dl") 110 | True 111 | """ 112 | return regex.fullmatch(NUMBER_PATTERN, word) is not None 113 | 114 | 115 | def to_parsed_token(token: str) -> ParsedToken: 116 | if token == '\n': 117 | return NewLine() 118 | elif token == '\t': 119 | return Tab() 120 | elif is_number(token): 121 | return Number(token) 122 | elif regex.fullmatch("\\w+", token): 123 | return split_identifier(token) 124 | else: 125 | return NonCodeChar(token) 126 | 127 | 128 | def split_string(token: str) -> List[ParsedToken]: 129 | """ 130 | >>> split_string(" var = 9.4\\t\\n") 131 | [ (n_chars=4), SplitContainer[Word(('var', none))], \ 132 | (n_chars=1), NonCodeChar(=), (n_chars=1), (9), \ 133 | NonCodeChar(.), (4), , ] 134 | """ 135 | res = [] 136 | arbitrary_whitespace = "( )+" 137 | for m in regex.finditer(f"(\\w+|[^ ]|{arbitrary_whitespace})", token): 138 | if regex.fullmatch(arbitrary_whitespace, m[0]): 139 | res.append(SpaceInString(n_chars=len(m[0]))) 140 | else: 141 | res.append(to_parsed_token(m[0])) 142 | return res 143 | 144 | 145 | def split_into_words(token: str) -> List[ParsedToken]: 146 | """ 147 | >>> split_into_words(" var = 9.4\\t\\n") 148 | [, SplitContainer[Word(('var', none))], NonCodeChar(=), (9), \ 149 | NonCodeChar(.), (4), , ] 150 | """ 151 | res = [] 152 | four_char_whitespace = " " * 4 153 | for m in regex.finditer(f"(\\w+|[^ ]|{four_char_whitespace})", token): 154 | if m[0] == four_char_whitespace: 155 | res.append(Tab()) 156 | else: 157 | res.append(to_parsed_token(m[0])) 158 | return res -------------------------------------------------------------------------------- /codeprep/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /codeprep/pipeline/bpelearner.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | import os 7 | from typing import Tuple, Dict, Set, Optional 8 | 9 | from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam, BpeConfigNotSupported 10 | from codeprep.bpepkg.bpe_encode import escape 11 | from codeprep.bpepkg.bpe_learn import separate_vocabs, logger, do_merges, create_resulting_vocab, create_bpe_cache 12 | from codeprep.bpepkg.cache import dump_bpe_cache 13 | from codeprep.bpepkg.merge import MergeList, read_merges, dump_merges 14 | from codeprep.pipeline import stages 15 | from codeprep.pipeline.bperegistry import get_max_merges, MERGES_FILE_NAME, MERGES_CACHE_FILE_NAME, \ 16 | RESULTING_VOCAB_FILE_NAME, BPE_REASSEMBLED_VOCAB_FILE_NAME 17 | from codeprep.pipeline.dataset import Dataset 18 | from codeprep.pipeline.vocab import _dump_vocab_dict, _load_vocab_dict 19 | from codeprep.util import to_non_literal_str 20 | 21 | 22 | def get_base_vocab(dataset: Dataset) -> Tuple[Dict[str, int], Dict[str, int]]: 23 | stages.run_until_base_bpe_vocab(dataset) 24 | all_vocab = _load_vocab_dict(dataset.path_to_bpe_vocab_file) 25 | non_bpe_vocab = load_nonbpe_vocab(dataset) 26 | return separate_vocabs(all_vocab, non_bpe_vocab) 27 | 28 | 29 | def load_nonbpe_vocab(dataset: Dataset) -> Set[str]: 30 | non_bpe_vocab = set() 31 | with open(dataset.path_to_nonbpe_vocab_file, 'r') as f: 32 | for line in f: 33 | non_bpe_vocab.add(to_non_literal_str(line.rstrip('\n'))) 34 | return non_bpe_vocab 35 | 36 | 37 | def check_if_bpe_config_supported(bpe_config: BpeConfig): 38 | if bpe_config.get_param_value(BpeParam.UNICODE) == 'bytes': 39 | raise BpeConfigNotSupported('Byte-BPE is not yet supported') 40 | 41 | if bpe_config.get_param_value(BpeParam.WORD_END): 42 | raise BpeConfigNotSupported('BPE with word-end characters are not yet supported') 43 | 44 | if bpe_config.get_param_value(BpeParam.CASE) == 'prefix': 45 | raise BpeConfigNotSupported('BPE with case encoded in prefix is not yet supported') 46 | 47 | 48 | def prepare_vocabs(dataset: Dataset, dir_with_most_merges, starting_from_scratch): 49 | if starting_from_scratch: 50 | base_bpe_vocab, other_vocab = get_base_vocab(dataset) # TODO extract this into stages 51 | other_vocab = {escape(k, merged=True): v for k, v in other_vocab.items()} 52 | split_base_vocab = {escape(" ".join(k)): v for k, v in base_bpe_vocab.items()} 53 | else: 54 | path_to_bpe_vocab_file = os.path.join(dir_with_most_merges, BPE_REASSEMBLED_VOCAB_FILE_NAME) 55 | non_bpe_vocab = {escape(k, merged=True) for k in load_nonbpe_vocab(dataset)} 56 | split_base_vocab = _load_vocab_dict(path_to_bpe_vocab_file) 57 | split_base_vocab, other_vocab = separate_vocabs(split_base_vocab, non_bpe_vocab) 58 | 59 | return split_base_vocab, other_vocab 60 | 61 | 62 | def get_dir_with_most_merges(dataset_bpe_path, n_merges) -> Optional[str]: 63 | max_merges = get_max_merges(dataset_bpe_path, n_merges) 64 | if not max_merges: 65 | return None 66 | 67 | dir_with_most_merges = os.path.join(dataset_bpe_path, str(max_merges)) 68 | return dir_with_most_merges 69 | 70 | 71 | def save_results(split_base_vocab, merges, new_bpe_dir): 72 | 73 | os.makedirs(new_bpe_dir) 74 | 75 | resulting_vocab = create_resulting_vocab(split_base_vocab) 76 | resulting_vocab_sorted = sorted(resulting_vocab.items(), key=lambda x: x[1], reverse=True) 77 | _dump_vocab_dict(resulting_vocab_sorted, os.path.join(new_bpe_dir, RESULTING_VOCAB_FILE_NAME)) 78 | 79 | bpe_cache = create_bpe_cache(split_base_vocab) 80 | dump_bpe_cache(bpe_cache, os.path.join(new_bpe_dir, MERGES_CACHE_FILE_NAME)) 81 | 82 | dump_merges(merges, os.path.join(new_bpe_dir, MERGES_FILE_NAME)) 83 | _dump_vocab_dict(split_base_vocab.items(), os.path.join(new_bpe_dir, BPE_REASSEMBLED_VOCAB_FILE_NAME)) 84 | logger.info(f'Bpe output files are saved into {new_bpe_dir} folder') 85 | 86 | 87 | def run(dataset: Dataset, n_merges: int, bpe_config: BpeConfig) -> None: 88 | 89 | check_if_bpe_config_supported(bpe_config) 90 | dataset_bpe_path = dataset.bpe_path 91 | 92 | dir_with_most_merges = get_dir_with_most_merges(dataset_bpe_path, n_merges) 93 | 94 | if dir_with_most_merges: 95 | logger.info("Using existing merges...") 96 | already_done_merges = read_merges(os.path.join(dir_with_most_merges, MERGES_FILE_NAME)) 97 | else: 98 | logger.info("Starting encoding from scratch. ..") 99 | already_done_merges = MergeList() 100 | 101 | split_base_vocab, other_vocab = prepare_vocabs(dataset, dir_with_most_merges, 102 | starting_from_scratch=not dir_with_most_merges) 103 | 104 | logger.info("Learning bpe codes...") 105 | split_base_vocab, merges = do_merges(split_base_vocab, n_merges - len(already_done_merges)) 106 | for k, v in other_vocab.items(): 107 | split_base_vocab[k] = v 108 | merges = already_done_merges + merges 109 | 110 | new_bpe_dir = os.path.join(dataset_bpe_path, str(len(merges))) 111 | if os.path.exists(new_bpe_dir): 112 | logging.info("Merges already learned!") 113 | return 114 | 115 | save_results(split_base_vocab, merges, new_bpe_dir) -------------------------------------------------------------------------------- /codeprep/pipeline/parse_projects.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import gzip 6 | import logging 7 | import os 8 | import pickle 9 | from multiprocessing.pool import Pool 10 | from typing import Tuple 11 | 12 | from tqdm import tqdm 13 | 14 | from codeprep.config import REWRITE_PARSED_FILE, CHUNKSIZE, LIMIT_FILES_SCANNING 15 | from codeprep.fileutils import read_file_contents 16 | from codeprep.pipeline.dataset import Dataset, NOT_FINISHED_EXTENSION 17 | from codeprep.parse.core import convert_text 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def preprocess_and_write(params: Tuple[bytes, bytes]) -> None: 23 | src_file_path, dest_file_path = params 24 | 25 | dest_dirname = os.path.dirname(dest_file_path) 26 | if not os.path.exists(dest_dirname): 27 | os.makedirs(dest_dirname, exist_ok=True) 28 | 29 | if not REWRITE_PARSED_FILE and os.path.exists(dest_file_path): 30 | logger.warning(f"File {dest_file_path} already exists! Doing nothing.") 31 | return 32 | 33 | not_finished_dest_file_path = dest_file_path + NOT_FINISHED_EXTENSION.encode() 34 | with gzip.GzipFile(not_finished_dest_file_path, 'wb') as f: 35 | try: 36 | lines_from_file, path = read_file_contents(src_file_path) 37 | except FileNotFoundError: 38 | logger.error(f"File was found when scanning the directory, but cannot be read: {src_file_path}. " 39 | f"Invalid symlink? Ignoring ...") 40 | return 41 | extension_bin = os.path.splitext(src_file_path)[1].decode()[1:] 42 | parsed = [p for p in convert_text("\n".join(lines_from_file), extension_bin)] 43 | pickle.dump(parsed, f, pickle.HIGHEST_PROTOCOL) 44 | 45 | os.rename(not_finished_dest_file_path, dest_file_path) 46 | 47 | 48 | def params_generator(dataset: Dataset): 49 | for input_file_path in dataset.original.file_iterator(): 50 | output_file_path = dataset.original.get_new_file_name(input_file_path, dataset.parsed) 51 | yield (input_file_path, output_file_path) 52 | 53 | 54 | def run(dataset: Dataset) -> None: 55 | logger.info(f"Getting files from {dataset.original.path}") 56 | logger.info(f"Writing preprocessed files to {dataset.parsed.path}") 57 | 58 | if dataset.files_need_to_be_saved(): 59 | files_total = 0 60 | for _ in dataset.get_all_files(): 61 | files_total += 1 62 | print(f"Files scanned: {files_total}", end='\r') 63 | if files_total > LIMIT_FILES_SCANNING: 64 | files_total = None 65 | logger.info(f"Total files to be preprocessed: {LIMIT_FILES_SCANNING}+") 66 | break 67 | else: 68 | files_total = len([f for f in dataset.get_all_files()]) 69 | with Pool() as pool: 70 | it = pool.imap_unordered(preprocess_and_write, params_generator(dataset), chunksize=CHUNKSIZE) 71 | for _ in tqdm(it, total=files_total): 72 | pass 73 | if not dataset.suppress_caching: 74 | dataset.parsed.set_ready() -------------------------------------------------------------------------------- /codeprep/pipeline/stages.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | """This module runs different stages of preprocessing flow and makes sure not to rerun a stage if its results are already available. 6 | """ 7 | import logging 8 | import os 9 | from typing import Optional 10 | 11 | from codeprep.pipeline import parse_projects, to_repr 12 | from codeprep.pipeline.bperegistry import CustomBpeConfig 13 | from codeprep.pipeline.dataset import Dataset, is_path_ready, is_path_outdated, archive_path 14 | from codeprep.pipeline.vocab import calc_vocab 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | #TODO remove code duplication in methods below 19 | def run_parsing(dataset: Dataset) -> None: 20 | logger.info("Parsing...") 21 | if not dataset.parsed.ready(): 22 | parse_projects.run(dataset) 23 | elif dataset.parsed.is_outdated(): 24 | dataset.parsed.archive() 25 | parse_projects.run(dataset) 26 | else: 27 | logger.info("Parsed dataset is up-to-date.") 28 | 29 | 30 | def run_until_preprocessing(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]=None) -> None: 31 | run_parsing(dataset) 32 | logger.info("Preprocessing...") 33 | if not dataset.preprocessed.ready(): 34 | to_repr.run(dataset, custom_bpe_config) 35 | elif dataset.preprocessed.is_outdated(): 36 | dataset.preprocessed.archive() 37 | to_repr.run(dataset, custom_bpe_config) 38 | else: 39 | logger.info(f"Dataset is already preprocessed and up-to-date.") 40 | 41 | 42 | def run_until_base_bpe_vocab(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]=None) -> None: 43 | run_until_preprocessing(dataset, custom_bpe_config) 44 | logger.info("Computing base bpe vocab...") 45 | if not is_path_ready(dataset.path_to_bpe_vocab_file): 46 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.base_bpe_vocab_path) 47 | elif is_path_outdated(dataset.path_to_bpe_vocab_file): 48 | archive_path(dataset.path_to_bpe_vocab_file) 49 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.base_bpe_vocab_path) 50 | else: 51 | logger.info("Vocabulary is already computed and up-to-date") 52 | 53 | 54 | def run_until_vocab(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]=None) -> None: 55 | logger.info(f'Checking first if vocabulary file exists: {dataset.path_to_vocab_file}') 56 | if os.path.exists(dataset.path_to_vocab_file): 57 | logger.info("Vocabulary is already computed and up-to-date") 58 | return 59 | 60 | if not is_path_ready(dataset.path_to_vocab_file): 61 | run_until_preprocessing(dataset, custom_bpe_config) 62 | logger.info("Computing vocab...") 63 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.vocab_path) 64 | elif is_path_outdated(dataset.path_to_vocab_file): 65 | run_until_preprocessing(dataset, custom_bpe_config) 66 | logger.info("Computing vocab...") 67 | archive_path(dataset.path_to_bpe_vocab_file) 68 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.vocab_path) 69 | else: 70 | raise AssertionError() -------------------------------------------------------------------------------- /codeprep/pipeline/to_repr.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import gzip 6 | import logging 7 | import os 8 | import pickle 9 | import platform 10 | from multiprocessing.pool import Pool 11 | from typing import List, Tuple 12 | from typing import Optional 13 | 14 | import time 15 | from tqdm import tqdm 16 | 17 | from codeprep.bpepkg.bpe_encode import read_merges, BpeData 18 | from codeprep.bpepkg.cache import read_bpe_cache 19 | from codeprep.config import DEFAULT_BPE_DIR, NO_CASE_DIR, CASE_DIR, DEFAULT_BPE_CACHE_DIR, REWRITE_PREPROCESSED_FILE, \ 20 | CHUNKSIZE, LIMIT_FILES_SCANNING 21 | from codeprep.pipeline import vocabloader 22 | from codeprep.pipeline.bperegistry import CustomBpeConfig 23 | from codeprep.pipeline.dataset import Dataset, NOT_FINISHED_EXTENSION 24 | from codeprep.prepconfig import PrepParam, PrepConfig 25 | from codeprep.preprocess.core import to_repr_list 26 | from codeprep.preprocess.metadata import PreprocessingMetadata 27 | from codeprep.preprocess.metadata import save_metadata 28 | from codeprep.preprocess.placeholders import placeholders 29 | from codeprep.tokens.rootclasses import ParsedToken 30 | from codeprep.tokens.word import SpecialToken 31 | from codeprep.util import to_literal_str 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def get_global_bpe_data_if_available() -> Optional[BpeData]: 37 | return global_bpe_data if 'global_bpe_data' in globals() else None 38 | 39 | 40 | def insert_and_word_tokens(prep_list: List[str], metadata: PreprocessingMetadata) -> List[str]: 41 | list_copy = [elm for elm in prep_list] 42 | for index in metadata.word_boundaries[1:]: 43 | list_copy[index-1] += placeholders['compound_word_end'] 44 | return list_copy 45 | 46 | 47 | def to_repr(prep_config: PrepConfig, token_list: List[ParsedToken], 48 | bpe_data: Optional[BpeData] = None) -> Tuple[List[str], PreprocessingMetadata]: 49 | bpe_data = bpe_data or get_global_bpe_data_if_available() 50 | repr_list, metadata = to_repr_list(token_list, prep_config.get_repr_config(bpe_data)) 51 | if prep_config.is_bpe(): 52 | repr_list = insert_and_word_tokens(repr_list, metadata) 53 | return repr_list, metadata 54 | 55 | 56 | def to_token_str(tokens: List) -> str: 57 | return " ".join(map(lambda t: str(t), tokens)) 58 | 59 | 60 | def preprocess_and_write(params: Tuple[bytes, bytes, PrepConfig, str], bpe_data: Optional[BpeData] = None): 61 | src_file_path, dest_file_path, prep_config, part_nonbpe_vocab_folder = params 62 | 63 | dest_dirname = os.path.dirname(dest_file_path) 64 | if not os.path.exists(dest_dirname): 65 | os.makedirs(dest_dirname, exist_ok=True) 66 | 67 | if not REWRITE_PREPROCESSED_FILE and os.path.exists(dest_file_path): 68 | logger.warning(f"File {dest_file_path} already exists! Doing nothing.") 69 | return 70 | 71 | not_finished_dest_file_path = dest_file_path + NOT_FINISHED_EXTENSION.encode() 72 | with gzip.GzipFile(src_file_path, 'rb') as i, open(not_finished_dest_file_path, 'w') as o: 73 | token_list = pickle.load(i) 74 | bpe_data = get_global_bpe_data_if_available() if bpe_data is None else bpe_data 75 | repr, metadata = to_repr(prep_config, token_list + [SpecialToken(placeholders['ect'])], bpe_data) 76 | o.write(to_literal_str(to_token_str(repr)) + '\n') 77 | 78 | if part_nonbpe_vocab_folder: 79 | save_metadata(metadata, os.path.join(part_nonbpe_vocab_folder, f'{os.path.basename(dest_file_path)}_-_{time.time()}')) 80 | 81 | os.rename(not_finished_dest_file_path, dest_file_path) 82 | 83 | #TODO make this method independent of actual directory structure 84 | def init_bpe_data(prep_config: PrepConfig, custom_bpe_config: Optional[CustomBpeConfig], force_reinit: bool=True): 85 | if get_global_bpe_data_if_available() and not force_reinit: 86 | return # already initialized 87 | global global_bpe_data 88 | global_bpe_data = BpeData() 89 | if custom_bpe_config: 90 | logger.info(f'Using bpe merges file: {custom_bpe_config.codes_file}') 91 | if custom_bpe_config.can_use_cache_file(): 92 | global_bpe_data.merges_cache = read_bpe_cache(custom_bpe_config.cache_file) 93 | else: 94 | global_bpe_data.merges_cache = {} 95 | global_bpe_data.merges = read_merges(custom_bpe_config.codes_file, custom_bpe_config.n_merges) 96 | 97 | if custom_bpe_config.n_merges: 98 | logger.info(f'Using first {custom_bpe_config.n_merges} merges.') 99 | nonbpe_vocab = vocabloader.nonbpe(custom_bpe_config.merge_list_id) 100 | global_bpe_data.merges_cache.update({s: [s] for s in nonbpe_vocab}) 101 | else: 102 | bpe_n_merges_dict = {'4': '5k', '5': '1k', '6': '10k', '7': '20k', '8': '0'} 103 | bpe_n_merges = bpe_n_merges_dict[prep_config.get_param_value(PrepParam.SPLIT)] 104 | 105 | bpe_merges_file = os.path.join(DEFAULT_BPE_DIR, 106 | CASE_DIR if prep_config.get_param_value(PrepParam.CASE) == 'u' else NO_CASE_DIR, 107 | str(bpe_n_merges), 'merges.txt') 108 | bpe_merges_cache_file = os.path.join(DEFAULT_BPE_CACHE_DIR, 109 | CASE_DIR if prep_config.get_param_value(PrepParam.CASE) == 'u' else NO_CASE_DIR, 110 | str(bpe_n_merges), 'merges_cache.txt') 111 | if os.path.exists(bpe_merges_cache_file): 112 | global_bpe_data.merges_cache = read_bpe_cache(bpe_merges_cache_file) 113 | else: 114 | global_bpe_data.merges_cache = {} 115 | global_bpe_data.merges = read_merges(bpe_merges_file) 116 | 117 | 118 | def params_generator(dataset: Dataset, path_to_part_metadata: Optional[str]): 119 | for input_file_path in dataset.parsed.file_iterator(): 120 | output_file_path = dataset.parsed.get_new_file_name(input_file_path, dataset.preprocessed) 121 | yield (input_file_path, output_file_path, dataset.prep_config, path_to_part_metadata) 122 | 123 | 124 | def get_n_cpus_to_be_used(): 125 | system_platform = platform.system() 126 | n_cpus = 1 if system_platform in ['Windows', 'Darwin'] else os.cpu_count() or 1 127 | logger.info(f"Platform: {system_platform}, n cores to be used: {n_cpus}") 128 | return n_cpus 129 | 130 | 131 | def run(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]) -> None: 132 | path_to_parsed_dataset = dataset.parsed.path 133 | 134 | if not os.path.exists(path_to_parsed_dataset): 135 | logger.error(f"Dir does not exist: {path_to_parsed_dataset}") 136 | exit(3) 137 | logger.info(f"Reading parsed files from: {path_to_parsed_dataset}") 138 | 139 | if dataset.prep_config.is_bpe(): 140 | init_bpe_data(dataset.prep_config, custom_bpe_config) 141 | 142 | if not os.path.exists(dataset.path_to_nonbpe_vocab_file) and dataset.prep_config.is_base_bpe_config(): 143 | path_to_part_metadata = f'{dataset.path_to_nonbpe_vocab_file}_part' 144 | else: 145 | path_to_part_metadata = None 146 | if path_to_part_metadata and not os.path.exists(path_to_part_metadata): 147 | os.makedirs(path_to_part_metadata) 148 | 149 | logger.info(f"Writing preprocessed files to {dataset.preprocessed.path}") 150 | 151 | if dataset.files_need_to_be_saved(): 152 | files_total = 0 153 | for _ in dataset.get_all_files(): 154 | files_total += 1 155 | print(f"Files scanned: {files_total}", end='\r') 156 | if files_total > LIMIT_FILES_SCANNING: 157 | files_total = None 158 | logger.info(f"Total files to be preprocessed: {LIMIT_FILES_SCANNING}+") 159 | break 160 | else: 161 | files_total = len([f for f in dataset.get_all_files()]) 162 | n_cpus = get_n_cpus_to_be_used() 163 | if n_cpus > 1: 164 | with Pool(processes=n_cpus) as pool: 165 | it = pool.imap_unordered(preprocess_and_write, params_generator(dataset, path_to_part_metadata), chunksize=CHUNKSIZE) 166 | for _ in tqdm(it, total=files_total): 167 | pass 168 | else: 169 | for params in tqdm(params_generator(dataset, path_to_part_metadata), total=files_total): 170 | preprocess_and_write(params, get_global_bpe_data_if_available()) 171 | 172 | if path_to_part_metadata: 173 | vocabloader.gather_non_bpe_vocab(dataset) 174 | 175 | dataset.preprocessed.set_ready() -------------------------------------------------------------------------------- /codeprep/pipeline/vocabloader.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | import os 7 | import shutil 8 | from typing import Set, Dict 9 | 10 | from codeprep.pipeline.bperegistry import get_bpe_dir, get_base_vocab_dir, RESULTING_VOCAB_FILE_NAME 11 | from codeprep.pipeline.dataset import Dataset, NONBPE_VOCAB_FILENAME 12 | from codeprep.preprocess.placeholders import placeholders 13 | from codeprep.util import to_literal_str 14 | from codeprep.pipeline.vocab import _load_vocab_dict, _load_vocab_set, VOCAB_FILENAME 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | def all(merge_list_id: str) -> Dict[str, int]: 19 | bpe_dir = get_base_vocab_dir(merge_list_id) 20 | return _load_vocab_dict(os.path.join(bpe_dir, VOCAB_FILENAME)) 21 | 22 | 23 | def nonbpe(merge_list_id: str) -> Set[str]: 24 | bpe_dir = get_base_vocab_dir(merge_list_id) 25 | return _load_vocab_set(os.path.join(bpe_dir, NONBPE_VOCAB_FILENAME)) 26 | 27 | 28 | def base(merge_list_id: str) -> Dict[str, int]: 29 | all_vocab = all(merge_list_id) 30 | nonbpe_vocab = nonbpe(merge_list_id) 31 | for token in nonbpe_vocab: 32 | if token in all_vocab: 33 | del all_vocab[token] 34 | return all_vocab 35 | 36 | 37 | def bpe(merge_list_id: str, n_merges: int) -> Dict[str, int]: 38 | bpe_dir = get_bpe_dir(merge_list_id, n_merges) 39 | return _load_vocab_dict(os.path.join(bpe_dir, RESULTING_VOCAB_FILE_NAME)) 40 | 41 | 42 | def gather_non_bpe_vocab(dataset: Dataset): 43 | logger.info("Gathering non-bpe vocab...") 44 | part_nonbpe_vocab_dir = f'{dataset.path_to_nonbpe_vocab_file}_part' 45 | non_bpe_tokens: Set[str] = set() 46 | for idx, file in enumerate(os.listdir(part_nonbpe_vocab_dir)): 47 | if idx % 569 == 0: 48 | print(f'Files processed: {idx}', end='\r') 49 | non_bpe_tokens.update(_load_vocab_set(os.path.join(part_nonbpe_vocab_dir, file))) 50 | 51 | non_bpe_tokens.update(list(placeholders.values())) 52 | with open(dataset.path_to_nonbpe_vocab_file, 'w') as f: 53 | for token in non_bpe_tokens: 54 | f.write(f'{to_literal_str(token)}\n') 55 | shutil.rmtree(part_nonbpe_vocab_dir) -------------------------------------------------------------------------------- /codeprep/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /codeprep/preprocess/core.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import Tuple, List, Sequence 6 | 7 | from codeprep.preprocess.metadata import PreprocessingMetadata 8 | from codeprep.preprocess.reprconfig import ReprConfig 9 | from codeprep.tokens.rootclasses import ParsedToken 10 | 11 | 12 | def to_repr_list(token_list: Sequence[ParsedToken], repr_config: ReprConfig) \ 13 | -> Tuple[List[str], PreprocessingMetadata]: 14 | repr_res = [] 15 | all_metadata = PreprocessingMetadata() 16 | for token in token_list: 17 | repr_token, metadata = torepr(token, repr_config) 18 | repr_res.extend(repr_token) 19 | all_metadata.update(metadata) 20 | return repr_res, all_metadata 21 | 22 | 23 | def torepr(token, repr_config) -> Tuple[List[str], PreprocessingMetadata]: 24 | clazz = type(token) 25 | if clazz == str: 26 | raise AssertionError('Strings are not allowed any more as a result of parsing') 27 | if clazz == list: 28 | return to_repr_list(token, repr_config) 29 | if repr_config and clazz in repr_config.types_to_be_repr: 30 | return token.preprocessed_repr(repr_config) 31 | else: 32 | non_prep, metadata = token.non_preprocessed_repr(repr_config) 33 | return (non_prep if isinstance(non_prep, list) else [non_prep]), metadata -------------------------------------------------------------------------------- /codeprep/preprocess/metadata.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import logging 6 | from typing import Set, Optional, List, Type, Tuple 7 | 8 | from codeprep.subtokens import is_terminal_subtoken 9 | from codeprep.util import to_literal_str 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class InvalidMetadataError(Exception): 15 | pass 16 | 17 | 18 | class PreprocessingMetadata(object): 19 | def __init__(self, 20 | nonprocessable_tokens: Optional[Set[str]] = None, 21 | word_boundaries: Optional[List[int]] = None, 22 | token_types: List[Type] = None): 23 | self.nonprocessable_tokens = nonprocessable_tokens or set() 24 | self.word_boundaries = word_boundaries or [0] 25 | self.token_types = token_types or [] 26 | 27 | self._check_invariants() 28 | 29 | def _check_invariants(self) -> None: 30 | assert len(self.word_boundaries) - 1 == len(self.token_types) 31 | 32 | def set_all_tokens_type(self, t: Type) -> None: 33 | self.token_types = [t] * (len(self.word_boundaries) -1) 34 | 35 | def update(self, preprocessing_metadata: 'PreprocessingMetadata') -> 'PreprocessingMetadata': 36 | """ 37 | >>> class TypeA: pass 38 | >>> class TypeB: pass 39 | >>> PreprocessingMetadata().update(PreprocessingMetadata()) 40 | (set(), [0], []) 41 | 42 | >>> PreprocessingMetadata({''}, [0, 2], [TypeA]).update(PreprocessingMetadata({''}, [0, 1, 2, 3], [TypeA, TypeA, TypeB])) 43 | ({''}, [0, 2, 3, 4, 5], ['TypeA', 'TypeA', 'TypeA', 'TypeB']) 44 | 45 | >>> PreprocessingMetadata(set(), [0, 2], [TypeA]).update(PreprocessingMetadata(set(), [0, 3], [TypeB])) 46 | (set(), [0, 2, 5], ['TypeA', 'TypeB']) 47 | """ 48 | self.nonprocessable_tokens.update(preprocessing_metadata.nonprocessable_tokens) 49 | 50 | n_subtokens = self.word_boundaries.pop() 51 | for boundary in preprocessing_metadata.word_boundaries: 52 | self.word_boundaries.append(n_subtokens + boundary) 53 | 54 | self.token_types.extend(preprocessing_metadata.token_types) 55 | 56 | return self 57 | 58 | def __repr__(self): 59 | return str((self.nonprocessable_tokens, self.word_boundaries, list(map(lambda x: x.__name__, self.token_types)))) 60 | 61 | def __eq__(self, other): 62 | return self.__class__ == other.__class__ \ 63 | and self.nonprocessable_tokens == other.nonprocessable_tokens \ 64 | and self.word_boundaries == other.word_boundaries \ 65 | and self.token_types == other.token_types 66 | 67 | 68 | def save_metadata(metadata: PreprocessingMetadata, save_to: bytes) -> None: 69 | with open(save_to, 'w') as f: 70 | for token in metadata.nonprocessable_tokens: 71 | f.write(f'{to_literal_str(token)}\n') 72 | 73 | 74 | def check_metadata_validity(subwords: List[str], metadata: PreprocessingMetadata, use_only_token_end_chars=True) -> None: 75 | word_boundaries = metadata.word_boundaries 76 | if len(word_boundaries) == 0: 77 | raise ValueError("Word boundaries list should contain at least 0!") 78 | if len(subwords) != word_boundaries[-1]: 79 | raise ValueError(f"Word boundaries list should contain the indices of the last word.\n" 80 | f"However, the subword entropies list has {len(subwords)} elements, and " 81 | f"value {len(subwords)} is not found in word boundaries list: {word_boundaries}") 82 | if word_boundaries[0] != 0: 83 | raise ValueError('Word boundaries list must start with 0!') 84 | 85 | if use_only_token_end_chars: 86 | for idx, token in enumerate(subwords): 87 | end_according_to_data = is_terminal_subtoken(token) 88 | end_according_to_metadata = (idx + 1) in metadata.word_boundaries 89 | if end_according_to_data != end_according_to_metadata: 90 | error_context_start_index = idx - 20 if idx - 20 > 0 else 0 91 | error_context_end_index = idx + 20 if idx + 20 < len(subwords) else len(subwords) - 1 92 | raise AssertionError(f'Token {token} according to metadata is' 93 | f'{" " if end_according_to_metadata else " NOT"} end-token. ' 94 | f'Showing context: {subwords[error_context_start_index:error_context_end_index]}') 95 | 96 | 97 | def with_empty_metadata(tokens: List[str]) -> Tuple[List[str], PreprocessingMetadata]: 98 | return tokens, PreprocessingMetadata() 99 | 100 | 101 | def unwrap_single_string(tokens_and_metadata: Tuple[List[str], PreprocessingMetadata]) -> str: 102 | tokens = tokens_and_metadata[0] 103 | if isinstance(tokens, list) and len(tokens) == 1: 104 | return tokens[0] -------------------------------------------------------------------------------- /codeprep/preprocess/placeholders.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | placeholders = { 6 | 'comment': '', 7 | 'string_literal': '', 8 | 'word_start': '', 9 | 'word_end': '', 10 | 'capital': '', 11 | 'capitals': '', 12 | 'ect': '', 13 | 'non_eng': '', 14 | 'non_ascii_seq': '\xf7', 15 | 'non_eng_content': '', 16 | 'olc_end': '', 17 | 'compound_word_end': '', 18 | 'space_in_str': '\xa0' 19 | } -------------------------------------------------------------------------------- /codeprep/preprocess/reprconfig.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import Optional, Callable, List 6 | 7 | from codeprep.bpepkg.bpe_encode import BpeData 8 | 9 | 10 | Splitter = Callable[[str, BpeData], List[str]] 11 | 12 | 13 | class ReprConfig(object): 14 | def __init__(self, types_to_be_repr, 15 | bpe_data: Optional[BpeData], 16 | should_lowercase: bool, 17 | number_splitter: Splitter, 18 | word_splitter: Optional[Splitter], 19 | full_strings: bool, 20 | max_str_length: int): 21 | self.types_to_be_repr = types_to_be_repr 22 | self.bpe_data = bpe_data 23 | self.should_lowercase = should_lowercase 24 | self.number_splitter = number_splitter 25 | self.word_splitter = word_splitter 26 | self.full_strings = full_strings 27 | self.max_str_length = max_str_length -------------------------------------------------------------------------------- /codeprep/stemming.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from nltk.stem import PorterStemmer 6 | 7 | stemmer = PorterStemmer() 8 | 9 | 10 | def stem(word: str): 11 | if not word: 12 | return word 13 | 14 | stemmed = stemmer.stem(word) 15 | if word.isupper(): 16 | return stemmed.upper() 17 | elif word[0].isupper(): 18 | return stemmed.capitalize() 19 | else: 20 | return stemmed -------------------------------------------------------------------------------- /codeprep/subtokens.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List, Callable, Any 6 | 7 | from codeprep.preprocess.placeholders import placeholders 8 | 9 | 10 | class TokenIterator(object): 11 | def __init__(self, subwords, word_boundaries, format, return_full_token_index): 12 | self.validate_word_boundaries(subwords, word_boundaries) 13 | 14 | self.subwords = subwords 15 | self.word_boundaries = word_boundaries 16 | self.format = format 17 | self.return_full_token_index = return_full_token_index 18 | 19 | def __iter__(self): 20 | return self 21 | 22 | @staticmethod 23 | def validate_word_boundaries(subwords: List[str], word_boundaries: List[int]) -> None: 24 | if len(word_boundaries) == 0: 25 | raise ValueError("Word boundaries list should contain at least 0!") 26 | if len(subwords) != word_boundaries[-1]: 27 | raise ValueError(f"Word boundaries list should contain the indices of the last word.\n" 28 | f"However, the subword entropies list has {len(subwords)} elements, and " 29 | f"value {len(subwords)} is not found in word boundaries list: {word_boundaries}") 30 | if word_boundaries[0] != 0: 31 | raise ValueError('Word boundaries list must start with 0!') 32 | 33 | 34 | class SubtokenIterator(TokenIterator): 35 | """ 36 | >>> [token for token in SubtokenIterator(['hi', 'the', 're'], [0, 1, 3])] 37 | ['hi', 'the', 're'] 38 | 39 | >>> [token for token in SubtokenIterator([1, 2, 3], [0, 1, 3], format=lambda s: str(s[0]))] 40 | ['1', '2', '3'] 41 | 42 | >>> [token for token in SubtokenIterator(['hi', 'the', 're'], [0, 1, 3], return_full_token_index=True)] 43 | [(0, 'hi'), (1, 'the'), (1, 're')] 44 | 45 | >>> [token for token in SubtokenIterator(['hi'], [0])] 46 | Traceback (most recent call last): 47 | ... 48 | ValueError: Word boundaries list should contain the indices of the last word. 49 | However, the subword entropies list has 1 elements, and value 1 is not found in word boundaries list: [0] 50 | 51 | >>> [token for token in SubtokenIterator(['hi'], [1])] 52 | Traceback (most recent call last): 53 | ... 54 | ValueError: Word boundaries list must start with 0! 55 | """ 56 | def __init__(self, subwords: List[Any], 57 | word_boundaries: List[int], 58 | format: Callable[[List[str]], Any] = lambda l: l[0], 59 | return_full_token_index: bool = False): 60 | 61 | super().__init__(subwords, word_boundaries, format, return_full_token_index) 62 | 63 | self.current_index = 0 64 | self.current_full_word = 0 65 | 66 | def __next__(self): 67 | if self.current_index >= len(self.subwords): 68 | raise StopIteration 69 | 70 | value = [self.subwords[self.current_index]] 71 | formatted_value = self.format(value) 72 | result = (self.current_full_word, formatted_value) if self.return_full_token_index else formatted_value 73 | 74 | self.current_index += 1 75 | if self.word_boundaries[self.current_full_word + 1] == self.current_index: 76 | self.current_full_word += 1 77 | 78 | return result 79 | 80 | 81 | class FullTokenIterator(TokenIterator): 82 | """ 83 | >>> [token for token in FullTokenIterator(['hi', 'the', 're'], [0, 1, 3])] 84 | ['hi', 'there'] 85 | 86 | >>> [token for token in FullTokenIterator(['hel', 'l', 'o'], [0, 3])] 87 | ['hello'] 88 | 89 | >>> [token for token in FullTokenIterator([1, 2, 4], [0, 2, 3], format=sum)] 90 | [3, 4] 91 | 92 | >>> [token for token in FullTokenIterator(['hi', 'the', 're'], [0, 1, 3], return_full_token_index=True)] 93 | [(0, 'hi'), (1, 'there')] 94 | 95 | >>> [token for token in FullTokenIterator([], [])] 96 | Traceback (most recent call last): 97 | ... 98 | ValueError: Word boundaries list should contain at least 0! 99 | 100 | >>> [token for token in FullTokenIterator(['hi'], [0])] 101 | Traceback (most recent call last): 102 | ... 103 | ValueError: Word boundaries list should contain the indices of the last word. 104 | However, the subword entropies list has 1 elements, and value 1 is not found in word boundaries list: [0] 105 | 106 | >>> [token for token in FullTokenIterator(['hi'], [1])] 107 | Traceback (most recent call last): 108 | ... 109 | ValueError: Word boundaries list must start with 0! 110 | """ 111 | def __init__(self, subwords: List[Any], 112 | word_boundaries: List[int], 113 | format: Callable[[List[str]], Any] = lambda s: ''.join(s), 114 | return_full_token_index: bool = False): 115 | super().__init__(subwords, word_boundaries, format, return_full_token_index) 116 | 117 | self.current_full_word = 0 118 | 119 | def __next__(self): 120 | if self.current_full_word >= len(self.word_boundaries) - 1: 121 | raise StopIteration 122 | 123 | word_start = self.word_boundaries[self.current_full_word] 124 | word_end = self.word_boundaries[self.current_full_word + 1] 125 | formatted_value = self.format(self.subwords[word_start:word_end]) 126 | result = (self.current_full_word, formatted_value) if self.return_full_token_index else formatted_value 127 | 128 | self.current_full_word += 1 129 | 130 | return result 131 | 132 | 133 | def is_terminal_subtoken(subtoken: str, use_token_end_chars: bool = True) -> bool: 134 | if not use_token_end_chars: 135 | raise NotImplemented("Finding out if a subtoken is terminal for tokens represented with and tokens " 136 | "is not yet implemented.") 137 | 138 | return subtoken.endswith(placeholders['compound_word_end']) -------------------------------------------------------------------------------- /codeprep/tokens/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import importlib 6 | from os.path import dirname, basename, isfile, join 7 | import glob 8 | 9 | modules = glob.glob(join(dirname(__file__), "*.py")) 10 | for f in modules: 11 | if isfile(f) and not f.endswith('__init__.py'): 12 | importlib.import_module(f'{__package__}.{basename(f)[:-3]}') -------------------------------------------------------------------------------- /codeprep/tokens/containers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List, Tuple, Union, Optional 6 | 7 | from codeprep.noneng import replace_non_ascii_seqs 8 | from codeprep.preprocess.core import ReprConfig, torepr 9 | from codeprep.preprocess.metadata import PreprocessingMetadata 10 | from codeprep.preprocess.placeholders import placeholders 11 | from codeprep.tokens.rootclasses import ParsedToken, ParsedSubtoken 12 | from codeprep.tokens.whitespace import SpaceInString 13 | from codeprep.tokens.word import Word 14 | 15 | 16 | class ProcessableTokenContainer(ParsedToken): 17 | def __init__(self, subtokens: Union[List[ParsedSubtoken], List[ParsedToken]]): 18 | if isinstance(subtokens, list): 19 | self.subtokens = subtokens 20 | else: 21 | raise AssertionError(f"Should be list but is: {subtokens}") 22 | 23 | def add(self, token): 24 | self.subtokens.append(token) 25 | 26 | def get_subtokens(self): 27 | return self.subtokens 28 | 29 | def __eq__(self, other): 30 | return self.__class__ == other.__class__ and self.subtokens == other.subtokens 31 | 32 | def __repr__(self): 33 | return f'{self.__class__.__name__}{self.subtokens}' 34 | 35 | def __str__(self): 36 | return " ".join(self.non_preprocessed_repr()[0]) 37 | 38 | 39 | def wrap_in_word_boundaries_if_necessary(res: List[str]) -> List[str]: 40 | if len(res) == 1 or (len(res) == 2 and res[0] in [placeholders['capitals'], placeholders['capital']]): 41 | return res 42 | else: 43 | return [placeholders['word_start']] + res + [placeholders['word_end']] 44 | 45 | 46 | class SplitContainer(ProcessableTokenContainer): 47 | def __init__(self, subtokens: List[ParsedSubtoken]): 48 | super().__init__(subtokens) 49 | 50 | def empty_repr(self): 51 | return self.subtokens 52 | 53 | def __repr__(self): 54 | return f'{self.__class__.__name__}{self.subtokens}' 55 | 56 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 57 | nospl_str = ["".join(map(lambda s: torepr(s, repr_config)[0][0], self.subtokens))] 58 | return self.wrap_in_metadata_for_full_word(nospl_str) 59 | 60 | def preprocessed_repr(self, repr_config) -> Tuple[List[str], PreprocessingMetadata]: 61 | if repr_config.bpe_data: 62 | return self.wrap_in_metadata_for_full_word(repr_config.word_splitter(str(self), repr_config.bpe_data)) 63 | res = [] 64 | all_metadata = PreprocessingMetadata() 65 | for subtoken in self.subtokens: 66 | r, metadata = torepr(subtoken, repr_config) 67 | res.extend(r if isinstance(r, list) else [r]) 68 | all_metadata.update(metadata) 69 | return self.wrap_in_metadata_for_full_word(wrap_in_word_boundaries_if_necessary(res), all_metadata.nonprocessable_tokens) 70 | 71 | @classmethod 72 | def from_single_token(cls, token: str): 73 | return cls([Word.from_(token)]) 74 | 75 | 76 | class TextContainer(ProcessableTokenContainer): 77 | 78 | def __init__(self, tokens: List[ParsedToken]): 79 | super().__init__(tokens) 80 | for token in tokens: 81 | if isinstance(token, ParsedSubtoken): 82 | raise TypeError( 83 | f"ParsedTokens cannot be a part of Text container, but one ofn the tokens passed was {type(token)} ({token})") 84 | 85 | def __repr__(self): 86 | return f'{self.__class__.__name__}{self.subtokens}' 87 | 88 | def __eq__(self, other): 89 | return self.__class__ == other.__class__ and self.subtokens == other.subtokens 90 | 91 | 92 | class Comment(TextContainer): 93 | def __init__(self, tokens: List[ParsedToken]): 94 | super().__init__(tokens) 95 | 96 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 97 | return self.wrap_in_metadata_for_full_word([placeholders['comment']]) 98 | 99 | 100 | class OneLineComment(Comment): 101 | def __init__(self, tokens: List[ParsedToken]): 102 | super().__init__(tokens) 103 | 104 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 105 | prep_tokens, metadata = torepr(self.subtokens, repr_config) 106 | metadata.update(PreprocessingMetadata(word_boundaries=[0, 1], token_types=[OneLineComment])) 107 | metadata.set_all_tokens_type(OneLineComment) 108 | return prep_tokens + [placeholders['olc_end']], metadata 109 | 110 | 111 | class MultilineComment(Comment): 112 | def __init__(self, tokens: List[ParsedToken]): 113 | super().__init__(tokens) 114 | 115 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 116 | prep_tokens, metadata = torepr(self.subtokens, repr_config) 117 | metadata.set_all_tokens_type(MultilineComment) 118 | return prep_tokens, metadata 119 | 120 | 121 | class StringLiteral(TextContainer): 122 | def __init__(self, tokens: List[ParsedToken], length: int): 123 | super().__init__(tokens) 124 | self.length = length 125 | 126 | def _replace_non_ascii_seqs_if_necessary(self,repr_config: ReprConfig) -> str: 127 | s = str(self) 128 | if 'NonEng' in list(map(lambda x: x.__name__, repr_config.types_to_be_repr)): 129 | s = placeholders["space_in_str"].join(map(lambda t: replace_non_ascii_seqs(t, placeholders['non_ascii_seq']), s.split(placeholders["space_in_str"]))) 130 | return s 131 | 132 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 133 | if not repr_config: #called by str() 134 | return self.wrap_in_metadata_for_full_word(["".join(map(lambda t: str(t), self.subtokens))]) 135 | elif self.length > repr_config.max_str_length: 136 | s = ['""'] if repr_config.full_strings else ['"', '"'] 137 | non_processable_tokens = {} if repr_config.full_strings else {'"'} 138 | return self.wrap_in_metadata_for_full_word(s, non_processable_tokens) 139 | elif repr_config.bpe_data: 140 | s = self._replace_non_ascii_seqs_if_necessary(repr_config) 141 | return self.wrap_in_metadata_for_full_word(repr_config.word_splitter(s, repr_config.bpe_data)) 142 | elif repr_config.full_strings: 143 | s = self._replace_non_ascii_seqs_if_necessary(repr_config) 144 | return self.wrap_in_metadata_for_full_word([s]) 145 | else: # here we dont keep full strings 146 | tokens, metadata = torepr(list(filter(lambda t: type(t) != SpaceInString, self.subtokens)), repr_config) 147 | metadata.set_all_tokens_type(StringLiteral) 148 | return tokens, metadata 149 | 150 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 151 | return self.wrap_in_metadata_for_full_word([placeholders['string_literal']]) 152 | 153 | def __eq__(self, other): 154 | return super().__eq__(other) and self.length == other.length 155 | 156 | def _repr__(self): 157 | return super().__repr__() + f" , length: {self.length}" -------------------------------------------------------------------------------- /codeprep/tokens/noneng.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List, Tuple, Optional 6 | 7 | from codeprep.noneng import replace_non_ascii_seqs 8 | from codeprep.preprocess.core import ReprConfig, torepr 9 | from codeprep.preprocess.metadata import PreprocessingMetadata 10 | from codeprep.preprocess.placeholders import placeholders 11 | from codeprep.tokens.containers import SplitContainer 12 | from codeprep.tokens.rootclasses import ParsedToken 13 | 14 | 15 | class NonEng(ParsedToken): 16 | def __init__(self, processable_token: SplitContainer): 17 | if not isinstance(processable_token, SplitContainer): 18 | raise ValueError(f"Only SplitContainer can be wrapped in {self.__class__}. Type passed: {type(processable_token)}") 19 | 20 | self.processable_token = processable_token 21 | 22 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 23 | return torepr(self.processable_token, repr_config) 24 | 25 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 26 | if repr_config.bpe_data: 27 | token = replace_non_ascii_seqs(str(self.processable_token), placeholders['non_ascii_seq']) 28 | return torepr(SplitContainer.from_single_token(token), repr_config) 29 | else: 30 | return self.wrap_in_metadata_for_full_word([placeholders['non_eng']]) 31 | 32 | def __repr__(self): 33 | return f'{self.__class__.__name__}({self.processable_token.__repr__()})' 34 | 35 | def __str__(self): 36 | return str(self.processable_token) 37 | 38 | def __eq__(self, other): 39 | return self.__class__ == other.__class__ and self.processable_token == other.processable_token -------------------------------------------------------------------------------- /codeprep/tokens/numeric.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List, Tuple, Optional 6 | 7 | from codeprep.preprocess.core import ReprConfig 8 | from codeprep.preprocess.metadata import PreprocessingMetadata, unwrap_single_string 9 | from codeprep.preprocess.placeholders import placeholders 10 | from codeprep.tokens.rootclasses import ParsedToken 11 | 12 | 13 | class Number(ParsedToken): 14 | def __init__(self, val: str): 15 | self.val = val.lower() 16 | 17 | def __str__(self): 18 | return unwrap_single_string(self.non_preprocessed_repr()) 19 | 20 | def __repr__(self): 21 | return f'<{self.__class__.__name__}>({self.val})' 22 | 23 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 24 | return self.wrap_in_metadata_for_full_word([self.val]) 25 | 26 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 27 | subwords = repr_config.number_splitter(self.non_preprocessed_repr()[0][0], repr_config.bpe_data) 28 | 29 | if len(subwords) > 1 and not repr_config.bpe_data: 30 | prep_number = [placeholders['word_start']] + subwords + [placeholders['word_end']] 31 | else: 32 | prep_number = subwords 33 | 34 | return self.wrap_in_metadata_for_full_word(prep_number) 35 | 36 | def __eq__(self, other): 37 | return self.__class__ == other.__class__ and self.val == other.val 38 | 39 | 40 | class One(Number): 41 | def __init__(self): 42 | super().__init__('1') 43 | 44 | 45 | class Zero(Number): 46 | def __init__(self): 47 | super().__init__('0') -------------------------------------------------------------------------------- /codeprep/tokens/rootclasses.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List, Tuple, Set, Optional 6 | 7 | from codeprep.preprocess.metadata import PreprocessingMetadata 8 | 9 | 10 | class ParsedToken(object): 11 | def wrap_in_metadata_for_full_word(self, tokens: List[str], non_proc: Optional[Set[str]] = None) \ 12 | -> Tuple[List[str], PreprocessingMetadata]: 13 | assert type(tokens) == list 14 | 15 | metadata = PreprocessingMetadata() 16 | metadata.nonprocessable_tokens = non_proc or [] 17 | metadata.word_boundaries = [0, len(tokens)] 18 | metadata.token_types = [type(self)] 19 | return tokens, metadata 20 | 21 | 22 | class ParsedSubtoken(object): 23 | pass -------------------------------------------------------------------------------- /codeprep/tokens/whitespace.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List, Tuple, Optional 6 | 7 | from codeprep.preprocess.core import ReprConfig 8 | from codeprep.preprocess.metadata import PreprocessingMetadata, unwrap_single_string 9 | from codeprep.preprocess.placeholders import placeholders 10 | from codeprep.tokens.rootclasses import ParsedToken 11 | 12 | NBSP = '\xa0' 13 | 14 | 15 | class Whitespace(ParsedToken): 16 | def __eq__(self, other): 17 | return other.__class__ == self.__class__ 18 | 19 | def __repr__(self): 20 | return f'<{self.__class__.__name__}>' 21 | 22 | def __str__(self): 23 | return unwrap_single_string(self.non_preprocessed_repr()) 24 | 25 | 26 | class NewLine(Whitespace): 27 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 28 | return self.wrap_in_metadata_for_full_word(["\n"], non_proc={"\n"}) 29 | 30 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 31 | return [], PreprocessingMetadata() 32 | 33 | 34 | class Tab(Whitespace): 35 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 36 | return self.wrap_in_metadata_for_full_word(["\t"], non_proc={"\t"}) 37 | 38 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 39 | return [], PreprocessingMetadata() 40 | 41 | 42 | class SpaceInString(Whitespace): 43 | 44 | def __init__(self, n_chars: int = 1): 45 | super().__init__() 46 | self.n_chars = n_chars 47 | 48 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 49 | return self.wrap_in_metadata_for_full_word([placeholders['space_in_str'] * self.n_chars]) 50 | 51 | def __repr__(self): 52 | return f'<{self.__class__.__name__}> (n_chars={self.n_chars})' 53 | 54 | def __eq__(self, other): 55 | return other.__class__ == self.__class__ and other.n_chars == self.n_chars -------------------------------------------------------------------------------- /codeprep/tokens/word.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from enum import Enum 6 | from typing import List, Tuple, Optional 7 | 8 | from codeprep.preprocess.core import ReprConfig 9 | from codeprep.preprocess.metadata import PreprocessingMetadata, with_empty_metadata, unwrap_single_string 10 | from codeprep.preprocess.placeholders import placeholders 11 | from codeprep.tokens.rootclasses import ParsedSubtoken, ParsedToken 12 | 13 | 14 | class Underscore(ParsedSubtoken): 15 | def __eq__(self, other): 16 | return other.__class__ == self.__class__ 17 | 18 | def __repr__(self): 19 | return f'<{self.__class__.__name__}>' 20 | 21 | def __str__(self): 22 | return self.non_preprocessed_repr()[0] 23 | 24 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> [Tuple[str, PreprocessingMetadata]]: 25 | return with_empty_metadata("_") 26 | 27 | 28 | class Word(ParsedSubtoken): 29 | """ 30 | Invariants: 31 | str === str(Word.of(str)) 32 | """ 33 | 34 | class Capitalization(str, Enum): 35 | UNDEFINED: str = 'undefined' 36 | NONE = 'none' 37 | FIRST_LETTER = 'first_letter' 38 | ALL = 'all' 39 | 40 | def __repr__(self): 41 | return self.value 42 | 43 | def __init__(self, canonic_form: str, capitalization: Capitalization = Capitalization.UNDEFINED): 44 | Word._check_canonic_form_is_valid(canonic_form) 45 | 46 | self.canonic_form = canonic_form 47 | self.capitalization = capitalization 48 | 49 | def get_canonic_form(self) -> str: 50 | return self.canonic_form 51 | 52 | @staticmethod 53 | def _is_strictly_upper(s) -> bool: 54 | return s.isupper() and not s.lower().isupper() 55 | 56 | @staticmethod 57 | def _check_canonic_form_is_valid(canonic_form) -> None: 58 | if not isinstance(canonic_form, str) or Word._is_strictly_upper(canonic_form) \ 59 | or (canonic_form and Word._is_strictly_upper(canonic_form[0])): 60 | raise AssertionError(f"Bad canonic form: {canonic_form}") 61 | 62 | def __str__(self): 63 | return unwrap_single_string(self.non_preprocessed_repr()) 64 | 65 | def __with_capitalization_prefixes(self, subwords: List[str]) -> List[str]: 66 | if self.capitalization == Word.Capitalization.UNDEFINED or self.capitalization == Word.Capitalization.NONE: 67 | res = subwords 68 | elif self.capitalization == Word.Capitalization.FIRST_LETTER: 69 | res = [placeholders['capital']] + subwords 70 | elif self.capitalization == Word.Capitalization.ALL: 71 | res = [placeholders['capitals']] + subwords 72 | else: 73 | raise AssertionError(f"Unknown value: {self.capitalization}") 74 | return res 75 | 76 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]: 77 | if repr_config.should_lowercase: 78 | subwords = repr_config.word_splitter(self.canonic_form, repr_config.bpe_data) 79 | subwords_with_prefix = self.__with_capitalization_prefixes(subwords) 80 | return with_empty_metadata(subwords_with_prefix) 81 | else: 82 | subwords = repr_config.word_splitter(self.__with_preserved_case(), repr_config.bpe_data) 83 | return with_empty_metadata(subwords) 84 | 85 | def __with_preserved_case(self) -> str: 86 | if self.capitalization == Word.Capitalization.UNDEFINED or self.capitalization == Word.Capitalization.NONE: 87 | return self.canonic_form 88 | elif self.capitalization == Word.Capitalization.FIRST_LETTER: 89 | return self.canonic_form.capitalize() 90 | elif self.capitalization == Word.Capitalization.ALL: 91 | return self.canonic_form.upper() 92 | else: 93 | raise AssertionError(f"Unknown value: {self.capitalization}") 94 | 95 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 96 | return with_empty_metadata([self.__with_preserved_case()]) 97 | 98 | def __repr__(self): 99 | return f'{self.__class__.__name__}({self.canonic_form, self.capitalization})' 100 | 101 | def __eq__(self, other): 102 | return self.__class__ == other.__class__ and self.canonic_form == other.canonic_form \ 103 | and self.capitalization == other.capitalization 104 | 105 | @classmethod 106 | def from_(cls, s: str): 107 | if not s: 108 | raise ValueError(f'A subword can be neither None nor of length zero. Value of the subword is {s}') 109 | 110 | if s.islower() or not s: 111 | return cls(s, Word.Capitalization.NONE) 112 | elif s.isupper(): 113 | return cls(s.lower(), Word.Capitalization.ALL) 114 | elif s[0].isupper(): 115 | return cls(s[0].lower() + s[1:], Word.Capitalization.FIRST_LETTER) 116 | else: 117 | return cls(s, Word.Capitalization.UNDEFINED) 118 | 119 | 120 | class NonProcessibleToken(ParsedToken): 121 | def __init__(self, token: str): 122 | self.token = token 123 | 124 | def __eq__(self, other): 125 | return self.__class__ == other.__class__ and self.token == other.token 126 | 127 | def __repr__(self): 128 | return f'{self.__class__.__name__}({self.token})' 129 | 130 | def __str__(self): 131 | return self.token 132 | 133 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]: 134 | return self.wrap_in_metadata_for_full_word([self.token], non_proc={self.token}) 135 | 136 | 137 | class KeyWord(NonProcessibleToken): 138 | def __init__(self, token: str): 139 | super().__init__(token) 140 | 141 | 142 | class Operator(NonProcessibleToken): 143 | def __init__(self, token: str): 144 | super().__init__(token) 145 | 146 | 147 | class Semicolon(Operator): 148 | def __init__(self): 149 | super().__init__(';') 150 | 151 | 152 | class OpeningCurlyBracket(Operator): 153 | def __init__(self): 154 | super().__init__('{') 155 | 156 | 157 | class ClosingCurlyBracket(Operator): 158 | def __init__(self): 159 | super().__init__('}') 160 | 161 | 162 | class OpeningBracket(Operator): 163 | def __init__(self): 164 | super().__init__('(') 165 | 166 | 167 | class ClosingBracket(Operator): 168 | def __init__(self): 169 | super().__init__(')') 170 | 171 | 172 | class NonCodeChar(NonProcessibleToken): 173 | def __init__(self, token: str): 174 | super().__init__(token) 175 | 176 | 177 | class SpecialToken(NonProcessibleToken): 178 | def __init__(self, token: str): 179 | super().__init__(token) -------------------------------------------------------------------------------- /codeprep/util.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import multiprocessing 6 | from heapq import heappush, heappop, heapify 7 | 8 | import itertools 9 | from typing import Dict, Tuple, List, Optional, Generator 10 | 11 | 12 | def merge_dicts_(dict1, dict2) -> Tuple[Dict, List]: 13 | """ 14 | this method returns modified `dict1`! and new words are added to the dictionary 15 | 16 | >>> dict1 = {"a": 3, "b": 4} 17 | >>> dict2 = {"b": 5, "c": 6} 18 | >>> merge_dicts_(dict1, dict2) 19 | ({'a': 3, 'b': 9, 'c': 6}, ['c']) 20 | >>> dict1 21 | {'a': 3, 'b': 9, 'c': 6} 22 | 23 | """ 24 | new_words = [] 25 | for k, v in dict2.items(): 26 | if k not in dict1: 27 | dict1[k] = v 28 | new_words.append(k) 29 | else: 30 | dict1[k] = dict1[k] + v 31 | return dict1, new_words 32 | 33 | 34 | class NonAtomicCounter(object): 35 | """ 36 | >>> counter = NonAtomicCounter(10) 37 | >>> counter.inc() 38 | 11 39 | >>> counter.dec() 40 | 10 41 | >>> counter.compare_and_dec(10) 42 | True 43 | >>> counter.value 44 | 9 45 | >>> counter.value = 20 46 | >>> counter.get_and_dec() 47 | 20 48 | >>> counter.value 49 | 19 50 | """ 51 | def __init__(self, v: int=0): 52 | self.val = v 53 | 54 | def inc(self) -> int: 55 | self.val += 1 56 | return self.val 57 | 58 | def dec(self) -> int: 59 | self.val -= 1 60 | return self.val 61 | 62 | def compare_and_dec(self, v: int) -> bool: 63 | result = (self.val == v) 64 | self.dec() 65 | return result 66 | 67 | def get_and_dec(self) -> int: 68 | result = self.val 69 | self.dec() 70 | return result 71 | 72 | @property 73 | def value(self) -> int: 74 | return self.val 75 | 76 | @value.setter 77 | def value(self, v: int) -> None: 78 | self.val = v 79 | 80 | 81 | class AtomicInteger(object): 82 | def __init__(self, v=0): 83 | self._lock = multiprocessing.Lock() 84 | self._queue = multiprocessing.Queue() 85 | for i in range(v): 86 | self._queue.put(1) 87 | 88 | def inc(self): 89 | with self._lock: 90 | self._queue.put(1) 91 | return self._queue.qsize() 92 | 93 | def dec(self): 94 | with self._lock: 95 | self._queue.get() 96 | return self._queue.qsize() 97 | 98 | def compare_and_dec(self, val): 99 | with self._lock: 100 | result = self._queue.qsize() == val 101 | self._queue.get() 102 | return result 103 | 104 | def get_and_dec(self): 105 | with self._lock: 106 | result = self._queue.qsize() 107 | self._queue.get() 108 | return result 109 | 110 | @property 111 | def value(self): 112 | with self._lock: 113 | return self._queue.qsize() 114 | 115 | @value.setter 116 | def value(self, v): 117 | with self._lock: 118 | self._queue = multiprocessing.Queue() 119 | for i in range(v): 120 | self._queue.put(1) 121 | 122 | 123 | class PriorityCounter(object): 124 | REMOVED = '' # placeholder for a removed task 125 | 126 | def __init__(self, d: Dict, automatic_count: bool=True): 127 | self.counter = itertools.count() if automatic_count else None 128 | self.pq = [[(-value, next(self.counter)) if self.counter else (-value[0], value[1]), key] for key, value in d.items()] # list of entries arranged in a heap 129 | heapify(self.pq) 130 | self.entry_finder = {entry[1]: entry for entry in self.pq} # mapping of tasks to entries 131 | 132 | def add(self, pair, to_add: int, c: Optional[int]=None): 133 | 'Add a new task or update the priority of an existing task' 134 | if (self.counter is None) == (c is None): 135 | raise ValueError("Either counter should be set, or count argument should be passed!") 136 | count = next(self.counter) if self.counter else c 137 | to_add = -to_add 138 | if pair in self.entry_finder: 139 | entry = self.entry_finder[pair] 140 | to_add = entry[0][0] + to_add 141 | self.remove_task(pair) 142 | if to_add != 0: 143 | entry = [(to_add, count), pair] 144 | self.entry_finder[pair] = entry 145 | heappush(self.pq, entry) 146 | 147 | def remove_task(self, task): 148 | 'Mark an existing task as REMOVED. Raise KeyError if not found.' 149 | entry = self.entry_finder.pop(task) 150 | entry[-1] = PriorityCounter.REMOVED 151 | 152 | def pop_pair(self): 153 | 'Remove and return the lowest priority task. Raise KeyError if empty.' 154 | while self.pq: 155 | (priority, count), pair = heappop(self.pq) 156 | if pair is not PriorityCounter.REMOVED: 157 | del self.entry_finder[pair] 158 | return pair, -priority 159 | raise KeyError('pop from an empty priority queue') 160 | 161 | 162 | import sys 163 | from numbers import Number 164 | from collections import Set, Mapping, deque 165 | 166 | 167 | # From https://stackoverflow.com/a/30316760: 168 | def getsize(obj): 169 | zero_depth_bases = (str, bytes, Number, range, bytearray) 170 | iteritems = 'items' 171 | 172 | def _getsize(obj_0): 173 | """Recursively iterate to sum size of object & members.""" 174 | _seen_ids = set() 175 | 176 | def inner(obj): 177 | obj_id = id(obj) 178 | if obj_id in _seen_ids: 179 | return 0 180 | _seen_ids.add(obj_id) 181 | size = sys.getsizeof(obj) 182 | if isinstance(obj, zero_depth_bases): 183 | pass # bypass remaining control flow and return 184 | elif isinstance(obj, (tuple, list, Set, deque)): 185 | size += sum(inner(i) for i in obj) 186 | elif isinstance(obj, Mapping) or hasattr(obj, iteritems): 187 | size += sum(inner(k) + inner(v) for k, v in getattr(obj, iteritems)()) 188 | # Check for custom object instances - may subclass above too 189 | if hasattr(obj, '__dict__'): 190 | size += inner(vars(obj)) 191 | if hasattr(obj, '__slots__'): # can have __slots__ with __dict__ 192 | size += sum(inner(getattr(obj, s)) for s in obj.__slots__ if hasattr(obj, s)) 193 | return size 194 | 195 | return inner(obj_0) 196 | 197 | return _getsize(obj) 198 | 199 | 200 | def is_python_3_6_and_higher(): 201 | python_version = sys.version_info 202 | return python_version[0] >= 3 and python_version[1] >= 6 203 | 204 | 205 | def create_chunk_generator(total: int, n_chunks: int) -> Generator[int, None, None]: 206 | min_elms_in_chunk = total // n_chunks 207 | for i in range(min_elms_in_chunk): 208 | for j in range(n_chunks): 209 | yield j 210 | for i in range(total % n_chunks): 211 | yield i 212 | 213 | 214 | def groupify(all: List, n_chunks: int) -> List[List]: 215 | """ 216 | >>> groupify([0, 1, 2, 3, 4, 5, 6], 3) 217 | [[0, 3, 6], [1, 4], [2, 5]] 218 | 219 | >>> groupify([0, 1, 2, 3, 4, 5, 6], 300) 220 | [[0], [1], [2], [3], [4], [5], [6]] 221 | 222 | >>> groupify([], 300) 223 | [] 224 | 225 | >>> result = groupify(list(range(100 * 1000 + 17)), 100) 226 | >>> len(result[0]) 227 | 1001 228 | >>> len(result[17]) 229 | 1000 230 | """ 231 | groups = [[] for _ in range(n_chunks if len(all) >= n_chunks else len(all))] 232 | chunk_gen = create_chunk_generator(len(all), n_chunks) 233 | for elm, label in zip(all, chunk_gen): 234 | groups[label].append(elm) 235 | return groups 236 | 237 | 238 | def to_non_literal_str(word:str) -> str: 239 | return word.encode().decode("unicode-escape") 240 | 241 | 242 | def to_literal_str(word: str) -> str: 243 | return word.encode("unicode-escape").decode() 244 | 245 | 246 | START_ERROR_COLOR = '\033[31m' 247 | END_ERROR_COLOR = '\033[0m' -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.05mb_1bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.05mb_1bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.05mb_1bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.05mb_2bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.05mb_2bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.05mb_2bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.05mb_3bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.05mb_3bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.05mb_3bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.5mb_1bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.5mb_1bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.5mb_1bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.5mb_2bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.5mb_2bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.5mb_2bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.5mb_3bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.5mb_3bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_0.5mb_3bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_5mb_1bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_5mb_1bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_5mb_1bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_5mb_2bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_5mb_2bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_5mb_2bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_5mb_3bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_5mb_3bit.png -------------------------------------------------------------------------------- /reports/bpe/wild-bpe/v0.1_5mb_3bit.png.license: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: 2020 Hlib Babii 2 | 3 | SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/casics/spiral.git@a10fcd3afbd4c28e917cc0056d6f646cc2e9bd44 2 | pytest==5.3.1 3 | pytest-mock==1.12.1 4 | coverage==4.5.4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.4 2 | click==7.1.2 3 | dill==0.3.1.1 4 | docopt==0.6.2 5 | docopt-subcommands==3.0.0 6 | joblib==0.15.1 7 | jsons==1.1.2 8 | nltk==3.5 9 | Pygments==2.6.1 10 | PyYAML==5.4 11 | regex==2020.5.14 12 | tqdm==4.46.0 13 | typish==1.6.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | 7 | from setuptools import setup, find_packages 8 | 9 | root_package_name = 'codeprep' 10 | 11 | 12 | def readme(): 13 | with open('README.md') as f: 14 | return f.read() 15 | 16 | 17 | def version(): 18 | with open(os.path.join(root_package_name, 'VERSION')) as version_file: 19 | return version_file.read().strip() 20 | 21 | 22 | setup(name='codeprep', 23 | version=version(), 24 | description='A toolkit for pre-processing large source code corpora', 25 | long_description=readme(), 26 | long_description_content_type="text/markdown", 27 | url='http://github.com/giganticode/codeprep', 28 | author='Hlib Babii', 29 | author_email='hlibbabii@gmail.com', 30 | license='Apache-2.0', 31 | packages=find_packages(), 32 | classifiers=[ 33 | 'Development Status :: 3 - Alpha', 34 | 'Environment :: Console', 35 | 'Intended Audience :: Science/Research', 36 | 'License :: OSI Approved :: Apache Software License', 37 | 'Natural Language :: English', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Programming Language :: Python :: 3.7', 40 | 'Operating System :: POSIX :: Linux', 41 | 'Operating System :: MacOS :: MacOS X', 42 | 'Operating System :: Microsoft :: Windows', 43 | 'Topic :: Software Development :: Pre-processors' 44 | ], 45 | python_requires='>=3.6', 46 | keywords='big large data source code corpus machine learning pre-processing nlp', 47 | install_requires=[ 48 | 'appdirs>=1.4, <2', 49 | 'dill>=0.3.1.1, <0.4', 50 | 'docopt>=0.6.2, <0.7', 51 | 'docopt-subcommands>=3.0.0, <4', 52 | 'jsons>=1.0, <2', 53 | 'nltk>=3.4.5, <4', 54 | 'Pygments>=2.5.2, <3', 55 | 'PyYAML>=5.1, <6', 56 | 'regex>=2019.11.1, <=2020.5.14', 57 | 'tqdm>=4.39, <5' 58 | ], 59 | entry_points={ 60 | 'console_scripts': [ 61 | f'codeprep = {root_package_name}.__main__:main' 62 | ] 63 | }, 64 | include_package_data=True, 65 | zip_safe=False) -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/.reuse/dep5: -------------------------------------------------------------------------------- 1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ 2 | Upstream-Name: yahtzee 3 | Upstream-Contact: Hlib Babii 4 | Source: 5 | 6 | # Sample paragraph, commented out: 7 | # 8 | # Files: src/* 9 | # Copyright: $YEAR $NAME <$CONTACT> 10 | # License: ... 11 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/LICENSES/MIT.txt: -------------------------------------------------------------------------------- 1 | MIT License Copyright (c) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is furnished 8 | to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice (including the next 11 | paragraph) shall be included in all copies or substantial portions of the 12 | Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS 17 | OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 18 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF 19 | OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/DiceValues.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee; 2 | 3 | import java.util.Iterator; 4 | 5 | public class DiceValues implements Iterable { 6 | 7 | public static final Integer MIN_DICE_VALUE = 1; 8 | public static final Integer MAX_DICE_VALUE = 6; 9 | 10 | private Iterator iterator; 11 | 12 | public DiceValues(Iterator iterator) { 13 | this.iterator = iterator; 14 | } 15 | 16 | public static Iterable getDescendingIterator() { 17 | return new DiceValues(new DescendingIterator()); 18 | } 19 | 20 | public static class DescendingIterator implements Iterator { 21 | 22 | private Integer currentValue = MAX_DICE_VALUE; 23 | 24 | @Override 25 | public boolean hasNext() { 26 | return this.currentValue >= MIN_DICE_VALUE; 27 | } 28 | 29 | @Override 30 | public Integer next() { 31 | return this.currentValue--; 32 | } 33 | } 34 | 35 | @Override 36 | public Iterator iterator() { 37 | return this.iterator; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/GameDemo.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee; 2 | 3 | import hlibbabii.yahtzee.gameplay.Game; 4 | import hlibbabii.yahtzee.gameplay.GameStats; 5 | import hlibbabii.yahtzee.player.DummyPlayer; 6 | 7 | public class GameDemo { 8 | public static void main(String[] args) { 9 | Player player1 = new DummyPlayer(); 10 | Player player2 = new DummyPlayer(); 11 | 12 | Game game = new Game(player1, player2); 13 | 14 | GameStats gameStats = game.play(); 15 | System.out.println("Player1"); 16 | System.out.println(gameStats.getPlayerStats(player1)); 17 | System.out.println("Player2"); 18 | System.out.println(gameStats.getPlayerStats(player2)); 19 | 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/Player.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee; 2 | 3 | import hlibbabii.yahtzee.combination.Combination; 4 | import hlibbabii.yahtzee.gameplay.Decision; 5 | import hlibbabii.yahtzee.model.DiceLayout; 6 | 7 | import java.util.Set; 8 | 9 | public interface Player { 10 | Decision makeDecision(DiceLayout diceLayout, DiceLayout fixedDiceLayout, Set availableCombinations, int rollsLeft); 11 | } 12 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Chance.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | 5 | import java.util.stream.IntStream; 6 | 7 | public class Chance extends Combination { 8 | public static final Chance CHANCE = new Chance(); 9 | 10 | @Override 11 | public int earnedScores(DiceLayout diceLayout) { 12 | return IntStream.of(diceLayout.toSortedNumbers()).sum(); 13 | } 14 | 15 | @Override 16 | public String toString() { 17 | return "Chance"; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Combination.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | 5 | public abstract class Combination { 6 | 7 | public abstract int earnedScores(DiceLayout diceLayout); 8 | } 9 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/FullHouse.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.gameplay.Game; 4 | import hlibbabii.yahtzee.model.DiceLayout; 5 | 6 | import java.util.Collection; 7 | import java.util.stream.IntStream; 8 | 9 | public class FullHouse extends Combination { 10 | 11 | public static final FullHouse FULL_HOUSE = new FullHouse(); 12 | 13 | @Override 14 | public int earnedScores(DiceLayout diceLayout) { 15 | assert 5 == Game.N_DICE; 16 | 17 | Collection counts = diceLayout.toCounts().values(); 18 | return counts.contains(2) && counts.contains(3)? 19 | IntStream.of(diceLayout.toSortedNumbers()).sum(): 20 | 0; 21 | 22 | } 23 | 24 | @Override 25 | public String toString() { 26 | return "Full House"; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/LargeStraight.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.DiceValues; 4 | import hlibbabii.yahtzee.model.DiceLayout; 5 | 6 | import java.util.Arrays; 7 | import java.util.stream.IntStream; 8 | 9 | public class LargeStraight extends Combination { 10 | 11 | public static final LargeStraight LARGE_STRAIGHT = new LargeStraight(); 12 | 13 | @Override 14 | public int earnedScores(DiceLayout diceLayout) { 15 | int[] sortedRolledNumbers = diceLayout.toSortedNumbers(); 16 | if (Arrays.equals(sortedRolledNumbers, IntStream.range(2, DiceValues.MAX_DICE_VALUE + 1).toArray())) { 17 | return IntStream.of(sortedRolledNumbers).sum(); 18 | } else { 19 | return 0; 20 | } 21 | } 22 | 23 | @Override 24 | public String toString() { 25 | return "Large Straight"; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/NOfAKind.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.DiceValues; 4 | import hlibbabii.yahtzee.model.DiceLayout; 5 | 6 | public class NOfAKind extends Combination{ 7 | 8 | public static final NOfAKind PAIR = new NOfAKind(2); 9 | public static final NOfAKind THREE_OF_A_KIND = new NOfAKind(3); 10 | public static final NOfAKind FOUR_OF_A_KIND = new NOfAKind(4); 11 | 12 | private int n; 13 | 14 | public NOfAKind(int n) { 15 | this.n = n; 16 | } 17 | 18 | @Override 19 | public int earnedScores(DiceLayout diceLayout) { 20 | for (Integer diceValue : DiceValues.getDescendingIterator()) { 21 | if (diceLayout.getCount(diceValue) >= this.n) { 22 | return diceValue * this.n; 23 | } 24 | } 25 | return 0; 26 | } 27 | 28 | @Override 29 | public String toString() { 30 | return n + " Of A Kind"; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Numbers.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | 5 | public class Numbers extends Combination { 6 | 7 | public static final Combination ACES = new Numbers(1); 8 | public static final Combination TWOS = new Numbers(2); 9 | public static final Combination THREES = new Numbers(3); 10 | public static final Combination FOURS = new Numbers(4); 11 | public static final Combination FIVES = new Numbers(5); 12 | public static final Combination SIXES = new Numbers(6); 13 | 14 | private Integer number; 15 | 16 | public Numbers(Integer number) { 17 | this.number = number; 18 | } 19 | 20 | @Override 21 | public int earnedScores(DiceLayout diceLayout) { 22 | return diceLayout.getCount(this.number) * this.number; 23 | } 24 | 25 | @Override 26 | public String toString() { 27 | return "Number(" + number + ")"; 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/SmallStraight.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.DiceValues; 4 | import hlibbabii.yahtzee.model.DiceLayout; 5 | 6 | import java.util.Arrays; 7 | import java.util.stream.IntStream; 8 | 9 | public class SmallStraight extends Combination { 10 | 11 | public static final SmallStraight SMALL_STRAIGHT = new SmallStraight(); 12 | 13 | @Override 14 | public int earnedScores(DiceLayout diceLayout) { 15 | int[] sortedRolledNumbers = diceLayout.toSortedNumbers(); 16 | if (Arrays.equals(sortedRolledNumbers, IntStream.range(1, DiceValues.MAX_DICE_VALUE).toArray())) { 17 | return IntStream.of(sortedRolledNumbers).sum(); 18 | } else { 19 | return 0; 20 | } 21 | } 22 | 23 | @Override 24 | public String toString() { 25 | return "Small Straight"; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/TwoPairs.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.DiceValues; 4 | import hlibbabii.yahtzee.model.DiceLayout; 5 | 6 | public class TwoPairs extends Combination { 7 | 8 | public static final TwoPairs TWO_PAIRS = new TwoPairs(); 9 | 10 | @Override 11 | public int earnedScores(DiceLayout diceLayout) { 12 | int sum = 0; 13 | for (Integer diceValue: DiceValues.getDescendingIterator()) { 14 | if (diceLayout.getCount(diceValue) >= 2) { 15 | if (sum == 0) { 16 | sum = diceValue * 2; 17 | } else { 18 | return sum + diceValue * 2; 19 | } 20 | } 21 | } 22 | return 0; 23 | } 24 | 25 | @Override 26 | public String toString() { 27 | return "Two Pairs"; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Yahtzee.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | 5 | public class Yahtzee extends Combination { 6 | 7 | public static final Yahtzee YAHTZEE = new Yahtzee(); 8 | 9 | private static final int SCORES_FOR_YAHTZEE = 50; 10 | 11 | @Override 12 | public int earnedScores(DiceLayout diceLayout) { 13 | return diceLayout.toCounts().keySet().size() == 1 ? SCORES_FOR_YAHTZEE : 0; 14 | } 15 | 16 | @Override 17 | public String toString() { 18 | return "Yahtzee"; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/Decision.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.gameplay; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import hlibbabii.yahtzee.combination.Combination; 5 | 6 | public class Decision { 7 | private DiceLayout fixedDiceLayout; 8 | private Combination combination; 9 | 10 | public Decision(Combination combination) { 11 | this.combination = combination; 12 | } 13 | 14 | public Decision(DiceLayout fixedDiceLayout) { 15 | this.fixedDiceLayout = fixedDiceLayout; 16 | } 17 | 18 | public static Decision fixLayoutAndRoll(DiceLayout fixedDiceLayout) { 19 | return new Decision(fixedDiceLayout); 20 | } 21 | 22 | public boolean isCombinationDecision() { 23 | return this.combination != null; 24 | } 25 | 26 | public static Decision decideCombination(Combination combination) { 27 | return new Decision(combination); 28 | } 29 | 30 | public Combination getCombination() { 31 | if (!this.isCombinationDecision()) { 32 | throw new AssertionError("To get the combination the decision has to be final!"); 33 | } 34 | 35 | return this.combination; 36 | } 37 | 38 | public DiceLayout getDiceDecidedToLeave() { 39 | return this.fixedDiceLayout; 40 | } 41 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/Game.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.gameplay; 2 | 3 | import hlibbabii.yahtzee.player.DummyPlayer; 4 | import hlibbabii.yahtzee.Player; 5 | import hlibbabii.yahtzee.model.DiceLayout; 6 | import hlibbabii.yahtzee.combination.Combination; 7 | import hlibbabii.yahtzee.util.RandomService; 8 | 9 | import java.util.*; 10 | import java.util.Map.Entry; 11 | 12 | public class Game { 13 | 14 | public static final int N_DICE = 5; 15 | private static final Integer NUMBER_OF_ATTEMPTS = 3; 16 | 17 | private final Player player1; 18 | private final Player player2; 19 | private final GameStats gameStats; 20 | 21 | 22 | public Game(Player player1, Player player2) { 23 | this.player1 = player1; 24 | this.player2 = player2; 25 | this.gameStats = new GameStats(player1, player2); 26 | } 27 | 28 | public static void main(String[] args) { 29 | GameStats gameStats = new Game(new DummyPlayer(), new DummyPlayer()).play(); 30 | System.out.println(gameStats.getFinalPoints()); 31 | } 32 | 33 | public GameStats play() { 34 | while (this.gameStats.combinationsAvailable()) { 35 | MoveResult moveResult1 = this.makeMove(this.player1); 36 | this.gameStats.addMoveResult(moveResult1); 37 | 38 | MoveResult moveResult2 = this.makeMove(this.player2); 39 | this.gameStats.addMoveResult(moveResult2); 40 | } 41 | return this.gameStats; 42 | } 43 | 44 | private MoveResult makeMove(Player player) { 45 | PlayerStats playerStats = this.gameStats.getPlayerStats(player); 46 | Set availableCombinations = playerStats.getAvailableCombinations(); 47 | 48 | DiceLayout fixedDiceLayout = DiceLayout.empty(); 49 | int howManyToRoll = N_DICE; 50 | for (int i = NUMBER_OF_ATTEMPTS; i > 0; i--) { 51 | DiceLayout nonFixedDiceLayout = this.roll(howManyToRoll); 52 | Decision decision; 53 | try { 54 | decision = player.makeDecision(nonFixedDiceLayout, fixedDiceLayout, availableCombinations, i - 1); 55 | } catch (Exception e) { 56 | throw new PlayerException(e); 57 | } 58 | if (decision.isCombinationDecision()) { 59 | return MoveResult.create(player, decision, nonFixedDiceLayout); 60 | } else { 61 | DiceLayout currentFixedDiceLayout = decision.getDiceDecidedToLeave(); 62 | this.checkFixedDiceLayoutValid(currentFixedDiceLayout, fixedDiceLayout); 63 | fixedDiceLayout = currentFixedDiceLayout; 64 | howManyToRoll = N_DICE - fixedDiceLayout.getSize(); 65 | } 66 | } 67 | throw new AssertionError("Combination decision should have already been made!"); 68 | } 69 | 70 | private void checkFixedDiceLayoutValid(DiceLayout newFixedDiceLayout, DiceLayout previousFixedDiceLayout) { 71 | for (Entry previousAlignmentEntry: previousFixedDiceLayout.toCounts().entrySet()) { 72 | if (previousAlignmentEntry.getValue() > newFixedDiceLayout.toCounts().get(previousAlignmentEntry.getKey())) { 73 | throw new AssertionError("The dice which once were left on the table cannot be rerolled later"); 74 | } 75 | } 76 | } 77 | 78 | public DiceLayout roll() { 79 | return this.roll(N_DICE); 80 | } 81 | 82 | private DiceLayout roll(int howMany) { 83 | RandomService random = new RandomService(); 84 | int[] arr = new int[howMany]; 85 | for (int i = 0; i < howMany; i++) { 86 | arr[i] = random.getRandomDiceNumber(); 87 | } 88 | return DiceLayout.fromNumbers(arr); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/GameStats.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.gameplay; 2 | 3 | import hlibbabii.yahtzee.Player; 4 | 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | import java.util.stream.Collectors; 8 | 9 | public class GameStats { 10 | Map playerStatsMap; 11 | 12 | public GameStats(Player player1, Player player2) { 13 | this.playerStatsMap = new HashMap() {{ 14 | this.put(player1, new PlayerStats()); 15 | this.put(player2, new PlayerStats()); 16 | }}; 17 | } 18 | 19 | public boolean combinationsAvailable() { 20 | PlayerStats anyPlayerStats = this.playerStatsMap.values().iterator().next(); 21 | return !anyPlayerStats.getAvailableCombinations().isEmpty(); 22 | } 23 | 24 | public void addMoveResult(MoveResult moveResult) { 25 | Player player = moveResult.getPlayer(); 26 | PlayerStats playerStats = this.playerStatsMap.get(player); 27 | playerStats.put(moveResult.getCombination(), moveResult.getScore()); 28 | } 29 | 30 | public PlayerStats getPlayerStats(Player player) { 31 | return this.playerStatsMap.get(player); 32 | } 33 | 34 | public Map getFinalPoints() { 35 | return this.playerStatsMap.entrySet().stream() 36 | .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getFinalPoints())); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/MoveResult.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.gameplay; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import hlibbabii.yahtzee.Player; 5 | import hlibbabii.yahtzee.combination.Combination; 6 | 7 | import java.util.Objects; 8 | 9 | public class MoveResult { 10 | private final Player player; 11 | private final Combination combination; 12 | private final int score; 13 | 14 | private MoveResult(Player player, Combination combination, int scores) { 15 | this.player = player; 16 | this.combination = combination; 17 | this.score = scores; 18 | } 19 | 20 | public static MoveResult create(Player player, Decision decision, DiceLayout nonFixedDiceLayout) { 21 | Combination combination = decision.getCombination(); 22 | int scores = combination.earnedScores(nonFixedDiceLayout); 23 | return new MoveResult(player, combination, scores); 24 | } 25 | 26 | public Player getPlayer() { 27 | return this.player; 28 | } 29 | 30 | public Combination getCombination() { 31 | return this.combination; 32 | } 33 | 34 | public int getScore() { 35 | return this.score; 36 | } 37 | 38 | @Override 39 | public boolean equals(Object o) { 40 | if (this == o) return true; 41 | if (o == null || this.getClass() != o.getClass()) return false; 42 | MoveResult that = (MoveResult) o; 43 | return this.score == that.score && 44 | Objects.equals(this.player, that.player) && 45 | Objects.equals(this.combination, that.combination); 46 | } 47 | 48 | @Override 49 | public int hashCode() { 50 | return Objects.hash(this.player, this.combination, this.score); 51 | } 52 | 53 | @Override 54 | public String toString() { 55 | return "MoveResult{" + 56 | "player=" + player + 57 | ", combination=" + combination + 58 | ", score=" + score + 59 | '}'; 60 | } 61 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/PlayerException.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.gameplay; 2 | 3 | public class PlayerException extends RuntimeException { 4 | public PlayerException(Exception e) { 5 | super(e); 6 | } 7 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/PlayerStats.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.gameplay; 2 | 3 | 4 | import hlibbabii.yahtzee.combination.*; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | import java.util.Map.Entry; 9 | import java.util.Set; 10 | import java.util.stream.Collectors; 11 | 12 | public class PlayerStats { 13 | 14 | private static final int NUMBERS_VALUE_REACH_TO_GET_BONUS = 63; 15 | private static final int BONUS_VALUE = 50; 16 | 17 | private Map combinations; 18 | 19 | private void initAllCombinations() { 20 | this.combinations = new HashMap<>(); 21 | /* Numbers*/ 22 | this.combinations.put(Numbers.ACES, null); 23 | this.combinations.put(Numbers.TWOS, null); 24 | this.combinations.put(Numbers.THREES, null); 25 | this.combinations.put(Numbers.FOURS, null); 26 | this.combinations.put(Numbers.FIVES, null); 27 | this.combinations.put(Numbers.SIXES, null); 28 | 29 | this.combinations.put(NOfAKind.PAIR, null); 30 | this.combinations.put(NOfAKind.THREE_OF_A_KIND, null); 31 | this.combinations.put(NOfAKind.FOUR_OF_A_KIND, null); 32 | 33 | this.combinations.put(FullHouse.FULL_HOUSE, null); 34 | this.combinations.put(TwoPairs.TWO_PAIRS, null); 35 | 36 | this.combinations.put(SmallStraight.SMALL_STRAIGHT, null); 37 | this.combinations.put(LargeStraight.LARGE_STRAIGHT, null); 38 | 39 | this.combinations.put(Chance.CHANCE, null); 40 | this.combinations.put(Yahtzee.YAHTZEE, null); 41 | } 42 | 43 | public PlayerStats() { 44 | this.initAllCombinations(); 45 | } 46 | 47 | public Set getAvailableCombinations() { 48 | return this.combinations.entrySet().stream().filter(e -> e.getValue() == null).map(Entry::getKey).collect(Collectors.toSet()); 49 | } 50 | 51 | public void put(Combination combination, int score) { 52 | if (this.combinations.get(combination) != null) { 53 | throw new AssertionError(String.format("Combination %s has already been played", combination)); 54 | } 55 | this.combinations.put(combination, score); 56 | } 57 | 58 | private int getUpperSectionPoints() { 59 | return this.combinations.entrySet().stream().filter(e -> e.getKey() instanceof Numbers) 60 | .map(Entry::getValue).mapToInt(e -> e).sum(); 61 | } 62 | 63 | private int getLowerSectionPoints() { 64 | return this.combinations.entrySet().stream().filter(e -> !(e.getKey() instanceof Numbers)) 65 | .map(Entry::getValue).mapToInt(e -> e).sum(); 66 | } 67 | 68 | private boolean bonusEarned() { 69 | return this.getUpperSectionPoints() >= NUMBERS_VALUE_REACH_TO_GET_BONUS; 70 | } 71 | 72 | public Integer getFinalPoints() { 73 | return this.getLowerSectionPoints() + this.getUpperSectionPoints() + (this.bonusEarned() ? BONUS_VALUE : 0); 74 | } 75 | 76 | @Override 77 | public String toString() { 78 | return "PlayerStats{" + 79 | "combinations=" + combinations + 80 | '}'; 81 | } 82 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/model/DiceLayout.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.model; 2 | 3 | import hlibbabii.yahtzee.DiceValues; 4 | import hlibbabii.yahtzee.gameplay.Game; 5 | 6 | import java.util.*; 7 | import java.util.stream.Collectors; 8 | import java.util.stream.IntStream; 9 | import java.util.stream.Stream; 10 | 11 | public class DiceLayout { 12 | 13 | private static final String INVALID_DICE_LAYOUT_MESSAGE_TEMPLATE = "Invalid dice layout: %s. %s"; 14 | 15 | private SortedMap valuesToOccurences; 16 | 17 | private DiceLayout(SortedMap valuesToOccurences) { 18 | this.valuesToOccurences = valuesToOccurences; 19 | 20 | this.checkAlignmentInvariants(valuesToOccurences); 21 | } 22 | 23 | private void checkAlignmentInvariants(Map valuesToOccurences) { 24 | if (this.getSize() > Game.N_DICE) { 25 | String reason = String.format("The number of dice cannot be more than %s.", Game.N_DICE); 26 | throw new IllegalArgumentException(String.format(INVALID_DICE_LAYOUT_MESSAGE_TEMPLATE, Arrays.toString(this.toSortedNumbers()), reason)); 27 | } 28 | for (Integer rolledNumber : valuesToOccurences.keySet()) { 29 | if (rolledNumber < DiceValues.MIN_DICE_VALUE || rolledNumber > DiceValues.MAX_DICE_VALUE) { 30 | String reason = "Rolled numbers contain invalid values."; 31 | throw new IllegalArgumentException(String.format(INVALID_DICE_LAYOUT_MESSAGE_TEMPLATE, Arrays.toString(this.toSortedNumbers()), reason)); 32 | } 33 | } 34 | } 35 | 36 | /* Dice layout creation options */ 37 | 38 | public static DiceLayout empty() { 39 | return new DiceLayout(new TreeMap<>()); 40 | } 41 | 42 | public static DiceLayout fromMap(SortedMap valuesToOccurences) { 43 | return new DiceLayout(valuesToOccurences); 44 | } 45 | 46 | private static DiceLayout fromStream(Stream stream) { 47 | return DiceLayout.fromMap(new TreeMap<>( 48 | stream.collect(Collectors.groupingBy((a) -> a, Collectors.summingInt((e)->1))) 49 | )); 50 | } 51 | 52 | public static DiceLayout fromNumbers(int... numbers) { 53 | return fromStream(IntStream.of(numbers).boxed()); 54 | } 55 | 56 | public static DiceLayout fromNumbers(List rolledNumbers) { 57 | return fromStream(rolledNumbers.stream()); 58 | } 59 | 60 | /* Dice layout representation options */ 61 | 62 | public Map toCounts() { 63 | return this.valuesToOccurences; 64 | } 65 | 66 | public int getCount(int number) { 67 | return toCounts().getOrDefault(number, 0); 68 | } 69 | 70 | public int[] toSortedNumbers() { 71 | return this.valuesToOccurences.entrySet().stream() 72 | .flatMapToInt((a) -> IntStream.range(0, a.getValue()).map((k) -> a.getKey())).toArray(); 73 | } 74 | 75 | /* */ 76 | 77 | public int getSize() { 78 | return this.valuesToOccurences.values().stream().mapToInt(e -> e).sum(); 79 | } 80 | 81 | @Override 82 | public boolean equals(Object o) { 83 | if (this == o) return true; 84 | if (o == null || this.getClass() != o.getClass()) return false; 85 | DiceLayout that = (DiceLayout) o; 86 | return Objects.equals(this.valuesToOccurences, that.valuesToOccurences); 87 | } 88 | 89 | @Override 90 | public int hashCode() { 91 | return Objects.hash(this.valuesToOccurences); 92 | } 93 | 94 | @Override 95 | public String toString() { 96 | return "DiceLayout{" + 97 | "valuesToOccurences=" + Arrays.toString(this.toSortedNumbers()) + 98 | '}'; 99 | } 100 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/player/DummyPlayer.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.player; 2 | 3 | import hlibbabii.yahtzee.Player; 4 | import hlibbabii.yahtzee.gameplay.Decision; 5 | import hlibbabii.yahtzee.model.DiceLayout; 6 | import hlibbabii.yahtzee.combination.Combination; 7 | 8 | import java.util.Set; 9 | 10 | /** 11 | * This player rolls the dice as many times as he/she/it can and chooses the first combination which is 12 | * available (even if will give 0 points and there are other combination which can give more than 0 points). 13 | */ 14 | public class DummyPlayer implements Player { 15 | 16 | @Override 17 | public Decision makeDecision(DiceLayout diceLayout, DiceLayout fixedDiceLayout, Set availableCombinations, int rollsLeft) { 18 | if (rollsLeft == 0) { 19 | return Decision.decideCombination(availableCombinations.iterator().next()); 20 | // availableCombinations.stream().map(c -> c.earnedScores(diceLayout)).max(Comparator.comparingInt(a->a)); 21 | } else { 22 | return Decision.fixLayoutAndRoll(fixedDiceLayout); 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/util/RandomService.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.util; 2 | 3 | import java.util.Random; 4 | 5 | public class RandomService { 6 | 7 | private final Random rand; 8 | 9 | public RandomService() { 10 | this.rand = new Random(); 11 | } 12 | 13 | public int getRandomDiceNumber() { 14 | return this.rand.nextInt(6) + 1; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/ChanceTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static org.junit.jupiter.api.Assertions.*; 7 | 8 | class ChanceTest { 9 | @Test 10 | void testChance() { 11 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,3,3,6,6); 12 | int actual = Chance.CHANCE.earnedScores(diceLayout); 13 | assertEquals(19, actual); 14 | } 15 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/FullHouseTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static org.junit.jupiter.api.Assertions.*; 7 | 8 | class FullHouseTest { 9 | @Test 10 | void testFullHouseAllDistinct() { 11 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,3,4,5,6); 12 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout); 13 | assertEquals(0, actual); 14 | } 15 | 16 | @Test 17 | void testFullHouseYahtzeePresent() { 18 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,2,2); 19 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout); 20 | assertEquals(0, actual); 21 | } 22 | 23 | @Test 24 | void testFullHouseThreeOfAKindPresent() { 25 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,3,4); 26 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout); 27 | assertEquals(0, actual); 28 | } 29 | 30 | @Test 31 | void testFullHousePositive() { 32 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,3,3); 33 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout); 34 | assertEquals(12, actual); 35 | } 36 | 37 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/LargeStraightTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static org.junit.jupiter.api.Assertions.*; 7 | 8 | class LargeStraightTest { 9 | 10 | @Test 11 | void earnedScoresSmallStraight() { 12 | DiceLayout diceLayout = DiceLayout.fromNumbers(1, 2, 3, 4, 5); 13 | int actual = new LargeStraight().earnedScores(diceLayout); 14 | assertEquals(0, actual); 15 | } 16 | 17 | @Test 18 | void earnedScoresLargeStraight() { 19 | DiceLayout diceLayout = DiceLayout.fromNumbers(2, 3, 4, 5, 6); 20 | int actual = new LargeStraight().earnedScores(diceLayout); 21 | assertEquals(20, actual); 22 | } 23 | 24 | @Test 25 | void earnedScoresNoStraight() { 26 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,1,1,2,3); 27 | int actual = new LargeStraight().earnedScores(diceLayout); 28 | assertEquals(0, actual); 29 | } 30 | 31 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/NOfAKindTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | 4 | import hlibbabii.yahtzee.model.DiceLayout; 5 | import org.junit.jupiter.api.Assertions; 6 | import org.junit.jupiter.api.Test; 7 | 8 | class NOfAKindTest { 9 | @Test 10 | public void testNoPairs() { 11 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,6,5); 12 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 13 | Assertions.assertEquals(0, actual); 14 | } 15 | 16 | @Test 17 | public void testHighPair() { 18 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,5,5); 19 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 20 | Assertions.assertEquals(10, actual); 21 | } 22 | 23 | @Test 24 | public void testMediumPair() { 25 | DiceLayout diceLayout = DiceLayout.fromNumbers(6,1,2,2,5); 26 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 27 | Assertions.assertEquals(4, actual); 28 | } 29 | 30 | @Test 31 | public void testPairLow() { 32 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,6,1); 33 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 34 | Assertions.assertEquals(2, actual); 35 | } 36 | 37 | @Test 38 | public void testPairTwoPairsPresent() { 39 | DiceLayout diceLayout = DiceLayout.fromNumbers(6,5,3,3,5); 40 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 41 | Assertions.assertEquals(10, actual); 42 | } 43 | 44 | @Test 45 | public void testPairThreeOfAKindPresent() { 46 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,3,5,3,3); 47 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 48 | Assertions.assertEquals(6, actual); 49 | } 50 | 51 | @Test 52 | public void testPairFullHouseHighPresent() { 53 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,4,4,4,2); 54 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 55 | Assertions.assertEquals(8, actual); 56 | } 57 | 58 | @Test 59 | public void testPairFullHouseLowPresent() { 60 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,4,4,2,2); 61 | int actual = NOfAKind.PAIR.earnedScores(diceLayout); 62 | Assertions.assertEquals(8, actual); 63 | } 64 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/SmallStraightTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static org.junit.jupiter.api.Assertions.*; 7 | 8 | class SmallStraightTest { 9 | 10 | @Test 11 | void earnedScoresSmallStraight() { 12 | DiceLayout diceLayout = DiceLayout.fromNumbers(1, 2, 3, 4, 5); 13 | int actual = new SmallStraight().earnedScores(diceLayout); 14 | assertEquals(15, actual); 15 | } 16 | 17 | @Test 18 | void earnedScoresLargeStraight() { 19 | DiceLayout diceLayout = DiceLayout.fromNumbers(2, 3, 4, 5, 6); 20 | int actual = new SmallStraight().earnedScores(diceLayout); 21 | assertEquals(0, actual); 22 | } 23 | 24 | @Test 25 | void earnedScoresNoStraight() { 26 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,1,1,2,3); 27 | int actual = new SmallStraight().earnedScores(diceLayout); 28 | assertEquals(0, actual); 29 | } 30 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/TwoPairsTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.combination; 2 | 3 | import hlibbabii.yahtzee.model.DiceLayout; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import static hlibbabii.yahtzee.combination.TwoPairs.TWO_PAIRS; 7 | import static org.junit.jupiter.api.Assertions.*; 8 | 9 | class TwoPairsTest { 10 | 11 | @Test 12 | void testTwoPairsAllDistinct() { 13 | DiceLayout diceLayout = DiceLayout.fromNumbers(1, 2, 3, 5, 6); 14 | int actual = TWO_PAIRS.earnedScores(diceLayout); 15 | assertEquals(0, actual); 16 | } 17 | 18 | @Test 19 | void testTwoPairsOnePairPresent() { 20 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,2,3,4); 21 | int actual = TWO_PAIRS.earnedScores(diceLayout); 22 | assertEquals(0, actual); 23 | } 24 | 25 | @Test 26 | void testTwoPairsTwoPairsPresent() { 27 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,1,4,5,5); 28 | int actual = TWO_PAIRS.earnedScores(diceLayout); 29 | assertEquals(12, actual); 30 | } 31 | 32 | @Test 33 | void testTwoPairsFullHousePresent() { 34 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,6,6); 35 | int actual = TWO_PAIRS.earnedScores(diceLayout); 36 | assertEquals(16, actual); 37 | } 38 | 39 | @Test 40 | void testTwoPairsForOfAKindPresent() { 41 | DiceLayout diceLayout = DiceLayout.fromNumbers(3,5,5,5,5); 42 | int actual = TWO_PAIRS.earnedScores(diceLayout); 43 | assertEquals(0, actual); 44 | } 45 | } -------------------------------------------------------------------------------- /test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/model/DiceLayoutTest.java: -------------------------------------------------------------------------------- 1 | package hlibbabii.yahtzee.model; 2 | 3 | import org.junit.jupiter.api.Test; 4 | 5 | import java.util.Map; 6 | 7 | import static org.junit.jupiter.api.Assertions.*; 8 | 9 | class DiceLayoutTest { 10 | 11 | @Test 12 | void testEmpty() { 13 | DiceLayout diceLayout = DiceLayout.empty(); 14 | 15 | assertEquals(0, diceLayout.toCounts().entrySet().size()); 16 | } 17 | 18 | @Test 19 | void testFromNumbers() { 20 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,4,4); 21 | 22 | int[] numbers = diceLayout.toSortedNumbers(); 23 | Map counts = diceLayout.toCounts(); 24 | } 25 | 26 | } -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | import sys 7 | 8 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) -------------------------------------------------------------------------------- /tests/api/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /tests/api/test_corpus.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | from unittest import mock 7 | from unittest.mock import Mock 8 | 9 | from codeprep.api.corpus import preprocess_corpus 10 | from codeprep.prepconfig import PrepConfig, PrepParam 11 | 12 | PATH_TO_CUR_DIR_STUB = os.path.join('path', 'to', 'curdir') 13 | PATH_TO_DATASET_STUB = os.path.join('path', 'to', 'dataset') 14 | PATH_TO_OUTPUT_STUB = os.path.join('path', 'to', 'output') 15 | 16 | DEFAULT_PREP_CONFIG = PrepConfig({ 17 | PrepParam.EN_ONLY: 'u', 18 | PrepParam.COM: 'c', 19 | PrepParam.STR: '1', 20 | PrepParam.SPLIT: '0', 21 | PrepParam.TABS_NEWLINES: '0', 22 | PrepParam.CASE: 'u', 23 | }) 24 | 25 | 26 | @mock.patch('codeprep.api.corpus.Dataset', autospec=True) 27 | @mock.patch('codeprep.api.corpus.stages', autospec=True) 28 | @mock.patch('codeprep.cli.impl.os.getcwd', autospec=True, return_value=PATH_TO_CUR_DIR_STUB) 29 | def test_simple(os_mock, stages_mock, dataset_mock): 30 | # given 31 | dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock) 32 | 33 | 34 | # when 35 | preprocess_corpus(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG) 36 | 37 | # then 38 | dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, None, None, 39 | overriden_path_to_prep_dataset=PATH_TO_CUR_DIR_STUB, suppress_caching=False) 40 | stages_mock.run_until_preprocessing.assert_called_with(dataset_mock, None) 41 | 42 | 43 | @mock.patch('codeprep.api.corpus.Dataset', autospec=True) 44 | @mock.patch('codeprep.api.corpus.stages', autospec=True) 45 | @mock.patch('codeprep.cli.impl.os.getcwd', autospec=True, return_value=PATH_TO_CUR_DIR_STUB) 46 | def test_calc_vocab(os_mock, stages_mock, dataset_mock): 47 | # given 48 | dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock) 49 | 50 | # when 51 | preprocess_corpus(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, calc_vocab=True, suppress_caching=True) 52 | 53 | # then 54 | dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, None, None, 55 | overriden_path_to_prep_dataset=PATH_TO_CUR_DIR_STUB, suppress_caching=True) 56 | stages_mock.run_until_vocab.assert_called_with(dataset_mock, None) 57 | 58 | 59 | @mock.patch('codeprep.api.corpus.Dataset', autospec=True) 60 | @mock.patch('codeprep.api.corpus.stages', autospec=True) 61 | def test_output(stages_mock, dataset_mock): 62 | # given 63 | dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock) 64 | 65 | # when 66 | preprocess_corpus(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, output_path=PATH_TO_OUTPUT_STUB) 67 | 68 | # then 69 | dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, None, None, 70 | overriden_path_to_prep_dataset=PATH_TO_OUTPUT_STUB, suppress_caching=False) 71 | stages_mock.run_until_preprocessing.assert_called_with(dataset_mock, None) -------------------------------------------------------------------------------- /tests/bpepkg/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /tests/bpepkg/test_merge.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from unittest import mock 6 | 7 | from unittest.mock import MagicMock 8 | 9 | import pytest 10 | from pytest import fixture 11 | 12 | import codeprep 13 | from codeprep.bpepkg import merge 14 | from codeprep.bpepkg.merge import MergeList, Merge 15 | 16 | 17 | @fixture 18 | def file_handle_mock(mocker): 19 | mocker.patch('codeprep.bpepkg.merge.open') 20 | codeprep.bpepkg.merge.open.return_value = MagicMock(spec=['__enter__', '__exit__']) 21 | handle = codeprep.bpepkg.merge.open.return_value.__enter__.return_value 22 | return handle 23 | 24 | 25 | def test_read_merges(file_handle_mock): 26 | file_handle_mock.__iter__.return_value = iter(['a b 67', 'b c 34', 'c d 94']) 27 | 28 | actual = merge.read_merges('file', 2) 29 | expected = MergeList().append(Merge(('a', 'b'), 67, 0)).append(Merge(('b', 'c'), 34, 1)) 30 | 31 | assert expected == actual 32 | 33 | 34 | def test_read_merges_with_wrong_delimiter(file_handle_mock): 35 | with pytest.raises(ValueError): 36 | file_handle_mock.__iter__.return_value = iter(['a\tb\t67']) 37 | 38 | merge.read_merges('file') 39 | 40 | 41 | def test_dump_merges(file_handle_mock): 42 | merges = MergeList().append(Merge(('a', 'b'), 67, 0)).append(Merge(('b', 'c'), 34, 1)) 43 | merge.dump_merges(merges, 'file') 44 | 45 | file_handle_mock.write.assert_has_calls([ 46 | mock.call('a b 67\n'), 47 | mock.call('b c 34\n') 48 | ]) -------------------------------------------------------------------------------- /tests/bpepkg/wild_bpe_performance.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import random 6 | 7 | from typing import List 8 | 9 | from codeprep.bpepkg import wild_bpe 10 | from codeprep.bpepkg.wild_bpe import BpePerformanceStatsEntry, run 11 | 12 | 13 | def gen_performance_test_case(data_size_mb: float, entropy: int): 14 | char_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] 15 | for i in range(int(data_size_mb * (2 ** 20))): 16 | w = random.choice(char_list[:2 ** entropy]) 17 | for i in range(len(w)): 18 | yield w[i] 19 | 20 | 21 | def plotting_function(data_size_mb: float, entropy: int, version: str, 22 | performance_stats: List[BpePerformanceStatsEntry], 23 | final_show: bool = True): 24 | merges_done = list(map(lambda p: p.merges_done, performance_stats)) 25 | n_pq_entries = list(map(lambda p: p.n_priority_queue_entries, performance_stats)) 26 | n_index_entries = list(map(lambda p: p.n_index_entries, performance_stats)) 27 | location_index_obj_size = list(map(lambda p: p.location_index_obj_size, performance_stats)) 28 | neighbour_index_obj_size = list(map(lambda p: p.neighbour_index_obj_size, performance_stats)) 29 | priority_counter_obj_size = list(map(lambda p: p.priority_counter_obj_size, performance_stats)) 30 | total_index_size = [a + b + c for a, b, c in 31 | zip(location_index_obj_size, neighbour_index_obj_size, priority_counter_obj_size)] 32 | merge_time = list(map(lambda p: p.time_for_last_merge, performance_stats)) 33 | 34 | import matplotlib.pyplot as plt 35 | fig, splots = plt.subplots(nrows=3) 36 | splots[0].plot(merges_done, n_pq_entries, label='priority counter') 37 | splots[0].plot(merges_done, n_index_entries, label='indices') 38 | splots[0].set(ylabel='entries', 39 | title=f'Wild BPE version {version}, Data size: {data_size_mb} MB, entropy: {entropy} bit') 40 | splots[0].legend(loc='upper right') 41 | 42 | splots[1].plot(merges_done, location_index_obj_size, label='location index') 43 | splots[1].plot(merges_done, neighbour_index_obj_size, label='neighbour index') 44 | splots[1].plot(merges_done, priority_counter_obj_size, label='priority counter') 45 | splots[1].plot(merges_done, total_index_size, label='total') 46 | splots[1].set(xlabel='number of merges', ylabel='memory consumed (MB)') 47 | splots[1].legend(loc='upper right') 48 | 49 | splots[2].plot(merges_done, merge_time) 50 | splots[2].set(xlabel='number of merges', ylabel='time per merge (s)') 51 | 52 | fig.savefig(f'{version}_{data_size_mb}_{entropy}bit.png') 53 | if final_show: 54 | plt.show() 55 | 56 | 57 | def test_performance(): 58 | test_cases = [ 59 | {'mb': 0.05, 'entropy': 1}, 60 | {'mb': 0.05, 'entropy': 2}, 61 | {'mb': 0.05, 'entropy': 3}, 62 | {'mb': 0.5, 'entropy': 1}, 63 | {'mb': 0.5, 'entropy': 2}, 64 | {'mb': 0.5, 'entropy': 3}, 65 | {'mb': 5, 'entropy': 1}, 66 | {'mb': 5, 'entropy': 2}, 67 | {'mb': 5, 'entropy': 3}, 68 | ] 69 | stats_every_n = 50 70 | 71 | for test_case in test_cases: 72 | gen = run(gen_performance_test_case(test_case['mb'], test_case['entropy']), include_performance_stats_every_n_merges=stats_every_n) 73 | for i in range(1000): 74 | try: 75 | merge, occurences, stats = next(gen) 76 | except StopIteration: 77 | pass 78 | plotting_function(test_case['mb'], test_case['entropy'], wild_bpe.__version__, stats) 79 | 80 | 81 | if __name__ == '__main__': 82 | test_performance() -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /tests/infrastructure/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /tests/infrastructure/test_bpelearner.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from unittest import mock 6 | 7 | import pytest 8 | 9 | from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam, BpeConfigNotSupported 10 | from codeprep.pipeline.bpelearner import run 11 | 12 | 13 | @mock.patch('codeprep.pipeline.bpelearner.Dataset', autospec=True) 14 | def test_run_word_end(mocked_dataset): 15 | bpe_config = BpeConfig({ 16 | BpeParam.BASE: 'code', 17 | BpeParam.WORD_END: True, 18 | BpeParam.UNICODE: 'yes', 19 | BpeParam.CASE: 'yes' 20 | }) 21 | with pytest.raises(BpeConfigNotSupported): 22 | run(mocked_dataset, 1, bpe_config) 23 | 24 | 25 | @mock.patch('codeprep.pipeline.bpelearner.Dataset', autospec=True) 26 | def test_run_bytes_bpe(mocked_dataset): 27 | bpe_config = BpeConfig({ 28 | BpeParam.BASE: 'code', 29 | BpeParam.WORD_END: False, 30 | BpeParam.UNICODE: 'bytes', 31 | BpeParam.CASE: 'yes' 32 | }) 33 | with pytest.raises(BpeConfigNotSupported): 34 | run(mocked_dataset, 1, bpe_config) -------------------------------------------------------------------------------- /tests/infrastructure/test_bperegistry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | from unittest.mock import patch 7 | 8 | from codeprep.pipeline.bperegistry import get_max_merges, format_available_merge_list_ids, get_min_merges 9 | from codeprep.pipeline.dataset import create_new_id_from 10 | 11 | 12 | PATH_TO_DATASET_BPE_DIR_STUB = os.path.join('/', 'path', 'to', 'dataset', 'bpe', 'dir') 13 | PATH_TO_DATASET_STUB = os.path.join('/', 'path', 'to', 'dataset') 14 | HLIB_PATH = '/home/hlib/path' 15 | 16 | 17 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 18 | def test_with_predefined_id(bpe_config_mock): 19 | bpe_config_mock.to_suffix.return_value = '' 20 | assert create_new_id_from(PATH_TO_DATASET_STUB, bpe_config_mock, 'id23') == 'id23' 21 | 22 | 23 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 24 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, return_value={}) 25 | def test_no_existing_bpe_codes(mock, bpe_config_mock): 26 | bpe_config_mock.to_suffix.return_value = '' 27 | assert create_new_id_from(PATH_TO_DATASET_STUB, bpe_config_mock) == 'dataset' 28 | 29 | 30 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 31 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, 32 | return_value={'dataset': 10, 'dataset4': 20, 'dataset_3': 30}) 33 | def test_ids_for_same_dataset_exist(mock, bpe_config_mock): 34 | bpe_config_mock.to_suffix.return_value = '' 35 | assert create_new_id_from(PATH_TO_DATASET_STUB, bpe_config_mock) == 'dataset_4' 36 | 37 | 38 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 39 | def test_with_predefined_codes_id(bpe_config_mock): 40 | bpe_config_mock.to_suffix.return_value = "" 41 | assert create_new_id_from(HLIB_PATH, bpe_config_mock, 'my-id') == 'my-id' 42 | 43 | 44 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 45 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, return_value="") 46 | def test_simple(mock, bpe_config_mock): 47 | # given 48 | bpe_config_mock.to_suffix.return_value = "" 49 | 50 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path' 51 | 52 | 53 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 54 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, 55 | return_value={'path': 1000}) 56 | def test_same_path_exists(mock, bpe_config_mock): 57 | # given 58 | bpe_config_mock.to_suffix.return_value = "" 59 | 60 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path_1' 61 | 62 | 63 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 64 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, 65 | return_value={'path': 1000, 'path_1': 2000}) 66 | def test_same_path_and_next_one_exist(mock, bpe_config_mock): 67 | # given 68 | bpe_config_mock.to_suffix.return_value = "" 69 | 70 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path_2' 71 | 72 | 73 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True) 74 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, 75 | return_value={'path': 1000, 'path_28': 2000}) 76 | def test_same_path_and_one_more_exist(mock, bpe_config_mock): 77 | # given 78 | bpe_config_mock.to_suffix.return_value = "" 79 | 80 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path_29' 81 | 82 | 83 | @patch('codeprep.pipeline.bperegistry.os.walk', return_value=iter([('', [], [])])) 84 | def test_none(mocked_walk): 85 | assert get_max_merges('.') is None 86 | 87 | 88 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, return_value={}) 89 | def test_no_available_merge_lists(bpe_registry_mock): 90 | assert format_available_merge_list_ids() == "" 91 | 92 | 93 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, 94 | return_value={"a": 1000, "b": 500}) 95 | def test_simple(mock): 96 | assert format_available_merge_list_ids() == "a-[1..1000]\nb-[1..500]\n" 97 | 98 | 99 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=[]) 100 | def test_max_no_folders(mock): 101 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None 102 | 103 | 104 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=[]) 105 | def test_min_no_folders(mock): 106 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None 107 | 108 | 109 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['part_vocab']) 110 | def test_max_with_non_number_folder(mock): 111 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None 112 | 113 | 114 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['part_vocab']) 115 | def test_min_with_non_number_folder(mock): 116 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None 117 | 118 | 119 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', '20']) 120 | def test_max_all_folders_above_limit(mock): 121 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB, 5) is None 122 | 123 | 124 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', '20']) 125 | def test_min_all_folders_below_limit(mock): 126 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB) == 10 127 | 128 | 129 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', 'partvocab']) 130 | def test_max_one_folder_available(mock): 131 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) == 10 132 | 133 | 134 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', 'partvocab']) 135 | def test_min_one_folder_available(mock): 136 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) == 10 137 | 138 | 139 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, 140 | return_value=['10', '20', '15', '30', 'partvocab']) 141 | def test_max_simple(mock): 142 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB, 20) == 20 143 | 144 | 145 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, 146 | return_value=['10', '20', '15', '30', 'partvocab']) 147 | def test_min_simple(mock): 148 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB, 15) == 15 -------------------------------------------------------------------------------- /tests/infrastructure/test_dataset.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | from unittest import mock 7 | 8 | from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam 9 | from codeprep.config import USER_CONFIG_DIR, VOCAB_DIR, BPE_DIR, USER_CACHE_DIR 10 | from codeprep.pipeline.bperegistry import CustomBpeConfig 11 | from codeprep.pipeline.dataset import Dataset, SubDataset 12 | from codeprep.prepconfig import PrepConfig, PrepParam 13 | 14 | 15 | PATH_TO_DATASET_STUB = os.path.join('/', 'path', 'to', 'dataset') 16 | PARSED_DATASETS_DIR = os.path.join('/', 'parsed', 'datasets') 17 | PREP_DATASETS_DIR = os.path.join('/', 'prep', 'datasets') 18 | OVERRIDDEN_PATH = os.path.join('/', 'overridden', 'path') 19 | 20 | 21 | @mock.patch('os.path.exists', autospec=True, return_value=True) 22 | @mock.patch('codeprep.pipeline.dataset.get_timestamp', autospec=True, return_value="01_01_01") 23 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PARSED_DATASETS_DIR', PARSED_DATASETS_DIR) 24 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PREP_DATASETS_DIR', PREP_DATASETS_DIR) 25 | def test_non_bpe_split(get_timestamp_mock, os_exists_mock): 26 | prep_config = PrepConfig({ 27 | PrepParam.EN_ONLY: 'u', 28 | PrepParam.COM: 'c', 29 | PrepParam.STR: '1', 30 | PrepParam.SPLIT: '0', 31 | PrepParam.TABS_NEWLINES: 's', 32 | PrepParam.CASE: 'u' 33 | }) 34 | 35 | actual = Dataset.create(PATH_TO_DATASET_STUB, prep_config, None, None) 36 | 37 | assert PATH_TO_DATASET_STUB == actual._path 38 | assert prep_config == actual._prep_config 39 | assert actual._normalized_extension_list is None 40 | assert actual._custom_bpe_config is None 41 | assert actual._bpe_config is None 42 | assert '01_01_01', actual._dataset_last_modified 43 | 44 | assert SubDataset(actual, PATH_TO_DATASET_STUB, '') == actual._original 45 | assert SubDataset(actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01'), '.parsed') == actual._parsed 46 | assert SubDataset(actual, os.path.join(PREP_DATASETS_DIR, 'dataset_01_01_01_-_uc10su'), '.prep') == actual._preprocessed 47 | 48 | 49 | @mock.patch('os.path.exists', autospec=True, return_value=True) 50 | @mock.patch('codeprep.pipeline.dataset.get_timestamp', autospec=True, return_value="01_01_01") 51 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PARSED_DATASETS_DIR', PARSED_DATASETS_DIR) 52 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PREP_DATASETS_DIR', PREP_DATASETS_DIR) 53 | def test_non_bpe_split_with_one_extension(get_timestamp_mock, os_exists_mock): 54 | prep_config = PrepConfig({ 55 | PrepParam.EN_ONLY: 'u', 56 | PrepParam.COM: 'c', 57 | PrepParam.STR: '1', 58 | PrepParam.SPLIT: '0', 59 | PrepParam.TABS_NEWLINES: 's', 60 | PrepParam.CASE: 'u' 61 | }) 62 | 63 | actual = Dataset.create(PATH_TO_DATASET_STUB, prep_config, "java", None) 64 | 65 | assert PATH_TO_DATASET_STUB == actual._path 66 | assert prep_config == actual._prep_config 67 | assert ['java'] == actual._normalized_extension_list 68 | assert actual._custom_bpe_config is None 69 | assert actual._bpe_config is None 70 | assert '01_01_01' == actual._dataset_last_modified 71 | 72 | assert SubDataset(actual, PATH_TO_DATASET_STUB, ''), actual._original 73 | assert SubDataset(actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01_java'), '.parsed'), actual._parsed 74 | assert SubDataset(actual, os.path.join(PREP_DATASETS_DIR, 'dataset_01_01_01_java_-_uc10su'), '.prep'), actual._preprocessed 75 | 76 | 77 | @mock.patch('os.path.exists', autospec=True, return_value=True) 78 | @mock.patch('codeprep.pipeline.dataset.get_timestamp', autospec=True, return_value="01_01_01") 79 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PARSED_DATASETS_DIR', PARSED_DATASETS_DIR) 80 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PREP_DATASETS_DIR', PREP_DATASETS_DIR) 81 | def test_all_custom(get_timestamp_mock, os_exists_mock): 82 | prep_config = PrepConfig({ 83 | PrepParam.EN_ONLY: 'u', 84 | PrepParam.COM: 'c', 85 | PrepParam.STR: '1', 86 | PrepParam.SPLIT: '0', 87 | PrepParam.TABS_NEWLINES: 's', 88 | PrepParam.CASE: 'u' 89 | }) 90 | bpe_config = BpeConfig({ 91 | BpeParam.CASE: 'yes', 92 | BpeParam.WORD_END: False, 93 | BpeParam.BASE: "code", 94 | BpeParam.UNICODE: "no", 95 | }) 96 | 97 | custom_bpe_config = CustomBpeConfig("id", 1000, "/codes/file", "/cache/file") 98 | actual = Dataset.create(PATH_TO_DATASET_STUB, prep_config, "c|java", custom_bpe_config, 99 | bpe_config, overriden_path_to_prep_dataset=OVERRIDDEN_PATH) 100 | 101 | assert PATH_TO_DATASET_STUB == actual._path 102 | assert prep_config == actual._prep_config 103 | assert ['c', 'java'] == actual._normalized_extension_list 104 | assert custom_bpe_config == actual._custom_bpe_config 105 | assert bpe_config == actual._bpe_config 106 | assert '01_01_01' == actual._dataset_last_modified 107 | 108 | assert SubDataset(actual, PATH_TO_DATASET_STUB, '') == actual.original 109 | assert SubDataset(actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01_c_java'), '.parsed') == actual.parsed 110 | assert SubDataset(actual, os.path.join(OVERRIDDEN_PATH, 'dataset_01_01_01_c_java_-_uc10su_id-1000_-_prep'), '.prep') == actual.preprocessed 111 | assert os.path.join(USER_CONFIG_DIR, VOCAB_DIR , 'dataset_01_01_01_c_java_-_U0EFsu') == actual.base_bpe_vocab_path 112 | assert os.path.join(USER_CONFIG_DIR, BPE_DIR , 'dataset_01_01_01_c_java_-_nounicode') == actual.bpe_path 113 | assert os.path.join(USER_CACHE_DIR, 'file_lists' , 'dataset_01_01_01_c_java') == actual.path_to_file_list_folder 114 | assert os.path.join(USER_CONFIG_DIR, VOCAB_DIR , 'dataset_01_01_01_c_java_-_uc10su_id-1000') == actual.vocab_path -------------------------------------------------------------------------------- /tests/parse/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /tests/parse/test_subtokens.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from codeprep.tokens.numeric import Number 6 | 7 | from codeprep.parse.matchers import split_into_words 8 | from codeprep.tokens.containers import SplitContainer 9 | from codeprep.tokens.whitespace import NewLine, SpaceInString 10 | from codeprep.tokens.word import Word, Underscore 11 | from codeprep.parse.subtokens import split_string 12 | 13 | 14 | def test_split_into_tokens(): 15 | actual = split_into_words("123\nAb2cd34Ef000GG j_89_J") 16 | 17 | expected = [Number('123'), 18 | NewLine(), 19 | SplitContainer([Word.from_('Ab'), Word.from_('2'), Word.from_('cd'), 20 | Word.from_('34'), Word.from_('Ef'), Word.from_('000'), Word.from_('GG')]), 21 | SplitContainer([Word.from_('j'), Underscore(), Word.from_('89'), Underscore(), Word.from_('J')])] 22 | 23 | assert expected == actual 24 | 25 | 26 | def test_split_string(): 27 | actual = split_string("123\nAb2cd34Ef000GG j_89_J") 28 | 29 | expected = [Number('123'), 30 | NewLine(), 31 | SplitContainer([Word.from_('Ab'), Word.from_('2'), Word.from_('cd'), 32 | Word.from_('34'), Word.from_('Ef'), Word.from_('000'), Word.from_('GG')]), 33 | SpaceInString(5), 34 | SplitContainer([Word.from_('j'), Underscore(), Word.from_('89'), Underscore(), Word.from_('J')])] 35 | 36 | assert expected == actual -------------------------------------------------------------------------------- /tests/test_corpus_b2b.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | import platform 7 | import shutil 8 | 9 | from codeprep.cli.spec import parse_and_run 10 | 11 | import codeprep.api.corpus as api 12 | from codeprep.config import root_package_dir 13 | 14 | PATH_TO_TEST_CORPUS = os.path.join(root_package_dir, '..', 'test-data', 'test-corpus') 15 | TEST_OUTPUT = os.path.join(root_package_dir, '..', 'test-output') 16 | 17 | 18 | def test_preprocess_with_different_options(): 19 | calc_vocab = True 20 | api.basic(path=PATH_TO_TEST_CORPUS, extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab) 21 | api.basic(path=PATH_TO_TEST_CORPUS, extensions="java", split_numbers=True, ronin=True, stem=True, 22 | no_spaces=True, no_unicode=True, no_case=True, no_com=True, no_str=True, max_str_length=30, 23 | output_path=TEST_OUTPUT, calc_vocab=calc_vocab) 24 | api.chars(path=PATH_TO_TEST_CORPUS, extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab) 25 | api.nosplit(path=PATH_TO_TEST_CORPUS, extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab) 26 | api.bpe(path=PATH_TO_TEST_CORPUS, bpe_codes_id='10k', extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab) 27 | 28 | 29 | def test_learn_bpe_codes(): 30 | if platform.system() != 'Darwin': 31 | parse_and_run(['learn-bpe', '100', '-p', PATH_TO_TEST_CORPUS, '-e', 'java']) 32 | parse_and_run(['learn-bpe', '150', '-p', PATH_TO_TEST_CORPUS, '-e', 'java']) 33 | 34 | api.bpe(path=PATH_TO_TEST_CORPUS, bpe_codes_id='test-corpus-130', extensions="java", output_path=TEST_OUTPUT) 35 | else: 36 | print('Skipping the test on OSx.') 37 | 38 | 39 | def teardown_function(function): 40 | print(f'Removing the outputs at: {TEST_OUTPUT}') 41 | if os.path.exists(TEST_OUTPUT): 42 | shutil.rmtree(TEST_OUTPUT) 43 | -------------------------------------------------------------------------------- /tests/test_subword_separation.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2020 Hlib Babii 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # import unittest 6 | # 7 | # from codeprep.bpepkg.bpe_encode import BpeData 8 | # from codeprep.parse.core import convert_text 9 | # from codeprep.parse.model.containers import SplitContainer 10 | # from codeprep.parse.model.numeric import Number 11 | # from codeprep.parse.model.placeholders import placeholders 12 | # from codeprep.parse.model.word import Underscore, Word 13 | # from codeprep.prepconfig import PrepConfig 14 | # from codeprep.to_repr import to_repr 15 | # 16 | # test_cases = { 17 | # "create": ( 18 | # [SplitContainer.from_single_token("create")], 19 | # ["create"], 20 | # ), 21 | # "Vector": ( 22 | # [SplitContainer.from_single_token("Vector")], 23 | # [placeholders["capital"], "vector"], 24 | # ), 25 | # "players": ( 26 | # [SplitContainer.from_single_token("players")], 27 | # [placeholders["word_start"], 'play', 'er', 's', placeholders["word_end"]] 28 | # ), 29 | # "0.345e+4": ( 30 | # [Number("0.345e+4")], 31 | # [placeholders["word_start"], "0.", "3", "4", "5", "e+", "4", placeholders["word_end"]] 32 | # ), 33 | # "bestPlayers": ( 34 | # [SplitContainer([Word.from_("best"), Word.from_("Players")])], 35 | # [placeholders["word_start"], "best", placeholders["capital"], 'play', "er", "s", placeholders["word_end"]] 36 | # ), 37 | # "test_BestPlayers": ( 38 | # [SplitContainer([Word.from_("test"), Underscore(), Word.from_("Best"), Word.from_("Players")])], 39 | # [placeholders["word_start"], "test", '_', placeholders["capital"], 40 | # "best", placeholders["capital"], 'play', "er", "s", placeholders["word_end"]] 41 | # ), 42 | # "test_BestPlayers_modified": ( 43 | # [SplitContainer( 44 | # [Word.from_("test"), Underscore(), Word.from_("Best"), Word.from_("Players"), Underscore(), 45 | # Word.from_("modified")] 46 | # )], 47 | # [placeholders["word_start"], "test", '_', placeholders["capital"], 48 | # "best", placeholders["capital"], 'play', "er", "s", '_', "mod", 49 | # "if", "ied", 50 | # placeholders["word_end"]] 51 | # ), 52 | # "N_PLAYERS_NUM": ( 53 | # [SplitContainer([Word.from_("N"), Underscore(), Word.from_("PLAYERS"), Underscore(), Word.from_("NUM")])], 54 | # [placeholders["word_start"], placeholders["capitals"], "n", '_', 55 | # placeholders["capitals"], "play", "er", "s", '_', placeholders["capitals"], 56 | # "num", placeholders["word_end"]] 57 | # ), 58 | # "_players": ( 59 | # [SplitContainer([Underscore(), (Word.from_("players"))])], 60 | # [placeholders['word_start'], '_', "play", "er", "s", placeholders['word_end']] 61 | # ), 62 | # } 63 | # 64 | # bpe_merges_cache = { 65 | # "players": ["play", "er", "s"], 66 | # "0.345e+4": ["0.", "3", "4", "5", "e+", "4"], 67 | # "modified": ["mod", "if", "ied"], 68 | # 69 | # "create": ["create"], 70 | # "vector": ["vector"], 71 | # "best": ["best"], 72 | # "test": ["test"], 73 | # "num": ["num"], 74 | # "user": ["user"], 75 | # "get": ["get"], 76 | # "nick": ["ni", "ck"], 77 | # "logger": ["logger"], 78 | # "info": ["info"] 79 | # } 80 | # 81 | # 82 | # class SubwordSeparation(unittest.TestCase): 83 | # def test(self): 84 | # for input, output_tuple in test_cases.items(): 85 | # parsed = [p for p in convert_text(input, "java")][:-1] 86 | # 87 | # self.assertEqual(output_tuple[0], parsed) 88 | # 89 | # repred, metadata = to_repr(PrepConfig.from_encoded_string('Uc140l'), parsed, BpeData(merges_cache=bpe_merges_cache)) 90 | # 91 | # self.assertEqual(output_tuple[1], repred) 92 | # 93 | # 94 | # if __name__ == '__main__': 95 | # unittest.main() -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # tox (https://tox.readthedocs.io/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = py36 8 | 9 | [testenv] 10 | deps = git+https://github.com/casics/spiral 11 | 12 | commands = 13 | python -m unittest discover 14 | --------------------------------------------------------------------------------