├── .github └── workflows │ ├── main.yml │ └── pypi-publish.yml ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── installation.rst ├── make.bat ├── quick_start.rst ├── requirements.txt ├── tutorials │ ├── remote_evaluation.rst │ ├── speech_to_speech.rst │ └── speech_to_text.rst └── user_guide │ ├── agent.rst │ ├── dataloader.rst │ ├── evaluator.rst │ └── introduction.rst ├── examples ├── __init__.py ├── demo │ └── silero_vad.py ├── quick_start │ ├── Dockerfile │ ├── agent_pipeline.py │ ├── agent_with_configs.py │ ├── agent_with_new_metrics.py │ ├── dict.txt │ ├── first_agent.py │ ├── readme.md │ ├── source.txt │ ├── spm_detokenizer_agent.py │ ├── spm_source.txt │ ├── spm_target.txt │ └── target.txt ├── speech_to_speech │ ├── english_alternate_agent.py │ ├── english_counter_agent.py │ ├── eval.sh │ ├── readme.md │ ├── reference │ │ ├── de.txt │ │ ├── en.txt │ │ ├── ja.txt │ │ ├── tgt_lang.txt │ │ └── zh.txt │ ├── source.txt │ └── test.wav ├── speech_to_speech_demo │ ├── english_counter_pipeline.py │ └── readme.md ├── speech_to_speech_text │ └── tree_agent_pipeline.py ├── speech_to_text │ ├── Dockerfile │ ├── counter_in_tgt_lang_agent.py │ ├── english_counter_agent.py │ ├── eval.sh │ ├── readme.md │ ├── reference │ │ ├── en.txt │ │ ├── tgt_lang.txt │ │ └── transcript.txt │ ├── source.txt │ ├── test.wav │ └── whisper_waitk.py └── speech_to_text_demo │ ├── counter_in_tgt_lang_pipeline.py │ ├── english_counter_pipeline.py │ └── readme.md ├── setup.cfg ├── setup.py └── simuleval ├── __init__.py ├── agents ├── __init__.py ├── actions.py ├── agent.py ├── pipeline.py ├── service.py └── states.py ├── analysis └── curve.py ├── cli.py ├── data ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── dataloader.py │ ├── s2t_dataloader.py │ └── t2t_dataloader.py └── segments.py ├── evaluator ├── __init__.py ├── evaluator.py ├── instance.py ├── remote.py └── scorers │ ├── __init__.py │ ├── latency_scorer.py │ └── quality_scorer.py ├── options.py ├── test ├── test_agent.py ├── test_agent_pipeline.py ├── test_evaluator.py ├── test_remote_evaluation.py ├── test_s2s.py ├── test_s2t.py └── test_visualize.py └── utils ├── __init__.py ├── agent.py ├── arguments.py ├── functional.py ├── slurm.py └── visualize.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: build 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | pull_request: 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: [3.8] 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version}} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | sudo apt-get update 28 | sudo apt-get install libsndfile1 29 | sudo apt-get install portaudio19-dev 30 | python -m pip install --upgrade pip==24.0 31 | pip install flake8 pytest black 32 | pip install g2p-en 33 | pip install huggingface-hub 34 | pip install fairseq 35 | pip install sentencepiece 36 | pip install openai-whisper editdistance pyaudio silero-vad 37 | pip install -e . 38 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 39 | python -c "import nltk; nltk.download('averaged_perceptron_tagger_eng')" 40 | - name: Lint with black 41 | run: black --check --diff . 42 | - name: Lint with flake8 43 | run: | 44 | # stop the build if there are Python syntax errors or undefined names 45 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 46 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 47 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 48 | - name: Test with pytest 49 | run: | 50 | pytest simuleval/test/test_agent.py 51 | pytest simuleval/test/test_agent_pipeline.py 52 | pytest simuleval/test/test_evaluator.py 53 | pytest simuleval/test/test_remote_evaluation.py 54 | pytest simuleval/test/test_s2s.py 55 | pytest simuleval/test/test_visualize.py 56 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | jobs: 8 | pypi_publish: 9 | name: Upload release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/p/simuleval 14 | permissions: 15 | id-token: write 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v2 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.x 25 | 26 | - name: Install dependencies 27 | run: pip install --upgrade pip setuptools wheel 28 | 29 | - name: Build package 30 | run: python setup.py sdist bdist_wheel 31 | 32 | - name: Publish package distributions to PyPI 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | args: "--use-feature=fast-deploy" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | .vscode 140 | 141 | # Mac files 142 | .DS_Store -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 1.0.0 (September 25, 2020) 2 | 3 | * Initial release. 4 | 5 | 1.0.1 (February 8, 2021) 6 | 7 | * Change CLI command 8 | * Change `simuleval-server` to `simuleval --server-only` 9 | * Change `simuleval-client` to `simuleval --client-only` 10 | * Fix some typos 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org 46 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Facebook AI SimulEval 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 14 | ## Issues 15 | We use GitHub issues to track public bugs. Please ensure your description is 16 | clear and has sufficient instructions to be able to reproduce the issue. 17 | 18 | ## License 19 | By contributing to Facebook AI SimulEval, you agree that your contributions will 20 | be licensed under the LICENSE file in the root directory of this source tree. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimulEval 2 | [![](https://github.com/facebookresearch/SimulEval/workflows/build/badge.svg)](https://github.com/facebookresearch/SimulEval/actions) 3 | 4 | SimulEval is a general evaluation framework for simultaneous translation on text and speech. Full documentation can be found [here](https://simuleval.readthedocs.io/en/v1.1.0/). 5 | 6 | ## Installation 7 | ``` 8 | git clone https://github.com/facebookresearch/SimulEval.git 9 | cd SimulEval 10 | pip install -e . 11 | ``` 12 | 13 | ## Quick Start 14 | Following is the evaluation of a [dummy agent](examples/quick_start) which operates wait-k (k = 3) policy and generates random words until the length of the generated words is the same as the number of all the source words. 15 | ```shell 16 | cd examples/quick_start 17 | simuleval --source source.txt --target target.txt --agent first_agent.py 18 | ``` 19 | 20 | # License 21 | 22 | SimulEval is licensed under Creative Commons BY-SA 4.0. 23 | 24 | # Citation 25 | 26 | Please cite as: 27 | 28 | ```bibtex 29 | @inproceedings{simuleval2020, 30 | title = {Simuleval: An evaluation toolkit for simultaneous translation}, 31 | author = {Xutai Ma, Mohammad Javad Dousti, Changhan Wang, Jiatao Gu, Juan Pino}, 32 | booktitle = {Proceedings of the EMNLP}, 33 | year = {2020}, 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Configuration file for the Sphinx documentation builder. 8 | # 9 | # For the full list of built-in configuration values, see the documentation: 10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 11 | 12 | # -- Project information ----------------------------------------------------- 13 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 14 | 15 | project = "SimulEval" 16 | copyright = "Facebook AI Research (FAIR)" 17 | author = "Facebook AI Research (FAIR)" 18 | release = "1.1.0" 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | extensions = [ 24 | "sphinx_rtd_theme", 25 | "sphinx.ext.autodoc", 26 | "sphinx.ext.intersphinx", 27 | "sphinx.ext.viewcode", 28 | "sphinx.ext.napoleon", 29 | "sphinxarg.ext", 30 | ] 31 | 32 | # templates_path = ['_templates'] 33 | exclude_patterns = [] 34 | 35 | 36 | # -- Options for HTML output ------------------------------------------------- 37 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 38 | 39 | html_theme = "sphinx_rtd_theme" 40 | # html_static_path = ['_static'] 41 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | SimulEval documentation 2 | ======================= 3 | 4 | SimulEval is a general evaluation framework for simultaneous translation. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :glob: 9 | :caption: Get started 10 | 11 | installation 12 | quick_start 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: User's guide 17 | 18 | user_guide/introduction 19 | user_guide/agent 20 | user_guide/evaluator 21 | user_guide/dataloader 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Tutorials 26 | 27 | tutorials/remote_evaluation 28 | tutorials/speech_to_text 29 | tutorials/speech_to_speech 30 | 31 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | From pip 5 | -------- 6 | 7 | .. code-block:: bash 8 | 9 | pip install simuleval 10 | 11 | 12 | From source 13 | ----------- 14 | 15 | .. code-block:: bash 16 | 17 | git clone https://github.com/facebookresearch/SimulEval.git 18 | cd SimulEval 19 | pip install -e . 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | :: Copyright (c) Facebook, Inc. and its affiliates. 2 | :: All rights reserved. 3 | :: 4 | :: This source code is licensed under the license found in the 5 | :: LICENSE file in the root directory of this source tree. 6 | 7 | @ECHO OFF 8 | 9 | pushd %~dp0 10 | 11 | REM Command file for Sphinx documentation 12 | 13 | if "%SPHINXBUILD%" == "" ( 14 | set SPHINXBUILD=sphinx-build 15 | ) 16 | set SOURCEDIR=source 17 | set BUILDDIR=build 18 | 19 | %SPHINXBUILD% >NUL 2>NUL 20 | if errorlevel 9009 ( 21 | echo. 22 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 23 | echo.installed, then set the SPHINXBUILD environment variable to point 24 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 25 | echo.may add the Sphinx directory to PATH. 26 | echo. 27 | echo.If you don't have Sphinx installed, grab it from 28 | echo.https://www.sphinx-doc.org/ 29 | exit /b 1 30 | ) 31 | 32 | if "%1" == "" goto help 33 | 34 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 35 | goto end 36 | 37 | :help 38 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 39 | 40 | :end 41 | popd 42 | -------------------------------------------------------------------------------- /docs/quick_start.rst: -------------------------------------------------------------------------------- 1 | .. _first-agent: 2 | 3 | Quick Start 4 | =========== 5 | 6 | This section will introduce a minimal example on how to use SimulEval for simultaneous translation evaluation. 7 | The code in the example can be found in :code:`examples/quick_start`. 8 | 9 | The agent in SimulEval is core for simultaneous evaluation. 10 | It's a carrier of user's simultaneous system. 11 | The user has to implement the agent based on their system for evaluation. 12 | The example simultaneous system is a dummy wait-k agent, which 13 | 14 | - Runs `wait-k `_ policy. 15 | - Generates random characters the policy decide to write. 16 | - Stops the generation k predictions after source input. For simplicity, we just set :code:`k=3` in this example. 17 | 18 | The implementation of this agent is shown as follow. 19 | 20 | .. literalinclude:: ../examples/quick_start/first_agent.py 21 | :language: python 22 | :lines: 6- 23 | 24 | There two essential components for an agent: 25 | 26 | - :code:`states`: The attribute keeps track of the source and target information. 27 | - :code:`policy`: The method makes decisions when the there is a new source segment. 28 | 29 | Once the agent is implemented and saved at :code:`first_agent.py`, 30 | run the following command for latency evaluation on: 31 | 32 | .. code-block:: bash 33 | 34 | simuleval --source source.txt --reference target.txt --agent first_agent.py 35 | 36 | where :code:`--source` is the input file while :code:`--target` is the reference file. 37 | 38 | By default, the SimulEval will give the following output --- one quality and three latency metrics. 39 | 40 | .. code-block:: bash 41 | 42 | 2022-12-05 13:43:58 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent 43 | 2022-12-05 13:43:58 | INFO | simuleval.dataloader | Evaluating from text to text. 44 | 2022-12-05 13:43:58 | INFO | simuleval.sentence_level_evaluator | Results: 45 | BLEU AL AP DAL 46 | 1.541 3.0 0.688 3.0 47 | 48 | The average lagging is expected since we are running an wait-3 system where the source and target always have the same length. 49 | Notice that we have a very low yet random BLEU score. It's because we are randomly generate the output. 50 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-argparse==0.4.0 2 | -e . 3 | -------------------------------------------------------------------------------- /docs/tutorials/remote_evaluation.rst: -------------------------------------------------------------------------------- 1 | Remote Evaluation 2 | ================= 3 | 4 | Stand Alone Agent 5 | ----------------- 6 | The agent can run in stand alone mode, 7 | by using :code:`--standalone` option. 8 | The SimulEval will kickoff a server that host the agent. 9 | For instance, with the agent in :ref:`first-agent`, 10 | 11 | .. code-block:: bash 12 | 13 | > simuleval --standalone --remote-port 8888 --agent first_agent.py 14 | 2022-12-06 19:12:26 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent 15 | 2022-12-06 19:12:26 | INFO | simuleval.agent_server | Simultaneous Translation Server Started (process id 53902). Listening to port 8888 16 | 17 | | 18 | For custom speech to text transcription, you could also use the whisper agent in :ref:`speech-to-text`, 19 | 20 | .. code-block:: bash 21 | 22 | > simuleval --standalone --remote-port 8888 --agent whisper_waitk.py --waitk-lagging 3 23 | 2024-08-11 11:51:56 | INFO | simuleval.utils.agent | System will run on device: cpu. dtype: fp32 24 | 2024-08-11 11:51:56 | INFO | simuleval.agent_server | Simultaneous Translation Server Started (process id 38768). Listening to port 8888 25 | 26 | For detailed RESTful APIs, please see (TODO) 27 | 28 | Docker 29 | ----------------- 30 | You can also use a docker image to run the simuleval. 31 | An minimal example of :code:`Dockerfile` is 32 | 33 | .. literalinclude:: ../../examples/quick_start/Dockerfile 34 | :language: docker 35 | 36 | Build and run the docker image: 37 | 38 | .. code-block:: bash 39 | 40 | cd examples/quick_start && docker build -t simuleval_agent . 41 | docker run -p 8888:8888 simuleval_agent:latest 42 | 43 | | 44 | The custom audio file speech to text :code:`Dockerfile` is 45 | 46 | .. literalinclude:: ../../examples/speech_to_text/Dockerfile 47 | :language: docker 48 | 49 | Build and run the docker image: 50 | 51 | .. code-block:: bash 52 | 53 | cd examples/speech_to_text && docker build -t simuleval-speech-to-text:1.0 . 54 | docker run -p 8888:8888 simuleval-speech-to-text:1.0 55 | 56 | Remote Evaluation 57 | ------------------ 58 | If there is an agent server or docker image available, 59 | (let's say the one we just kickoff at localhost:8888) 60 | We can start a remote evaluator as follow. For simplicity we assume they are on the same machine 61 | 62 | .. code-block:: bash 63 | 64 | simuleval --remote-eval --remote-port 8888 \ 65 | --source source.txt --target target.txt \ 66 | --source-type text --target-type text 67 | 68 | | 69 | For whisper agent's speech to text: 70 | 71 | .. code-block:: bash 72 | 73 | simuleval --remote-eval --remote-port 8888 \ 74 | --source-segment-size 500 \ 75 | --source source.txt --target reference/transcript.txt \ 76 | --source-type speech --target-type text \ 77 | --output output --quality-metrics WER -------------------------------------------------------------------------------- /docs/tutorials/speech_to_speech.rst: -------------------------------------------------------------------------------- 1 | Speech-to-Speech 2 | ================ -------------------------------------------------------------------------------- /docs/tutorials/speech_to_text.rst: -------------------------------------------------------------------------------- 1 | .. _speech-to-text: 2 | 3 | Speech-to-Text 4 | ============== 5 | 6 | Whisper Agent 7 | ----------------- 8 | Use whisper to evaluate custom audio for speech to text transcription. 9 | First, change directory to :code:`speech_to_text`: 10 | 11 | .. code-block:: bash 12 | 13 | cd examples/speech-to-text 14 | 15 | Then, run the example code: 16 | 17 | .. code-block:: bash 18 | 19 | simuleval \ 20 | --agent whisper_waitk.py \ 21 | --source-segment-size 500 \ 22 | --waitk-lagging 3 \ 23 | --source source.txt --target reference/transcript.txt \ 24 | --output output --quality-metrics WER --visualize 25 | 26 | The optional :code:`--visualize` tag generates N number of graphs in speech_to_text/output/visual directory where N corresponds to the number of source audio provided. An example graph can be seen `here `_. 27 | 28 | | 29 | In addition, it supports the :code:`--score-only` command, where it will read data from :code:`instances.log` without running inference, which saves time if you just want the scores. 30 | 31 | .. code-block:: bash 32 | 33 | simuleval --score-only --output output --visualize -------------------------------------------------------------------------------- /docs/user_guide/agent.rst: -------------------------------------------------------------------------------- 1 | Agent 2 | ===== 3 | 4 | To evaluate the simultaneous translation system, 5 | the users need to implement agent class which operate the system logics. 6 | This section will introduce how to implement an agent. 7 | 8 | Source-Target Types 9 | ------------------- 10 | First of all, 11 | we must declare the source and target types of the agent class. 12 | It can be done by inheriting from 13 | 14 | - One of the following four built-in agent types 15 | 16 | - :class:`simuleval.agents.TextToTextAgent` 17 | - :class:`simuleval.agents.SpeechToTextAgent` 18 | - :class:`simuleval.agents.TextToSpeechAgent` 19 | - :class:`simuleval.agents.SpeechToSpeechAgent` 20 | 21 | - Or :class:`simuleval.agents.GenericAgent`, with explicit declaration of :code:`source_type` and :code:`target_type`. 22 | 23 | The follow two examples are equivalent. 24 | 25 | .. code-block:: python 26 | 27 | from simuleval import simuleval 28 | from simuleval.agents import GenericAgent 29 | 30 | class MySpeechToTextAgent(GenericAgent): 31 | source_type = "Speech" 32 | target_type = "Text" 33 | .... 34 | 35 | .. code-block:: python 36 | 37 | from simuleval.agents import SpeechToSpeechAgent 38 | 39 | class MySpeechToTextAgent(SpeechToSpeechAgent): 40 | .... 41 | 42 | .. _agent_policy: 43 | 44 | Policy 45 | ------ 46 | 47 | The agent must have a :code:`policy` method which must return one of two actions, :code:`ReadAction` and :code:`WriteAction`. 48 | For example, an agent with a :code:`policy` method should look like this 49 | 50 | .. code-block:: python 51 | 52 | class MySpeechToTextAgent(SpeechToSpeechAgent): 53 | def policy(self): 54 | if do_we_need_more_input(self.states): 55 | return ReadAction() 56 | else: 57 | prediction = generate_a_token(self.states) 58 | finished = is_sentence_finished(self.states) 59 | return WriteAction(prediction, finished=finished) 60 | 61 | 62 | .. 63 | .. autoclass:: simuleval.agents.actions.WriteAction 64 | 65 | .. 66 | .. autoclass:: simuleval.agents.actions.ReadAction 67 | 68 | States 69 | ------------ 70 | Each agent has the attribute the :code:`states` to keep track of the progress of decoding. 71 | The :code:`states` attribute will be reset at the beginning of each sentence. 72 | SimulEval provide an built-in states :class:`simuleval.agents.states.AgentStates`, 73 | which has some basic attributes such source and target sequences. 74 | The users can also define customized states with :code:`Agent.build_states` method: 75 | 76 | .. code-block:: python 77 | 78 | from simuleval.agents.states import AgentStates 79 | from dataclasses import dataclass 80 | 81 | @dataclass 82 | class MyComplicatedStates(AgentStates) 83 | some_very_useful_variable: int 84 | 85 | def reset(self): 86 | super().reset() 87 | # also remember to reset the value 88 | some_very_useful_variable = 0 89 | 90 | class MySpeechToTextAgent(SpeechToSpeechAgent): 91 | def build_states(self): 92 | return MyComplicatedStates(0) 93 | 94 | def policy(self): 95 | some_very_useful_variable = self.states.some_very_useful_variable 96 | ... 97 | self.states.some_very_useful_variable = new_value 98 | ... 99 | 100 | .. 101 | .. autoclass:: simuleval.agents.states.AgentStates 102 | :members: 103 | 104 | 105 | Pipeline 106 | -------- 107 | The simultaneous system can consist several different components. 108 | For instance, a simultaneous speech-to-text translation can have a streaming automatic speech recognition system and simultaneous text-to-text translation system. 109 | SimulEval introduces the agent pipeline to support this function. 110 | The following is a minimal example. 111 | We concatenate two wait-k systems with different rates (:code:`k=2` and :code:`k=3`) 112 | Note that if there are more than one agent class define, 113 | the :code:`@entrypoint` decorator has to be used to determine the entry point, 114 | or `--user-dir` and `--agent-class` must be specified. See `simuleval/test/first_agent.py` 115 | for an example. 116 | 117 | .. literalinclude:: ../../examples/quick_start/agent_pipeline.py 118 | :language: python 119 | :lines: 7- 120 | 121 | Customized Arguments 122 | ----------------------- 123 | 124 | It is often the case that we need to pass some customized arguments for the system to configure different settings. 125 | The agent class has a built-in static method :code:`add_args` for this purpose. 126 | The following is an updated version of the dummy agent from :ref:`first-agent`. 127 | 128 | .. literalinclude:: ../../examples/quick_start/agent_with_configs.py 129 | :language: python 130 | :lines: 6- 131 | 132 | Then just simply pass the arguments through command line as follow. 133 | 134 | .. code-block:: bash 135 | 136 | simuleval \ 137 | --source source.txt --source target.txt \ # data arguments 138 | --agent dummy_waitk_text_agent_v2.py \ 139 | --waitk 3 --vocab data/dict.txt # agent arguments 140 | 141 | Load Agents from Python Class 142 | ----------------------------- 143 | 144 | If you have the agent class in the python environment, for instance 145 | 146 | .. literalinclude:: ../../examples/quick_start/agent_with_configs.py 147 | :language: python 148 | :lines: 6- 149 | 150 | You can also start the evaluation with following command 151 | 152 | .. code-block:: bash 153 | 154 | simuleval \ 155 | --source source.txt --source target.txt \ # data arguments 156 | --agent-class DummyWaitkTextAgent \ 157 | --waitk 3 --vocab data/dict.txt # agent arguments 158 | 159 | 160 | Load Agents from Directory 161 | -------------------------- 162 | 163 | Agent can also be loaded from a directory, which will be referred to as system directory. 164 | The system directory should have everything required to start the agent. Again use the following agent as example 165 | 166 | .. literalinclude:: ../../examples/quick_start/agent_with_configs.py 167 | :language: python 168 | :lines: 6- 169 | 170 | and the system directory has 171 | 172 | .. code-block:: bash 173 | 174 | > ls ${system_dir} 175 | main.yaml dict.txt 176 | 177 | Where the `main.yaml` has all the command line options. The path will be the relative path to the `${system_dir}`. 178 | 179 | .. code-block:: yaml 180 | 181 | waitk: 3 182 | vocab: dict.txt 183 | 184 | The agent can then be started as following 185 | 186 | .. code-block:: bash 187 | 188 | simuleval \ 189 | --source source.txt --source target.txt \ # data arguments 190 | --system-dir ${system_dir} 191 | 192 | By default, the `main.yaml` will be read. You can also have multiple YAML files in the system directory and pass them through command line arguments 193 | 194 | .. code-block:: bash 195 | > ls ${system_dir} 196 | main.yaml dict.txt v1.yaml 197 | 198 | > simuleval \ 199 | --source source.txt --source target.txt \ # data arguments 200 | --system-dir ${system_dir} --system-config v1.yaml -------------------------------------------------------------------------------- /docs/user_guide/dataloader.rst: -------------------------------------------------------------------------------- 1 | Dataloader 2 | =========== 3 | There are two ways to load data. 4 | 5 | .. autoclass:: simuleval.data.dataloader.GenericDataloader -------------------------------------------------------------------------------- /docs/user_guide/evaluator.rst: -------------------------------------------------------------------------------- 1 | Evaluator 2 | ========= 3 | 4 | The evaluation in SimulEval implemented as the Evaluator shown below. 5 | It runs on sentence level, and will score the translation on quality and latency. 6 | The user can use :code:`--quality-metrics` and :code:`--latency-metrics` to choose the metrics. 7 | The final results along with the logs will be saved at :code:`--output` if given. 8 | 9 | .. autoclass:: simuleval.evaluator.evaluator.SentenceLevelEvaluator 10 | 11 | Quality Scorers 12 | --------------- 13 | 14 | .. autoclass:: simuleval.evaluator.scorers.quality_scorer.SacreBLEUScorer 15 | .. autoclass:: simuleval.evaluator.scorers.quality_scorer.ASRSacreBLEUScorer 16 | 17 | Latency Scorers 18 | --------------- 19 | 20 | .. autoclass:: simuleval.evaluator.scorers.latency_scorer.ALScorer 21 | :members: 22 | 23 | .. autoclass:: simuleval.evaluator.scorers.latency_scorer.APScorer 24 | :members: 25 | 26 | .. autoclass:: simuleval.evaluator.scorers.latency_scorer.DALScorer 27 | :members: 28 | 29 | Customized Scorers 30 | ------------------ 31 | To add customized scorers, the user can use :code:`@register_latency_scorer` or :code:`@register_quality_scorer` to decorate a scorer class. 32 | and use :code:`--quality-metrics` and :code:`--latency-metrics` to call the scorer. For example: 33 | 34 | .. literalinclude:: ../../examples/quick_start/agent_with_new_metrics.py 35 | :lines: 6- 36 | 37 | .. code-block:: bash 38 | 39 | > simuleval --source source.txt --target target.txt --agent agent_with_new_metrics.py --latency-metrics RTF 40 | 2022-12-06 12:56:01 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent 41 | 2022-12-06 12:56:01 | INFO | simuleval.dataloader | Evaluating from text to text. 42 | 2022-12-06 12:56:01 | INFO | simuleval.sentence_level_evaluator | Results: 43 | BLEU RTF 44 | 1.593 1.078 45 | -------------------------------------------------------------------------------- /docs/user_guide/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | Different from offline translation system, the evaluation of simultaneous translation requires incremental decoding with an streaming input. 4 | The simultaneous introduce the a front-end / back-end setup, shown as follow. 5 | 6 | The back-end contains one or multiple user-defined agents which make decisions of whether to generate prediction at a certain point. 7 | The agent can also considered as queue, where the input are keep pushed in and policy decides the timing to pop the output. 8 | 9 | The front-end on the other side, represent the source of input and recipient of the system prediction. 10 | In deployment, the front-end can be web page or cell phone app. 11 | In SimulEval, the front-end is the evaluator , which feeds streaming input to back-end, receive prediction and track the delays. 12 | The front-end and back-end can run separately for different purpose. 13 | 14 | The evaluation process can summarized as follow pseudocode 15 | 16 | .. code-block:: python 17 | 18 | for instance in evaluator.instances: 19 | while not instance.finished: 20 | input_segment = instance.send_source() 21 | prediction = agent.pushpop(input_segment) 22 | if prediction is not None: 23 | instance.receive_prediction(prediction) 24 | 25 | results = [scorer.score() for scorer in evaluate.scorers] 26 | 27 | 28 | 29 | The common usage of SimulEval is as follow 30 | 31 | .. code-block:: bash 32 | 33 | simuleval DATALOADER_OPTIONS EVALUATOR_OPTIONS --agent $AGENT_FILE AGENT_OPTIONS 34 | 35 | We will introduce the usage of the toolkit based on these three major components: Agent, Dataloader and Evaluator. -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SimulEval/536de8253b82d805c9845440169a5010ff507357/examples/__init__.py -------------------------------------------------------------------------------- /examples/demo/silero_vad.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import queue 3 | import time 4 | import torch 5 | import numpy as np 6 | import soundfile 7 | from argparse import Namespace, ArgumentParser 8 | from simuleval.agents import SpeechToSpeechAgent, AgentStates 9 | from simuleval.agents.actions import WriteAction, ReadAction 10 | from simuleval.data.segments import Segment, EmptySegment, SpeechSegment 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | class SileroVADStates(AgentStates): 16 | def __init__(self, args): 17 | self.model, utils = torch.hub.load( 18 | repo_or_dir="snakers4/silero-vad", 19 | model="silero_vad", 20 | force_reload=False, 21 | onnx=False, 22 | ) 23 | 24 | ( 25 | self.get_speech_timestamps, 26 | self.save_audio, 27 | self.read_audio, 28 | self.VADIterator, 29 | self.collect_chunks, 30 | ) = utils 31 | self.silence_limit_ms = args.silence_limit_ms 32 | self.window_size_samples = args.window_size_samples 33 | self.chunk_size_samples = args.chunk_size_samples 34 | self.sample_rate = args.sample_rate 35 | self.debug = args.debug 36 | self.test_input_segments_wav = None 37 | self.debug_log(args) 38 | self.input_queue: queue.Queue[Segment] = queue.Queue() 39 | self.next_input_queue: queue.Queue[Segment] = queue.Queue() 40 | super().__init__() 41 | 42 | def clear_queues(self): 43 | self.debug_log(f"clearing {self.input_queue.qsize()} chunks") 44 | while not self.input_queue.empty(): 45 | self.input_queue.get_nowait() 46 | self.input_queue.task_done() 47 | self.debug_log(f"moving {self.next_input_queue.qsize()} chunks") 48 | # move everything from next_input_queue to input_queue 49 | while not self.next_input_queue.empty(): 50 | chunk = self.next_input_queue.get_nowait() 51 | self.next_input_queue.task_done() 52 | self.input_queue.put_nowait(chunk) 53 | 54 | def reset(self) -> None: 55 | super().reset() 56 | # TODO: in seamless_server, report latency for each new segment 57 | self.first_input_ts = None 58 | self.silence_acc_ms = 0 59 | self.input_chunk = np.empty(0, dtype=np.int16) 60 | self.is_fresh_state = True 61 | self.clear_queues() 62 | self.model.reset_states() 63 | 64 | def get_speech_prob_from_np_float32(self, segment: np.ndarray): 65 | t = torch.from_numpy(segment) 66 | speech_probs = [] 67 | # print("len(t): ", len(t)) 68 | for i in range(0, len(t), self.window_size_samples): 69 | chunk = t[i : i + self.window_size_samples] 70 | if len(chunk) < self.window_size_samples: 71 | break 72 | speech_prob = self.model(chunk, self.sample_rate).item() 73 | speech_probs.append(speech_prob) 74 | return speech_probs 75 | 76 | def debug_log(self, m): 77 | if self.debug: 78 | logger.info(m) 79 | 80 | def process_speech(self, segment): 81 | """ 82 | Process a full or partial speech chunk 83 | """ 84 | queue = self.input_queue 85 | if self.source_finished: 86 | # current source is finished, but next speech starts to come in already 87 | self.debug_log("use next_input_queue") 88 | queue = self.next_input_queue 89 | 90 | # NOTE: we don't reset silence_acc_ms here so that once an utterance 91 | # becomes longer (accumulating more silence), it has a higher chance 92 | # of being segmented. 93 | # self.silence_acc_ms = 0 94 | 95 | if self.first_input_ts is None: 96 | self.first_input_ts = time.time() * 1000 97 | 98 | while len(segment) > 0: 99 | # add chunks to states.buffer 100 | i = self.chunk_size_samples - len(self.input_chunk) 101 | self.input_chunk = np.concatenate((self.input_chunk, segment[:i])) 102 | segment = segment[i:] 103 | self.is_fresh_state = False 104 | if len(self.input_chunk) == self.chunk_size_samples: 105 | queue.put_nowait( 106 | SpeechSegment(content=self.input_chunk, finished=False) 107 | ) 108 | self.input_chunk = np.empty(0, dtype=np.int16) 109 | 110 | def check_silence_acc(self): 111 | if self.silence_acc_ms >= self.silence_limit_ms: 112 | self.silence_acc_ms = 0 113 | if self.input_chunk.size > 0: 114 | # flush partial input_chunk 115 | self.input_queue.put_nowait( 116 | SpeechSegment(content=self.input_chunk, finished=True) 117 | ) 118 | self.input_chunk = np.empty(0, dtype=np.int16) 119 | self.input_queue.put_nowait(EmptySegment(finished=True)) 120 | self.source_finished = True 121 | 122 | def update_source(self, segment: np.ndarray): 123 | speech_probs = self.get_speech_prob_from_np_float32(segment) 124 | chunk_size_ms = len(segment) * 1000 / self.sample_rate 125 | self.debug_log( 126 | f"{chunk_size_ms}, {len(speech_probs)} {[round(s, 2) for s in speech_probs]}" 127 | ) 128 | window_size_ms = self.window_size_samples * 1000 / self.sample_rate 129 | if all(i <= 0.5 for i in speech_probs): 130 | if self.source_finished: 131 | return 132 | self.debug_log("got silent chunk") 133 | if not self.is_fresh_state: 134 | self.silence_acc_ms += chunk_size_ms 135 | self.check_silence_acc() 136 | return 137 | elif speech_probs[-1] <= 0.5: 138 | self.debug_log("=== start of silence chunk") 139 | # beginning = speech, end = silence 140 | # pass to process_speech and accumulate silence 141 | self.process_speech(segment) 142 | # accumulate contiguous silence 143 | for i in range(len(speech_probs) - 1, -1, -1): 144 | if speech_probs[i] > 0.5: 145 | break 146 | self.silence_acc_ms += window_size_ms 147 | self.check_silence_acc() 148 | elif speech_probs[0] <= 0.5: 149 | self.debug_log("=== start of speech chunk") 150 | # beginning = silence, end = speech 151 | # accumulate silence , pass next to process_speech 152 | for i in range(0, len(speech_probs)): 153 | if speech_probs[i] > 0.5: 154 | break 155 | self.silence_acc_ms += window_size_ms 156 | self.check_silence_acc() 157 | self.process_speech(segment) 158 | else: 159 | self.debug_log("======== got speech chunk") 160 | self.process_speech(segment) 161 | 162 | def debug_write_wav(self, chunk): 163 | if self.test_input_segments_wav is not None: 164 | self.test_input_segments_wav.seek(0, soundfile.SEEK_END) 165 | self.test_input_segments_wav.write(chunk) 166 | 167 | 168 | class SileroVADAgent(SpeechToSpeechAgent): 169 | def __init__(self, args: Namespace) -> None: 170 | super().__init__(args) 171 | self.chunk_size_samples = args.chunk_size_samples 172 | self.args = args 173 | 174 | @staticmethod 175 | def add_args(parser: ArgumentParser): 176 | parser.add_argument( 177 | "--sample-rate", 178 | default=16000, 179 | type=float, 180 | ) 181 | parser.add_argument( 182 | "--window-size-samples", 183 | default=512, # sampling_rate // 1000 * 32 => 32 ms at 16000 sample rate 184 | type=int, 185 | help="Window size for passing samples to VAD", 186 | ) 187 | parser.add_argument( 188 | "--chunk-size-samples", 189 | default=5120, # sampling_rate // 1000 * 320 => 320 ms at 16000 sample rate 190 | type=int, 191 | help="Chunk size for passing samples to model", 192 | ) 193 | parser.add_argument( 194 | "--silence-limit-ms", 195 | default=700, 196 | type=int, 197 | help="send EOS to the input_queue after this amount of silence", 198 | ) 199 | parser.add_argument( 200 | "--debug", 201 | default=False, 202 | type=bool, 203 | help="Enable debug logs", 204 | ) 205 | 206 | def build_states(self) -> SileroVADStates: 207 | return SileroVADStates(self.args) 208 | 209 | def policy(self, states: SileroVADStates): 210 | states.debug_log( 211 | f"queue size: {states.input_queue.qsize()}, input_chunk size: {len(states.input_chunk)}" 212 | ) 213 | content = np.empty(0, dtype=np.int16) 214 | is_finished = states.source_finished 215 | while not states.input_queue.empty(): 216 | chunk = states.input_queue.get_nowait() 217 | states.input_queue.task_done() 218 | content = np.concatenate((content, chunk.content)) 219 | 220 | states.debug_write_wav(content) 221 | if is_finished: 222 | states.debug_write_wav(np.zeros(16000)) 223 | 224 | if len(content) == 0: # empty queue 225 | if not states.source_finished: 226 | return ReadAction() 227 | else: 228 | # NOTE: this should never happen, this logic is a safeguard 229 | segment = EmptySegment(finished=True) 230 | else: 231 | segment = SpeechSegment( 232 | content=content.tolist(), 233 | finished=is_finished, 234 | sample_rate=states.sample_rate, 235 | ) 236 | 237 | return WriteAction(segment, finished=is_finished) 238 | -------------------------------------------------------------------------------- /examples/quick_start/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | RUN apt-get update \ 3 | && apt-get upgrade -y \ 4 | && apt-get install -y \ 5 | && apt-get -y install apt-utils gcc libpq-dev libsndfile-dev 6 | RUN git clone https://github.com/facebookresearch/SimulEval.git 7 | WORKDIR SimulEval 8 | RUN git checkout v1.1.0 9 | RUN pip install -e . 10 | CMD ["simuleval", "--standalone", "--remote-port", "8888", "--agent", "examples/quick_start/first_agent.py"] 11 | -------------------------------------------------------------------------------- /examples/quick_start/agent_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from simuleval.utils import entrypoint 9 | from simuleval.agents import TextToTextAgent 10 | from simuleval.agents.actions import ReadAction, WriteAction 11 | from simuleval.agents import AgentPipeline 12 | 13 | 14 | class DummyWaitkTextAgent(TextToTextAgent): 15 | waitk = 0 16 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 17 | 18 | def policy(self): 19 | lagging = len(self.states.source) - len(self.states.target) 20 | 21 | if lagging >= self.waitk or self.states.source_finished: 22 | prediction = random.choice(self.vocab) 23 | 24 | return WriteAction(prediction, finished=(lagging <= 1)) 25 | else: 26 | return ReadAction() 27 | 28 | 29 | class DummyWait2TextAgent(DummyWaitkTextAgent): 30 | waitk = 2 31 | 32 | 33 | class DummyWait4TextAgent(DummyWaitkTextAgent): 34 | waitk = 4 35 | 36 | 37 | @entrypoint 38 | class DummyPipeline(AgentPipeline): 39 | pipeline = [DummyWait2TextAgent, DummyWait4TextAgent] 40 | -------------------------------------------------------------------------------- /examples/quick_start/agent_with_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from simuleval.utils import entrypoint 9 | from simuleval.agents import TextToTextAgent 10 | from simuleval.agents.actions import ReadAction, WriteAction 11 | from argparse import Namespace, ArgumentParser 12 | 13 | 14 | @entrypoint 15 | class DummyWaitkTextAgent(TextToTextAgent): 16 | def __init__(self, args: Namespace): 17 | """Initialize your agent here. 18 | For example loading model, vocab, etc 19 | """ 20 | super().__init__(args) 21 | self.waitk = args.waitk 22 | with open(args.vocab) as f: 23 | self.vocab = [line.strip() for line in f] 24 | 25 | @staticmethod 26 | def add_args(parser: ArgumentParser): 27 | """Add customized command line arguments""" 28 | parser.add_argument("--waitk", type=int, default=3) 29 | parser.add_argument("--vocab", type=str) 30 | 31 | def policy(self): 32 | lagging = len(self.states.source) - len(self.states.target) 33 | 34 | if lagging >= self.waitk or self.states.source_finished: 35 | prediction = random.choice(self.vocab) 36 | 37 | return WriteAction(prediction, finished=(lagging <= 1)) 38 | else: 39 | return ReadAction() 40 | -------------------------------------------------------------------------------- /examples/quick_start/agent_with_new_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from statistics import mean 9 | from simuleval.utils import entrypoint 10 | from simuleval.evaluator.scorers.latency_scorer import ( 11 | register_latency_scorer, 12 | LatencyScorer, 13 | ) 14 | from simuleval.agents import TextToTextAgent 15 | from simuleval.agents.actions import ReadAction, WriteAction 16 | 17 | 18 | @register_latency_scorer("RTF") 19 | class RTFScorer(LatencyScorer): 20 | """Real time factor 21 | 22 | Usage: 23 | --latency-metrics RTF 24 | """ 25 | 26 | def __call__(self, instances) -> float: 27 | scores = [] 28 | for ins in instances.values(): 29 | scores.append(ins.delays[-1] / ins.source_length) 30 | 31 | return mean(scores) 32 | 33 | 34 | @entrypoint 35 | class DummyWaitkTextAgent(TextToTextAgent): 36 | waitk = 3 37 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 38 | 39 | def policy(self): 40 | lagging = len(self.states.source) - len(self.states.target) 41 | 42 | if lagging >= self.waitk or self.states.source_finished: 43 | prediction = random.choice(self.vocab) 44 | 45 | return WriteAction(prediction, finished=(lagging <= 1)) 46 | else: 47 | return ReadAction() 48 | -------------------------------------------------------------------------------- /examples/quick_start/dict.txt: -------------------------------------------------------------------------------- 1 | A 2 | B 3 | C 4 | D 5 | E 6 | F 7 | G 8 | H 9 | I 10 | J 11 | K 12 | L 13 | M 14 | N 15 | O 16 | P 17 | Q 18 | R 19 | S 20 | T 21 | U 22 | V 23 | W 24 | X 25 | Y 26 | Z 27 | -------------------------------------------------------------------------------- /examples/quick_start/first_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from simuleval.utils import entrypoint 9 | from simuleval.agents import TextToTextAgent 10 | from simuleval.agents.actions import ReadAction, WriteAction 11 | 12 | 13 | @entrypoint 14 | class DummyWaitkTextAgent(TextToTextAgent): 15 | waitk = 3 16 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 17 | 18 | def policy(self): 19 | lagging = len(self.states.source) - len(self.states.target) 20 | 21 | if lagging >= self.waitk or self.states.source_finished: 22 | prediction = random.choice(self.vocab) 23 | 24 | return WriteAction(prediction, finished=(lagging <= 1)) 25 | else: 26 | return ReadAction() 27 | -------------------------------------------------------------------------------- /examples/quick_start/readme.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | Following are some minimal examples to use SimulEval. More details can be found [here](https://simuleval.readthedocs.io/en/v1.1.0/quick_start.html). 3 | 4 | ## First Agent 5 | To evaluate a text-to-text wait-3 system with random output: 6 | 7 | ``` 8 | > simuleval --source source.txt --target target.txt --agent first_agent.py 9 | 10 | 2022-12-05 13:43:58 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent 11 | 2022-12-05 13:43:58 | INFO | simuleval.dataloader | Evaluating from text to text. 12 | 2022-12-05 13:43:58 | INFO | simuleval.sentence_level_evaluator | Results: 13 | BLEU AL AP DAL 14 | 1.541 3.0 0.688 3.0 15 | 16 | ``` 17 | 18 | ## Agent with Command Line Arguments 19 | ``` 20 | simuleval --source source.txt --target target.txt --agent agent_with_configs.py --waitk 3 --vocab dict.txt 21 | ``` 22 | 23 | ## Agent Pipeline 24 | ``` 25 | simuleval --source source.txt --target target.txt --agent agent_pipeline.py 26 | ``` 27 | 28 | ## Agent with New Metrics 29 | ``` 30 | simuleval --source source.txt --target target.txt --agent agent_with_new_metrics.py 31 | ``` 32 | 33 | ## Standalone Agent & Remote Evaluation 34 | Start an agent server: 35 | ``` 36 | simuleval --standalone --remote-port 8888 --agent agent_with_new_metrics.py 37 | ``` 38 | Or with docker 39 | ``` 40 | docker build -t simuleval_agent . 41 | docker run -p 8888:8888 simuleval_agent:latest 42 | ``` 43 | 44 | Start a remote evaluator: 45 | ``` 46 | simuleval --remote-eval --source source.txt --target target.txt --source-type text --target-type text --remote-port 8888 47 | ``` 48 | -------------------------------------------------------------------------------- /examples/quick_start/source.txt: -------------------------------------------------------------------------------- 1 | Z U S N B Y X Q L O T 2 | M A J F P G O Y V R H M Z O T A 3 | M O W A O I D H H B O F 4 | N Q N I P C O H A A 5 | G B O J H P W C I A L V 6 | P T Z D E E N T B Y G Z R K 7 | F S H U K R W K S B R K M B B Q F C O U 8 | M H O L W Z G J Y X J B I 9 | A V B F E S F E W Q C S 10 | F N O I E Z B R S C V N S 11 | -------------------------------------------------------------------------------- /examples/quick_start/spm_detokenizer_agent.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from fairseq.data.encoders import build_bpe 4 | 5 | from simuleval.agents import TextToTextAgent 6 | from simuleval.agents.actions import ReadAction, WriteAction 7 | from simuleval.agents.pipeline import AgentPipeline 8 | from simuleval.agents.states import AgentStates 9 | 10 | 11 | class DummySegmentAgent(TextToTextAgent): 12 | """ 13 | This agent just splits on space 14 | """ 15 | 16 | def __init__(self, args): 17 | super().__init__(args) 18 | self.segment_k = args.segment_k 19 | 20 | @classmethod 21 | def from_args(cls, args, **kwargs): 22 | return cls(args) 23 | 24 | def add_args(parser: ArgumentParser): 25 | parser.add_argument( 26 | "--segment-k", 27 | type=int, 28 | help="Output segments with this many words", 29 | required=True, 30 | ) 31 | 32 | def policy(self, states: AgentStates): 33 | if len(states.source) == self.segment_k or states.source_finished: 34 | out = " ".join(states.source) 35 | states.source = [] 36 | return WriteAction(out, finished=states.source_finished) 37 | return ReadAction() 38 | 39 | 40 | class SentencePieceModelDetokenizerAgent(TextToTextAgent): 41 | def __init__(self, args): 42 | super().__init__(args) 43 | self.args.bpe = "sentencepiece" 44 | spm_processor = build_bpe(self.args) 45 | self.spm_processor = spm_processor 46 | self.detokenize_only = args.detokenize_only 47 | 48 | @classmethod 49 | def from_args(cls, args, **kwargs): 50 | return cls(args) 51 | 52 | def add_args(parser: ArgumentParser): 53 | parser.add_argument( 54 | "--sentencepiece-model", 55 | type=str, 56 | help="Path to sentencepiece model.", 57 | required=True, 58 | ) 59 | parser.add_argument( 60 | "--detokenize-only", 61 | action="store_true", 62 | default=False, 63 | help=( 64 | "Run detokenization without waiting for new token. By default(False)," 65 | "wait for beginning of next word before finalizing the previous word" 66 | ), 67 | ) 68 | 69 | def policy(self, states: AgentStates): 70 | possible_full_words = self.spm_processor.decode( 71 | " ".join([x for x in states.source]) 72 | ) 73 | 74 | if self.detokenize_only and len(states.source) > 0: 75 | states.source = [] 76 | if len(possible_full_words) == 0 and not states.source_finished: 77 | return ReadAction() 78 | else: 79 | return WriteAction(possible_full_words, states.source_finished) 80 | 81 | if states.source_finished: 82 | return WriteAction(possible_full_words, True) 83 | elif len(possible_full_words.split()) > 1: 84 | full_words, last_word = ( 85 | possible_full_words.split()[:-1], 86 | possible_full_words.split()[-1], 87 | ) 88 | states.source = [self.spm_processor.encode(last_word)] 89 | return WriteAction(" ".join(full_words), finished=False) 90 | else: 91 | return ReadAction() 92 | 93 | 94 | class DummyPipeline(AgentPipeline): 95 | pipeline = [DummySegmentAgent, SentencePieceModelDetokenizerAgent] 96 | -------------------------------------------------------------------------------- /examples/quick_start/spm_source.txt: -------------------------------------------------------------------------------- 1 | ▁Let ' s ▁do ▁it ▁with out ▁hesitation . -------------------------------------------------------------------------------- /examples/quick_start/spm_target.txt: -------------------------------------------------------------------------------- 1 | Let's do it without hesitation. -------------------------------------------------------------------------------- /examples/quick_start/target.txt: -------------------------------------------------------------------------------- 1 | Z U S N B Y X Q L O T 2 | M A J F P G O Y V R H M Z O T A 3 | M O W A O I D H H B O F 4 | N Q N I P C O H A A 5 | G B O J H P W C I A L V 6 | P T Z D E E N T B Y G Z R K 7 | F S H U K R W K S B R K M B B Q F C O U 8 | M H O L W Z G J Y X J B I 9 | A V B F E S F E W Q C S 10 | F N O I E Z B R S C V N S -------------------------------------------------------------------------------- /examples/speech_to_speech/english_alternate_agent.py: -------------------------------------------------------------------------------- 1 | from simuleval.utils import entrypoint 2 | from simuleval.data.segments import SpeechSegment 3 | from simuleval.agents import SpeechToSpeechAgent 4 | from simuleval.agents.actions import WriteAction, ReadAction 5 | from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub 6 | from fairseq.models.text_to_speech.hub_interface import TTSHubInterface 7 | 8 | 9 | class TTSModel: 10 | def __init__(self): 11 | models, cfg, task = load_model_ensemble_and_task_from_hf_hub( 12 | "facebook/fastspeech2-en-ljspeech", 13 | arg_overrides={"vocoder": "hifigan", "fp16": False}, 14 | ) 15 | TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) 16 | self.tts_generator = task.build_generator(models, cfg) 17 | self.tts_task = task 18 | self.tts_model = models[0] 19 | self.tts_model.to("cpu") 20 | self.tts_generator.vocoder.to("cpu") 21 | 22 | def synthesize(self, text): 23 | sample = TTSHubInterface.get_model_input(self.tts_task, text) 24 | if sample["net_input"]["src_lengths"][0] == 0: 25 | return [], 0 26 | for key in sample["net_input"].keys(): 27 | if sample["net_input"][key] is not None: 28 | sample["net_input"][key] = sample["net_input"][key].to("cpu") 29 | 30 | wav, rate = TTSHubInterface.get_prediction( 31 | self.tts_task, self.tts_model, self.tts_generator, sample 32 | ) 33 | wav = wav.tolist() 34 | return wav, rate 35 | 36 | 37 | @entrypoint 38 | class EnglishAlternateAgent(SpeechToSpeechAgent): 39 | """ 40 | Incrementally feed text to this offline Fastspeech2 TTS model, 41 | with an alternating speech pattern that is decrementing. 42 | """ 43 | 44 | def __init__(self, args): 45 | super().__init__(args) 46 | self.wait_seconds = args.wait_seconds 47 | self.tts_model = TTSModel() 48 | 49 | @staticmethod 50 | def add_args(parser): 51 | parser.add_argument("--wait-seconds", default=1, type=int) 52 | 53 | def policy(self): 54 | length_in_seconds = round( 55 | len(self.states.source) / self.states.source_sample_rate 56 | ) 57 | if not self.states.source_finished and length_in_seconds < self.wait_seconds: 58 | return ReadAction() 59 | if length_in_seconds % 2 == 0: 60 | samples, fs = self.tts_model.synthesize( 61 | f"{8 - length_in_seconds} even even" 62 | ) 63 | else: 64 | samples, fs = self.tts_model.synthesize(f"{8 - length_in_seconds} odd odd") 65 | 66 | # A SpeechSegment has to be returned for speech-to-speech translation system 67 | return WriteAction( 68 | SpeechSegment( 69 | content=samples, 70 | sample_rate=fs, 71 | finished=self.states.source_finished, 72 | ), 73 | finished=self.states.source_finished, 74 | ) 75 | -------------------------------------------------------------------------------- /examples/speech_to_speech/english_counter_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from simuleval.agents.states import AgentStates 3 | from simuleval.utils import entrypoint 4 | from simuleval.data.segments import SpeechSegment 5 | from simuleval.agents import SpeechToSpeechAgent 6 | from simuleval.agents.actions import WriteAction, ReadAction 7 | from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub 8 | from fairseq.models.text_to_speech.hub_interface import TTSHubInterface 9 | 10 | 11 | class TTSModel: 12 | def __init__(self): 13 | models, cfg, task = load_model_ensemble_and_task_from_hf_hub( 14 | "facebook/fastspeech2-en-ljspeech", 15 | arg_overrides={"vocoder": "hifigan", "fp16": False}, 16 | ) 17 | TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) 18 | self.tts_generator = task.build_generator(models, cfg) 19 | self.tts_task = task 20 | self.tts_model = models[0] 21 | self.tts_model.to("cpu") 22 | self.tts_generator.vocoder.to("cpu") 23 | 24 | def synthesize(self, text): 25 | sample = TTSHubInterface.get_model_input(self.tts_task, text) 26 | if sample["net_input"]["src_lengths"][0] == 0: 27 | return [], 0 28 | for key in sample["net_input"].keys(): 29 | if sample["net_input"][key] is not None: 30 | sample["net_input"][key] = sample["net_input"][key].to("cpu") 31 | 32 | wav, rate = TTSHubInterface.get_prediction( 33 | self.tts_task, self.tts_model, self.tts_generator, sample 34 | ) 35 | wav = wav.tolist() 36 | return wav, rate 37 | 38 | 39 | @entrypoint 40 | class EnglishSpeechCounter(SpeechToSpeechAgent): 41 | """ 42 | Incrementally feed text to this offline Fastspeech2 TTS model, 43 | with a minimum numbers of phonemes every chunk. 44 | """ 45 | 46 | def __init__(self, args): 47 | super().__init__(args) 48 | self.wait_seconds = args.wait_seconds 49 | self.tts_model = TTSModel() 50 | 51 | @staticmethod 52 | def add_args(parser): 53 | parser.add_argument("--wait-seconds", default=1, type=int) 54 | 55 | def policy(self, states: Optional[AgentStates] = None): 56 | if states is None: 57 | states = self.states 58 | if states.source_sample_rate == 0: 59 | # empty source, source_sample_rate not set yet 60 | length_in_seconds = 0 61 | else: 62 | length_in_seconds = round(len(states.source) / states.source_sample_rate) 63 | if not states.source_finished and length_in_seconds < self.wait_seconds: 64 | return ReadAction() 65 | samples, fs = self.tts_model.synthesize(f"{length_in_seconds} mississippi") 66 | 67 | # A SpeechSegment has to be returned for speech-to-speech translation system 68 | return WriteAction( 69 | SpeechSegment( 70 | content=samples, 71 | sample_rate=fs, 72 | finished=states.source_finished, 73 | ), 74 | finished=states.source_finished, 75 | ) 76 | -------------------------------------------------------------------------------- /examples/speech_to_speech/eval.sh: -------------------------------------------------------------------------------- 1 | simuleval \ 2 | --agent english_counter_agent.py --output output \ 3 | --source source.txt --target reference/en.txt --source-segment-size 1000\ 4 | --quality-metrics WHISPER_ASR_BLEU \ 5 | --target-speech-lang en --transcript-lowercase --transcript-non-punctuation --whisper-model-size large \ 6 | --latency-metrics StartOffset EndOffset ATD 7 | -------------------------------------------------------------------------------- /examples/speech_to_speech/readme.md: -------------------------------------------------------------------------------- 1 | ## Simultaneous Speech-to-Speech Translation 2 | 3 | This tutorial provides a minimal example on how to evaluate a simultaneous speech-to-speech translation system. 4 | 5 | ### Requirements 6 | 7 | To run this example, the following package is required 8 | 9 | - [`whisper`](https://github.com/openai/whisper): for quality evaluation (`WHISPER_ASR_BLEU`). 10 | 11 | ### Agent 12 | 13 | The speech-to-speech agent ([english_counter_agent.py](english_counter_agent.py)) in this example is a counter, which generates a piece of audio every second after an initial wait. 14 | The policy of the agent is show follow. The agent will wait for `self.wait_seconds` seconds, 15 | and generate the audio of `{length_in_seconds} mississippi` every second afterward. 16 | 17 | ```python 18 | def policy(self): 19 | length_in_seconds = round( 20 | len(self.states.source) / self.states.source_sample_rate 21 | ) 22 | if not self.states.source_finished and length_in_seconds < self.wait_seconds: 23 | return ReadAction() 24 | print(length_in_seconds) 25 | samples, fs = self.tts_model.synthesize(f"{length_in_seconds} mississippi") 26 | 27 | # A SpeechSegment has to be returned for speech-to-speech translation system 28 | return WriteAction( 29 | SpeechSegment( 30 | content=samples, 31 | sample_rate=fs, 32 | finished=self.states.source_finished, 33 | ), 34 | finished=self.states.source_finished, 35 | ) 36 | ``` 37 | 38 | Notice that for speech output agent, the `WriteAction` has to contain a `SpeechSegment` class. 39 | 40 | ### Evaluation 41 | 42 | The following command will start an evaluation 43 | 44 | ```bash 45 | simuleval \ 46 | --agent english_counter_agent.py --output output \ 47 | --source source.txt --target reference/en.txt --source-segment-size 1000\ 48 | --quality-metrics WHISPER_ASR_BLEU \ 49 | --target-speech-lang en --transcript-lowercase --transcript-non-punctuation\ 50 | --latency-metrics StartOffset EndOffset ATD 51 | ``` 52 | 53 | For quality evaluation, we use ASR_BLEU, that is transcribing the speech output and compute BLEU score with the reference text. To use this feature, `whisper` has to be installed. 54 | 55 | We use three metrics for latency evaluation 56 | 57 | - `StartOffset`: The starting offset of translation comparing with source audio 58 | - `EndOffset`: The ending offset of translation comparing with source audio 59 | - `ATD`: Average Token Delay 60 | 61 | The results of the evaluation should be as following. The transcripts and alignments can be found in the `output` directory. 62 | 63 | ``` 64 | WHISPER_ASR_BLEU StartOffset EndOffset ATD 65 | 100.0 1000.0 1490.703 1248.261 66 | ``` 67 | -------------------------------------------------------------------------------- /examples/speech_to_speech/reference/de.txt: -------------------------------------------------------------------------------- 1 | ein Mississippi zwei Mississippi drei Mississippi vier Mississippi fünf Mississippi sechs Mississippi sieben Mississippi 2 | -------------------------------------------------------------------------------- /examples/speech_to_speech/reference/en.txt: -------------------------------------------------------------------------------- 1 | one mississippi two mississippi three mississippi four mississippi five mississippi six mississippi seven mississippi 2 | -------------------------------------------------------------------------------- /examples/speech_to_speech/reference/ja.txt: -------------------------------------------------------------------------------- 1 | 1 ミシシッピ 2 ミシシッピ 3 ミシシッピ 4 ミシシッピ 5 ミシシッピ 6 ミシシッピ 7 ミシシッピ 2 | -------------------------------------------------------------------------------- /examples/speech_to_speech/reference/tgt_lang.txt: -------------------------------------------------------------------------------- 1 | es -------------------------------------------------------------------------------- /examples/speech_to_speech/reference/zh.txt: -------------------------------------------------------------------------------- 1 | 一密西西比二密西西比三密西西比四密西西比五密西西比六密西西比七密西西比 2 | -------------------------------------------------------------------------------- /examples/speech_to_speech/source.txt: -------------------------------------------------------------------------------- 1 | test.wav 2 | -------------------------------------------------------------------------------- /examples/speech_to_speech/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SimulEval/536de8253b82d805c9845440169a5010ff507357/examples/speech_to_speech/test.wav -------------------------------------------------------------------------------- /examples/speech_to_speech_demo/english_counter_pipeline.py: -------------------------------------------------------------------------------- 1 | from simuleval.agents import AgentPipeline 2 | from examples.demo.silero_vad import SileroVADAgent 3 | from examples.speech_to_speech.english_counter_agent import EnglishSpeechCounter 4 | 5 | 6 | class EnglishCounterAgentPipeline(AgentPipeline): 7 | pipeline = [ 8 | SileroVADAgent, 9 | EnglishSpeechCounter, 10 | ] 11 | -------------------------------------------------------------------------------- /examples/speech_to_speech_demo/readme.md: -------------------------------------------------------------------------------- 1 | Running the demo: 2 | 1. Create a directory for the dummy model: `models/$DUMMY_MODEL` 3 | 2. Create a new yaml file `models/$DUMMY_MODEL/vad_main.yaml`, with the following: 4 | ``` 5 | agent_class: examples.speech_to_speech_demo.english_counter_pipeline.EnglishCounterAgentPipeline 6 | ``` 7 | 3. Set the available agent in `SimulevalAgentDirectory.py` to `$DUMMY_MODEL` 8 | 4. Run `python app.py` 9 | 10 | 11 | - Note: If you get an ImportError for `examples.speech_to_speech_demo`, run `python -c "import examples; print(examples.__file__)"`. If the file is something like `$PREFIX/site-packages/examples/__init__.py`, run `rm -r $PREFIX/site-packages/examples` and try again. -------------------------------------------------------------------------------- /examples/speech_to_speech_text/tree_agent_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from simuleval.agents import TreeAgentPipeline 3 | from examples.speech_to_speech.english_counter_agent import ( 4 | EnglishSpeechCounter as EnglishSpeechToSpeech, 5 | ) 6 | from examples.speech_to_text.english_counter_agent import ( 7 | EnglishSpeechCounter as EnglishSpeechToText, 8 | ) 9 | from simuleval.agents.actions import WriteAction 10 | from simuleval.agents.agent import SpeechToTextAgent 11 | from simuleval.agents.states import AgentStates 12 | 13 | 14 | class EnglishWait2SpeechToText(EnglishSpeechToText): 15 | def __init__(self, args): 16 | super().__init__(args) 17 | args.wait_seconds = 2 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | pass 22 | 23 | 24 | class DotSpeechToText(SpeechToTextAgent): 25 | def policy(self, states: Optional[AgentStates] = None): 26 | return WriteAction( 27 | content=".", 28 | finished=states.source_finished, 29 | ) 30 | 31 | 32 | class EnglishWait2SpeechToSpeech(EnglishSpeechToSpeech): 33 | def __init__(self, args): 34 | super().__init__(args) 35 | args.wait_seconds = 2 36 | 37 | @staticmethod 38 | def add_args(parser): 39 | pass 40 | 41 | 42 | class DummyTreePipeline(TreeAgentPipeline): 43 | # pipeline is a dict, used to instantiate agents in from_args 44 | pipeline = { 45 | EnglishSpeechToSpeech: [EnglishWait2SpeechToText, EnglishWait2SpeechToSpeech], 46 | EnglishWait2SpeechToSpeech: [], 47 | EnglishWait2SpeechToText: [], 48 | } 49 | 50 | 51 | class TemplateTreeAgentPipeline(TreeAgentPipeline): 52 | def __init__(self, args) -> None: 53 | speech_speech = self.pipeline[0](args) 54 | speech_speech_wait2 = self.pipeline[1](args) 55 | speech_text = self.pipeline[2](args) 56 | 57 | module_dict = { 58 | speech_speech: [speech_text, speech_speech_wait2], 59 | speech_speech_wait2: [], 60 | speech_text: [], 61 | } 62 | 63 | super().__init__(module_dict, args) 64 | 65 | @classmethod 66 | def from_args(cls, args): 67 | return cls(args) 68 | 69 | 70 | class InstantiatedTreeAgentPipeline(TemplateTreeAgentPipeline): 71 | pipeline = [ 72 | EnglishSpeechToSpeech, 73 | EnglishWait2SpeechToSpeech, 74 | EnglishWait2SpeechToText, 75 | ] 76 | 77 | 78 | class AnotherInstantiatedTreeAgentPipeline(TemplateTreeAgentPipeline): 79 | pipeline = [ 80 | EnglishSpeechToSpeech, 81 | EnglishWait2SpeechToSpeech, 82 | DotSpeechToText, # swap the speech_text module in the pipeline 83 | ] 84 | -------------------------------------------------------------------------------- /examples/speech_to_text/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | RUN apt-get update \ 3 | && apt-get upgrade -y \ 4 | && apt-get -y install apt-utils gcc libpq-dev libsndfile-dev 5 | RUN pip install -U openai-whisper 6 | RUN pip install -U editdistance 7 | RUN git clone https://github.com/facebookresearch/SimulEval 8 | WORKDIR /SimulEval/ 9 | RUN pip install -e . 10 | WORKDIR /SimulEval/examples/speech_to_text/ 11 | CMD ["simuleval", "--standalone", "--remote-port", "8888", "--agent", "whisper_waitk.py", "--waitk-lagging", "3"] 12 | -------------------------------------------------------------------------------- /examples/speech_to_text/counter_in_tgt_lang_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from simuleval.agents.states import AgentStates 3 | from simuleval.utils import entrypoint 4 | from simuleval.agents import SpeechToTextAgent 5 | from simuleval.agents.actions import WriteAction, ReadAction 6 | 7 | 8 | @entrypoint 9 | class CounterInTargetLanguage(SpeechToTextAgent): 10 | """ 11 | The agent generate the number of seconds from an input audio and output it in the target language text 12 | """ 13 | 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.wait_seconds = args.wait_seconds 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | parser.add_argument("--wait-seconds", default=1, type=int) 21 | 22 | def policy(self, states: Optional[AgentStates] = None): 23 | if states is None: 24 | states = self.states 25 | if states.source_sample_rate == 0: 26 | # empty source, source_sample_rate not set yet 27 | length_in_seconds = 0 28 | else: 29 | length_in_seconds = round(len(states.source) / states.source_sample_rate) 30 | 31 | if not states.source_finished and length_in_seconds < self.wait_seconds: 32 | return ReadAction() 33 | 34 | prediction = f"{length_in_seconds} " 35 | tgt_lang = states.tgt_lang 36 | if tgt_lang == "en": 37 | prediction += "seconds" 38 | elif tgt_lang == "es": 39 | prediction += "segundos" 40 | elif tgt_lang == "de": 41 | prediction += "sekunden" 42 | else: 43 | prediction += "" 44 | 45 | return WriteAction( 46 | content=prediction, 47 | finished=states.source_finished, 48 | ) 49 | -------------------------------------------------------------------------------- /examples/speech_to_text/english_counter_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from simuleval.agents.states import AgentStates 3 | from simuleval.utils import entrypoint 4 | from simuleval.agents import SpeechToTextAgent 5 | from simuleval.agents.actions import WriteAction, ReadAction 6 | 7 | 8 | @entrypoint 9 | class EnglishSpeechCounter(SpeechToTextAgent): 10 | """ 11 | The agent generate the number of seconds from an input audio. 12 | """ 13 | 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.wait_seconds = args.wait_seconds 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | parser.add_argument("--wait-seconds", default=1, type=int) 21 | 22 | def policy(self, states: Optional[AgentStates] = None): 23 | if states is None: 24 | states = self.states 25 | if states.source_sample_rate == 0: 26 | # empty source, source_sample_rate not set yet 27 | length_in_seconds = 0 28 | else: 29 | length_in_seconds = round(len(states.source) / states.source_sample_rate) 30 | if not states.source_finished and length_in_seconds < self.wait_seconds: 31 | return ReadAction() 32 | 33 | prediction = f"{length_in_seconds} second" 34 | 35 | return WriteAction( 36 | content=prediction, 37 | finished=states.source_finished, 38 | ) 39 | -------------------------------------------------------------------------------- /examples/speech_to_text/eval.sh: -------------------------------------------------------------------------------- 1 | simuleval \ 2 | --agent counter_in_tgt_lang_agent.py \ 3 | --source-segment-size 1000 \ 4 | --source source.txt --target reference/en.txt \ 5 | --tgt-lang reference/tgt_lang.txt \ 6 | --output output 7 | -------------------------------------------------------------------------------- /examples/speech_to_text/readme.md: -------------------------------------------------------------------------------- 1 | ## Simultaneous Speech-to-Text Translation 2 | 3 | This tutorial provides a minimal example on how to evaluate a simultaneous speech-to-text translation system. 4 | 5 | ### Agent 6 | 7 | The speech-to-text agent ([english_counter_agent.py](english_counter_agent.py)) in this example is a counter, which generates number of seconds in text, after waiting for `self.wait_seconds` seconds. The policy finishes when the source is finished. 8 | 9 | ```python 10 | def policy(self): 11 | length_in_seconds = round( 12 | len(self.states.source) / self.states.source_sample_rate 13 | ) 14 | if not self.states.source_finished and length_in_seconds < self.wait_seconds: 15 | return ReadAction() 16 | 17 | prediction = f"{length_in_seconds} second" 18 | 19 | return WriteAction( 20 | content=prediction, 21 | finished=self.states.source_finished, 22 | ) 23 | ``` 24 | 25 | ### Evaluation 26 | 27 | The following command will start an evaluation 28 | 29 | ```bash 30 | simuleval \ 31 | --agent english_counter_agent.py \ 32 | --source-segment-size 1000 \ 33 | --source source.txt --target reference/en.txt \ 34 | --output output 35 | ``` 36 | 37 | The results of the evaluation should be as following. The detailed results can be found in the `output` directory. 38 | 39 | ``` 40 | BLEU LAAL AL AP DAL ATD 41 | 100.0 822.018 822.018 0.581 1061.271 2028.555 42 | ``` 43 | 44 | ### Example Streaming ASR / S2T: Whipser Wait-K model 45 | 46 | This section provide a more realistic model. [whisper_waitk.py](whisper_waitk.py) is a streaming ASR agent running [wait-k](https://aclanthology.org/P19-1289/) policy on the [Whisper](https://github.com/openai/whisper) ASR model 47 | 48 | ```bash 49 | simuleval \ 50 | --agent whisper_waitk.py \ 51 | --source-segment-size 500 \ 52 | --waitk-lagging 3 \ 53 | --source source.txt --target reference/transcript.txt \ 54 | --output output --quality-metrics WER 55 | ``` 56 | 57 | The results of the evaluation should be as following. The detailed results can be found in the `output` directory. 58 | 59 | ``` 60 | WER LAAL AL AP DAL ATD 61 | 25.0 2353.772 2353.772 0.721 2491.04 2457.847 62 | ``` 63 | 64 | This agent can also perform S2T task, by adding `--task translate`. 65 | 66 | ### Streaming Speech-to-Text Demo 67 | 68 | A streaming speech to text demo feature, taking input from user's microphone, sending it to Whisper's wait-k model, and displaying the prediction texts in the terminal. 69 | 70 | 1. Kick off a remote agent. More information [Remote_agent](../../docs/tutorials/remote_evaluation.rst) 71 | 2. Enter demo mode by providing a desired segment size (usually 500ms): 72 | 73 | ```bash 74 | simuleval --remote-eval --demo --source-segment-size 500 --remote-port 8888 75 | ``` 76 | 77 | 3. Speak into the microphone and watch the live transcription! 78 | 4. Press ^c (Control C) to exit the program in terminal 79 | -------------------------------------------------------------------------------- /examples/speech_to_text/reference/en.txt: -------------------------------------------------------------------------------- 1 | 1 second 2 second 3 second 4 second 5 second 6 second 7 second 2 | -------------------------------------------------------------------------------- /examples/speech_to_text/reference/tgt_lang.txt: -------------------------------------------------------------------------------- 1 | es -------------------------------------------------------------------------------- /examples/speech_to_text/reference/transcript.txt: -------------------------------------------------------------------------------- 1 | This is a synthesized audio file to test your simultaneous speech to text and to speech to speach translation system. -------------------------------------------------------------------------------- /examples/speech_to_text/source.txt: -------------------------------------------------------------------------------- 1 | test.wav 2 | -------------------------------------------------------------------------------- /examples/speech_to_text/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SimulEval/536de8253b82d805c9845440169a5010ff507357/examples/speech_to_text/test.wav -------------------------------------------------------------------------------- /examples/speech_to_text/whisper_waitk.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from simuleval.agents.states import AgentStates 3 | from simuleval.utils import entrypoint 4 | from simuleval.data.segments import SpeechSegment 5 | from simuleval.agents import SpeechToTextAgent 6 | from simuleval.agents.actions import WriteAction, ReadAction 7 | 8 | import whisper 9 | import numpy 10 | 11 | 12 | @entrypoint 13 | class WaitkWhisper(SpeechToTextAgent): 14 | """ 15 | The agent generate the number of seconds from an input audio. 16 | """ 17 | 18 | def __init__(self, args): 19 | super().__init__(args) 20 | self.waitk_lagging = args.waitk_lagging 21 | self.source_segment_size = args.source_segment_size 22 | self.source_language = args.source_language 23 | self.continuous_write = args.continuous_write 24 | self.model_size = args.model_size 25 | self.model = whisper.load_model(self.model_size) 26 | self.task = args.task 27 | if self.task == "translate": 28 | assert ( 29 | self.source_language != "en" 30 | ), "source language must be different from en for translation task" 31 | 32 | @staticmethod 33 | def add_args(parser): 34 | parser.add_argument("--waitk-lagging", default=1, type=int) 35 | parser.add_argument("--source-language", default="en", type=str) 36 | parser.add_argument( 37 | "--continuous-write", 38 | default=1, 39 | type=int, 40 | help="Max number of words to write at each step", 41 | ) 42 | parser.add_argument("--model-size", default="tiny", type=str) 43 | parser.add_argument( 44 | "--task", 45 | default="transcribe", 46 | type=str, 47 | choices=["transcribe", "translate"], 48 | ) 49 | 50 | def policy(self, states: Optional[AgentStates] = None): 51 | if states is None: 52 | states = self.states 53 | 54 | if states.source_sample_rate == 0: 55 | # empty source, source_sample_rate not set yet 56 | length_in_seconds = 0 57 | else: 58 | length_in_seconds = float(len(states.source)) / states.source_sample_rate 59 | 60 | if not states.source_finished: 61 | if ( 62 | length_in_seconds * 1000 / self.source_segment_size 63 | ) < self.waitk_lagging: 64 | return ReadAction() 65 | 66 | previous_translation = " ".join(states.target) 67 | # We use the previous translation as a prefix. 68 | options = whisper.DecodingOptions( 69 | prefix=previous_translation, 70 | language=self.source_language, 71 | without_timestamps=True, 72 | fp16=False, 73 | ) 74 | 75 | # We encode the whole audio to get the full transcription each time a new audio chunk is received. 76 | audio = whisper.pad_or_trim(numpy.array(states.source).astype("float32")) 77 | mel = whisper.log_mel_spectrogram(audio).to(self.model.device) 78 | output = self.model.decode(mel, options) 79 | prediction = output.text.split() 80 | 81 | if not states.source_finished and self.continuous_write > 0: 82 | prediction = prediction[: self.continuous_write] 83 | 84 | return WriteAction( 85 | content=" ".join(prediction), 86 | finished=states.source_finished, 87 | ) 88 | -------------------------------------------------------------------------------- /examples/speech_to_text_demo/counter_in_tgt_lang_pipeline.py: -------------------------------------------------------------------------------- 1 | from simuleval.agents import AgentPipeline 2 | from examples.demo.silero_vad import SileroVADAgent 3 | from examples.speech_to_text.counter_in_tgt_lang import CounterInTargetLanguage 4 | 5 | 6 | class CounterInTargetLanguageAgentPipeline(AgentPipeline): 7 | pipeline = [ 8 | SileroVADAgent, 9 | CounterInTargetLanguage, 10 | ] 11 | -------------------------------------------------------------------------------- /examples/speech_to_text_demo/english_counter_pipeline.py: -------------------------------------------------------------------------------- 1 | from simuleval.agents import AgentPipeline 2 | from examples.demo.silero_vad import SileroVADAgent 3 | from examples.speech_to_text.english_counter_agent import EnglishSpeechCounter 4 | 5 | 6 | class EnglishCounterAgentPipeline(AgentPipeline): 7 | pipeline = [ 8 | SileroVADAgent, 9 | EnglishSpeechCounter, 10 | ] 11 | -------------------------------------------------------------------------------- /examples/speech_to_text_demo/readme.md: -------------------------------------------------------------------------------- 1 | Running the demo: 2 | 1. Create a directory for the dummy model: `models/$DUMMY_MODEL` 3 | 2. Create a new yaml file `models/$DUMMY_MODEL/vad_main.yaml`, with the following: 4 | ``` 5 | agent_class: examples.speech_to_text_demo.english_counter_pipeline.EnglishCounterAgentPipeline 6 | ``` 7 | 3. Set the available agent in `SimulevalAgentDirectory.py` to `$DUMMY_MODEL` 8 | 4. Run `python app.py` 9 | 10 | 11 | - Note: If you get an ImportError for `examples.speech_to_text_demo`, run `python -c "import examples; print(examples.__file__)"`. If the file is something like `$PREFIX/site-packages/examples/__init__.py`, run `rm -r $PREFIX/site-packages/examples` and try again. -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | flake8-max-line-length = 127 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | with open("README.md", "r") as readme_file: 10 | long_description = readme_file.read() 11 | 12 | setup( 13 | python_requires=">3.7.0", 14 | name="simuleval", 15 | version="1.1.4", 16 | author="Xutai Ma", 17 | description="SimulEval: A Flexible Toolkit for Automated Machine Translation Evaluation", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | homepage="https://github.com/facebookresearch/SimulEval.git", 21 | documentation="https://simuleval.readthedocs.io/en/v1.1.0/quick_start.html", 22 | license="LICENSE", 23 | entry_points={ 24 | "console_scripts": [ 25 | "simuleval = simuleval.cli:main", 26 | ], 27 | }, 28 | install_requires=[ 29 | "pytest", 30 | "pytest-cov", 31 | "sacrebleu>=2.3.1", 32 | "tornado", 33 | "soundfile", 34 | "pandas", 35 | "requests", 36 | "pytest-flake8", 37 | "textgrid", 38 | "tqdm==4.64.1", 39 | "pyyaml", 40 | "bitarray==2.6.0", 41 | "yt-dlp", 42 | "pydub", 43 | "matplotlib", 44 | ], 45 | classifiers=[ 46 | "Programming Language :: Python :: 3", 47 | "License :: OSI Approved :: Apache Software License", 48 | "Operating System :: OS Independent", 49 | "Development Status :: 4 - Beta", 50 | ], 51 | keywords=[ 52 | "SimulEval", 53 | "Machine Translation", 54 | "Evaluation", 55 | "Metrics", 56 | "BLEU", 57 | "TER", 58 | "METEOR", 59 | "chrF", 60 | "RIBES", 61 | "WMD", 62 | "Embedding Average", 63 | "Embedding Extrema", 64 | "Embedding Greedy", 65 | "Embedding Average", 66 | "SimulEval", 67 | "SimulEval_Testing_Package_1", 68 | "facebookresearch", 69 | "facebook", 70 | "Meta-Evaluation", 71 | ], 72 | packages=find_packages( 73 | exclude=[ 74 | "examples", 75 | "examples.*", 76 | "docs", 77 | "docs.*", 78 | ] 79 | ), 80 | setup_requires=["setuptools_scm"], 81 | use_scm_version=True, 82 | ) 83 | -------------------------------------------------------------------------------- /simuleval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /simuleval/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .agent import ( # noqa 8 | GenericAgent, 9 | SpeechToTextAgent, 10 | SpeechToSpeechAgent, 11 | TextToSpeechAgent, 12 | TextToTextAgent, 13 | ) 14 | from .states import AgentStates # noqa 15 | from .actions import Action, ReadAction, WriteAction # noqa 16 | from .pipeline import AgentPipeline, TreeAgentPipeline # noqa 17 | -------------------------------------------------------------------------------- /simuleval/agents/actions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple, Union, List 8 | from dataclasses import dataclass 9 | from simuleval.data.segments import Segment 10 | 11 | 12 | class Action: 13 | """ 14 | Abstract Action class 15 | """ 16 | 17 | def is_read(self) -> bool: 18 | """ 19 | Whether the action is a read action 20 | 21 | Returns: 22 | bool: True if the action is a read action. 23 | """ 24 | assert NotImplementedError 25 | 26 | 27 | class ReadAction(Action): 28 | """ 29 | Action to return when policy decide to read one more source segment. 30 | The only way to use it is to return :code:`ReadAction()` 31 | """ 32 | 33 | def is_read(self) -> bool: 34 | return True 35 | 36 | def __repr__(self) -> str: 37 | return "ReadAction()" 38 | 39 | 40 | @dataclass 41 | class WriteAction(Action): 42 | """ 43 | Action to return when policy decide to generate a prediction 44 | 45 | Args: 46 | content (Union[str, List[float], Tuple[List[float], str]]): The prediction. 47 | finished (bool): Indicates if current sentence is finished. 48 | 49 | .. note:: 50 | For text the prediction a str; for speech, it's a list. 51 | For speech_text, it's a tuple[list, str] 52 | 53 | """ 54 | 55 | content: Union[str, List[float], Tuple[List[float], str]] 56 | finished: bool 57 | 58 | def is_read(self) -> bool: 59 | return False 60 | 61 | def __repr__(self) -> str: 62 | return f"WriteAction({self.content}, finished={self.finished})" 63 | -------------------------------------------------------------------------------- /simuleval/agents/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from inspect import signature 8 | from argparse import Namespace, ArgumentParser 9 | from simuleval.data.segments import ( 10 | Segment, 11 | SpeechTextSegment, 12 | TextSegment, 13 | SpeechSegment, 14 | EmptySegment, 15 | ) 16 | from typing import Optional, List 17 | from .states import AgentStates 18 | from .actions import Action 19 | 20 | 21 | SEGMENT_TYPE_DICT = { 22 | "text": TextSegment, 23 | "speech": SpeechSegment, 24 | "speech_text": SpeechTextSegment, 25 | } 26 | 27 | 28 | class GenericAgent: 29 | """ 30 | Generic Agent class. 31 | """ 32 | 33 | source_type = None 34 | target_type = None 35 | 36 | def __init__(self, args: Optional[Namespace] = None) -> None: 37 | if args is not None: 38 | self.args = args 39 | assert self.source_type 40 | assert self.target_type 41 | self.device = "cpu" 42 | 43 | self.states = self.build_states() 44 | self.reset() 45 | 46 | def build_states(self) -> AgentStates: 47 | """ 48 | Build states instance for agent 49 | 50 | Returns: 51 | AgentStates: agent states 52 | """ 53 | return AgentStates() 54 | 55 | def reset(self) -> None: 56 | """ 57 | Reset agent, called every time when a new sentence coming in. 58 | Applies for stateful agents. 59 | """ 60 | self.states.reset() 61 | 62 | def policy(self, states: Optional[AgentStates] = None) -> Action: 63 | """ 64 | The policy to make decision every time 65 | when the system has new input. 66 | The function has to return an Action instance 67 | 68 | Args: 69 | states (Optional[AgentStates]): an optional states for stateless agent 70 | 71 | Returns: 72 | Action: The actions to make at certain point. 73 | 74 | .. note: 75 | 76 | WriteAction means that the system has a prediction. 77 | ReadAction means that the system needs more source. 78 | When states are provided, the agent will become stateless and ignore self.states. 79 | """ 80 | assert NotImplementedError 81 | 82 | def push( 83 | self, 84 | source_segment: Segment, 85 | states: Optional[AgentStates] = None, 86 | upstream_states: Optional[List[AgentStates]] = None, 87 | ) -> None: 88 | """ 89 | The function to process the incoming information. 90 | 91 | Args: 92 | source_info (dict): incoming information dictionary 93 | states (Optional[AgentStates]): an optional states for stateless agent 94 | """ 95 | if states is None: 96 | states = self.states 97 | 98 | if upstream_states is None: 99 | upstream_states = [] 100 | 101 | states.upstream_states = upstream_states 102 | 103 | states.update_config(source_segment.config) 104 | states.update_source(source_segment) 105 | 106 | def pop(self, states: Optional[AgentStates] = None) -> Segment: 107 | """ 108 | The function to generate system output. 109 | By default, it first runs policy, 110 | and than returns the output segment. 111 | If the policy decide to read, 112 | it will return an empty segment. 113 | 114 | Args: 115 | states (Optional[AgentStates]): an optional states for stateless agent 116 | 117 | Returns: 118 | Segment: segment to return. 119 | """ 120 | if len(signature(self.policy).parameters) == 0: 121 | is_stateless = False 122 | if states: 123 | raise RuntimeError("Feeding states to stateful agents.") 124 | else: 125 | is_stateless = True 126 | 127 | if states is None: 128 | states = self.states 129 | 130 | if states.target_finished: 131 | return EmptySegment(finished=True) 132 | 133 | if is_stateless: 134 | action = self.policy(states) 135 | else: 136 | action = self.policy() 137 | 138 | if not isinstance(action, Action): 139 | raise RuntimeError( 140 | f"The return value of {self.policy.__qualname__} is not an {Action.__qualname__} instance" 141 | ) 142 | if action.is_read(): 143 | return EmptySegment() 144 | else: 145 | if isinstance(action.content, Segment): 146 | return action.content 147 | 148 | segment = SEGMENT_TYPE_DICT[self.target_type]( 149 | index=0, content=action.content, finished=action.finished 150 | ) 151 | states.update_target(segment) 152 | return segment 153 | 154 | def pushpop( 155 | self, 156 | segment: Segment, 157 | states: Optional[AgentStates] = None, 158 | upstream_states: Optional[List[AgentStates]] = None, 159 | ) -> Segment: 160 | """ 161 | Operate pop immediately after push. 162 | 163 | Args: 164 | segment (Segment): input segment 165 | 166 | Returns: 167 | Segment: output segment 168 | """ 169 | self.push(segment, states, upstream_states) 170 | return self.pop(states) 171 | 172 | @staticmethod 173 | def add_args(parser: ArgumentParser): 174 | """ 175 | Add agent arguments to parser. 176 | Has to be a static method. 177 | 178 | Args: 179 | parser (ArgumentParser): cli argument parser 180 | """ 181 | pass 182 | 183 | @classmethod 184 | def from_args(cls, args): 185 | return cls(args) 186 | 187 | def to(self, device: str, *args, **kwargs) -> None: 188 | """ 189 | Move agent to specified device. 190 | 191 | Args: 192 | device (str): Device to move agent to. 193 | """ 194 | pass 195 | 196 | def __repr__(self) -> str: 197 | return f"{self.__class__.__name__}[{self.source_type} -> {self.target_type}]" 198 | 199 | def __str__(self) -> str: 200 | return self.__repr__() 201 | 202 | 203 | class SpeechToTextAgent(GenericAgent): 204 | """ 205 | Same as generic agent, but with explicit types 206 | speech -> text 207 | """ 208 | 209 | source_type: str = "speech" 210 | target_type: str = "text" 211 | tgt_lang: Optional[str] = None 212 | 213 | 214 | class SpeechToSpeechAgent(GenericAgent): 215 | """ 216 | Same as generic agent, but with explicit types 217 | speech -> speech 218 | """ 219 | 220 | source_type: str = "speech" 221 | target_type: str = "speech" 222 | 223 | 224 | class TextToSpeechAgent(GenericAgent): 225 | """ 226 | Same as generic agent, but with explicit types 227 | text -> speech 228 | """ 229 | 230 | source_type: str = "text" 231 | target_type: str = "speech" 232 | 233 | 234 | class TextToTextAgent(GenericAgent): 235 | """ 236 | Same as generic agent, but with explicit types 237 | text -> text 238 | """ 239 | 240 | source_type: str = "text" 241 | target_type: str = "text" 242 | -------------------------------------------------------------------------------- /simuleval/agents/service.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import os 6 | import json 7 | import logging 8 | from tornado import web, ioloop 9 | from simuleval.data.segments import segment_from_json_string 10 | from simuleval import options 11 | 12 | logger = logging.getLogger("simuleval.agent_server") 13 | 14 | 15 | class SystemHandler(web.RequestHandler): 16 | def initialize(self, system): 17 | self.system = system 18 | 19 | def get(self): 20 | self.write(json.dumps({"info": str(self.system)})) 21 | 22 | 23 | class ResetHandle(SystemHandler): 24 | def post(self): 25 | self.system.reset() 26 | 27 | 28 | class OutputHandler(SystemHandler): 29 | def get(self): 30 | output_segment = self.system.pop() 31 | self.write(output_segment.json()) 32 | 33 | 34 | class InputHandler(SystemHandler): 35 | def put(self): 36 | segment = segment_from_json_string(self.request.body) 37 | self.system.push(segment) 38 | 39 | 40 | def start_agent_service(system): 41 | parser = options.general_parser() 42 | options.add_evaluator_args(parser) 43 | args, _ = parser.parse_known_args() 44 | app = web.Application( 45 | [ 46 | (r"/reset", ResetHandle, {"system": system}), 47 | (r"/input", InputHandler, {"system": system}), 48 | (r"/output", OutputHandler, {"system": system}), 49 | (r"/", SystemHandler, {"system": system}), 50 | ], 51 | debug=False, 52 | ) 53 | 54 | app.listen(args.remote_port, max_buffer_size=1024**3) 55 | 56 | logger.info( 57 | f"Simultaneous Translation Server Started (process id {os.getpid()}). Listening to port {args.remote_port} " 58 | ) 59 | ioloop.IOLoop.current().start() 60 | -------------------------------------------------------------------------------- /simuleval/agents/states.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from simuleval.data.segments import Segment, TextSegment, EmptySegment, SpeechSegment 8 | 9 | 10 | class AgentStates: 11 | """ 12 | Tracker of the decoding progress. 13 | 14 | Attributes: 15 | source (list): current source sequence. 16 | target (list): current target sequence. 17 | source_finished (bool): if the source is finished. 18 | target_finished (bool): if the target is finished. 19 | """ 20 | 21 | def __init__(self) -> None: 22 | self.reset() 23 | 24 | def reset(self) -> None: 25 | """Reset Agent states""" 26 | self.source = [] 27 | self.target = [] 28 | self.source_finished = False 29 | self.target_finished = False 30 | self.source_sample_rate = 0 31 | self.target_sample_rate = 0 32 | self.tgt_lang = None 33 | self.upstream_states = [] 34 | self.config = {} 35 | 36 | def update_config(self, config: dict): 37 | for k in config.keys(): 38 | self.config[k] = config[k] 39 | 40 | def update_source(self, segment: Segment): 41 | """ 42 | Update states from input segment 43 | 44 | Args: 45 | segment (~simuleval.agents.segments.Segment): input segment 46 | """ 47 | self.source_finished = segment.finished 48 | if isinstance(segment, EmptySegment): 49 | return 50 | elif isinstance(segment, TextSegment): 51 | self.source.append(segment.content) 52 | self.tgt_lang = segment.tgt_lang 53 | elif isinstance(segment, SpeechSegment): 54 | self.source += segment.content 55 | self.source_sample_rate = segment.sample_rate 56 | self.tgt_lang = segment.tgt_lang 57 | else: 58 | raise NotImplementedError 59 | 60 | def update_target(self, segment: Segment): 61 | """ 62 | Update states from output segment 63 | 64 | Args: 65 | segment (~simuleval.agents.segments.Segment): input segment 66 | """ 67 | self.target_finished = segment.finished 68 | if not self.target_finished: 69 | if isinstance(segment, EmptySegment): 70 | return 71 | elif isinstance(segment, TextSegment): 72 | self.target.append(segment.content) 73 | elif isinstance(segment, SpeechSegment): 74 | self.target += segment.content 75 | self.target_sample_rate = segment.sample_rate 76 | else: 77 | raise NotImplementedError 78 | -------------------------------------------------------------------------------- /simuleval/analysis/curve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import pandas 9 | from pathlib import Path 10 | from typing import Dict, List, Union 11 | 12 | 13 | class SimulEvalResults: 14 | def __init__(self, path: Union[Path, str]) -> None: 15 | self.path = Path(path) 16 | scores_path = self.path / "scores" 17 | if scores_path.exists(): 18 | self.is_finished = True 19 | with open(self.path / "scores") as f: 20 | self.scores = json.load(f) 21 | else: 22 | self.is_finished = False 23 | self.scores = {} 24 | 25 | @property 26 | def quality(self) -> float: 27 | if self.is_finished: 28 | if self.scores is None: 29 | return 0 30 | return self.scores["Quality"]["BLEU"] 31 | else: 32 | return 0 33 | 34 | @property 35 | def bleu(self) -> float: 36 | return self.quality 37 | 38 | @property 39 | def latency(self) -> Dict[str, float]: 40 | if self.is_finished: 41 | return self.scores["Latency"] 42 | else: 43 | return {} 44 | 45 | @property 46 | def average_lagging(self): 47 | return self.latency.get("AL", 0) 48 | 49 | @property 50 | def average_lagging_ca(self): 51 | return self.latency.get("AL_CA", 0) 52 | 53 | @property 54 | def average_proportion(self): 55 | return self.latency.get("AP", 0) 56 | 57 | @property 58 | def name(self): 59 | return self.path.name 60 | 61 | 62 | class S2SSimulEvalResults(SimulEvalResults): 63 | @property 64 | def bow_average_lagging(self): 65 | return self.latency.get("BOW", {}).get("AL", 0) 66 | 67 | @property 68 | def cow_average_lagging(self): 69 | return self.latency.get("COW", {}).get("AL", 0) 70 | 71 | @property 72 | def eow_average_lagging(self): 73 | return self.latency.get("EOW", {}).get("AL", 0) 74 | 75 | 76 | class QualityLatencyAnalyzer: 77 | def __init__(self) -> None: 78 | self.score_list: List[SimulEvalResults] = [] 79 | 80 | def add_scores_from_path(self, path: Path): 81 | self.score_list.append(SimulEvalResults(path)) 82 | 83 | @classmethod 84 | def from_paths(cls, path_list: List[Path]): 85 | analyzer = cls() 86 | for path in path_list: 87 | analyzer.add_scores_from_path(path) 88 | return analyzer 89 | 90 | def summarize(self): 91 | results = [] 92 | for score in self.score_list: 93 | if score.bleu == 0: 94 | continue 95 | results.append( 96 | [ 97 | score.name, 98 | round(score.average_lagging / 1000, 2), 99 | round(score.average_lagging_ca / 1000, 2), 100 | round(score.average_proportion, 2), 101 | round(score.bleu, 2), 102 | ] 103 | ) 104 | results.sort(key=lambda x: x[1]) 105 | return pandas.DataFrame(results, columns=["name", "AL", "AL(CA)", "AP", "BLEU"]) 106 | 107 | 108 | class S2SQualityLatencyAnalyzer(QualityLatencyAnalyzer): 109 | def add_scores_from_path(self, path: Path): 110 | self.score_list.append(S2SSimulEvalResults(path)) 111 | 112 | def summarize(self): 113 | results = [] 114 | for score in self.score_list: 115 | if score.bleu == 0: 116 | continue 117 | results.append( 118 | [ 119 | score.name, 120 | round(score.bow_average_lagging / 1000, 2), 121 | round(score.cow_average_lagging / 1000, 2), 122 | round(score.eow_average_lagging / 1000, 2), 123 | round(score.bleu, 2), 124 | ] 125 | ) 126 | results.sort(key=lambda x: x[1]) 127 | return pandas.DataFrame( 128 | results, columns=["name", "BOW_AL", "COW_AL", "EOW_AL", "BLEU"] 129 | ) 130 | -------------------------------------------------------------------------------- /simuleval/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import sys 9 | from argparse import ArgumentParser 10 | from typing import Optional 11 | 12 | from simuleval import options 13 | from simuleval.agents import GenericAgent 14 | from simuleval.agents.service import start_agent_service 15 | from simuleval.evaluator import ( 16 | SentenceLevelEvaluator, 17 | build_evaluator, 18 | build_remote_evaluator, 19 | ) 20 | from simuleval.utils import EVALUATION_SYSTEM_LIST 21 | from simuleval.utils.agent import build_system_args 22 | from simuleval.utils.arguments import check_argument 23 | from simuleval.utils.slurm import submit_slurm_job 24 | 25 | logging.basicConfig( 26 | format="%(asctime)s | %(levelname)-8s | %(name)-16s | %(message)s", 27 | datefmt="%Y-%m-%d %H:%M:%S", 28 | level=logging.INFO, 29 | stream=sys.stderr, 30 | ) 31 | 32 | 33 | logger = logging.getLogger("simuleval.cli") 34 | 35 | 36 | def main(): 37 | if check_argument("remote_eval"): 38 | remote_evaluate() 39 | return 40 | 41 | if check_argument("score_only"): 42 | scoring() 43 | return 44 | 45 | if check_argument("slurm"): 46 | submit_slurm_job() 47 | return 48 | 49 | system, args = build_system_args() 50 | 51 | if check_argument("standalone"): 52 | start_agent_service(system) 53 | return 54 | 55 | # build evaluator 56 | evaluator = build_evaluator(args) 57 | 58 | # evaluate system 59 | evaluator(system) 60 | 61 | 62 | def evaluate( 63 | system_class: GenericAgent, 64 | config_dict: dict = {}, 65 | parser: Optional[ArgumentParser] = None, 66 | ): 67 | EVALUATION_SYSTEM_LIST.append(system_class) 68 | just_for_arg_check = {} 69 | for key, value in config_dict.items(): 70 | if isinstance(value, list): 71 | just_for_arg_check[key] = value[0] 72 | else: 73 | just_for_arg_check[key] = value 74 | if check_argument("slurm", just_for_arg_check): 75 | submit_slurm_job(config_dict, parser) 76 | return 77 | 78 | system, args = build_system_args(config_dict, parser) 79 | 80 | # build evaluator 81 | evaluator = build_evaluator(args) 82 | 83 | # evaluate system 84 | evaluator(system) 85 | 86 | 87 | def scoring(): 88 | parser = options.general_parser() 89 | options.add_evaluator_args(parser) 90 | options.add_scorer_args(parser) 91 | options.add_dataloader_args(parser) 92 | args = parser.parse_args() 93 | evaluator = SentenceLevelEvaluator.from_args(args) 94 | print(evaluator.results) 95 | 96 | 97 | def remote_evaluate(): 98 | # build evaluator 99 | parser = options.general_parser() 100 | options.add_dataloader_args(parser) 101 | options.add_evaluator_args(parser) 102 | options.add_scorer_args(parser) 103 | args = parser.parse_args() 104 | evaluator = build_remote_evaluator(args) 105 | 106 | # evaluate system 107 | evaluator.remote_eval() 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /simuleval/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import build_dataloader # noqa 2 | -------------------------------------------------------------------------------- /simuleval/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from argparse import Namespace 9 | 10 | from .dataloader import ( # noqa 11 | GenericDataloader, 12 | register_dataloader, 13 | register_dataloader_class, 14 | SUPPORTED_MEDIUM, 15 | SUPPORTED_SOURCE_MEDIUM, 16 | SUPPORTED_TARGET_MEDIUM, 17 | DATALOADER_DICT, 18 | ) 19 | from .t2t_dataloader import TextToTextDataloader # noqa 20 | from .s2t_dataloader import SpeechToTextDataloader # noqa 21 | 22 | 23 | logger = logging.getLogger("simuleval.dataloader") 24 | 25 | 26 | def build_dataloader(args: Namespace) -> GenericDataloader: 27 | dataloader_key = getattr(args, "dataloader", None) 28 | if dataloader_key is not None: 29 | assert dataloader_key in DATALOADER_DICT, f"{dataloader_key} is not defined" 30 | logger.info(f"Evaluating from dataloader {dataloader_key}.") 31 | return DATALOADER_DICT[dataloader_key].from_args(args) 32 | if args.demo: 33 | args.source_type = "speech" 34 | args.target_type = "text" 35 | assert args.source_type in SUPPORTED_SOURCE_MEDIUM 36 | assert args.target_type in SUPPORTED_TARGET_MEDIUM 37 | 38 | logger.info(f"Evaluating from {args.source_type} to {args.target_type}.") 39 | return DATALOADER_DICT[f"{args.source_type}-to-{args.target_type}"].from_args(args) 40 | -------------------------------------------------------------------------------- /simuleval/data/dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import abstractmethod 8 | from argparse import ArgumentParser, Namespace 9 | from typing import Any, Dict, List, Optional, Union 10 | 11 | SUPPORTED_MEDIUM = ["text", "speech"] 12 | SUPPORTED_SOURCE_MEDIUM = ["youtube", "text", "speech"] 13 | SUPPORTED_TARGET_MEDIUM = ["text", "speech"] 14 | DATALOADER_DICT = {} 15 | 16 | 17 | def register_dataloader(name): 18 | def register(cls): 19 | DATALOADER_DICT[name] = cls 20 | return cls 21 | 22 | return register 23 | 24 | 25 | def register_dataloader_class(name, cls): 26 | DATALOADER_DICT[name] = cls 27 | 28 | 29 | class GenericDataloader: 30 | """ 31 | Load source and target data 32 | 33 | .. argparse:: 34 | :ref: simuleval.options.add_data_args 35 | :passparser: 36 | :prog: 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | source_list: List[str], 43 | target_list: Union[List[str], List[None]], 44 | tgt_lang_list: Optional[List[str]] = None, 45 | ) -> None: 46 | self.source_list = source_list 47 | self.target_list = target_list 48 | self.tgt_lang_list = tgt_lang_list 49 | assert len(self.source_list) == len(self.target_list) 50 | 51 | def __len__(self): 52 | return len(self.source_list) 53 | 54 | def get_source(self, index: int) -> Any: 55 | return self.preprocess_source(self.source_list[index]) 56 | 57 | def get_target(self, index: int) -> Any: 58 | return self.preprocess_target(self.target_list[index]) 59 | 60 | def get_tgt_lang(self, index: int) -> Optional[str]: 61 | if getattr(self, "tgt_lang_list", None) is None or index >= len( 62 | self.tgt_lang_list 63 | ): 64 | return None 65 | else: 66 | return self.tgt_lang_list[index] 67 | 68 | def __getitem__(self, index: int) -> Dict[str, Any]: 69 | return { 70 | "source": self.get_source(index), 71 | "target": self.get_target(index), 72 | "tgt_lang": self.get_tgt_lang(index), 73 | } 74 | 75 | def preprocess_source(self, source: Any) -> Any: 76 | raise NotImplementedError 77 | 78 | def preprocess_target(self, target: Any) -> Any: 79 | raise NotImplementedError 80 | 81 | @classmethod 82 | def from_args(cls, args: Namespace): 83 | return cls(args.source, args.target) 84 | 85 | @staticmethod 86 | def add_args(parser: ArgumentParser): 87 | parser.add_argument( 88 | "--source", 89 | type=str, 90 | help="Source file.", 91 | ) 92 | parser.add_argument( 93 | "--target", 94 | type=str, 95 | help="Target file.", 96 | ) 97 | parser.add_argument( 98 | "--source-type", 99 | type=str, 100 | choices=SUPPORTED_SOURCE_MEDIUM, 101 | help="Source Data type to evaluate.", 102 | ) 103 | parser.add_argument( 104 | "--target-type", 105 | type=str, 106 | choices=SUPPORTED_TARGET_MEDIUM, 107 | help="Data type to evaluate.", 108 | ) 109 | parser.add_argument( 110 | "--source-segment-size", 111 | type=int, 112 | default=1, 113 | help="Source segment size, For text the unit is # token, for speech is ms", 114 | ) 115 | parser.add_argument( 116 | "--tgt-lang", 117 | type=str, 118 | default=None, 119 | help="Target language", 120 | ) 121 | 122 | 123 | class IterableDataloader: 124 | cur_index: int 125 | 126 | @abstractmethod 127 | def __iter__(self): 128 | pass 129 | 130 | @abstractmethod 131 | def __next__(self): 132 | pass 133 | -------------------------------------------------------------------------------- /simuleval/data/dataloader/s2t_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import annotations 8 | from pathlib import Path 9 | from typing import List, Union, Optional 10 | from .dataloader import GenericDataloader 11 | from simuleval.data.dataloader import register_dataloader 12 | from argparse import Namespace 13 | from urllib.parse import urlparse, parse_qs 14 | 15 | try: 16 | import yt_dlp as youtube_dl 17 | from pydub import AudioSegment 18 | except ImportError: 19 | yt_dlp = AudioSegment = None 20 | 21 | try: 22 | import soundfile 23 | 24 | IS_IMPORT_SOUNDFILE = True 25 | except Exception: 26 | IS_IMPORT_SOUNDFILE = False 27 | 28 | 29 | def download_youtube_video(url): 30 | def get_video_id(url): 31 | url_data = urlparse(url) 32 | query = parse_qs(url_data.query) 33 | video = query.get("v", []) 34 | if video: 35 | return video[0] 36 | else: 37 | raise Exception("unrecoginzed url format.") 38 | 39 | id = get_video_id(url) 40 | name = f"{id}.wav" 41 | 42 | if not Path(name).exists(): 43 | ydl_opts = { 44 | "format": "bestaudio/best", 45 | "postprocessors": [ 46 | { 47 | "key": "FFmpegExtractAudio", 48 | "preferredcodec": "wav", 49 | "preferredquality": "192", 50 | } 51 | ], 52 | "outtmpl": id, # name the file "downloaded_video" with original extension 53 | } 54 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 55 | ydl.download([url]) 56 | 57 | sound = AudioSegment.from_wav(name) 58 | sound = sound.set_channels(1).set_frame_rate(16000) 59 | sound.export(name, format="wav") 60 | return name 61 | 62 | 63 | def load_list_from_file(file_path: Union[Path, str]) -> List[str]: 64 | with open(file_path) as f: 65 | return [line.strip() for line in f] 66 | 67 | 68 | @register_dataloader("speech-to-text") 69 | class SpeechToTextDataloader(GenericDataloader): 70 | def __init__( 71 | self, 72 | source_list: List[str], 73 | target_list: List[str], 74 | tgt_lang_list: Optional[List[str]] = None, 75 | ) -> None: 76 | super().__init__(source_list, target_list, tgt_lang_list) 77 | 78 | def preprocess_source(self, source: Union[Path, str]) -> List[float]: 79 | assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed." 80 | samples, _ = soundfile.read(source, dtype="float32") 81 | samples = samples.tolist() 82 | return samples 83 | 84 | def preprocess_target(self, target: str) -> str: 85 | return target 86 | 87 | def get_source_audio_info(self, index: int) -> soundfile._SoundFileInfo: 88 | return soundfile.info(self.get_source_audio_path(index)) 89 | 90 | def get_source_audio_path(self, index: int): 91 | return self.source_list[index] 92 | 93 | @classmethod 94 | def from_files( 95 | cls, 96 | source: Union[Path, str], 97 | target: Union[Path, str], 98 | tgt_lang: Union[Path, str], 99 | ) -> SpeechToTextDataloader: 100 | source_list = load_list_from_file(source) 101 | target_list = load_list_from_file(target) 102 | tgt_lang_list = [] 103 | if tgt_lang is not None: 104 | tgt_lang_list = load_list_from_file(tgt_lang) 105 | dataloader = cls(source_list, target_list, tgt_lang_list) 106 | return dataloader 107 | 108 | @classmethod 109 | def from_args(cls, args: Namespace): 110 | args.source_type = "speech" 111 | args.target_type = "text" 112 | if args.demo: 113 | return cls([], [], []) 114 | return cls.from_files(args.source, args.target, args.tgt_lang) 115 | 116 | 117 | @register_dataloader("speech-to-speech") 118 | class SpeechToSpeechDataloader(SpeechToTextDataloader): 119 | @classmethod 120 | def from_files( 121 | cls, 122 | source: Union[Path, str], 123 | target: Union[Path, str], 124 | tgt_lang: Union[Path, str, None] = None, 125 | ) -> SpeechToSpeechDataloader: 126 | source_list = load_list_from_file(source) 127 | target_list = load_list_from_file(target) 128 | tgt_lang_list = [] 129 | if tgt_lang is not None: 130 | tgt_lang_list = load_list_from_file(tgt_lang) 131 | dataloader = cls(source_list, target_list, tgt_lang_list) 132 | return dataloader 133 | 134 | @classmethod 135 | def from_args(cls, args: Namespace): 136 | args.source_type = "speech" 137 | args.target_type = "speech" 138 | return cls.from_files(args.source, args.target, args.tgt_lang) 139 | 140 | 141 | @register_dataloader("youtube-to-text") 142 | class YoutubeToTextDataloader(SpeechToTextDataloader): 143 | @classmethod 144 | def from_youtube( 145 | cls, source: Union[Path, str], target: Union[Path, str] 146 | ) -> YoutubeToTextDataloader: 147 | assert AudioSegment is not None 148 | assert youtube_dl is not None 149 | source_list = [download_youtube_video(source)] 150 | target_list = [target] 151 | dataloader = cls(source_list, target_list) 152 | return dataloader 153 | 154 | @classmethod 155 | def from_args(cls, args: Namespace): 156 | args.source_type = "youtube" 157 | args.target_type = "text" 158 | return cls.from_youtube(args.source, args.target) 159 | 160 | 161 | @register_dataloader("youtube-to-speech") 162 | class YoutubeToSpeechDataloader(YoutubeToTextDataloader): 163 | @classmethod 164 | def from_args(cls, args: Namespace): 165 | args.source_type = "youtube" 166 | args.target_type = "speech" 167 | return cls.from_youtube(args.source, args.target) 168 | -------------------------------------------------------------------------------- /simuleval/data/dataloader/t2t_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import annotations 8 | from pathlib import Path 9 | from typing import Callable, List, Union, Optional 10 | from .dataloader import GenericDataloader 11 | from simuleval.data.dataloader import register_dataloader 12 | from argparse import Namespace 13 | 14 | 15 | @register_dataloader("text-to-text") 16 | class TextToTextDataloader(GenericDataloader): 17 | def __init__( 18 | self, source_list: List[str], target_list: Union[List[str], List[None]] 19 | ) -> None: 20 | super().__init__(source_list, target_list) 21 | self.source_splitter = lambda x: x.split() 22 | self.target_splitter = lambda x: x 23 | 24 | def set_source_splitter(self, function: Callable) -> None: 25 | # TODO, make is configurable 26 | self.splitter = function 27 | 28 | def preprocess_source(self, source: str) -> List: 29 | return self.source_splitter(source) 30 | 31 | def preprocess_target(self, target: str) -> List: 32 | return self.target_splitter(target) 33 | 34 | @classmethod 35 | def from_files( 36 | cls, source: Union[Path, str], target: Optional[Union[Path, str]] 37 | ) -> TextToTextDataloader: 38 | assert source 39 | with open(source) as f: 40 | source_list = f.readlines() 41 | if target: 42 | with open(target) as f: 43 | target_list = f.readlines() 44 | else: 45 | target_list = [None for _ in source_list] 46 | dataloader = cls(source_list, target_list) 47 | return dataloader 48 | 49 | @classmethod 50 | def from_args(cls, args: Namespace): 51 | args.source_type = "text" 52 | args.target_type = "text" 53 | return cls.from_files(args.source, args.target) 54 | -------------------------------------------------------------------------------- /simuleval/data/segments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | from dataclasses import dataclass, field 9 | from typing import Union, Optional 10 | 11 | 12 | @dataclass 13 | class Segment: 14 | index: int = 0 15 | content: list = field(default_factory=list) 16 | finished: bool = False 17 | is_empty: bool = False 18 | data_type: str = None 19 | tgt_lang: str = None 20 | config: dict = field(default_factory=dict) 21 | 22 | def json(self) -> str: 23 | info_dict = {attribute: value for attribute, value in self.__dict__.items()} 24 | return json.dumps(info_dict) 25 | 26 | @classmethod 27 | def from_json(cls, json_string: str): 28 | return cls(**json.loads(json_string)) 29 | 30 | 31 | @dataclass 32 | class EmptySegment(Segment): 33 | is_empty: bool = True 34 | 35 | 36 | @dataclass 37 | class TextSegment(Segment): 38 | content: str = "" 39 | data_type: str = "text" 40 | tgt_lang: Optional[str] = None 41 | 42 | 43 | @dataclass 44 | class SpeechSegment(Segment): 45 | sample_rate: int = -1 46 | data_type: str = "speech" 47 | tgt_lang: Optional[str] = None 48 | 49 | 50 | @dataclass 51 | class SpeechTextSegment: 52 | text_segment: Union[EmptySegment, TextSegment] 53 | speech_segment: Union[EmptySegment, SpeechSegment] 54 | data_type: str = "speech_text" 55 | 56 | 57 | def segment_from_json_string(string: str): 58 | info_dict = json.loads(string) 59 | if info_dict["data_type"] == "text": 60 | return TextSegment.from_json(string) 61 | elif info_dict["data_type"] == "speech": 62 | return SpeechSegment.from_json(string) 63 | elif info_dict["data_type"] == "speech_text": 64 | return SpeechTextSegment.from_json(string) 65 | else: 66 | return EmptySegment.from_json(string) 67 | -------------------------------------------------------------------------------- /simuleval/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .evaluator import SentenceLevelEvaluator 8 | from .remote import RemoteEvaluator 9 | from .remote import DemoRemote 10 | 11 | 12 | def build_evaluator(args): 13 | return SentenceLevelEvaluator.from_args(args) 14 | 15 | 16 | def build_remote_evaluator(args): 17 | if args.demo: 18 | return DemoRemote(build_evaluator(args)) 19 | return RemoteEvaluator(build_evaluator(args)) 20 | -------------------------------------------------------------------------------- /simuleval/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import json 9 | import logging 10 | import numbers 11 | import os 12 | from argparse import Namespace 13 | from pathlib import Path 14 | from typing import Dict, Generator, Optional 15 | 16 | import pandas 17 | import yaml 18 | from simuleval.data.dataloader import GenericDataloader, build_dataloader 19 | from simuleval.data.dataloader.dataloader import IterableDataloader 20 | from tqdm import tqdm 21 | 22 | from .instance import INSTANCE_TYPE_DICT, LogInstance 23 | from .scorers import get_scorer_class 24 | from .scorers.latency_scorer import LatencyScorer 25 | from .scorers.quality_scorer import QualityScorer 26 | from ..utils.visualize import Visualize 27 | 28 | try: 29 | import sentencepiece 30 | 31 | IS_IMPORT_SPM = True 32 | except Exception: 33 | IS_IMPORT_SPM = False 34 | 35 | 36 | logger = logging.getLogger("simuleval.sentence_level_evaluator") 37 | 38 | 39 | class SentenceLevelEvaluator(object): 40 | """ 41 | Sentence Level evaluator. It iterates over sentence pairs and run evaluation. 42 | 43 | 44 | .. code-block:: python 45 | 46 | for instance in self.maybe_tqdm(self.instances.values()): 47 | agent.reset() 48 | while not instance.finish_prediction: 49 | input_segment = instance.send_source(self.source_segment_size) 50 | output_segment = agent.pushpop(input_segment) 51 | instance.receive_prediction(output_segment) 52 | 53 | 54 | Attributes: 55 | instances: collections of sentence pairs. Instances also keep track of delays. 56 | latency_scorers (List[~simuleval.scorers.latency_scorer.LatencyScorer]): Scorers for latency evaluation. 57 | quality_scorers (List[~simuleval.scorers.latency_scorer.QualityScorer]): Scorers for quality evaluation. 58 | output: output directory 59 | 60 | Evaluator related command line arguments: 61 | 62 | .. argparse:: 63 | :ref: simuleval.options.add_evaluator_args 64 | :passparser: 65 | :prog: 66 | """ 67 | 68 | def __init__( 69 | self, 70 | dataloader: Optional[GenericDataloader], 71 | quality_scorers: Dict[str, QualityScorer], 72 | latency_scorers: Dict[str, LatencyScorer], 73 | args: Namespace, 74 | ) -> None: 75 | self.dataloader = dataloader 76 | self.quality_scorers = quality_scorers 77 | self.latency_scorers = latency_scorers 78 | self.instances = {} 79 | 80 | self.args = args 81 | self.output = Path(args.output) if args.output else None 82 | self.score_only = args.score_only 83 | self.no_scoring = args.no_scoring 84 | self.source_segment_size = getattr(args, "source_segment_size", 1) 85 | self.source_type = getattr(args, "source_type", None) 86 | self.target_type = getattr(args, "target_type", None) 87 | self.visualize = args.visualize 88 | 89 | self.target_spm_model = None 90 | if args.eval_latency_unit == "spm": 91 | assert args.eval_latency_spm_model 92 | assert IS_IMPORT_SPM 93 | self.target_spm_model = sentencepiece.SentencePieceProcessor( 94 | model_file=args.eval_latency_spm_model 95 | ) 96 | 97 | if ( 98 | self.source_type is None 99 | and self.target_type is None 100 | and self.output is not None 101 | ): 102 | with open(self.output / "config.yaml") as f: 103 | configs = yaml.safe_load(f) 104 | self.source_type = configs["source_type"] 105 | self.target_type = configs["target_type"] 106 | 107 | assert self.source_type 108 | assert self.target_type 109 | 110 | if self.output is not None: 111 | os.makedirs(self.output, exist_ok=True) 112 | with open(self.output / "config.yaml", "w") as f: 113 | yaml.dump( 114 | {"source_type": self.source_type, "target_type": self.target_type}, 115 | f, 116 | default_flow_style=False, 117 | ) 118 | 119 | self.instance_class = INSTANCE_TYPE_DICT[ 120 | f"{self.source_type}-{self.target_type}" 121 | ] 122 | self.start_index = getattr(args, "start_index", 0) 123 | self.end_index = getattr(args, "end_index", -1) 124 | 125 | if not self.score_only: 126 | if self.output: 127 | if ( 128 | self.args.continue_unfinished 129 | and (self.output / "instances.log").exists() 130 | ): 131 | with open(self.output / "instances.log", "r") as f: 132 | line = None 133 | for line in f: # noqa 134 | pass 135 | if line is not None: 136 | last_info = json.loads(line.strip()) 137 | self.start_index = last_info["index"] + 1 138 | else: 139 | self.output.mkdir(exist_ok=True, parents=True) 140 | open(self.output / "instances.log", "w").close() 141 | if self.end_index < 0: 142 | assert self.dataloader is not None 143 | self.end_index = len(self.dataloader) 144 | 145 | self.build_instances() 146 | 147 | iterable = self.instances.values() 148 | if isinstance(self.dataloader, IterableDataloader): 149 | iterable = self.dataloader 150 | 151 | if not self.args.no_progress_bar and not self.score_only: 152 | self.iterator = tqdm( 153 | iterable, 154 | initial=self.start_index, 155 | ) 156 | else: 157 | self.iterator = iterable 158 | 159 | def write_log(self, instance): 160 | if self.output is not None: 161 | with open(self.output / "instances.log", "a") as f: 162 | f.write(json.dumps(instance.summarize()) + "\n") 163 | 164 | def build_instances(self): 165 | if self.score_only: 166 | self.build_instances_from_log() 167 | else: 168 | self.build_instances_from_dataloader() 169 | 170 | def build_instances_from_log(self): 171 | self.instances = {} 172 | if self.output is not None: 173 | with open(self.output / "instances.log", "r") as f: 174 | for line in f: 175 | instance = LogInstance(line.strip(), self.args.eval_latency_unit) 176 | index = instance.index - self.start_index 177 | self.instances[index] = instance 178 | self.instances[index].set_target_spm_model(self.target_spm_model) 179 | 180 | def build_instances_from_dataloader(self): 181 | if isinstance(self.dataloader, IterableDataloader): 182 | return 183 | 184 | for i in self.get_indices(): 185 | self.instances[i] = self.instance_class(i, self.dataloader, self.args) 186 | self.instances[i].set_target_spm_model(self.target_spm_model) 187 | 188 | def __len__(self) -> int: 189 | return self.end_index - self.start_index 190 | 191 | def get_indices(self) -> Generator: 192 | if self.end_index < 0: 193 | self.end_index = max(self.instances.keys()) + 1 194 | 195 | if self.start_index > self.end_index: 196 | return [] 197 | 198 | for index in range(self.start_index, self.end_index): 199 | yield index 200 | 201 | @property 202 | def quality(self) -> Dict[str, float]: 203 | return { 204 | name: scorer(self.instances) 205 | for name, scorer in self.quality_scorers.items() 206 | } 207 | 208 | @property 209 | def latency(self) -> Dict[str, float]: 210 | return { 211 | name: scorer(self.instances) 212 | for name, scorer in self.latency_scorers.items() 213 | } 214 | 215 | @property 216 | def results(self): 217 | scores = {**self.quality, **self.latency} 218 | new_scores = {} 219 | for name, value in scores.items(): 220 | if isinstance(value, numbers.Number): 221 | value = round(value, 3) 222 | new_scores[name] = [value] 223 | 224 | df = pandas.DataFrame(new_scores) 225 | if self.output and self.visualize: 226 | self.make_visual() 227 | return df 228 | 229 | def dump_results(self) -> None: 230 | results = self.results 231 | if self.output: 232 | results.to_csv(self.output / "scores.tsv", sep="\t", index=False) 233 | 234 | logger.info("Results:") 235 | print(results.to_string(index=False)) 236 | 237 | def dump_metrics(self) -> None: 238 | metrics = pandas.DataFrame([ins.metrics for ins in self.instances.values()]) 239 | metrics = metrics.round(3) 240 | if self.output: 241 | metrics.to_csv(self.output / "metrics.tsv", sep="\t", index=False) 242 | 243 | def is_finished(self, instance) -> bool: 244 | if hasattr(instance, "source_finished_reading"): 245 | return instance.source_finished_reading 246 | return instance.finish_prediction 247 | 248 | def make_visual(self): 249 | with open(self.output / "instances.log", "r") as file: 250 | for line in file: 251 | # Load data & index 252 | data = json.loads(line) 253 | index = data.get("index", 0) 254 | 255 | # Create object & graph 256 | visualize = Visualize(data, index, self.output) 257 | visualize.make_graph() 258 | 259 | def __call__(self, system): 260 | with ( 261 | open(self.output / "instances.log", "a") 262 | if self.output 263 | else contextlib.nullcontext() 264 | ) as file: 265 | system.reset() 266 | for sample in self.iterator: 267 | instance = ( 268 | self.instance_class( 269 | self.dataloader.cur_index, self.dataloader, self.args 270 | ) 271 | if isinstance(self.dataloader, IterableDataloader) 272 | else sample 273 | ) 274 | while not self.is_finished(instance): 275 | input_segment = instance.send_source(self.source_segment_size) 276 | output_segment = system.pushpop(input_segment) 277 | instance.receive_prediction(output_segment) 278 | if instance.finish_prediction: 279 | # if instance.finish_prediction where set by the reader, 280 | # source_finished_reading will be set as well. If it is 281 | # set by any of the intermediate components, then we didn't 282 | # end yet. We are going to clear the state and continue 283 | # processing the rest of the input. 284 | system.reset() 285 | 286 | if not self.score_only and self.output: 287 | file.write(json.dumps(instance.summarize()) + "\n") 288 | 289 | if self.output: 290 | self.build_instances_from_log() 291 | if not self.no_scoring: 292 | self.dump_results() 293 | self.dump_metrics() 294 | if self.output and self.visualize: 295 | self.make_visual() 296 | 297 | @classmethod 298 | def from_args(cls, args): 299 | if not args.score_only: 300 | dataloader = build_dataloader(args) 301 | else: 302 | dataloader = None 303 | 304 | latency_scorers = {} 305 | use_ref_len = not args.no_use_ref_len 306 | for name in args.latency_metrics: 307 | latency_scorers[name] = get_scorer_class("latency", name).from_args(args) 308 | if args.computation_aware: 309 | latency_scorers[name + "_CA"] = get_scorer_class("latency", name)( 310 | computation_aware=True, use_ref_len=use_ref_len 311 | ) 312 | 313 | quality_scorers = {} 314 | for name in args.quality_metrics: 315 | quality_scorers[name] = get_scorer_class("quality", name).from_args(args) 316 | 317 | return cls(dataloader, quality_scorers, latency_scorers, args) 318 | -------------------------------------------------------------------------------- /simuleval/evaluator/remote.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import logging 9 | import threading 10 | import time 11 | from queue import Queue 12 | import numpy as np 13 | 14 | try: 15 | import wave 16 | import pyaudio 17 | from silero_vad import load_silero_vad, read_audio, get_speech_timestamps 18 | except: 19 | wave, pyaudio, load_silero_vad, read_audio, get_speech_timestamps = [ 20 | None for _ in range(5) 21 | ] 22 | 23 | from simuleval.data.segments import ( 24 | Segment, 25 | segment_from_json_string, 26 | SpeechSegment, 27 | EmptySegment, 28 | ) 29 | from simuleval.evaluator import SentenceLevelEvaluator 30 | import requests 31 | 32 | logger = logging.getLogger("simuleval.remote_evaluator") 33 | 34 | 35 | class RemoteEvaluator: 36 | def __init__(self, evaluator: SentenceLevelEvaluator) -> None: 37 | self.evaluator = evaluator 38 | self.address = evaluator.args.remote_address 39 | self.port = evaluator.args.remote_port 40 | self.source_segment_size = evaluator.args.source_segment_size 41 | self.base_url = f"http://{self.address}:{self.port}" 42 | 43 | def send_source(self, segment: Segment): 44 | url = f"{self.base_url}/input" 45 | requests.put(url, data=segment.json()) 46 | 47 | def receive_prediction(self) -> Segment: 48 | url = f"{self.base_url}/output" 49 | r = requests.get(url) 50 | return segment_from_json_string(r.text) 51 | 52 | def system_reset(self): 53 | requests.post(f"{self.base_url}/reset") 54 | 55 | def results(self): 56 | return self.evaluator.results() 57 | 58 | def remote_eval(self): 59 | for instance in self.evaluator.iterator: 60 | self.system_reset() 61 | while not instance.finish_prediction: 62 | self.send_source(instance.send_source(self.source_segment_size)) 63 | # instance.py line 275, returns a segment object with all the floats in the 500 ms range 64 | 65 | output_segment = self.receive_prediction() 66 | # gets the prediction in text! like "This"... 67 | # refreshes each time. "This" for the 1st, "is" for the second 68 | 69 | instance.receive_prediction(output_segment) 70 | # instance.py line 190 71 | # processes data, gets in a prediction list with ["This", "is"] on 2nd iteration 72 | self.evaluator.write_log(instance) 73 | 74 | self.evaluator.dump_results() 75 | 76 | 77 | class DemoRemote(RemoteEvaluator): 78 | def __init__(self, evaluator: SentenceLevelEvaluator) -> None: 79 | if None in [wave, pyaudio, load_silero_vad, read_audio, get_speech_timestamps]: 80 | raise Exception( 81 | "Please install wave, pyaudio, and silero_vad to run the demo" 82 | ) 83 | super().__init__(evaluator) 84 | self.float_array = np.asarray([]) 85 | self.sample_rate = 16000 86 | self.finished = False 87 | self.queue = Queue(maxsize=0) 88 | self.VADmodel = load_silero_vad() 89 | self.silence_count = 0 90 | 91 | def record_audio(self): 92 | CHUNK = 1024 93 | FORMAT = pyaudio.paInt16 94 | CHANNELS = 1 if sys.platform == "darwin" else 2 95 | RATE = self.sample_rate 96 | RECORD_SECONDS = 10000 # Indefinite time 97 | 98 | with wave.open(f"output.wav", "wb") as wf: 99 | p = pyaudio.PyAudio() 100 | wf.setnchannels(CHANNELS) 101 | wf.setsampwidth(p.get_sample_size(FORMAT)) 102 | wf.setframerate(RATE) 103 | 104 | stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True) 105 | 106 | all_data = bytearray() 107 | start = time.time() 108 | for _ in range(0, round(RATE // CHUNK * RECORD_SECONDS)): 109 | data = stream.read(CHUNK) 110 | wf.writeframes(data) 111 | all_data += data 112 | if time.time() - start > (self.source_segment_size / 1000.0): 113 | self.queue.put(all_data) 114 | all_data = bytearray() 115 | start = time.time() 116 | 117 | self.queue.put(all_data) 118 | stream.close() 119 | p.terminate() 120 | self.finished = True 121 | 122 | def remote_eval(self): 123 | # Initialization 124 | self.system_reset() 125 | recording = threading.Thread(target=self.record_audio) 126 | recording.start() 127 | 128 | # Start recording 129 | print("Recording...") 130 | while not self.finished or not self.queue.empty(): 131 | data = byte_to_float(self.queue.get()).tolist() 132 | # VAD 133 | speech_timestamps = get_speech_timestamps( 134 | audio=data, model=self.VADmodel, sampling_rate=self.sample_rate 135 | ) 136 | 137 | if len(speech_timestamps) != 0: # has audio 138 | self.silence_count = 0 139 | else: 140 | self.silence_count += 1 141 | 142 | if self.silence_count <= 4: 143 | segment = SpeechSegment( 144 | index=self.source_segment_size, 145 | content=data, 146 | sample_rate=self.sample_rate, 147 | finished=False, 148 | ) 149 | self.send_source(segment) 150 | output_segment = self.receive_prediction() 151 | if len(output_segment.content) == 0: 152 | continue 153 | prediction_list = str(output_segment.content.replace(" ", "")) 154 | print(prediction_list, end=" ") 155 | sys.stdout.flush() 156 | 157 | else: 158 | segment = SpeechSegment( 159 | index=self.source_segment_size, 160 | content=[0.0 for _ in range(8192)], 161 | sample_rate=self.sample_rate, 162 | finished=True, 163 | ) 164 | self.send_source(segment) 165 | output_segment = self.receive_prediction() 166 | self.silence_count = 0 167 | self.system_reset() 168 | 169 | 170 | def pcm2float(sig, dtype="float32"): 171 | sig = np.asarray(sig) 172 | if sig.dtype.kind not in "iu": 173 | raise TypeError("'sig' must be an array of integers") 174 | dtype = np.dtype(dtype) 175 | if dtype.kind != "f": 176 | raise TypeError("'dtype' must be a floating point type") 177 | 178 | # pcm (16 bit) min = -32768, max = 32767, map it to -1 to 1 by dividing by max (32767) 179 | i = np.iinfo(sig.dtype) 180 | abs_max = 2 ** (i.bits - 1) 181 | offset = i.min + abs_max 182 | return (sig.astype(dtype) - offset) / abs_max 183 | 184 | 185 | def byte_to_float(byte): 186 | # byte -> int16(PCM_16) -> float32 187 | return pcm2float(np.frombuffer(byte, dtype=np.int16), dtype="float32") 188 | -------------------------------------------------------------------------------- /simuleval/evaluator/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | from .latency_scorer import LATENCY_SCORERS_DICT 2 | from .quality_scorer import QUALITY_SCORERS_DICT 3 | 4 | 5 | def get_scorer_class(scorer_type, name): 6 | if scorer_type == "quality": 7 | scorer_dict = QUALITY_SCORERS_DICT 8 | else: 9 | scorer_dict = LATENCY_SCORERS_DICT 10 | 11 | if name not in scorer_dict: 12 | raise RuntimeError(f"No {scorer_type} metric called {name}") 13 | 14 | return scorer_dict[name] 15 | -------------------------------------------------------------------------------- /simuleval/evaluator/scorers/quality_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import string 9 | import subprocess 10 | from pathlib import Path 11 | from typing import Dict 12 | 13 | import sacrebleu 14 | import tqdm 15 | from sacrebleu.metrics.bleu import BLEU 16 | 17 | QUALITY_SCORERS_DICT = {} 18 | 19 | 20 | def register_quality_scorer(name): 21 | def register(cls): 22 | QUALITY_SCORERS_DICT[name] = cls 23 | return cls 24 | 25 | return register 26 | 27 | 28 | class QualityScorer: 29 | def __init__(self) -> None: 30 | pass 31 | 32 | def __call__(self, instances: Dict) -> float: 33 | raise NotImplementedError 34 | 35 | @staticmethod 36 | def add_args(parser): 37 | pass 38 | 39 | 40 | def add_sacrebleu_args(parser): 41 | parser.add_argument( 42 | "--sacrebleu-tokenizer", 43 | type=str, 44 | default=sacrebleu.metrics.METRICS["BLEU"].TOKENIZER_DEFAULT, 45 | choices=sacrebleu.metrics.METRICS["BLEU"].TOKENIZERS, 46 | help="Tokenizer in sacrebleu", 47 | ) 48 | 49 | 50 | @register_quality_scorer("WER") 51 | class WERScorer(QualityScorer): 52 | """ 53 | Compute Word Error Rate (WER) 54 | 55 | Usage: 56 | :code:`--quality-metrics WER` 57 | """ 58 | 59 | def __init__(self, args) -> None: 60 | super().__init__() 61 | try: 62 | import editdistance as ed 63 | except ImportError: 64 | raise ImportError("Please install editdistance to use WER scorer") 65 | self.logger = logging.getLogger("simuleval.scorer.wer") 66 | self.logger.warning("WER scorer only support language with spaces.") 67 | self.logger.warning( 68 | "Current WER scorer is on raw text (un-tokenized with punctuations)." 69 | ) 70 | self.ed = ed 71 | 72 | def __call__(self, instances: Dict) -> float: 73 | distance = 0 74 | ref_length = 0 75 | for ins in instances.values(): 76 | distance += self.ed.eval(ins.prediction.split(), ins.reference.split()) 77 | ref_length += len(ins.reference.split()) 78 | if ref_length == 0: 79 | self.logger.warning("Reference length is 0. Return WER as 0.") 80 | return 0 81 | 82 | return 100.0 * distance / ref_length 83 | 84 | @classmethod 85 | def from_args(cls, args): 86 | return cls(args) 87 | 88 | 89 | @register_quality_scorer("BLEU") 90 | class SacreBLEUScorer(QualityScorer): 91 | """ 92 | SacreBLEU Scorer 93 | 94 | Usage: 95 | :code:`--quality-metrics BLEU` 96 | 97 | Additional command line arguments: 98 | 99 | .. argparse:: 100 | :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args 101 | :passparser: 102 | :prog: 103 | """ 104 | 105 | def __init__(self, tokenizer: str = "13a") -> None: 106 | super().__init__() 107 | self.logger = logging.getLogger("simuleval.scorer.bleu") 108 | self.tokenizer = tokenizer 109 | 110 | def __call__(self, instances: Dict) -> float: 111 | try: 112 | return ( 113 | BLEU(tokenize=self.tokenizer) 114 | .corpus_score( 115 | [ins.prediction for ins in instances.values()], 116 | [[ins.reference for ins in instances.values()]], 117 | ) 118 | .score 119 | ) 120 | except Exception as e: 121 | self.logger.error(str(e)) 122 | return 0 123 | 124 | @staticmethod 125 | def add_args(parser): 126 | add_sacrebleu_args(parser) 127 | 128 | @classmethod 129 | def from_args(cls, args): 130 | return cls(args.sacrebleu_tokenizer) 131 | 132 | 133 | @register_quality_scorer("ASR_BLEU") 134 | class ASRSacreBLEUScorer(QualityScorer): 135 | """ 136 | ASR + SacreBLEU Scorer (BETA version) 137 | 138 | Usage: 139 | :code:`--quality-metrics ASR_BLEU` 140 | 141 | Additional command line arguments: 142 | 143 | .. argparse:: 144 | :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args 145 | :passparser: 146 | :prog: 147 | """ 148 | 149 | def __init__(self, tokenizer: str = "13a", target_lang: str = "en") -> None: 150 | super().__init__() 151 | self.logger = logging.getLogger("simuleval.scorer.asr_bleu") 152 | self.tokenizer = tokenizer 153 | self.target_lang = target_lang 154 | 155 | def __call__(self, instances: Dict) -> float: 156 | transcripts = self.asr_transcribe(instances) 157 | score = ( 158 | BLEU(tokenize=self.tokenizer) 159 | .corpus_score( 160 | transcripts, 161 | [[ins.reference for ins in instances.values()]], 162 | ) 163 | .score 164 | ) 165 | return score 166 | 167 | def asr_transcribe(self, instances): 168 | self.logger.warn("Beta feature: Evaluating speech output. Faieseq is required.") 169 | try: 170 | import fairseq 171 | 172 | fairseq_path = Path(fairseq.__path__[0]).parent # type: ignore 173 | except Exception: 174 | self.logger.warn("Please install fairseq.") 175 | return ["" for _ in instances.keys()] 176 | 177 | wav_dir = Path(instances[0].prediction).absolute().parent 178 | root_dir = wav_dir.parent 179 | transcripts_path = root_dir / "asr_transcripts.txt" 180 | asr_cmd_bash_path = root_dir / "asr_cmd.bash" 181 | 182 | # This is a dummy reference. The bleu score will be compute separately. 183 | reference_path = root_dir / "instances.log" 184 | 185 | fairseq_asr_bleu_cmd = "\n".join( 186 | [ 187 | f"cd {fairseq_path.as_posix()}/examples/speech_to_speech/asr_bleu/", 188 | " ".join( 189 | [ 190 | "python compute_asr_bleu.py", 191 | f"--reference_path {reference_path.as_posix()}", 192 | f"--lang {self.target_lang}", 193 | f"--audio_dirpath {wav_dir.as_posix()}", 194 | "--reference_format txt", 195 | f"--transcripts_path {(root_dir / 'asr_transcripts.txt').as_posix()}", 196 | ] 197 | ), 198 | ] 199 | ) 200 | with open(asr_cmd_bash_path, "w") as f: 201 | f.write(fairseq_asr_bleu_cmd + "\n") 202 | 203 | process = subprocess.Popen(["bash", asr_cmd_bash_path], stdout=subprocess.PIPE) 204 | _, stderr = process.communicate() 205 | 206 | if process.returncode != 0: 207 | self.logger.error("ASR on target speech failed:") 208 | self.logger.error(str(stderr) + "\n") 209 | return ["" for _ in instances.keys()] 210 | 211 | with open(transcripts_path, "r") as f: 212 | transcripts = [line.strip() for line in f] 213 | 214 | for idx, item in enumerate(transcripts): 215 | with open(wav_dir / f"{idx}_pred.txt", "w") as f: 216 | f.write(item.lower() + "\n") 217 | 218 | return transcripts 219 | 220 | @staticmethod 221 | def add_args(parser): 222 | add_sacrebleu_args(parser) 223 | parser.add_argument( 224 | "--target-speech-lang", 225 | type=str, 226 | default="en", 227 | help="The language of target speech", 228 | ) 229 | 230 | @classmethod 231 | def from_args(cls, args): 232 | return cls(args.sacrebleu_tokenizer, args.target_speech_lang) 233 | 234 | 235 | PUNCTUATIONS_EXCLUDE_APOSTROPHE = ( 236 | string.punctuation.replace("'", "") + "¡¨«°³º»¿‘“”…♪♫ˆᵉ™,ʾ˚" 237 | ) 238 | PUNCTUATIONS_TO_SPACE = "-/–·—•" 239 | 240 | 241 | def remove_punctuations(text, punctuations=string.punctuation): 242 | text = text.translate( 243 | str.maketrans(PUNCTUATIONS_TO_SPACE, " " * len(PUNCTUATIONS_TO_SPACE)) 244 | ) 245 | return text.translate(str.maketrans("", "", punctuations)) 246 | 247 | 248 | @register_quality_scorer("WHISPER_ASR_BLEU") 249 | class WhisperASRSacreBLEUScorer(QualityScorer): 250 | """ 251 | Whisper ASR + SacreBLEU Scorer with whisper model 252 | 253 | Usage: 254 | :code:`--quality-metrics ASR_BLEU` 255 | 256 | Additional command line arguments: 257 | 258 | .. argparse:: 259 | :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args 260 | :passparser: 261 | :prog: 262 | """ 263 | 264 | def __init__( 265 | self, 266 | tokenizer: str = "13a", 267 | target_lang: str = "en", 268 | model_size: str = "base", 269 | temperature: float = 0.0, 270 | lowercase: bool = False, 271 | remove_punctuations: bool = False, 272 | ) -> None: 273 | super().__init__() 274 | self.logger = logging.getLogger("simuleval.scorer.whisper_asr_bleu") 275 | self.tokenizer = tokenizer 276 | self.target_lang = target_lang 277 | self.model_size = model_size 278 | self.temperature = temperature 279 | self.lowercase = lowercase 280 | self.remove_punctuations = remove_punctuations 281 | 282 | def __call__(self, instances: Dict) -> float: 283 | transcripts = self.asr_transcribe(instances) 284 | score = ( 285 | BLEU(tokenize=self.tokenizer) 286 | .corpus_score( 287 | transcripts, 288 | [[ins.reference for ins in instances.values()]], 289 | ) 290 | .score 291 | ) 292 | return score 293 | 294 | def asr_transcribe(self, instances): 295 | self.logger.info( 296 | "Evaluating speech output by ASR BLEU. whisper and sacrebleu are required." 297 | ) 298 | self.logger.info("Configs:") 299 | self.logger.info(f"tokenizer = {self.tokenizer}") 300 | self.logger.info(f"target_lang = {self.target_lang}") 301 | self.logger.info(f"model_size = {self.model_size}") 302 | self.logger.info(f"temperature = {self.temperature}") 303 | self.logger.info(f"lowercase = {self.lowercase}") 304 | self.logger.info(f"remove_punctuations = {self.remove_punctuations}") 305 | try: 306 | import whisper 307 | except Exception: 308 | self.logger.warn("Please install whisper.") 309 | return ["" for _ in instances.keys()] 310 | 311 | model = whisper.load_model(self.model_size) 312 | wav_dir = Path(instances[0].prediction).absolute().parent 313 | 314 | transcripts = [] 315 | for index in tqdm.tqdm(instances.keys()): 316 | wav_path = wav_dir / f"{index}_pred.wav" 317 | if wav_path.exists(): 318 | result = model.transcribe( 319 | wav_path.as_posix(), 320 | language=self.target_lang, 321 | temperature=self.temperature, 322 | ) 323 | text = result["text"] 324 | assert type(text) == str 325 | if self.lowercase: 326 | text = text.lower() 327 | if self.remove_punctuations: 328 | text = remove_punctuations(text) 329 | transcripts.append(text.strip()) 330 | else: 331 | transcripts.append("") 332 | 333 | root_dir = wav_dir.parent 334 | transcripts_path = root_dir / "asr_transcripts.txt" 335 | with open(transcripts_path, "w") as f: 336 | for line in transcripts: 337 | f.write(line + "\n") 338 | 339 | return transcripts 340 | 341 | @staticmethod 342 | def add_args(parser): 343 | add_sacrebleu_args(parser) 344 | parser.add_argument( 345 | "--target-speech-lang", 346 | type=str, 347 | default="en", 348 | help="The language of target speech", 349 | ) 350 | parser.add_argument( 351 | "--whisper-model-size", 352 | type=str, 353 | default="large", 354 | help="The size of whisper asr model", 355 | ) 356 | parser.add_argument( 357 | "--whisper-model-temperature", 358 | type=float, 359 | default=0.0, 360 | help="If temperature > 0.0, the decoding will perform sampling", 361 | ) 362 | parser.add_argument( 363 | "--transcript-lowercase", 364 | action="store_true", 365 | help="Lowercase the whisper output", 366 | ) 367 | parser.add_argument( 368 | "--transcript-non-punctuation", 369 | action="store_true", 370 | help="Remove punctuations in the whisper output", 371 | ) 372 | 373 | @classmethod 374 | def from_args(cls, args): 375 | return cls( 376 | args.sacrebleu_tokenizer, 377 | args.target_speech_lang, 378 | args.whisper_model_size, 379 | args.whisper_model_temperature, 380 | args.transcript_lowercase, 381 | args.transcript_non_punctuation, 382 | ) 383 | -------------------------------------------------------------------------------- /simuleval/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import importlib 9 | import logging 10 | import os 11 | import sys 12 | from typing import List, Optional 13 | 14 | from simuleval.data.dataloader import ( 15 | DATALOADER_DICT, 16 | GenericDataloader, 17 | register_dataloader_class, 18 | ) 19 | from simuleval.evaluator.scorers import get_scorer_class 20 | 21 | 22 | def add_dataloader_args( 23 | parser: argparse.ArgumentParser, cli_argument_list: Optional[List[str]] = None 24 | ): 25 | if cli_argument_list is None: 26 | args, _ = parser.parse_known_args() 27 | else: 28 | args, _ = parser.parse_known_args(cli_argument_list) 29 | 30 | if args.dataloader_class: 31 | dataloader_module = importlib.import_module( 32 | ".".join(args.dataloader_class.split(".")[:-1]) 33 | ) 34 | dataloader_class = getattr( 35 | dataloader_module, args.dataloader_class.split(".")[-1] 36 | ) 37 | register_dataloader_class(args.dataloader, dataloader_class) 38 | 39 | dataloader_class = DATALOADER_DICT.get(args.dataloader) 40 | if dataloader_class is None: 41 | dataloader_class = GenericDataloader 42 | dataloader_class.add_args(parser) 43 | 44 | 45 | def add_evaluator_args(parser: argparse.ArgumentParser): 46 | parser.add_argument( 47 | "--quality-metrics", 48 | nargs="+", 49 | default=["BLEU"], 50 | help="Quality metrics", 51 | ) 52 | parser.add_argument( 53 | "--latency-metrics", 54 | nargs="+", 55 | default=["LAAL", "AL", "AP", "DAL", "ATD"], 56 | help="Latency metrics", 57 | ) 58 | parser.add_argument( 59 | "--continue-unfinished", 60 | action="store_true", 61 | default=False, 62 | help="Continue the experiments in output dir.", 63 | ) 64 | parser.add_argument( 65 | "--computation-aware", 66 | action="store_true", 67 | default=False, 68 | help="Include computational latency.", 69 | ) 70 | parser.add_argument( 71 | "--no-use-ref-len", 72 | action="store_true", 73 | default=False, 74 | help="Include computational latency.", 75 | ) 76 | parser.add_argument( 77 | "--eval-latency-unit", 78 | type=str, 79 | default="word", 80 | choices=["word", "char", "spm"], 81 | help="Basic unit used for latency calculation, choose from " 82 | "words (detokenized) and characters.", 83 | ) 84 | parser.add_argument( 85 | "--eval-latency-spm-model", 86 | type=str, 87 | default=None, 88 | help="Pass the spm model path if the eval_latency_unit is spm.", 89 | ) 90 | parser.add_argument( 91 | "--remote-address", 92 | default="localhost", 93 | help="Address to client backend", 94 | ) 95 | parser.add_argument( 96 | "--remote-port", 97 | default=12321, 98 | help="Port to client backend", 99 | ) 100 | parser.add_argument( 101 | "--no-progress-bar", 102 | action="store_true", 103 | default=False, 104 | help="Do not use progress bar", 105 | ) 106 | parser.add_argument( 107 | "--start-index", 108 | type=int, 109 | default=0, 110 | help="Start index for evaluation.", 111 | ) 112 | parser.add_argument( 113 | "--end-index", 114 | type=int, 115 | default=-1, 116 | help="The last index for evaluation.", 117 | ) 118 | parser.add_argument( 119 | "--output", 120 | type=str, 121 | default=None, 122 | help="Output directory. Required if using iterable dataloader.", 123 | ) 124 | 125 | 126 | def add_scorer_args( 127 | parser: argparse.ArgumentParser, cli_argument_list: Optional[List[str]] = None 128 | ): 129 | if cli_argument_list is None: 130 | args, _ = parser.parse_known_args() 131 | else: 132 | args, _ = parser.parse_known_args(cli_argument_list) 133 | 134 | for metric in args.latency_metrics: 135 | get_scorer_class("latency", metric).add_args(parser) 136 | 137 | for metric in args.quality_metrics: 138 | get_scorer_class("quality", metric).add_args(parser) 139 | 140 | 141 | def import_user_module(module_path): 142 | module_path = os.path.abspath(module_path) 143 | module_parent, module_name = os.path.split(module_path) 144 | 145 | sys.path.insert(0, module_parent) 146 | importlib.import_module(module_name) 147 | sys.path.pop(0) 148 | 149 | 150 | def general_parser( 151 | config_dict: Optional[dict] = None, 152 | parser: Optional[argparse.ArgumentParser] = None, 153 | ): 154 | if parser is None: 155 | parser = argparse.ArgumentParser( 156 | add_help=False, 157 | description="SimulEval - Simultaneous Evaluation CLI", 158 | conflict_handler="resolve", 159 | ) 160 | 161 | parser.add_argument( 162 | "--user-dir", 163 | default=None, 164 | help="path to a python module containing custom agents", 165 | ) 166 | args, _ = parser.parse_known_args() 167 | if args.user_dir is not None: 168 | import_user_module(args.user_dir) 169 | 170 | parser.add_argument( 171 | "--remote-eval", 172 | action="store_true", 173 | help="Evaluate a standalone agent", 174 | ) 175 | parser.add_argument( 176 | "--standalone", 177 | action="store_true", 178 | help="", 179 | ) 180 | parser.add_argument( 181 | "--slurm", action="store_true", default=False, help="Use slurm." 182 | ) 183 | parser.add_argument("--agent", default=None, help="Agent file") 184 | parser.add_argument( 185 | "--agent-class", 186 | default=None, 187 | help="The full string of class of the agent.", 188 | ) 189 | parser.add_argument( 190 | "--system-dir", 191 | default=None, 192 | help="Directory that contains everything to start the simultaneous system.", 193 | ) 194 | parser.add_argument( 195 | "--system-config", 196 | default="main.yaml", 197 | help="Name of the config yaml of the system configs.", 198 | ) 199 | parser.add_argument("--dataloader", default=None, help="Dataloader to use") 200 | parser.add_argument( 201 | "--dataloader-class", default=None, help="Dataloader class to use" 202 | ) 203 | parser.add_argument( 204 | "--log-level", 205 | type=str, 206 | default="info", 207 | choices=[x.lower() for x in logging._levelToName.values()], 208 | help="Log level.", 209 | ) 210 | scoring_arg_group = parser.add_mutually_exclusive_group() 211 | scoring_arg_group.add_argument( 212 | "--score-only", 213 | action="store_true", 214 | default=False, 215 | help="Only score the inference file.", 216 | ) 217 | scoring_arg_group.add_argument( 218 | "--no-scoring", 219 | action="store_true", 220 | help="No scoring after inference", 221 | ) 222 | parser.add_argument( 223 | "--device", type=str, default="cpu", help="Device to run the model." 224 | ) 225 | dtype_arg_group = parser.add_mutually_exclusive_group() 226 | dtype_arg_group.add_argument( 227 | "--dtype", 228 | choices=["fp16", "fp32"], 229 | type=str, 230 | help=( 231 | "Choose between half-precision (fp16) and single precision (fp32) floating point formats." 232 | + " Prefer this over the fp16 flag." 233 | ), 234 | ) 235 | dtype_arg_group.add_argument( 236 | "--fp16", action="store_true", default=False, help="Use fp16." 237 | ) 238 | parser.add_argument( 239 | "--visualize", 240 | action="store_true", 241 | default=False, 242 | help="Create visualization graphs", 243 | ) 244 | parser.add_argument( 245 | "--demo", 246 | action="store_true", 247 | default=False, 248 | help="Live remote speech to text demonstration in terminal", 249 | ) 250 | 251 | return parser 252 | 253 | 254 | def add_slurm_args(parser): 255 | parser.add_argument("--slurm-partition", default="", help="Slurm partition.") 256 | parser.add_argument("--slurm-job-name", default="simuleval", help="Slurm job name.") 257 | parser.add_argument("--slurm-time", default="2:00:00", help="Slurm partition.") 258 | -------------------------------------------------------------------------------- /simuleval/test/test_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import sys 9 | import tempfile 10 | from pathlib import Path 11 | import urllib.request 12 | from argparse import Namespace 13 | 14 | import pytest 15 | 16 | import simuleval.cli as cli 17 | from simuleval.agents import TextToTextAgent 18 | from simuleval.agents.actions import ReadAction, WriteAction 19 | from simuleval.data.segments import TextSegment 20 | import logging 21 | 22 | logger = logging.getLogger() 23 | 24 | 25 | ROOT_PATH = Path(__file__).parents[2] 26 | sys.path.insert(0, str(ROOT_PATH)) # may be needed for import from `examples` 27 | 28 | from examples.quick_start.spm_detokenizer_agent import ( 29 | SentencePieceModelDetokenizerAgent, 30 | ) 31 | 32 | 33 | def test_agent(root_path=ROOT_PATH): 34 | with tempfile.TemporaryDirectory() as tmpdirname: 35 | cli.sys.argv[1:] = [ 36 | "--user-dir", 37 | os.path.join(root_path, "examples"), 38 | "--agent-class", 39 | "examples.quick_start.first_agent.DummyWaitkTextAgent", 40 | "--source", 41 | os.path.join(root_path, "examples", "quick_start", "source.txt"), 42 | "--target", 43 | os.path.join(root_path, "examples", "quick_start", "target.txt"), 44 | "--output", 45 | tmpdirname, 46 | ] 47 | cli.main() 48 | 49 | 50 | def test_statelss_agent(root_path=ROOT_PATH): 51 | class DummyWaitkTextAgent(TextToTextAgent): 52 | waitk = 0 53 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 54 | 55 | def policy(self, states=None): 56 | if states is None: 57 | states = self.states 58 | 59 | lagging = len(states.source) - len(states.target) 60 | 61 | if lagging >= self.waitk or states.source_finished: 62 | prediction = self.vocab[len(states.source)] 63 | 64 | return WriteAction(prediction, finished=(lagging <= 1)) 65 | else: 66 | return ReadAction() 67 | 68 | args = None 69 | agent_stateless = DummyWaitkTextAgent.from_args(args) 70 | agent_state = agent_stateless.build_states() 71 | agent_stateful = DummyWaitkTextAgent.from_args(args) 72 | 73 | for _ in range(10): 74 | segment = TextSegment(0, "A") 75 | output_1 = agent_stateless.pushpop(segment, agent_state) 76 | output_2 = agent_stateful.pushpop(segment) 77 | assert output_1.content == output_2.content 78 | 79 | 80 | @pytest.mark.parametrize("detokenize_only", [True, False]) 81 | def test_spm_detokenizer_agent(detokenize_only): 82 | with tempfile.TemporaryDirectory() as tmpdirname: 83 | tokenizer_file = f"{tmpdirname}/tokenizer.model" 84 | tokenizer_url = "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model" 85 | urllib.request.urlretrieve(tokenizer_url, tokenizer_file) 86 | 87 | args = Namespace() 88 | args.sentencepiece_model = tokenizer_file 89 | args.detokenize_only = detokenize_only 90 | 91 | output = [] 92 | delays = [] 93 | agent = SentencePieceModelDetokenizerAgent.from_args(args) 94 | agent_state = agent.build_states() 95 | segments = [ 96 | TextSegment(0, "▁Let ' s"), 97 | TextSegment(1, "▁do ▁it ▁with"), 98 | TextSegment(2, "out ▁hesitation .", finished=True), 99 | ] 100 | for i, segment in enumerate(segments): 101 | output_segment = agent.pushpop(segment, agent_state) 102 | if not output_segment.is_empty: 103 | output.append(output_segment.content) 104 | delays += [i] * len(output_segment.content.split()) 105 | if detokenize_only: 106 | assert output == ["Let's", "do it with", "out hesitation."] 107 | assert delays == [0, 1, 1, 1, 2, 2] 108 | else: 109 | assert output == ["Let's do it", "without hesitation."] 110 | assert delays == [1, 1, 1, 2, 2] 111 | 112 | 113 | @pytest.mark.parametrize("detokenize_only", [True, False]) 114 | def test_spm_detokenizer_agent_pipeline(detokenize_only, root_path=ROOT_PATH): 115 | with tempfile.TemporaryDirectory() as tmpdirname: 116 | tokenizer_file = f"{tmpdirname}/tokenizer.model" 117 | tokenizer_url = "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model" 118 | urllib.request.urlretrieve(tokenizer_url, tokenizer_file) 119 | 120 | cli.sys.argv[1:] = [ 121 | "--user-dir", 122 | os.path.join(root_path, "examples"), 123 | "--agent-class", 124 | "examples.quick_start.spm_detokenizer_agent.DummyPipeline", 125 | "--source", 126 | os.path.join(root_path, "examples", "quick_start", "spm_source.txt"), 127 | "--target", 128 | os.path.join(root_path, "examples", "quick_start", "spm_target.txt"), 129 | "--output", 130 | tmpdirname, 131 | "--segment-k", 132 | "3", 133 | "--sentencepiece-model", 134 | tokenizer_file, 135 | ] 136 | if detokenize_only: 137 | cli.sys.argv.append("--detokenize-only") 138 | cli.main() 139 | -------------------------------------------------------------------------------- /simuleval/test/test_agent_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from pathlib import Path 9 | import tempfile 10 | 11 | import simuleval.cli as cli 12 | from simuleval.agents import AgentPipeline, TextToTextAgent 13 | from simuleval.agents.actions import ReadAction, WriteAction 14 | from simuleval.data.segments import TextSegment 15 | 16 | ROOT_PATH = Path(__file__).parents[2] 17 | 18 | 19 | def test_pipeline_cmd(root_path=ROOT_PATH): 20 | # NOTE: When importing --agent we use import_file, thus need to specify 21 | # --agent-class as agents.DummyPipeline 22 | cli.sys.argv[1:] = [ 23 | "--agent", 24 | os.path.join(root_path, "examples", "quick_start", "agent_pipeline.py"), 25 | "--user-dir", 26 | os.path.join(root_path, "examples"), 27 | "--agent-class", 28 | "agents.DummyPipeline", 29 | "--source", 30 | os.path.join(root_path, "examples", "quick_start", "source.txt"), 31 | "--target", 32 | os.path.join(root_path, "examples", "quick_start", "target.txt"), 33 | ] 34 | cli.main() 35 | 36 | 37 | def test_tree_pipeline_cmd(root_path=ROOT_PATH): 38 | args_path = Path.joinpath(root_path, "examples", "speech_to_speech") 39 | os.chdir(args_path) 40 | with tempfile.TemporaryDirectory() as tmpdirname: 41 | cli.sys.argv[1:] = [ 42 | "--agent-class", 43 | "examples.speech_to_speech_text.tree_agent_pipeline.DummyTreePipeline", 44 | "--user-dir", 45 | os.path.join(root_path, "examples"), 46 | "--source", 47 | os.path.join(root_path, "examples", "speech_to_speech", "source.txt"), 48 | "--target", 49 | os.path.join( 50 | root_path, "examples", "speech_to_text", "reference", "en.txt" 51 | ), 52 | "--source-segment-size", 53 | "320", 54 | "--output-index", 55 | "0", 56 | "--output", 57 | tmpdirname, 58 | ] 59 | cli.main() 60 | 61 | 62 | def test_instantiated_tree_pipeline_cmd(root_path=ROOT_PATH): 63 | args_path = Path.joinpath(root_path, "examples", "speech_to_speech") 64 | os.chdir(args_path) 65 | with tempfile.TemporaryDirectory() as tmpdirname: 66 | cli.sys.argv[1:] = [ 67 | "--agent-class", 68 | "examples.speech_to_speech_text.tree_agent_pipeline.AnotherInstantiatedTreeAgentPipeline", 69 | "--user-dir", 70 | os.path.join(root_path, "examples"), 71 | "--source", 72 | os.path.join(root_path, "examples", "speech_to_speech", "source.txt"), 73 | "--target", 74 | os.path.join( 75 | root_path, "examples", "speech_to_text", "reference", "en.txt" 76 | ), 77 | "--source-segment-size", 78 | "320", 79 | "--output-index", 80 | "0", 81 | "--output", 82 | tmpdirname, 83 | ] 84 | cli.main() 85 | 86 | 87 | def test_pipeline(): 88 | class DummyWaitkTextAgent(TextToTextAgent): 89 | waitk = 0 90 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 91 | 92 | def policy(self): 93 | lagging = len(self.states.source) - len(self.states.target) 94 | 95 | if lagging >= self.waitk or self.states.source_finished: 96 | prediction = self.vocab[len(self.states.source)] 97 | 98 | return WriteAction(prediction, finished=(lagging <= 1)) 99 | else: 100 | return ReadAction() 101 | 102 | class DummyWait2TextAgent(DummyWaitkTextAgent): 103 | waitk = 2 104 | 105 | class DummyWait4TextAgent(DummyWaitkTextAgent): 106 | waitk = 4 107 | 108 | class DummyPipeline(AgentPipeline): 109 | pipeline = [DummyWait2TextAgent, DummyWait4TextAgent] 110 | 111 | args = None 112 | agent_1 = DummyPipeline.from_args(args) 113 | agent_2 = DummyPipeline.from_args(args) 114 | for _ in range(10): 115 | segment = TextSegment(0, "A") 116 | output_1 = agent_1.pushpop(segment) 117 | agent_2.push(segment) 118 | output_2 = agent_2.pop() 119 | assert output_1.content == output_2.content 120 | -------------------------------------------------------------------------------- /simuleval/test/test_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | from pathlib import Path 10 | 11 | import simuleval.cli as cli 12 | 13 | ROOT_PATH = Path(__file__).parents[2] 14 | 15 | 16 | def test_score_only(root_path=ROOT_PATH): 17 | with tempfile.TemporaryDirectory() as tmpdirname: 18 | cli.sys.argv[1:] = [ 19 | "--user-dir", 20 | os.path.join(root_path, "examples"), 21 | "--agent-class", 22 | "examples.quick_start.first_agent.DummyWaitkTextAgent", 23 | "--source", 24 | os.path.join(root_path, "examples", "quick_start", "source.txt"), 25 | "--target", 26 | os.path.join(root_path, "examples", "quick_start", "target.txt"), 27 | "--output", 28 | tmpdirname, 29 | ] 30 | cli.main() 31 | cli.sys.argv[1:] = ["--score-only", "--output", tmpdirname] 32 | cli.main() 33 | -------------------------------------------------------------------------------- /simuleval/test/test_remote_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | import time 10 | from multiprocessing import Process 11 | from pathlib import Path 12 | 13 | import simuleval.cli as cli 14 | from simuleval.utils.functional import find_free_port 15 | 16 | ROOT_PATH = Path(__file__).parents[2] 17 | 18 | 19 | def p1(port, root_path): 20 | cli.sys.argv[1:] = [ 21 | "--standalone", 22 | "--remote-port", 23 | str(port), 24 | "--user-dir", 25 | os.path.join(root_path, "examples"), 26 | "--agent-class", 27 | "examples.quick_start.first_agent.DummyWaitkTextAgent", 28 | ] 29 | cli.main() 30 | time.sleep(5) 31 | 32 | 33 | def p2(port, root_path): 34 | with tempfile.TemporaryDirectory() as tmpdirname: 35 | cli.sys.argv[1:] = [ 36 | "--remote-eval", 37 | "--remote-port", 38 | str(port), 39 | "--source", 40 | os.path.join(root_path, "examples", "quick_start", "source.txt"), 41 | "--target", 42 | os.path.join(root_path, "examples", "quick_start", "target.txt"), 43 | "--dataloader", 44 | "text-to-text", 45 | "--output", 46 | tmpdirname, 47 | ] 48 | cli.main() 49 | 50 | 51 | def test_remote_eval(root_path=ROOT_PATH): 52 | port = find_free_port() 53 | 54 | p_1 = Process(target=p1, args=(port, root_path)) 55 | p_1.start() 56 | 57 | p_2 = Process(target=p2, args=(port, root_path)) 58 | p_2.start() 59 | 60 | p_1.kill() 61 | p_2.kill() 62 | -------------------------------------------------------------------------------- /simuleval/test/test_s2s.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | from pathlib import Path 10 | from typing import Optional 11 | from simuleval.agents.states import AgentStates 12 | 13 | import simuleval.cli as cli 14 | from simuleval.agents import SpeechToSpeechAgent 15 | from simuleval.agents.actions import ReadAction, WriteAction 16 | from simuleval.data.segments import SpeechSegment 17 | 18 | ROOT_PATH = Path(__file__).parents[2] 19 | 20 | 21 | def test_s2s(root_path=ROOT_PATH): 22 | args_path = Path.joinpath(root_path, "examples", "speech_to_speech") 23 | os.chdir(args_path) 24 | with tempfile.TemporaryDirectory() as tmpdirname: 25 | cli.sys.argv[1:] = [ 26 | "--agent", 27 | os.path.join( 28 | root_path, "examples", "speech_to_speech", "english_alternate_agent.py" 29 | ), 30 | "--user-dir", 31 | os.path.join(root_path, "examples"), 32 | "--agent-class", 33 | "agents.EnglishAlternateAgent", 34 | "--source-segment-size", 35 | "1000", 36 | "--source", 37 | os.path.join(root_path, "examples", "speech_to_speech", "source.txt"), 38 | "--target", 39 | os.path.join(root_path, "examples", "speech_to_speech", "reference/en.txt"), 40 | "--output", 41 | tmpdirname, 42 | "--tgt-lang", 43 | os.path.join( 44 | root_path, "examples", "speech_to_speech", "reference/tgt_lang.txt" 45 | ), 46 | ] 47 | cli.main() 48 | 49 | 50 | def test_stateless_agent(root_path=ROOT_PATH): 51 | class EnglishAlternateAgent(SpeechToSpeechAgent): 52 | waitk = 0 53 | wait_seconds = 3 54 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 55 | 56 | def policy(self, states: Optional[AgentStates] = None): 57 | if states is None: 58 | states = states 59 | 60 | length_in_seconds = round(len(states.source) / states.source_sample_rate) 61 | if ( 62 | not self.states.source_finished 63 | and length_in_seconds < self.wait_seconds 64 | ): 65 | return ReadAction() 66 | 67 | if length_in_seconds % 2 == 0: 68 | samples, fs = self.tts_model.synthesize( 69 | f"{8 - length_in_seconds} even even" 70 | ) 71 | else: 72 | samples, fs = self.tts_model.synthesize( 73 | f"{8 - length_in_seconds} odd odd" 74 | ) 75 | 76 | prediction = f"{length_in_seconds} second" 77 | 78 | return WriteAction( 79 | SpeechSegment( 80 | content=samples, 81 | sample_rate=fs, 82 | finished=self.states.source_finished, 83 | ), 84 | content=prediction, 85 | finished=self.states.source_finished, 86 | ) 87 | 88 | args = None 89 | agent_stateless = EnglishAlternateAgent.from_args(args) 90 | agent_state = agent_stateless.build_states() 91 | agent_stateful = EnglishAlternateAgent.from_args(args) 92 | 93 | for _ in range(10): 94 | segment = SpeechSegment(0, "A") 95 | output_1 = agent_stateless.pushpop(segment, agent_state) 96 | output_2 = agent_stateful.pushpop(segment) 97 | assert output_1.content == output_2.content 98 | -------------------------------------------------------------------------------- /simuleval/test/test_s2t.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | from pathlib import Path 10 | 11 | import simuleval.cli as cli 12 | from simuleval.agents import SpeechToTextAgent 13 | from simuleval.agents.actions import ReadAction, WriteAction 14 | from simuleval.data.segments import SpeechSegment 15 | from simuleval.evaluator.instance import LogInstance 16 | 17 | ROOT_PATH = Path(__file__).parents[2] 18 | 19 | 20 | def test_s2t(root_path=ROOT_PATH): 21 | args_path = Path.joinpath(root_path, "examples", "speech_to_text") 22 | os.chdir(args_path) 23 | with tempfile.TemporaryDirectory() as tmpdirname: 24 | cli.sys.argv[1:] = [ 25 | "--agent", 26 | os.path.join( 27 | root_path, "examples", "speech_to_text", "counter_in_tgt_lang_agent.py" 28 | ), 29 | "--user-dir", 30 | os.path.join(root_path, "examples"), 31 | "--agent-class", 32 | "agents.EnglishSpeechCounter", 33 | "--source-segment-size", 34 | "1000", 35 | "--source", 36 | os.path.join(root_path, "examples", "speech_to_text", "source.txt"), 37 | "--target", 38 | os.path.join(root_path, "examples", "speech_to_text", "reference/en.txt"), 39 | "--output", 40 | tmpdirname, 41 | "--tgt-lang", 42 | os.path.join( 43 | root_path, "examples", "speech_to_text", "reference/tgt_lang.txt" 44 | ), 45 | ] 46 | cli.main() 47 | 48 | with open(os.path.join(tmpdirname, "instances.log"), "r") as f: 49 | for line in f: 50 | instance = LogInstance(line.strip()) 51 | assert ( 52 | instance.prediction 53 | == "1 segundos 2 segundos 3 segundos 4 segundos 5 segundos 6 segundos 7 segundos" 54 | ) 55 | 56 | 57 | def test_statelss_agent(root_path=ROOT_PATH): 58 | class EnglishSpeechCounter(SpeechToTextAgent): 59 | wait_seconds = 3 60 | 61 | def policy(self, states=None): 62 | if states is None: 63 | states = self.states 64 | 65 | length_in_seconds = round(len(states.source) / states.source_sample_rate) 66 | if not states.source_finished and length_in_seconds < self.wait_seconds: 67 | return ReadAction() 68 | 69 | prediction = f"{length_in_seconds} second" 70 | 71 | return WriteAction( 72 | content=prediction, 73 | finished=states.source_finished, 74 | ) 75 | 76 | args = None 77 | agent_stateless = EnglishSpeechCounter.from_args(args) 78 | agent_state = agent_stateless.build_states() 79 | agent_stateful = EnglishSpeechCounter.from_args(args) 80 | 81 | for _ in range(10): 82 | segment = SpeechSegment(0, "A") 83 | output_1 = agent_stateless.pushpop(segment, agent_state) 84 | output_2 = agent_stateful.pushpop(segment) 85 | assert output_1.content == output_2.content 86 | 87 | 88 | def test_s2t_with_tgt_lang(root_path=ROOT_PATH): 89 | args_path = Path.joinpath(root_path, "examples", "speech_to_text") 90 | os.chdir(args_path) 91 | with tempfile.TemporaryDirectory() as tmpdirname: 92 | cli.sys.argv[1:] = [ 93 | "--agent", 94 | os.path.join( 95 | root_path, "examples", "speech_to_text", "counter_in_tgt_lang_agent.py" 96 | ), 97 | "--user-dir", 98 | os.path.join(root_path, "examples"), 99 | "--agent-class", 100 | "agents.CounterInTargetLanguage", 101 | "--source-segment-size", 102 | "1000", 103 | "--source", 104 | os.path.join(root_path, "examples", "speech_to_text", "source.txt"), 105 | "--target", 106 | os.path.join(root_path, "examples", "speech_to_text", "reference/en.txt"), 107 | "--output", 108 | tmpdirname, 109 | "--tgt-lang", 110 | os.path.join( 111 | root_path, "examples", "speech_to_text", "reference/tgt_lang.txt" 112 | ), 113 | ] 114 | cli.main() 115 | 116 | with open(os.path.join(tmpdirname, "instances.log"), "r") as f: 117 | for line in f: 118 | instance = LogInstance(line.strip()) 119 | assert ( 120 | instance.prediction 121 | == "1 segundos 2 segundos 3 segundos 4 segundos 5 segundos 6 segundos 7 segundos" 122 | ) 123 | -------------------------------------------------------------------------------- /simuleval/test/test_visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | from pathlib import Path 10 | import simuleval.cli as cli 11 | import shutil 12 | import json 13 | 14 | ROOT_PATH = Path(__file__).parents[2] 15 | 16 | 17 | def test_visualize(root_path=ROOT_PATH): 18 | args_path = Path.joinpath(root_path, "examples", "speech_to_text") 19 | os.chdir(args_path) 20 | with tempfile.TemporaryDirectory() as tmpdirname: 21 | cli.sys.argv[1:] = [ 22 | "--agent", 23 | os.path.join(root_path, "examples", "speech_to_text", "whisper_waitk.py"), 24 | "--source-segment-size", 25 | "500", 26 | "--waitk-lagging", 27 | "3", 28 | "--source", 29 | os.path.join(root_path, "examples", "speech_to_text", "source.txt"), 30 | "--target", 31 | os.path.join( 32 | root_path, "examples", "speech_to_text", "reference/transcript.txt" 33 | ), 34 | "--output", 35 | "output", 36 | "--quality-metrics", 37 | "WER", 38 | "--visualize", 39 | ] 40 | cli.main() 41 | 42 | visual_folder_path = os.path.join("output", "visual") 43 | source_path = os.path.join( 44 | root_path, "examples", "speech_to_text", "source.txt" 45 | ) 46 | source_length = 0 47 | 48 | with open(source_path, "r") as f: 49 | source_length = len(f.readlines()) 50 | images = list(Path(visual_folder_path).glob("*.png")) 51 | assert len(images) == source_length 52 | shutil.rmtree("output") 53 | 54 | 55 | def test_visualize_score_only(root_path=ROOT_PATH): 56 | args_path = Path.joinpath(root_path, "examples", "speech_to_text") 57 | os.chdir(args_path) 58 | 59 | # Create sample instances.log and config.yaml in output directory 60 | output = Path("output") 61 | output.mkdir() 62 | os.chdir(output) 63 | with open("config.yaml", "w") as config: 64 | config.write("source_type: speech\n") 65 | config.write("target_type: speech") 66 | with open("instances.log", "w") as instances: 67 | json.dump( 68 | { 69 | "index": 0, 70 | "prediction": "This is a synthesized audio file to test your simultaneous speech, to speak to speech, to speak translation system.", 71 | "delays": [ 72 | 1500.0, 73 | 2000.0, 74 | 2500.0, 75 | 3000.0, 76 | 3500.0, 77 | 4000.0, 78 | 4500.0, 79 | 5000.0, 80 | 5500.0, 81 | 6000.0, 82 | 6500.0, 83 | 6849.886621315192, 84 | 6849.886621315192, 85 | 6849.886621315192, 86 | 6849.886621315192, 87 | 6849.886621315192, 88 | 6849.886621315192, 89 | 6849.886621315192, 90 | 6849.886621315192, 91 | ], 92 | "elapsed": [ 93 | 1947.3278522491455, 94 | 2592.338800430298, 95 | 3256.8109035491943, 96 | 3900.0539779663086, 97 | 4561.986684799194, 98 | 5216.205835342407, 99 | 5874.6888637542725, 100 | 6526.906728744507, 101 | 7193.655729293823, 102 | 7852.792739868164, 103 | 8539.628744125366, 104 | 9043.279374916267, 105 | 9043.279374916267, 106 | 9043.279374916267, 107 | 9043.279374916267, 108 | 9043.279374916267, 109 | 9043.279374916267, 110 | 9043.279374916267, 111 | 9043.279374916267, 112 | ], 113 | "prediction_length": 19, 114 | "reference": "This is a synthesized audio file to test your simultaneous speech to text and to speech to speach translation system.", 115 | "source": [ 116 | "test.wav", 117 | "samplerate: 22050 Hz", 118 | "channels: 1", 119 | "duration: 6.850 s", 120 | "format: WAV (Microsoft) [WAV]", 121 | "subtype: Signed 16 bit PCM [PCM_16]", 122 | ], 123 | "source_length": 6849.886621315192, 124 | }, 125 | instances, 126 | ) 127 | 128 | os.chdir(args_path) 129 | 130 | with tempfile.TemporaryDirectory() as tmpdirname: 131 | cli.sys.argv[1:] = ["--score-only", "--output", "output", "--visualize"] 132 | cli.main() 133 | 134 | visual_folder_path = os.path.join("output", "visual") 135 | source_path = os.path.join( 136 | root_path, "examples", "speech_to_text", "source.txt" 137 | ) 138 | source_length = 0 139 | 140 | with open(source_path, "r") as f: 141 | source_length = len(f.readlines()) 142 | images = list(Path(visual_folder_path).glob("*.png")) 143 | assert len(images) == source_length 144 | shutil.rmtree("output") 145 | -------------------------------------------------------------------------------- /simuleval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .agent import build_system_from_dir, EVALUATION_SYSTEM_LIST # noqa F401 8 | 9 | 10 | def entrypoint(klass): 11 | EVALUATION_SYSTEM_LIST.append(klass) 12 | return klass 13 | -------------------------------------------------------------------------------- /simuleval/utils/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import importlib 9 | import logging 10 | import os 11 | import sys 12 | from argparse import ArgumentParser, Namespace 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union 15 | 16 | import yaml 17 | from simuleval import options 18 | from simuleval.agents import GenericAgent 19 | from simuleval.utils.arguments import check_argument, cli_argument_list 20 | 21 | EVALUATION_SYSTEM_LIST = [] 22 | 23 | logger = logging.getLogger("simuleval.utils.agent") 24 | 25 | 26 | def import_file(file_path): 27 | spec = importlib.util.spec_from_file_location("agents", file_path) 28 | agent_modules = importlib.util.module_from_spec(spec) 29 | spec.loader.exec_module(agent_modules) 30 | 31 | 32 | def get_agent_class(config_dict: Optional[dict] = None) -> GenericAgent: 33 | class_name = check_argument("agent_class", config_dict) 34 | 35 | if class_name is not None: 36 | if not check_argument("agent"): 37 | EVALUATION_SYSTEM_LIST.append(get_agent_class_from_string(class_name)) 38 | 39 | system_dir = check_argument("system_dir") 40 | config_name = check_argument("system_config") 41 | 42 | if system_dir is not None: 43 | EVALUATION_SYSTEM_LIST.append(get_agent_class_from_dir(system_dir, config_name)) 44 | 45 | agent_file = check_argument("agent") 46 | if agent_file is not None: 47 | import_file(agent_file) 48 | 49 | if len(EVALUATION_SYSTEM_LIST) == 0: 50 | raise RuntimeError( 51 | "Please use @entrypoint decorator to indicate the system you want to evaluate." 52 | ) 53 | if len(EVALUATION_SYSTEM_LIST) > 1: 54 | if class_name is None: 55 | raise RuntimeError( 56 | "--agent-class must be specified if more than one system." 57 | ) 58 | # if both --agent-class and --agent: 59 | for system in EVALUATION_SYSTEM_LIST: 60 | if f"{system.__module__}.{system.__name__}" == class_name: 61 | return system 62 | raise RuntimeError( 63 | f"--agent-class {class_name} not found in system list: {EVALUATION_SYSTEM_LIST}" 64 | ) 65 | return EVALUATION_SYSTEM_LIST[0] 66 | 67 | 68 | def get_system_config(path: Union[Path, str], config_name) -> dict: 69 | path = Path(path) 70 | with open(path / config_name, "r") as f: 71 | try: 72 | config_dict = yaml.safe_load(f) 73 | except yaml.YAMLError as exc: 74 | logging.error(f"Failed to load configs from {path / config_name}.") 75 | logging.error(exc) 76 | sys.exit(1) 77 | return config_dict 78 | 79 | 80 | def get_agent_class_from_string(class_name: str) -> GenericAgent: 81 | try: 82 | agent_module = importlib.import_module(".".join(class_name.split(".")[:-1])) 83 | agent_class = getattr(agent_module, class_name.split(".")[-1]) 84 | except Exception as e: 85 | logger.error(f"Not able to load {class_name}. Try setting --user-dir?") 86 | raise e 87 | return agent_class 88 | 89 | 90 | def get_agent_class_from_dir( 91 | path: Union[Path, str], config_name: str = "main.yaml" 92 | ) -> GenericAgent: 93 | config_dict = get_system_config(path, config_name) 94 | assert "agent_class" in config_dict 95 | class_name = config_dict["agent_class"] 96 | return get_agent_class_from_string(class_name) 97 | 98 | 99 | def build_system_from_dir( 100 | path: Union[Path, str], 101 | config_name: str = "main.yaml", 102 | overwrite_config_dict: Optional[dict] = None, 103 | ) -> GenericAgent: 104 | path = Path(path) 105 | config_dict = get_system_config(path, config_name) 106 | if overwrite_config_dict is not None: 107 | for key, value in overwrite_config_dict: 108 | config_dict[key] = value 109 | agent_class = get_agent_class_from_dir(path, config_name) 110 | 111 | parser = options.general_parser() 112 | agent_class.add_args(parser) 113 | args, _ = parser.parse_known_args(cli_argument_list(config_dict)) 114 | sys.path.append(path.as_posix()) 115 | 116 | cur_dir = os.getcwd() 117 | os.chdir(path.as_posix()) 118 | system = agent_class.from_args(args) 119 | os.chdir(cur_dir) 120 | return system 121 | 122 | 123 | def build_system_args( 124 | config_dict: Optional[dict] = None, 125 | parser: Optional[ArgumentParser] = None, 126 | ) -> Tuple[GenericAgent, Namespace]: 127 | parser = options.general_parser(config_dict, parser) 128 | cli_arguments = cli_argument_list(config_dict) 129 | options.add_evaluator_args(parser) 130 | options.add_scorer_args(parser, cli_arguments) 131 | options.add_slurm_args(parser) 132 | options.add_dataloader_args(parser, cli_arguments) 133 | 134 | if check_argument("system_dir"): 135 | system = build_system_from_dir( 136 | check_argument("system_dir"), check_argument("system_config"), config_dict 137 | ) 138 | else: 139 | system_class = get_agent_class(config_dict) 140 | system_class.add_args(parser) 141 | add_command_helper_arg(parser) 142 | args, _ = parser.parse_known_args(cli_argument_list(config_dict)) 143 | system = system_class.from_args(args) 144 | 145 | args = parser.parse_args(cli_argument_list(config_dict)) 146 | 147 | dtype = args.dtype if args.dtype else "fp16" if args.fp16 else "fp32" 148 | logger.info(f"System will run on device: {args.device}. dtype: {dtype}") 149 | system.to(args.device, fp16=(dtype == "fp16")) 150 | 151 | args.source_type = system.source_type 152 | args.target_type = system.target_type 153 | return system, args 154 | 155 | 156 | def add_command_helper_arg(parser: ArgumentParser) -> None: 157 | parser.add_argument( 158 | "-h", 159 | "--help", 160 | action="help", 161 | default="==SUPPRESS==", 162 | help=("show this help message and exit"), 163 | ) 164 | -------------------------------------------------------------------------------- /simuleval/utils/arguments.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | from simuleval import options 4 | 5 | 6 | def cli_argument_list(config_dict: Optional[dict]): 7 | if config_dict is None: 8 | return sys.argv[1:] 9 | else: 10 | string = "" 11 | for key, value in config_dict.items(): 12 | if f"--{key.replace('_', '-')}" in sys.argv: 13 | continue 14 | 15 | if type(value) is not bool: 16 | string += f" --{key.replace('_', '-')} {value}" 17 | else: 18 | string += f" --{key.replace('_', '-')}" 19 | return sys.argv[1:] + string.split() 20 | 21 | 22 | def check_argument(name: str, config_dict: Optional[dict] = None): 23 | parser = options.general_parser() 24 | args, _ = parser.parse_known_args(cli_argument_list(config_dict)) 25 | return getattr(args, name) 26 | -------------------------------------------------------------------------------- /simuleval/utils/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from contextlib import closing 8 | import socket 9 | 10 | 11 | def find_free_port(): 12 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 13 | s.bind(("", 0)) 14 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 15 | return s.getsockname()[1] 16 | -------------------------------------------------------------------------------- /simuleval/utils/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | import logging 9 | import os 10 | import re 11 | import subprocess 12 | import sys 13 | from argparse import ArgumentParser 14 | from typing import Dict, List, Optional 15 | 16 | from simuleval import options 17 | from simuleval.utils.agent import get_agent_class 18 | from simuleval.utils.arguments import cli_argument_list 19 | 20 | logger = logging.getLogger("simuleval.slurm") 21 | 22 | 23 | def mkdir_output_dir(path: str) -> bool: 24 | try: 25 | os.makedirs(path, exist_ok=True) 26 | return True 27 | except BaseException as be: 28 | logger.error(f"Failed to write results to {path}.") 29 | logger.error(be) 30 | logger.error("Skip writing predictions.") 31 | return False 32 | 33 | 34 | def submit_slurm_job( 35 | config_dict: Optional[Dict] = None, parser: Optional[ArgumentParser] = None 36 | ) -> None: 37 | if config_dict is not None and "slurm" in config_dict: 38 | raise RuntimeError("--slurm is only available as a CLI argument") 39 | 40 | sweep_options = [ 41 | [[key, v] for v in value] 42 | for key, value in config_dict.items() 43 | if isinstance(value, list) 44 | ] 45 | sweep_config_dict_list = [] 46 | if len(sweep_options) > 0: 47 | for option_list in itertools.product(*sweep_options): 48 | sweep_config_dict_list.append({k: v for k, v in option_list}) 49 | 50 | for x in sweep_options: 51 | if x[0][0] in config_dict: 52 | del config_dict[x[0][0]] 53 | 54 | cli_arguments = cli_argument_list(config_dict) 55 | parser = options.general_parser(config_dict, parser) 56 | options.add_evaluator_args(parser) 57 | options.add_scorer_args(parser, cli_arguments) 58 | options.add_slurm_args(parser) 59 | options.add_dataloader_args(parser, cli_arguments) 60 | system_class = get_agent_class(config_dict) 61 | system_class.add_args(parser) 62 | args = parser.parse_args(cli_argument_list(config_dict)) 63 | args.output = os.path.abspath(args.output) 64 | assert mkdir_output_dir(args.output) 65 | 66 | if args.agent is None: 67 | args.agent = sys.argv[0] 68 | 69 | os.system(f"cp {args.agent} {args.output}/agent.py") 70 | _args = [sys.argv[0]] 71 | for arg in sys.argv[1:]: 72 | if str(arg).isdigit() or str(arg).startswith("--"): 73 | _args.append(arg) 74 | else: 75 | _args.append(f'"{arg}"') 76 | command = " ".join(_args).strip() 77 | command = re.sub(r"(--slurm\S*(\s+[^-]\S+)*)", "", command).strip() 78 | if subprocess.check_output(["which", "simuleval"]).decode().strip() in command: 79 | command = re.sub( 80 | r"--agent\s+\S+", f"--agent {args.output}/agent.py", command 81 | ).strip() 82 | else: 83 | # Attention: not fully tested! 84 | command = re.sub( 85 | r"[^\"'\s]+\.py", f"{os.path.abspath(args.output)}/agent.py", command 86 | ).strip() 87 | 88 | sweep_command = "" 89 | sbatch_job_array_head = "" 90 | job_array_configs = "" 91 | 92 | if len(sweep_config_dict_list) > 0: 93 | job_array_configs = "declare -A JobArrayConfigs\n" 94 | for i, sub_config_dict in enumerate(sweep_config_dict_list): 95 | sub_config_string = " ".join( 96 | [f"--{k.replace('_', '-')} {v}" for k, v in sub_config_dict.items()] 97 | ) 98 | job_array_configs += f'JobArrayConfigs[{i}]="{sub_config_string}"\n' 99 | 100 | job_array_configs += "\ndeclare -A JobArrayString\n" 101 | for i, sub_config_dict in enumerate(sweep_config_dict_list): 102 | sub_config_string = ".".join([str(v) for k, v in sub_config_dict.items()]) 103 | job_array_configs += f'JobArrayString[{i}]="{sub_config_string}"\n' 104 | 105 | sweep_command = "${JobArrayConfigs[$SLURM_ARRAY_TASK_ID]}" 106 | sbatch_job_array_head = f"#SBATCH --array=0-{len(sweep_config_dict_list) - 1}" 107 | output_dir = ( 108 | f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}" 109 | ) 110 | log_path = f"{args.output}/logs/slurm-%A_%a.log" 111 | 112 | else: 113 | output_dir = args.output 114 | log_path = f"{args.output}/slurm-%j.log" 115 | 116 | if "--output" in command: 117 | command = re.sub(r"--output\s+\S+", f"--output {output_dir}", command).strip() 118 | else: 119 | command += f" --output {output_dir}" 120 | 121 | command = command.replace("--", "\\\n\t--") 122 | script = f"""#!/bin/bash 123 | #SBATCH --time={args.slurm_time} 124 | #SBATCH --partition={args.slurm_partition} 125 | #SBATCH --nodes=1 126 | #SBATCH --gpus-per-node=1 127 | #SBATCH --ntasks-per-node=8 128 | #SBATCH --output="{log_path}" 129 | #SBATCH --job-name="{args.slurm_job_name}" 130 | {sbatch_job_array_head} 131 | 132 | {job_array_configs} 133 | 134 | mkdir -p {args.output}/logs 135 | cd {os.path.abspath(args.output)} 136 | 137 | GPU_ID=$SLURM_LOCALID 138 | 139 | # Change to local a gpu id for debugging, e.g. 140 | # GPU_ID=0 141 | 142 | 143 | CUDA_VISIBLE_DEVICES=$GPU_ID {command} {sweep_command} 144 | """ 145 | script_file = os.path.join(args.output, "script.sh") 146 | with open(script_file, "w") as f: 147 | f.writelines(script) 148 | 149 | process = subprocess.Popen( 150 | ["sbatch", script_file], 151 | stderr=subprocess.PIPE, 152 | stdout=subprocess.PIPE, 153 | ) 154 | stdout, stderr = process.communicate() 155 | logger.info("Using slurm.") 156 | logger.info(f"sbatch stdout: {stdout.decode('utf-8').strip()}") 157 | stderr = stderr.decode("utf-8").strip() 158 | if len(stderr) > 0: 159 | logger.info(f"sbatch stderr: {stderr.strip()}") 160 | -------------------------------------------------------------------------------- /simuleval/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import soundfile 2 | import matplotlib.patches as patches 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | from pathlib import Path 7 | 8 | 9 | class Visualize: 10 | def __init__(self, data, index, path): 11 | self.data = data 12 | self.index = index 13 | self.path = path 14 | 15 | def make_graph(self): 16 | data = self.data 17 | 18 | ##### 1st PLOT ##### 19 | # Organize all data 20 | words = np.arange(1, data.get("prediction_length", 0) + 1) 21 | delays = [x / 1000 for x in data.get("delays", "N/A delays data")] 22 | data_points = [ 23 | (delays[i], words[i]) for i in range(len(delays)) 24 | ] # /1000 to make it from ms to s 25 | data_points.insert(0, (0, 0)) 26 | prediction_word_list = data.get("prediction", "N/A predition data").split(" ") 27 | 28 | # Insert points to create the staircase effect 29 | i = 0 30 | while i < len(data_points) - 1: 31 | if data_points[i][0] != data_points[i + 1][0]: 32 | data_points.insert(i + 1, (data_points[i + 1][0], data_points[i][1])) 33 | i += 2 34 | else: 35 | break 36 | 37 | # Create a new figure 38 | fig = plt.figure(figsize=(10, 16), dpi=300) 39 | ax1, ax2 = fig.subplots(2, gridspec_kw={"hspace": 0.1}) 40 | 41 | # Plot the points and connect them with lines 42 | x, y = zip(*data_points) 43 | ax1.plot(list(x), list(y), marker="o", color="black") 44 | 45 | # Draw arrows between points 46 | red_arrow_indices = [] 47 | for i in range(len(data_points) - 1): 48 | color = "blue" if y[i] == y[i + 1] else "red" 49 | if color == "red": 50 | red_arrow_indices.append(i) 51 | ax1.annotate( 52 | "", 53 | xy=(x[i + 1], y[i + 1]), 54 | xytext=(x[i], y[i]), 55 | arrowprops=dict( 56 | facecolor=color, edgecolor=color, shrink=0, width=1, headwidth=7 57 | ), 58 | ) 59 | 60 | # Annotate blue arrows with words from prediction_word_list 61 | for idx, word in zip(red_arrow_indices, prediction_word_list): 62 | ax1.text( 63 | x[idx] + 0.1, # x position of the annotation 64 | (y[idx] + y[idx + 1]) / 2, # y position of the annotation 65 | word, # text to annotate 66 | va="center", # vertical alignment 67 | fontsize=9, # font size 68 | color="black", # text color 69 | ) 70 | 71 | # Custom legend on the right side 72 | legend_x = 0.85 # X-coordinate for the legend in axes coordinates 73 | legend_y = 0.9 # Starting Y-coordinate for the legend 74 | 75 | # Red downward arrow 76 | red_arrow = patches.FancyArrow( 77 | legend_x + 0.03, 78 | legend_y, 79 | 0, 80 | -0.04, 81 | width=0.01, 82 | head_width=0.03, 83 | head_length=0.03, 84 | color="red", 85 | transform=ax1.transAxes, 86 | ) 87 | ax1.add_patch(red_arrow) 88 | ax1.text( 89 | legend_x + 0.05, 90 | legend_y - 0.03, 91 | "Write", 92 | color="red", 93 | transform=ax1.transAxes, 94 | ) 95 | 96 | # Blue rightward arrow 97 | blue_arrow = patches.FancyArrow( 98 | legend_x - 0.05, 99 | legend_y, 100 | 0.05, 101 | 0, 102 | width=0.01, 103 | head_width=0.03, 104 | head_length=0.03, 105 | color="blue", 106 | transform=ax1.transAxes, 107 | ) 108 | ax1.add_patch(blue_arrow) 109 | ax1.text( 110 | legend_x - 0.05, 111 | legend_y + 0.02, 112 | "Read", 113 | color="blue", 114 | transform=ax1.transAxes, 115 | ) 116 | 117 | # Flip the y-axis 118 | ax1.invert_yaxis() 119 | # Set the limits of the axes to start at (0,0) 120 | ax1.set_xlim(left=0) 121 | ax1.set_xlim(right=delays[-1]) 122 | ax1.set_ylim(top=0) 123 | 124 | # Set the grid 125 | ax1.grid(True) 126 | 127 | # Add labels and title 128 | ax1.set_xlabel("Source / delays (s)") 129 | ax1.set_ylabel("Target / number of words") 130 | plt.suptitle("SimulEval", fontsize=20) 131 | reference = data.get("reference", "N/A reference data") 132 | subtitle = f'Reference: "{reference}"' 133 | ax1.set_title(subtitle, fontsize=10, pad=20) 134 | ax1.set_xticks( 135 | [ 136 | i * (delays[1] - delays[0]) 137 | for i in range(0, round(delays[0] / (delays[1] - delays[0]))) 138 | ] 139 | + delays 140 | ) 141 | ax1.set_yticks(words) 142 | 143 | # Additional text at the bottom 144 | additional_text = [ 145 | ("Source", data.get("source", "N/A source data")[0]), 146 | ( 147 | "Prediction length", 148 | data.get("prediction_length", "N/A prediction length"), 149 | ), 150 | ( 151 | "Sample rate", 152 | data.get("source", "N/A source data")[1].split(":")[1].strip(), 153 | ), 154 | ( 155 | "Duration", 156 | data.get("source", "N/A source data")[3].split(":")[1].strip(), 157 | ), 158 | ] 159 | start_y = 1.2 160 | line_space = 0.05 161 | for i in range(len(additional_text)): 162 | if i % 2 == 0: 163 | ax1.text( 164 | 0, 165 | start_y, 166 | additional_text[i][0], 167 | ha="left", 168 | va="top", 169 | transform=ax1.transAxes, 170 | fontsize=10, 171 | bbox=dict(facecolor="white", alpha=0.5), 172 | ) 173 | ax1.text( 174 | 0.3, 175 | start_y, 176 | additional_text[i][1], 177 | ha="right", 178 | va="top", 179 | transform=ax1.transAxes, 180 | fontsize=10, 181 | ) 182 | else: 183 | ax1.text( 184 | 0.65, 185 | start_y, 186 | additional_text[i][0], 187 | ha="left", 188 | va="top", 189 | transform=ax1.transAxes, 190 | fontsize=10, 191 | bbox=dict(facecolor="white", alpha=0.5), 192 | ) 193 | ax1.text( 194 | 0.85, 195 | start_y, 196 | additional_text[i][1], 197 | ha="right", 198 | va="top", 199 | transform=ax1.transAxes, 200 | fontsize=10, 201 | ) 202 | start_y -= line_space 203 | 204 | ##### 2nd PLOT ##### 205 | # Load the audio file data 206 | example_path = Path(self.path).parent 207 | audio_path = example_path / data["source"][0] 208 | audio_data, rate = soundfile.read(audio_path) 209 | length = audio_data.shape[0] / rate 210 | time = np.linspace(0, length, audio_data.shape[0]) 211 | 212 | # Make the figure 213 | ax2.set_xlim(left=0) 214 | ax2.set_xlim(right=delays[-1]) 215 | ax2.set_xticks( 216 | [ 217 | i * (delays[1] - delays[0]) 218 | for i in range(0, round(delays[0] / (delays[1] - delays[0]))) 219 | ] 220 | + delays 221 | ) 222 | 223 | # Plot the waveform 224 | ax2.plot(time, audio_data) 225 | ax2.set_xlabel("Delays (s)") 226 | ax2.set_ylabel("Amplitude") 227 | 228 | # Set the grid 229 | ax2.grid(True) 230 | 231 | # Write to output/visual/graph_.png 232 | img_path = Path(self.path) / "visual" 233 | img_path.mkdir(exist_ok=True, parents=True) 234 | ## here we use / instead of os.path.join because img_path is a Path object which cannot be joined with this function 235 | plt.savefig(img_path / ("graph" + str(self.index))) 236 | --------------------------------------------------------------------------------