├── .gitignore
├── .idea
├── codeprep.iml
├── dictionaries
│ └── hlib.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── .reuse
└── dep5
├── .travis.yml
├── LICENSES
├── Apache-2.0.txt
└── CC0-1.0.txt
├── MANIFEST.in
├── README.md
├── codeprep
├── VERSION
├── __init__.py
├── __main__.py
├── api
│ ├── __init__.py
│ ├── common.py
│ ├── corpus.py
│ └── text.py
├── bpepkg
│ ├── __init__.py
│ ├── bpe_config.py
│ ├── bpe_encode.py
│ ├── bpe_learn.py
│ ├── cache.py
│ ├── merge.py
│ └── wild_bpe.py
├── cli
│ ├── __init__.py
│ ├── impl.py
│ ├── spec.py
│ └── vocab.py
├── config.py
├── data
│ ├── allamanis_dataset_metadata
│ │ └── small-train-projects.txt
│ └── bpe
│ │ ├── 10k
│ │ └── merges.txt
│ │ ├── case
│ │ ├── 0
│ │ │ └── merges.txt
│ │ ├── 10k
│ │ │ └── merges.txt
│ │ ├── 1k
│ │ │ └── merges.txt
│ │ ├── 2k
│ │ │ └── merges.txt
│ │ └── 5k
│ │ │ └── merges.txt
│ │ └── nocase
│ │ └── 0
│ │ └── merges.txt
├── dirutils.py
├── fileutils.py
├── logging.yaml
├── noneng.py
├── parse
│ ├── __init__.py
│ ├── core.py
│ ├── matchers.py
│ └── subtokens.py
├── pipeline
│ ├── __init__.py
│ ├── bpelearner.py
│ ├── bperegistry.py
│ ├── dataset.py
│ ├── parse_projects.py
│ ├── stages.py
│ ├── to_repr.py
│ ├── vocab.py
│ └── vocabloader.py
├── prepconfig.py
├── preprocess
│ ├── __init__.py
│ ├── core.py
│ ├── metadata.py
│ ├── placeholders.py
│ └── reprconfig.py
├── stemming.py
├── subtokens.py
├── tokens
│ ├── __init__.py
│ ├── containers.py
│ ├── noneng.py
│ ├── numeric.py
│ ├── rootclasses.py
│ ├── whitespace.py
│ └── word.py
└── util.py
├── reports
└── bpe
│ └── wild-bpe
│ ├── v0.1_0.05mb_1bit.png
│ ├── v0.1_0.05mb_1bit.png.license
│ ├── v0.1_0.05mb_2bit.png
│ ├── v0.1_0.05mb_2bit.png.license
│ ├── v0.1_0.05mb_3bit.png
│ ├── v0.1_0.05mb_3bit.png.license
│ ├── v0.1_0.5mb_1bit.png
│ ├── v0.1_0.5mb_1bit.png.license
│ ├── v0.1_0.5mb_2bit.png
│ ├── v0.1_0.5mb_2bit.png.license
│ ├── v0.1_0.5mb_3bit.png
│ ├── v0.1_0.5mb_3bit.png.license
│ ├── v0.1_5mb_1bit.png
│ ├── v0.1_5mb_1bit.png.license
│ ├── v0.1_5mb_2bit.png
│ ├── v0.1_5mb_2bit.png.license
│ ├── v0.1_5mb_3bit.png
│ └── v0.1_5mb_3bit.png.license
├── requirements-dev.txt
├── requirements.txt
├── setup.py
├── test-data
└── test-corpus
│ ├── jquery.min.js
│ └── yahtzee
│ ├── .reuse
│ └── dep5
│ ├── LICENSES
│ └── MIT.txt
│ └── src
│ ├── main
│ └── java
│ │ └── hlibbabii
│ │ └── yahtzee
│ │ ├── DiceValues.java
│ │ ├── GameDemo.java
│ │ ├── Player.java
│ │ ├── combination
│ │ ├── Chance.java
│ │ ├── Combination.java
│ │ ├── FullHouse.java
│ │ ├── LargeStraight.java
│ │ ├── NOfAKind.java
│ │ ├── Numbers.java
│ │ ├── SmallStraight.java
│ │ ├── TwoPairs.java
│ │ └── Yahtzee.java
│ │ ├── gameplay
│ │ ├── Decision.java
│ │ ├── Game.java
│ │ ├── GameStats.java
│ │ ├── MoveResult.java
│ │ ├── PlayerException.java
│ │ └── PlayerStats.java
│ │ ├── model
│ │ └── DiceLayout.java
│ │ ├── player
│ │ └── DummyPlayer.java
│ │ └── util
│ │ └── RandomService.java
│ └── test
│ └── java
│ └── hlibbabii
│ └── yahtzee
│ ├── combination
│ ├── ChanceTest.java
│ ├── FullHouseTest.java
│ ├── LargeStraightTest.java
│ ├── NOfAKindTest.java
│ ├── SmallStraightTest.java
│ └── TwoPairsTest.java
│ └── model
│ └── DiceLayoutTest.java
├── tests
├── __init__.py
├── api
│ ├── __init__.py
│ └── test_corpus.py
├── bpepkg
│ ├── __init__.py
│ ├── test_merge.py
│ └── wild_bpe_performance.py
├── cli
│ ├── __init__.py
│ └── test_spec.py
├── infrastructure
│ ├── __init__.py
│ ├── test_bpelearner.py
│ ├── test_bperegistry.py
│ └── test_dataset.py
├── parse
│ ├── __init__.py
│ ├── test_core.py
│ └── test_subtokens.py
├── test_corpus_b2b.py
├── test_subword_separation.py
└── test_to_repr.py
└── tox.ini
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # Pycharm
107 | **/.idea/workspace.xml
108 | **/.idea/tasks.xml
109 |
110 | !test-data
--------------------------------------------------------------------------------
/.idea/codeprep.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/dictionaries/hlib.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | databunch
5 | fastai
6 | mult
7 | numericalizer
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.reuse/dep5:
--------------------------------------------------------------------------------
1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/
2 | Upstream-Name: codeprep
3 | Upstream-Contact: Hlib Babii
4 | Source: https://github.com/giganticode/codeprep
5 |
6 | # Sample paragraph, commented out:
7 | #
8 | Files: codeprep/data/**/*.txt codeprep/logging.yaml
9 | Copyright: 2020 Hlib Babii
10 | License: Apache-2.0
11 |
12 | Files: .idea/**/*.* .idea/*.* .gitignore .travis.yml MANIFEST.in requirements.txt requirements-dev.txt tox.ini codeprep/VERSION
13 | Copyright: 2020 Hlib Babii
14 | License: CC0-1.0
15 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: xenial # required for Python >= 3.7
2 | env:
3 | global:
4 | - CC_TEST_REPORTER_ID=$CODE_CLIMATE_REPORTER_ID
5 | language: python
6 | if: 'type = pull_request OR branch = master'
7 | python:
8 | - "3.6"
9 | - "3.7"
10 | jobs:
11 | include:
12 | - os: osx
13 | osx_image: xcode11.2 # Python 3.7.4 running on macOS 10.14.4
14 | language: shell
15 | python: "3.7"
16 | before_script: echo "Not sending code coverage to code climate on OSx"
17 | after_script: echo "Not sending code coverage to code climate on OSx"
18 | - os: windows
19 | language: sh
20 | python: "3.7"
21 | before_install:
22 | - choco install python3 --params "/InstallDir:C:\\Python"
23 | - export PATH="/c/Python:/c/Python/Scripts:$PATH"
24 | - python -m pip install --upgrade pip wheel
25 | before_script: echo "Not sending code coverage to code climate on Windows"
26 | after_script: echo "Not sending code coverage to code climate on Windows"
27 | install:
28 | - pip3 install -r requirements.txt
29 | - pip3 install -r requirements-dev.txt
30 | before_script:
31 | - curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter
32 | - chmod +x ./cc-test-reporter
33 | - ./cc-test-reporter before-build
34 | script:
35 | - cd $TRAVIS_BUILD_DIR
36 | - echo "Current directory is $(pwd)"
37 | - coverage run --concurrency=multiprocessing -m pytest --doctest-modules
38 | - coverage combine
39 | - coverage report -m --include=./\*\* --omit=tests/\*\*,.venv/\*\* --fail-under=70
40 | - coverage xml -i --include=./\*\* --omit=tests/\*\*,.venv/\*\*
41 | after_script:
42 | - if [[ "$TRAVIS_PULL_REQUEST" == "false" && "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then echo "Reporting coverage to Code Climate" && ./cc-test-reporter after-build --exit-code $TRAVIS_TEST_RESULT; fi
43 |
--------------------------------------------------------------------------------
/LICENSES/CC0-1.0.txt:
--------------------------------------------------------------------------------
1 | Creative Commons Legal Code
2 |
3 | CC0 1.0 Universal CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES
4 | NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE
5 | AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION
6 | ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE USE
7 | OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER, AND DISCLAIMS
8 | LIABILITY FOR DAMAGES RESULTING FROM THE USE OF THIS DOCUMENT OR THE INFORMATION
9 | OR WORKS PROVIDED HEREUNDER.
10 |
11 | Statement of Purpose
12 |
13 | The laws of most jurisdictions throughout the world automatically confer exclusive
14 | Copyright and Related Rights (defined below) upon the creator and subsequent
15 | owner(s) (each and all, an "owner") of an original work of authorship and/or
16 | a database (each, a "Work").
17 |
18 | Certain owners wish to permanently relinquish those rights to a Work for the
19 | purpose of contributing to a commons of creative, cultural and scientific
20 | works ("Commons") that the public can reliably and without fear of later claims
21 | of infringement build upon, modify, incorporate in other works, reuse and
22 | redistribute as freely as possible in any form whatsoever and for any purposes,
23 | including without limitation commercial purposes. These owners may contribute
24 | to the Commons to promote the ideal of a free culture and the further production
25 | of creative, cultural and scientific works, or to gain reputation or greater
26 | distribution for their Work in part through the use and efforts of others.
27 |
28 | For these and/or other purposes and motivations, and without any expectation
29 | of additional consideration or compensation, the person associating CC0 with
30 | a Work (the "Affirmer"), to the extent that he or she is an owner of Copyright
31 | and Related Rights in the Work, voluntarily elects to apply CC0 to the Work
32 | and publicly distribute the Work under its terms, with knowledge of his or
33 | her Copyright and Related Rights in the Work and the meaning and intended
34 | legal effect of CC0 on those rights.
35 |
36 | 1. Copyright and Related Rights. A Work made available under CC0 may be protected
37 | by copyright and related or neighboring rights ("Copyright and Related Rights").
38 | Copyright and Related Rights include, but are not limited to, the following:
39 |
40 | i. the right to reproduce, adapt, distribute, perform, display, communicate,
41 | and translate a Work;
42 |
43 | ii. moral rights retained by the original author(s) and/or performer(s);
44 |
45 | iii. publicity and privacy rights pertaining to a person's image or likeness
46 | depicted in a Work;
47 |
48 | iv. rights protecting against unfair competition in regards to a Work, subject
49 | to the limitations in paragraph 4(a), below;
50 |
51 | v. rights protecting the extraction, dissemination, use and reuse of data
52 | in a Work;
53 |
54 | vi. database rights (such as those arising under Directive 96/9/EC of the
55 | European Parliament and of the Council of 11 March 1996 on the legal protection
56 | of databases, and under any national implementation thereof, including any
57 | amended or successor version of such directive); and
58 |
59 | vii. other similar, equivalent or corresponding rights throughout the world
60 | based on applicable law or treaty, and any national implementations thereof.
61 |
62 | 2. Waiver. To the greatest extent permitted by, but not in contravention of,
63 | applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and
64 | unconditionally waives, abandons, and surrenders all of Affirmer's Copyright
65 | and Related Rights and associated claims and causes of action, whether now
66 | known or unknown (including existing as well as future claims and causes of
67 | action), in the Work (i) in all territories worldwide, (ii) for the maximum
68 | duration provided by applicable law or treaty (including future time extensions),
69 | (iii) in any current or future medium and for any number of copies, and (iv)
70 | for any purpose whatsoever, including without limitation commercial, advertising
71 | or promotional purposes (the "Waiver"). Affirmer makes the Waiver for the
72 | benefit of each member of the public at large and to the detriment of Affirmer's
73 | heirs and successors, fully intending that such Waiver shall not be subject
74 | to revocation, rescission, cancellation, termination, or any other legal or
75 | equitable action to disrupt the quiet enjoyment of the Work by the public
76 | as contemplated by Affirmer's express Statement of Purpose.
77 |
78 | 3. Public License Fallback. Should any part of the Waiver for any reason be
79 | judged legally invalid or ineffective under applicable law, then the Waiver
80 | shall be preserved to the maximum extent permitted taking into account Affirmer's
81 | express Statement of Purpose. In addition, to the extent the Waiver is so
82 | judged Affirmer hereby grants to each affected person a royalty-free, non
83 | transferable, non sublicensable, non exclusive, irrevocable and unconditional
84 | license to exercise Affirmer's Copyright and Related Rights in the Work (i)
85 | in all territories worldwide, (ii) for the maximum duration provided by applicable
86 | law or treaty (including future time extensions), (iii) in any current or
87 | future medium and for any number of copies, and (iv) for any purpose whatsoever,
88 | including without limitation commercial, advertising or promotional purposes
89 | (the "License"). The License shall be deemed effective as of the date CC0
90 | was applied by Affirmer to the Work. Should any part of the License for any
91 | reason be judged legally invalid or ineffective under applicable law, such
92 | partial invalidity or ineffectiveness shall not invalidate the remainder of
93 | the License, and in such case Affirmer hereby affirms that he or she will
94 | not (i) exercise any of his or her remaining Copyright and Related Rights
95 | in the Work or (ii) assert any associated claims and causes of action with
96 | respect to the Work, in either case contrary to Affirmer's express Statement
97 | of Purpose.
98 |
99 | 4. Limitations and Disclaimers.
100 |
101 | a. No trademark or patent rights held by Affirmer are waived, abandoned, surrendered,
102 | licensed or otherwise affected by this document.
103 |
104 | b. Affirmer offers the Work as-is and makes no representations or warranties
105 | of any kind concerning the Work, express, implied, statutory or otherwise,
106 | including without limitation warranties of title, merchantability, fitness
107 | for a particular purpose, non infringement, or the absence of latent or other
108 | defects, accuracy, or the present or absence of errors, whether or not discoverable,
109 | all to the greatest extent permissible under applicable law.
110 |
111 | c. Affirmer disclaims responsibility for clearing rights of other persons
112 | that may apply to the Work or any use thereof, including without limitation
113 | any person's Copyright and Related Rights in the Work. Further, Affirmer disclaims
114 | responsibility for obtaining any necessary consents, permissions or other
115 | rights required for any use of the Work.
116 |
117 | d. Affirmer understands and acknowledges that Creative Commons is not a party
118 | to this document and has no duty or obligation with respect to this CC0 or
119 | use of the Work.
120 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | graft codeprep/data
2 | include codeprep/VERSION
3 | include codeprep/*.yaml
4 | include *.md
5 | include *.txt
6 | include LICENSE
7 |
--------------------------------------------------------------------------------
/codeprep/VERSION:
--------------------------------------------------------------------------------
1 | 1.0.5
2 |
--------------------------------------------------------------------------------
/codeprep/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | import logging.config
7 | import os
8 | import yaml
9 |
10 | from codeprep.config import root_package_dir, version
11 |
12 |
13 | def load_logging_config():
14 | path = os.path.join(root_package_dir, 'logging.yaml')
15 | if os.path.exists(path):
16 | with open(path, 'rt') as f:
17 | logging_config = yaml.safe_load(f.read())
18 | logging.config.dictConfig(logging_config)
19 | else:
20 | logging.basicConfig(level=logging.DEBUG)
21 |
22 |
23 | load_logging_config()
24 |
25 | logging.getLogger('matplotlib').setLevel(logging.INFO)
26 | logging.getLogger('Ronin').setLevel(logging.INFO)
27 |
28 | __version__ = version
--------------------------------------------------------------------------------
/codeprep/__main__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import sys
6 |
7 | from codeprep.cli.spec import parse_and_run
8 |
9 |
10 | def main():
11 | parse_and_run(sys.argv[1:])
12 |
13 |
14 | if __name__ == '__main__':
15 | main()
--------------------------------------------------------------------------------
/codeprep/api/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/codeprep/api/common.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import sys
6 | from typing import Optional
7 |
8 | from codeprep.prepconfig import PrepConfig, PrepParam, get_possible_str_values
9 |
10 |
11 | def create_split_value(split_type: str, bpe_codes_id: Optional[str] = None, full_strings: bool = False,
12 | split_numbers: bool = False, ronin: bool = False, stem: bool = False):
13 | if split_type == 'nosplit':
14 | return 'F' if full_strings else '0'
15 | elif split_type == 'chars':
16 | return '8'
17 | elif split_type == 'basic':
18 | if stem:
19 | return 's'
20 | elif ronin:
21 | return '3'
22 | elif split_numbers:
23 | return '2'
24 | else:
25 | return '1'
26 | elif split_type == 'bpe':
27 | if bpe_codes_id == '1k':
28 | return '5'
29 | elif bpe_codes_id == '5k':
30 | return '4'
31 | elif bpe_codes_id == '10k':
32 | return '6'
33 | else:
34 | return '9'
35 | else:
36 | raise AssertionError(f"Invalid split option: {split_type}")
37 |
38 |
39 | def create_str_value(no_str: bool, max_str_len: int) -> str:
40 | if no_str:
41 | return '0'
42 | if 0 <= max_str_len < 2:
43 | return '2'
44 | if 2 <= max_str_len < len(get_possible_str_values()):
45 | return get_possible_str_values()[max_str_len]
46 | else:
47 | return '1'
48 |
49 |
50 | def create_prep_config(spl_type: str, bpe_codes_id: Optional[str] = None, no_spaces: bool = False,
51 | no_unicode: bool = False, no_case: bool = False, no_com: bool = False, no_str: bool = False,
52 | full_strings: bool = False, max_str_length: int = sys.maxsize, split_numbers: bool = False,
53 | ronin: bool = False, stem: bool = False):
54 | return PrepConfig({
55 | PrepParam.EN_ONLY: 'U' if no_unicode else 'u',
56 | PrepParam.COM: '0' if no_com else 'c',
57 | PrepParam.STR: create_str_value(no_str, max_str_length),
58 | PrepParam.SPLIT: create_split_value(spl_type, bpe_codes_id=bpe_codes_id, full_strings=full_strings,
59 | split_numbers=split_numbers, ronin=ronin, stem=stem),
60 | PrepParam.TABS_NEWLINES: '0' if no_spaces else 's',
61 | PrepParam.CASE: 'l' if no_case else 'u'
62 | })
--------------------------------------------------------------------------------
/codeprep/bpepkg/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/codeprep/bpepkg/bpe_config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from enum import Enum
6 | from typing import Dict
7 |
8 | from codeprep.prepconfig import PrepConfig, PrepParam
9 |
10 |
11 | class BpeConfigNotSupported(Exception):
12 | pass
13 |
14 |
15 | class BpeParam(str, Enum):
16 | CASE: str = 'case'
17 | WORD_END: str = 'wordend'
18 | BASE: str = 'base'
19 | UNICODE: str = 'unicode'
20 |
21 |
22 | class BpeConfig(object):
23 | possible_param_values = {
24 | BpeParam.CASE: ['yes'],
25 | BpeParam.WORD_END: [True, False],
26 | BpeParam.BASE: ["all", "code", "java"],
27 | BpeParam.UNICODE: ['yes', 'no', 'bytes'],
28 | }
29 |
30 | @staticmethod
31 | def _check_param_number(n_passed_params: int):
32 | n_expected_params = len([i for i in BpeParam])
33 | if n_passed_params != n_expected_params:
34 | raise ValueError(f'Expected {n_expected_params} params, got {n_passed_params}')
35 |
36 | @staticmethod
37 | def _check_invariants(params: Dict[BpeParam, str]):
38 | BpeConfig._check_param_number(len(params))
39 | for pp in BpeParam:
40 | if params[pp] not in BpeConfig.possible_param_values[pp]:
41 | raise ValueError(f'Invalid value {params[pp]} for prep param {pp}, '
42 | f'possible values are: {BpeConfig.possible_param_values[pp]}')
43 |
44 | def __init__(self, params: Dict[BpeParam, str]):
45 | BpeConfig._check_invariants(params)
46 |
47 | self.params = params
48 |
49 | def get_param_value(self, param: BpeParam) -> str:
50 | return self.params[param]
51 |
52 | def to_prep_config(self):
53 | return PrepConfig({
54 | PrepParam.EN_ONLY: 'U' if self.get_param_value(BpeParam.UNICODE) == 'no' else 'u',
55 | PrepParam.COM: '0',
56 | PrepParam.STR: 'E',
57 | PrepParam.SPLIT: 'F',
58 | PrepParam.TABS_NEWLINES: 's',
59 | PrepParam.CASE: 'u'
60 | })
61 |
62 | UNICODE_NO = 'nounicode'
63 | UNICODE_BYTES = 'bytes'
64 | CASE_NO = 'nocase'
65 | CASE_PREFIX = 'prefix'
66 | WORD_END = 'we'
67 |
68 | @staticmethod
69 | def from_suffix(suffix: str):
70 | if suffix.find(BpeConfig.CASE_NO) != -1:
71 | case = 'no'
72 | elif suffix.find(BpeConfig.CASE_PREFIX) != -1:
73 | case = 'prefix'
74 | else:
75 | case = 'yes'
76 |
77 | if suffix.find(BpeConfig.UNICODE_NO) != -1:
78 | unicode = 'no'
79 | elif suffix.find(BpeConfig.UNICODE_BYTES) != -1:
80 | unicode = 'bytes'
81 | else:
82 | unicode = 'yes'
83 |
84 |
85 | return BpeConfig({
86 | BpeParam.CASE: case,
87 | BpeParam.WORD_END: suffix.find(BpeConfig.WORD_END) != -1,
88 | BpeParam.BASE: 'code',
89 | BpeParam.UNICODE: unicode,
90 | })
91 |
92 | def to_suffix(self):
93 | """
94 | >>> bpe_config = BpeConfig({
95 | ... BpeParam.CASE: 'yes',
96 | ... BpeParam.WORD_END: False,
97 | ... BpeParam.BASE: 'all',
98 | ... BpeParam.UNICODE: 'yes'
99 | ... })
100 | >>> bpe_config.to_suffix()
101 | ''
102 |
103 | >>> bpe_config = BpeConfig({
104 | ... BpeParam.CASE: 'yes',
105 | ... BpeParam.WORD_END: True,
106 | ... BpeParam.BASE: 'all',
107 | ... BpeParam.UNICODE: 'no'
108 | ... })
109 | >>> bpe_config.to_suffix()
110 | 'we_nounicode'
111 |
112 | >>> bpe_config = BpeConfig({
113 | ... BpeParam.CASE: 'yes',
114 | ... BpeParam.WORD_END: False,
115 | ... BpeParam.BASE: 'all',
116 | ... BpeParam.UNICODE: 'bytes'
117 | ... })
118 | >>> bpe_config.to_suffix()
119 | 'bytes'
120 |
121 | """
122 | suffix_parts = []
123 |
124 | if self.get_param_value(BpeParam.CASE) == 'no':
125 | suffix_parts.append(BpeConfig.CASE_NO)
126 | elif self.get_param_value(BpeParam.CASE) == 'prefix':
127 | suffix_parts.append(BpeConfig.CASE_PREFIX)
128 |
129 | if self.get_param_value(BpeParam.WORD_END):
130 | suffix_parts.append(BpeConfig.WORD_END)
131 |
132 | if self.get_param_value(BpeParam.UNICODE) == 'no':
133 | suffix_parts.append(BpeConfig.UNICODE_NO)
134 | elif self.get_param_value(BpeParam.UNICODE) == 'bytes':
135 | suffix_parts.append(BpeConfig.UNICODE_BYTES)
136 |
137 | return "_".join(suffix_parts)
138 |
139 | def __eq__(self, other):
140 | return self.params == other.params
141 |
142 | def __str__(self) -> str:
143 | parts = [str(self.params[k]) for k in BpeParam]
144 | return "_".join(parts)
145 |
146 | def __repr__(self):
147 | return str(self.params)
--------------------------------------------------------------------------------
/codeprep/bpepkg/bpe_encode.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | import os
7 | import sys
8 |
9 | import argparse
10 | from typing import List, Dict
11 |
12 | from tqdm import tqdm
13 |
14 | from codeprep.bpepkg.merge import MergeList, read_merges
15 | from codeprep.config import DEFAULT_BPE_DIR
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class BpeData(object):
21 | def __init__(self, merges_cache=None, merges: MergeList=None):
22 | self.merges_cache = merges_cache
23 | self.merges = merges
24 |
25 |
26 | ESCAPE_CHAR = '@'
27 |
28 | ESCAPABLE_CHAR_LIST = [] + [ESCAPE_CHAR]
29 |
30 |
31 | def escape(word: str, merged: bool=False) -> str:
32 | word = word.replace(ESCAPE_CHAR, 2 * ESCAPE_CHAR)
33 | if merged:
34 | return f"{word}{ESCAPE_CHAR}"
35 | else:
36 | return f"{word} {ESCAPE_CHAR}"
37 |
38 |
39 | def unescape(parts: List[str]):
40 | if parts[-1][-1] != ESCAPE_CHAR:
41 | raise ValueError(f"There should be {ESCAPE_CHAR} at the end, however this is what was passed: {parts}")
42 |
43 | parts[-1] = parts[-1][:-1]
44 | return list(map(lambda p: p.replace(ESCAPE_CHAR + '@', ESCAPE_CHAR), parts))
45 |
46 |
47 | def to_char_list(word: str):
48 | i = 0
49 | res = []
50 | while i < len(word):
51 | if word[i] != ESCAPE_CHAR or i+1 == len(word):
52 | res.append(word[i])
53 | i += 1
54 | elif word[i+1] in ESCAPABLE_CHAR_LIST:
55 | res.append(word[i:i+2])
56 | i += 2
57 | else:
58 | raise ValueError(f"Illegal escape sequence: {word[i:i+2]}")
59 | return res
60 |
61 |
62 | def encode(words: Dict[str, int], merges: MergeList) -> Dict[str, int]:
63 | letters_list = {" ".join(to_char_list(k)): v for k, v in words.items()}
64 |
65 | new_letters_list = {}
66 | for letters, freq in letters_list.items():
67 | subwords = letters.split(" ")
68 |
69 | show_bpe_progress_bar = False
70 | if len(subwords) > 5000:
71 | logger.warning(f'Encountered a string of length {len(subwords)}. It will take a while to bpe-encode it.')
72 | show_bpe_progress_bar = True
73 |
74 | if show_bpe_progress_bar:
75 | bpe_progress = tqdm(total=len(merges))
76 | last_value = 0
77 | while True:
78 | merge_indices = []
79 | merge_candidate_priority = sys.maxsize
80 | for i in range(len(subwords) - 1):
81 | merge_candidate = (subwords[i], subwords[i + 1])
82 | if merge_candidate in merges:
83 | current_merge_candidate_priority = merges.get_priority(merge_candidate)
84 | if current_merge_candidate_priority < merge_candidate_priority:
85 | merge_candidate_priority = current_merge_candidate_priority
86 | merge_indices = [i]
87 | elif current_merge_candidate_priority == merge_candidate_priority:
88 | if not merge_indices or merge_indices[-1] != i - 1:
89 | merge_indices.append(i)
90 |
91 | if not merge_indices:
92 | break
93 |
94 | subwords_after_this_merge_round = []
95 | start_idx = 0
96 | for merge_index in merge_indices:
97 | for i in range(start_idx, merge_index):
98 | subwords_after_this_merge_round.append(subwords[i])
99 | subwords_after_this_merge_round.append(subwords[merge_index] + subwords[merge_index + 1])
100 | start_idx = merge_index + 2
101 | for i in range(start_idx, len(subwords)):
102 | subwords_after_this_merge_round.append(subwords[i])
103 | subwords = subwords_after_this_merge_round
104 | if show_bpe_progress_bar:
105 | bpe_progress.update(merge_candidate_priority - last_value)
106 | last_value = merge_candidate_priority
107 | if show_bpe_progress_bar:
108 | bpe_progress.update(len(merges) - last_value)
109 | bpe_progress.close()
110 |
111 | new_letters_list[" ".join(subwords)] = freq
112 | return new_letters_list
113 |
114 |
115 | def encode_word(word: str, merges: MergeList) -> List[str]:
116 | """
117 | >>> merge_file = os.path.join(DEFAULT_BPE_DIR, '10k', 'merges.txt')
118 | >>> merges = read_merges(merge_file, 10000)
119 |
120 | >>> encode_word('this@@is_all_one_String@', merges)
121 | ['this', '@@', 'is_', 'all', '_', 'one', '_', 'String@']
122 |
123 | >>> encode_word('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa@', merges)
124 | ['aaaaaaaa', 'aaaaaaaa', 'aaaaaaaa', 'aaaaaaaa', 'aaaa', 'a', 'a@']
125 |
126 | >>> encode_word('erererererererererererer@', merges)
127 | ['er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er', 'er@']
128 |
129 | >>> encode_word('@', merges)
130 | ['@']
131 |
132 | >>> encode_word('', merges)
133 | ['']
134 |
135 | >>> encode_word('split@', merges)
136 | ['split@']
137 |
138 | >>> encode_word('aaa', merges)
139 | ['aa', 'a']
140 |
141 | >>> encode_word('this\xa0is@@a@@@@bit@@@@larger\xa0stringwith\xa0some@@unicode@@possibly\xf7@', merges)
142 | ['this', '\\xa0', 'is', '@@', 'a', '@@', '@@', 'bit', '@@', '@@', 'l', 'arg', 'er', '\\xa0', 'string', \
143 | 'with', '\\xa0', 's', 'ome', '@@', 'unic', 'ode', '@@', 'pos', 'si', 'b', 'ly', '÷', '@']
144 | """
145 | enc_word, _ = encode({word: 0}, merges).popitem()
146 | subwords = enc_word.split(" ")
147 | return subwords
148 |
149 |
150 | def get_bpe_subwords(word: str, bpe_data: BpeData) -> List[str]:
151 | merges = bpe_data.merges
152 | cache = bpe_data.merges_cache
153 | word = escape(word, merged=True)
154 | if word in cache:
155 | result = cache[word]
156 | else:
157 | result = encode_word(word, merges)
158 |
159 | return unescape(result)
160 |
161 |
162 | __all__ = [encode, encode_word]
163 |
164 |
165 | if __name__ == '__main__':
166 | arg_parser = argparse.ArgumentParser()
167 | arg_parser.add_argument('merges-file', action='store', help='path to file with merges')
168 | arg_parser.add_argument('word', action='store', help='word to encode', default='if')
169 |
170 | args = arg_parser.parse_args()
171 |
172 | merges = read_merges(args.merges_file)
173 |
174 | subwords = encode_word(args.word, merges)
175 | print(subwords)
--------------------------------------------------------------------------------
/codeprep/bpepkg/bpe_learn.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import collections
6 | import logging
7 |
8 | import regex
9 | from tqdm import tqdm
10 | from typing import Dict, List, Tuple, Set
11 |
12 | from codeprep.bpepkg.merge import Merge, MergeList
13 | from codeprep.util import PriorityCounter
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | # ======== BPE algo itself
18 |
19 |
20 | def get_stats(split_base_vocab: Dict[str, int]) -> PriorityCounter:
21 | pairs = collections.defaultdict(int)
22 | for word, freq in split_base_vocab.items():
23 | symbols = word.split(' ')
24 | for i in range(len(symbols) - 1):
25 | pairs [symbols[i], symbols[i + 1]] += freq
26 | return PriorityCounter(pairs)
27 |
28 |
29 | def merge_vocab(pair: Tuple[str, str], input_vocab: Dict[str, int]) -> Tuple[Dict[str, int], List]:
30 | """
31 | >>> pair = ('w', 'o')
32 | >>> input_vocab = {'b i r d @': 3, 'w o r d @': 7, 'w o g @': 13}
33 | >>> new_vocab, new_pairs = merge_vocab(pair, input_vocab)
34 | >>> new_vocab
35 | {'b i r d @': 3, 'wo r d @': 7, 'wo g @': 13}
36 | >>> new_pairs
37 | [(('wo', 'r'), 7), (('o', 'r'), -7), (('wo', 'g'), 13), (('o', 'g'), -13)]
38 | """
39 | output_vocab = {}
40 | concat_pair_with_space = ' '.join(pair)
41 | concat_pair_with_space_escaped = regex.escape(concat_pair_with_space)
42 | concat_pair = ''.join(pair)
43 | reg = regex.compile('(^|[^ ]+ )(' + concat_pair_with_space_escaped + ')( [^ ]+|$)')
44 | added_pairs = []
45 | for word in input_vocab:
46 | word_occurences = input_vocab[word]
47 | match = reg.search(word)
48 | while match:
49 | # word changed
50 | if match.group(1) != '':
51 | subtoken_before = match.group(1)[:-1]
52 | added_pairs.append(((subtoken_before, concat_pair), word_occurences))
53 | if pair != (subtoken_before, pair[0]):
54 | added_pairs.append(((subtoken_before, pair[0]), -word_occurences))
55 | if match.group(3) != '':
56 | subtoken_after = match.group(3)[1:]
57 | added_pairs.append(((concat_pair, subtoken_after), word_occurences))
58 | if pair != (pair[1], subtoken_after):
59 | added_pairs.append(((pair[1], subtoken_after), -word_occurences))
60 | start, end = match.span(2)
61 | replacement = concat_pair
62 | word = word[:start] + replacement + word[end:]
63 | match = reg.search(word)
64 | output_vocab[word] = word_occurences
65 | return output_vocab, added_pairs
66 |
67 |
68 | def do_merges(vocab: Dict[str, int], n_merges: int) -> Tuple[Dict[str, int], MergeList]:
69 | """
70 | Do `n_merges` bpe merges starting from vocabulary splittings `vocab` which were formed after applying `already_done_merges` merges
71 |
72 | :param vocab: base vocab splittings formed after applying `already_done_merges` in a format
73 | {"fix me@": 3242, "a b c@": 400}
74 | :param n_merges: number of bpe merges to be applied
75 | :param already_done_merges: merges which has already been applied in a format ["e @, f i", "fi x", "m e@"]
76 |
77 | :return: a tuple where the first elements is the resulting vocab splittings,
78 | the second one are all the merges done to reach those vocab splittings
79 |
80 | >>> input_vocab = {
81 | ... "b i r d @": 3,
82 | ... "w o r d @": 7,
83 | ... "w o g @": 13
84 | ... }
85 |
86 | >>> vocab, merges = do_merges(input_vocab, 10)
87 | >>> vocab
88 | {'bird@': 3, 'word@': 7, 'wog@': 13}
89 | >>> merges
90 | [('w', 'o'): (20, 0), ('g', '@'): (13, 1), ('wo', 'g@'): (13, 2), ('r', 'd'): (10, 3), ('rd', '@'): (10, 4), \
91 | ('wo', 'rd@'): (7, 5), ('b', 'i'): (3, 6), ('bi', 'rd@'): (3, 7)]
92 |
93 |
94 | >>> input_vocab = {"a a a a a @": 3}
95 |
96 | >>> do_merges(input_vocab, 10)
97 | ({'aaaaa@': 3}, [('a', 'a'): (12, 0), ('a', '@'): (3, 1), ('aa', 'aa'): (3, 2), ('aaaa', 'a@'): (3, 3)])
98 |
99 | >>> input_vocab = {"l a l a l a @": 3}
100 | >>> do_merges(input_vocab, 10)
101 | ({'lalala@': 3}, [('l', 'a'): (9, 0), ('la', 'la'): (6, 1), ('la', '@'): (3, 2), ('lala', 'la@'): (3, 3)])
102 |
103 | """
104 | merges = MergeList()
105 | pairs = get_stats(vocab)
106 | for i in tqdm(range(n_merges), total=n_merges):
107 | try:
108 | best, occurences = pairs.pop_pair()
109 | merges.append(Merge(best, freq=occurences, priority=i))
110 | except KeyError:
111 | break
112 | vocab, added_pairs = merge_vocab(best, vocab)
113 | for p in added_pairs:
114 | pairs.add(*p)
115 | return vocab, merges
116 |
117 | # ======== Create auxiliary data structures.
118 |
119 |
120 | def create_bpe_cache(vocab: Dict[str, int]) -> Dict[str, List[str]]:
121 | merges_cache = {}
122 | for entry, _ in vocab.items():
123 | subword_list = entry.split(' ')
124 | key = ''.join(subword_list)
125 | merges_cache[key] = subword_list
126 | return merges_cache
127 |
128 |
129 | def create_resulting_vocab(split_base_vocab: Dict[str, int]) -> Dict[str, int]:
130 | resulting_vocab = collections.defaultdict(int)
131 | for entry, frequency in split_base_vocab.items():
132 | for subword in entry.split(" "):
133 | resulting_vocab[subword] += frequency
134 | return resulting_vocab
135 |
136 | # ============
137 |
138 |
139 | def separate_vocabs(all_vocab: Dict[str, int], tokens_to_exclude: Set[str]) -> Tuple[Dict[str, int], Dict[str, int]]:
140 | main_vocab = {}
141 | other_vocab = {}
142 | for k, v in all_vocab.items():
143 | if k not in tokens_to_exclude:
144 | main_vocab[k] = v
145 | else:
146 | other_vocab[k] = v
147 | return main_vocab, other_vocab
--------------------------------------------------------------------------------
/codeprep/bpepkg/cache.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | """
6 | >>> import tempfile
7 | >>> f = tempfile.NamedTemporaryFile(delete=False)
8 | >>> cache = {'ab': ['a', 'b'], '\\t\\xa0': ['\\t', '\\xa0']}
9 | >>> dump_bpe_cache(cache, f.name)
10 | >>> cache == read_bpe_cache(f.name)
11 | True
12 | """
13 | from typing import List, Dict
14 |
15 | from codeprep.util import to_literal_str, to_non_literal_str
16 |
17 | KEY_VALUE_DELIM = '\t'
18 | VALUE_PARTS_DELIM = ' '
19 |
20 |
21 | def read_bpe_cache(file: str) -> Dict[str, List[str]]:
22 | words = {}
23 | with open(file, 'r') as f:
24 | for line in f:
25 | line = line.rstrip('\n')
26 | splits = line.split(KEY_VALUE_DELIM)
27 | second_column = to_non_literal_str(splits[1]).split(VALUE_PARTS_DELIM)
28 | words[to_non_literal_str(splits[0])] = second_column
29 | return words
30 |
31 |
32 | def dump_bpe_cache(dct: Dict[str, List[str]], file: str) -> None:
33 | with open(file, 'w') as f:
34 | for word, subwords in dct.items():
35 | a = to_literal_str(" ".join(subwords))
36 | f.write(f'{to_literal_str(str(word))}{KEY_VALUE_DELIM}{a}\n')
--------------------------------------------------------------------------------
/codeprep/bpepkg/merge.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import copy
6 |
7 | from typing import List, Tuple, Union, Optional, Iterator, Dict
8 |
9 | from codeprep.util import is_python_3_6_and_higher, to_literal_str, to_non_literal_str
10 |
11 |
12 | # TODO this class should be frozen
13 | class Merge(object):
14 | def __init__(self, pair: Tuple[str, str], freq: int = None, priority: int = None):
15 | self.pair = pair
16 | self.freq = freq
17 | self.priority = priority
18 |
19 | @classmethod
20 | def parse_file_entry(cls, line: str, priority: int) -> "Merge":
21 | try:
22 | spl = to_non_literal_str(line).split(" ")
23 | if len(spl) == 2:
24 | return cls((spl[0], spl[1]), priority=priority)
25 | else:
26 | return cls((spl[0], spl[1]), freq=int(spl[2]), priority=priority)
27 | except (IndexError, TypeError) as err:
28 | raise ValueError(f"Invalid merge entry format: {line}", err)
29 |
30 | def __str__(self):
31 | return self.__repr__()
32 |
33 | def __repr__(self):
34 | return f'{self.pair}: ({self.freq}, {self.priority})'
35 |
36 | def __eq__(self, other):
37 | return self.__class__ == other.__class__ and self.pair == other.pair and self.priority == other.priority \
38 | and self.freq == other.freq
39 |
40 | def __hash__(self):
41 | return hash((self.pair, self.priority, self.freq))
42 |
43 |
44 | class MergeList(object):
45 | """
46 | >>> merges = MergeList()
47 | >>> merges = merges.append(Merge(('a', 'b'), 34, 0)).append(Merge(('b', 'c'), 44, 1))
48 | >>> [m for m in merges]
49 | [('a', 'b'): (34, 0), ('b', 'c'): (44, 1)]
50 | >>> len(merges)
51 | 2
52 | >>> merges[0]
53 | ('a', 'b'): (34, 0)
54 | >>> merges[1]
55 | ('b', 'c'): (44, 1)
56 | >>> merges[-1]
57 | ('b', 'c'): (44, 1)
58 | >>> merges[0:-1]
59 | [('a', 'b'): (34, 0)]
60 | >>> type(merges[0:-1])
61 |
62 |
63 | >>> merges[2]
64 | Traceback (most recent call last):
65 | ...
66 | IndexError: list index out of range
67 |
68 | >>> ('a', 'b') in merges
69 | True
70 | >>> ('a', 'x') in merges
71 | False
72 |
73 | >>> merge1 = Merge(('a', 'b'), 34, 0)
74 | >>> merge2 = Merge(('a', 'b'), 34, 0)
75 | >>> dct = {merge1: 3}
76 | >>> dct[merge2]
77 | 3
78 |
79 | >>> merges + MergeList().append(Merge(('d', 'e'), 84, 0))
80 | [('a', 'b'): (34, 0), ('b', 'c'): (44, 1), ('d', 'e'): (84, 2)]
81 | >>> merges + [(('d', 'e'), 84, 1)]
82 | Traceback (most recent call last):
83 | ...
84 | TypeError: Cannot add to a MergeList
85 |
86 | >>> merges + merges
87 | Traceback (most recent call last):
88 | ...
89 | ValueError: It's only possible to add merges in priority order. The priority of the next merge should be 2 but is 3
90 |
91 | >>> merges.append(Merge(('x', 'y'), 34, 0))
92 | Traceback (most recent call last):
93 | ...
94 | ValueError: It's only possible to add merges in priority order. The priority of the next merge should be 2 but is 0
95 |
96 | >>> merges = merges.append(Merge(('x', 'y'), 34))
97 | >>> merges
98 | [('a', 'b'): (34, 0), ('b', 'c'): (44, 1), ('x', 'y'): (34, 2)]
99 | >>> merges.get_priority(('x', 'y'))
100 | 2
101 | """
102 | def __init__(self):
103 | self.merges: Dict[Tuple[str, str], Merge] = {}
104 |
105 | def __contains__(self, item):
106 | return item in self.merges
107 |
108 | def __len__(self):
109 | return len(self.merges)
110 |
111 | def __iter__(self) -> Iterator[Merge]:
112 | return iter(self._get_sorted_merges())
113 |
114 | def _get_sorted_merges(self) -> List[Merge]:
115 | if not is_python_3_6_and_higher():
116 | # we cannot rely on dict order for python versions lower than 3.6
117 | raise NotImplementedError()
118 |
119 | return list(self.merges.values())
120 |
121 | def __add__(self, other: 'MergeList'):
122 | if self.__class__ != other.__class__:
123 | raise TypeError(f"Cannot add {other.__class__} to a MergeList")
124 |
125 | new_merge_list = copy.deepcopy(self)
126 | other_copy = copy.deepcopy(other)
127 | first_list_len = len(new_merge_list)
128 | for merge in other_copy:
129 | merge.priority += first_list_len
130 | new_merge_list.append(merge)
131 |
132 | return new_merge_list
133 |
134 | def append(self, merge: Merge) -> 'MergeList':
135 | # along with the pair we save its priority and the number of its occurrences
136 | if merge.priority is None:
137 | merge.priority = len(self.merges)
138 | elif merge.priority != len(self.merges):
139 | raise ValueError(f"It's only possible to add merges in priority order. "
140 | f"The priority of the next merge should be {len(self.merges)} but is {merge.priority}")
141 |
142 | self.merges[merge.pair] = merge
143 | return self
144 |
145 | def get_priority(self, pair: Tuple[str, str]) -> int:
146 | return self.merges[pair].priority
147 |
148 | def __getitem__(self, item) -> Union[List[Merge], Merge]:
149 | lst = self._get_sorted_merges()
150 | return lst[item]
151 |
152 | def __repr__(self):
153 | return repr(self[:])
154 |
155 | def __eq__(self, other):
156 | return self.__class__ == other.__class__ and self[:] == other[:]
157 |
158 |
159 | def read_merges(file: str, n_merges: Optional[int] = None) -> MergeList:
160 | merges = MergeList()
161 | with open(file, 'r') as f:
162 | for idx, line in enumerate(f):
163 | if n_merges and idx >= n_merges:
164 | break
165 | line = line.rstrip('\n')
166 | merges.append(Merge.parse_file_entry(line, idx))
167 | return merges
168 |
169 |
170 | def dump_merges(merges: MergeList, file: str):
171 | with open(file, 'w') as f:
172 | for merge in merges:
173 | f.write(f"{to_literal_str(' '.join(merge.pair))} {merge.freq}\n")
--------------------------------------------------------------------------------
/codeprep/cli/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/codeprep/cli/impl.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | import os
7 | from typing import Dict, Optional, Any
8 |
9 | import sys
10 |
11 | import codeprep
12 | import codeprep.api.corpus
13 | import codeprep.api.text
14 | from codeprep.api.common import create_split_value, create_str_value
15 | from codeprep.bpepkg.bpe_config import BpeParam, BpeConfig
16 | from codeprep.pipeline import bpelearner
17 | from codeprep.pipeline.bperegistry import InvalidBpeCodesIdError, USER_PREDEFINED_BPE_CODES
18 | from codeprep.pipeline.dataset import Dataset, normalize_extension_string
19 | from codeprep.prepconfig import PrepConfig, PrepParam
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def set_log_level(args: Dict[str, str]) -> None:
25 | if args['--verbose']:
26 | logging.root.setLevel(logging.DEBUG)
27 | else:
28 | logging.root.setLevel(logging.ERROR)
29 |
30 |
31 | def get_option(args: Dict, option: str) -> Optional[Any]:
32 | return args[option] if option in args else None
33 |
34 |
35 | def is_option_true(args: Dict, option: str) -> bool:
36 | return bool(get_option(args, option))
37 |
38 |
39 | def handle_learnbpe(args):
40 | set_log_level(args)
41 | path = os.path.abspath(args['--path'])
42 | bpe_config = create_bpe_config_from_args(args)
43 | n_merges = int(args[''])
44 | if args['--legacy']:
45 | parsed_extensions = normalize_extension_string(args['--ext'])
46 | if parsed_extensions and parsed_extensions != ['java']:
47 | print("Only --ext 'java' is supported when --legacy is specified")
48 | return
49 | else:
50 | extensions = 'java'
51 | else:
52 | extensions = args['--ext']
53 | bpe_codes_id = args['--id']
54 | dataset = Dataset.create(path, bpe_config.to_prep_config(), extensions, None, bpe_config)
55 |
56 | if not dataset.bpe_codes_id:
57 | dataset.assign_bpe_codes_id(bpe_config, predefined_bpe_codes_id=bpe_codes_id)
58 | elif bpe_codes_id:
59 | logger.warning(f"Ignoring passed bpe codes id: {bpe_codes_id}. "
60 | f"This dataset has already been assigned id: {dataset.bpe_codes_id}")
61 |
62 | bpelearner.run(dataset, n_merges, bpe_config)
63 |
64 |
65 | def handle_splitting(args: Dict) -> None:
66 | set_log_level(args)
67 | try:
68 | prep_config = create_prep_config_from_args(args)
69 | bpe_codes_id = get_option(args, '') or get_predefined_bpe_codes_id(args)
70 | if args['']:
71 | prep_text = codeprep.api.text.preprocess(args[''], prep_config, bpe_codes_id,
72 | extension=args['--ext'])
73 | print(prep_text)
74 | else:
75 | codeprep.api.corpus.preprocess_corpus(args['--path'], prep_config, bpe_codes_id,
76 | extensions=args['--ext'],
77 | output_path=args['--output-path'],
78 | calc_vocab=bool(args['--calc-vocab']))
79 | except InvalidBpeCodesIdError as err:
80 | logger.error(err)
81 | return
82 |
83 |
84 | def create_bpe_config_from_args(run_options: Dict[str, str]) -> BpeConfig:
85 | if run_options['--no-unicode']:
86 | unicode = 'no'
87 | elif run_options['--bytes']:
88 | unicode = 'bytes'
89 | else:
90 | unicode = 'yes'
91 | return BpeConfig({
92 | BpeParam.CASE: 'yes',
93 | BpeParam.WORD_END: run_options["--word-end"],
94 | BpeParam.BASE: 'java' if run_options['--legacy'] else 'code',
95 | BpeParam.UNICODE: unicode
96 | })
97 |
98 |
99 | def create_prep_config_from_args(arguments: Dict) -> PrepConfig:
100 | max_str_length = get_option(arguments, '--max-str-length')
101 | max_str_length = int(max_str_length) if max_str_length is not None else sys.maxsize
102 | return PrepConfig({
103 | PrepParam.EN_ONLY: 'U' if is_option_true(arguments, '--no-unicode') else 'u',
104 | PrepParam.COM: '0' if is_option_true(arguments, '--no-com') else 'c',
105 | PrepParam.STR: create_str_value(is_option_true(arguments, '--no-str'), max_str_length),
106 | PrepParam.SPLIT: create_split_value_from_args(arguments),
107 | PrepParam.TABS_NEWLINES: '0' if is_option_true(arguments, '--no-spaces') else 's',
108 | PrepParam.CASE: 'l' if is_option_true(arguments, '--no-case') else 'u',
109 | })
110 |
111 |
112 | def get_predefined_bpe_codes_id(arguments: Dict) -> Optional[str]:
113 | for predefined_id in USER_PREDEFINED_BPE_CODES:
114 | if is_option_true(arguments, predefined_id):
115 | return predefined_id
116 |
117 | return '0' if is_option_true(arguments, 'chars') else None
118 |
119 |
120 | def create_split_value_from_args(arguments: Dict) -> str:
121 | if is_option_true(arguments, 'nosplit'):
122 | return create_split_value('nosplit', full_strings=is_option_true(arguments, '--full-strings'))
123 | elif is_option_true(arguments, 'chars'):
124 | return create_split_value('chars')
125 | elif is_option_true(arguments, 'basic'):
126 | return create_split_value('basic',
127 | split_numbers=is_option_true(arguments, '--split-numbers'),
128 | ronin=is_option_true(arguments, '--ronin'),
129 | stem=is_option_true(arguments, '--stem'))
130 | elif is_option_true(arguments, 'bpe'):
131 | return create_split_value('bpe', bpe_codes_id=get_predefined_bpe_codes_id(arguments))
132 | else:
133 | raise AssertionError(f"Invalid split option: {arguments}")
--------------------------------------------------------------------------------
/codeprep/cli/vocab.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import argparse
6 | import logging
7 |
8 | from codeprep.dirutils import walk
9 | from codeprep.pipeline.vocab import calc_vocab
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | if __name__ == '__main__':
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('path_to_dataset', action='store', help=f'path to dataset')
17 | parser.add_argument('output_dir', action='store', help=f'output dir')
18 | parser.add_argument('extension', action='store', help=f'extension')
19 |
20 | args = parser.parse_known_args()
21 | args = args[0]
22 |
23 | calc_vocab(args.path_to_dataset, walk(args.path_to_dataset.encode(), extension=args.extension.encode()), args.output_dir)
--------------------------------------------------------------------------------
/codeprep/config.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 |
7 | import appdirs
8 |
9 | TRAIN_DIR = 'train'
10 | TEST_DIR = 'test'
11 | VALID_DIR = 'valid'
12 |
13 | CASE_DIR='case'
14 | NO_CASE_DIR='nocase'
15 |
16 | REPR_DIR = 'repr'
17 | PARSED_DIR = 'parsed'
18 | BPE_DIR = 'bpe'
19 | VOCAB_DIR = 'vocab'
20 |
21 | current_script_location = os.path.realpath(__file__)
22 | root_package_dir = os.path.dirname(current_script_location)
23 | data_dir = os.path.join(root_package_dir, 'data')
24 | test_data_dir = os.path.join(os.path.dirname(root_package_dir), 'test-data')
25 |
26 | app_name='codeprep'
27 |
28 | with open(os.path.join(root_package_dir, 'VERSION')) as version_file:
29 | version = version_file.read().strip()
30 |
31 | USER_CONFIG_DIR = appdirs.user_config_dir(app_name, appauthor=False, version=version)
32 | USER_CACHE_DIR = appdirs.user_cache_dir(app_name, appauthor=False, version=version)
33 |
34 | DEFAULT_FILE_LIST_DIR = os.path.join(USER_CACHE_DIR, 'file_lists')
35 | DEFAULT_PARSED_DATASETS_DIR = os.path.join(USER_CACHE_DIR, 'parsed_datasets')
36 | DEFAULT_PREP_DATASETS_DIR = os.path.join(USER_CACHE_DIR, 'prep_datasets')
37 | DEFAULT_BPE_DIR = os.path.join(data_dir, BPE_DIR)
38 | USER_BPE_DIR = os.path.join(USER_CONFIG_DIR, BPE_DIR)
39 | USER_VOCAB_DIR = os.path.join(USER_CONFIG_DIR, VOCAB_DIR)
40 | DEFAULT_BPE_CACHE_DIR = os.path.join(USER_CACHE_DIR, BPE_DIR)
41 | DEFAULT_CORPUS_SIZES_DIR = os.path.join(USER_CACHE_DIR, 'corpus_sizes')
42 |
43 | REWRITE_PARSED_FILE=False
44 | REWRITE_PREPROCESSED_FILE=False
45 |
46 | CHUNKSIZE=24
47 | LIMIT_FILES_ON_LAST_MODIFICATION_CHECK=1000
48 | LIMIT_FILES_SCANNING=50000
--------------------------------------------------------------------------------
/codeprep/data/allamanis_dataset_metadata/small-train-projects.txt:
--------------------------------------------------------------------------------
1 | AccelService
2 | agentcontest
3 | Ambience
4 | AndroidApp
5 | Android-CalendarView
6 | android-nice-guidance
7 | android_packages_apps_ClassicnerdWallpapers
8 | ANNIS
9 | Arboretum-Kiosk
10 | Atlantis
11 | BacksideUpdater
12 | bee-encode
13 | Birthday-Ring-Kata
14 | bookmarklet-and-chrome-addon
15 | BuilderGen
16 | Californium
17 | cassandra-tutorial
18 | CherryToolsMP
19 | Client
20 | codjo-gui-toolkit
21 | CommuniCase-spmode-mail-support-patch
22 | cookbook-domain
23 | CreativeStick
24 | cuke4duke
25 | DAVEtools
26 | device_allwinner_novo7a
27 | DockSoundRedirector
28 | Dspace-googleanalytics
29 | eclim
30 | elasticsearch-redis-river
31 | erjang
32 | expr
33 | festomat
34 | fluent-logger-java
35 | ftbt
36 | gatein-wsrp
37 | Gilded-Rose-Kata
38 | gora
39 | groovy-eclipse
40 | GWTCleaner
41 | hbasene
42 | hibernateuniversity-devoxx
43 | Hudson-Gerrit-Plugin
44 | imageassert
45 | InterviewStreet
46 | izpack
47 | java-api
48 | javacodes
49 | jboss-aejb
50 | jdeb
51 | JHexView
52 | jogl-utils
53 | Json-lib
54 | karma
55 | labeled-test-groups-publisher-plugin
56 | Library
57 | LockSupport
58 | lzmajio
59 | markdownj
60 | maven-thrift-plugin
61 | metatypes
62 | minerhat
63 | modern-ee-app20
64 | mousefeed
65 | MWM-for-Android
66 | Nearby
67 | NFEGuardian-client
68 | oae-maven-plugin
69 | open311-android
70 | opposite_of_a_bloom_filter
71 | otertool
72 | payments4j
73 | phoneload
74 | platform_packages_apps_alarmclock
75 | play2-scheduled-job-demo
76 | port-allocator-maven-plugin
77 | programmer-rush-beta
78 | Pubtran-London
79 | rabbitosgi
80 | redak
81 | RestFixture
82 | Roll
83 | Saints-Robotics-Programming
84 | SCORMCloud_JavaDemoApp
85 | sensei
86 | sia2012
87 | SimpleWarps
88 | smartgrid
89 | Solr-Ranking-Plugin
90 | SpoutWallet
91 | Spring-MVC-sample-using-iBatis
92 | sql-processor
93 | storm-counts
94 | swagger-spring
95 | Tartiquizz
96 | TenVersion_Project
97 | ThreadPaint
98 | Toast-Plugin-for-Smart-Phone
99 | translator-aquacrop
100 | Twister
101 | UPC_POO_EDU
102 | Video-Display-Portlet
103 | vraptor-jasperreport
104 | websms
105 | Wool-Trees
106 | xml-utilities
107 | zest
--------------------------------------------------------------------------------
/codeprep/data/bpe/case/0/merges.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/codeprep/data/bpe/case/0/merges.txt
--------------------------------------------------------------------------------
/codeprep/data/bpe/nocase/0/merges.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/codeprep/data/bpe/nocase/0/merges.txt
--------------------------------------------------------------------------------
/codeprep/dirutils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | import os
7 | from datetime import datetime
8 |
9 | from typing import Optional, List, Generator
10 |
11 | from codeprep.config import LIMIT_FILES_ON_LAST_MODIFICATION_CHECK
12 | from codeprep.fileutils import has_one_of_extensions
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def walk(path:bytes, extension: Optional[bytes] = None) -> Generator[bytes, None, None]:
18 | if os.path.isfile(path) and (not extension or path.endswith(extension)):
19 | yield path
20 | else:
21 | for root, dirs, files in os.walk(path):
22 | for file in files:
23 | if not extension or file.endswith(extension):
24 | yield os.path.join(root, file)
25 |
26 |
27 | def walk_and_save(path: str, dir_list_path: str, file_list_path: str, return_dirs_instead_of_regular_files: bool,
28 | extensions: Optional[List[str]]) -> Generator[bytes, None, None]:
29 | with open(dir_list_path, 'w') as d, open(file_list_path, 'w') as f:
30 | path_bin = path.encode()
31 | extensions_bin = list(map(lambda e: e.encode(), extensions)) if extensions else None
32 | # we want to list and store all the files a sequences of bytes to avoid problems with different encodings for filenames
33 | if os.path.isfile(path_bin):
34 | res = os.path.basename(path_bin)
35 | f.write(f'{res}\n')
36 | if not return_dirs_instead_of_regular_files:
37 | yield res
38 | else:
39 | for root, dirs, files in os.walk(path_bin, followlinks=True):
40 | # we pass bytes to os.walk -> the output are bytes as well
41 | for dir in dirs:
42 | bin_name = os.path.join(os.path.relpath(root, path_bin), dir)
43 | d.write(f'{bin_name}\n')
44 | if return_dirs_instead_of_regular_files:
45 | yield bin_name
46 | for file in files:
47 | bin_name = os.path.join(os.path.relpath(root, path_bin), file)
48 | if not extensions or has_one_of_extensions(bin_name, extensions_bin):
49 | if not os.path.islink(os.path.join(root, file)):
50 | f.write(f'{bin_name}\n')
51 | if not return_dirs_instead_of_regular_files:
52 | yield bin_name
53 |
54 |
55 | def get_dir_last_modification(path: str, limit: int = LIMIT_FILES_ON_LAST_MODIFICATION_CHECK) -> datetime:
56 |
57 | def walk_path(path):
58 | counter = 0
59 | if os.path.isfile(path) or len(os.listdir(path)) == 0:
60 | yield os.path.getmtime(path)
61 | else:
62 | for root, dirs, files in os.walk(path):
63 | for dir in dirs:
64 | if counter >= limit:
65 | return
66 | counter += 1
67 | yield os.path.getmtime(os.path.join(root, dir))
68 | for file in files:
69 | if counter >= limit:
70 | return
71 | full_path = os.path.join(root, file)
72 | if not os.path.islink(full_path):
73 | counter += 1
74 | yield os.path.getmtime(full_path)
75 |
76 | mtime = max(walk_path(path))
77 | return datetime.fromtimestamp(mtime)
78 |
79 |
80 | def get_timestamp(path: str) -> str:
81 | last_modif_time = get_dir_last_modification(path)
82 | return last_modif_time.strftime("%y-%m-%dT%H-%M-%S")
--------------------------------------------------------------------------------
/codeprep/fileutils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 |
7 | from typing import List, Tuple
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def has_one_of_extensions(name: bytes, extensions: List[bytes]) -> bool:
13 | """
14 | >>> has_one_of_extensions(b'/home/abc.java', [b'java', b'c'])
15 | True
16 |
17 | >>> has_one_of_extensions(b'/home/abc.py', [b'java', b'c'])
18 | False
19 |
20 | >>> has_one_of_extensions(b'/home/abc.dtc', [b'java', b'c'])
21 | False
22 |
23 | >>> has_one_of_extensions(b'/home/abc.f.java.prep', [b'java.prep', b'c'])
24 | True
25 |
26 | >>> has_one_of_extensions(b'/home/abc.f.java.prep', [b'a.prep', b'c'])
27 | False
28 |
29 | """
30 | for ext in extensions:
31 | if name.endswith(b'.' + ext):
32 | return True
33 | return False
34 |
35 |
36 | def read_file_contents(file_path: bytes) -> Tuple[List[str], bytes]:
37 | try:
38 | return read_file_with_encoding(file_path, 'utf-8')
39 | except UnicodeDecodeError:
40 | try:
41 | return read_file_with_encoding(file_path, 'ISO-8859-1')
42 | except UnicodeDecodeError:
43 | logger.error(f"Unicode decode error in file: {file_path}")
44 |
45 |
46 | def read_file_with_encoding(file_path: bytes, encoding: str) -> Tuple[List[str], bytes]:
47 | with open(file_path, 'r', encoding=encoding) as f:
48 | return [line.rstrip('\n') for line in f], file_path
--------------------------------------------------------------------------------
/codeprep/logging.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 | disable_existing_loggers: False
3 | formatters:
4 | with_process_name:
5 | format: "%(asctime)s [%(name)s] %(levelname)s: %(message)s"
6 |
7 | handlers:
8 | console:
9 | class: logging.StreamHandler
10 | level: DEBUG
11 | formatter: with_process_name
12 | stream: ext://sys.stdout
13 | loggers:
14 | codeprep.infrastructure:
15 | level: DEBUG
16 | propagate: True
17 | codeprep.api:
18 | level: DEBUG
19 | propagate: True
20 | codeprep.cli:
21 | level: DEBUG
22 | propagate: True
23 | codeprep.vocab:
24 | level: DEBUG
25 | propagate: True
26 | codeprep.parse_projects:
27 | level: DEBUG
28 | propagate: True
29 | codeprep.to_repr:
30 | level: DEBUG
31 | propagate: True
32 | root:
33 | level: DEBUG
34 | handlers: [console]
--------------------------------------------------------------------------------
/codeprep/noneng.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def is_non_eng(word):
11 | return not __isascii(word)
12 |
13 |
14 | def __isascii(str):
15 | try:
16 | str.encode('ascii')
17 | return True
18 | except UnicodeEncodeError:
19 | return False
20 |
21 |
22 | def replace_non_ascii_seqs(word:str, placeholder: str) -> str:
23 | """
24 | >>> replace_non_ascii_seqs("","\xf7")
25 | ''
26 |
27 | >>> replace_non_ascii_seqs("Ü", "\xf7")
28 | '\xf7'
29 |
30 | >>> replace_non_ascii_seqs("Üüø", "\xf7")
31 | '\xf7'
32 |
33 | >>> replace_non_ascii_seqs("abcd", "\xf7")
34 | 'abcd'
35 |
36 | >>> replace_non_ascii_seqs("aæbñńcdú", "\xf7")
37 | 'a\xf7b\xf7cd\xf7'
38 |
39 | >>> replace_non_ascii_seqs("any_string", "\xf7\xa0")
40 | Traceback (most recent call last):
41 | ...
42 | ValueError: Placeholder should be a single character, but is ÷\xa0
43 | """
44 | if len(placeholder) != 1:
45 | raise ValueError(f"Placeholder should be a single character, but is {placeholder}")
46 |
47 | new_word = ""
48 | ongoing_non_ascii_seq = False
49 | for ch in word:
50 | if ord(ch) < 128:
51 | if ongoing_non_ascii_seq:
52 | new_word += placeholder
53 | ongoing_non_ascii_seq = False
54 | new_word += ch
55 | else:
56 | ongoing_non_ascii_seq = True
57 | if ongoing_non_ascii_seq:
58 | new_word += placeholder
59 |
60 | return new_word
--------------------------------------------------------------------------------
/codeprep/parse/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/codeprep/parse/core.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | from typing import List
7 |
8 | from pygments import lex
9 | from pygments.lexers import get_lexer_by_name, guess_lexer
10 | from pygments.util import ClassNotFound
11 |
12 | from codeprep.parse import matchers
13 | from codeprep.parse.matchers import DefaultMatcher
14 | from codeprep.tokens.rootclasses import ParsedToken
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | matchers = [
19 | matchers.NewLineMatcher(),
20 | matchers.TabMatcher(),
21 | matchers.WhitespaceMatcher(),
22 | matchers.OperatorMatcher(),
23 | matchers.NumberMatchers(),
24 | matchers.WordMatcher(),
25 | matchers.GenericLiteralMatcher(),
26 | matchers.KeywordMatcher(),
27 | matchers.StringMatcher(),
28 | matchers.OneLineCommentMatcher(),
29 | matchers.MultiLineLineCommentMatcher(),
30 | matchers.GenericTokenMatcher()
31 | ]
32 |
33 |
34 | def _convert(token, value: str) -> List[ParsedToken]:
35 | for matcher in matchers:
36 | if matcher.match(token, value):
37 | return matcher.transform(value)
38 |
39 | if DefaultMatcher().match(token, value):
40 | return DefaultMatcher().transform(value)
41 |
42 | assert False
43 |
44 |
45 | def convert_text(text: str, extension: str) -> List[ParsedToken]:
46 | extension = extension or 'java'
47 | if extension:
48 | try:
49 | lexer = get_lexer_by_name(extension)
50 | except ClassNotFound as err:
51 | logger.warning(err)
52 | lexer = guess_lexer(text)
53 | else:
54 | lexer = guess_lexer(text)
55 | for token, value in lex(text, lexer):
56 | model_tokens = _convert(token, value)
57 | for mr in model_tokens:
58 | yield mr
--------------------------------------------------------------------------------
/codeprep/parse/matchers.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List
6 |
7 | from pygments.token import Token
8 |
9 | from codeprep.parse.subtokens import split_into_words, split_string
10 | from codeprep.tokens.containers import StringLiteral, OneLineComment, MultilineComment
11 | from codeprep.tokens.numeric import Number, Zero, One
12 | from codeprep.tokens.rootclasses import ParsedToken
13 | from codeprep.tokens.whitespace import NewLine, Tab
14 | from codeprep.tokens.word import KeyWord, Operator, Semicolon, OpeningCurlyBracket, ClosingCurlyBracket, OpeningBracket, \
15 | ClosingBracket
16 |
17 |
18 | class DefaultMatcher(object):
19 | def match(self, token, value: str) -> bool:
20 | return True
21 |
22 | def transform(self, value: str) -> List[ParsedToken]:
23 | return split_into_words(value)
24 |
25 |
26 | class GenericTokenMatcher(object):
27 | def match(self, token, value: str) -> bool:
28 | return token in Token.Generic
29 |
30 | def transform(self, value: str) -> List[ParsedToken]:
31 | return split_into_words(value)
32 |
33 |
34 | class StringMatcher(object):
35 | def match(self, token, value: str) -> bool:
36 | return token in Token.Literal.String
37 |
38 | def transform(self, value: str) -> List[StringLiteral]:
39 | return [StringLiteral(split_string(value), len(value))]
40 |
41 |
42 | class OneLineCommentMatcher(object):
43 | def match(self, token, value: str) -> bool:
44 | return token is Token.Comment.Single
45 |
46 | def transform(self, value: str) -> List[OneLineComment]:
47 | return [OneLineComment(split_into_words(value))]
48 |
49 |
50 | class MultiLineLineCommentMatcher(object):
51 | def match(self, token, value: str) -> bool:
52 | return token in Token.Comment and not token is Token.Comment.Single
53 |
54 | def transform(self, value: str) -> List[MultilineComment]:
55 | return [MultilineComment(split_into_words(value))]
56 |
57 |
58 | class WordMatcher(object):
59 | def match(self, token, value: str) -> bool:
60 | return token in Token.Name
61 |
62 | def transform(self, value: str) -> List[ParsedToken]:
63 | return split_into_words(value)
64 |
65 |
66 | class GenericLiteralMatcher(object):
67 | def match(self, token, value: str) -> bool:
68 | return token is Token.Literal or token is Token.Literal.Date
69 |
70 | def transform(self, value: str) -> List[ParsedToken]:
71 | return split_into_words(value)
72 |
73 |
74 | class KeywordMatcher(object):
75 | def match(self, token, value: str) -> bool:
76 | return token in Token.Keyword
77 |
78 | def transform(self, value: str) -> List[KeyWord]:
79 | return [KeyWord(value)]
80 |
81 |
82 | class NewLineMatcher(object):
83 | def match(self, token, value: str) -> bool:
84 | return value == '\n'
85 |
86 | def transform(self, value: str) -> List[NewLine]:
87 | return [NewLine()]
88 |
89 |
90 | class WhitespaceMatcher(object):
91 | def match(self, token, value: str) -> bool:
92 | return value.strip() == ''
93 |
94 | def transform(self, value: str) -> List[Tab]:
95 | return [Tab()] * (len(value) // 4)
96 |
97 |
98 | class TabMatcher(object):
99 | def match(self, token, value: str) -> bool:
100 | return value == '\t'
101 |
102 | def transform(self, value: str) -> List[Tab]:
103 | return [Tab()]
104 |
105 |
106 | class NumberMatchers(object):
107 | def match(self, token, value: str) -> bool:
108 | return token in Token.Literal.Number
109 |
110 | def transform(self, value: str) -> List[Number]:
111 | if value == '0':
112 | return [Zero()]
113 | elif value == '1':
114 | return [One()]
115 | else:
116 | return [Number(value)]
117 |
118 |
119 | class OperatorMatcher(object):
120 | def match(self, token, value: str):
121 | return token is Token.Operator or token in Token.Punctuation
122 |
123 | def transform(self, value: str) -> List[Operator]:
124 | if value == ';':
125 | return [Semicolon()]
126 | elif value == '{':
127 | return [OpeningCurlyBracket()]
128 | elif value == '}':
129 | return [ClosingCurlyBracket()]
130 | elif value == '(':
131 | return [OpeningBracket()]
132 | elif value == ')':
133 | return [ClosingBracket()]
134 | else:
135 | return [Operator(value)]
136 |
137 |
138 | class WordOperatorMatcher(object):
139 | def match(self, token, value: str):
140 | return token is Token.Operator.Word
141 |
142 | def transform(self, value: str) -> List[ParsedToken]:
143 | return split_into_words(value)
--------------------------------------------------------------------------------
/codeprep/parse/subtokens.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List
6 |
7 | import regex
8 |
9 | from codeprep.noneng import is_non_eng
10 | from codeprep.tokens.containers import SplitContainer
11 | from codeprep.tokens.noneng import NonEng
12 | from codeprep.tokens.numeric import Number
13 | from codeprep.tokens.rootclasses import ParsedToken
14 | from codeprep.tokens.whitespace import NewLine, Tab, SpaceInString
15 | from codeprep.tokens.word import Underscore, Word, NonCodeChar
16 |
17 |
18 | def split_identifier(token: str) -> SplitContainer:
19 | parts = [m[0] for m in
20 | regex.finditer('(_|[0-9]+|[[:upper:]]?[[:lower:]]+|[[:upper:]]+(?![[:lower:]])|[^ ])', token)]
21 |
22 | processable_tokens = [Word.from_(p) if p != '_' else Underscore() for p in parts]
23 | split_container = SplitContainer(processable_tokens)
24 | return NonEng(split_container) if is_non_eng(token) else split_container
25 |
26 |
27 | # Using the same regexps SLP team uses to parse numbers in java code
28 | # https://github.com/SLP-team/SLP-Core/blob/master/src/main/java/slp/core/lexing/code/JavaLexer.java
29 |
30 | HEX_REGEX = "0x([0-9a-fA-F]+_)*[0-9a-fA-F]+[lL]?"
31 | BIN_REGEX = "0b([01]+_)*[01]+[lL]?"
32 | IR_REGEX = "([0-9]+_)*[0-9]+[lLfFdD]?"
33 | DBL_REGEXA = "[0-9]+\\.[0-9]+([eE][-+]?[0-9]+)?[fFdD]?"
34 | DBL_REGEXB = "[0-9]+\\.([eE][-+]?[0-9]+)?[fFdD]?"
35 | DBL_REGEXC = "\\.[0-9]+([eE][-+]?[0-9]+)?[fFdD]?"
36 | DBL_REGEXD = "[0-9]+[eE][-+]?[0-9]+[fFdD]?"
37 |
38 | NUMBER_PATTERN = f'({HEX_REGEX}|{BIN_REGEX}|{IR_REGEX}|{DBL_REGEXA}|{DBL_REGEXB}|{DBL_REGEXC}|{DBL_REGEXD})'
39 |
40 |
41 | def is_number(word: str) -> bool:
42 | """
43 | >>> is_number("0")
44 | True
45 |
46 | >>> is_number("8")
47 | True
48 |
49 | >>> is_number("-5")
50 | False
51 |
52 | >>> is_number("23450012")
53 | True
54 |
55 | >>> is_number("283463L")
56 | True
57 |
58 | >>> is_number("342424242l")
59 | True
60 |
61 | >>> is_number("0.")
62 | True
63 |
64 | >>> is_number(".0")
65 | True
66 |
67 | >>> is_number(".0d")
68 | True
69 |
70 | >>> is_number("353535.")
71 | True
72 |
73 | >>> is_number("353535.D")
74 | True
75 |
76 | >>> is_number(".353535F")
77 | True
78 |
79 | >>> is_number(".353535f")
80 | True
81 |
82 | >>> is_number("0.2e+3D")
83 | True
84 |
85 | >>> is_number("23424.E-30F")
86 | True
87 |
88 | >>> is_number(".002e-0f")
89 | True
90 |
91 | >>> is_number("0b10101")
92 | True
93 |
94 | >>> is_number("0b0011L") # java -- not python
95 | True
96 |
97 | >>> is_number("0b0")
98 | True
99 |
100 | >>> is_number("0x8AbCc006EfBd")
101 | True
102 |
103 | >>> is_number("0xG12")
104 | False
105 |
106 | >>> is_number("0x56DL")
107 | True
108 |
109 | >>> is_number("0x56Dl")
110 | True
111 | """
112 | return regex.fullmatch(NUMBER_PATTERN, word) is not None
113 |
114 |
115 | def to_parsed_token(token: str) -> ParsedToken:
116 | if token == '\n':
117 | return NewLine()
118 | elif token == '\t':
119 | return Tab()
120 | elif is_number(token):
121 | return Number(token)
122 | elif regex.fullmatch("\\w+", token):
123 | return split_identifier(token)
124 | else:
125 | return NonCodeChar(token)
126 |
127 |
128 | def split_string(token: str) -> List[ParsedToken]:
129 | """
130 | >>> split_string(" var = 9.4\\t\\n")
131 | [ (n_chars=4), SplitContainer[Word(('var', none))], \
132 | (n_chars=1), NonCodeChar(=), (n_chars=1), (9), \
133 | NonCodeChar(.), (4), , ]
134 | """
135 | res = []
136 | arbitrary_whitespace = "( )+"
137 | for m in regex.finditer(f"(\\w+|[^ ]|{arbitrary_whitespace})", token):
138 | if regex.fullmatch(arbitrary_whitespace, m[0]):
139 | res.append(SpaceInString(n_chars=len(m[0])))
140 | else:
141 | res.append(to_parsed_token(m[0]))
142 | return res
143 |
144 |
145 | def split_into_words(token: str) -> List[ParsedToken]:
146 | """
147 | >>> split_into_words(" var = 9.4\\t\\n")
148 | [, SplitContainer[Word(('var', none))], NonCodeChar(=), (9), \
149 | NonCodeChar(.), (4), , ]
150 | """
151 | res = []
152 | four_char_whitespace = " " * 4
153 | for m in regex.finditer(f"(\\w+|[^ ]|{four_char_whitespace})", token):
154 | if m[0] == four_char_whitespace:
155 | res.append(Tab())
156 | else:
157 | res.append(to_parsed_token(m[0]))
158 | return res
--------------------------------------------------------------------------------
/codeprep/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/codeprep/pipeline/bpelearner.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | import os
7 | from typing import Tuple, Dict, Set, Optional
8 |
9 | from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam, BpeConfigNotSupported
10 | from codeprep.bpepkg.bpe_encode import escape
11 | from codeprep.bpepkg.bpe_learn import separate_vocabs, logger, do_merges, create_resulting_vocab, create_bpe_cache
12 | from codeprep.bpepkg.cache import dump_bpe_cache
13 | from codeprep.bpepkg.merge import MergeList, read_merges, dump_merges
14 | from codeprep.pipeline import stages
15 | from codeprep.pipeline.bperegistry import get_max_merges, MERGES_FILE_NAME, MERGES_CACHE_FILE_NAME, \
16 | RESULTING_VOCAB_FILE_NAME, BPE_REASSEMBLED_VOCAB_FILE_NAME
17 | from codeprep.pipeline.dataset import Dataset
18 | from codeprep.pipeline.vocab import _dump_vocab_dict, _load_vocab_dict
19 | from codeprep.util import to_non_literal_str
20 |
21 |
22 | def get_base_vocab(dataset: Dataset) -> Tuple[Dict[str, int], Dict[str, int]]:
23 | stages.run_until_base_bpe_vocab(dataset)
24 | all_vocab = _load_vocab_dict(dataset.path_to_bpe_vocab_file)
25 | non_bpe_vocab = load_nonbpe_vocab(dataset)
26 | return separate_vocabs(all_vocab, non_bpe_vocab)
27 |
28 |
29 | def load_nonbpe_vocab(dataset: Dataset) -> Set[str]:
30 | non_bpe_vocab = set()
31 | with open(dataset.path_to_nonbpe_vocab_file, 'r') as f:
32 | for line in f:
33 | non_bpe_vocab.add(to_non_literal_str(line.rstrip('\n')))
34 | return non_bpe_vocab
35 |
36 |
37 | def check_if_bpe_config_supported(bpe_config: BpeConfig):
38 | if bpe_config.get_param_value(BpeParam.UNICODE) == 'bytes':
39 | raise BpeConfigNotSupported('Byte-BPE is not yet supported')
40 |
41 | if bpe_config.get_param_value(BpeParam.WORD_END):
42 | raise BpeConfigNotSupported('BPE with word-end characters are not yet supported')
43 |
44 | if bpe_config.get_param_value(BpeParam.CASE) == 'prefix':
45 | raise BpeConfigNotSupported('BPE with case encoded in prefix is not yet supported')
46 |
47 |
48 | def prepare_vocabs(dataset: Dataset, dir_with_most_merges, starting_from_scratch):
49 | if starting_from_scratch:
50 | base_bpe_vocab, other_vocab = get_base_vocab(dataset) # TODO extract this into stages
51 | other_vocab = {escape(k, merged=True): v for k, v in other_vocab.items()}
52 | split_base_vocab = {escape(" ".join(k)): v for k, v in base_bpe_vocab.items()}
53 | else:
54 | path_to_bpe_vocab_file = os.path.join(dir_with_most_merges, BPE_REASSEMBLED_VOCAB_FILE_NAME)
55 | non_bpe_vocab = {escape(k, merged=True) for k in load_nonbpe_vocab(dataset)}
56 | split_base_vocab = _load_vocab_dict(path_to_bpe_vocab_file)
57 | split_base_vocab, other_vocab = separate_vocabs(split_base_vocab, non_bpe_vocab)
58 |
59 | return split_base_vocab, other_vocab
60 |
61 |
62 | def get_dir_with_most_merges(dataset_bpe_path, n_merges) -> Optional[str]:
63 | max_merges = get_max_merges(dataset_bpe_path, n_merges)
64 | if not max_merges:
65 | return None
66 |
67 | dir_with_most_merges = os.path.join(dataset_bpe_path, str(max_merges))
68 | return dir_with_most_merges
69 |
70 |
71 | def save_results(split_base_vocab, merges, new_bpe_dir):
72 |
73 | os.makedirs(new_bpe_dir)
74 |
75 | resulting_vocab = create_resulting_vocab(split_base_vocab)
76 | resulting_vocab_sorted = sorted(resulting_vocab.items(), key=lambda x: x[1], reverse=True)
77 | _dump_vocab_dict(resulting_vocab_sorted, os.path.join(new_bpe_dir, RESULTING_VOCAB_FILE_NAME))
78 |
79 | bpe_cache = create_bpe_cache(split_base_vocab)
80 | dump_bpe_cache(bpe_cache, os.path.join(new_bpe_dir, MERGES_CACHE_FILE_NAME))
81 |
82 | dump_merges(merges, os.path.join(new_bpe_dir, MERGES_FILE_NAME))
83 | _dump_vocab_dict(split_base_vocab.items(), os.path.join(new_bpe_dir, BPE_REASSEMBLED_VOCAB_FILE_NAME))
84 | logger.info(f'Bpe output files are saved into {new_bpe_dir} folder')
85 |
86 |
87 | def run(dataset: Dataset, n_merges: int, bpe_config: BpeConfig) -> None:
88 |
89 | check_if_bpe_config_supported(bpe_config)
90 | dataset_bpe_path = dataset.bpe_path
91 |
92 | dir_with_most_merges = get_dir_with_most_merges(dataset_bpe_path, n_merges)
93 |
94 | if dir_with_most_merges:
95 | logger.info("Using existing merges...")
96 | already_done_merges = read_merges(os.path.join(dir_with_most_merges, MERGES_FILE_NAME))
97 | else:
98 | logger.info("Starting encoding from scratch. ..")
99 | already_done_merges = MergeList()
100 |
101 | split_base_vocab, other_vocab = prepare_vocabs(dataset, dir_with_most_merges,
102 | starting_from_scratch=not dir_with_most_merges)
103 |
104 | logger.info("Learning bpe codes...")
105 | split_base_vocab, merges = do_merges(split_base_vocab, n_merges - len(already_done_merges))
106 | for k, v in other_vocab.items():
107 | split_base_vocab[k] = v
108 | merges = already_done_merges + merges
109 |
110 | new_bpe_dir = os.path.join(dataset_bpe_path, str(len(merges)))
111 | if os.path.exists(new_bpe_dir):
112 | logging.info("Merges already learned!")
113 | return
114 |
115 | save_results(split_base_vocab, merges, new_bpe_dir)
--------------------------------------------------------------------------------
/codeprep/pipeline/parse_projects.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import gzip
6 | import logging
7 | import os
8 | import pickle
9 | from multiprocessing.pool import Pool
10 | from typing import Tuple
11 |
12 | from tqdm import tqdm
13 |
14 | from codeprep.config import REWRITE_PARSED_FILE, CHUNKSIZE, LIMIT_FILES_SCANNING
15 | from codeprep.fileutils import read_file_contents
16 | from codeprep.pipeline.dataset import Dataset, NOT_FINISHED_EXTENSION
17 | from codeprep.parse.core import convert_text
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def preprocess_and_write(params: Tuple[bytes, bytes]) -> None:
23 | src_file_path, dest_file_path = params
24 |
25 | dest_dirname = os.path.dirname(dest_file_path)
26 | if not os.path.exists(dest_dirname):
27 | os.makedirs(dest_dirname, exist_ok=True)
28 |
29 | if not REWRITE_PARSED_FILE and os.path.exists(dest_file_path):
30 | logger.warning(f"File {dest_file_path} already exists! Doing nothing.")
31 | return
32 |
33 | not_finished_dest_file_path = dest_file_path + NOT_FINISHED_EXTENSION.encode()
34 | with gzip.GzipFile(not_finished_dest_file_path, 'wb') as f:
35 | try:
36 | lines_from_file, path = read_file_contents(src_file_path)
37 | except FileNotFoundError:
38 | logger.error(f"File was found when scanning the directory, but cannot be read: {src_file_path}. "
39 | f"Invalid symlink? Ignoring ...")
40 | return
41 | extension_bin = os.path.splitext(src_file_path)[1].decode()[1:]
42 | parsed = [p for p in convert_text("\n".join(lines_from_file), extension_bin)]
43 | pickle.dump(parsed, f, pickle.HIGHEST_PROTOCOL)
44 |
45 | os.rename(not_finished_dest_file_path, dest_file_path)
46 |
47 |
48 | def params_generator(dataset: Dataset):
49 | for input_file_path in dataset.original.file_iterator():
50 | output_file_path = dataset.original.get_new_file_name(input_file_path, dataset.parsed)
51 | yield (input_file_path, output_file_path)
52 |
53 |
54 | def run(dataset: Dataset) -> None:
55 | logger.info(f"Getting files from {dataset.original.path}")
56 | logger.info(f"Writing preprocessed files to {dataset.parsed.path}")
57 |
58 | if dataset.files_need_to_be_saved():
59 | files_total = 0
60 | for _ in dataset.get_all_files():
61 | files_total += 1
62 | print(f"Files scanned: {files_total}", end='\r')
63 | if files_total > LIMIT_FILES_SCANNING:
64 | files_total = None
65 | logger.info(f"Total files to be preprocessed: {LIMIT_FILES_SCANNING}+")
66 | break
67 | else:
68 | files_total = len([f for f in dataset.get_all_files()])
69 | with Pool() as pool:
70 | it = pool.imap_unordered(preprocess_and_write, params_generator(dataset), chunksize=CHUNKSIZE)
71 | for _ in tqdm(it, total=files_total):
72 | pass
73 | if not dataset.suppress_caching:
74 | dataset.parsed.set_ready()
--------------------------------------------------------------------------------
/codeprep/pipeline/stages.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | """This module runs different stages of preprocessing flow and makes sure not to rerun a stage if its results are already available.
6 | """
7 | import logging
8 | import os
9 | from typing import Optional
10 |
11 | from codeprep.pipeline import parse_projects, to_repr
12 | from codeprep.pipeline.bperegistry import CustomBpeConfig
13 | from codeprep.pipeline.dataset import Dataset, is_path_ready, is_path_outdated, archive_path
14 | from codeprep.pipeline.vocab import calc_vocab
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | #TODO remove code duplication in methods below
19 | def run_parsing(dataset: Dataset) -> None:
20 | logger.info("Parsing...")
21 | if not dataset.parsed.ready():
22 | parse_projects.run(dataset)
23 | elif dataset.parsed.is_outdated():
24 | dataset.parsed.archive()
25 | parse_projects.run(dataset)
26 | else:
27 | logger.info("Parsed dataset is up-to-date.")
28 |
29 |
30 | def run_until_preprocessing(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]=None) -> None:
31 | run_parsing(dataset)
32 | logger.info("Preprocessing...")
33 | if not dataset.preprocessed.ready():
34 | to_repr.run(dataset, custom_bpe_config)
35 | elif dataset.preprocessed.is_outdated():
36 | dataset.preprocessed.archive()
37 | to_repr.run(dataset, custom_bpe_config)
38 | else:
39 | logger.info(f"Dataset is already preprocessed and up-to-date.")
40 |
41 |
42 | def run_until_base_bpe_vocab(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]=None) -> None:
43 | run_until_preprocessing(dataset, custom_bpe_config)
44 | logger.info("Computing base bpe vocab...")
45 | if not is_path_ready(dataset.path_to_bpe_vocab_file):
46 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.base_bpe_vocab_path)
47 | elif is_path_outdated(dataset.path_to_bpe_vocab_file):
48 | archive_path(dataset.path_to_bpe_vocab_file)
49 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.base_bpe_vocab_path)
50 | else:
51 | logger.info("Vocabulary is already computed and up-to-date")
52 |
53 |
54 | def run_until_vocab(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]=None) -> None:
55 | logger.info(f'Checking first if vocabulary file exists: {dataset.path_to_vocab_file}')
56 | if os.path.exists(dataset.path_to_vocab_file):
57 | logger.info("Vocabulary is already computed and up-to-date")
58 | return
59 |
60 | if not is_path_ready(dataset.path_to_vocab_file):
61 | run_until_preprocessing(dataset, custom_bpe_config)
62 | logger.info("Computing vocab...")
63 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.vocab_path)
64 | elif is_path_outdated(dataset.path_to_vocab_file):
65 | run_until_preprocessing(dataset, custom_bpe_config)
66 | logger.info("Computing vocab...")
67 | archive_path(dataset.path_to_bpe_vocab_file)
68 | calc_vocab(dataset.preprocessed.path, dataset.preprocessed.file_iterator(), dataset.vocab_path)
69 | else:
70 | raise AssertionError()
--------------------------------------------------------------------------------
/codeprep/pipeline/to_repr.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import gzip
6 | import logging
7 | import os
8 | import pickle
9 | import platform
10 | from multiprocessing.pool import Pool
11 | from typing import List, Tuple
12 | from typing import Optional
13 |
14 | import time
15 | from tqdm import tqdm
16 |
17 | from codeprep.bpepkg.bpe_encode import read_merges, BpeData
18 | from codeprep.bpepkg.cache import read_bpe_cache
19 | from codeprep.config import DEFAULT_BPE_DIR, NO_CASE_DIR, CASE_DIR, DEFAULT_BPE_CACHE_DIR, REWRITE_PREPROCESSED_FILE, \
20 | CHUNKSIZE, LIMIT_FILES_SCANNING
21 | from codeprep.pipeline import vocabloader
22 | from codeprep.pipeline.bperegistry import CustomBpeConfig
23 | from codeprep.pipeline.dataset import Dataset, NOT_FINISHED_EXTENSION
24 | from codeprep.prepconfig import PrepParam, PrepConfig
25 | from codeprep.preprocess.core import to_repr_list
26 | from codeprep.preprocess.metadata import PreprocessingMetadata
27 | from codeprep.preprocess.metadata import save_metadata
28 | from codeprep.preprocess.placeholders import placeholders
29 | from codeprep.tokens.rootclasses import ParsedToken
30 | from codeprep.tokens.word import SpecialToken
31 | from codeprep.util import to_literal_str
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | def get_global_bpe_data_if_available() -> Optional[BpeData]:
37 | return global_bpe_data if 'global_bpe_data' in globals() else None
38 |
39 |
40 | def insert_and_word_tokens(prep_list: List[str], metadata: PreprocessingMetadata) -> List[str]:
41 | list_copy = [elm for elm in prep_list]
42 | for index in metadata.word_boundaries[1:]:
43 | list_copy[index-1] += placeholders['compound_word_end']
44 | return list_copy
45 |
46 |
47 | def to_repr(prep_config: PrepConfig, token_list: List[ParsedToken],
48 | bpe_data: Optional[BpeData] = None) -> Tuple[List[str], PreprocessingMetadata]:
49 | bpe_data = bpe_data or get_global_bpe_data_if_available()
50 | repr_list, metadata = to_repr_list(token_list, prep_config.get_repr_config(bpe_data))
51 | if prep_config.is_bpe():
52 | repr_list = insert_and_word_tokens(repr_list, metadata)
53 | return repr_list, metadata
54 |
55 |
56 | def to_token_str(tokens: List) -> str:
57 | return " ".join(map(lambda t: str(t), tokens))
58 |
59 |
60 | def preprocess_and_write(params: Tuple[bytes, bytes, PrepConfig, str], bpe_data: Optional[BpeData] = None):
61 | src_file_path, dest_file_path, prep_config, part_nonbpe_vocab_folder = params
62 |
63 | dest_dirname = os.path.dirname(dest_file_path)
64 | if not os.path.exists(dest_dirname):
65 | os.makedirs(dest_dirname, exist_ok=True)
66 |
67 | if not REWRITE_PREPROCESSED_FILE and os.path.exists(dest_file_path):
68 | logger.warning(f"File {dest_file_path} already exists! Doing nothing.")
69 | return
70 |
71 | not_finished_dest_file_path = dest_file_path + NOT_FINISHED_EXTENSION.encode()
72 | with gzip.GzipFile(src_file_path, 'rb') as i, open(not_finished_dest_file_path, 'w') as o:
73 | token_list = pickle.load(i)
74 | bpe_data = get_global_bpe_data_if_available() if bpe_data is None else bpe_data
75 | repr, metadata = to_repr(prep_config, token_list + [SpecialToken(placeholders['ect'])], bpe_data)
76 | o.write(to_literal_str(to_token_str(repr)) + '\n')
77 |
78 | if part_nonbpe_vocab_folder:
79 | save_metadata(metadata, os.path.join(part_nonbpe_vocab_folder, f'{os.path.basename(dest_file_path)}_-_{time.time()}'))
80 |
81 | os.rename(not_finished_dest_file_path, dest_file_path)
82 |
83 | #TODO make this method independent of actual directory structure
84 | def init_bpe_data(prep_config: PrepConfig, custom_bpe_config: Optional[CustomBpeConfig], force_reinit: bool=True):
85 | if get_global_bpe_data_if_available() and not force_reinit:
86 | return # already initialized
87 | global global_bpe_data
88 | global_bpe_data = BpeData()
89 | if custom_bpe_config:
90 | logger.info(f'Using bpe merges file: {custom_bpe_config.codes_file}')
91 | if custom_bpe_config.can_use_cache_file():
92 | global_bpe_data.merges_cache = read_bpe_cache(custom_bpe_config.cache_file)
93 | else:
94 | global_bpe_data.merges_cache = {}
95 | global_bpe_data.merges = read_merges(custom_bpe_config.codes_file, custom_bpe_config.n_merges)
96 |
97 | if custom_bpe_config.n_merges:
98 | logger.info(f'Using first {custom_bpe_config.n_merges} merges.')
99 | nonbpe_vocab = vocabloader.nonbpe(custom_bpe_config.merge_list_id)
100 | global_bpe_data.merges_cache.update({s: [s] for s in nonbpe_vocab})
101 | else:
102 | bpe_n_merges_dict = {'4': '5k', '5': '1k', '6': '10k', '7': '20k', '8': '0'}
103 | bpe_n_merges = bpe_n_merges_dict[prep_config.get_param_value(PrepParam.SPLIT)]
104 |
105 | bpe_merges_file = os.path.join(DEFAULT_BPE_DIR,
106 | CASE_DIR if prep_config.get_param_value(PrepParam.CASE) == 'u' else NO_CASE_DIR,
107 | str(bpe_n_merges), 'merges.txt')
108 | bpe_merges_cache_file = os.path.join(DEFAULT_BPE_CACHE_DIR,
109 | CASE_DIR if prep_config.get_param_value(PrepParam.CASE) == 'u' else NO_CASE_DIR,
110 | str(bpe_n_merges), 'merges_cache.txt')
111 | if os.path.exists(bpe_merges_cache_file):
112 | global_bpe_data.merges_cache = read_bpe_cache(bpe_merges_cache_file)
113 | else:
114 | global_bpe_data.merges_cache = {}
115 | global_bpe_data.merges = read_merges(bpe_merges_file)
116 |
117 |
118 | def params_generator(dataset: Dataset, path_to_part_metadata: Optional[str]):
119 | for input_file_path in dataset.parsed.file_iterator():
120 | output_file_path = dataset.parsed.get_new_file_name(input_file_path, dataset.preprocessed)
121 | yield (input_file_path, output_file_path, dataset.prep_config, path_to_part_metadata)
122 |
123 |
124 | def get_n_cpus_to_be_used():
125 | system_platform = platform.system()
126 | n_cpus = 1 if system_platform in ['Windows', 'Darwin'] else os.cpu_count() or 1
127 | logger.info(f"Platform: {system_platform}, n cores to be used: {n_cpus}")
128 | return n_cpus
129 |
130 |
131 | def run(dataset: Dataset, custom_bpe_config: Optional[CustomBpeConfig]) -> None:
132 | path_to_parsed_dataset = dataset.parsed.path
133 |
134 | if not os.path.exists(path_to_parsed_dataset):
135 | logger.error(f"Dir does not exist: {path_to_parsed_dataset}")
136 | exit(3)
137 | logger.info(f"Reading parsed files from: {path_to_parsed_dataset}")
138 |
139 | if dataset.prep_config.is_bpe():
140 | init_bpe_data(dataset.prep_config, custom_bpe_config)
141 |
142 | if not os.path.exists(dataset.path_to_nonbpe_vocab_file) and dataset.prep_config.is_base_bpe_config():
143 | path_to_part_metadata = f'{dataset.path_to_nonbpe_vocab_file}_part'
144 | else:
145 | path_to_part_metadata = None
146 | if path_to_part_metadata and not os.path.exists(path_to_part_metadata):
147 | os.makedirs(path_to_part_metadata)
148 |
149 | logger.info(f"Writing preprocessed files to {dataset.preprocessed.path}")
150 |
151 | if dataset.files_need_to_be_saved():
152 | files_total = 0
153 | for _ in dataset.get_all_files():
154 | files_total += 1
155 | print(f"Files scanned: {files_total}", end='\r')
156 | if files_total > LIMIT_FILES_SCANNING:
157 | files_total = None
158 | logger.info(f"Total files to be preprocessed: {LIMIT_FILES_SCANNING}+")
159 | break
160 | else:
161 | files_total = len([f for f in dataset.get_all_files()])
162 | n_cpus = get_n_cpus_to_be_used()
163 | if n_cpus > 1:
164 | with Pool(processes=n_cpus) as pool:
165 | it = pool.imap_unordered(preprocess_and_write, params_generator(dataset, path_to_part_metadata), chunksize=CHUNKSIZE)
166 | for _ in tqdm(it, total=files_total):
167 | pass
168 | else:
169 | for params in tqdm(params_generator(dataset, path_to_part_metadata), total=files_total):
170 | preprocess_and_write(params, get_global_bpe_data_if_available())
171 |
172 | if path_to_part_metadata:
173 | vocabloader.gather_non_bpe_vocab(dataset)
174 |
175 | dataset.preprocessed.set_ready()
--------------------------------------------------------------------------------
/codeprep/pipeline/vocabloader.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | import os
7 | import shutil
8 | from typing import Set, Dict
9 |
10 | from codeprep.pipeline.bperegistry import get_bpe_dir, get_base_vocab_dir, RESULTING_VOCAB_FILE_NAME
11 | from codeprep.pipeline.dataset import Dataset, NONBPE_VOCAB_FILENAME
12 | from codeprep.preprocess.placeholders import placeholders
13 | from codeprep.util import to_literal_str
14 | from codeprep.pipeline.vocab import _load_vocab_dict, _load_vocab_set, VOCAB_FILENAME
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | def all(merge_list_id: str) -> Dict[str, int]:
19 | bpe_dir = get_base_vocab_dir(merge_list_id)
20 | return _load_vocab_dict(os.path.join(bpe_dir, VOCAB_FILENAME))
21 |
22 |
23 | def nonbpe(merge_list_id: str) -> Set[str]:
24 | bpe_dir = get_base_vocab_dir(merge_list_id)
25 | return _load_vocab_set(os.path.join(bpe_dir, NONBPE_VOCAB_FILENAME))
26 |
27 |
28 | def base(merge_list_id: str) -> Dict[str, int]:
29 | all_vocab = all(merge_list_id)
30 | nonbpe_vocab = nonbpe(merge_list_id)
31 | for token in nonbpe_vocab:
32 | if token in all_vocab:
33 | del all_vocab[token]
34 | return all_vocab
35 |
36 |
37 | def bpe(merge_list_id: str, n_merges: int) -> Dict[str, int]:
38 | bpe_dir = get_bpe_dir(merge_list_id, n_merges)
39 | return _load_vocab_dict(os.path.join(bpe_dir, RESULTING_VOCAB_FILE_NAME))
40 |
41 |
42 | def gather_non_bpe_vocab(dataset: Dataset):
43 | logger.info("Gathering non-bpe vocab...")
44 | part_nonbpe_vocab_dir = f'{dataset.path_to_nonbpe_vocab_file}_part'
45 | non_bpe_tokens: Set[str] = set()
46 | for idx, file in enumerate(os.listdir(part_nonbpe_vocab_dir)):
47 | if idx % 569 == 0:
48 | print(f'Files processed: {idx}', end='\r')
49 | non_bpe_tokens.update(_load_vocab_set(os.path.join(part_nonbpe_vocab_dir, file)))
50 |
51 | non_bpe_tokens.update(list(placeholders.values()))
52 | with open(dataset.path_to_nonbpe_vocab_file, 'w') as f:
53 | for token in non_bpe_tokens:
54 | f.write(f'{to_literal_str(token)}\n')
55 | shutil.rmtree(part_nonbpe_vocab_dir)
--------------------------------------------------------------------------------
/codeprep/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/codeprep/preprocess/core.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import Tuple, List, Sequence
6 |
7 | from codeprep.preprocess.metadata import PreprocessingMetadata
8 | from codeprep.preprocess.reprconfig import ReprConfig
9 | from codeprep.tokens.rootclasses import ParsedToken
10 |
11 |
12 | def to_repr_list(token_list: Sequence[ParsedToken], repr_config: ReprConfig) \
13 | -> Tuple[List[str], PreprocessingMetadata]:
14 | repr_res = []
15 | all_metadata = PreprocessingMetadata()
16 | for token in token_list:
17 | repr_token, metadata = torepr(token, repr_config)
18 | repr_res.extend(repr_token)
19 | all_metadata.update(metadata)
20 | return repr_res, all_metadata
21 |
22 |
23 | def torepr(token, repr_config) -> Tuple[List[str], PreprocessingMetadata]:
24 | clazz = type(token)
25 | if clazz == str:
26 | raise AssertionError('Strings are not allowed any more as a result of parsing')
27 | if clazz == list:
28 | return to_repr_list(token, repr_config)
29 | if repr_config and clazz in repr_config.types_to_be_repr:
30 | return token.preprocessed_repr(repr_config)
31 | else:
32 | non_prep, metadata = token.non_preprocessed_repr(repr_config)
33 | return (non_prep if isinstance(non_prep, list) else [non_prep]), metadata
--------------------------------------------------------------------------------
/codeprep/preprocess/metadata.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import logging
6 | from typing import Set, Optional, List, Type, Tuple
7 |
8 | from codeprep.subtokens import is_terminal_subtoken
9 | from codeprep.util import to_literal_str
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class InvalidMetadataError(Exception):
15 | pass
16 |
17 |
18 | class PreprocessingMetadata(object):
19 | def __init__(self,
20 | nonprocessable_tokens: Optional[Set[str]] = None,
21 | word_boundaries: Optional[List[int]] = None,
22 | token_types: List[Type] = None):
23 | self.nonprocessable_tokens = nonprocessable_tokens or set()
24 | self.word_boundaries = word_boundaries or [0]
25 | self.token_types = token_types or []
26 |
27 | self._check_invariants()
28 |
29 | def _check_invariants(self) -> None:
30 | assert len(self.word_boundaries) - 1 == len(self.token_types)
31 |
32 | def set_all_tokens_type(self, t: Type) -> None:
33 | self.token_types = [t] * (len(self.word_boundaries) -1)
34 |
35 | def update(self, preprocessing_metadata: 'PreprocessingMetadata') -> 'PreprocessingMetadata':
36 | """
37 | >>> class TypeA: pass
38 | >>> class TypeB: pass
39 | >>> PreprocessingMetadata().update(PreprocessingMetadata())
40 | (set(), [0], [])
41 |
42 | >>> PreprocessingMetadata({''}, [0, 2], [TypeA]).update(PreprocessingMetadata({''}, [0, 1, 2, 3], [TypeA, TypeA, TypeB]))
43 | ({''}, [0, 2, 3, 4, 5], ['TypeA', 'TypeA', 'TypeA', 'TypeB'])
44 |
45 | >>> PreprocessingMetadata(set(), [0, 2], [TypeA]).update(PreprocessingMetadata(set(), [0, 3], [TypeB]))
46 | (set(), [0, 2, 5], ['TypeA', 'TypeB'])
47 | """
48 | self.nonprocessable_tokens.update(preprocessing_metadata.nonprocessable_tokens)
49 |
50 | n_subtokens = self.word_boundaries.pop()
51 | for boundary in preprocessing_metadata.word_boundaries:
52 | self.word_boundaries.append(n_subtokens + boundary)
53 |
54 | self.token_types.extend(preprocessing_metadata.token_types)
55 |
56 | return self
57 |
58 | def __repr__(self):
59 | return str((self.nonprocessable_tokens, self.word_boundaries, list(map(lambda x: x.__name__, self.token_types))))
60 |
61 | def __eq__(self, other):
62 | return self.__class__ == other.__class__ \
63 | and self.nonprocessable_tokens == other.nonprocessable_tokens \
64 | and self.word_boundaries == other.word_boundaries \
65 | and self.token_types == other.token_types
66 |
67 |
68 | def save_metadata(metadata: PreprocessingMetadata, save_to: bytes) -> None:
69 | with open(save_to, 'w') as f:
70 | for token in metadata.nonprocessable_tokens:
71 | f.write(f'{to_literal_str(token)}\n')
72 |
73 |
74 | def check_metadata_validity(subwords: List[str], metadata: PreprocessingMetadata, use_only_token_end_chars=True) -> None:
75 | word_boundaries = metadata.word_boundaries
76 | if len(word_boundaries) == 0:
77 | raise ValueError("Word boundaries list should contain at least 0!")
78 | if len(subwords) != word_boundaries[-1]:
79 | raise ValueError(f"Word boundaries list should contain the indices of the last word.\n"
80 | f"However, the subword entropies list has {len(subwords)} elements, and "
81 | f"value {len(subwords)} is not found in word boundaries list: {word_boundaries}")
82 | if word_boundaries[0] != 0:
83 | raise ValueError('Word boundaries list must start with 0!')
84 |
85 | if use_only_token_end_chars:
86 | for idx, token in enumerate(subwords):
87 | end_according_to_data = is_terminal_subtoken(token)
88 | end_according_to_metadata = (idx + 1) in metadata.word_boundaries
89 | if end_according_to_data != end_according_to_metadata:
90 | error_context_start_index = idx - 20 if idx - 20 > 0 else 0
91 | error_context_end_index = idx + 20 if idx + 20 < len(subwords) else len(subwords) - 1
92 | raise AssertionError(f'Token {token} according to metadata is'
93 | f'{" " if end_according_to_metadata else " NOT"} end-token. '
94 | f'Showing context: {subwords[error_context_start_index:error_context_end_index]}')
95 |
96 |
97 | def with_empty_metadata(tokens: List[str]) -> Tuple[List[str], PreprocessingMetadata]:
98 | return tokens, PreprocessingMetadata()
99 |
100 |
101 | def unwrap_single_string(tokens_and_metadata: Tuple[List[str], PreprocessingMetadata]) -> str:
102 | tokens = tokens_and_metadata[0]
103 | if isinstance(tokens, list) and len(tokens) == 1:
104 | return tokens[0]
--------------------------------------------------------------------------------
/codeprep/preprocess/placeholders.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | placeholders = {
6 | 'comment': '',
7 | 'string_literal': '',
8 | 'word_start': '',
9 | 'word_end': '',
10 | 'capital': '',
11 | 'capitals': '',
12 | 'ect': '',
13 | 'non_eng': '',
14 | 'non_ascii_seq': '\xf7',
15 | 'non_eng_content': '',
16 | 'olc_end': '',
17 | 'compound_word_end': '',
18 | 'space_in_str': '\xa0'
19 | }
--------------------------------------------------------------------------------
/codeprep/preprocess/reprconfig.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import Optional, Callable, List
6 |
7 | from codeprep.bpepkg.bpe_encode import BpeData
8 |
9 |
10 | Splitter = Callable[[str, BpeData], List[str]]
11 |
12 |
13 | class ReprConfig(object):
14 | def __init__(self, types_to_be_repr,
15 | bpe_data: Optional[BpeData],
16 | should_lowercase: bool,
17 | number_splitter: Splitter,
18 | word_splitter: Optional[Splitter],
19 | full_strings: bool,
20 | max_str_length: int):
21 | self.types_to_be_repr = types_to_be_repr
22 | self.bpe_data = bpe_data
23 | self.should_lowercase = should_lowercase
24 | self.number_splitter = number_splitter
25 | self.word_splitter = word_splitter
26 | self.full_strings = full_strings
27 | self.max_str_length = max_str_length
--------------------------------------------------------------------------------
/codeprep/stemming.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from nltk.stem import PorterStemmer
6 |
7 | stemmer = PorterStemmer()
8 |
9 |
10 | def stem(word: str):
11 | if not word:
12 | return word
13 |
14 | stemmed = stemmer.stem(word)
15 | if word.isupper():
16 | return stemmed.upper()
17 | elif word[0].isupper():
18 | return stemmed.capitalize()
19 | else:
20 | return stemmed
--------------------------------------------------------------------------------
/codeprep/subtokens.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List, Callable, Any
6 |
7 | from codeprep.preprocess.placeholders import placeholders
8 |
9 |
10 | class TokenIterator(object):
11 | def __init__(self, subwords, word_boundaries, format, return_full_token_index):
12 | self.validate_word_boundaries(subwords, word_boundaries)
13 |
14 | self.subwords = subwords
15 | self.word_boundaries = word_boundaries
16 | self.format = format
17 | self.return_full_token_index = return_full_token_index
18 |
19 | def __iter__(self):
20 | return self
21 |
22 | @staticmethod
23 | def validate_word_boundaries(subwords: List[str], word_boundaries: List[int]) -> None:
24 | if len(word_boundaries) == 0:
25 | raise ValueError("Word boundaries list should contain at least 0!")
26 | if len(subwords) != word_boundaries[-1]:
27 | raise ValueError(f"Word boundaries list should contain the indices of the last word.\n"
28 | f"However, the subword entropies list has {len(subwords)} elements, and "
29 | f"value {len(subwords)} is not found in word boundaries list: {word_boundaries}")
30 | if word_boundaries[0] != 0:
31 | raise ValueError('Word boundaries list must start with 0!')
32 |
33 |
34 | class SubtokenIterator(TokenIterator):
35 | """
36 | >>> [token for token in SubtokenIterator(['hi', 'the', 're'], [0, 1, 3])]
37 | ['hi', 'the', 're']
38 |
39 | >>> [token for token in SubtokenIterator([1, 2, 3], [0, 1, 3], format=lambda s: str(s[0]))]
40 | ['1', '2', '3']
41 |
42 | >>> [token for token in SubtokenIterator(['hi', 'the', 're'], [0, 1, 3], return_full_token_index=True)]
43 | [(0, 'hi'), (1, 'the'), (1, 're')]
44 |
45 | >>> [token for token in SubtokenIterator(['hi'], [0])]
46 | Traceback (most recent call last):
47 | ...
48 | ValueError: Word boundaries list should contain the indices of the last word.
49 | However, the subword entropies list has 1 elements, and value 1 is not found in word boundaries list: [0]
50 |
51 | >>> [token for token in SubtokenIterator(['hi'], [1])]
52 | Traceback (most recent call last):
53 | ...
54 | ValueError: Word boundaries list must start with 0!
55 | """
56 | def __init__(self, subwords: List[Any],
57 | word_boundaries: List[int],
58 | format: Callable[[List[str]], Any] = lambda l: l[0],
59 | return_full_token_index: bool = False):
60 |
61 | super().__init__(subwords, word_boundaries, format, return_full_token_index)
62 |
63 | self.current_index = 0
64 | self.current_full_word = 0
65 |
66 | def __next__(self):
67 | if self.current_index >= len(self.subwords):
68 | raise StopIteration
69 |
70 | value = [self.subwords[self.current_index]]
71 | formatted_value = self.format(value)
72 | result = (self.current_full_word, formatted_value) if self.return_full_token_index else formatted_value
73 |
74 | self.current_index += 1
75 | if self.word_boundaries[self.current_full_word + 1] == self.current_index:
76 | self.current_full_word += 1
77 |
78 | return result
79 |
80 |
81 | class FullTokenIterator(TokenIterator):
82 | """
83 | >>> [token for token in FullTokenIterator(['hi', 'the', 're'], [0, 1, 3])]
84 | ['hi', 'there']
85 |
86 | >>> [token for token in FullTokenIterator(['hel', 'l', 'o'], [0, 3])]
87 | ['hello']
88 |
89 | >>> [token for token in FullTokenIterator([1, 2, 4], [0, 2, 3], format=sum)]
90 | [3, 4]
91 |
92 | >>> [token for token in FullTokenIterator(['hi', 'the', 're'], [0, 1, 3], return_full_token_index=True)]
93 | [(0, 'hi'), (1, 'there')]
94 |
95 | >>> [token for token in FullTokenIterator([], [])]
96 | Traceback (most recent call last):
97 | ...
98 | ValueError: Word boundaries list should contain at least 0!
99 |
100 | >>> [token for token in FullTokenIterator(['hi'], [0])]
101 | Traceback (most recent call last):
102 | ...
103 | ValueError: Word boundaries list should contain the indices of the last word.
104 | However, the subword entropies list has 1 elements, and value 1 is not found in word boundaries list: [0]
105 |
106 | >>> [token for token in FullTokenIterator(['hi'], [1])]
107 | Traceback (most recent call last):
108 | ...
109 | ValueError: Word boundaries list must start with 0!
110 | """
111 | def __init__(self, subwords: List[Any],
112 | word_boundaries: List[int],
113 | format: Callable[[List[str]], Any] = lambda s: ''.join(s),
114 | return_full_token_index: bool = False):
115 | super().__init__(subwords, word_boundaries, format, return_full_token_index)
116 |
117 | self.current_full_word = 0
118 |
119 | def __next__(self):
120 | if self.current_full_word >= len(self.word_boundaries) - 1:
121 | raise StopIteration
122 |
123 | word_start = self.word_boundaries[self.current_full_word]
124 | word_end = self.word_boundaries[self.current_full_word + 1]
125 | formatted_value = self.format(self.subwords[word_start:word_end])
126 | result = (self.current_full_word, formatted_value) if self.return_full_token_index else formatted_value
127 |
128 | self.current_full_word += 1
129 |
130 | return result
131 |
132 |
133 | def is_terminal_subtoken(subtoken: str, use_token_end_chars: bool = True) -> bool:
134 | if not use_token_end_chars:
135 | raise NotImplemented("Finding out if a subtoken is terminal for tokens represented with and tokens "
136 | "is not yet implemented.")
137 |
138 | return subtoken.endswith(placeholders['compound_word_end'])
--------------------------------------------------------------------------------
/codeprep/tokens/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import importlib
6 | from os.path import dirname, basename, isfile, join
7 | import glob
8 |
9 | modules = glob.glob(join(dirname(__file__), "*.py"))
10 | for f in modules:
11 | if isfile(f) and not f.endswith('__init__.py'):
12 | importlib.import_module(f'{__package__}.{basename(f)[:-3]}')
--------------------------------------------------------------------------------
/codeprep/tokens/containers.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List, Tuple, Union, Optional
6 |
7 | from codeprep.noneng import replace_non_ascii_seqs
8 | from codeprep.preprocess.core import ReprConfig, torepr
9 | from codeprep.preprocess.metadata import PreprocessingMetadata
10 | from codeprep.preprocess.placeholders import placeholders
11 | from codeprep.tokens.rootclasses import ParsedToken, ParsedSubtoken
12 | from codeprep.tokens.whitespace import SpaceInString
13 | from codeprep.tokens.word import Word
14 |
15 |
16 | class ProcessableTokenContainer(ParsedToken):
17 | def __init__(self, subtokens: Union[List[ParsedSubtoken], List[ParsedToken]]):
18 | if isinstance(subtokens, list):
19 | self.subtokens = subtokens
20 | else:
21 | raise AssertionError(f"Should be list but is: {subtokens}")
22 |
23 | def add(self, token):
24 | self.subtokens.append(token)
25 |
26 | def get_subtokens(self):
27 | return self.subtokens
28 |
29 | def __eq__(self, other):
30 | return self.__class__ == other.__class__ and self.subtokens == other.subtokens
31 |
32 | def __repr__(self):
33 | return f'{self.__class__.__name__}{self.subtokens}'
34 |
35 | def __str__(self):
36 | return " ".join(self.non_preprocessed_repr()[0])
37 |
38 |
39 | def wrap_in_word_boundaries_if_necessary(res: List[str]) -> List[str]:
40 | if len(res) == 1 or (len(res) == 2 and res[0] in [placeholders['capitals'], placeholders['capital']]):
41 | return res
42 | else:
43 | return [placeholders['word_start']] + res + [placeholders['word_end']]
44 |
45 |
46 | class SplitContainer(ProcessableTokenContainer):
47 | def __init__(self, subtokens: List[ParsedSubtoken]):
48 | super().__init__(subtokens)
49 |
50 | def empty_repr(self):
51 | return self.subtokens
52 |
53 | def __repr__(self):
54 | return f'{self.__class__.__name__}{self.subtokens}'
55 |
56 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
57 | nospl_str = ["".join(map(lambda s: torepr(s, repr_config)[0][0], self.subtokens))]
58 | return self.wrap_in_metadata_for_full_word(nospl_str)
59 |
60 | def preprocessed_repr(self, repr_config) -> Tuple[List[str], PreprocessingMetadata]:
61 | if repr_config.bpe_data:
62 | return self.wrap_in_metadata_for_full_word(repr_config.word_splitter(str(self), repr_config.bpe_data))
63 | res = []
64 | all_metadata = PreprocessingMetadata()
65 | for subtoken in self.subtokens:
66 | r, metadata = torepr(subtoken, repr_config)
67 | res.extend(r if isinstance(r, list) else [r])
68 | all_metadata.update(metadata)
69 | return self.wrap_in_metadata_for_full_word(wrap_in_word_boundaries_if_necessary(res), all_metadata.nonprocessable_tokens)
70 |
71 | @classmethod
72 | def from_single_token(cls, token: str):
73 | return cls([Word.from_(token)])
74 |
75 |
76 | class TextContainer(ProcessableTokenContainer):
77 |
78 | def __init__(self, tokens: List[ParsedToken]):
79 | super().__init__(tokens)
80 | for token in tokens:
81 | if isinstance(token, ParsedSubtoken):
82 | raise TypeError(
83 | f"ParsedTokens cannot be a part of Text container, but one ofn the tokens passed was {type(token)} ({token})")
84 |
85 | def __repr__(self):
86 | return f'{self.__class__.__name__}{self.subtokens}'
87 |
88 | def __eq__(self, other):
89 | return self.__class__ == other.__class__ and self.subtokens == other.subtokens
90 |
91 |
92 | class Comment(TextContainer):
93 | def __init__(self, tokens: List[ParsedToken]):
94 | super().__init__(tokens)
95 |
96 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
97 | return self.wrap_in_metadata_for_full_word([placeholders['comment']])
98 |
99 |
100 | class OneLineComment(Comment):
101 | def __init__(self, tokens: List[ParsedToken]):
102 | super().__init__(tokens)
103 |
104 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
105 | prep_tokens, metadata = torepr(self.subtokens, repr_config)
106 | metadata.update(PreprocessingMetadata(word_boundaries=[0, 1], token_types=[OneLineComment]))
107 | metadata.set_all_tokens_type(OneLineComment)
108 | return prep_tokens + [placeholders['olc_end']], metadata
109 |
110 |
111 | class MultilineComment(Comment):
112 | def __init__(self, tokens: List[ParsedToken]):
113 | super().__init__(tokens)
114 |
115 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
116 | prep_tokens, metadata = torepr(self.subtokens, repr_config)
117 | metadata.set_all_tokens_type(MultilineComment)
118 | return prep_tokens, metadata
119 |
120 |
121 | class StringLiteral(TextContainer):
122 | def __init__(self, tokens: List[ParsedToken], length: int):
123 | super().__init__(tokens)
124 | self.length = length
125 |
126 | def _replace_non_ascii_seqs_if_necessary(self,repr_config: ReprConfig) -> str:
127 | s = str(self)
128 | if 'NonEng' in list(map(lambda x: x.__name__, repr_config.types_to_be_repr)):
129 | s = placeholders["space_in_str"].join(map(lambda t: replace_non_ascii_seqs(t, placeholders['non_ascii_seq']), s.split(placeholders["space_in_str"])))
130 | return s
131 |
132 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
133 | if not repr_config: #called by str()
134 | return self.wrap_in_metadata_for_full_word(["".join(map(lambda t: str(t), self.subtokens))])
135 | elif self.length > repr_config.max_str_length:
136 | s = ['""'] if repr_config.full_strings else ['"', '"']
137 | non_processable_tokens = {} if repr_config.full_strings else {'"'}
138 | return self.wrap_in_metadata_for_full_word(s, non_processable_tokens)
139 | elif repr_config.bpe_data:
140 | s = self._replace_non_ascii_seqs_if_necessary(repr_config)
141 | return self.wrap_in_metadata_for_full_word(repr_config.word_splitter(s, repr_config.bpe_data))
142 | elif repr_config.full_strings:
143 | s = self._replace_non_ascii_seqs_if_necessary(repr_config)
144 | return self.wrap_in_metadata_for_full_word([s])
145 | else: # here we dont keep full strings
146 | tokens, metadata = torepr(list(filter(lambda t: type(t) != SpaceInString, self.subtokens)), repr_config)
147 | metadata.set_all_tokens_type(StringLiteral)
148 | return tokens, metadata
149 |
150 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
151 | return self.wrap_in_metadata_for_full_word([placeholders['string_literal']])
152 |
153 | def __eq__(self, other):
154 | return super().__eq__(other) and self.length == other.length
155 |
156 | def _repr__(self):
157 | return super().__repr__() + f" , length: {self.length}"
--------------------------------------------------------------------------------
/codeprep/tokens/noneng.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List, Tuple, Optional
6 |
7 | from codeprep.noneng import replace_non_ascii_seqs
8 | from codeprep.preprocess.core import ReprConfig, torepr
9 | from codeprep.preprocess.metadata import PreprocessingMetadata
10 | from codeprep.preprocess.placeholders import placeholders
11 | from codeprep.tokens.containers import SplitContainer
12 | from codeprep.tokens.rootclasses import ParsedToken
13 |
14 |
15 | class NonEng(ParsedToken):
16 | def __init__(self, processable_token: SplitContainer):
17 | if not isinstance(processable_token, SplitContainer):
18 | raise ValueError(f"Only SplitContainer can be wrapped in {self.__class__}. Type passed: {type(processable_token)}")
19 |
20 | self.processable_token = processable_token
21 |
22 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
23 | return torepr(self.processable_token, repr_config)
24 |
25 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
26 | if repr_config.bpe_data:
27 | token = replace_non_ascii_seqs(str(self.processable_token), placeholders['non_ascii_seq'])
28 | return torepr(SplitContainer.from_single_token(token), repr_config)
29 | else:
30 | return self.wrap_in_metadata_for_full_word([placeholders['non_eng']])
31 |
32 | def __repr__(self):
33 | return f'{self.__class__.__name__}({self.processable_token.__repr__()})'
34 |
35 | def __str__(self):
36 | return str(self.processable_token)
37 |
38 | def __eq__(self, other):
39 | return self.__class__ == other.__class__ and self.processable_token == other.processable_token
--------------------------------------------------------------------------------
/codeprep/tokens/numeric.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List, Tuple, Optional
6 |
7 | from codeprep.preprocess.core import ReprConfig
8 | from codeprep.preprocess.metadata import PreprocessingMetadata, unwrap_single_string
9 | from codeprep.preprocess.placeholders import placeholders
10 | from codeprep.tokens.rootclasses import ParsedToken
11 |
12 |
13 | class Number(ParsedToken):
14 | def __init__(self, val: str):
15 | self.val = val.lower()
16 |
17 | def __str__(self):
18 | return unwrap_single_string(self.non_preprocessed_repr())
19 |
20 | def __repr__(self):
21 | return f'<{self.__class__.__name__}>({self.val})'
22 |
23 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
24 | return self.wrap_in_metadata_for_full_word([self.val])
25 |
26 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
27 | subwords = repr_config.number_splitter(self.non_preprocessed_repr()[0][0], repr_config.bpe_data)
28 |
29 | if len(subwords) > 1 and not repr_config.bpe_data:
30 | prep_number = [placeholders['word_start']] + subwords + [placeholders['word_end']]
31 | else:
32 | prep_number = subwords
33 |
34 | return self.wrap_in_metadata_for_full_word(prep_number)
35 |
36 | def __eq__(self, other):
37 | return self.__class__ == other.__class__ and self.val == other.val
38 |
39 |
40 | class One(Number):
41 | def __init__(self):
42 | super().__init__('1')
43 |
44 |
45 | class Zero(Number):
46 | def __init__(self):
47 | super().__init__('0')
--------------------------------------------------------------------------------
/codeprep/tokens/rootclasses.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List, Tuple, Set, Optional
6 |
7 | from codeprep.preprocess.metadata import PreprocessingMetadata
8 |
9 |
10 | class ParsedToken(object):
11 | def wrap_in_metadata_for_full_word(self, tokens: List[str], non_proc: Optional[Set[str]] = None) \
12 | -> Tuple[List[str], PreprocessingMetadata]:
13 | assert type(tokens) == list
14 |
15 | metadata = PreprocessingMetadata()
16 | metadata.nonprocessable_tokens = non_proc or []
17 | metadata.word_boundaries = [0, len(tokens)]
18 | metadata.token_types = [type(self)]
19 | return tokens, metadata
20 |
21 |
22 | class ParsedSubtoken(object):
23 | pass
--------------------------------------------------------------------------------
/codeprep/tokens/whitespace.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List, Tuple, Optional
6 |
7 | from codeprep.preprocess.core import ReprConfig
8 | from codeprep.preprocess.metadata import PreprocessingMetadata, unwrap_single_string
9 | from codeprep.preprocess.placeholders import placeholders
10 | from codeprep.tokens.rootclasses import ParsedToken
11 |
12 | NBSP = '\xa0'
13 |
14 |
15 | class Whitespace(ParsedToken):
16 | def __eq__(self, other):
17 | return other.__class__ == self.__class__
18 |
19 | def __repr__(self):
20 | return f'<{self.__class__.__name__}>'
21 |
22 | def __str__(self):
23 | return unwrap_single_string(self.non_preprocessed_repr())
24 |
25 |
26 | class NewLine(Whitespace):
27 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
28 | return self.wrap_in_metadata_for_full_word(["\n"], non_proc={"\n"})
29 |
30 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
31 | return [], PreprocessingMetadata()
32 |
33 |
34 | class Tab(Whitespace):
35 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
36 | return self.wrap_in_metadata_for_full_word(["\t"], non_proc={"\t"})
37 |
38 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
39 | return [], PreprocessingMetadata()
40 |
41 |
42 | class SpaceInString(Whitespace):
43 |
44 | def __init__(self, n_chars: int = 1):
45 | super().__init__()
46 | self.n_chars = n_chars
47 |
48 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
49 | return self.wrap_in_metadata_for_full_word([placeholders['space_in_str'] * self.n_chars])
50 |
51 | def __repr__(self):
52 | return f'<{self.__class__.__name__}> (n_chars={self.n_chars})'
53 |
54 | def __eq__(self, other):
55 | return other.__class__ == self.__class__ and other.n_chars == self.n_chars
--------------------------------------------------------------------------------
/codeprep/tokens/word.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from enum import Enum
6 | from typing import List, Tuple, Optional
7 |
8 | from codeprep.preprocess.core import ReprConfig
9 | from codeprep.preprocess.metadata import PreprocessingMetadata, with_empty_metadata, unwrap_single_string
10 | from codeprep.preprocess.placeholders import placeholders
11 | from codeprep.tokens.rootclasses import ParsedSubtoken, ParsedToken
12 |
13 |
14 | class Underscore(ParsedSubtoken):
15 | def __eq__(self, other):
16 | return other.__class__ == self.__class__
17 |
18 | def __repr__(self):
19 | return f'<{self.__class__.__name__}>'
20 |
21 | def __str__(self):
22 | return self.non_preprocessed_repr()[0]
23 |
24 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> [Tuple[str, PreprocessingMetadata]]:
25 | return with_empty_metadata("_")
26 |
27 |
28 | class Word(ParsedSubtoken):
29 | """
30 | Invariants:
31 | str === str(Word.of(str))
32 | """
33 |
34 | class Capitalization(str, Enum):
35 | UNDEFINED: str = 'undefined'
36 | NONE = 'none'
37 | FIRST_LETTER = 'first_letter'
38 | ALL = 'all'
39 |
40 | def __repr__(self):
41 | return self.value
42 |
43 | def __init__(self, canonic_form: str, capitalization: Capitalization = Capitalization.UNDEFINED):
44 | Word._check_canonic_form_is_valid(canonic_form)
45 |
46 | self.canonic_form = canonic_form
47 | self.capitalization = capitalization
48 |
49 | def get_canonic_form(self) -> str:
50 | return self.canonic_form
51 |
52 | @staticmethod
53 | def _is_strictly_upper(s) -> bool:
54 | return s.isupper() and not s.lower().isupper()
55 |
56 | @staticmethod
57 | def _check_canonic_form_is_valid(canonic_form) -> None:
58 | if not isinstance(canonic_form, str) or Word._is_strictly_upper(canonic_form) \
59 | or (canonic_form and Word._is_strictly_upper(canonic_form[0])):
60 | raise AssertionError(f"Bad canonic form: {canonic_form}")
61 |
62 | def __str__(self):
63 | return unwrap_single_string(self.non_preprocessed_repr())
64 |
65 | def __with_capitalization_prefixes(self, subwords: List[str]) -> List[str]:
66 | if self.capitalization == Word.Capitalization.UNDEFINED or self.capitalization == Word.Capitalization.NONE:
67 | res = subwords
68 | elif self.capitalization == Word.Capitalization.FIRST_LETTER:
69 | res = [placeholders['capital']] + subwords
70 | elif self.capitalization == Word.Capitalization.ALL:
71 | res = [placeholders['capitals']] + subwords
72 | else:
73 | raise AssertionError(f"Unknown value: {self.capitalization}")
74 | return res
75 |
76 | def preprocessed_repr(self, repr_config: ReprConfig) -> Tuple[List[str], PreprocessingMetadata]:
77 | if repr_config.should_lowercase:
78 | subwords = repr_config.word_splitter(self.canonic_form, repr_config.bpe_data)
79 | subwords_with_prefix = self.__with_capitalization_prefixes(subwords)
80 | return with_empty_metadata(subwords_with_prefix)
81 | else:
82 | subwords = repr_config.word_splitter(self.__with_preserved_case(), repr_config.bpe_data)
83 | return with_empty_metadata(subwords)
84 |
85 | def __with_preserved_case(self) -> str:
86 | if self.capitalization == Word.Capitalization.UNDEFINED or self.capitalization == Word.Capitalization.NONE:
87 | return self.canonic_form
88 | elif self.capitalization == Word.Capitalization.FIRST_LETTER:
89 | return self.canonic_form.capitalize()
90 | elif self.capitalization == Word.Capitalization.ALL:
91 | return self.canonic_form.upper()
92 | else:
93 | raise AssertionError(f"Unknown value: {self.capitalization}")
94 |
95 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
96 | return with_empty_metadata([self.__with_preserved_case()])
97 |
98 | def __repr__(self):
99 | return f'{self.__class__.__name__}({self.canonic_form, self.capitalization})'
100 |
101 | def __eq__(self, other):
102 | return self.__class__ == other.__class__ and self.canonic_form == other.canonic_form \
103 | and self.capitalization == other.capitalization
104 |
105 | @classmethod
106 | def from_(cls, s: str):
107 | if not s:
108 | raise ValueError(f'A subword can be neither None nor of length zero. Value of the subword is {s}')
109 |
110 | if s.islower() or not s:
111 | return cls(s, Word.Capitalization.NONE)
112 | elif s.isupper():
113 | return cls(s.lower(), Word.Capitalization.ALL)
114 | elif s[0].isupper():
115 | return cls(s[0].lower() + s[1:], Word.Capitalization.FIRST_LETTER)
116 | else:
117 | return cls(s, Word.Capitalization.UNDEFINED)
118 |
119 |
120 | class NonProcessibleToken(ParsedToken):
121 | def __init__(self, token: str):
122 | self.token = token
123 |
124 | def __eq__(self, other):
125 | return self.__class__ == other.__class__ and self.token == other.token
126 |
127 | def __repr__(self):
128 | return f'{self.__class__.__name__}({self.token})'
129 |
130 | def __str__(self):
131 | return self.token
132 |
133 | def non_preprocessed_repr(self, repr_config: Optional[ReprConfig] = None) -> Tuple[List[str], PreprocessingMetadata]:
134 | return self.wrap_in_metadata_for_full_word([self.token], non_proc={self.token})
135 |
136 |
137 | class KeyWord(NonProcessibleToken):
138 | def __init__(self, token: str):
139 | super().__init__(token)
140 |
141 |
142 | class Operator(NonProcessibleToken):
143 | def __init__(self, token: str):
144 | super().__init__(token)
145 |
146 |
147 | class Semicolon(Operator):
148 | def __init__(self):
149 | super().__init__(';')
150 |
151 |
152 | class OpeningCurlyBracket(Operator):
153 | def __init__(self):
154 | super().__init__('{')
155 |
156 |
157 | class ClosingCurlyBracket(Operator):
158 | def __init__(self):
159 | super().__init__('}')
160 |
161 |
162 | class OpeningBracket(Operator):
163 | def __init__(self):
164 | super().__init__('(')
165 |
166 |
167 | class ClosingBracket(Operator):
168 | def __init__(self):
169 | super().__init__(')')
170 |
171 |
172 | class NonCodeChar(NonProcessibleToken):
173 | def __init__(self, token: str):
174 | super().__init__(token)
175 |
176 |
177 | class SpecialToken(NonProcessibleToken):
178 | def __init__(self, token: str):
179 | super().__init__(token)
--------------------------------------------------------------------------------
/codeprep/util.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import multiprocessing
6 | from heapq import heappush, heappop, heapify
7 |
8 | import itertools
9 | from typing import Dict, Tuple, List, Optional, Generator
10 |
11 |
12 | def merge_dicts_(dict1, dict2) -> Tuple[Dict, List]:
13 | """
14 | this method returns modified `dict1`! and new words are added to the dictionary
15 |
16 | >>> dict1 = {"a": 3, "b": 4}
17 | >>> dict2 = {"b": 5, "c": 6}
18 | >>> merge_dicts_(dict1, dict2)
19 | ({'a': 3, 'b': 9, 'c': 6}, ['c'])
20 | >>> dict1
21 | {'a': 3, 'b': 9, 'c': 6}
22 |
23 | """
24 | new_words = []
25 | for k, v in dict2.items():
26 | if k not in dict1:
27 | dict1[k] = v
28 | new_words.append(k)
29 | else:
30 | dict1[k] = dict1[k] + v
31 | return dict1, new_words
32 |
33 |
34 | class NonAtomicCounter(object):
35 | """
36 | >>> counter = NonAtomicCounter(10)
37 | >>> counter.inc()
38 | 11
39 | >>> counter.dec()
40 | 10
41 | >>> counter.compare_and_dec(10)
42 | True
43 | >>> counter.value
44 | 9
45 | >>> counter.value = 20
46 | >>> counter.get_and_dec()
47 | 20
48 | >>> counter.value
49 | 19
50 | """
51 | def __init__(self, v: int=0):
52 | self.val = v
53 |
54 | def inc(self) -> int:
55 | self.val += 1
56 | return self.val
57 |
58 | def dec(self) -> int:
59 | self.val -= 1
60 | return self.val
61 |
62 | def compare_and_dec(self, v: int) -> bool:
63 | result = (self.val == v)
64 | self.dec()
65 | return result
66 |
67 | def get_and_dec(self) -> int:
68 | result = self.val
69 | self.dec()
70 | return result
71 |
72 | @property
73 | def value(self) -> int:
74 | return self.val
75 |
76 | @value.setter
77 | def value(self, v: int) -> None:
78 | self.val = v
79 |
80 |
81 | class AtomicInteger(object):
82 | def __init__(self, v=0):
83 | self._lock = multiprocessing.Lock()
84 | self._queue = multiprocessing.Queue()
85 | for i in range(v):
86 | self._queue.put(1)
87 |
88 | def inc(self):
89 | with self._lock:
90 | self._queue.put(1)
91 | return self._queue.qsize()
92 |
93 | def dec(self):
94 | with self._lock:
95 | self._queue.get()
96 | return self._queue.qsize()
97 |
98 | def compare_and_dec(self, val):
99 | with self._lock:
100 | result = self._queue.qsize() == val
101 | self._queue.get()
102 | return result
103 |
104 | def get_and_dec(self):
105 | with self._lock:
106 | result = self._queue.qsize()
107 | self._queue.get()
108 | return result
109 |
110 | @property
111 | def value(self):
112 | with self._lock:
113 | return self._queue.qsize()
114 |
115 | @value.setter
116 | def value(self, v):
117 | with self._lock:
118 | self._queue = multiprocessing.Queue()
119 | for i in range(v):
120 | self._queue.put(1)
121 |
122 |
123 | class PriorityCounter(object):
124 | REMOVED = '' # placeholder for a removed task
125 |
126 | def __init__(self, d: Dict, automatic_count: bool=True):
127 | self.counter = itertools.count() if automatic_count else None
128 | self.pq = [[(-value, next(self.counter)) if self.counter else (-value[0], value[1]), key] for key, value in d.items()] # list of entries arranged in a heap
129 | heapify(self.pq)
130 | self.entry_finder = {entry[1]: entry for entry in self.pq} # mapping of tasks to entries
131 |
132 | def add(self, pair, to_add: int, c: Optional[int]=None):
133 | 'Add a new task or update the priority of an existing task'
134 | if (self.counter is None) == (c is None):
135 | raise ValueError("Either counter should be set, or count argument should be passed!")
136 | count = next(self.counter) if self.counter else c
137 | to_add = -to_add
138 | if pair in self.entry_finder:
139 | entry = self.entry_finder[pair]
140 | to_add = entry[0][0] + to_add
141 | self.remove_task(pair)
142 | if to_add != 0:
143 | entry = [(to_add, count), pair]
144 | self.entry_finder[pair] = entry
145 | heappush(self.pq, entry)
146 |
147 | def remove_task(self, task):
148 | 'Mark an existing task as REMOVED. Raise KeyError if not found.'
149 | entry = self.entry_finder.pop(task)
150 | entry[-1] = PriorityCounter.REMOVED
151 |
152 | def pop_pair(self):
153 | 'Remove and return the lowest priority task. Raise KeyError if empty.'
154 | while self.pq:
155 | (priority, count), pair = heappop(self.pq)
156 | if pair is not PriorityCounter.REMOVED:
157 | del self.entry_finder[pair]
158 | return pair, -priority
159 | raise KeyError('pop from an empty priority queue')
160 |
161 |
162 | import sys
163 | from numbers import Number
164 | from collections import Set, Mapping, deque
165 |
166 |
167 | # From https://stackoverflow.com/a/30316760:
168 | def getsize(obj):
169 | zero_depth_bases = (str, bytes, Number, range, bytearray)
170 | iteritems = 'items'
171 |
172 | def _getsize(obj_0):
173 | """Recursively iterate to sum size of object & members."""
174 | _seen_ids = set()
175 |
176 | def inner(obj):
177 | obj_id = id(obj)
178 | if obj_id in _seen_ids:
179 | return 0
180 | _seen_ids.add(obj_id)
181 | size = sys.getsizeof(obj)
182 | if isinstance(obj, zero_depth_bases):
183 | pass # bypass remaining control flow and return
184 | elif isinstance(obj, (tuple, list, Set, deque)):
185 | size += sum(inner(i) for i in obj)
186 | elif isinstance(obj, Mapping) or hasattr(obj, iteritems):
187 | size += sum(inner(k) + inner(v) for k, v in getattr(obj, iteritems)())
188 | # Check for custom object instances - may subclass above too
189 | if hasattr(obj, '__dict__'):
190 | size += inner(vars(obj))
191 | if hasattr(obj, '__slots__'): # can have __slots__ with __dict__
192 | size += sum(inner(getattr(obj, s)) for s in obj.__slots__ if hasattr(obj, s))
193 | return size
194 |
195 | return inner(obj_0)
196 |
197 | return _getsize(obj)
198 |
199 |
200 | def is_python_3_6_and_higher():
201 | python_version = sys.version_info
202 | return python_version[0] >= 3 and python_version[1] >= 6
203 |
204 |
205 | def create_chunk_generator(total: int, n_chunks: int) -> Generator[int, None, None]:
206 | min_elms_in_chunk = total // n_chunks
207 | for i in range(min_elms_in_chunk):
208 | for j in range(n_chunks):
209 | yield j
210 | for i in range(total % n_chunks):
211 | yield i
212 |
213 |
214 | def groupify(all: List, n_chunks: int) -> List[List]:
215 | """
216 | >>> groupify([0, 1, 2, 3, 4, 5, 6], 3)
217 | [[0, 3, 6], [1, 4], [2, 5]]
218 |
219 | >>> groupify([0, 1, 2, 3, 4, 5, 6], 300)
220 | [[0], [1], [2], [3], [4], [5], [6]]
221 |
222 | >>> groupify([], 300)
223 | []
224 |
225 | >>> result = groupify(list(range(100 * 1000 + 17)), 100)
226 | >>> len(result[0])
227 | 1001
228 | >>> len(result[17])
229 | 1000
230 | """
231 | groups = [[] for _ in range(n_chunks if len(all) >= n_chunks else len(all))]
232 | chunk_gen = create_chunk_generator(len(all), n_chunks)
233 | for elm, label in zip(all, chunk_gen):
234 | groups[label].append(elm)
235 | return groups
236 |
237 |
238 | def to_non_literal_str(word:str) -> str:
239 | return word.encode().decode("unicode-escape")
240 |
241 |
242 | def to_literal_str(word: str) -> str:
243 | return word.encode("unicode-escape").decode()
244 |
245 |
246 | START_ERROR_COLOR = '\033[31m'
247 | END_ERROR_COLOR = '\033[0m'
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.05mb_1bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.05mb_1bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.05mb_1bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.05mb_2bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.05mb_2bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.05mb_2bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.05mb_3bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.05mb_3bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.05mb_3bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.5mb_1bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.5mb_1bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.5mb_1bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.5mb_2bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.5mb_2bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.5mb_2bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.5mb_3bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_0.5mb_3bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_0.5mb_3bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_5mb_1bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_5mb_1bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_5mb_1bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_5mb_2bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_5mb_2bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_5mb_2bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_5mb_3bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/giganticode/codeprep/0f41307f7a9ad545e5ec0cc9552a0144328f2422/reports/bpe/wild-bpe/v0.1_5mb_3bit.png
--------------------------------------------------------------------------------
/reports/bpe/wild-bpe/v0.1_5mb_3bit.png.license:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: 2020 Hlib Babii
2 |
3 | SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/casics/spiral.git@a10fcd3afbd4c28e917cc0056d6f646cc2e9bd44
2 | pytest==5.3.1
3 | pytest-mock==1.12.1
4 | coverage==4.5.4
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | appdirs==1.4.4
2 | click==7.1.2
3 | dill==0.3.1.1
4 | docopt==0.6.2
5 | docopt-subcommands==3.0.0
6 | joblib==0.15.1
7 | jsons==1.1.2
8 | nltk==3.5
9 | Pygments==2.6.1
10 | PyYAML==5.4
11 | regex==2020.5.14
12 | tqdm==4.46.0
13 | typish==1.6.0
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 |
7 | from setuptools import setup, find_packages
8 |
9 | root_package_name = 'codeprep'
10 |
11 |
12 | def readme():
13 | with open('README.md') as f:
14 | return f.read()
15 |
16 |
17 | def version():
18 | with open(os.path.join(root_package_name, 'VERSION')) as version_file:
19 | return version_file.read().strip()
20 |
21 |
22 | setup(name='codeprep',
23 | version=version(),
24 | description='A toolkit for pre-processing large source code corpora',
25 | long_description=readme(),
26 | long_description_content_type="text/markdown",
27 | url='http://github.com/giganticode/codeprep',
28 | author='Hlib Babii',
29 | author_email='hlibbabii@gmail.com',
30 | license='Apache-2.0',
31 | packages=find_packages(),
32 | classifiers=[
33 | 'Development Status :: 3 - Alpha',
34 | 'Environment :: Console',
35 | 'Intended Audience :: Science/Research',
36 | 'License :: OSI Approved :: Apache Software License',
37 | 'Natural Language :: English',
38 | 'Programming Language :: Python :: 3.6',
39 | 'Programming Language :: Python :: 3.7',
40 | 'Operating System :: POSIX :: Linux',
41 | 'Operating System :: MacOS :: MacOS X',
42 | 'Operating System :: Microsoft :: Windows',
43 | 'Topic :: Software Development :: Pre-processors'
44 | ],
45 | python_requires='>=3.6',
46 | keywords='big large data source code corpus machine learning pre-processing nlp',
47 | install_requires=[
48 | 'appdirs>=1.4, <2',
49 | 'dill>=0.3.1.1, <0.4',
50 | 'docopt>=0.6.2, <0.7',
51 | 'docopt-subcommands>=3.0.0, <4',
52 | 'jsons>=1.0, <2',
53 | 'nltk>=3.4.5, <4',
54 | 'Pygments>=2.5.2, <3',
55 | 'PyYAML>=5.1, <6',
56 | 'regex>=2019.11.1, <=2020.5.14',
57 | 'tqdm>=4.39, <5'
58 | ],
59 | entry_points={
60 | 'console_scripts': [
61 | f'codeprep = {root_package_name}.__main__:main'
62 | ]
63 | },
64 | include_package_data=True,
65 | zip_safe=False)
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/.reuse/dep5:
--------------------------------------------------------------------------------
1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/
2 | Upstream-Name: yahtzee
3 | Upstream-Contact: Hlib Babii
4 | Source:
5 |
6 | # Sample paragraph, commented out:
7 | #
8 | # Files: src/*
9 | # Copyright: $YEAR $NAME <$CONTACT>
10 | # License: ...
11 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/LICENSES/MIT.txt:
--------------------------------------------------------------------------------
1 | MIT License Copyright (c)
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is furnished
8 | to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice (including the next
11 | paragraph) shall be included in all copies or substantial portions of the
12 | Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
16 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
17 | OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF
19 | OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/DiceValues.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee;
2 |
3 | import java.util.Iterator;
4 |
5 | public class DiceValues implements Iterable {
6 |
7 | public static final Integer MIN_DICE_VALUE = 1;
8 | public static final Integer MAX_DICE_VALUE = 6;
9 |
10 | private Iterator iterator;
11 |
12 | public DiceValues(Iterator iterator) {
13 | this.iterator = iterator;
14 | }
15 |
16 | public static Iterable extends Integer> getDescendingIterator() {
17 | return new DiceValues(new DescendingIterator());
18 | }
19 |
20 | public static class DescendingIterator implements Iterator {
21 |
22 | private Integer currentValue = MAX_DICE_VALUE;
23 |
24 | @Override
25 | public boolean hasNext() {
26 | return this.currentValue >= MIN_DICE_VALUE;
27 | }
28 |
29 | @Override
30 | public Integer next() {
31 | return this.currentValue--;
32 | }
33 | }
34 |
35 | @Override
36 | public Iterator iterator() {
37 | return this.iterator;
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/GameDemo.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee;
2 |
3 | import hlibbabii.yahtzee.gameplay.Game;
4 | import hlibbabii.yahtzee.gameplay.GameStats;
5 | import hlibbabii.yahtzee.player.DummyPlayer;
6 |
7 | public class GameDemo {
8 | public static void main(String[] args) {
9 | Player player1 = new DummyPlayer();
10 | Player player2 = new DummyPlayer();
11 |
12 | Game game = new Game(player1, player2);
13 |
14 | GameStats gameStats = game.play();
15 | System.out.println("Player1");
16 | System.out.println(gameStats.getPlayerStats(player1));
17 | System.out.println("Player2");
18 | System.out.println(gameStats.getPlayerStats(player2));
19 |
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/Player.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee;
2 |
3 | import hlibbabii.yahtzee.combination.Combination;
4 | import hlibbabii.yahtzee.gameplay.Decision;
5 | import hlibbabii.yahtzee.model.DiceLayout;
6 |
7 | import java.util.Set;
8 |
9 | public interface Player {
10 | Decision makeDecision(DiceLayout diceLayout, DiceLayout fixedDiceLayout, Set availableCombinations, int rollsLeft);
11 | }
12 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Chance.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 |
5 | import java.util.stream.IntStream;
6 |
7 | public class Chance extends Combination {
8 | public static final Chance CHANCE = new Chance();
9 |
10 | @Override
11 | public int earnedScores(DiceLayout diceLayout) {
12 | return IntStream.of(diceLayout.toSortedNumbers()).sum();
13 | }
14 |
15 | @Override
16 | public String toString() {
17 | return "Chance";
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Combination.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 |
5 | public abstract class Combination {
6 |
7 | public abstract int earnedScores(DiceLayout diceLayout);
8 | }
9 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/FullHouse.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.gameplay.Game;
4 | import hlibbabii.yahtzee.model.DiceLayout;
5 |
6 | import java.util.Collection;
7 | import java.util.stream.IntStream;
8 |
9 | public class FullHouse extends Combination {
10 |
11 | public static final FullHouse FULL_HOUSE = new FullHouse();
12 |
13 | @Override
14 | public int earnedScores(DiceLayout diceLayout) {
15 | assert 5 == Game.N_DICE;
16 |
17 | Collection counts = diceLayout.toCounts().values();
18 | return counts.contains(2) && counts.contains(3)?
19 | IntStream.of(diceLayout.toSortedNumbers()).sum():
20 | 0;
21 |
22 | }
23 |
24 | @Override
25 | public String toString() {
26 | return "Full House";
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/LargeStraight.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.DiceValues;
4 | import hlibbabii.yahtzee.model.DiceLayout;
5 |
6 | import java.util.Arrays;
7 | import java.util.stream.IntStream;
8 |
9 | public class LargeStraight extends Combination {
10 |
11 | public static final LargeStraight LARGE_STRAIGHT = new LargeStraight();
12 |
13 | @Override
14 | public int earnedScores(DiceLayout diceLayout) {
15 | int[] sortedRolledNumbers = diceLayout.toSortedNumbers();
16 | if (Arrays.equals(sortedRolledNumbers, IntStream.range(2, DiceValues.MAX_DICE_VALUE + 1).toArray())) {
17 | return IntStream.of(sortedRolledNumbers).sum();
18 | } else {
19 | return 0;
20 | }
21 | }
22 |
23 | @Override
24 | public String toString() {
25 | return "Large Straight";
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/NOfAKind.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.DiceValues;
4 | import hlibbabii.yahtzee.model.DiceLayout;
5 |
6 | public class NOfAKind extends Combination{
7 |
8 | public static final NOfAKind PAIR = new NOfAKind(2);
9 | public static final NOfAKind THREE_OF_A_KIND = new NOfAKind(3);
10 | public static final NOfAKind FOUR_OF_A_KIND = new NOfAKind(4);
11 |
12 | private int n;
13 |
14 | public NOfAKind(int n) {
15 | this.n = n;
16 | }
17 |
18 | @Override
19 | public int earnedScores(DiceLayout diceLayout) {
20 | for (Integer diceValue : DiceValues.getDescendingIterator()) {
21 | if (diceLayout.getCount(diceValue) >= this.n) {
22 | return diceValue * this.n;
23 | }
24 | }
25 | return 0;
26 | }
27 |
28 | @Override
29 | public String toString() {
30 | return n + " Of A Kind";
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Numbers.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 |
5 | public class Numbers extends Combination {
6 |
7 | public static final Combination ACES = new Numbers(1);
8 | public static final Combination TWOS = new Numbers(2);
9 | public static final Combination THREES = new Numbers(3);
10 | public static final Combination FOURS = new Numbers(4);
11 | public static final Combination FIVES = new Numbers(5);
12 | public static final Combination SIXES = new Numbers(6);
13 |
14 | private Integer number;
15 |
16 | public Numbers(Integer number) {
17 | this.number = number;
18 | }
19 |
20 | @Override
21 | public int earnedScores(DiceLayout diceLayout) {
22 | return diceLayout.getCount(this.number) * this.number;
23 | }
24 |
25 | @Override
26 | public String toString() {
27 | return "Number(" + number + ")";
28 | }
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/SmallStraight.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.DiceValues;
4 | import hlibbabii.yahtzee.model.DiceLayout;
5 |
6 | import java.util.Arrays;
7 | import java.util.stream.IntStream;
8 |
9 | public class SmallStraight extends Combination {
10 |
11 | public static final SmallStraight SMALL_STRAIGHT = new SmallStraight();
12 |
13 | @Override
14 | public int earnedScores(DiceLayout diceLayout) {
15 | int[] sortedRolledNumbers = diceLayout.toSortedNumbers();
16 | if (Arrays.equals(sortedRolledNumbers, IntStream.range(1, DiceValues.MAX_DICE_VALUE).toArray())) {
17 | return IntStream.of(sortedRolledNumbers).sum();
18 | } else {
19 | return 0;
20 | }
21 | }
22 |
23 | @Override
24 | public String toString() {
25 | return "Small Straight";
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/TwoPairs.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.DiceValues;
4 | import hlibbabii.yahtzee.model.DiceLayout;
5 |
6 | public class TwoPairs extends Combination {
7 |
8 | public static final TwoPairs TWO_PAIRS = new TwoPairs();
9 |
10 | @Override
11 | public int earnedScores(DiceLayout diceLayout) {
12 | int sum = 0;
13 | for (Integer diceValue: DiceValues.getDescendingIterator()) {
14 | if (diceLayout.getCount(diceValue) >= 2) {
15 | if (sum == 0) {
16 | sum = diceValue * 2;
17 | } else {
18 | return sum + diceValue * 2;
19 | }
20 | }
21 | }
22 | return 0;
23 | }
24 |
25 | @Override
26 | public String toString() {
27 | return "Two Pairs";
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/combination/Yahtzee.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 |
5 | public class Yahtzee extends Combination {
6 |
7 | public static final Yahtzee YAHTZEE = new Yahtzee();
8 |
9 | private static final int SCORES_FOR_YAHTZEE = 50;
10 |
11 | @Override
12 | public int earnedScores(DiceLayout diceLayout) {
13 | return diceLayout.toCounts().keySet().size() == 1 ? SCORES_FOR_YAHTZEE : 0;
14 | }
15 |
16 | @Override
17 | public String toString() {
18 | return "Yahtzee";
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/Decision.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.gameplay;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import hlibbabii.yahtzee.combination.Combination;
5 |
6 | public class Decision {
7 | private DiceLayout fixedDiceLayout;
8 | private Combination combination;
9 |
10 | public Decision(Combination combination) {
11 | this.combination = combination;
12 | }
13 |
14 | public Decision(DiceLayout fixedDiceLayout) {
15 | this.fixedDiceLayout = fixedDiceLayout;
16 | }
17 |
18 | public static Decision fixLayoutAndRoll(DiceLayout fixedDiceLayout) {
19 | return new Decision(fixedDiceLayout);
20 | }
21 |
22 | public boolean isCombinationDecision() {
23 | return this.combination != null;
24 | }
25 |
26 | public static Decision decideCombination(Combination combination) {
27 | return new Decision(combination);
28 | }
29 |
30 | public Combination getCombination() {
31 | if (!this.isCombinationDecision()) {
32 | throw new AssertionError("To get the combination the decision has to be final!");
33 | }
34 |
35 | return this.combination;
36 | }
37 |
38 | public DiceLayout getDiceDecidedToLeave() {
39 | return this.fixedDiceLayout;
40 | }
41 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/Game.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.gameplay;
2 |
3 | import hlibbabii.yahtzee.player.DummyPlayer;
4 | import hlibbabii.yahtzee.Player;
5 | import hlibbabii.yahtzee.model.DiceLayout;
6 | import hlibbabii.yahtzee.combination.Combination;
7 | import hlibbabii.yahtzee.util.RandomService;
8 |
9 | import java.util.*;
10 | import java.util.Map.Entry;
11 |
12 | public class Game {
13 |
14 | public static final int N_DICE = 5;
15 | private static final Integer NUMBER_OF_ATTEMPTS = 3;
16 |
17 | private final Player player1;
18 | private final Player player2;
19 | private final GameStats gameStats;
20 |
21 |
22 | public Game(Player player1, Player player2) {
23 | this.player1 = player1;
24 | this.player2 = player2;
25 | this.gameStats = new GameStats(player1, player2);
26 | }
27 |
28 | public static void main(String[] args) {
29 | GameStats gameStats = new Game(new DummyPlayer(), new DummyPlayer()).play();
30 | System.out.println(gameStats.getFinalPoints());
31 | }
32 |
33 | public GameStats play() {
34 | while (this.gameStats.combinationsAvailable()) {
35 | MoveResult moveResult1 = this.makeMove(this.player1);
36 | this.gameStats.addMoveResult(moveResult1);
37 |
38 | MoveResult moveResult2 = this.makeMove(this.player2);
39 | this.gameStats.addMoveResult(moveResult2);
40 | }
41 | return this.gameStats;
42 | }
43 |
44 | private MoveResult makeMove(Player player) {
45 | PlayerStats playerStats = this.gameStats.getPlayerStats(player);
46 | Set availableCombinations = playerStats.getAvailableCombinations();
47 |
48 | DiceLayout fixedDiceLayout = DiceLayout.empty();
49 | int howManyToRoll = N_DICE;
50 | for (int i = NUMBER_OF_ATTEMPTS; i > 0; i--) {
51 | DiceLayout nonFixedDiceLayout = this.roll(howManyToRoll);
52 | Decision decision;
53 | try {
54 | decision = player.makeDecision(nonFixedDiceLayout, fixedDiceLayout, availableCombinations, i - 1);
55 | } catch (Exception e) {
56 | throw new PlayerException(e);
57 | }
58 | if (decision.isCombinationDecision()) {
59 | return MoveResult.create(player, decision, nonFixedDiceLayout);
60 | } else {
61 | DiceLayout currentFixedDiceLayout = decision.getDiceDecidedToLeave();
62 | this.checkFixedDiceLayoutValid(currentFixedDiceLayout, fixedDiceLayout);
63 | fixedDiceLayout = currentFixedDiceLayout;
64 | howManyToRoll = N_DICE - fixedDiceLayout.getSize();
65 | }
66 | }
67 | throw new AssertionError("Combination decision should have already been made!");
68 | }
69 |
70 | private void checkFixedDiceLayoutValid(DiceLayout newFixedDiceLayout, DiceLayout previousFixedDiceLayout) {
71 | for (Entry previousAlignmentEntry: previousFixedDiceLayout.toCounts().entrySet()) {
72 | if (previousAlignmentEntry.getValue() > newFixedDiceLayout.toCounts().get(previousAlignmentEntry.getKey())) {
73 | throw new AssertionError("The dice which once were left on the table cannot be rerolled later");
74 | }
75 | }
76 | }
77 |
78 | public DiceLayout roll() {
79 | return this.roll(N_DICE);
80 | }
81 |
82 | private DiceLayout roll(int howMany) {
83 | RandomService random = new RandomService();
84 | int[] arr = new int[howMany];
85 | for (int i = 0; i < howMany; i++) {
86 | arr[i] = random.getRandomDiceNumber();
87 | }
88 | return DiceLayout.fromNumbers(arr);
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/GameStats.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.gameplay;
2 |
3 | import hlibbabii.yahtzee.Player;
4 |
5 | import java.util.HashMap;
6 | import java.util.Map;
7 | import java.util.stream.Collectors;
8 |
9 | public class GameStats {
10 | Map playerStatsMap;
11 |
12 | public GameStats(Player player1, Player player2) {
13 | this.playerStatsMap = new HashMap() {{
14 | this.put(player1, new PlayerStats());
15 | this.put(player2, new PlayerStats());
16 | }};
17 | }
18 |
19 | public boolean combinationsAvailable() {
20 | PlayerStats anyPlayerStats = this.playerStatsMap.values().iterator().next();
21 | return !anyPlayerStats.getAvailableCombinations().isEmpty();
22 | }
23 |
24 | public void addMoveResult(MoveResult moveResult) {
25 | Player player = moveResult.getPlayer();
26 | PlayerStats playerStats = this.playerStatsMap.get(player);
27 | playerStats.put(moveResult.getCombination(), moveResult.getScore());
28 | }
29 |
30 | public PlayerStats getPlayerStats(Player player) {
31 | return this.playerStatsMap.get(player);
32 | }
33 |
34 | public Map getFinalPoints() {
35 | return this.playerStatsMap.entrySet().stream()
36 | .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getFinalPoints()));
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/MoveResult.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.gameplay;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import hlibbabii.yahtzee.Player;
5 | import hlibbabii.yahtzee.combination.Combination;
6 |
7 | import java.util.Objects;
8 |
9 | public class MoveResult {
10 | private final Player player;
11 | private final Combination combination;
12 | private final int score;
13 |
14 | private MoveResult(Player player, Combination combination, int scores) {
15 | this.player = player;
16 | this.combination = combination;
17 | this.score = scores;
18 | }
19 |
20 | public static MoveResult create(Player player, Decision decision, DiceLayout nonFixedDiceLayout) {
21 | Combination combination = decision.getCombination();
22 | int scores = combination.earnedScores(nonFixedDiceLayout);
23 | return new MoveResult(player, combination, scores);
24 | }
25 |
26 | public Player getPlayer() {
27 | return this.player;
28 | }
29 |
30 | public Combination getCombination() {
31 | return this.combination;
32 | }
33 |
34 | public int getScore() {
35 | return this.score;
36 | }
37 |
38 | @Override
39 | public boolean equals(Object o) {
40 | if (this == o) return true;
41 | if (o == null || this.getClass() != o.getClass()) return false;
42 | MoveResult that = (MoveResult) o;
43 | return this.score == that.score &&
44 | Objects.equals(this.player, that.player) &&
45 | Objects.equals(this.combination, that.combination);
46 | }
47 |
48 | @Override
49 | public int hashCode() {
50 | return Objects.hash(this.player, this.combination, this.score);
51 | }
52 |
53 | @Override
54 | public String toString() {
55 | return "MoveResult{" +
56 | "player=" + player +
57 | ", combination=" + combination +
58 | ", score=" + score +
59 | '}';
60 | }
61 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/PlayerException.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.gameplay;
2 |
3 | public class PlayerException extends RuntimeException {
4 | public PlayerException(Exception e) {
5 | super(e);
6 | }
7 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/gameplay/PlayerStats.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.gameplay;
2 |
3 |
4 | import hlibbabii.yahtzee.combination.*;
5 |
6 | import java.util.HashMap;
7 | import java.util.Map;
8 | import java.util.Map.Entry;
9 | import java.util.Set;
10 | import java.util.stream.Collectors;
11 |
12 | public class PlayerStats {
13 |
14 | private static final int NUMBERS_VALUE_REACH_TO_GET_BONUS = 63;
15 | private static final int BONUS_VALUE = 50;
16 |
17 | private Map combinations;
18 |
19 | private void initAllCombinations() {
20 | this.combinations = new HashMap<>();
21 | /* Numbers*/
22 | this.combinations.put(Numbers.ACES, null);
23 | this.combinations.put(Numbers.TWOS, null);
24 | this.combinations.put(Numbers.THREES, null);
25 | this.combinations.put(Numbers.FOURS, null);
26 | this.combinations.put(Numbers.FIVES, null);
27 | this.combinations.put(Numbers.SIXES, null);
28 |
29 | this.combinations.put(NOfAKind.PAIR, null);
30 | this.combinations.put(NOfAKind.THREE_OF_A_KIND, null);
31 | this.combinations.put(NOfAKind.FOUR_OF_A_KIND, null);
32 |
33 | this.combinations.put(FullHouse.FULL_HOUSE, null);
34 | this.combinations.put(TwoPairs.TWO_PAIRS, null);
35 |
36 | this.combinations.put(SmallStraight.SMALL_STRAIGHT, null);
37 | this.combinations.put(LargeStraight.LARGE_STRAIGHT, null);
38 |
39 | this.combinations.put(Chance.CHANCE, null);
40 | this.combinations.put(Yahtzee.YAHTZEE, null);
41 | }
42 |
43 | public PlayerStats() {
44 | this.initAllCombinations();
45 | }
46 |
47 | public Set getAvailableCombinations() {
48 | return this.combinations.entrySet().stream().filter(e -> e.getValue() == null).map(Entry::getKey).collect(Collectors.toSet());
49 | }
50 |
51 | public void put(Combination combination, int score) {
52 | if (this.combinations.get(combination) != null) {
53 | throw new AssertionError(String.format("Combination %s has already been played", combination));
54 | }
55 | this.combinations.put(combination, score);
56 | }
57 |
58 | private int getUpperSectionPoints() {
59 | return this.combinations.entrySet().stream().filter(e -> e.getKey() instanceof Numbers)
60 | .map(Entry::getValue).mapToInt(e -> e).sum();
61 | }
62 |
63 | private int getLowerSectionPoints() {
64 | return this.combinations.entrySet().stream().filter(e -> !(e.getKey() instanceof Numbers))
65 | .map(Entry::getValue).mapToInt(e -> e).sum();
66 | }
67 |
68 | private boolean bonusEarned() {
69 | return this.getUpperSectionPoints() >= NUMBERS_VALUE_REACH_TO_GET_BONUS;
70 | }
71 |
72 | public Integer getFinalPoints() {
73 | return this.getLowerSectionPoints() + this.getUpperSectionPoints() + (this.bonusEarned() ? BONUS_VALUE : 0);
74 | }
75 |
76 | @Override
77 | public String toString() {
78 | return "PlayerStats{" +
79 | "combinations=" + combinations +
80 | '}';
81 | }
82 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/model/DiceLayout.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.model;
2 |
3 | import hlibbabii.yahtzee.DiceValues;
4 | import hlibbabii.yahtzee.gameplay.Game;
5 |
6 | import java.util.*;
7 | import java.util.stream.Collectors;
8 | import java.util.stream.IntStream;
9 | import java.util.stream.Stream;
10 |
11 | public class DiceLayout {
12 |
13 | private static final String INVALID_DICE_LAYOUT_MESSAGE_TEMPLATE = "Invalid dice layout: %s. %s";
14 |
15 | private SortedMap valuesToOccurences;
16 |
17 | private DiceLayout(SortedMap valuesToOccurences) {
18 | this.valuesToOccurences = valuesToOccurences;
19 |
20 | this.checkAlignmentInvariants(valuesToOccurences);
21 | }
22 |
23 | private void checkAlignmentInvariants(Map valuesToOccurences) {
24 | if (this.getSize() > Game.N_DICE) {
25 | String reason = String.format("The number of dice cannot be more than %s.", Game.N_DICE);
26 | throw new IllegalArgumentException(String.format(INVALID_DICE_LAYOUT_MESSAGE_TEMPLATE, Arrays.toString(this.toSortedNumbers()), reason));
27 | }
28 | for (Integer rolledNumber : valuesToOccurences.keySet()) {
29 | if (rolledNumber < DiceValues.MIN_DICE_VALUE || rolledNumber > DiceValues.MAX_DICE_VALUE) {
30 | String reason = "Rolled numbers contain invalid values.";
31 | throw new IllegalArgumentException(String.format(INVALID_DICE_LAYOUT_MESSAGE_TEMPLATE, Arrays.toString(this.toSortedNumbers()), reason));
32 | }
33 | }
34 | }
35 |
36 | /* Dice layout creation options */
37 |
38 | public static DiceLayout empty() {
39 | return new DiceLayout(new TreeMap<>());
40 | }
41 |
42 | public static DiceLayout fromMap(SortedMap valuesToOccurences) {
43 | return new DiceLayout(valuesToOccurences);
44 | }
45 |
46 | private static DiceLayout fromStream(Stream stream) {
47 | return DiceLayout.fromMap(new TreeMap<>(
48 | stream.collect(Collectors.groupingBy((a) -> a, Collectors.summingInt((e)->1)))
49 | ));
50 | }
51 |
52 | public static DiceLayout fromNumbers(int... numbers) {
53 | return fromStream(IntStream.of(numbers).boxed());
54 | }
55 |
56 | public static DiceLayout fromNumbers(List rolledNumbers) {
57 | return fromStream(rolledNumbers.stream());
58 | }
59 |
60 | /* Dice layout representation options */
61 |
62 | public Map toCounts() {
63 | return this.valuesToOccurences;
64 | }
65 |
66 | public int getCount(int number) {
67 | return toCounts().getOrDefault(number, 0);
68 | }
69 |
70 | public int[] toSortedNumbers() {
71 | return this.valuesToOccurences.entrySet().stream()
72 | .flatMapToInt((a) -> IntStream.range(0, a.getValue()).map((k) -> a.getKey())).toArray();
73 | }
74 |
75 | /* */
76 |
77 | public int getSize() {
78 | return this.valuesToOccurences.values().stream().mapToInt(e -> e).sum();
79 | }
80 |
81 | @Override
82 | public boolean equals(Object o) {
83 | if (this == o) return true;
84 | if (o == null || this.getClass() != o.getClass()) return false;
85 | DiceLayout that = (DiceLayout) o;
86 | return Objects.equals(this.valuesToOccurences, that.valuesToOccurences);
87 | }
88 |
89 | @Override
90 | public int hashCode() {
91 | return Objects.hash(this.valuesToOccurences);
92 | }
93 |
94 | @Override
95 | public String toString() {
96 | return "DiceLayout{" +
97 | "valuesToOccurences=" + Arrays.toString(this.toSortedNumbers()) +
98 | '}';
99 | }
100 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/player/DummyPlayer.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.player;
2 |
3 | import hlibbabii.yahtzee.Player;
4 | import hlibbabii.yahtzee.gameplay.Decision;
5 | import hlibbabii.yahtzee.model.DiceLayout;
6 | import hlibbabii.yahtzee.combination.Combination;
7 |
8 | import java.util.Set;
9 |
10 | /**
11 | * This player rolls the dice as many times as he/she/it can and chooses the first combination which is
12 | * available (even if will give 0 points and there are other combination which can give more than 0 points).
13 | */
14 | public class DummyPlayer implements Player {
15 |
16 | @Override
17 | public Decision makeDecision(DiceLayout diceLayout, DiceLayout fixedDiceLayout, Set availableCombinations, int rollsLeft) {
18 | if (rollsLeft == 0) {
19 | return Decision.decideCombination(availableCombinations.iterator().next());
20 | // availableCombinations.stream().map(c -> c.earnedScores(diceLayout)).max(Comparator.comparingInt(a->a));
21 | } else {
22 | return Decision.fixLayoutAndRoll(fixedDiceLayout);
23 | }
24 | }
25 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/main/java/hlibbabii/yahtzee/util/RandomService.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.util;
2 |
3 | import java.util.Random;
4 |
5 | public class RandomService {
6 |
7 | private final Random rand;
8 |
9 | public RandomService() {
10 | this.rand = new Random();
11 | }
12 |
13 | public int getRandomDiceNumber() {
14 | return this.rand.nextInt(6) + 1;
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/ChanceTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static org.junit.jupiter.api.Assertions.*;
7 |
8 | class ChanceTest {
9 | @Test
10 | void testChance() {
11 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,3,3,6,6);
12 | int actual = Chance.CHANCE.earnedScores(diceLayout);
13 | assertEquals(19, actual);
14 | }
15 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/FullHouseTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static org.junit.jupiter.api.Assertions.*;
7 |
8 | class FullHouseTest {
9 | @Test
10 | void testFullHouseAllDistinct() {
11 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,3,4,5,6);
12 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout);
13 | assertEquals(0, actual);
14 | }
15 |
16 | @Test
17 | void testFullHouseYahtzeePresent() {
18 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,2,2);
19 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout);
20 | assertEquals(0, actual);
21 | }
22 |
23 | @Test
24 | void testFullHouseThreeOfAKindPresent() {
25 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,3,4);
26 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout);
27 | assertEquals(0, actual);
28 | }
29 |
30 | @Test
31 | void testFullHousePositive() {
32 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,3,3);
33 | int actual = FullHouse.FULL_HOUSE.earnedScores(diceLayout);
34 | assertEquals(12, actual);
35 | }
36 |
37 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/LargeStraightTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static org.junit.jupiter.api.Assertions.*;
7 |
8 | class LargeStraightTest {
9 |
10 | @Test
11 | void earnedScoresSmallStraight() {
12 | DiceLayout diceLayout = DiceLayout.fromNumbers(1, 2, 3, 4, 5);
13 | int actual = new LargeStraight().earnedScores(diceLayout);
14 | assertEquals(0, actual);
15 | }
16 |
17 | @Test
18 | void earnedScoresLargeStraight() {
19 | DiceLayout diceLayout = DiceLayout.fromNumbers(2, 3, 4, 5, 6);
20 | int actual = new LargeStraight().earnedScores(diceLayout);
21 | assertEquals(20, actual);
22 | }
23 |
24 | @Test
25 | void earnedScoresNoStraight() {
26 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,1,1,2,3);
27 | int actual = new LargeStraight().earnedScores(diceLayout);
28 | assertEquals(0, actual);
29 | }
30 |
31 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/NOfAKindTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 |
4 | import hlibbabii.yahtzee.model.DiceLayout;
5 | import org.junit.jupiter.api.Assertions;
6 | import org.junit.jupiter.api.Test;
7 |
8 | class NOfAKindTest {
9 | @Test
10 | public void testNoPairs() {
11 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,6,5);
12 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
13 | Assertions.assertEquals(0, actual);
14 | }
15 |
16 | @Test
17 | public void testHighPair() {
18 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,5,5);
19 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
20 | Assertions.assertEquals(10, actual);
21 | }
22 |
23 | @Test
24 | public void testMediumPair() {
25 | DiceLayout diceLayout = DiceLayout.fromNumbers(6,1,2,2,5);
26 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
27 | Assertions.assertEquals(4, actual);
28 | }
29 |
30 | @Test
31 | public void testPairLow() {
32 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,6,1);
33 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
34 | Assertions.assertEquals(2, actual);
35 | }
36 |
37 | @Test
38 | public void testPairTwoPairsPresent() {
39 | DiceLayout diceLayout = DiceLayout.fromNumbers(6,5,3,3,5);
40 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
41 | Assertions.assertEquals(10, actual);
42 | }
43 |
44 | @Test
45 | public void testPairThreeOfAKindPresent() {
46 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,3,5,3,3);
47 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
48 | Assertions.assertEquals(6, actual);
49 | }
50 |
51 | @Test
52 | public void testPairFullHouseHighPresent() {
53 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,4,4,4,2);
54 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
55 | Assertions.assertEquals(8, actual);
56 | }
57 |
58 | @Test
59 | public void testPairFullHouseLowPresent() {
60 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,4,4,2,2);
61 | int actual = NOfAKind.PAIR.earnedScores(diceLayout);
62 | Assertions.assertEquals(8, actual);
63 | }
64 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/SmallStraightTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static org.junit.jupiter.api.Assertions.*;
7 |
8 | class SmallStraightTest {
9 |
10 | @Test
11 | void earnedScoresSmallStraight() {
12 | DiceLayout diceLayout = DiceLayout.fromNumbers(1, 2, 3, 4, 5);
13 | int actual = new SmallStraight().earnedScores(diceLayout);
14 | assertEquals(15, actual);
15 | }
16 |
17 | @Test
18 | void earnedScoresLargeStraight() {
19 | DiceLayout diceLayout = DiceLayout.fromNumbers(2, 3, 4, 5, 6);
20 | int actual = new SmallStraight().earnedScores(diceLayout);
21 | assertEquals(0, actual);
22 | }
23 |
24 | @Test
25 | void earnedScoresNoStraight() {
26 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,1,1,2,3);
27 | int actual = new SmallStraight().earnedScores(diceLayout);
28 | assertEquals(0, actual);
29 | }
30 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/combination/TwoPairsTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.combination;
2 |
3 | import hlibbabii.yahtzee.model.DiceLayout;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import static hlibbabii.yahtzee.combination.TwoPairs.TWO_PAIRS;
7 | import static org.junit.jupiter.api.Assertions.*;
8 |
9 | class TwoPairsTest {
10 |
11 | @Test
12 | void testTwoPairsAllDistinct() {
13 | DiceLayout diceLayout = DiceLayout.fromNumbers(1, 2, 3, 5, 6);
14 | int actual = TWO_PAIRS.earnedScores(diceLayout);
15 | assertEquals(0, actual);
16 | }
17 |
18 | @Test
19 | void testTwoPairsOnePairPresent() {
20 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,2,3,4);
21 | int actual = TWO_PAIRS.earnedScores(diceLayout);
22 | assertEquals(0, actual);
23 | }
24 |
25 | @Test
26 | void testTwoPairsTwoPairsPresent() {
27 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,1,4,5,5);
28 | int actual = TWO_PAIRS.earnedScores(diceLayout);
29 | assertEquals(12, actual);
30 | }
31 |
32 | @Test
33 | void testTwoPairsFullHousePresent() {
34 | DiceLayout diceLayout = DiceLayout.fromNumbers(2,2,2,6,6);
35 | int actual = TWO_PAIRS.earnedScores(diceLayout);
36 | assertEquals(16, actual);
37 | }
38 |
39 | @Test
40 | void testTwoPairsForOfAKindPresent() {
41 | DiceLayout diceLayout = DiceLayout.fromNumbers(3,5,5,5,5);
42 | int actual = TWO_PAIRS.earnedScores(diceLayout);
43 | assertEquals(0, actual);
44 | }
45 | }
--------------------------------------------------------------------------------
/test-data/test-corpus/yahtzee/src/test/java/hlibbabii/yahtzee/model/DiceLayoutTest.java:
--------------------------------------------------------------------------------
1 | package hlibbabii.yahtzee.model;
2 |
3 | import org.junit.jupiter.api.Test;
4 |
5 | import java.util.Map;
6 |
7 | import static org.junit.jupiter.api.Assertions.*;
8 |
9 | class DiceLayoutTest {
10 |
11 | @Test
12 | void testEmpty() {
13 | DiceLayout diceLayout = DiceLayout.empty();
14 |
15 | assertEquals(0, diceLayout.toCounts().entrySet().size());
16 | }
17 |
18 | @Test
19 | void testFromNumbers() {
20 | DiceLayout diceLayout = DiceLayout.fromNumbers(1,2,3,4,4);
21 |
22 | int[] numbers = diceLayout.toSortedNumbers();
23 | Map counts = diceLayout.toCounts();
24 | }
25 |
26 | }
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | import sys
7 |
8 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), '..'))
--------------------------------------------------------------------------------
/tests/api/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/tests/api/test_corpus.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | from unittest import mock
7 | from unittest.mock import Mock
8 |
9 | from codeprep.api.corpus import preprocess_corpus
10 | from codeprep.prepconfig import PrepConfig, PrepParam
11 |
12 | PATH_TO_CUR_DIR_STUB = os.path.join('path', 'to', 'curdir')
13 | PATH_TO_DATASET_STUB = os.path.join('path', 'to', 'dataset')
14 | PATH_TO_OUTPUT_STUB = os.path.join('path', 'to', 'output')
15 |
16 | DEFAULT_PREP_CONFIG = PrepConfig({
17 | PrepParam.EN_ONLY: 'u',
18 | PrepParam.COM: 'c',
19 | PrepParam.STR: '1',
20 | PrepParam.SPLIT: '0',
21 | PrepParam.TABS_NEWLINES: '0',
22 | PrepParam.CASE: 'u',
23 | })
24 |
25 |
26 | @mock.patch('codeprep.api.corpus.Dataset', autospec=True)
27 | @mock.patch('codeprep.api.corpus.stages', autospec=True)
28 | @mock.patch('codeprep.cli.impl.os.getcwd', autospec=True, return_value=PATH_TO_CUR_DIR_STUB)
29 | def test_simple(os_mock, stages_mock, dataset_mock):
30 | # given
31 | dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock)
32 |
33 |
34 | # when
35 | preprocess_corpus(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG)
36 |
37 | # then
38 | dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, None, None,
39 | overriden_path_to_prep_dataset=PATH_TO_CUR_DIR_STUB, suppress_caching=False)
40 | stages_mock.run_until_preprocessing.assert_called_with(dataset_mock, None)
41 |
42 |
43 | @mock.patch('codeprep.api.corpus.Dataset', autospec=True)
44 | @mock.patch('codeprep.api.corpus.stages', autospec=True)
45 | @mock.patch('codeprep.cli.impl.os.getcwd', autospec=True, return_value=PATH_TO_CUR_DIR_STUB)
46 | def test_calc_vocab(os_mock, stages_mock, dataset_mock):
47 | # given
48 | dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock)
49 |
50 | # when
51 | preprocess_corpus(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, calc_vocab=True, suppress_caching=True)
52 |
53 | # then
54 | dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, None, None,
55 | overriden_path_to_prep_dataset=PATH_TO_CUR_DIR_STUB, suppress_caching=True)
56 | stages_mock.run_until_vocab.assert_called_with(dataset_mock, None)
57 |
58 |
59 | @mock.patch('codeprep.api.corpus.Dataset', autospec=True)
60 | @mock.patch('codeprep.api.corpus.stages', autospec=True)
61 | def test_output(stages_mock, dataset_mock):
62 | # given
63 | dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock)
64 |
65 | # when
66 | preprocess_corpus(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, output_path=PATH_TO_OUTPUT_STUB)
67 |
68 | # then
69 | dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, DEFAULT_PREP_CONFIG, None, None,
70 | overriden_path_to_prep_dataset=PATH_TO_OUTPUT_STUB, suppress_caching=False)
71 | stages_mock.run_until_preprocessing.assert_called_with(dataset_mock, None)
--------------------------------------------------------------------------------
/tests/bpepkg/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/tests/bpepkg/test_merge.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from unittest import mock
6 |
7 | from unittest.mock import MagicMock
8 |
9 | import pytest
10 | from pytest import fixture
11 |
12 | import codeprep
13 | from codeprep.bpepkg import merge
14 | from codeprep.bpepkg.merge import MergeList, Merge
15 |
16 |
17 | @fixture
18 | def file_handle_mock(mocker):
19 | mocker.patch('codeprep.bpepkg.merge.open')
20 | codeprep.bpepkg.merge.open.return_value = MagicMock(spec=['__enter__', '__exit__'])
21 | handle = codeprep.bpepkg.merge.open.return_value.__enter__.return_value
22 | return handle
23 |
24 |
25 | def test_read_merges(file_handle_mock):
26 | file_handle_mock.__iter__.return_value = iter(['a b 67', 'b c 34', 'c d 94'])
27 |
28 | actual = merge.read_merges('file', 2)
29 | expected = MergeList().append(Merge(('a', 'b'), 67, 0)).append(Merge(('b', 'c'), 34, 1))
30 |
31 | assert expected == actual
32 |
33 |
34 | def test_read_merges_with_wrong_delimiter(file_handle_mock):
35 | with pytest.raises(ValueError):
36 | file_handle_mock.__iter__.return_value = iter(['a\tb\t67'])
37 |
38 | merge.read_merges('file')
39 |
40 |
41 | def test_dump_merges(file_handle_mock):
42 | merges = MergeList().append(Merge(('a', 'b'), 67, 0)).append(Merge(('b', 'c'), 34, 1))
43 | merge.dump_merges(merges, 'file')
44 |
45 | file_handle_mock.write.assert_has_calls([
46 | mock.call('a b 67\n'),
47 | mock.call('b c 34\n')
48 | ])
--------------------------------------------------------------------------------
/tests/bpepkg/wild_bpe_performance.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import random
6 |
7 | from typing import List
8 |
9 | from codeprep.bpepkg import wild_bpe
10 | from codeprep.bpepkg.wild_bpe import BpePerformanceStatsEntry, run
11 |
12 |
13 | def gen_performance_test_case(data_size_mb: float, entropy: int):
14 | char_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
15 | for i in range(int(data_size_mb * (2 ** 20))):
16 | w = random.choice(char_list[:2 ** entropy])
17 | for i in range(len(w)):
18 | yield w[i]
19 |
20 |
21 | def plotting_function(data_size_mb: float, entropy: int, version: str,
22 | performance_stats: List[BpePerformanceStatsEntry],
23 | final_show: bool = True):
24 | merges_done = list(map(lambda p: p.merges_done, performance_stats))
25 | n_pq_entries = list(map(lambda p: p.n_priority_queue_entries, performance_stats))
26 | n_index_entries = list(map(lambda p: p.n_index_entries, performance_stats))
27 | location_index_obj_size = list(map(lambda p: p.location_index_obj_size, performance_stats))
28 | neighbour_index_obj_size = list(map(lambda p: p.neighbour_index_obj_size, performance_stats))
29 | priority_counter_obj_size = list(map(lambda p: p.priority_counter_obj_size, performance_stats))
30 | total_index_size = [a + b + c for a, b, c in
31 | zip(location_index_obj_size, neighbour_index_obj_size, priority_counter_obj_size)]
32 | merge_time = list(map(lambda p: p.time_for_last_merge, performance_stats))
33 |
34 | import matplotlib.pyplot as plt
35 | fig, splots = plt.subplots(nrows=3)
36 | splots[0].plot(merges_done, n_pq_entries, label='priority counter')
37 | splots[0].plot(merges_done, n_index_entries, label='indices')
38 | splots[0].set(ylabel='entries',
39 | title=f'Wild BPE version {version}, Data size: {data_size_mb} MB, entropy: {entropy} bit')
40 | splots[0].legend(loc='upper right')
41 |
42 | splots[1].plot(merges_done, location_index_obj_size, label='location index')
43 | splots[1].plot(merges_done, neighbour_index_obj_size, label='neighbour index')
44 | splots[1].plot(merges_done, priority_counter_obj_size, label='priority counter')
45 | splots[1].plot(merges_done, total_index_size, label='total')
46 | splots[1].set(xlabel='number of merges', ylabel='memory consumed (MB)')
47 | splots[1].legend(loc='upper right')
48 |
49 | splots[2].plot(merges_done, merge_time)
50 | splots[2].set(xlabel='number of merges', ylabel='time per merge (s)')
51 |
52 | fig.savefig(f'{version}_{data_size_mb}_{entropy}bit.png')
53 | if final_show:
54 | plt.show()
55 |
56 |
57 | def test_performance():
58 | test_cases = [
59 | {'mb': 0.05, 'entropy': 1},
60 | {'mb': 0.05, 'entropy': 2},
61 | {'mb': 0.05, 'entropy': 3},
62 | {'mb': 0.5, 'entropy': 1},
63 | {'mb': 0.5, 'entropy': 2},
64 | {'mb': 0.5, 'entropy': 3},
65 | {'mb': 5, 'entropy': 1},
66 | {'mb': 5, 'entropy': 2},
67 | {'mb': 5, 'entropy': 3},
68 | ]
69 | stats_every_n = 50
70 |
71 | for test_case in test_cases:
72 | gen = run(gen_performance_test_case(test_case['mb'], test_case['entropy']), include_performance_stats_every_n_merges=stats_every_n)
73 | for i in range(1000):
74 | try:
75 | merge, occurences, stats = next(gen)
76 | except StopIteration:
77 | pass
78 | plotting_function(test_case['mb'], test_case['entropy'], wild_bpe.__version__, stats)
79 |
80 |
81 | if __name__ == '__main__':
82 | test_performance()
--------------------------------------------------------------------------------
/tests/cli/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/tests/infrastructure/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/tests/infrastructure/test_bpelearner.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from unittest import mock
6 |
7 | import pytest
8 |
9 | from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam, BpeConfigNotSupported
10 | from codeprep.pipeline.bpelearner import run
11 |
12 |
13 | @mock.patch('codeprep.pipeline.bpelearner.Dataset', autospec=True)
14 | def test_run_word_end(mocked_dataset):
15 | bpe_config = BpeConfig({
16 | BpeParam.BASE: 'code',
17 | BpeParam.WORD_END: True,
18 | BpeParam.UNICODE: 'yes',
19 | BpeParam.CASE: 'yes'
20 | })
21 | with pytest.raises(BpeConfigNotSupported):
22 | run(mocked_dataset, 1, bpe_config)
23 |
24 |
25 | @mock.patch('codeprep.pipeline.bpelearner.Dataset', autospec=True)
26 | def test_run_bytes_bpe(mocked_dataset):
27 | bpe_config = BpeConfig({
28 | BpeParam.BASE: 'code',
29 | BpeParam.WORD_END: False,
30 | BpeParam.UNICODE: 'bytes',
31 | BpeParam.CASE: 'yes'
32 | })
33 | with pytest.raises(BpeConfigNotSupported):
34 | run(mocked_dataset, 1, bpe_config)
--------------------------------------------------------------------------------
/tests/infrastructure/test_bperegistry.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | from unittest.mock import patch
7 |
8 | from codeprep.pipeline.bperegistry import get_max_merges, format_available_merge_list_ids, get_min_merges
9 | from codeprep.pipeline.dataset import create_new_id_from
10 |
11 |
12 | PATH_TO_DATASET_BPE_DIR_STUB = os.path.join('/', 'path', 'to', 'dataset', 'bpe', 'dir')
13 | PATH_TO_DATASET_STUB = os.path.join('/', 'path', 'to', 'dataset')
14 | HLIB_PATH = '/home/hlib/path'
15 |
16 |
17 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
18 | def test_with_predefined_id(bpe_config_mock):
19 | bpe_config_mock.to_suffix.return_value = ''
20 | assert create_new_id_from(PATH_TO_DATASET_STUB, bpe_config_mock, 'id23') == 'id23'
21 |
22 |
23 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
24 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, return_value={})
25 | def test_no_existing_bpe_codes(mock, bpe_config_mock):
26 | bpe_config_mock.to_suffix.return_value = ''
27 | assert create_new_id_from(PATH_TO_DATASET_STUB, bpe_config_mock) == 'dataset'
28 |
29 |
30 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
31 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True,
32 | return_value={'dataset': 10, 'dataset4': 20, 'dataset_3': 30})
33 | def test_ids_for_same_dataset_exist(mock, bpe_config_mock):
34 | bpe_config_mock.to_suffix.return_value = ''
35 | assert create_new_id_from(PATH_TO_DATASET_STUB, bpe_config_mock) == 'dataset_4'
36 |
37 |
38 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
39 | def test_with_predefined_codes_id(bpe_config_mock):
40 | bpe_config_mock.to_suffix.return_value = ""
41 | assert create_new_id_from(HLIB_PATH, bpe_config_mock, 'my-id') == 'my-id'
42 |
43 |
44 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
45 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, return_value="")
46 | def test_simple(mock, bpe_config_mock):
47 | # given
48 | bpe_config_mock.to_suffix.return_value = ""
49 |
50 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path'
51 |
52 |
53 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
54 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True,
55 | return_value={'path': 1000})
56 | def test_same_path_exists(mock, bpe_config_mock):
57 | # given
58 | bpe_config_mock.to_suffix.return_value = ""
59 |
60 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path_1'
61 |
62 |
63 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
64 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True,
65 | return_value={'path': 1000, 'path_1': 2000})
66 | def test_same_path_and_next_one_exist(mock, bpe_config_mock):
67 | # given
68 | bpe_config_mock.to_suffix.return_value = ""
69 |
70 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path_2'
71 |
72 |
73 | @patch("codeprep.bpepkg.bpe_config.BpeConfig", autospec=True)
74 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True,
75 | return_value={'path': 1000, 'path_28': 2000})
76 | def test_same_path_and_one_more_exist(mock, bpe_config_mock):
77 | # given
78 | bpe_config_mock.to_suffix.return_value = ""
79 |
80 | assert create_new_id_from(HLIB_PATH, bpe_config_mock) == 'path_29'
81 |
82 |
83 | @patch('codeprep.pipeline.bperegistry.os.walk', return_value=iter([('', [], [])]))
84 | def test_none(mocked_walk):
85 | assert get_max_merges('.') is None
86 |
87 |
88 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True, return_value={})
89 | def test_no_available_merge_lists(bpe_registry_mock):
90 | assert format_available_merge_list_ids() == ""
91 |
92 |
93 | @patch('codeprep.pipeline.bperegistry._get_all_custom_bpe_codes_and_max_merges', autospec=True,
94 | return_value={"a": 1000, "b": 500})
95 | def test_simple(mock):
96 | assert format_available_merge_list_ids() == "a-[1..1000]\nb-[1..500]\n"
97 |
98 |
99 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=[])
100 | def test_max_no_folders(mock):
101 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None
102 |
103 |
104 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=[])
105 | def test_min_no_folders(mock):
106 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None
107 |
108 |
109 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['part_vocab'])
110 | def test_max_with_non_number_folder(mock):
111 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None
112 |
113 |
114 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['part_vocab'])
115 | def test_min_with_non_number_folder(mock):
116 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB) is None
117 |
118 |
119 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', '20'])
120 | def test_max_all_folders_above_limit(mock):
121 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB, 5) is None
122 |
123 |
124 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', '20'])
125 | def test_min_all_folders_below_limit(mock):
126 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB) == 10
127 |
128 |
129 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', 'partvocab'])
130 | def test_max_one_folder_available(mock):
131 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) == 10
132 |
133 |
134 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True, return_value=['10', 'partvocab'])
135 | def test_min_one_folder_available(mock):
136 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB) == 10
137 |
138 |
139 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True,
140 | return_value=['10', '20', '15', '30', 'partvocab'])
141 | def test_max_simple(mock):
142 | assert get_max_merges(PATH_TO_DATASET_BPE_DIR_STUB, 20) == 20
143 |
144 |
145 | @patch('codeprep.pipeline.bperegistry._get_all_bpe_merges_dirs', autospec=True,
146 | return_value=['10', '20', '15', '30', 'partvocab'])
147 | def test_min_simple(mock):
148 | assert get_min_merges(PATH_TO_DATASET_BPE_DIR_STUB, 15) == 15
--------------------------------------------------------------------------------
/tests/infrastructure/test_dataset.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | from unittest import mock
7 |
8 | from codeprep.bpepkg.bpe_config import BpeConfig, BpeParam
9 | from codeprep.config import USER_CONFIG_DIR, VOCAB_DIR, BPE_DIR, USER_CACHE_DIR
10 | from codeprep.pipeline.bperegistry import CustomBpeConfig
11 | from codeprep.pipeline.dataset import Dataset, SubDataset
12 | from codeprep.prepconfig import PrepConfig, PrepParam
13 |
14 |
15 | PATH_TO_DATASET_STUB = os.path.join('/', 'path', 'to', 'dataset')
16 | PARSED_DATASETS_DIR = os.path.join('/', 'parsed', 'datasets')
17 | PREP_DATASETS_DIR = os.path.join('/', 'prep', 'datasets')
18 | OVERRIDDEN_PATH = os.path.join('/', 'overridden', 'path')
19 |
20 |
21 | @mock.patch('os.path.exists', autospec=True, return_value=True)
22 | @mock.patch('codeprep.pipeline.dataset.get_timestamp', autospec=True, return_value="01_01_01")
23 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PARSED_DATASETS_DIR', PARSED_DATASETS_DIR)
24 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PREP_DATASETS_DIR', PREP_DATASETS_DIR)
25 | def test_non_bpe_split(get_timestamp_mock, os_exists_mock):
26 | prep_config = PrepConfig({
27 | PrepParam.EN_ONLY: 'u',
28 | PrepParam.COM: 'c',
29 | PrepParam.STR: '1',
30 | PrepParam.SPLIT: '0',
31 | PrepParam.TABS_NEWLINES: 's',
32 | PrepParam.CASE: 'u'
33 | })
34 |
35 | actual = Dataset.create(PATH_TO_DATASET_STUB, prep_config, None, None)
36 |
37 | assert PATH_TO_DATASET_STUB == actual._path
38 | assert prep_config == actual._prep_config
39 | assert actual._normalized_extension_list is None
40 | assert actual._custom_bpe_config is None
41 | assert actual._bpe_config is None
42 | assert '01_01_01', actual._dataset_last_modified
43 |
44 | assert SubDataset(actual, PATH_TO_DATASET_STUB, '') == actual._original
45 | assert SubDataset(actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01'), '.parsed') == actual._parsed
46 | assert SubDataset(actual, os.path.join(PREP_DATASETS_DIR, 'dataset_01_01_01_-_uc10su'), '.prep') == actual._preprocessed
47 |
48 |
49 | @mock.patch('os.path.exists', autospec=True, return_value=True)
50 | @mock.patch('codeprep.pipeline.dataset.get_timestamp', autospec=True, return_value="01_01_01")
51 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PARSED_DATASETS_DIR', PARSED_DATASETS_DIR)
52 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PREP_DATASETS_DIR', PREP_DATASETS_DIR)
53 | def test_non_bpe_split_with_one_extension(get_timestamp_mock, os_exists_mock):
54 | prep_config = PrepConfig({
55 | PrepParam.EN_ONLY: 'u',
56 | PrepParam.COM: 'c',
57 | PrepParam.STR: '1',
58 | PrepParam.SPLIT: '0',
59 | PrepParam.TABS_NEWLINES: 's',
60 | PrepParam.CASE: 'u'
61 | })
62 |
63 | actual = Dataset.create(PATH_TO_DATASET_STUB, prep_config, "java", None)
64 |
65 | assert PATH_TO_DATASET_STUB == actual._path
66 | assert prep_config == actual._prep_config
67 | assert ['java'] == actual._normalized_extension_list
68 | assert actual._custom_bpe_config is None
69 | assert actual._bpe_config is None
70 | assert '01_01_01' == actual._dataset_last_modified
71 |
72 | assert SubDataset(actual, PATH_TO_DATASET_STUB, ''), actual._original
73 | assert SubDataset(actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01_java'), '.parsed'), actual._parsed
74 | assert SubDataset(actual, os.path.join(PREP_DATASETS_DIR, 'dataset_01_01_01_java_-_uc10su'), '.prep'), actual._preprocessed
75 |
76 |
77 | @mock.patch('os.path.exists', autospec=True, return_value=True)
78 | @mock.patch('codeprep.pipeline.dataset.get_timestamp', autospec=True, return_value="01_01_01")
79 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PARSED_DATASETS_DIR', PARSED_DATASETS_DIR)
80 | @mock.patch('codeprep.pipeline.dataset.DEFAULT_PREP_DATASETS_DIR', PREP_DATASETS_DIR)
81 | def test_all_custom(get_timestamp_mock, os_exists_mock):
82 | prep_config = PrepConfig({
83 | PrepParam.EN_ONLY: 'u',
84 | PrepParam.COM: 'c',
85 | PrepParam.STR: '1',
86 | PrepParam.SPLIT: '0',
87 | PrepParam.TABS_NEWLINES: 's',
88 | PrepParam.CASE: 'u'
89 | })
90 | bpe_config = BpeConfig({
91 | BpeParam.CASE: 'yes',
92 | BpeParam.WORD_END: False,
93 | BpeParam.BASE: "code",
94 | BpeParam.UNICODE: "no",
95 | })
96 |
97 | custom_bpe_config = CustomBpeConfig("id", 1000, "/codes/file", "/cache/file")
98 | actual = Dataset.create(PATH_TO_DATASET_STUB, prep_config, "c|java", custom_bpe_config,
99 | bpe_config, overriden_path_to_prep_dataset=OVERRIDDEN_PATH)
100 |
101 | assert PATH_TO_DATASET_STUB == actual._path
102 | assert prep_config == actual._prep_config
103 | assert ['c', 'java'] == actual._normalized_extension_list
104 | assert custom_bpe_config == actual._custom_bpe_config
105 | assert bpe_config == actual._bpe_config
106 | assert '01_01_01' == actual._dataset_last_modified
107 |
108 | assert SubDataset(actual, PATH_TO_DATASET_STUB, '') == actual.original
109 | assert SubDataset(actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01_c_java'), '.parsed') == actual.parsed
110 | assert SubDataset(actual, os.path.join(OVERRIDDEN_PATH, 'dataset_01_01_01_c_java_-_uc10su_id-1000_-_prep'), '.prep') == actual.preprocessed
111 | assert os.path.join(USER_CONFIG_DIR, VOCAB_DIR , 'dataset_01_01_01_c_java_-_U0EFsu') == actual.base_bpe_vocab_path
112 | assert os.path.join(USER_CONFIG_DIR, BPE_DIR , 'dataset_01_01_01_c_java_-_nounicode') == actual.bpe_path
113 | assert os.path.join(USER_CACHE_DIR, 'file_lists' , 'dataset_01_01_01_c_java') == actual.path_to_file_list_folder
114 | assert os.path.join(USER_CONFIG_DIR, VOCAB_DIR , 'dataset_01_01_01_c_java_-_uc10su_id-1000') == actual.vocab_path
--------------------------------------------------------------------------------
/tests/parse/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/tests/parse/test_subtokens.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from codeprep.tokens.numeric import Number
6 |
7 | from codeprep.parse.matchers import split_into_words
8 | from codeprep.tokens.containers import SplitContainer
9 | from codeprep.tokens.whitespace import NewLine, SpaceInString
10 | from codeprep.tokens.word import Word, Underscore
11 | from codeprep.parse.subtokens import split_string
12 |
13 |
14 | def test_split_into_tokens():
15 | actual = split_into_words("123\nAb2cd34Ef000GG j_89_J")
16 |
17 | expected = [Number('123'),
18 | NewLine(),
19 | SplitContainer([Word.from_('Ab'), Word.from_('2'), Word.from_('cd'),
20 | Word.from_('34'), Word.from_('Ef'), Word.from_('000'), Word.from_('GG')]),
21 | SplitContainer([Word.from_('j'), Underscore(), Word.from_('89'), Underscore(), Word.from_('J')])]
22 |
23 | assert expected == actual
24 |
25 |
26 | def test_split_string():
27 | actual = split_string("123\nAb2cd34Ef000GG j_89_J")
28 |
29 | expected = [Number('123'),
30 | NewLine(),
31 | SplitContainer([Word.from_('Ab'), Word.from_('2'), Word.from_('cd'),
32 | Word.from_('34'), Word.from_('Ef'), Word.from_('000'), Word.from_('GG')]),
33 | SpaceInString(5),
34 | SplitContainer([Word.from_('j'), Underscore(), Word.from_('89'), Underscore(), Word.from_('J')])]
35 |
36 | assert expected == actual
--------------------------------------------------------------------------------
/tests/test_corpus_b2b.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | import platform
7 | import shutil
8 |
9 | from codeprep.cli.spec import parse_and_run
10 |
11 | import codeprep.api.corpus as api
12 | from codeprep.config import root_package_dir
13 |
14 | PATH_TO_TEST_CORPUS = os.path.join(root_package_dir, '..', 'test-data', 'test-corpus')
15 | TEST_OUTPUT = os.path.join(root_package_dir, '..', 'test-output')
16 |
17 |
18 | def test_preprocess_with_different_options():
19 | calc_vocab = True
20 | api.basic(path=PATH_TO_TEST_CORPUS, extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab)
21 | api.basic(path=PATH_TO_TEST_CORPUS, extensions="java", split_numbers=True, ronin=True, stem=True,
22 | no_spaces=True, no_unicode=True, no_case=True, no_com=True, no_str=True, max_str_length=30,
23 | output_path=TEST_OUTPUT, calc_vocab=calc_vocab)
24 | api.chars(path=PATH_TO_TEST_CORPUS, extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab)
25 | api.nosplit(path=PATH_TO_TEST_CORPUS, extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab)
26 | api.bpe(path=PATH_TO_TEST_CORPUS, bpe_codes_id='10k', extensions="java", output_path=TEST_OUTPUT, calc_vocab=calc_vocab)
27 |
28 |
29 | def test_learn_bpe_codes():
30 | if platform.system() != 'Darwin':
31 | parse_and_run(['learn-bpe', '100', '-p', PATH_TO_TEST_CORPUS, '-e', 'java'])
32 | parse_and_run(['learn-bpe', '150', '-p', PATH_TO_TEST_CORPUS, '-e', 'java'])
33 |
34 | api.bpe(path=PATH_TO_TEST_CORPUS, bpe_codes_id='test-corpus-130', extensions="java", output_path=TEST_OUTPUT)
35 | else:
36 | print('Skipping the test on OSx.')
37 |
38 |
39 | def teardown_function(function):
40 | print(f'Removing the outputs at: {TEST_OUTPUT}')
41 | if os.path.exists(TEST_OUTPUT):
42 | shutil.rmtree(TEST_OUTPUT)
43 |
--------------------------------------------------------------------------------
/tests/test_subword_separation.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2020 Hlib Babii
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # import unittest
6 | #
7 | # from codeprep.bpepkg.bpe_encode import BpeData
8 | # from codeprep.parse.core import convert_text
9 | # from codeprep.parse.model.containers import SplitContainer
10 | # from codeprep.parse.model.numeric import Number
11 | # from codeprep.parse.model.placeholders import placeholders
12 | # from codeprep.parse.model.word import Underscore, Word
13 | # from codeprep.prepconfig import PrepConfig
14 | # from codeprep.to_repr import to_repr
15 | #
16 | # test_cases = {
17 | # "create": (
18 | # [SplitContainer.from_single_token("create")],
19 | # ["create"],
20 | # ),
21 | # "Vector": (
22 | # [SplitContainer.from_single_token("Vector")],
23 | # [placeholders["capital"], "vector"],
24 | # ),
25 | # "players": (
26 | # [SplitContainer.from_single_token("players")],
27 | # [placeholders["word_start"], 'play', 'er', 's', placeholders["word_end"]]
28 | # ),
29 | # "0.345e+4": (
30 | # [Number("0.345e+4")],
31 | # [placeholders["word_start"], "0.", "3", "4", "5", "e+", "4", placeholders["word_end"]]
32 | # ),
33 | # "bestPlayers": (
34 | # [SplitContainer([Word.from_("best"), Word.from_("Players")])],
35 | # [placeholders["word_start"], "best", placeholders["capital"], 'play', "er", "s", placeholders["word_end"]]
36 | # ),
37 | # "test_BestPlayers": (
38 | # [SplitContainer([Word.from_("test"), Underscore(), Word.from_("Best"), Word.from_("Players")])],
39 | # [placeholders["word_start"], "test", '_', placeholders["capital"],
40 | # "best", placeholders["capital"], 'play', "er", "s", placeholders["word_end"]]
41 | # ),
42 | # "test_BestPlayers_modified": (
43 | # [SplitContainer(
44 | # [Word.from_("test"), Underscore(), Word.from_("Best"), Word.from_("Players"), Underscore(),
45 | # Word.from_("modified")]
46 | # )],
47 | # [placeholders["word_start"], "test", '_', placeholders["capital"],
48 | # "best", placeholders["capital"], 'play', "er", "s", '_', "mod",
49 | # "if", "ied",
50 | # placeholders["word_end"]]
51 | # ),
52 | # "N_PLAYERS_NUM": (
53 | # [SplitContainer([Word.from_("N"), Underscore(), Word.from_("PLAYERS"), Underscore(), Word.from_("NUM")])],
54 | # [placeholders["word_start"], placeholders["capitals"], "n", '_',
55 | # placeholders["capitals"], "play", "er", "s", '_', placeholders["capitals"],
56 | # "num", placeholders["word_end"]]
57 | # ),
58 | # "_players": (
59 | # [SplitContainer([Underscore(), (Word.from_("players"))])],
60 | # [placeholders['word_start'], '_', "play", "er", "s", placeholders['word_end']]
61 | # ),
62 | # }
63 | #
64 | # bpe_merges_cache = {
65 | # "players": ["play", "er", "s"],
66 | # "0.345e+4": ["0.", "3", "4", "5", "e+", "4"],
67 | # "modified": ["mod", "if", "ied"],
68 | #
69 | # "create": ["create"],
70 | # "vector": ["vector"],
71 | # "best": ["best"],
72 | # "test": ["test"],
73 | # "num": ["num"],
74 | # "user": ["user"],
75 | # "get": ["get"],
76 | # "nick": ["ni", "ck"],
77 | # "logger": ["logger"],
78 | # "info": ["info"]
79 | # }
80 | #
81 | #
82 | # class SubwordSeparation(unittest.TestCase):
83 | # def test(self):
84 | # for input, output_tuple in test_cases.items():
85 | # parsed = [p for p in convert_text(input, "java")][:-1]
86 | #
87 | # self.assertEqual(output_tuple[0], parsed)
88 | #
89 | # repred, metadata = to_repr(PrepConfig.from_encoded_string('Uc140l'), parsed, BpeData(merges_cache=bpe_merges_cache))
90 | #
91 | # self.assertEqual(output_tuple[1], repred)
92 | #
93 | #
94 | # if __name__ == '__main__':
95 | # unittest.main()
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # tox (https://tox.readthedocs.io/) is a tool for running tests
2 | # in multiple virtualenvs. This configuration file will run the
3 | # test suite on all supported python versions. To use it, "pip install tox"
4 | # and then run "tox" from this directory.
5 |
6 | [tox]
7 | envlist = py36
8 |
9 | [testenv]
10 | deps = git+https://github.com/casics/spiral
11 |
12 | commands =
13 | python -m unittest discover
14 |
--------------------------------------------------------------------------------