├── docs ├── .nojekyll ├── .gitignore ├── _static │ └── .donotdelete ├── _templates │ └── .donotdelete ├── index.rst ├── README.md ├── Makefile ├── conf.py └── introduction.rst ├── sotastream ├── utils │ ├── __init__.py │ ├── phrases.py │ └── split.py ├── filters │ ├── __init__.py │ └── filters.py ├── augmentors │ ├── __init__.py │ └── augmentors.py ├── __main__.py ├── pipelines │ ├── default.py │ ├── __init__.py │ ├── example_pipeline.py │ ├── multistream_pipeline.py │ ├── mtdata_pipeline.py │ └── base.py ├── __init__.py ├── data.py └── cli.py ├── test ├── regression │ ├── .gitignore │ ├── data │ │ ├── wmt21.fr-de.tsv.gz │ │ └── wmt22.en-de.tsv.gz │ ├── Makefile │ ├── tests │ │ └── pipelines │ │ │ ├── test_default_pipeline.sh │ │ │ └── test_example_pipeline.sh │ ├── README.md │ ├── test_mtdata.py │ ├── test_multistream.py │ └── run.sh ├── test_filter.py ├── test_line.py ├── dummy_pipeline.py ├── test_pipeline.py └── test_augmentors.py ├── CHANGELOG.md ├── Makefile ├── LICENSE ├── .readthedocs.yaml ├── .gitignore ├── pyproject.toml └── README.md /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | api -------------------------------------------------------------------------------- /docs/_static/.donotdelete: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_templates/.donotdelete: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sotastream/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sotastream/filters/__init__.py: -------------------------------------------------------------------------------- 1 | from .filters import * 2 | -------------------------------------------------------------------------------- /sotastream/augmentors/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentors import * 2 | -------------------------------------------------------------------------------- /test/regression/.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.diff 3 | *.log 4 | *.out 5 | data/trec -------------------------------------------------------------------------------- /sotastream/__main__.py: -------------------------------------------------------------------------------- 1 | from sotastream.cli import main 2 | 3 | main() 4 | -------------------------------------------------------------------------------- /test/regression/data/wmt21.fr-de.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marian-nmt/sotastream/HEAD/test/regression/data/wmt21.fr-de.tsv.gz -------------------------------------------------------------------------------- /test/regression/data/wmt22.en-de.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marian-nmt/sotastream/HEAD/test/regression/data/wmt22.en-de.tsv.gz -------------------------------------------------------------------------------- /test/regression/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all clean run setup 2 | .SECONDARY: 3 | 4 | run: 5 | bash ./run.sh 6 | 7 | setup: 8 | bash ./setup.sh 9 | 10 | all: setup run 11 | 12 | clean: 13 | git clean -x -d -f tests 14 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [1.0.1] --- 2023-08-28 4 | 5 | ### Fixed 6 | - Moved random seed initialization from DataSource to Constructor 7 | - Read version from project file manually instead of via importlib, 8 | which created problems with Python 3.8 9 | 10 | ## [1.0.0] --- 2023-07-31 11 | 12 | Initial public release. 13 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: check reformat test regression 2 | 3 | check: checkformat test regression 4 | 5 | checkformat: 6 | python -m black --check . || (echo "Please run 'make reformat' to fix formatting issues" && exit 1) 7 | 8 | reformat: 9 | python -m black . 10 | 11 | # unit tests; ignore tests/regression 12 | test: 13 | python -m pytest test/ --ignore test/regression/ 14 | 15 | regression: 16 | python -m pytest test/regression 17 | cd test/regression && bash run.sh 18 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Sotastream documentation master file, created by 2 | sphinx-quickstart on Mon Jul 17 11:12:25 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Sotastream's documentation! 7 | ====================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: 12 | 13 | introduction 14 | api/modules 15 | 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Docs 2 | 3 | 4 | ## Build Docs 5 | ```bash 6 | pip install -U sphinx sphinx_rtd_theme 7 | make clean 8 | make html 9 | ``` 10 | 11 | 12 | 13 | ## Release Package to PyPI 14 | 15 | ```bash 16 | 17 | # run unit and regression tests 18 | make check 19 | 20 | pip install --upgrade build pip twine 21 | rm -rf dist/ 22 | python -m build --sdist --wheel -o dist/ 23 | 24 | # create your ~/.pypirc, if missing 25 | twine upload -r testpypi dist/* 26 | twine upload -r pypi dist/* 27 | 28 | ``` 29 | 30 | 31 | ## Update Docs 32 | 33 | Go to https://readthedocs.org/projects/sotastream/ and click/touch "Build" button. 34 | -------------------------------------------------------------------------------- /test/regression/tests/pipelines/test_default_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ##################################################################### 4 | # SUMMARY: Run the default pipeline with default options 5 | ##################################################################### 6 | 7 | # Exit on error 8 | set -eo pipefail 9 | 10 | # Test command 11 | $AB_PYTHON -B -m sotastream --seed 1111 -b 1 -n 1 \ 12 | default $AB_DATA/wmt22.en-de.tsv.gz \ 13 | | head -n 1000 > default.out 14 | 15 | # Compare with the expected output 16 | diff default.out default.expected > default.diff 17 | 18 | # Exit with success code 19 | exit 0 20 | -------------------------------------------------------------------------------- /sotastream/pipelines/default.py: -------------------------------------------------------------------------------- 1 | from sotastream.augmentors import DataSource, UTF8File 2 | 3 | from . import Pipeline, pipeline 4 | 5 | 6 | @pipeline('default') 7 | class DefaultPipeline(Pipeline): 8 | def __init__(self, parallel_data, **kwargs): 9 | super().__init__(**kwargs) 10 | 11 | self.stream = self.create_data_stream(parallel_data) 12 | 13 | @classmethod 14 | def get_data_sources_for_argparse(cls): 15 | return [('parallel_data', 'Path to parallel data (folder with .gz files, or compressed TSV)')] 16 | 17 | @classmethod 18 | def get_data_sources_default_weights(cls): 19 | return [1.0] 20 | -------------------------------------------------------------------------------- /test/regression/tests/pipelines/test_example_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ##################################################################### 4 | # SUMMARY: Run 't2' pipeline with basic options 5 | ##################################################################### 6 | 7 | # Exit on error 8 | set -eo pipefail 9 | 10 | # Test command 11 | $AB_PYTHON -B -m sotastream --seed 1111 -b 1 -n 1 \ 12 | example $AB_DATA/wmt22.en-de.tsv.gz $AB_DATA/wmt21.fr-de.tsv.gz \ 13 | | head -n 1000 > example.out 14 | 15 | # Compare with the expected output 16 | diff example.out example.expected > example.diff 17 | 18 | # Exit with success code 19 | exit 0 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SPHINXAPIDOC ?= sphinx-apidoc 9 | SOURCEDIR = . 10 | BUILDDIR = build 11 | MODULE_NAME = sotastream 12 | CODEDIR = ../$(MODULE_NAME) 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | 18 | .PHONY: help Makefile 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | 23 | %: Makefile 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | @if [ $@ = "clean" ]; then rm -rf "$(SOURCEDIR)/api/"; fi 26 | -------------------------------------------------------------------------------- /sotastream/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | __version__ = "1.0.1" 4 | sys.dont_write_bytecode = True 5 | 6 | 7 | class Defaults: 8 | """ 9 | Default values for pipeline arguments 10 | """ 11 | 12 | BUFFER_SIZE = 1_000_000 13 | QUEUE_BUFFER_SIZE = 10_000 14 | SEPARATOR = " " 15 | DOC_SEPARATOR = " " 16 | SAMPLE_FILE = None 17 | SPM_MODEL = None 18 | SEED = 0 19 | MAX_TOKENS = 250 20 | SAMPLE_LENGTH = True 21 | QUIET = False 22 | NUM_PROCESSES = 16 23 | DOC_PROB = 0.0 24 | DOC_PROB_PARALLEL = 0.0 25 | SHUFFLE = True 26 | 27 | 28 | from .filters import * 29 | 30 | # BUG: note that this will result in import order bug: .augmentors.{doc,robustness}.* won't be available under .augmentors 31 | # from .augmentors import * 32 | from . import augmentors 33 | from .data import Line 34 | from .utils import * 35 | from .cli import main as cli_main 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /test/test_filter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | 6 | sys.dont_write_bytecode = True 7 | 8 | import pytest 9 | 10 | from sotastream.data import Line 11 | from sotastream.filters import * 12 | 13 | from test_augmentors import ToLines, TEST_CORPUS 14 | 15 | URL_CORPUS = [ 16 | "http://microsoft.com\thttp://microsoft.com", 17 | "No URL here\tNo URL here", 18 | "http://google.com in the US\thttp://microsoft.com in the US", 19 | ] 20 | 21 | 22 | from test_line import inputs 23 | 24 | 25 | @pytest.mark.parametrize("corpus", [inputs, TEST_CORPUS, TEST_CORPUS, URL_CORPUS]) 26 | def test_bitext_filter(corpus): 27 | """ 28 | Test that the bitext filter is working. It reduces a line object to just the first 29 | two fields. 30 | """ 31 | 32 | for line, wholeline, bitextline in zip(corpus, ToLines(corpus), BitextFilter(ToLines(corpus))): 33 | length = len(line.split("\t")) 34 | 35 | assert len(wholeline) == length 36 | assert len(bitextline) == min(length, 2) 37 | -------------------------------------------------------------------------------- /test/test_line.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | 6 | sys.dont_write_bytecode = True 7 | 8 | import pytest 9 | 10 | from sotastream.data import Line 11 | 12 | from test_augmentors import TEST_CORPUS, ToLines 13 | from itertools import zip_longest 14 | from copy import copy 15 | 16 | 17 | def test_line(): 18 | text = "Das ist ein Test\tThis is a test." 19 | source, target = text.split("\t") 20 | 21 | line = Line(text) 22 | assert line[0] == source 23 | assert line[1] == target 24 | 25 | 26 | inputs = [ 27 | "", 28 | "\tJust the target side, please.", 29 | "Nur die Quellseite, bitte\t", 30 | "Just an ambiguous sentence that should be rendered as the source", 31 | "This has\tlots of\tfields\tthat do not get\tprinted", 32 | "Here are a\tnumber of\tfields that I hope will\tbe joined\t.", 33 | ] 34 | 35 | 36 | @pytest.mark.parametrize("text", inputs) 37 | def test_str(text): 38 | line = Line(text) 39 | assert str(line) == text 40 | assert len(line) == len(text.split("\t")) 41 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | build: 8 | os: ubuntu-22.04 9 | tools: 10 | python: "3.11" 11 | # You can also specify other tool versions: 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 17 | # builder: "dirhtml" 18 | # Fail on all warnings to avoid broken references 19 | # fail_on_warning: true 20 | 21 | # Optionally build your docs in additional formats such as PDF and ePub 22 | # formats: 23 | # - pdf 24 | # - epub 25 | 26 | # Optional but recommended, declare the Python requirements required 27 | # to build your documentation 28 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 29 | # https://docs.readthedocs.io/en/stable/config-file/v2.html#packages 30 | python: 31 | install: 32 | - method: pip 33 | path: . 34 | extra_requirements: 35 | - dev 36 | -------------------------------------------------------------------------------- /test/dummy_pipeline.py: -------------------------------------------------------------------------------- 1 | # this pipeline is created to test loading of custom/private pipelines from a runtime directory 2 | 3 | from sotastream.pipelines import Pipeline, pipeline 4 | from sotastream.pipelines.default import DefaultPipeline 5 | 6 | import logging as log 7 | 8 | log.basicConfig(level=log.INFO) 9 | 10 | 11 | @pipeline('dummy') 12 | class DummyPipeline(DefaultPipeline): 13 | def __init__(self, path, **kwargs): 14 | super().__init__(path, **kwargs) 15 | self.myarg = kwargs['myarg'] 16 | log.info(f'Loaded sample pipeline with myarg={self.myarg}') 17 | 18 | @classmethod 19 | def add_cli_args(cls, parser): 20 | super().add_cli_args(parser) 21 | parser.add_argument('--myarg', required=True, help='Sample pipeline argument (required)') 22 | 23 | 24 | @pipeline('dummy2') 25 | class DummyPipeline2(DefaultPipeline): 26 | def __init__(self, path, **kwargs): 27 | super().__init__(path, **kwargs) 28 | self.myarg2 = kwargs['myarg2'] 29 | log.info(f'Loaded sample pipeline with myarg={self.myarg2}') 30 | 31 | @classmethod 32 | def add_cli_args(cls, parser): 33 | super().add_cli_args(parser) 34 | parser.add_argument('--myarg2', required=True, help='Sample pipeline argument (required)') 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .pytest_cache/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # Unix 108 | *~ 109 | *.swp 110 | *.swo 111 | .history 112 | tmp* 113 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Path setup -------------------------------------------------------------- 7 | import sys 8 | from pathlib import Path 9 | 10 | DOCS_DIR = Path(__file__).parent.absolute() 11 | PROJECT_DIR = DOCS_DIR.parent 12 | # /docs/conf.py i.e two levels up ^ 13 | SRC_DIR = PROJECT_DIR / 'sotastream' 14 | 15 | sys.path.insert(0, str(PROJECT_DIR)) 16 | import sotastream # this import should work 17 | 18 | # -- Project information ----------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 20 | 21 | project = 'Sotastream' 22 | copyright = '2023, Marian NMT' 23 | author = 'Marian NMT' 24 | 25 | # -- General configuration --------------------------------------------------- 26 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 27 | 28 | extensions = [ 29 | 'sphinx.ext.autodoc', 30 | 'sphinx.ext.mathjax', 31 | 'sphinx.ext.viewcode', 32 | 'sphinx.ext.autodoc', 33 | ] 34 | 35 | templates_path = ['_templates'] 36 | exclude_patterns = [] 37 | 38 | language = 'en' 39 | 40 | # -- Options for HTML output ------------------------------------------------- 41 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 42 | 43 | # html_theme = 'alabaster' 44 | html_theme = 'sphinx_rtd_theme' 45 | html_static_path = ['_static'] 46 | 47 | 48 | def run_apidoc(_): 49 | # from sphinx.apidoc import main # for older Sphinx <= 1.6 50 | from sphinx.ext.apidoc import main # for newer 51 | 52 | main(['-e', '-o', str(DOCS_DIR / 'api'), str(SRC_DIR), '--force']) 53 | 54 | 55 | def setup(app): 56 | app.connect('builder-inited', run_apidoc) 57 | -------------------------------------------------------------------------------- /sotastream/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import sys 4 | import logging as logger 5 | 6 | from pathlib import Path 7 | 8 | from .base import Pipeline, DocumentPipeline 9 | 10 | logger.basicConfig(level=logger.INFO) 11 | PIPELINES = {} # dict of pipeline name -> pipeline class 12 | 13 | 14 | def pipeline(name: str): 15 | """ 16 | Register a pipeline class to the PIPELINES dict. 17 | The name is used to index the class object, e.g., from the command line. 18 | 19 | :param name: name of component i.e., pipeline name e.g, "t1" 20 | :return: a decorator. 21 | """ 22 | assert name not in PIPELINES, f"Pipeline {name} already taken by {PIPELINES[name]}" 23 | 24 | def decorator(cls): 25 | PIPELINES[name] = cls 26 | return cls 27 | 28 | return decorator 29 | 30 | 31 | """ 32 | Load all modules in the current directory matching the pattern "*_pipeline.py". 33 | """ 34 | modules = Path(__file__).parent.glob("*.py") 35 | __all__ = [f.name.replace('.py', '') for f in modules if f.is_file() and not f.name.startswith('__')] 36 | 37 | FAIL_ON_ERROR = os.environ.get('SOTASTREAM_FAIL_ON_ERROR', False) 38 | 39 | for module_name in __all__: 40 | try: 41 | importlib.import_module(f'.{module_name}', __package__) 42 | except Exception as ex: 43 | logger.error(f'Unable to load {module_name}: {ex}') 44 | if FAIL_ON_ERROR: 45 | raise ex 46 | 47 | for path in list(Path(os.getcwd()).glob('*_pipeline.py')): 48 | module_name = path.name.replace('.py', '') 49 | if module_name in sys.modules: 50 | raise Exception( 51 | f'Module name {module_name} from {path} collides with an already imported module.\ 52 | This state might lead to hard-to-find bugs. Please rename your module.' 53 | ) 54 | try: 55 | importlib.import_module(module_name, __package__) 56 | except: 57 | logger.error( 58 | f'Error while importing {path}. \ 59 | Double check that you have installed all the required libraries.' 60 | ) 61 | raise 62 | -------------------------------------------------------------------------------- /sotastream/pipelines/example_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable 3 | from functools import partial 4 | 5 | from sotastream.augmentors import * 6 | from sotastream import Defaults 7 | from sotastream.filters import BitextFilter 8 | 9 | from . import DocumentPipeline, pipeline 10 | 11 | logger = logging.getLogger(f"sotastream") 12 | 13 | 14 | @pipeline("example") 15 | class ExamplePipeline(DocumentPipeline): 16 | description = "Example pipeline with two data streams" 17 | 18 | def __init__(self, parallel_data, backtrans_data, **kwargs): 19 | super().__init__(**kwargs) 20 | 21 | parallel = self.create_data_stream(parallel_data, processor=ReadAndAugment) 22 | backtrans = self.create_data_stream(backtrans_data, processor=partial(ReadAndAugment, tag="")) 23 | 24 | stream = Mixer([parallel, backtrans], self.mix_weights) 25 | self.stream = BitextFilter(stream) # removes all but fields 0 and 1 26 | 27 | @classmethod 28 | def get_data_sources_for_argparse(cls): 29 | return [ 30 | ('parallel_data', 'Path to parallel data (folder with .gz files, or compressed TSV)'), 31 | ('backtrans_data', 'Path to backtranslation data (folder with .gz files, or compressed TSV)'), 32 | ] 33 | 34 | @classmethod 35 | def get_data_sources_default_weights(cls): 36 | return [0.5, 0.5] 37 | 38 | 39 | def LowerCase(stream): 40 | for line in stream: 41 | line[0] = line[0].lower() 42 | yield line 43 | 44 | 45 | def TitleCase(stream): 46 | for line in stream: 47 | line[0] = line[0].title() 48 | line[1] = line[1].title() 49 | yield line 50 | 51 | 52 | def TagData(stream, tag): 53 | for line in stream: 54 | line[0] = f"{tag} {line}" 55 | yield line 56 | 57 | 58 | def ReadAndAugment(path: str, tag: str = None): 59 | """ 60 | Opens a file as a stream and passes it through an augmentor. 61 | """ 62 | stream = UTF8File(path) 63 | 64 | # Randomly mix in case 65 | stream = Mixer( 66 | [ 67 | stream, 68 | LowerCase(stream), 69 | TitleCase(stream), 70 | ], 71 | [0.95, 0.04, 0.01], 72 | ) 73 | 74 | if tag is not None: 75 | stream = TagData(stream, tag) 76 | 77 | return stream 78 | -------------------------------------------------------------------------------- /test/test_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | At the moment, this just checks that the pipelines actually run, 4 | and that they return two fields only; it doesn't test that they 5 | are correct in terms of the distributions they return. But this at 6 | least prevents breaking the pipeline builds. 7 | """ 8 | 9 | 10 | import sys 11 | 12 | sys.dont_write_bytecode = True 13 | 14 | import pytest 15 | import tempfile 16 | from typing import List 17 | 18 | from sotastream import Defaults 19 | from sotastream.data import Line 20 | from sotastream.pipelines import Pipeline 21 | from sotastream.augmentors import * 22 | 23 | from test_augmentors import TEST_CORPUS, ToLines 24 | 25 | 26 | PIPELINES = [ 27 | ("default", [TEST_CORPUS]), # parallel only 28 | ("example", [TEST_CORPUS, TEST_CORPUS]), 29 | ] 30 | 31 | 32 | def create_pipeline(pipeline_name, data_sources: List[List[str]]): 33 | """ 34 | Creates a pipeline by creating temporary files from each of the data_sources, 35 | since DataSource expects file paths. 36 | """ 37 | tmpdir = tempfile.TemporaryDirectory() 38 | 39 | data_files = [] 40 | for source in data_sources: 41 | tmpfile = tempfile.NamedTemporaryFile("wt", suffix=".gz", dir=tmpdir.name, delete=False) 42 | data_files.append(tmpdir.name) 43 | with gzip.open(tmpfile.name, "wt") as outfh: 44 | for line in source: 45 | print(line, file=outfh) 46 | 47 | args = { 48 | "buffer_size": 2, 49 | "separator": Defaults.SEPARATOR, 50 | "doc_separator": Defaults.DOC_SEPARATOR, 51 | "max_tokens": 250, 52 | "sample_length": True, 53 | "doc_prob": 1.0, 54 | "doc_prob_parallel": 0.0, 55 | "mix_weights": [1] * len(data_files), 56 | "augment": False, 57 | 'data_sources': data_files, 58 | } 59 | 60 | pipeline = Pipeline.create(pipeline_name, *data_files, **args) 61 | 62 | return pipeline, tmpdir 63 | 64 | 65 | def cleanup_pipeline(tmpdir): 66 | tmpdir.cleanup() 67 | 68 | 69 | @pytest.mark.parametrize("name, data_sources", PIPELINES) 70 | def test_pipeline(name, data_sources): 71 | pipeline, data_files = create_pipeline(name, data_sources) 72 | 73 | for lineno, line in enumerate(pipeline, 1): 74 | if lineno > 10: 75 | break 76 | 77 | cleanup_pipeline(data_files) 78 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sotastream" 3 | dynamic = ["version"] # see [tool.setuptools.dynamic] below 4 | description = """Sotastream is a command line tool that augments a batch of text and produces infinite stream of records.""" 5 | readme = "README.md" 6 | requires-python = ">=3.6" 7 | license = { file = "LICENSE.txt" } 8 | keywords = [ 9 | "data augmentation", 10 | "machine translation", 11 | "natural language processing", 12 | "text processing", 13 | "text augmentation", 14 | "machine learning", 15 | "deep learning", 16 | "artificial intelligence", 17 | ] 18 | 19 | authors = [ 20 | { name = "Text MT @ Microsoft Translator", email = "marcinjd@microsoft.com" }, 21 | ] 22 | 23 | maintainers = [ 24 | { name = "Thamme Gowda", email = "thammegowda@microsoft.com" }, 25 | { name = "Roman Grundkiewicz", email = "roman.grundkiewicz@microsoft.com" }, 26 | { name = "Matt Post", email = "mattpost@microsoft.com" }, 27 | ] 28 | 29 | classifiers = [ 30 | "Development Status :: 4 - Beta", 31 | "Programming Language :: Python", 32 | ] 33 | 34 | dependencies = [ 35 | "titlecase", 36 | "infinibatch", 37 | "sentencepiece", 38 | "mtdata >= 0.4.0", 39 | ] 40 | 41 | [project.optional-dependencies] 42 | dev = ["black", "sphinx", "sphinx_rtd_theme"] 43 | test = ["pytest < 5.0.0", "pytest-cov[all]"] 44 | 45 | [project.urls] 46 | homepage = "https://github.com/marian-nmt/sotastream" 47 | documentation = "https://github.com/marian-nmt/sotastream" 48 | repository = "https://github.com/marian-nmt/sotastream" 49 | #changelog = "" 50 | 51 | [project.scripts] 52 | sotastream = "sotastream.cli:main" 53 | 54 | # all the above are project metadata, below is configuration for the build system 55 | # there are many build systems: setuptools, flit, poetry, etc. 56 | # we use setuptools here (because we are familiar with setup.py which uses it) 57 | [build-system] 58 | requires = ["setuptools", "wheel"] 59 | build-backend = "setuptools.build_meta" 60 | 61 | [tool.setuptools.dynamic] 62 | version = {attr = "sotastream.__version__"} 63 | 64 | [tool.setuptools.packages.find] 65 | #where = ["src"] # ["."] by default 66 | include = ["sotastream*"] # ["*"] by default 67 | exclude = ["tests*", "tmp*", "build*", "dist*"] # empty by default 68 | ##################### 69 | 70 | [tool.black] 71 | line-length = 110 72 | target-version = ['py37', 'py38', 'py39'] 73 | include = '\.pyi?$' 74 | skip-string-normalization = true 75 | 76 | [tool.pytest.ini_options] 77 | addopts = " -v" 78 | -------------------------------------------------------------------------------- /sotastream/filters/filters.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import logging 4 | 5 | logger = logging.getLogger(f"sotastream") 6 | 7 | 8 | def SkipBlanks(lines, fields=[0, 1]): 9 | """ 10 | Skips lines that are blank in any of the requested fields. 11 | Also zeroes out the third field if present (to reset docid). 12 | This is important for training document models, where a blank field can teach the model 13 | to drop / add sentences. 14 | 15 | :param lines: The data stream 16 | :param fields: fields to check for blankness 17 | """ 18 | skipped_prev = False 19 | for line in lines: 20 | for fieldno in fields: 21 | if fieldno >= len(line) or line[fieldno] is None or line[fieldno] == "": 22 | skipped_prev = True 23 | break 24 | else: 25 | # If we skipped the previous line, we invalidate the current document ID 26 | if skipped_prev and len(fields) >= 3: 27 | fields[2] = 0 28 | skipped_prev = False 29 | 30 | yield line 31 | 32 | 33 | def BitextFilter(lines, end_range=2): 34 | """ 35 | Removes all fields up to end_range. 36 | 37 | :param lines: the stream of input lines 38 | :param end_range: One higher than the last 0-index field number that should be included. 39 | """ 40 | for line in lines: 41 | line.fields = line.fields[0:end_range] 42 | yield line 43 | 44 | 45 | def MatchFilter(lines, pattern=r'[\=\+\#\@\^\~\<\>]', fields=[0, 1], invert=False): 46 | for line in lines: 47 | if len(line) < 2: 48 | logger.debug(f"MatchFilter: bad line: {line}") 49 | continue 50 | 51 | if len(fields) != 2: 52 | raise IndexError("need to specify two field indices for matching") 53 | 54 | f1 = line[fields[0]] 55 | f2 = line[fields[1]] 56 | 57 | criterion = sorted(re.findall(pattern, f1)) == sorted(re.findall(pattern, f2)) 58 | if (not invert and criterion) or (invert and not criterion): 59 | yield line 60 | 61 | 62 | def RegexFilter(lines, pattern, fields=[0, 1], invert=False): 63 | """ 64 | Removes a line if the pattern is found in one or more fields. 65 | """ 66 | regex = re.compile(pattern) 67 | for line in lines: 68 | if len(line) < len(fields): 69 | logger.debug(f"RegexFilter: bad line: {line}") 70 | continue 71 | 72 | founds = [regex.search(line[field]) for field in fields] 73 | if (not invert and not any(founds)) or (invert and all(founds)): 74 | yield line 75 | -------------------------------------------------------------------------------- /test/regression/README.md: -------------------------------------------------------------------------------- 1 | # sotastream: regression tests 2 | 3 | ## Introduction 4 | 5 | The framework has been adopted from [Marian NMT regression 6 | tests](https://github.com/marian-nmt/marian-regression-tests). 7 | A regression test is a bash script with file prefix name `test_*` in `tests` 8 | directory that calls sotastream and produces outputs that is then compared 9 | against the expected output. 10 | 11 | ## Usage 12 | 13 | After setting up sotastream, download the data used in regression tests: 14 | 15 | ```bash 16 | export AZURE_STORAGE_SAS_TOKEN='paste-here-a-valid-sas-token' 17 | bash setup.sh 18 | ``` 19 | 20 | This needs to be done once after new data is added. The SAS token for 21 | `https://romang.blob.core.windows.net/augmentibatch` (Internal Consumption) you 22 | can generate yourself from Azure Portal or ask _rogrundk_. To do this, 23 | 24 | 1. Navigate to portal.azure.com 25 | 2. Find romang's storage account 26 | 3. Click on "Containers" in the lefthand navbar 27 | 4. Click on "augmentibatch" in the file navigator 28 | 5. Click on "Shared access tokens" in the lefthand navbar 29 | 6. Make sure you have "read" and "list" permissions 30 | 7. Click "Generate SAS token and URL" 31 | 32 | Instead, you can also download the `data` folder manually via Microsoft Azure Storage Explorer. 33 | 34 | Then run all tests: 35 | 36 | ```bash 37 | bash run.sh 38 | ``` 39 | 40 | This will display an output similar to this: 41 | 42 | ``` 43 | [02/15/2023 03:04:47] Running on mt-gpu-008 as process 459328 44 | [02/15/2023 03:04:47] Python: python3 45 | [02/15/2023 03:04:47] Augmentibatch dir: /home/rogrundk/train/sotastream 46 | [02/15/2023 03:04:47] Data dir: /home/rogrundk/train/sotastream/test/regression/data 47 | [02/15/2023 03:04:47] Time out: 5m 48 | [02/15/2023 03:04:47] Checking directory: tests 49 | [02/15/2023 03:04:48] Checking directory: tests/basic 50 | [02/15/2023 03:04:48] Running tests/basic/test_default_pipeline.sh ... OK 51 | [02/15/2023 03:04:49] Test took 00:00:1.320s 52 | [02/15/2023 03:04:49] Running tests/basic/test_t1_pipeline.sh ... OK 53 | [02/15/2023 03:04:51] Test took 00:00:1.902s 54 | --------------------- 55 | Ran 2 tests in 00:00:3.493s, 2 passed, 0 skipped, 0 failed 56 | OK 57 | ``` 58 | 59 | See `run.sh` for more exacution examples and command-line arguments. 60 | 61 | By default the script travers all subdirectories of `tests` and runs each bash 62 | script with the file name format of `test_*.sh`. Directories and files with 63 | `_` prefix are ignored. Detailed outputs for each test are stored in 64 | `test_*.sh.log`. 65 | 66 | ## Adding new regression tests 67 | 68 | See existing tests in `tests/basic` for examples. 69 | 70 | -------------------------------------------------------------------------------- /sotastream/utils/phrases.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class PhraseSpanExtractor: 5 | """Re-implementation of phrase span extraction algorithm from Moses""" 6 | 7 | def __init__(self, srcSpans, trgSpans, alignment, maxLength=7): 8 | self.srcSpans = srcSpans 9 | self.trgSpans = trgSpans 10 | self.alignment = alignment 11 | self.maxLength = maxLength 12 | 13 | self.srcLength = len(srcSpans) 14 | self.trgLength = len(trgSpans) 15 | self.phrases = [] 16 | self.marked = set([q for _, q in alignment]) 17 | 18 | def extract(self, srcStart, srcEnd, trgStart, trgEnd): 19 | if trgEnd == -1: 20 | return [] 21 | for p, q in self.alignment: 22 | if trgStart <= q <= trgEnd and (p < srcStart or p > srcEnd): 23 | return [] 24 | E = [] 25 | ts = trgStart 26 | while True: 27 | te = trgEnd 28 | while True: 29 | if te - ts < self.maxLength: 30 | E.append(((srcStart, srcEnd), (ts, te))) 31 | else: 32 | break 33 | te += 1 34 | if te in self.marked or te >= self.trgLength: 35 | break 36 | ts -= 1 37 | if ts in self.marked or ts < 0: 38 | break 39 | return E 40 | 41 | def computePhraseSpans(self): 42 | for srcStart in range(self.srcLength): 43 | for srcEnd in range(srcStart, self.srcLength): 44 | if srcEnd - srcStart >= self.maxLength: 45 | break 46 | trgStart = self.trgLength - 1 47 | trgEnd = -1 48 | for p, q in self.alignment: 49 | if srcStart <= p <= srcEnd: 50 | trgStart = min(q, trgStart) 51 | trgEnd = max(q, trgEnd) 52 | E = self.extract(srcStart, srcEnd, trgStart, trgEnd) 53 | for p in E: 54 | (sb, se), (tb, te) = p 55 | self.phrases.append( 56 | ( 57 | (self.srcSpans[sb][0], self.srcSpans[se][1]), 58 | (self.trgSpans[tb][0], self.trgSpans[te][1]), 59 | ) 60 | ) 61 | 62 | def samplePhraseSpans(self, k=1): 63 | k = min(k, len(self.phrases)) 64 | if k: 65 | return random.choices( 66 | self.phrases, weights=[2 / (s[1] - s[0] + t[1] - t[0] + 2) for s, t in self.phrases], k=k 67 | ) 68 | else: 69 | return [] 70 | 71 | def getPhraseSpans(self): 72 | return self.phrases 73 | -------------------------------------------------------------------------------- /sotastream/pipelines/multistream_pipeline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import logging 4 | from pathlib import Path 5 | from typing import List, Tuple 6 | 7 | from sotastream.augmentors import DataSource, Mixer, UTF8File 8 | from sotastream.pipelines import Pipeline, pipeline 9 | 10 | logger = logging.getLogger(f"sotastream") 11 | 12 | 13 | @pipeline("multistream") 14 | class MultiStreamPipeline(Pipeline): 15 | """Pipeline for mixing multiple (or variable) number of datasources. 16 | 17 | This pipeline takes one more more data paths and mixes them together as given by --mix-weights parameter (default: equal ratios i.e. balance the sources). 18 | Example usecase: classification task, where each data stream is per class (default mix ratio is to balance classes) 19 | """ 20 | 21 | def __init__(self, paths: List[Path], ext: str, mix_weights: List = None, **kwargs): 22 | """Pipeline for mixing variable number of data sources. 23 | 24 | :param paths: paths of data files to mix. 25 | :param ext: extension of chunked files inside data files specified in paths 26 | :param mix_weights: weights of data files in mixing. Should be one weight per input path. If None, all data files are mixed with equal weights. 27 | :param **kwargs: see Pipeline class for more arguments 28 | """ 29 | if mix_weights: 30 | if len(mix_weights) != len(paths): 31 | raise ValueError( 32 | f'--mix-weights should have one weight per data source; Given {len(paths)} data sources but {len(mix_weights)} weight(s).' 33 | ) 34 | else: 35 | mix_weights = [1.0] * len(paths) 36 | # data_sources has paths as a nested list, so we remove it and pass paths list itself 37 | kwargs.pop('data_sources', None) 38 | super().__init__(mix_weights=mix_weights, data_sources=paths, **kwargs) 39 | 40 | assert paths 41 | assert len(paths) == len(self.mix_weights) 42 | assert abs(1 - sum(self.mix_weights)) <= 1e-6, f'{self.mix_weights} = {sum(self.mix_weights)} != 1.0' 43 | 44 | TsvChunkReader = functools.partial(DataSource, ext=ext, buffer_size=self.buffer_size, seed=self.seed) 45 | logger.info('Mixing data from paths:\n * ' + '\n * '.join([str(path) for path in paths])) 46 | streams = [TsvChunkReader(path, processChunk=UTF8File) for path in paths] 47 | if len(paths) == 1: 48 | pipeline = streams[0] 49 | else: 50 | pipeline = Mixer(streams, self.mix_weights) 51 | self.stream = pipeline 52 | 53 | @classmethod 54 | def get_data_sources_for_argparse(cls) -> List[Tuple]: 55 | return [ 56 | ( 57 | 'paths', 58 | '''Dataset paths (i.e. sub datasets) to mix. Mixture weights can be specified with --mix-weights, 59 | one per path and in the same order as paths (Default: equal ratios). 60 | Each path should be a directory with chunked files ending with suffix given by --ext argument.''', 61 | '+', 62 | ), 63 | ] 64 | 65 | @classmethod 66 | def get_data_sources_default_weights(cls): 67 | # we dont know how many sources will be provided until runtime CLI parsing 68 | return ['+'] 69 | 70 | @classmethod 71 | def add_cli_args(cls, parser: argparse.ArgumentParser): 72 | super().add_cli_args(parser) 73 | parser.add_argument( 74 | '--ext', 75 | '-e', 76 | type=str, 77 | default='.tsv', 78 | help='Extensions of chunked files inside data directories.\n Default: .tsv. ' 79 | 'For gzip compressed files set .gz', 80 | ) 81 | -------------------------------------------------------------------------------- /test/regression/test_mtdata.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | import logging as log 4 | import pytest 5 | 6 | log.getLogger(name=__name__).setLevel(log.INFO) 7 | 8 | 9 | EXPECTED = """\ 10 | ಮತ್ತು ಜನಸಂಖ್ಯಾ ಸ್ಥಿರತಾ ಕೋಶ (ಜೆ.ಎಸ್.ಕೆ.)ಯನ್ನು ಮುಚ್ಚುವ ಮತ್ತು ಈ ಕಾರ್ಯಗಳನ್ನು ಆರೋಗ್ಯ ಮತ್ತು ಕುಟುಂಬ ಕಲ್ಯಾಣ ಇಲಾಖೆ (ಡಿಓಎಚ್ಎಫ್.ಡಬ್ಲು) ನಲ್ಲಿ ನಿಯೋಜಿಸಲು ಉದ್ದೇಶಿಸಿರುವ ಪ್ರಸ್ತಾಪಕ್ಕೆ ತನ್ನ ಅನುಮೋದನೆ ನೀಡಿದೆ. The Union Cabinet chaired by Prime Minister Shri Narendra Modi has approved the proposal for closure of Autonomous Bodies, namely, Rashtriya Arogya Nidhi (RAN) and Jansankhya Sthirata Kosh (JSK) and the functions are proposed to be vested in Department of Health & Family Welfare (DoHFW). 11 | ಆರೋಗ್ಯ ಮತ್ತು ಕುಟುಂಬ ಕಲ್ಯಾಣ ಇಲಾಖೆಯ ಅಡಿಯಲ್ಲಿರುವ ಸ್ವಾಯತ್ತ ಕಾಯಗಳ ತರ್ಕಬದ್ಧೀಕರಣವು ಈ ಕಾಯಗಳ ಚಾಲ್ತಿಯಲ್ಲಿರುವ ಅಂಗರಚನೆಯಂತೆ ಅಂತರ ಸಚಿವಾಲಯದ ಸಮಾಲೋಚನೆ ಮತ್ತು ಪರಾಮರ್ಶೆಯನ್ನು ಒಳಗೊಂಡಿರುತ್ತದೆ. The rationalization of Autonomous Bodies under Department of Health & Family Welfare will involve inter-ministerial consultations and review of existing bye laws of these bodies. 12 | Steigt Gold auf 10.000 Dollar? $10,000 Gold? 13 | ಇದರ ಜಾರಿಗೆ ಕಾಲಮಿತಿ ಒಂದು ವರ್ಷವಾಗಿರುತ್ತದೆ. The time frame for implementation is one year, 14 | ನಿಯೋಜಿತ ಕೇಂದ್ರ ಸರ್ಕಾರದ ಆಸ್ಪತ್ರೆಗಳಲ್ಲಿ ಬಡ ರೋಗಿಗಳು ಪಡೆಯುವ ಚಿಕಿತ್ಸೆಗೆ ವೈದ್ಯಕೀಯ ನೆರವು ಒದಗಿಸಲು ರಾಷ್ಟ್ರೀಯ ಆರೋಗ್ಯ ನಿಧಿ (ಆರ್.ಎ.ಎನ್) ಯನ್ನು ನೋಂದಾಯಿತ ಸೊಸೈಟಿಯಾಗಿ ಸ್ಥಾಪಿಸಲಾಗಿದೆ. Rashtriya Arogya Nidhi (RAN) was set up as a registered society to provide financial medical assistance to poor patients receiving treatment in designated central government hospitals. 15 | ಪ್ರಕರಣಗಳ ಆಧಾರದ ಮೇಲೆ ನೆರವು ಒದಗಿಸಲು ಅಂಥ ಆಸ್ಪತ್ರೆಗಳ ವೈದ್ಯಕೀಯ ಸೂಪರಿಂಟೆಂಡೆಂಟ್ ಗಳ ಬಳಿ ಮುಂಗಡವನ್ನು ಇಡಲಾಗಿದೆ. An advance is placed with the Medical Superintendents of such hospitals who then provide assistance on a case to case basis. 16 | ಡಿಓಎಚ್.ಎಫ್.ಡಬ್ಲ್ಯು. ಆಸ್ಪತ್ರೆಗಳಿಗೆ ನಿಧಿ ಒದಗಿಸುವುದರಿಂದ, ಸಹಾಯಧನವನ್ನು ಇಲಾಖೆಯಿಂದ ನೇರವಾಗಿ ಆಸ್ಪತ್ರೆಗಳಿಗೇ ನೀಡಲಾಗುತ್ತಿದೆ. Since the DoHFW provides funds to the hospitals, the grants can be given from the Department to the hospital directly. 17 | ಆರ್.ಎ.ಎನ್ ಸೊಸೈಟಿಯ ಆಡಳಿತ ಮಂಡಳಿಗಳು ಸೊಸೈಟಿಗಳ ನೋಂದಣಿ ಕಾಯಿದೆ, 1860 (ಎಸ್.ಆರ್.ಎ) ಯ ನಿಬಂಧನೆಗಳ ಪ್ರಕಾರ ಸ್ವಾಯತ್ತ ಕಾಯವಾಗಿ (ಎಬಿ) ವಿಸರ್ಜಿಸಬಹುದಾಗಿರುತ್ತದೆ. Managing Committee of RAN Society will meet to dissolve the Autonomous Body (AB) as per provisions of Societies Registration Act, 1860 (SRA). 18 | SAN FRANCISCO – Es war noch nie leicht, ein rationales Gespräch über den Wert von Gold zu führen. SAN FRANCISCO – It has never been easy to have a rational conversation about the value of gold. 19 | ಇದರ ಜೊತೆಗೆ ಆರೋಗ್ಯ ಸಚಿವಾಲಯದ ಕ್ಯಾನ್ಸರ್ ರೋಗಿಗಳ ನಿಧಿ (ಎಚ್.ಎಂ.ಸಿ.ಪಿ.ಎಫ್.)ಯನ್ನು ಕೂಡ ಇಲಾಖೆಗೆ ವರ್ಗಾಯಿಸಲಾಗುತ್ತದೆ. In addition to this, Health Minister’s Cancer Patient Fund (HMCPF) shall also be transferred to the Department.""" 20 | 21 | 22 | def test_mtdata(): 23 | """ 24 | Test MTData pipeline. 25 | Starts a subprocess and reads its output. 26 | """ 27 | 28 | try: 29 | from mtdata.data import INDEX 30 | 31 | print("Import successful") 32 | except ImportError: 33 | pytest.skip("mtdata is unavailable") 34 | 35 | base_cmd = f'{sys.executable} -m sotastream -n 1 -q 1000 -b 1000 --seed 43' 36 | cmd = f'{base_cmd} mtdata -lp mul-eng Statmt-news_commentary-16-deu-eng Statmt-pmindia-1-eng-kan --mix-weights 1 2' 37 | log.info(f'Running command: {cmd}') 38 | proc = subprocess.Popen( 39 | cmd, 40 | shell=True, 41 | stdout=subprocess.PIPE, 42 | stderr=sys.stderr, 43 | stdin=subprocess.DEVNULL, 44 | text=True, 45 | bufsize=1, 46 | ) 47 | try: 48 | expected = EXPECTED.split('\n') 49 | recieved = [] 50 | i = 0 51 | for line in proc.stdout: 52 | recieved.append(line.rstrip('\n')) 53 | i += 1 54 | if i >= len(expected): 55 | break 56 | assert recieved == expected 57 | finally: 58 | proc.terminate() 59 | -------------------------------------------------------------------------------- /sotastream/data.py: -------------------------------------------------------------------------------- 1 | from . import Defaults 2 | 3 | from typing import List, Optional 4 | 5 | 6 | class Line: 7 | """ 8 | A Line object represents a line containined fields. The string representation 9 | is typically delimited by tabs, and internally we use fields. The fields can 10 | represent any parallel corpus. Typically, they are source, target, and metadata. 11 | """ 12 | 13 | # Define slots for efficiency. This avoids the use of a dict 14 | # for each instance, which is a big memory savings. 15 | # https://docs.python.org/3/reference/datamodel.html#slots 16 | # https://stackoverflow.com/questions/472000/usage-of-slots 17 | __slots__ = "fields" 18 | 19 | def __init__(self, rawLine=None, fields=[]) -> None: 20 | """ 21 | Initializes a new Line object from rawLine or fields. 22 | Preference is give to rawLine (if non-None). 23 | 24 | Example usage: 25 | 26 | line = Line("Das ist ein Test.\tThis is a test.") 27 | line[1] 28 | -> 'This is a test.' 29 | 30 | If rawLine is not defined, fields will be used. 31 | 32 | :param rawLine: The raw input line, tab-delimited. 33 | :param fields: A list of fields directly. 34 | """ 35 | if rawLine is not None: 36 | self.fields = [field.rstrip("\r\n ") for field in rawLine.split("\t")] 37 | elif fields is None: 38 | self.fields = [] 39 | else: 40 | self.fields = [field for field in fields] 41 | 42 | def __str__(self): 43 | """ 44 | Only join the first and second fields. 45 | This assumes a canonical output format of "{source}\t{target}". 46 | If you want to print metadata or have other semantics for these 47 | fields, you'll have to roll it yourself. 48 | """ 49 | return "\t".join(self.fields) 50 | 51 | def __len__(self): 52 | """The length is the number of non-None fields.""" 53 | return len(self.fields) 54 | 55 | def __getitem__(self, i): 56 | """Return the ith field.""" 57 | if isinstance(i, tuple): 58 | return self.fields[i[0] : i[1] : i[2]] 59 | return self.fields[i] 60 | 61 | def __setitem__(self, i, value): 62 | """Set the ith field.""" 63 | while i >= len(self.fields): 64 | self.fields.append("") 65 | self.fields[i] = value 66 | 67 | def __eq__(self, other): 68 | return isinstance(other, Line) and self.fields == other.fields 69 | 70 | def __hash__(self): 71 | """Makes the object hashable.""" 72 | return hash(tuple(self.fields)) 73 | 74 | def __copy__(self): 75 | return Line(fields=self.fields) 76 | 77 | @staticmethod 78 | def join(lines: List["Line"], separator=Defaults.DOC_SEPARATOR, end_range=2): 79 | """ 80 | Joins columns of lines together using the specified separator. 81 | Quits at column end_range - 1. 82 | 83 | Example input: join([Line("a\tb\t1"), Line("d\te\t1")], separator="|", end_range=2) 84 | Example output: Line("a|d\tb|e") 85 | 86 | :param lines: the list of Line objects to join. 87 | :param separator: the separator to use. 88 | :param end_range: the column to stop at. 89 | 90 | :return: a new Line object. 91 | """ 92 | fields = [] 93 | for i in range(end_range): 94 | fields.append(separator.join([line[i] for line in lines])) 95 | 96 | return Line(fields=fields) 97 | 98 | def append(self, other: "Line", fields: Optional[List[int]] = None, separator=Defaults.SEPARATOR): 99 | """ 100 | Append field-wise, on the specified fields. 101 | If the current Line has fewer fields than the Line being appended, 102 | it is padded to match. 103 | 104 | :param other: the Line object to append. 105 | :param fields: the list of fields to append (None means all fields). 106 | """ 107 | if fields is None: 108 | fields = list(range(len(other))) 109 | 110 | for i in fields: 111 | # skip non-existent fields (protects against caller listing too many fields) 112 | if i >= len(other): 113 | break 114 | while i >= len(self): 115 | self.fields.append("") 116 | 117 | if self[i] == "": 118 | self[i] = other[i] 119 | else: 120 | self[i] += separator + other[i] 121 | -------------------------------------------------------------------------------- /docs/introduction.rst: -------------------------------------------------------------------------------- 1 | 2 | Sotastream 3 | ========== 4 | 5 | Introduction 6 | ------------ 7 | 8 | Sotastream is a tool for data augmentation for training pipeline. It 9 | uses `infinibatch `_ internally to generate an infinite stream of 10 | shuffled training data and provides a means for on-the-fly data 11 | manipulation, augmentation, mixing, and sampling. 12 | 13 | 14 | 15 | 16 | Setup 17 | ----- 18 | 19 | To install from PyPI (https://pypi.org/project/sotastream/) 20 | 21 | 22 | .. code:: bash 23 | 24 | pip install sotastream 25 | 26 | 27 | *Developer Setup:* 28 | 29 | .. code:: bash 30 | 31 | # To begin, clone the repository: 32 | git clone https://github.com/marian-nmt/sotastream 33 | cd sotastream 34 | # option 1: 35 | python -m pip install . 36 | # option 2: install in --editable mode 37 | python -m pip install -e . 38 | 39 | 40 | *Entry points* 41 | * As a module: `python -m sotastream` 42 | * As a bin in your $PATH: `sotastream` 43 | 44 | 45 | Development 46 | ----------- 47 | 48 | Install development tools 49 | 50 | .. code:: bash 51 | 52 | python -m pip install -e .[dev,test] # editable mode 53 | 54 | Editable mode (``-e / --editable``) is recommended for development 55 | purposes, ``pip`` creates symbolic link to your source code in a way 56 | that any edits made are reflected directly to the installed package. 57 | ``[dev,test]`` installs depencies for development and tests which 58 | includes ``black``, ``pytest`` etc. 59 | 60 | We use ``black`` to reformat code to a common code style. 61 | 62 | .. code:: bash 63 | 64 | make reformat 65 | 66 | Before creating any pull requests, run 67 | 68 | .. code:: bash 69 | 70 | make check # runs reformatter and tests 71 | 72 | Running tests 73 | ------------- 74 | 75 | .. code:: bash 76 | 77 | make test # run unit tests 78 | make regression # run regression tests 79 | 80 | See ``Makefile`` for more details. 81 | 82 | Usage examples 83 | -------------- 84 | 85 | A folder like ``split/parallel`` contains training data in tsv format 86 | (``srctgt``) split into ``*.gz`` files of around 100,000 lines for 87 | better shuffling. The below will output an infinite stream of data 88 | generated from the gzipped files in these folders, according to the 89 | “wmt” recipe found in ``sotastream/pipelines/example_pipeline.py``. 90 | 91 | :: 92 | 93 | python -m sotastream example split/parallel split/backtrans 94 | 95 | You can also provide compressed TSV files directly, in which case 96 | sotastream will split them to checksummed folders under 97 | ``/tmp/sotastream/{checksum}``: 98 | 99 | :: 100 | 101 | python -m sotastream example parallel.tsv.gz backtrans.tsv.gz 102 | 103 | There are currently two main pipelines: “default”, and “wmt”. 104 | These vary according to the data sources they take as well as the other options 105 | available to them. 106 | 107 | There are global options that control behavioral aspects such as 108 | splitting and parallelization, and also pipeline-specific arguments. You 109 | can see these by running 110 | 111 | :: 112 | 113 | # see global options 114 | python -m sotastream -h 115 | 116 | # see default pipeline options 117 | python -m sotastream default -h 118 | 119 | # see wmt pipeline options 120 | python -m sotastream wmt -h 121 | 122 | Don't cross the streams! 123 | ------------------------ 124 | 125 | Sotastream workflows build a directed acyclic graph (DAG) consisting of 126 | cascades of generators that pass through mutable lines from the graph 127 | inputs to the pipeline output. Since each step provides transformations 128 | and manipulations of each input line, the only requirement is that 129 | modifications along separate branches must not be merged into a single 130 | node in the graph, or at least, that great care should be taken when 131 | doing so. An example is the Mixer, which does not actually merge 132 | modifications from alternate branches, but instead selects across 133 | multiple incoming branches using a provided probability distribution. 134 | 135 | Custom/private pipelines from own (private) directory 136 | ===================================================== 137 | 138 | You can create a custom pipeline by adding a file in the current 139 | (invocation) directory with a file name matching the pattern 140 | "*_pipeline.py". This should follow the interface defined in 141 | ``sotastream/pipelines``, namely: 142 | 143 | - Call ``@pipeline("name")`` to give your pipeline a name. This name 144 | must not conflict with existing names. 145 | - Inherit from ``Pipeline`` base class from ``sotastream.pipeline``. 146 | For document pipelines, use ``DocumentPipeline`` as base class. 147 | 148 | You can find some examples in ``test/dummy_pipeline.py``, as well as the 149 | real examples in ``sotastream/pipelines``. 150 | -------------------------------------------------------------------------------- /test/regression/test_multistream.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pytest 3 | from collections import defaultdict 4 | import subprocess 5 | import sys 6 | import logging as log 7 | from typing import List 8 | 9 | log.basicConfig(level=log.INFO) 10 | data_dir = Path(__file__).parent / "data/trec" 11 | 12 | EXPECTED = """LOC country What European country is home to the beer-producing city of Budweis ? 13 | HUM ind Who was the last woman executed in England ? 14 | ABBR exp What does the word LASER mean ? 15 | HUM ind What player squats an average of 3 times during a baseball doubleheader ? 16 | LOC other Where can I find a world atlas map online at no charge ? 17 | NUM date When did French revolutionaries storm the Bastille ? 18 | LOC other On what avenue is the original Saks department store located ? 19 | ENTY color What color is the cross on Switzerland 's flag ? 20 | ABBR exp What does SHIELD stand for ? 21 | LOC country What country has the best defensive position in the board game Diplomacy ? 22 | ENTY other What five cards make up a perfect Cribbage hand ? 23 | NUM count How many people die from snakebite poisoning in the U.S. per year ? 24 | ABBR exp What does pH stand for ? 25 | ENTY body Which leg does a cat move with its left front leg when walking - its left rear or right rear leg ? 26 | DESC def What is a disaccharide ? 27 | ENTY symbol What sign is The Water Carrier the zodiacal symbol for ? 28 | HUM ind What well-known actor is the father of star Alan Alda ? 29 | NUM count How many airline schools are there in the U.S. ? 30 | LOC city What is the capital of Burkina Faso ? 31 | HUM ind Who were the four famous founders of United Artists ?""" 32 | 33 | 34 | def get_data_paths() -> List[str]: 35 | flag = data_dir / '._OK' 36 | if not flag.exists(): 37 | url = "https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label" 38 | try: 39 | import requests 40 | except ImportError: 41 | pytest.skip("requests is unavailable. `pip install requests` to enable this test.") 42 | 43 | lines = requests.get(url).text.splitlines(keepends=False) 44 | data_parsed = defaultdict(list) 45 | for line in lines: 46 | label, text = line.split(' ', maxsplit=1) 47 | label, sublabel = label.split(':') 48 | data_parsed[label].append(f'{label}\t{sublabel}\t{text}') 49 | for label, texts in data_parsed.items(): 50 | label_dir = data_dir / label 51 | label_dir.mkdir(parents=True, exist_ok=True) 52 | with open(label_dir / 'part00.tsv', 'wt') as outfh: 53 | for text in texts: 54 | print(text, file=outfh) 55 | flag.touch() 56 | 57 | sub_dirs = set() 58 | for file in data_dir.glob('*/*.tsv'): 59 | sub_dirs.add(str(file.parent)) 60 | 61 | return list(sorted(sub_dirs)) 62 | 63 | 64 | def test_multistream(): 65 | """ 66 | Test varargs pipeline. 67 | Starts a subprocess and reads its output. 68 | """ 69 | paths = get_data_paths() 70 | assert len(paths) > 1 71 | base_cmd = f'{sys.executable} -m sotastream -n 1 -q 1000 -b 1000 --seed 43' 72 | cmd = f'{base_cmd} multistream {" ".join(paths)}' 73 | log.info(f'Running command: {cmd}') 74 | proc = subprocess.Popen( 75 | cmd, 76 | shell=True, 77 | stdout=subprocess.PIPE, 78 | stderr=sys.stderr, 79 | stdin=subprocess.DEVNULL, 80 | text=True, 81 | bufsize=1, 82 | ) 83 | 84 | try: 85 | expected = EXPECTED.split('\n') 86 | recieved = [] 87 | # collect 50_000 lines and compute stats 88 | stats = defaultdict(int) 89 | max_stats = 50_000 90 | i = 0 91 | for line in proc.stdout: 92 | i += 1 93 | label = line.split('\t')[0] 94 | if i <= len(expected): 95 | recieved.append(line.rstrip('\n')) 96 | if i == len(expected): 97 | if recieved != expected: 98 | expected = "\n".join(expected) 99 | recieved = "\n".join(recieved) 100 | log.error(f'##Expected:\n{expected}\n\n##Got:\n{recieved}') 101 | pytest.fail('Expected and recieved lines do not match.') 102 | stats[label] += 1 103 | if i > max_stats: 104 | break 105 | # all paths are used 106 | assert len(stats) == len(paths), f'Expected {len(paths)} paths to be used, got {len(stats)}.' 107 | majority, minority = max(stats.values()), min(stats.values()) 108 | assert majority > 1 109 | # all paths are used ~equally; allow 5% difference between majority and minority 110 | assert ( 111 | abs(majority - minority) < 0.05 * majority 112 | ), f'Paths are not used equally. majority: {majority}, minority: {minority}. Diff: {100 * (majority-minority)/majority:.2f}% difference.' 113 | finally: 114 | proc.terminate() 115 | -------------------------------------------------------------------------------- /sotastream/utils/split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import datetime 4 | import gzip 5 | import hashlib 6 | import logging 7 | import os 8 | import shutil 9 | import subprocess 10 | import time 11 | 12 | from pathlib import Path 13 | from typing import Type 14 | 15 | from sotastream.pipelines import PIPELINES 16 | 17 | 18 | logger = logging.getLogger(f"sotastream") 19 | 20 | 21 | # The block size to use when compute MD5 hashes 22 | MD5_BLOCK_SIZE = 8192 23 | 24 | 25 | def split_file_into_chunks( 26 | filepath: str, 27 | tmpdir: str = "/tmp/sotastream", 28 | split_size: int = 10000, 29 | native: bool = False, 30 | overwrite: bool = False, 31 | ) -> Path: 32 | """ 33 | Splits a file into compressed chunks under a directory. 34 | The location will be in a directory named by the file's checksum, within the 35 | provided temporary directory. Results are cached, providing for quick restarting. 36 | 37 | :param filepath: The input file path 38 | :param tmpdir: The top-level temporary directory to write to 39 | :param split_size: The size of each chunk in lines 40 | :param native: If True, use Python to split, instead of a subshell 41 | :return: The directory where the chunks are stored, as a Path object 42 | """ 43 | start_time = time.perf_counter() 44 | 45 | split_func = split_native if native else split_subshell 46 | 47 | # Compute the checksum 48 | md5sum = compute_md5(filepath) 49 | logger.info(f"md5sum({filepath}) = {md5sum} computed in {time.perf_counter() - start_time:.1f}s") 50 | 51 | # Check if we already have the file split 52 | destdir = Path(tmpdir) / md5sum 53 | donefile = destdir / ".done" 54 | if destdir.exists() and overwrite: 55 | logger.info(f"Removing existing split directory {destdir}") 56 | shutil.rmtree(destdir) 57 | elif donefile.exists(): 58 | logger.info(f"Using cached splitting of {filepath} (checksum: {md5sum})") 59 | return destdir 60 | 61 | # If not, split the file 62 | logger.info(f"Splitting file {filepath} to {tmpdir}...") 63 | destdir.mkdir(parents=True, exist_ok=True) 64 | start_time = time.perf_counter() 65 | split_func(filepath, destdir, split_size) 66 | logger.info(f"File {filepath} splitting took {time.perf_counter() - start_time:.1f}s") 67 | 68 | with open(donefile, "w") as outfh: 69 | print(f"{filepath} finished splitting {datetime.datetime.now()}", file=outfh) 70 | 71 | return destdir 72 | 73 | 74 | def split_native(filepath: str, destdir: Path, split_size: int): 75 | """ 76 | Split directly in Python by reading the file. 77 | This version is slower than the subshell version. 78 | 79 | :param filepath: The input file path 80 | :param destdir: The output directory 81 | :param split_size: The size of each chunk in lines 82 | """ 83 | 84 | def get_chunkpath(index=0): 85 | outfh = smart_open(destdir / f"part.{index:05d}.gz", "wt") 86 | index += 1 87 | return index, outfh 88 | 89 | with smart_open(filepath) as infh: 90 | chunkno, outfh = get_chunkpath() 91 | logger.info(f"Splitting {filepath} to {destdir}") 92 | for lineno, line in enumerate(infh, 1): 93 | line = line.rstrip("\r\n") 94 | if lineno % split_size == 0: 95 | if outfh is not None: 96 | outfh.close() 97 | chunkno, outfh = get_chunkpath(chunkno) 98 | print(line, file=outfh) 99 | outfh.close() 100 | 101 | 102 | def split_subshell(filepath: str, destdir: Path, split_size: int): 103 | """ 104 | Split using a subshell (~8x faster). 105 | 106 | :param filepath: The input file path 107 | :param destdir: The output directory 108 | :param split_size: The size of each chunk in lines 109 | """ 110 | cmd = f"pigz -cd {filepath} | sed 's/\r//g' | split -d -a5 -l {split_size} --filter 'pigz > $FILE.gz' - {destdir}/part." 111 | logger.info(cmd) 112 | subprocess.run(cmd, shell=True, check=True) 113 | 114 | 115 | def smart_open(filepath: str, mode: str = "rt", encoding: str = "utf-8"): 116 | """Convenience function for reading and writing compressed or plain text files. 117 | 118 | :param filepath: The file to read. 119 | :param mode: The file mode (read, write). 120 | :param encoding: The file encoding. 121 | :return: a file handle. 122 | """ 123 | if Path(filepath).suffix == ".gz": 124 | return gzip.open(filepath, mode=mode, encoding=encoding, newline="\n") 125 | return open(filepath, mode=mode, encoding=encoding, newline="\n") 126 | 127 | 128 | def compute_md5(filepath: str): 129 | """Computes an MD5 checksum over a file. 130 | Note that binary reading in this way is as fast as a subshell call. 131 | 132 | :param filepath: The file path as as string 133 | :return: The checksum as a hexdigest. 134 | """ 135 | with open(filepath, "rb") as f: 136 | m = hashlib.md5() 137 | while chunk := f.read(MD5_BLOCK_SIZE): 138 | m.update(chunk) 139 | return m.hexdigest() 140 | 141 | 142 | if __name__ == "__main__": 143 | import argparse 144 | 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument( 147 | "infile", 148 | type=str, 149 | help="Path to TSV input file containing source, target, and (optionally) docid fields", 150 | ) 151 | parser.add_argument("--numlines", "-l", type=int, default=10000) 152 | parser.add_argument("--prefix-dir", "-p", default="/tmp/sotastream") 153 | args = parser.parse_args() 154 | 155 | logger.basicConfig(level=logging.INFO) 156 | 157 | split_file_into_chunks(args.infile, tmpdir=args.prefix_dir, split_size=args.numlines) 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sotastream 2 | [![image](http://img.shields.io/pypi/v/sotastream.svg)](https://pypi.python.org/pypi/sotastream/) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) 4 | [![Read the Docs](https://img.shields.io/readthedocs/sotastream.svg)](https://sotastream.readthedocs.io/) 5 | 6 | 7 | Sotastream is a tool for data augmentation for training 8 | pipeline. It uses `infinibatch` internally to generate an infinite 9 | stream of shuffled training data and provides a means for on-the-fly 10 | data manipulation, augmentation, mixing, and sampling. 11 | 12 | 13 | ## Setup 14 | 15 | To install from PyPI (https://pypi.org/project/sotastream/) 16 | ```bash 17 | pip install sotastream 18 | ``` 19 | 20 | *Developer Setup:* 21 | 22 | ```bash 23 | # To begin, clone the repository: 24 | git clone https://github.com/marian-nmt/sotastream 25 | cd sotastream 26 | # option 1: 27 | python -m pip install . 28 | # option 2: install in --editable mode 29 | python -m pip install -e . 30 | ``` 31 | 32 | *Entry points* 33 | * As a module: `python -m sotastream` 34 | * As a bin in your $PATH: `sotastream` 35 | 36 | ## Development 37 | 38 | Install development tools 39 | ```bash 40 | python -m pip install -e .[dev,test] # editable mode 41 | ``` 42 | Editable mode (`-e / --editable`) is recommended for development purposes, `pip` creates symbolic link to your source code in a way that any edits made are reflected directly to the installed package. `[dev,test]` installs depencies for development and tests which includes `black`, `pytest` etc. 43 | 44 | We use `black` to reformat code to a common code style. 45 | ```bash 46 | make reformat 47 | ``` 48 | 49 | Before creating any pull requests, run 50 | ```bash 51 | make check # runs reformatter and tests 52 | ``` 53 | 54 | ## Running tests 55 | 56 | ```bash 57 | make test # run unit tests 58 | make regression # run regression tests 59 | ``` 60 | 61 | See `Makefile` for more details. 62 | 63 | 64 | ## Usage examples 65 | 66 | A folder like `split/parallel` contains training data in tsv format (`srctgt`) split into 67 | `*.gz` files of around 100,000 lines for better shuffling. The below will output an infinite 68 | stream of data generated from the gzipped files in these folders, according to the "wmt" recipe 69 | found in `sotastream/pipelines/example_pipeline.py`. 70 | 71 | ``` 72 | python -m sotastream example split/parallel split/backtrans 73 | ``` 74 | You can also provide compressed TSV files directly, in which case sotastream will split them 75 | to checksummed folders under `/tmp/sotastream/{checksum}`: 76 | 77 | ``` 78 | python -m sotastream example parallel.tsv.gz backtrans.tsv.gz 79 | ``` 80 | 81 | There are currently two main pipelines: "default", and "wmt". These vary according to 82 | the data sources they take as well as the other options available to them. 83 | 84 | There are global options that control behavioral aspects such as splitting and parallelization, 85 | and also pipeline-specific arguments. You can see these by running 86 | 87 | ``` 88 | # see global options 89 | python -m sotastream -h 90 | 91 | # see default pipeline options 92 | python -m sotastream default -h 93 | 94 | # see wmt pipeline options 95 | python -m sotastream wmt -h 96 | ``` 97 | 98 | ## Don't cross the streams! 99 | 100 | Sotastream workflows build a directed acyclic graph (DAG) 101 | consisting of cascades of generators that pass through mutable lines 102 | from the graph inputs to the pipeline output. Since each step provides 103 | transformations and manipulations of each input line, the only 104 | requirement is that modifications along separate branches must not be 105 | merged into a single node in the graph, or at least, that great care 106 | should be taken when doing so. An example is the Mixer, which 107 | does not actually merge modifications from alternate branches, but instead 108 | selects across multiple incoming branches using a provided probability 109 | distribution. 110 | 111 | # Custom/private pipelines from own (private) directory 112 | 113 | You can create a custom pipeline by adding a file in the current (invocation) 114 | directory with a file name matching the pattern "*_pipeline.py". This should 115 | follow the interface defined in `sotastream/pipelines`, namely: 116 | 117 | * Call `@pipeline("name")` to give your pipeline a name. This name must not conflict with existing names. 118 | * Inherit from `Pipeline` base class from `sotastream.pipeline`. For document pipelines, use `DocumentPipeline` as base class. 119 | 120 | You can find some examples in `test/dummy_pipeline.py`, as well as the real examples in `sotastream/pipelines`. 121 | 122 | # Authors 123 | 124 | Sotastream is developed by _TextMT Team_ @ Microsoft Translator. 125 | 126 | If you use this tool, please cite: 127 | Paper link: https://arxiv.org/abs/2308.07489 | https://aclanthology.org/2023.nlposs-1.13/ 128 | 129 | 130 | ```bibtex 131 | @inproceedings{post-etal-2023-sotastream, 132 | title = "{SOTASTREAM}: A Streaming Approach to Machine Translation Training", 133 | author = "Post, Matt and 134 | Gowda, Thamme and 135 | Grundkiewicz, Roman and 136 | Khayrallah, Huda and 137 | Jain, Rohit and 138 | Junczys-Dowmunt, Marcin", 139 | editor = "Tan, Liling and 140 | Milajevs, Dmitrijs and 141 | Chauhan, Geeticka and 142 | Gwinnup, Jeremy and 143 | Rippeth, Elijah", 144 | booktitle = "Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)", 145 | month = dec, 146 | year = "2023", 147 | address = "Singapore, Singapore", 148 | publisher = "Empirical Methods in Natural Language Processing", 149 | url = "https://aclanthology.org/2023.nlposs-1.13", 150 | pages = "110--119", 151 | } 152 | ``` 153 | -------------------------------------------------------------------------------- /sotastream/pipelines/mtdata_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Iterator, Union, List, Optional 3 | import random 4 | 5 | from sotastream.data import Line 6 | from sotastream.augmentors import Mixer 7 | from sotastream.filters import BitextFilter 8 | from sotastream.pipelines import Pipeline, pipeline 9 | 10 | logger = logging.getLogger(f"sotastream") 11 | 12 | 13 | @pipeline("mtdata") 14 | class MTDataPipeline(Pipeline): 15 | """Pipeline to mix datasets from mtdata. 16 | 17 | To install mtdata, run `pip install mtdata`, or visit https://github.com/thammegowda/mtdata 18 | To see the list of available datasets, run `mtdata list -id -l -` where - 19 | are language pairs. 20 | 21 | Example #1: 22 | sotastream mtdata -lp en-de Statmt-news_commentary-16-deu-eng Statmt-europarl-10-deu-eng 23 | 24 | Example #2: 25 | sotastream mtdata -lp en-de Statmt-news_commentary-16-deu-eng Statmt-europarl-10-deu-eng --mix-weights 1 2 26 | 27 | Example #3: 28 | sotastream mtdata -lp en-de Statmt-news_commentary-16-deu-eng,Statmt-europarl-10-deu-eng 29 | 30 | Example #1 mixes two datasets with equal weights (i.e., 1:1). 31 | Example #2 mixes two datasets with 1:2 ratio respectively. 32 | Example #3 simply concatenates both datasets separated by comma into a single dataset. 33 | Therefore, the resulting mixture weights are proportional to the number of segments in each dataset. 34 | 35 | The `--langs|-lp -` argument is used to enforce compatibility between the specified datasets and ensure correct ordering of source and target languages 36 | """ 37 | 38 | def __init__( 39 | self, 40 | data_ids: List[str], 41 | mix_weights: Optional[List[float]] = None, 42 | langs: Tuple[str, str] = None, 43 | **kwargs, 44 | ): 45 | """Initialize mtdata pipeline. 46 | 47 | :param data_ids: List of mtdata IDs 48 | :param mix_weights: Mixture weights, defaults to None (i.e., equal weights) 49 | :param langs: Tuple of source and target language codes to enforce compatibility with specified dataset ids, 50 | defaults to None (not enforced) 51 | """ 52 | if not mix_weights: 53 | mix_weights = [1.0] * len(data_ids) 54 | kwargs.pop('data_sources', None) 55 | super().__init__(mix_weights=mix_weights, data_sources=data_ids, **kwargs) 56 | assert len(data_ids) == len( 57 | self.mix_weights 58 | ), f'Expected {len(mix_weights)} weights, got {len(data_ids)}. See --mix-weights argument' 59 | 60 | random.seed(self.seed) 61 | if self.num_workers > 1: 62 | logger.warning(f'num_workers > 1 is not supported for MTData pipeline.') 63 | 64 | data_sources = [] 65 | for data_id in data_ids: 66 | dids = data_id.split(',') # allow comma-separated list of dataset IDs 67 | data_sources.append(MTDataSource(dids, langs=langs)) 68 | 69 | if len(data_sources) > 1: 70 | stream = Mixer(data_sources, self.mix_weights) 71 | else: 72 | stream = data_sources[0] 73 | self.stream = BitextFilter(stream) # removes all but fields 0 and 1 74 | 75 | @classmethod 76 | def get_data_sources_for_argparse(cls): 77 | help_msg = '''MTData dataset IDs which are of format Group-name-version-lang1-lang2 78 | E.g. "Statmt-news_commentary-16-deu-eng" 79 | Run "mtdata list -id -l -" to list all available dataset IDs for any - language pair. 80 | ''' 81 | return [('data_ids', help_msg, '+')] 82 | 83 | @classmethod 84 | def get_data_sources_default_weights(cls): 85 | # we dont know how many sources will be provided until runtime CLI parsing 86 | return ['+'] 87 | 88 | @classmethod 89 | def add_cli_args(cls, parser): 90 | super().add_cli_args(parser) 91 | 92 | def LangPair(txt) -> Tuple[str, str]: 93 | """Parse language pair from CLI argument.""" 94 | pair = txt.split('-') 95 | assert len(pair) == 2, f'Expected 2 languages src-tgt, got {len(pair)}' 96 | return tuple(pair) 97 | 98 | parser.add_argument( 99 | '--langs', 100 | '-lp', 101 | required=True, 102 | metavar='SRC-TGT', 103 | type=LangPair, 104 | help='''Source and language order, e.g. "deu-eng". Ensures the correct order of the fields in the output. 105 | As per mtdata, language code 'mul' is special and meant for multilingual datasets. 106 | E.g. "mul-en" is compatible for x->en datasets, where as "en-mul" is for en->x for any x.''', 107 | ) 108 | 109 | 110 | def MTDataSource( 111 | dids: Union[str, List[str]], 112 | langs=None, 113 | progress_bar=False, 114 | ) -> Iterator[Line]: 115 | """MTData dataset iterator. 116 | 117 | :param dids: either a single dataset ID or a list of dataset ID. 118 | IDs are of form Group-name-version-lang1-lang2 e.g. "Statmt-news_commentary-16-deu-eng" 119 | :param langs: source-target language order, e.g. "deu-eng" 120 | :progress_bar: whether to show progress bar 121 | :return: Line objects 122 | """ 123 | from mtdata.data import INDEX, Cache, Parser, DatasetId 124 | from mtdata import cache_dir as CACHE_DIR, pbar_man 125 | from mtdata.iso.bcp47 import bcp47, BCP47Tag 126 | 127 | pbar_man.enabled = bool(progress_bar) 128 | 129 | if langs: # check compatibility 130 | assert len(langs) == 2, f'Expected 2 languages, got {langs}' 131 | langs = (bcp47(langs[0]), bcp47(langs[1])) 132 | 133 | data_spec = [] 134 | for did in dids: 135 | did = DatasetId.parse(did) 136 | assert did in INDEX, f'Unknown dataset ID: {did}' 137 | 138 | is_swap = False 139 | if langs: 140 | is_compat, is_swap = BCP47Tag.check_compat_swap(langs, did.langs) 141 | if not is_compat: 142 | langs_txt = '-'.join(map(str, langs)) 143 | raise ValueError(f'{did} is not compatible with {langs_txt}.') 144 | entry = INDEX[did] 145 | path = Cache(CACHE_DIR).get_entry(entry) 146 | parser = Parser(path, ext=entry.in_ext or None, ent=entry) 147 | data_spec.append([did, parser, is_swap]) 148 | count = 0 149 | delim = '\t' 150 | while True: 151 | for did, parser, is_swap in data_spec: 152 | for rec in parser.read_segs(): 153 | if isinstance(rec, (list, tuple)): 154 | fields = [col.replace(delim, ' ').replace('\n', ' ').strip() for col in rec] 155 | else: 156 | fields = rec.split(delim) 157 | assert len(fields) >= 2, f'Expected 2 fields, got {len(fields)}' 158 | fields = fields[:2] 159 | if is_swap: 160 | fields = [fields[1], fields[0]] 161 | yield Line(fields=fields) 162 | count += 1 163 | -------------------------------------------------------------------------------- /sotastream/augmentors/augmentors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import string 4 | import random 5 | import logging 6 | from typing import Iterator, Iterable, Callable 7 | from subprocess import Popen, PIPE 8 | 9 | import titlecase 10 | from infinibatch.datasets import chunked_dataset_iterator 11 | 12 | from sotastream.data import Line 13 | from sotastream import Defaults 14 | 15 | 16 | logger = logging.getLogger(f"sotastream") 17 | 18 | 19 | def UTF8File(path: str) -> Iterator[str]: 20 | """ 21 | Opens a file and returns a stream of Line objects. 22 | """ 23 | with open(path, "rb") as f: 24 | data = f.read() 25 | if path.endswith('.gz'): 26 | data = gzip.decompress(data) 27 | 28 | for line in data.decode(encoding='utf-8').splitlines(): 29 | yield Line(line) 30 | 31 | 32 | def enumerate_files(dir: str, ext: str): 33 | return [ 34 | os.path.join(dir, path.name) 35 | for path in os.scandir(dir) 36 | if path.is_file() and (ext is None or path.name.endswith(ext)) 37 | ] 38 | 39 | 40 | def DataSource( 41 | path: str, 42 | processChunk: Callable = UTF8File, 43 | ext: str = ".gz", 44 | buffer_size: int = Defaults.BUFFER_SIZE, 45 | seed: int = 1234, 46 | shuffle: bool = True, 47 | worker_id: int = 0, 48 | num_workers: int = 1, 49 | ): 50 | """ 51 | Creates an infinibatch data source from a directory of files that all 52 | have extension {ext}. 53 | 54 | :param path: directory containing chunks 55 | :param processChunk: function to call on each chunk 56 | :param ext: the file extension to glob over 57 | :param buffer_size: how many lines infinibatch loads into memory at a time 58 | :param seed: the random seed 59 | :param shuffle: whether to shuffle results across shards 60 | :param worker_id: For multiprocessing, this worker's ID (0-based) 61 | :param num_workers: For multiprocessing, the number of workers 62 | """ 63 | 64 | # This is used to ensure that infinibatch iterators (a) differ on each node 65 | # and (b) see the same data in the same order, when called multiple times. 66 | # However, having multiple workers on a single node breaks this, because they 67 | # write to their shared queue in an unpredictable order. To fix this, we'd have 68 | # to do round-robin on the queue. 69 | if "OMPI_COMM_WORLD_SIZE" in os.environ: 70 | num_instances = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 71 | instance_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 72 | logger.info(f"Opening path {path} on instance {instance_rank} out of {num_instances} instances") 73 | else: 74 | num_instances = 1 75 | instance_rank = 0 76 | logger.info(f"Opening path {path}") 77 | 78 | # Worker ID i will only see every ith chunk 79 | chunk_file_paths = [] 80 | total_chunks = 0 81 | subpaths = enumerate_files(path, ext) 82 | for pathno, subpath in enumerate(subpaths): 83 | total_chunks += 1 84 | if len(subpaths) < num_workers or pathno % num_workers == worker_id: 85 | chunk_file_paths.append(subpath) 86 | chunk_file_paths.sort() # make sure file order is always the same, independent of OS 87 | 88 | logger.info(f"Worker {worker_id} gets {len(chunk_file_paths)} / {total_chunks} segments in path {path}") 89 | ds = chunked_dataset_iterator( 90 | chunk_refs=chunk_file_paths, 91 | read_chunk_fn=processChunk, 92 | shuffle=shuffle, 93 | buffer_size=buffer_size, 94 | seed=seed, 95 | use_windowed=False, 96 | num_instances=num_instances, 97 | instance_rank=instance_rank, 98 | ) 99 | 100 | return ds 101 | 102 | 103 | class Mixer: 104 | def __init__(self, iterators, probs): 105 | self.iterators = iterators 106 | self.probs = probs 107 | 108 | def __iter__(self): 109 | return self 110 | 111 | def __next__(self): 112 | draw = random.uniform(0, 1) 113 | prob_sum = 0 114 | for i, prob in enumerate(self.probs): 115 | prob_sum += prob 116 | if draw <= prob_sum: 117 | return next(self.iterators[i]) 118 | 119 | return next(self.iterators[0]) # default 120 | 121 | 122 | def Identity(lines): 123 | for line in lines: 124 | yield line 125 | 126 | 127 | def Append(lines, functor): 128 | for line in lines: 129 | line[len(line)] = functor(line) 130 | yield line 131 | 132 | 133 | def canBeUppercased(inputString): 134 | """Check if the input string can be plausibly uppercased (is the uppercased version different from the non-uppercased one). 135 | We randomly sample 10 chars (with repetition if needed) which should be good enough. Note, this is rather meant as a quick 136 | way to identify if a script has casing rather than if a particular string in a script with casing can be uppercased. Both 137 | may be caught.""" 138 | if not inputString: 139 | return False 140 | randChars = "".join(random.choices(inputString, k=10)) 141 | return randChars.upper() != randChars 142 | 143 | 144 | def canBeLowercased(inputString): 145 | """Check if the input string can be plausibly lowercased (is the lowercased version different from the non-lowercased one). 146 | We randomly sample 10 chars (with repetition if needed) which should be good enough. Note, this is rather meant as a quick 147 | way to identify if a script has casing rather than if a particular string in a script with casing can be lowercased. Both 148 | may be caught.""" 149 | if not inputString: 150 | return False 151 | randChars = "".join(random.choices(inputString, k=10)) 152 | return randChars.lower() != randChars 153 | 154 | 155 | def ToUpper(lines, fields=[0, 1], check=None): 156 | """Uppercases all specified fields. If check is set to a field id it conditions the uppercasing 157 | of the entire set on the fact if the checked field can be plausibly uppercased. This is used for 158 | things like Chinese source that has no case and would result in random target casing during inference""" 159 | for line in lines: 160 | if check is None or canBeUppercased(line[check]): 161 | for field in fields: 162 | line[field] = line[field].upper() 163 | yield line 164 | 165 | 166 | def ToLower(lines, fields=[0, 1], check=None): 167 | """Lowercases all specified fields. If check is set to a field id it conditions the lowercasing 168 | of the entire set on the fact if the checked field can be plausibly lowercased.""" 169 | for line in lines: 170 | if check is None or canBeLowercased(line[check]): 171 | for field in fields: 172 | line[field] = line[field].lower() 173 | yield line 174 | 175 | 176 | def ToTitle(lines, fields=[0, 1], check=None): 177 | """Titlecases all specified fields. If check is set to a field id it conditions the titlecasing 178 | of the entire set on the fact if the checked field can be plausibly uppercased.""" 179 | for line in lines: 180 | if check is None or canBeUppercased(line[check]): 181 | for field in fields: 182 | line[field] = titlecase.titlecase(line[field]) 183 | yield line 184 | 185 | 186 | def Tagger(lines, tag="", fields=[0]): 187 | for line in lines: 188 | for field in fields: 189 | line[field] = tag + line[field] 190 | yield line 191 | 192 | 193 | def Copy(lines, from_field=1, to_field=0): 194 | for line in lines: 195 | line[to_field] = line[from_field] 196 | yield line 197 | 198 | 199 | def CopySource(lines): 200 | """Copy source field to target.""" 201 | return Copy(lines, 0, 1) 202 | 203 | 204 | def Multiply(lines, n=2): 205 | """Makes n copies of the underlying object.""" 206 | for line in lines: 207 | for field in range(1, n): 208 | line[field] = line[0] 209 | yield line 210 | 211 | 212 | def JustSourceTarget(lines): 213 | """Removes all but fields 0 and 1""" 214 | for line in lines: 215 | yield Line(str(line)) 216 | 217 | 218 | def SPMEncoder(lines, spm_model): 219 | """Runs the SPM encoder on fields 0 and 1""" 220 | for line in lines: 221 | line[0:2] = list(map(lambda x: " ".join(x), spm_model.encode(line[0:2], out_type=str))) 222 | yield line 223 | 224 | 225 | def SPMDecoder(lines, spm_model): 226 | """SPM decodes fields 0 and 1""" 227 | for line in lines: 228 | line[0:2] = spm_model.decode(list(map(str.split, line[0:2]))) 229 | yield line 230 | -------------------------------------------------------------------------------- /test/regression/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Regression tests for sotastream. Invocation examples: 4 | # ./run.sh 5 | # ./run.sh tests/path/to/dir 6 | # ./run.sh tests/path/to/test_pipeline.sh 7 | # ./run.sh previous.log 8 | 9 | # Environment variables: 10 | # - SOTASTREAM - path to the root directory of sotastream 11 | # - DATA - path to the directory with data, default: ./data 12 | # - PYTHON - path to python command 13 | # - TIMEOUT - maximum duration for execution of a single test in the format 14 | # accepted by the timeout command; set to 0 to disable 15 | 16 | SHELL=/bin/bash 17 | 18 | export LC_ALL=C.UTF-8 19 | export AB_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 20 | 21 | RUN_LOGS="$AB_ROOT/previous.log.tmp" # Logging file for log and logn commands 22 | rm -f $RUN_LOGS 23 | 24 | # Needed so that previous.log is not overwritten when it is provided as an argument 25 | function cleanup { 26 | test -s "$RUN_LOGS" && mv "$RUN_LOGS" "$AB_ROOT/previous.log" 27 | } 28 | trap cleanup EXIT 29 | 30 | function log { 31 | echo "[$(date '+%m/%d/%Y %T')] $@" | tee -a $RUN_LOGS 32 | } 33 | 34 | function logn { 35 | echo -n "[$(date '+%m/%d/%Y %T')] $@" | tee -a $RUN_LOGS 36 | } 37 | 38 | function loge { 39 | echo $@ | tee -a $RUN_LOGS 40 | } 41 | 42 | ##################################################################### 43 | log "Running on $(hostname) as process $$" 44 | 45 | export AB_PYTHON="${PYTHON:-python3}" 46 | log "Python: $AB_PYTHON" 47 | export AB_SOTASTREAM="$( realpath "${SOTASTREAM:-$AB_ROOT/../../}" )" 48 | log "sotastream dir: $AB_SOTASTREAM" 49 | export AB_DATA="$( realpath ${DATA:-$AB_ROOT/data} )" 50 | log "Data dir: $AB_DATA" 51 | 52 | # Add sotastream root directory to PYTHONPATH 53 | export PYTHONPATH+=:$AB_SOTASTREAM 54 | 55 | # Time out 56 | export AB_TIMEOUT=${TIMEOUT:-5m} # the default time out is 5 minutes, see `man timeout` 57 | cmd_timeout="" 58 | if [ $AB_TIMEOUT != "0" ]; then 59 | cmd_timeout="timeout $AB_TIMEOUT" 60 | fi 61 | 62 | log "Time out: $AB_TIMEOUT" 63 | 64 | 65 | # Exit codes 66 | export EXIT_CODE_SUCCESS=0 67 | export EXIT_CODE_SKIP=100 68 | export EXIT_CODE_TIMEOUT=124 # Exit code returned by the timeout command if timed out 69 | 70 | function format_time { 71 | dt=$(python -c "print($2 - $1)" 2>/dev/null) 72 | dh=$(python -c "print(int($dt/3600))" 2>/dev/null) 73 | dt2=$(python -c "print($dt-3600*$dh)" 2>/dev/null) 74 | dm=$(python -c "print(int($dt2/60))" 2>/dev/null) 75 | ds=$(python -c "print($dt2-60*$dm)" 2>/dev/null) 76 | LANG=C printf "%02d:%02d:%02.3fs" $dh $dm $ds 77 | } 78 | 79 | 80 | ############################################################################### 81 | # Default directory with all regression tests 82 | test_prefixes=tests 83 | 84 | if [ $# -ge 1 ]; then 85 | test_prefixes= 86 | for arg in "$@"; do 87 | # A log file with paths to test files 88 | if [[ "$arg" = *.log ]]; then 89 | # Extract tests from .log file 90 | args=$(cat $arg | grep -vP '^\[' | grep '/test_.*\.sh' | grep -v '/_' | sed 's/^ *- *//' | tr '\n' ' ' | sed 's/ *$//') 91 | test_prefixes="$test_prefixes $args" 92 | # A hash tag 93 | elif [[ "$arg" = '#'* ]]; then 94 | # Find all tests with the given hash tag 95 | tag=${arg:1} 96 | args=$(find tests -name '*test_*.sh' | xargs -I{} grep -H "^ *# *TAGS:.* $tag" {} | cut -f1 -d:) 97 | test_prefixes="$test_prefixes $args" 98 | # A test file or directory name 99 | else 100 | test_prefixes="$test_prefixes $arg" 101 | fi 102 | done 103 | fi 104 | 105 | # Check if the variable is empty or contains only spaces 106 | if [[ -z "${test_prefixes// }" ]]; then 107 | log "Error: no tests found in the specified input(s): $@" 108 | exit 1 109 | fi 110 | 111 | # Extract all subdirectories, which will be traversed to look for regression tests 112 | test_dirs=$(find $test_prefixes -type d | grep -v "/_" | cat) 113 | 114 | if grep -q "/test_.*\.sh" <<< "$test_prefixes"; then 115 | test_files=$(printf '%s\n' $test_prefixes | sed 's!*/!!') 116 | test_dirs=$(printf '%s\n' $test_prefixes | xargs -I{} dirname {} | grep -v "/_" | sort | uniq) 117 | fi 118 | 119 | 120 | ############################################################################### 121 | success=true 122 | count_all=0 123 | count_failed=0 124 | count_passed=0 125 | count_skipped=0 126 | count_timedout=0 127 | 128 | declare -a tests_failed 129 | declare -a tests_skipped 130 | declare -a tests_timedout 131 | 132 | time_start=$(date +%s.%N) 133 | 134 | # Traverse test directories 135 | cd $AB_ROOT 136 | for test_dir in $test_dirs 137 | do 138 | log "Checking directory: $test_dir" 139 | nosetup=false 140 | 141 | # Run setup script if exists 142 | if [ -e $test_dir/setup.sh ]; then 143 | log "Running setup script" 144 | 145 | cd $test_dir 146 | $cmd_timeout $SHELL -v setup.sh &> setup.log 147 | if [ $? -ne 0 ]; then 148 | log "Warning: setup script returns a non-success exit code" 149 | success=false 150 | nosetup=true 151 | else 152 | rm setup.log 153 | fi 154 | cd $AB_ROOT 155 | fi 156 | 157 | # Run tests 158 | for test_path in $(ls -A $test_dir/test_*.sh 2>/dev/null) 159 | do 160 | test_file=$(basename $test_path) 161 | test_name="${test_file%.*}" 162 | 163 | # In non-traverse mode skip tests if not requested 164 | if [[ -n "$test_files" && $test_files != *"$test_file"* ]]; then 165 | continue 166 | fi 167 | test_time_start=$(date +%s.%N) 168 | ((++count_all)) 169 | 170 | # Tests are executed from their directory 171 | cd $test_dir 172 | 173 | # Skip tests if setup failed 174 | logn "Running $test_path ... " 175 | if [ "$nosetup" = true ]; then 176 | ((++count_skipped)) 177 | tests_skipped+=($test_path) 178 | loge " skipped" 179 | cd $AB_ROOT 180 | continue; 181 | fi 182 | 183 | # Run test 184 | # Note: all output gets written to stderr (very very few cases write to stdout) 185 | $cmd_timeout $SHELL -x $test_file 2> $test_file.log 1>&2 186 | exit_code=$? 187 | 188 | # Check exit code 189 | if [ $exit_code -eq $EXIT_CODE_SUCCESS ]; then 190 | ((++count_passed)) 191 | loge " OK" 192 | elif [ $exit_code -eq $EXIT_CODE_SKIP ]; then 193 | ((++count_skipped)) 194 | tests_skipped+=($test_path) 195 | loge " skipped" 196 | elif [ $exit_code -eq $EXIT_CODE_TIMEOUT ]; then 197 | ((++count_timedout)) 198 | tests_timedout+=($test_path) 199 | # Add a comment to the test log file that it timed out 200 | echo "The test timed out after $TIMEOUT" >> $test_file.log 201 | # A timed out test is a failed test 202 | ((++count_failed)) 203 | loge " timed out" 204 | success=false 205 | else 206 | ((++count_failed)) 207 | tests_failed+=($test_path) 208 | loge " failed" 209 | success=false 210 | fi 211 | 212 | # Report time 213 | test_time_end=$(date +%s.%N) 214 | test_time=$(format_time $test_time_start $test_time_end) 215 | log "Test took $test_time" 216 | 217 | cd $AB_ROOT 218 | done 219 | cd $AB_ROOT 220 | 221 | # Run teardown script if exists 222 | if [ -e $test_dir/teardown.sh ]; then 223 | log "Running teardown script" 224 | 225 | cd $test_dir 226 | $cmd_timeout $SHELL teardown.sh &> teardown.log 227 | if [ $? -ne 0 ]; then 228 | log "Warning: teardown script returns a non-success exit code" 229 | success=false 230 | else 231 | rm teardown.log 232 | fi 233 | cd $AB_ROOT 234 | fi 235 | done 236 | 237 | time_end=$(date +%s.%N) 238 | time_total=$(format_time $time_start $time_end) 239 | 240 | 241 | ############################################################################### 242 | # Print skipped and failed tests 243 | if [ -n "$tests_skipped" ] || [ -n "$tests_failed" ] || [ -n "$tests_timedout" ]; then 244 | loge "---------------------" 245 | fi 246 | [[ -z "$tests_skipped" ]] || loge "Skipped:" 247 | for test_name in "${tests_skipped[@]}"; do 248 | loge "- $test_name" 249 | done 250 | [[ -z "$tests_failed" ]] || loge "Failed:" 251 | for test_name in "${tests_failed[@]}"; do 252 | loge "- $test_name" 253 | done 254 | [[ -z "$tests_timedout" ]] || loge "Timed out:" 255 | for test_name in "${tests_timedout[@]}"; do 256 | loge "- $test_name" 257 | done 258 | [[ -z "$tests_failed" ]] || echo "Logs:" 259 | for test_name in "${tests_failed[@]}"; do 260 | echo "- $(realpath $test_name | sed 's/\.sh/.sh.log/')" 261 | done 262 | 263 | 264 | ############################################################################### 265 | # Print summary 266 | loge "---------------------" 267 | loge -n "Ran $count_all tests in $time_total, $count_passed passed, $count_skipped skipped, $count_failed failed" 268 | [ -n "$tests_timedout" ] && loge -n " (incl. $count_timedout timed out)" 269 | loge "" 270 | 271 | # Return exit code 272 | if $success && [ $count_all -gt 0 ]; then 273 | loge "OK" 274 | exit 0 275 | else 276 | loge "FAILED" 277 | exit 1 278 | fi 279 | -------------------------------------------------------------------------------- /sotastream/cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | 5 | # This might have to do with functioning on mounted Azure blobs 6 | sys.dont_write_bytecode = True 7 | 8 | import argparse 9 | import logging 10 | import json 11 | import os 12 | import time 13 | 14 | from collections import defaultdict 15 | from multiprocessing import Pipe, Process 16 | from typing import Type 17 | 18 | from . import __version__, Defaults 19 | from .utils.split import split_file_into_chunks 20 | from .pipelines import Pipeline, PIPELINES 21 | 22 | # Use seed in logger for when multiple are running 23 | logger = logging.getLogger(f"sotastream") 24 | 25 | USER = os.environ.get('USER', os.environ.get('USERNAME', 'nouser')) 26 | 27 | 28 | def adjustSeed(seed, local_num_instances, local_instance_rank): 29 | """ 30 | Adjust seed for infinibatch such that each instance gets a different one based on process number and MPI 31 | coordinates. 32 | """ 33 | if seed == 0: 34 | seed = round(time.time() * 1000) # the current time in milliseconds 35 | 36 | mpi_num_instances = 1 37 | mpi_instance_rank = 0 38 | 39 | # these variables are set automatically by mpirun when used inside an MPI world 40 | if "OMPI_COMM_WORLD_SIZE" in os.environ: 41 | mpi_num_instances = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 42 | mpi_instance_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 43 | 44 | # hash-combine seed with local process number and rank and MPI process number and rank 45 | hashed_seed = hash((seed, local_num_instances, local_instance_rank, mpi_num_instances, mpi_instance_rank)) 46 | 47 | logger.info( 48 | f"Computed seed {hashed_seed} from original seed {seed} and instance info: ({local_num_instances}, {local_instance_rank}, {mpi_num_instances}, {mpi_instance_rank})" 49 | ) 50 | return hashed_seed 51 | 52 | 53 | def run_pipeline_process(conn, args, seed, worker_id, num_workers): 54 | """ 55 | Runs a pipeline in a single subprocess. Each subprocess writes to 56 | the pipe (conn) after it has seen the specified number (args.queue_buffer_size) 57 | of lines. 58 | """ 59 | 60 | kwargs = {k: v for k, v in vars(args).items() if not (k in ["pipeline", "seed"])} 61 | 62 | # These environment variables are used in the subprocesses to determine which worker they are 63 | os.environ["SOTASTREAM_WORKER_ID"] = str(worker_id) 64 | os.environ["SOTASTREAM_WORKER_COUNT"] = str(num_workers) 65 | pipeline = Pipeline.create(args.pipeline, seed=seed, **kwargs) 66 | 67 | try: 68 | lines = [] 69 | for line in pipeline: 70 | lines.append(str(line)) 71 | if len(lines) >= min(args.queue_buffer_size, args.buffer_size): 72 | conn.send(lines) 73 | lines = [] 74 | if lines: 75 | conn.send(lines) 76 | finally: 77 | conn.close() 78 | 79 | 80 | def add_global_args(parser: argparse.ArgumentParser): 81 | """ 82 | Add global arguments to the parser. These appear before the pipeline argument and are available 83 | to all pipelines. 84 | 85 | :param parser: The parser to add the options to. 86 | """ 87 | parser.add_argument( 88 | "--log-rate", "-lr", type=int, default=0, metavar="N", help="Log every Nth instance (0=off)" 89 | ) 90 | parser.add_argument( 91 | "--log-first", 92 | "-lf", 93 | type=int, 94 | default=5, 95 | metavar="N", 96 | help="Log first N instances (default: %(default)s)", 97 | ) 98 | parser.add_argument("--sample-file", type=argparse.FileType("tw"), help="Where to log samples") 99 | parser.add_argument( 100 | '--buffer-size', 101 | '-b', 102 | help='Number of lines infinibatch will load into memory', 103 | type=int, 104 | default=Defaults.BUFFER_SIZE, 105 | ) 106 | parser.add_argument( 107 | '--queue-buffer-size', 108 | '-q', 109 | help='Queue buffer size', 110 | type=int, 111 | default=Defaults.QUEUE_BUFFER_SIZE, 112 | ) 113 | parser.add_argument( 114 | '--seed', 115 | '-s', 116 | help='Random seed (default 0 uses time for initialization)', 117 | type=int, 118 | default=Defaults.SEED, 119 | ) 120 | parser.add_argument( 121 | '--num-processes', 122 | '-n', 123 | help='Number of processes to use for better throughput', 124 | type=int, 125 | default=Defaults.NUM_PROCESSES, 126 | ) 127 | parser.add_argument('--version', '-V', action='version', version='sotastream {}'.format(__version__)) 128 | parser.add_argument( 129 | "--split-tmpdir", 130 | default=f"/tmp/sotastream-{USER}", 131 | help="Base temporary directory to use when splitting data files", 132 | ) 133 | parser.add_argument("--quiet", action="store_true", help="Suppress logging output") 134 | 135 | 136 | def maybe_split_files(args): 137 | """Split data files into smaller files in a temporary directory 138 | 139 | This function updates args inplace: it replaces .gz paths (if any) with split dirs. 140 | 141 | Args: 142 | args: CLI args object from argparse 143 | """ 144 | # Look up the class for the pipeline, and get the named list of arguments 145 | PipelineClass: Type['Pipeline'] = PIPELINES[args.pipeline] 146 | args_dict = vars(args) 147 | data_source_params = PipelineClass.get_data_sources_for_argparse() 148 | # Use the name to get the path from the runtime args object 149 | data_sources = [(x[0], args_dict[x[0]]) for x in data_source_params] 150 | for name, path in data_sources: 151 | # For any path that is a .gz file, split it into chunks. 152 | # Directories that were pre-split are left as-is. 153 | if not isinstance(path, str): 154 | logger.warning(f"Skipping {name}={path} because it is {type(path)}, but str expected") 155 | continue 156 | if not os.path.isdir(path) and path.endswith(".gz"): 157 | splitdir = split_file_into_chunks(path, tmpdir=args.split_tmpdir, split_size=args.buffer_size) 158 | setattr(args, name, splitdir) 159 | # Inject a keyword argument 'data_sources' that contains all data sources 160 | setattr(args, 'data_sources', [path for name, path in data_sources]) 161 | 162 | 163 | def main(): 164 | stats = defaultdict(int) 165 | stats['start_time'] = time.time() 166 | # Get the list of available pipelines 167 | parser = argparse.ArgumentParser( 168 | prog='sotastream', 169 | description='Command line wrapper for augmentation pipelines', 170 | formatter_class=argparse.RawTextHelpFormatter, 171 | epilog='''\n\nTo load additional pipelines create (or symlink) *_pipeline.py files from current directory.''', 172 | ) 173 | add_global_args(parser) 174 | 175 | # Each pipeline is a different subcommand with its own arguments. 176 | sub_parsers = parser.add_subparsers( 177 | dest='pipeline', 178 | required=True, 179 | metavar="pipeline", 180 | help="The pipeline to run. Available pipelines:\n- " + "\n- ".join(sorted(PIPELINES.keys())), 181 | ) 182 | for pipeline_name, pipeline_class in PIPELINES.items(): 183 | # Create a sub-parser and add the pipeline's arguments to it. 184 | sub_parser = sub_parsers.add_parser( 185 | pipeline_name, description=pipeline_class.__doc__, formatter_class=argparse.RawTextHelpFormatter 186 | ) 187 | pipeline_class.add_cli_args(sub_parser) 188 | 189 | args = parser.parse_args() 190 | logLevel = logging.CRITICAL if args.quiet else logging.INFO 191 | logging.basicConfig(level=logLevel) 192 | 193 | maybe_split_files(args) 194 | 195 | N = args.num_processes 196 | 197 | pipes = [Pipe() for i in range(N)] 198 | processes = [ 199 | Process(target=run_pipeline_process, args=(pipes[i][1], args, adjustSeed(args.seed, N, i), i, N)) 200 | for i in range(N) 201 | ] 202 | for p in processes: 203 | p.start() 204 | 205 | overhead_time = time.time() 206 | 207 | lineno = 0 208 | num_fields = defaultdict(int) 209 | try: 210 | # round-robin across the pipes forever 211 | while True: 212 | for pipe in pipes: 213 | # To avoid pickling (and the associated timing costs), lines 214 | # are transmitted as strings, not Line objects. 215 | lines = pipe[0].recv() 216 | for line in lines: 217 | fields = line.split("\t") 218 | num_fields[len(fields)] += 1 219 | print(line) 220 | lineno += 1 221 | 222 | if (args.log_rate > 0 and lineno % args.log_rate == 0) or lineno <= args.log_first: 223 | if args.sample_file: 224 | print(line, file=args.sample_file) 225 | else: 226 | logger.info(f"SAMPLE {lineno}: {line}") 227 | except BrokenPipeError: # this is not really an error, just means that the receiving process has ended 228 | # Python flushes standard streams on exit; redirect remaining output 229 | # to devnull to avoid another BrokenPipeError at shutdown 230 | devnull = os.open(os.devnull, os.O_WRONLY) 231 | os.dup2(devnull, sys.stdout.fileno()) 232 | finally: 233 | # Looks like the process that we are piping to is done, let's wrap things up 234 | for p in processes: 235 | p.terminate() 236 | 237 | stats['end_time'] = time.time() 238 | stats['lines_produced'] = f'{lineno:,}' 239 | stats['num_fields'] = num_fields 240 | total_time = stats['end_time'] - stats['start_time'] 241 | stats['overhead_time'] = overhead_time - stats['start_time'] 242 | stats['total_time'] = f"{total_time:,.3f} sec" 243 | stats['yield_rate'] = f"{lineno / total_time:,.2f} lines/sec" 244 | stats['yield_rate_sans_overhead'] = f"{lineno / (stats['end_time'] - overhead_time):,.2f} lines/sec" 245 | logger.info('Summary: ' + json.dumps(stats, indent=2)) 246 | 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /sotastream/pipelines/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import itertools 3 | import logging 4 | import random 5 | import os 6 | 7 | from sotastream import Defaults 8 | from sotastream.augmentors import DataSource, UTF8File 9 | from sentencepiece import SentencePieceProcessor 10 | from typing import List, Tuple, Callable 11 | 12 | logger = logging.getLogger(f"sotastream") 13 | 14 | 15 | class Pipeline(ABC): 16 | """Pipeline base class 17 | 18 | * To add a pipeline, extend this class and add @pipeline("pipeline_name") decorator 19 | * CLI arguments 20 | * To specify the data source names for argparse, override `get_data_sources_for_argparse` class method. 21 | * To specify default mixing weights for argparse, override `get_data_sources_default_weights` class method. 22 | * To specify any other CLI arguments override `add_cli_args` class method. 23 | * Don't forget the @classmethod decorator and cls as first argument. 24 | * The filename should match the pattern {name}_pipeline.py, where {name} is the name used in the decorator. 25 | * All CLI arguments are passed as arguments to constrctor. 26 | * Refer to default.py or t1_pipeline.py for example pipelines. 27 | """ 28 | 29 | def __init__(self, **kwargs) -> None: 30 | # Store the data sources. The length will vary based on how many the subclass expects. 31 | self.data_sources = kwargs.get('data_sources', []) 32 | 33 | spm_file = kwargs.get("spm", None) 34 | if spm_file: 35 | self.spm_model = SentencePieceProcessor(model_file=spm_file) 36 | else: 37 | logger.warning("Creating pipeline without an SPM model") 38 | self.spm_model = None 39 | self.sample_file = kwargs.get("sample_file") 40 | self.buffer_size = kwargs.get("buffer_size", Defaults.BUFFER_SIZE) 41 | self.queue_buffer_size = kwargs.get("queue_buffer_size", Defaults.QUEUE_BUFFER_SIZE) 42 | self.is_quiet = kwargs.get("quiet", Defaults.QUIET) 43 | self.seed = kwargs.get("seed", Defaults.SEED) 44 | self.max_tokens = kwargs.get("max_tokens", Defaults.MAX_TOKENS) 45 | self.sample_length = kwargs.get("sample_length", Defaults.SAMPLE_LENGTH) 46 | self.separator = kwargs.get("separator", Defaults.SEPARATOR) 47 | self.shuffle = not kwargs.get("no_shuffle", not Defaults.SHUFFLE) 48 | 49 | random.seed(self.seed) 50 | 51 | # These are set in the environment of the caller when multiprocessing is enabled. 52 | # Each sub-process gets a distinct worker ID and knows the total number of workers. 53 | # These values are used to allocate the shards of a data source in a round-robin 54 | # fashion, such that each subprocess has 1/Nth of the total data, ensuring that 55 | # reading from them in round-robin fashion produces a permutation over the training 56 | # data. You can see that these values are used below in the create_data_stream() 57 | # function, a wrapper around the underlying infinibatch data structure. 58 | self.worker_id = int(os.environ.get("SOTASTREAM_WORKER_ID", "0")) 59 | self.num_workers = int(os.environ.get("SOTASTREAM_WORKER_COUNT", "1")) 60 | 61 | # sanity-check mix_weights and normalize 62 | self.mix_weights = kwargs.get("mix_weights", self.get_data_sources_default_weights()) 63 | if any([w < 0 for w in self.mix_weights]): 64 | raise ValueError("Mix weights must be non-negative") 65 | self.mix_weights = [w / sum(self.mix_weights) for w in self.mix_weights] 66 | if len(self.data_sources) != len(self.mix_weights): 67 | raise Exception( 68 | f'Data sources does not match weights: {self.data_sources} != {self.mix_weights} ' 69 | ) 70 | 71 | # log a message with the mix weights and the data source names 72 | # e.g., INFO:sotastream:Using mix weights: 99.75% (parallel_data) 0.25% (garbage_data) 73 | mix_weight_message = "Using mix weights:\n\t" + "\n\t".join( 74 | f"{w*100:.5g}% : {path}) ({arg and arg[0] or ''})" 75 | for w, arg, path in itertools.zip_longest( 76 | self.mix_weights, self.get_data_sources_for_argparse(), self.data_sources 77 | ) 78 | ) 79 | logger.info(mix_weight_message) 80 | 81 | self.stream = None # to be initialized in subclass 82 | 83 | @classmethod 84 | def add_cli_args(cls, parser): 85 | """ 86 | Add CLI arguments to pipeline specific subparser. 87 | These arguments are shared across all pipelines and appear after the pipeline name in the CLI. 88 | For global args that appear before the pipeline name, see sotastream.cli.add_cli_args 89 | """ 90 | 91 | # Validate the default weights provided in the pipeline 92 | data_sources = cls.get_data_sources_for_argparse() 93 | mix_weights = cls.get_data_sources_default_weights() 94 | if len(mix_weights) != len(data_sources): 95 | raise ValueError( 96 | f"Number of data sources ({len(data_sources)}) must match number of weights ({len(mix_weights)})" 97 | ) 98 | 99 | for arg_spec in cls.get_data_sources_for_argparse(): 100 | name, desc = arg_spec[:2] 101 | nargs = None 102 | if len(arg_spec) > 2: 103 | nargs = arg_spec[2] 104 | parser.add_argument(name, help=desc, nargs=nargs) 105 | 106 | parser.add_argument("--spm", help="SPM model (for more accurate length calculation") 107 | parser.add_argument( 108 | "--separator", 109 | default=" ", 110 | help="String to use when joining sentences for data augmentation (default: '%(default)s').", 111 | ) 112 | parser.add_argument( 113 | "--max-joined-tokens", 114 | "--max-tokens", 115 | "-m", 116 | dest="max_tokens", 117 | type=int, 118 | default=Defaults.MAX_TOKENS, 119 | help="Maximum number of tokens to join", 120 | ) 121 | 122 | if len(mix_weights) == 1 and mix_weights[0] in ('+', '*'): 123 | mix_weights_default = None 124 | mix_weights_nargs = mix_weights[0] 125 | else: 126 | mix_weights_default = mix_weights 127 | mix_weights_nargs = len(mix_weights) 128 | 129 | parser.add_argument( 130 | "--mix-weights", 131 | "-w", 132 | type=float, 133 | metavar="WEIGHT", 134 | nargs=mix_weights_nargs, # validate the number of weights provided by the user 135 | default=mix_weights_default, 136 | help="Weights to use when mixing data sources (will be normalized if don't sum to 1.0) (default: %(default)s)", 137 | ) 138 | 139 | def create_data_stream( 140 | self, data_path, processor: Callable = UTF8File, buffer_size: int = None, ext: str = ".gz" 141 | ): 142 | """ 143 | Wrapper around data source creation to allow for easy overriding in subclasses. 144 | 145 | The worker ID and number of workers is passed to the DataSource class, which uses 146 | them to select the subset of shards this process will have access to. 147 | 148 | :param data_path: Path to data source 149 | :param processor: Augmentor processor function to apply to each chunk 150 | :param buffer_size: The buffer size to use 151 | :param ext: The extension of the data source 152 | """ 153 | return DataSource( 154 | data_path, 155 | processChunk=processor, 156 | ext=ext, 157 | buffer_size=buffer_size or self.buffer_size, 158 | seed=self.seed, 159 | worker_id=self.worker_id, 160 | num_workers=self.num_workers, 161 | ) 162 | 163 | @classmethod 164 | def get_data_sources_for_argparse(cls) -> List[Tuple[str, str]]: 165 | """ 166 | This returns a list of (name, description) pairs for each data source. 167 | This is used to instantiate the argparse subcommand with named positional arguments. 168 | These are not the actual instantiated data paths; for that, each class has 169 | The function name is quite verbose in order to minimize confusion. 170 | 171 | Returns: 172 | List[Tuple]: List of (name, description) 173 | """ 174 | return [ 175 | ( 176 | 'data', 177 | 'Path to data source (a pre-split folder containing .gz files, or a single compressed TSV)', 178 | ) 179 | ] 180 | 181 | @classmethod 182 | def get_data_sources_default_weights(cls) -> List[float]: 183 | """ 184 | A list of floats corresponding to the number of data sources and specifying the mixture weights among them. 185 | These will be provided to the argparse subcommand as the default values for the --mix-weights argument. 186 | To get the actual instantiated values, use self.mix_weights. 187 | The function is named in an overly explicit way to avoid confusion between these two sources. 188 | """ 189 | return [1.0] 190 | 191 | def __iter__(self): 192 | return self 193 | 194 | def __next__(self): 195 | return next(self.stream) 196 | 197 | @staticmethod 198 | def create(name: str, *args, **kwargs): 199 | """ 200 | Create an instance of Pipeline for a given pipeline name 201 | """ 202 | from . import PIPELINES 203 | 204 | assert name in PIPELINES, f'No pipeline with name {name} found' 205 | return PIPELINES[name](*args, **kwargs) 206 | 207 | 208 | class DocumentPipeline(Pipeline): 209 | """ 210 | Extends Pipeline base with document-level CLI args. 211 | """ 212 | 213 | def __init__(self, **kwargs) -> None: 214 | super().__init__(**kwargs) 215 | self.doc_separator = kwargs.get("doc_separator", Defaults.DOC_SEPARATOR) 216 | self.doc_prob = kwargs.get("doc_prob", Defaults.DOC_PROB) 217 | self.doc_prob_parallel = kwargs.get("doc_prob_parallel", Defaults.DOC_PROB_PARALLEL) 218 | 219 | @classmethod 220 | def add_cli_args(cls, parser): 221 | """ 222 | Add document-specific arguments. 223 | """ 224 | super().add_cli_args(parser) 225 | 226 | parser.add_argument( 227 | "--doc-separator", 228 | default=Defaults.DOC_SEPARATOR, 229 | help="Sentence joiner token to use when building docs (default: '%(default)s').", 230 | ) 231 | parser.add_argument( 232 | "--doc-prob-parallel", 233 | type=float, 234 | default=Defaults.DOC_PROB, 235 | help="Probability of creating a doc from parallel data", 236 | ) 237 | parser.add_argument( 238 | "--doc-prob", 239 | type=float, 240 | default=Defaults.DOC_PROB_PARALLEL, 241 | help="Probability of creating a doc from backtrans data", 242 | ) 243 | -------------------------------------------------------------------------------- /test/test_augmentors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | import pytest 6 | import random 7 | from typing import Iterable, List 8 | 9 | from sotastream.data import Line 10 | from sotastream.augmentors import * 11 | 12 | from collections import Counter 13 | 14 | TEST_CORPUS = [ 15 | "München 1856: Vier Karten, die Ihren Blick auf die Stadt verändern Munich 1856: Four maps that will change your view of the city", 16 | "Eine Irren-Anstalt, wo sich heute Jugendliche begegnen sollen. A mental asylum, where today young people are said to meet.", 17 | "Eine Gruftkapelle, wo nun für den S-Bahn-Tunnel gegraben wird. A crypt chapel, where they are now digging tunnels for the S-Bahn.", 18 | "Kleingärtner bewirtschaften den einstigen Grund von Bauern. Allotment holders cultivate the soil of former farmers.", 19 | "Die älteste offizielle Karte Münchens fördert spannende Geschichten zu Tage. The oldest official map of Munich brings captivating stories to light.", 20 | "Es nervt, wenn Landkarten nicht aktuell sind. It is annoying when geographical maps are not up-to-date.", 21 | "Das kennt jeder, der sich schon mal aufregen musste, weil das Auto-Navi statt einer Umgehungsstraße eine grüne Wiese anzeigte. Anyone who has ever got worked up because the car's sat-nav is showing a green field instead of a bypass knows that.", 22 | "Die historischen Landkarten des digitalen Bayern-Atlases, ein Angebot des Geoportals Bayern der Staatsregierung, sind alles andere als aktuell - doch gerade deshalb sehr aufschlussreich. The historical maps of the digital BayernAtlas, an offering from the State Government's Geoportal Bayern, are anything but up-to-date – and yet it is precisely for this reason that they are so informative.", 23 | "Besonders wenn man sie mit aktuellen Online-Karten vergleicht. Especially when one compares them with current online maps.", 24 | "Dann wird deutlich, wie sich Städte und Gemeinden im Verbreitungsgebiet des Münchner Merkur seit dem 19. Jahrhundert verändert haben. Then it becomes clear how the towns and municipalities in the distribution area of Munich's Merkur newspaper have changed since the 19th century.", 25 | ] 26 | 27 | 28 | # ToLines should be a generator since everything else is generators or iterators 29 | def ToLines(lines: List[str]) -> Iterable[Line]: 30 | for x in lines: 31 | yield Line(x) 32 | 33 | 34 | def test_line_creation(): 35 | lines = ToLines(TEST_CORPUS) 36 | 37 | for lineno, line in enumerate(lines): 38 | assert len(line) == 2, f"wrong length for {lineno}: {line.fields[2]}" 39 | 40 | 41 | # we need to call ToLines here twice to create a new generator, otherwise zipping would advance the same generator twice 42 | def test_identity(): 43 | lines1 = ToLines(TEST_CORPUS) 44 | lines2 = ToLines(TEST_CORPUS) 45 | 46 | for line, modline in zip(lines1, Identity(lines2)): 47 | assert line == modline 48 | 49 | 50 | def test_tolower(): 51 | for line in ToLower(ToLines(TEST_CORPUS)): 52 | assert line[0].islower() and line[1].islower() 53 | 54 | # Make sure lowering source works 55 | for line in ToLower(ToLines(TEST_CORPUS), fields=[0]): 56 | assert line[0].islower() and not line[1].islower() 57 | 58 | # Lower target 59 | for line in ToLower(ToLines(TEST_CORPUS), fields=[1]): 60 | assert not line[0].islower() and line[1].islower() 61 | 62 | 63 | def test_toupper(): 64 | for line in ToUpper(ToLines(TEST_CORPUS)): 65 | assert line[0].isupper() and line[1].isupper() 66 | 67 | # Make sure uppering source works 68 | for line in ToUpper(ToLines(TEST_CORPUS), fields=[0]): 69 | assert line[0].isupper() and not line[1].isupper() 70 | 71 | # Upper target 72 | for line in ToUpper(ToLines(TEST_CORPUS), fields=[1]): 73 | assert not line[0].isupper() and line[1].isupper() 74 | 75 | # Only uppercase if field `check` can be uppercased 76 | for line in ToUpper(ToLines(["她是囚犯还是老板? Is she prisoner or boss?"]), fields=[0, 1], check=1): 77 | assert not line[0].isupper() and line[1].isupper() 78 | 79 | for line in ToUpper(ToLines(["她是囚犯还是老板? Is she prisoner or boss?"]), fields=[0, 1], check=0): 80 | assert not line[0].isupper() and not line[1].isupper() 81 | 82 | for line in ToUpper(ToLines(["她是囚犯还是老板? Is she prisoner or boss?"]), fields=[0, 1]): 83 | assert not line[0].isupper() and line[1].isupper() 84 | 85 | 86 | @pytest.mark.parametrize("n", range(2, 5)) 87 | def test_multiply(n): 88 | for line in Multiply(ToLines(TEST_CORPUS), n=n): 89 | for i in range(1, n): 90 | assert line[0] == line[i] 91 | 92 | 93 | ###################################################################################### 94 | 95 | DIACRITICIZED_CORPUS = [ 96 | "ąćęłńóśźż ĄĆĘŁŃÓŚŹŻ\tąćęłńóśźż", 97 | "áčďéě íňóřš ťúůýž ÁČĎÉĚ ÍŇÓŘŠ ŤÚŮÝŽ\táčďéě íňóřš ťúůýž", 98 | "äöü ÄÖÜ\täöü", 99 | "ç é âêîôû àèìòù ëïü Ç É ÂÊÎÔÛ ÀÈÌÒÙ ËÏÜ\tç é âêîôû àèìòù ëïü", 100 | "áéíóú ñ ü ÁÉÍÓÚ Ñ Ü\táéíóú ñ ü", 101 | ] 102 | 103 | TRAILING_PUNCT_CORPUS = [ 104 | "This is a test sentence.\t这是一个测试语句。", 105 | "This is a test sentence!\t这是一个测试语句!", 106 | "This is a test sentence.\tこれはテスト文です。", 107 | "This is a test sentence;\tこれはテスト文です!", 108 | "This is a test sentence.\t테스트 문장입니다.", 109 | "This is a test sentence;\tهذه جملة اختبار.", 110 | "这是一个测试语句。\tThis is a test sentence." "这是一个测试语句!\tThis is a test sentence!", 111 | "これはテスト文です。\tThis is a test sentence.", 112 | "これはテスト文です!\tThis is a test sentence;", 113 | "테스트 문장입니다.\tThis is a test sentence.", 114 | "هذه جملة اختبار.\tThis is a test sentence;", 115 | ] 116 | 117 | 118 | ###################################################################################### 119 | 120 | # char-based alignments for input above 121 | TEST_CORPUS_ALN = [ 122 | [ 123 | ((0, 9), (0, 7)), 124 | ((9, 10), (7, 8)), 125 | ((10, 11), (8, 9)), 126 | ((11, 12), (9, 10)), 127 | ((12, 13), (10, 11)), 128 | ((13, 15), (11, 13)), 129 | ((15, 20), (13, 18)), 130 | ((20, 26), (18, 23)), 131 | ((26, 28), (18, 23)), 132 | ((28, 32), (23, 28)), 133 | ((32, 38), (28, 33)), 134 | ((38, 44), (33, 40)), 135 | ((44, 48), (40, 45)), 136 | ((48, 52), (45, 50)), 137 | ((52, 58), (50, 53)), 138 | ((58, 68), (53, 57)), 139 | ], 140 | [ 141 | ((0, 5), (0, 2)), 142 | ((5, 8), (9, 11)), 143 | ((8, 10), (11, 13)), 144 | ((10, 11), (11, 13)), 145 | ((11, 18), (15, 17)), 146 | ((18, 20), (15, 17)), 147 | ((18, 20), (17, 23)), 148 | ((20, 23), (23, 29)), 149 | ((23, 28), (29, 35)), 150 | ((28, 34), (35, 42)), 151 | ((34, 45), (42, 46)), 152 | ((46, 54), (46, 51)), 153 | ((54, 55), (51, 54)), 154 | ((55, 61), (54, 58)), 155 | ((61, 62), (58, 59)), 156 | ], 157 | [ 158 | ((0, 5), (0, 2)), 159 | ((5, 6), (2, 7)), 160 | ((6, 10), (7, 8)), 161 | ((10, 17), (8, 14)), 162 | ((17, 19), (14, 16)), 163 | ((19, 22), (14, 16)), 164 | ((22, 26), (16, 22)), 165 | ((26, 31), (22, 27)), 166 | ((31, 35), (27, 31)), 167 | ((35, 36), (35, 38)), 168 | ((35, 36), (59, 60)), 169 | ((36, 37), (38, 43)), 170 | ((37, 41), (61, 65)), 171 | ((41, 42), (49, 51)), 172 | ((41, 42), (60, 61)), 173 | ((42, 48), (43, 49)), 174 | ((42, 48), (55, 59)), 175 | ((48, 49), (7, 8)), 176 | ((58, 62), (61, 65)), 177 | ((62, 63), (65, 66)), 178 | ], 179 | [ 180 | ((0, 5), (0, 2)), 181 | ((5, 6), (2, 5)), 182 | ((6, 9), (2, 5)), 183 | ((9, 10), (5, 10)), 184 | ((10, 14), (10, 18)), 185 | ((14, 16), (10, 18)), 186 | ((14, 16), (27, 28)), 187 | ((16, 26), (18, 27)), 188 | ((26, 29), (27, 28)), 189 | ((29, 33), (28, 32)), 190 | ((33, 36), (28, 32)), 191 | ((36, 38), (32, 37)), 192 | ((38, 43), (37, 40)), 193 | ((43, 49), (40, 47)), 194 | ((49, 53), (40, 47)), 195 | ((53, 59), (47, 54)), 196 | ((59, 60), (54, 55)), 197 | ], 198 | [ 199 | ((0, 4), (0, 4)), 200 | ((4, 13), (4, 11)), 201 | ((13, 24), (11, 20)), 202 | ((24, 30), (20, 24)), 203 | ((30, 38), (24, 27)), 204 | ((38, 40), (27, 34)), 205 | ((38, 40), (34, 41)), 206 | ((40, 49), (41, 44)), 207 | ((49, 57), (44, 47)), 208 | ((57, 59), (47, 53)), 209 | ((59, 71), (53, 61)), 210 | ((71, 74), (61, 64)), 211 | ((74, 78), (64, 69)), 212 | ((78, 79), (69, 70)), 213 | ], 214 | [ 215 | ((0, 3), (0, 3)), 216 | ((3, 7), (3, 6)), 217 | ((7, 8), (6, 15)), 218 | ((7, 8), (15, 20)), 219 | ((8, 10), (20, 33)), 220 | ((10, 15), (33, 38)), 221 | ((15, 19), (38, 42)), 222 | ((19, 26), (42, 46)), 223 | ((26, 32), (46, 48)), 224 | ((32, 39), (48, 49)), 225 | ((32, 39), (49, 51)), 226 | ((39, 40), (51, 52)), 227 | ((40, 44), (52, 56)), 228 | ((44, 45), (56, 57)), 229 | ], 230 | [ 231 | ((0, 4), (0, 7)), 232 | ((4, 9), (7, 11)), 233 | ((10, 15), (11, 15)), 234 | ((15, 17), (15, 20)), 235 | ((17, 21), (20, 24)), 236 | ((21, 26), (24, 31)), 237 | ((26, 32), (31, 34)), 238 | ((66, 67), (55, 56)), 239 | ((71, 72), (59, 60)), 240 | ((101, 106), (93, 96)), 241 | ((106, 113), (96, 98)), 242 | ((113, 117), (100, 105)), 243 | ((117, 119), (100, 105)), 244 | ((119, 121), (105, 111)), 245 | ((121, 127), (111, 115)), 246 | ((127, 128), (115, 116)), 247 | ], 248 | [ 249 | ((42, 48), (35, 41)), 250 | ((48, 49), (125, 126)), 251 | ((49, 54), (41, 46)), 252 | ((56, 58), (46, 48)), 253 | ((74, 77), (88, 91)), 254 | ((77, 83), (91, 98)), 255 | ((83, 85), (86, 88)), 256 | ((185, 186), (206, 207)), 257 | ], 258 | [ 259 | ((0, 10), (0, 11)), 260 | ((10, 15), (11, 16)), 261 | ((15, 19), (16, 20)), 262 | ((19, 23), (16, 20)), 263 | ((23, 27), (20, 27)), 264 | ((27, 37), (27, 29)), 265 | ((37, 43), (29, 34)), 266 | ((43, 44), (34, 39)), 267 | ((44, 51), (39, 47)), 268 | ((51, 54), (39, 47)), 269 | ((54, 60), (47, 54)), 270 | ((60, 61), (54, 58)), 271 | ((61, 62), (58, 59)), 272 | ], 273 | [ 274 | ((0, 5), (0, 5)), 275 | ((5, 10), (5, 8)), 276 | ((10, 18), (8, 16)), 277 | ((18, 20), (16, 22)), 278 | ((20, 24), (22, 26)), 279 | ((24, 29), (26, 30)), 280 | ((87, 90), (92, 95)), 281 | ((90, 93), (95, 98)), 282 | ((93, 94), (98, 99)), 283 | ((103, 104), (132, 133)), 284 | ((104, 105), (128, 132)), 285 | ((104, 105), (133, 134)), 286 | ((107, 119), (133, 134)), 287 | ((119, 130), (134, 137)), 288 | ((130, 135), (137, 144)), 289 | ((135, 136), (144, 145)), 290 | ], 291 | ] 292 | 293 | # output for corpus above after phrasefixing, randomly seeded by string hash of input, hence should be deterministic 294 | TEST_CORPUS_PHRASEFIX_OUT = [ 295 | "München 1856(phrasefix)(#2)(#2)(#phrasefix) Vier (phrasefix)(#8)(#9)(#phrasefix) Blick auf (phrasefix)(#2)(#8)(#phrasefix) Munich 1856(phrasefix)(#2)(#2)(#phrasefix) Four (phrasefix)(#8)(#9)(#phrasefix) change your (phrasefix)(#2)(#8)(#phrasefix) city", 296 | "(phrasefix)(#1)(#3)(#phrasefix) (phrasefix)(#2)(#6)(#phrasefix) wo sich heute (phrasefix)(#7)(#9)(#phrasefix) begegnen sollen. (phrasefix)(#1)(#3)(#phrasefix) (phrasefix)(#2)(#6)(#phrasefix) today young people (phrasefix)(#7)(#9)(#phrasefix) said to meet.", 297 | "(phrasefix)(#2)(#8)(#phrasefix) Gruftkapelle, wo (phrasefix)(#2)(#0)(#phrasefix) (phrasefix)(#8)(#2)(#phrasefix) S-Bahn-Tunnel gegraben wird. (phrasefix)(#2)(#8)(#phrasefix) crypt chapel, (phrasefix)(#2)(#0)(#phrasefix) (phrasefix)(#8)(#2)(#phrasefix) digging tunnels for the S-Bahn.", 298 | "Kleingärtner bewirtschaften (phrasefix)(#8)(#7)(#phrasefix)(phrasefix)(#5)(#2)(#phrasefix) Allotment holders cultivate (phrasefix)(#8)(#7)(#phrasefix)(phrasefix)(#5)(#2)(#phrasefix)", 299 | "Die älteste offizielle Karte Münchens fördert spannende Geschichten zu (phrasefix)(#0)(#3)(#phrasefix). The oldest official map of Munich brings captivating stories to (phrasefix)(#0)(#3)(#phrasefix).", 300 | "Es nervt, (phrasefix)(#2)(#0)(#phrasefix) nicht aktuell (phrasefix)(#5)(#8)(#phrasefix) It is annoying when geographical (phrasefix)(#2)(#0)(#phrasefix) up-to(phrasefix)(#5)(#8)(#phrasefix)", 301 | "Das kennt jeder, der sich schon mal aufregen musste, weil das (phrasefix)(#4)(#6)(#phrasefix) einer Umgehungsstraße eine grüne Wiese anzeigte. Anyone who has ever got worked up because the (phrasefix)(#4)(#6)(#phrasefix)nav is showing a green field instead of a bypass knows that.", 302 | "Die historischen Landkarten des digitalen Bayern(phrasefix)(#9)(#7)(#phrasefix)Atlases(phrasefix)(#9)(#1)(#phrasefix) ein Angebot (phrasefix)(#6)(#5)(#phrasefix), sind alles andere als aktuell - doch gerade deshalb sehr aufschlussreich. The historical maps of the digital BayernAtlas(phrasefix)(#9)(#1)(#phrasefix) an offering from the State (phrasefix)(#6)(#5)(#phrasefix) Bayern, are anything but (phrasefix)(#9)(#7)(#phrasefix) – and yet it is precisely for this reason that they are so informative.", 303 | "(phrasefix)(#8)(#3)(#phrasefix) (phrasefix)(#2)(#3)(#phrasefix) (phrasefix)(#7)(#2)(#phrasefix)(phrasefix)(#1)(#6)(#phrasefix)Karten vergleicht(phrasefix)(#2)(#0)(#phrasefix) (phrasefix)(#8)(#3)(#phrasefix) (phrasefix)(#2)(#3)(#phrasefix) (phrasefix)(#7)(#2)(#phrasefix) (phrasefix)(#1)(#6)(#phrasefix) current online maps(phrasefix)(#2)(#0)(#phrasefix)", 304 | "Dann wird deutlich, wie sich Städte und Gemeinden im Verbreitungsgebiet des Münchner (phrasefix)(#2)(#9)(#phrasefix) 19. Jahrhundert verändert haben. Then it becomes clear how the towns and municipalities in the distribution area of Munich's (phrasefix)(#2)(#9)(#phrasefix) changed since the 19th century.", 305 | ] 306 | 307 | 308 | # fake aligner object to be used in PhraseFixer test 309 | class FakeAligner: 310 | def __init__(self, alignment): 311 | self.alignment = alignment 312 | self.lineNo = 0 313 | 314 | # This produces a sequence of character span pairs of aligned char ranges from raw src to trg string 315 | def align_char_ranges(self, src, trg): 316 | ret = self.alignment[self.lineNo] 317 | self.lineNo += 1 318 | return ret 319 | 320 | 321 | def test_mixer(num_trials=100000): 322 | """Ensures that the mixer chooses from streams evenly.""" 323 | 324 | def gen(token): 325 | while True: 326 | yield token 327 | 328 | mixer = Mixer([gen("a"), gen("b"), gen("c")], [1 / 3, 1 / 3, 1 / 3]) 329 | 330 | counter = Counter() 331 | for i, value in enumerate(mixer): 332 | counter[value] += 1 333 | if i > num_trials: 334 | break 335 | 336 | values = counter.values() 337 | assert max(values) - min(values) <= 0.01 * num_trials 338 | --------------------------------------------------------------------------------