├── .github
└── workflows
│ ├── main.yml
│ └── pypi-publish.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── conf.py
├── index.rst
├── installation.rst
├── make.bat
├── quick_start.rst
├── requirements.txt
├── tutorials
│ ├── remote_evaluation.rst
│ ├── speech_to_speech.rst
│ └── speech_to_text.rst
└── user_guide
│ ├── agent.rst
│ ├── dataloader.rst
│ ├── evaluator.rst
│ └── introduction.rst
├── examples
├── __init__.py
├── demo
│ └── silero_vad.py
├── quick_start
│ ├── Dockerfile
│ ├── agent_pipeline.py
│ ├── agent_with_configs.py
│ ├── agent_with_new_metrics.py
│ ├── dict.txt
│ ├── first_agent.py
│ ├── readme.md
│ ├── source.txt
│ ├── spm_detokenizer_agent.py
│ ├── spm_source.txt
│ ├── spm_target.txt
│ └── target.txt
├── speech_to_speech
│ ├── english_alternate_agent.py
│ ├── english_counter_agent.py
│ ├── eval.sh
│ ├── readme.md
│ ├── reference
│ │ ├── de.txt
│ │ ├── en.txt
│ │ ├── ja.txt
│ │ ├── tgt_lang.txt
│ │ └── zh.txt
│ ├── source.txt
│ └── test.wav
├── speech_to_speech_demo
│ ├── english_counter_pipeline.py
│ └── readme.md
├── speech_to_speech_text
│ └── tree_agent_pipeline.py
├── speech_to_text
│ ├── Dockerfile
│ ├── counter_in_tgt_lang_agent.py
│ ├── english_counter_agent.py
│ ├── eval.sh
│ ├── readme.md
│ ├── reference
│ │ ├── en.txt
│ │ ├── tgt_lang.txt
│ │ └── transcript.txt
│ ├── source.txt
│ ├── test.wav
│ └── whisper_waitk.py
└── speech_to_text_demo
│ ├── counter_in_tgt_lang_pipeline.py
│ ├── english_counter_pipeline.py
│ └── readme.md
├── setup.cfg
├── setup.py
└── simuleval
├── __init__.py
├── agents
├── __init__.py
├── actions.py
├── agent.py
├── pipeline.py
├── service.py
└── states.py
├── analysis
└── curve.py
├── cli.py
├── data
├── __init__.py
├── dataloader
│ ├── __init__.py
│ ├── dataloader.py
│ ├── s2t_dataloader.py
│ └── t2t_dataloader.py
└── segments.py
├── evaluator
├── __init__.py
├── evaluator.py
├── instance.py
├── remote.py
└── scorers
│ ├── __init__.py
│ ├── latency_scorer.py
│ └── quality_scorer.py
├── options.py
├── test
├── test_agent.py
├── test_agent_pipeline.py
├── test_evaluator.py
├── test_remote_evaluation.py
├── test_s2s.py
├── test_s2t.py
└── test_visualize.py
└── utils
├── __init__.py
├── agent.py
├── arguments.py
├── functional.py
├── slurm.py
└── visualize.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: build
5 |
6 | on:
7 | push:
8 | branches:
9 | - main
10 | pull_request:
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: [3.8]
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version}}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | sudo apt-get update
28 | sudo apt-get install libsndfile1
29 | sudo apt-get install portaudio19-dev
30 | python -m pip install --upgrade pip==24.0
31 | pip install flake8 pytest black
32 | pip install g2p-en
33 | pip install huggingface-hub
34 | pip install fairseq
35 | pip install sentencepiece
36 | pip install openai-whisper editdistance pyaudio silero-vad
37 | pip install -e .
38 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
39 | python -c "import nltk; nltk.download('averaged_perceptron_tagger_eng')"
40 | - name: Lint with black
41 | run: black --check --diff .
42 | - name: Lint with flake8
43 | run: |
44 | # stop the build if there are Python syntax errors or undefined names
45 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
46 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
47 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
48 | - name: Test with pytest
49 | run: |
50 | pytest simuleval/test/test_agent.py
51 | pytest simuleval/test/test_agent_pipeline.py
52 | pytest simuleval/test/test_evaluator.py
53 | pytest simuleval/test/test_remote_evaluation.py
54 | pytest simuleval/test/test_s2s.py
55 | pytest simuleval/test/test_visualize.py
56 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Publish
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*
7 | jobs:
8 | pypi_publish:
9 | name: Upload release to PyPI
10 | runs-on: ubuntu-latest
11 | environment:
12 | name: pypi
13 | url: https://pypi.org/p/simuleval
14 | permissions:
15 | id-token: write
16 |
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v2
20 |
21 | - name: Set up Python
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: 3.x
25 |
26 | - name: Install dependencies
27 | run: pip install --upgrade pip setuptools wheel
28 |
29 | - name: Build package
30 | run: python setup.py sdist bdist_wheel
31 |
32 | - name: Publish package distributions to PyPI
33 | uses: pypa/gh-action-pypi-publish@release/v1
34 | with:
35 | args: "--use-feature=fast-deploy"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 | .vscode
140 |
141 | # Mac files
142 | .DS_Store
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | 1.0.0 (September 25, 2020)
2 |
3 | * Initial release.
4 |
5 | 1.0.1 (February 8, 2021)
6 |
7 | * Change CLI command
8 | * Change `simuleval-server` to `simuleval --server-only`
9 | * Change `simuleval-client` to `simuleval --client-only`
10 | * Fix some typos
11 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Open Source Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
6 |
7 | ## Our Standards
8 |
9 | Examples of behavior that contributes to creating a positive environment include:
10 |
11 | Using welcoming and inclusive language
12 | Being respectful of differing viewpoints and experiences
13 | Gracefully accepting constructive criticism
14 | Focusing on what is best for the community
15 | Showing empathy towards other community members
16 | Examples of unacceptable behavior by participants include:
17 |
18 | The use of sexualized language or imagery and unwelcome sexual attention or advances
19 | Trolling, insulting/derogatory comments, and personal or political attacks
20 | Public or private harassment
21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission
22 | Other conduct which could reasonably be considered inappropriate in a professional setting
23 |
24 | ## Our Responsibilities
25 |
26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
27 |
28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
29 |
30 | ## Scope
31 |
32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
33 |
34 | ## Enforcement
35 |
36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
37 |
38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership.
39 |
40 | ## Attribution
41 |
42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
44 |
45 | [homepage]: https://www.contributor-covenant.org
46 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Facebook AI SimulEval
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `master`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 |
14 | ## Issues
15 | We use GitHub issues to track public bugs. Please ensure your description is
16 | clear and has sufficient instructions to be able to reproduce the issue.
17 |
18 | ## License
19 | By contributing to Facebook AI SimulEval, you agree that your contributions will
20 | be licensed under the LICENSE file in the root directory of this source tree.
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimulEval
2 | [](https://github.com/facebookresearch/SimulEval/actions)
3 |
4 | SimulEval is a general evaluation framework for simultaneous translation on text and speech. Full documentation can be found [here](https://simuleval.readthedocs.io/en/v1.1.0/).
5 |
6 | ## Installation
7 | ```
8 | git clone https://github.com/facebookresearch/SimulEval.git
9 | cd SimulEval
10 | pip install -e .
11 | ```
12 |
13 | ## Quick Start
14 | Following is the evaluation of a [dummy agent](examples/quick_start) which operates wait-k (k = 3) policy and generates random words until the length of the generated words is the same as the number of all the source words.
15 | ```shell
16 | cd examples/quick_start
17 | simuleval --source source.txt --target target.txt --agent first_agent.py
18 | ```
19 |
20 | # License
21 |
22 | SimulEval is licensed under Creative Commons BY-SA 4.0.
23 |
24 | # Citation
25 |
26 | Please cite as:
27 |
28 | ```bibtex
29 | @inproceedings{simuleval2020,
30 | title = {Simuleval: An evaluation toolkit for simultaneous translation},
31 | author = {Xutai Ma, Mohammad Javad Dousti, Changhan Wang, Jiatao Gu, Juan Pino},
32 | booktitle = {Proceedings of the EMNLP},
33 | year = {2020},
34 | }
35 | ```
36 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Configuration file for the Sphinx documentation builder.
8 | #
9 | # For the full list of built-in configuration values, see the documentation:
10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
11 |
12 | # -- Project information -----------------------------------------------------
13 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
14 |
15 | project = "SimulEval"
16 | copyright = "Facebook AI Research (FAIR)"
17 | author = "Facebook AI Research (FAIR)"
18 | release = "1.1.0"
19 |
20 | # -- General configuration ---------------------------------------------------
21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
22 |
23 | extensions = [
24 | "sphinx_rtd_theme",
25 | "sphinx.ext.autodoc",
26 | "sphinx.ext.intersphinx",
27 | "sphinx.ext.viewcode",
28 | "sphinx.ext.napoleon",
29 | "sphinxarg.ext",
30 | ]
31 |
32 | # templates_path = ['_templates']
33 | exclude_patterns = []
34 |
35 |
36 | # -- Options for HTML output -------------------------------------------------
37 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
38 |
39 | html_theme = "sphinx_rtd_theme"
40 | # html_static_path = ['_static']
41 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | SimulEval documentation
2 | =======================
3 |
4 | SimulEval is a general evaluation framework for simultaneous translation.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 | :glob:
9 | :caption: Get started
10 |
11 | installation
12 | quick_start
13 |
14 | .. toctree::
15 | :maxdepth: 2
16 | :caption: User's guide
17 |
18 | user_guide/introduction
19 | user_guide/agent
20 | user_guide/evaluator
21 | user_guide/dataloader
22 |
23 | .. toctree::
24 | :maxdepth: 2
25 | :caption: Tutorials
26 |
27 | tutorials/remote_evaluation
28 | tutorials/speech_to_text
29 | tutorials/speech_to_speech
30 |
31 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | From pip
5 | --------
6 |
7 | .. code-block:: bash
8 |
9 | pip install simuleval
10 |
11 |
12 | From source
13 | -----------
14 |
15 | .. code-block:: bash
16 |
17 | git clone https://github.com/facebookresearch/SimulEval.git
18 | cd SimulEval
19 | pip install -e .
20 |
21 |
22 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | :: Copyright (c) Facebook, Inc. and its affiliates.
2 | :: All rights reserved.
3 | ::
4 | :: This source code is licensed under the license found in the
5 | :: LICENSE file in the root directory of this source tree.
6 |
7 | @ECHO OFF
8 |
9 | pushd %~dp0
10 |
11 | REM Command file for Sphinx documentation
12 |
13 | if "%SPHINXBUILD%" == "" (
14 | set SPHINXBUILD=sphinx-build
15 | )
16 | set SOURCEDIR=source
17 | set BUILDDIR=build
18 |
19 | %SPHINXBUILD% >NUL 2>NUL
20 | if errorlevel 9009 (
21 | echo.
22 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
23 | echo.installed, then set the SPHINXBUILD environment variable to point
24 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
25 | echo.may add the Sphinx directory to PATH.
26 | echo.
27 | echo.If you don't have Sphinx installed, grab it from
28 | echo.https://www.sphinx-doc.org/
29 | exit /b 1
30 | )
31 |
32 | if "%1" == "" goto help
33 |
34 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
35 | goto end
36 |
37 | :help
38 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
39 |
40 | :end
41 | popd
42 |
--------------------------------------------------------------------------------
/docs/quick_start.rst:
--------------------------------------------------------------------------------
1 | .. _first-agent:
2 |
3 | Quick Start
4 | ===========
5 |
6 | This section will introduce a minimal example on how to use SimulEval for simultaneous translation evaluation.
7 | The code in the example can be found in :code:`examples/quick_start`.
8 |
9 | The agent in SimulEval is core for simultaneous evaluation.
10 | It's a carrier of user's simultaneous system.
11 | The user has to implement the agent based on their system for evaluation.
12 | The example simultaneous system is a dummy wait-k agent, which
13 |
14 | - Runs `wait-k `_ policy.
15 | - Generates random characters the policy decide to write.
16 | - Stops the generation k predictions after source input. For simplicity, we just set :code:`k=3` in this example.
17 |
18 | The implementation of this agent is shown as follow.
19 |
20 | .. literalinclude:: ../examples/quick_start/first_agent.py
21 | :language: python
22 | :lines: 6-
23 |
24 | There two essential components for an agent:
25 |
26 | - :code:`states`: The attribute keeps track of the source and target information.
27 | - :code:`policy`: The method makes decisions when the there is a new source segment.
28 |
29 | Once the agent is implemented and saved at :code:`first_agent.py`,
30 | run the following command for latency evaluation on:
31 |
32 | .. code-block:: bash
33 |
34 | simuleval --source source.txt --reference target.txt --agent first_agent.py
35 |
36 | where :code:`--source` is the input file while :code:`--target` is the reference file.
37 |
38 | By default, the SimulEval will give the following output --- one quality and three latency metrics.
39 |
40 | .. code-block:: bash
41 |
42 | 2022-12-05 13:43:58 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent
43 | 2022-12-05 13:43:58 | INFO | simuleval.dataloader | Evaluating from text to text.
44 | 2022-12-05 13:43:58 | INFO | simuleval.sentence_level_evaluator | Results:
45 | BLEU AL AP DAL
46 | 1.541 3.0 0.688 3.0
47 |
48 | The average lagging is expected since we are running an wait-3 system where the source and target always have the same length.
49 | Notice that we have a very low yet random BLEU score. It's because we are randomly generate the output.
50 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx-argparse==0.4.0
2 | -e .
3 |
--------------------------------------------------------------------------------
/docs/tutorials/remote_evaluation.rst:
--------------------------------------------------------------------------------
1 | Remote Evaluation
2 | =================
3 |
4 | Stand Alone Agent
5 | -----------------
6 | The agent can run in stand alone mode,
7 | by using :code:`--standalone` option.
8 | The SimulEval will kickoff a server that host the agent.
9 | For instance, with the agent in :ref:`first-agent`,
10 |
11 | .. code-block:: bash
12 |
13 | > simuleval --standalone --remote-port 8888 --agent first_agent.py
14 | 2022-12-06 19:12:26 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent
15 | 2022-12-06 19:12:26 | INFO | simuleval.agent_server | Simultaneous Translation Server Started (process id 53902). Listening to port 8888
16 |
17 | |
18 | For custom speech to text transcription, you could also use the whisper agent in :ref:`speech-to-text`,
19 |
20 | .. code-block:: bash
21 |
22 | > simuleval --standalone --remote-port 8888 --agent whisper_waitk.py --waitk-lagging 3
23 | 2024-08-11 11:51:56 | INFO | simuleval.utils.agent | System will run on device: cpu. dtype: fp32
24 | 2024-08-11 11:51:56 | INFO | simuleval.agent_server | Simultaneous Translation Server Started (process id 38768). Listening to port 8888
25 |
26 | For detailed RESTful APIs, please see (TODO)
27 |
28 | Docker
29 | -----------------
30 | You can also use a docker image to run the simuleval.
31 | An minimal example of :code:`Dockerfile` is
32 |
33 | .. literalinclude:: ../../examples/quick_start/Dockerfile
34 | :language: docker
35 |
36 | Build and run the docker image:
37 |
38 | .. code-block:: bash
39 |
40 | cd examples/quick_start && docker build -t simuleval_agent .
41 | docker run -p 8888:8888 simuleval_agent:latest
42 |
43 | |
44 | The custom audio file speech to text :code:`Dockerfile` is
45 |
46 | .. literalinclude:: ../../examples/speech_to_text/Dockerfile
47 | :language: docker
48 |
49 | Build and run the docker image:
50 |
51 | .. code-block:: bash
52 |
53 | cd examples/speech_to_text && docker build -t simuleval-speech-to-text:1.0 .
54 | docker run -p 8888:8888 simuleval-speech-to-text:1.0
55 |
56 | Remote Evaluation
57 | ------------------
58 | If there is an agent server or docker image available,
59 | (let's say the one we just kickoff at localhost:8888)
60 | We can start a remote evaluator as follow. For simplicity we assume they are on the same machine
61 |
62 | .. code-block:: bash
63 |
64 | simuleval --remote-eval --remote-port 8888 \
65 | --source source.txt --target target.txt \
66 | --source-type text --target-type text
67 |
68 | |
69 | For whisper agent's speech to text:
70 |
71 | .. code-block:: bash
72 |
73 | simuleval --remote-eval --remote-port 8888 \
74 | --source-segment-size 500 \
75 | --source source.txt --target reference/transcript.txt \
76 | --source-type speech --target-type text \
77 | --output output --quality-metrics WER
--------------------------------------------------------------------------------
/docs/tutorials/speech_to_speech.rst:
--------------------------------------------------------------------------------
1 | Speech-to-Speech
2 | ================
--------------------------------------------------------------------------------
/docs/tutorials/speech_to_text.rst:
--------------------------------------------------------------------------------
1 | .. _speech-to-text:
2 |
3 | Speech-to-Text
4 | ==============
5 |
6 | Whisper Agent
7 | -----------------
8 | Use whisper to evaluate custom audio for speech to text transcription.
9 | First, change directory to :code:`speech_to_text`:
10 |
11 | .. code-block:: bash
12 |
13 | cd examples/speech-to-text
14 |
15 | Then, run the example code:
16 |
17 | .. code-block:: bash
18 |
19 | simuleval \
20 | --agent whisper_waitk.py \
21 | --source-segment-size 500 \
22 | --waitk-lagging 3 \
23 | --source source.txt --target reference/transcript.txt \
24 | --output output --quality-metrics WER --visualize
25 |
26 | The optional :code:`--visualize` tag generates N number of graphs in speech_to_text/output/visual directory where N corresponds to the number of source audio provided. An example graph can be seen `here `_.
27 |
28 | |
29 | In addition, it supports the :code:`--score-only` command, where it will read data from :code:`instances.log` without running inference, which saves time if you just want the scores.
30 |
31 | .. code-block:: bash
32 |
33 | simuleval --score-only --output output --visualize
--------------------------------------------------------------------------------
/docs/user_guide/agent.rst:
--------------------------------------------------------------------------------
1 | Agent
2 | =====
3 |
4 | To evaluate the simultaneous translation system,
5 | the users need to implement agent class which operate the system logics.
6 | This section will introduce how to implement an agent.
7 |
8 | Source-Target Types
9 | -------------------
10 | First of all,
11 | we must declare the source and target types of the agent class.
12 | It can be done by inheriting from
13 |
14 | - One of the following four built-in agent types
15 |
16 | - :class:`simuleval.agents.TextToTextAgent`
17 | - :class:`simuleval.agents.SpeechToTextAgent`
18 | - :class:`simuleval.agents.TextToSpeechAgent`
19 | - :class:`simuleval.agents.SpeechToSpeechAgent`
20 |
21 | - Or :class:`simuleval.agents.GenericAgent`, with explicit declaration of :code:`source_type` and :code:`target_type`.
22 |
23 | The follow two examples are equivalent.
24 |
25 | .. code-block:: python
26 |
27 | from simuleval import simuleval
28 | from simuleval.agents import GenericAgent
29 |
30 | class MySpeechToTextAgent(GenericAgent):
31 | source_type = "Speech"
32 | target_type = "Text"
33 | ....
34 |
35 | .. code-block:: python
36 |
37 | from simuleval.agents import SpeechToSpeechAgent
38 |
39 | class MySpeechToTextAgent(SpeechToSpeechAgent):
40 | ....
41 |
42 | .. _agent_policy:
43 |
44 | Policy
45 | ------
46 |
47 | The agent must have a :code:`policy` method which must return one of two actions, :code:`ReadAction` and :code:`WriteAction`.
48 | For example, an agent with a :code:`policy` method should look like this
49 |
50 | .. code-block:: python
51 |
52 | class MySpeechToTextAgent(SpeechToSpeechAgent):
53 | def policy(self):
54 | if do_we_need_more_input(self.states):
55 | return ReadAction()
56 | else:
57 | prediction = generate_a_token(self.states)
58 | finished = is_sentence_finished(self.states)
59 | return WriteAction(prediction, finished=finished)
60 |
61 |
62 | ..
63 | .. autoclass:: simuleval.agents.actions.WriteAction
64 |
65 | ..
66 | .. autoclass:: simuleval.agents.actions.ReadAction
67 |
68 | States
69 | ------------
70 | Each agent has the attribute the :code:`states` to keep track of the progress of decoding.
71 | The :code:`states` attribute will be reset at the beginning of each sentence.
72 | SimulEval provide an built-in states :class:`simuleval.agents.states.AgentStates`,
73 | which has some basic attributes such source and target sequences.
74 | The users can also define customized states with :code:`Agent.build_states` method:
75 |
76 | .. code-block:: python
77 |
78 | from simuleval.agents.states import AgentStates
79 | from dataclasses import dataclass
80 |
81 | @dataclass
82 | class MyComplicatedStates(AgentStates)
83 | some_very_useful_variable: int
84 |
85 | def reset(self):
86 | super().reset()
87 | # also remember to reset the value
88 | some_very_useful_variable = 0
89 |
90 | class MySpeechToTextAgent(SpeechToSpeechAgent):
91 | def build_states(self):
92 | return MyComplicatedStates(0)
93 |
94 | def policy(self):
95 | some_very_useful_variable = self.states.some_very_useful_variable
96 | ...
97 | self.states.some_very_useful_variable = new_value
98 | ...
99 |
100 | ..
101 | .. autoclass:: simuleval.agents.states.AgentStates
102 | :members:
103 |
104 |
105 | Pipeline
106 | --------
107 | The simultaneous system can consist several different components.
108 | For instance, a simultaneous speech-to-text translation can have a streaming automatic speech recognition system and simultaneous text-to-text translation system.
109 | SimulEval introduces the agent pipeline to support this function.
110 | The following is a minimal example.
111 | We concatenate two wait-k systems with different rates (:code:`k=2` and :code:`k=3`)
112 | Note that if there are more than one agent class define,
113 | the :code:`@entrypoint` decorator has to be used to determine the entry point,
114 | or `--user-dir` and `--agent-class` must be specified. See `simuleval/test/first_agent.py`
115 | for an example.
116 |
117 | .. literalinclude:: ../../examples/quick_start/agent_pipeline.py
118 | :language: python
119 | :lines: 7-
120 |
121 | Customized Arguments
122 | -----------------------
123 |
124 | It is often the case that we need to pass some customized arguments for the system to configure different settings.
125 | The agent class has a built-in static method :code:`add_args` for this purpose.
126 | The following is an updated version of the dummy agent from :ref:`first-agent`.
127 |
128 | .. literalinclude:: ../../examples/quick_start/agent_with_configs.py
129 | :language: python
130 | :lines: 6-
131 |
132 | Then just simply pass the arguments through command line as follow.
133 |
134 | .. code-block:: bash
135 |
136 | simuleval \
137 | --source source.txt --source target.txt \ # data arguments
138 | --agent dummy_waitk_text_agent_v2.py \
139 | --waitk 3 --vocab data/dict.txt # agent arguments
140 |
141 | Load Agents from Python Class
142 | -----------------------------
143 |
144 | If you have the agent class in the python environment, for instance
145 |
146 | .. literalinclude:: ../../examples/quick_start/agent_with_configs.py
147 | :language: python
148 | :lines: 6-
149 |
150 | You can also start the evaluation with following command
151 |
152 | .. code-block:: bash
153 |
154 | simuleval \
155 | --source source.txt --source target.txt \ # data arguments
156 | --agent-class DummyWaitkTextAgent \
157 | --waitk 3 --vocab data/dict.txt # agent arguments
158 |
159 |
160 | Load Agents from Directory
161 | --------------------------
162 |
163 | Agent can also be loaded from a directory, which will be referred to as system directory.
164 | The system directory should have everything required to start the agent. Again use the following agent as example
165 |
166 | .. literalinclude:: ../../examples/quick_start/agent_with_configs.py
167 | :language: python
168 | :lines: 6-
169 |
170 | and the system directory has
171 |
172 | .. code-block:: bash
173 |
174 | > ls ${system_dir}
175 | main.yaml dict.txt
176 |
177 | Where the `main.yaml` has all the command line options. The path will be the relative path to the `${system_dir}`.
178 |
179 | .. code-block:: yaml
180 |
181 | waitk: 3
182 | vocab: dict.txt
183 |
184 | The agent can then be started as following
185 |
186 | .. code-block:: bash
187 |
188 | simuleval \
189 | --source source.txt --source target.txt \ # data arguments
190 | --system-dir ${system_dir}
191 |
192 | By default, the `main.yaml` will be read. You can also have multiple YAML files in the system directory and pass them through command line arguments
193 |
194 | .. code-block:: bash
195 | > ls ${system_dir}
196 | main.yaml dict.txt v1.yaml
197 |
198 | > simuleval \
199 | --source source.txt --source target.txt \ # data arguments
200 | --system-dir ${system_dir} --system-config v1.yaml
--------------------------------------------------------------------------------
/docs/user_guide/dataloader.rst:
--------------------------------------------------------------------------------
1 | Dataloader
2 | ===========
3 | There are two ways to load data.
4 |
5 | .. autoclass:: simuleval.data.dataloader.GenericDataloader
--------------------------------------------------------------------------------
/docs/user_guide/evaluator.rst:
--------------------------------------------------------------------------------
1 | Evaluator
2 | =========
3 |
4 | The evaluation in SimulEval implemented as the Evaluator shown below.
5 | It runs on sentence level, and will score the translation on quality and latency.
6 | The user can use :code:`--quality-metrics` and :code:`--latency-metrics` to choose the metrics.
7 | The final results along with the logs will be saved at :code:`--output` if given.
8 |
9 | .. autoclass:: simuleval.evaluator.evaluator.SentenceLevelEvaluator
10 |
11 | Quality Scorers
12 | ---------------
13 |
14 | .. autoclass:: simuleval.evaluator.scorers.quality_scorer.SacreBLEUScorer
15 | .. autoclass:: simuleval.evaluator.scorers.quality_scorer.ASRSacreBLEUScorer
16 |
17 | Latency Scorers
18 | ---------------
19 |
20 | .. autoclass:: simuleval.evaluator.scorers.latency_scorer.ALScorer
21 | :members:
22 |
23 | .. autoclass:: simuleval.evaluator.scorers.latency_scorer.APScorer
24 | :members:
25 |
26 | .. autoclass:: simuleval.evaluator.scorers.latency_scorer.DALScorer
27 | :members:
28 |
29 | Customized Scorers
30 | ------------------
31 | To add customized scorers, the user can use :code:`@register_latency_scorer` or :code:`@register_quality_scorer` to decorate a scorer class.
32 | and use :code:`--quality-metrics` and :code:`--latency-metrics` to call the scorer. For example:
33 |
34 | .. literalinclude:: ../../examples/quick_start/agent_with_new_metrics.py
35 | :lines: 6-
36 |
37 | .. code-block:: bash
38 |
39 | > simuleval --source source.txt --target target.txt --agent agent_with_new_metrics.py --latency-metrics RTF
40 | 2022-12-06 12:56:01 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent
41 | 2022-12-06 12:56:01 | INFO | simuleval.dataloader | Evaluating from text to text.
42 | 2022-12-06 12:56:01 | INFO | simuleval.sentence_level_evaluator | Results:
43 | BLEU RTF
44 | 1.593 1.078
45 |
--------------------------------------------------------------------------------
/docs/user_guide/introduction.rst:
--------------------------------------------------------------------------------
1 | Introduction
2 | ============
3 | Different from offline translation system, the evaluation of simultaneous translation requires incremental decoding with an streaming input.
4 | The simultaneous introduce the a front-end / back-end setup, shown as follow.
5 |
6 | The back-end contains one or multiple user-defined agents which make decisions of whether to generate prediction at a certain point.
7 | The agent can also considered as queue, where the input are keep pushed in and policy decides the timing to pop the output.
8 |
9 | The front-end on the other side, represent the source of input and recipient of the system prediction.
10 | In deployment, the front-end can be web page or cell phone app.
11 | In SimulEval, the front-end is the evaluator , which feeds streaming input to back-end, receive prediction and track the delays.
12 | The front-end and back-end can run separately for different purpose.
13 |
14 | The evaluation process can summarized as follow pseudocode
15 |
16 | .. code-block:: python
17 |
18 | for instance in evaluator.instances:
19 | while not instance.finished:
20 | input_segment = instance.send_source()
21 | prediction = agent.pushpop(input_segment)
22 | if prediction is not None:
23 | instance.receive_prediction(prediction)
24 |
25 | results = [scorer.score() for scorer in evaluate.scorers]
26 |
27 |
28 |
29 | The common usage of SimulEval is as follow
30 |
31 | .. code-block:: bash
32 |
33 | simuleval DATALOADER_OPTIONS EVALUATOR_OPTIONS --agent $AGENT_FILE AGENT_OPTIONS
34 |
35 | We will introduce the usage of the toolkit based on these three major components: Agent, Dataloader and Evaluator.
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/SimulEval/536de8253b82d805c9845440169a5010ff507357/examples/__init__.py
--------------------------------------------------------------------------------
/examples/demo/silero_vad.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import queue
3 | import time
4 | import torch
5 | import numpy as np
6 | import soundfile
7 | from argparse import Namespace, ArgumentParser
8 | from simuleval.agents import SpeechToSpeechAgent, AgentStates
9 | from simuleval.agents.actions import WriteAction, ReadAction
10 | from simuleval.data.segments import Segment, EmptySegment, SpeechSegment
11 |
12 | logger = logging.getLogger()
13 |
14 |
15 | class SileroVADStates(AgentStates):
16 | def __init__(self, args):
17 | self.model, utils = torch.hub.load(
18 | repo_or_dir="snakers4/silero-vad",
19 | model="silero_vad",
20 | force_reload=False,
21 | onnx=False,
22 | )
23 |
24 | (
25 | self.get_speech_timestamps,
26 | self.save_audio,
27 | self.read_audio,
28 | self.VADIterator,
29 | self.collect_chunks,
30 | ) = utils
31 | self.silence_limit_ms = args.silence_limit_ms
32 | self.window_size_samples = args.window_size_samples
33 | self.chunk_size_samples = args.chunk_size_samples
34 | self.sample_rate = args.sample_rate
35 | self.debug = args.debug
36 | self.test_input_segments_wav = None
37 | self.debug_log(args)
38 | self.input_queue: queue.Queue[Segment] = queue.Queue()
39 | self.next_input_queue: queue.Queue[Segment] = queue.Queue()
40 | super().__init__()
41 |
42 | def clear_queues(self):
43 | self.debug_log(f"clearing {self.input_queue.qsize()} chunks")
44 | while not self.input_queue.empty():
45 | self.input_queue.get_nowait()
46 | self.input_queue.task_done()
47 | self.debug_log(f"moving {self.next_input_queue.qsize()} chunks")
48 | # move everything from next_input_queue to input_queue
49 | while not self.next_input_queue.empty():
50 | chunk = self.next_input_queue.get_nowait()
51 | self.next_input_queue.task_done()
52 | self.input_queue.put_nowait(chunk)
53 |
54 | def reset(self) -> None:
55 | super().reset()
56 | # TODO: in seamless_server, report latency for each new segment
57 | self.first_input_ts = None
58 | self.silence_acc_ms = 0
59 | self.input_chunk = np.empty(0, dtype=np.int16)
60 | self.is_fresh_state = True
61 | self.clear_queues()
62 | self.model.reset_states()
63 |
64 | def get_speech_prob_from_np_float32(self, segment: np.ndarray):
65 | t = torch.from_numpy(segment)
66 | speech_probs = []
67 | # print("len(t): ", len(t))
68 | for i in range(0, len(t), self.window_size_samples):
69 | chunk = t[i : i + self.window_size_samples]
70 | if len(chunk) < self.window_size_samples:
71 | break
72 | speech_prob = self.model(chunk, self.sample_rate).item()
73 | speech_probs.append(speech_prob)
74 | return speech_probs
75 |
76 | def debug_log(self, m):
77 | if self.debug:
78 | logger.info(m)
79 |
80 | def process_speech(self, segment):
81 | """
82 | Process a full or partial speech chunk
83 | """
84 | queue = self.input_queue
85 | if self.source_finished:
86 | # current source is finished, but next speech starts to come in already
87 | self.debug_log("use next_input_queue")
88 | queue = self.next_input_queue
89 |
90 | # NOTE: we don't reset silence_acc_ms here so that once an utterance
91 | # becomes longer (accumulating more silence), it has a higher chance
92 | # of being segmented.
93 | # self.silence_acc_ms = 0
94 |
95 | if self.first_input_ts is None:
96 | self.first_input_ts = time.time() * 1000
97 |
98 | while len(segment) > 0:
99 | # add chunks to states.buffer
100 | i = self.chunk_size_samples - len(self.input_chunk)
101 | self.input_chunk = np.concatenate((self.input_chunk, segment[:i]))
102 | segment = segment[i:]
103 | self.is_fresh_state = False
104 | if len(self.input_chunk) == self.chunk_size_samples:
105 | queue.put_nowait(
106 | SpeechSegment(content=self.input_chunk, finished=False)
107 | )
108 | self.input_chunk = np.empty(0, dtype=np.int16)
109 |
110 | def check_silence_acc(self):
111 | if self.silence_acc_ms >= self.silence_limit_ms:
112 | self.silence_acc_ms = 0
113 | if self.input_chunk.size > 0:
114 | # flush partial input_chunk
115 | self.input_queue.put_nowait(
116 | SpeechSegment(content=self.input_chunk, finished=True)
117 | )
118 | self.input_chunk = np.empty(0, dtype=np.int16)
119 | self.input_queue.put_nowait(EmptySegment(finished=True))
120 | self.source_finished = True
121 |
122 | def update_source(self, segment: np.ndarray):
123 | speech_probs = self.get_speech_prob_from_np_float32(segment)
124 | chunk_size_ms = len(segment) * 1000 / self.sample_rate
125 | self.debug_log(
126 | f"{chunk_size_ms}, {len(speech_probs)} {[round(s, 2) for s in speech_probs]}"
127 | )
128 | window_size_ms = self.window_size_samples * 1000 / self.sample_rate
129 | if all(i <= 0.5 for i in speech_probs):
130 | if self.source_finished:
131 | return
132 | self.debug_log("got silent chunk")
133 | if not self.is_fresh_state:
134 | self.silence_acc_ms += chunk_size_ms
135 | self.check_silence_acc()
136 | return
137 | elif speech_probs[-1] <= 0.5:
138 | self.debug_log("=== start of silence chunk")
139 | # beginning = speech, end = silence
140 | # pass to process_speech and accumulate silence
141 | self.process_speech(segment)
142 | # accumulate contiguous silence
143 | for i in range(len(speech_probs) - 1, -1, -1):
144 | if speech_probs[i] > 0.5:
145 | break
146 | self.silence_acc_ms += window_size_ms
147 | self.check_silence_acc()
148 | elif speech_probs[0] <= 0.5:
149 | self.debug_log("=== start of speech chunk")
150 | # beginning = silence, end = speech
151 | # accumulate silence , pass next to process_speech
152 | for i in range(0, len(speech_probs)):
153 | if speech_probs[i] > 0.5:
154 | break
155 | self.silence_acc_ms += window_size_ms
156 | self.check_silence_acc()
157 | self.process_speech(segment)
158 | else:
159 | self.debug_log("======== got speech chunk")
160 | self.process_speech(segment)
161 |
162 | def debug_write_wav(self, chunk):
163 | if self.test_input_segments_wav is not None:
164 | self.test_input_segments_wav.seek(0, soundfile.SEEK_END)
165 | self.test_input_segments_wav.write(chunk)
166 |
167 |
168 | class SileroVADAgent(SpeechToSpeechAgent):
169 | def __init__(self, args: Namespace) -> None:
170 | super().__init__(args)
171 | self.chunk_size_samples = args.chunk_size_samples
172 | self.args = args
173 |
174 | @staticmethod
175 | def add_args(parser: ArgumentParser):
176 | parser.add_argument(
177 | "--sample-rate",
178 | default=16000,
179 | type=float,
180 | )
181 | parser.add_argument(
182 | "--window-size-samples",
183 | default=512, # sampling_rate // 1000 * 32 => 32 ms at 16000 sample rate
184 | type=int,
185 | help="Window size for passing samples to VAD",
186 | )
187 | parser.add_argument(
188 | "--chunk-size-samples",
189 | default=5120, # sampling_rate // 1000 * 320 => 320 ms at 16000 sample rate
190 | type=int,
191 | help="Chunk size for passing samples to model",
192 | )
193 | parser.add_argument(
194 | "--silence-limit-ms",
195 | default=700,
196 | type=int,
197 | help="send EOS to the input_queue after this amount of silence",
198 | )
199 | parser.add_argument(
200 | "--debug",
201 | default=False,
202 | type=bool,
203 | help="Enable debug logs",
204 | )
205 |
206 | def build_states(self) -> SileroVADStates:
207 | return SileroVADStates(self.args)
208 |
209 | def policy(self, states: SileroVADStates):
210 | states.debug_log(
211 | f"queue size: {states.input_queue.qsize()}, input_chunk size: {len(states.input_chunk)}"
212 | )
213 | content = np.empty(0, dtype=np.int16)
214 | is_finished = states.source_finished
215 | while not states.input_queue.empty():
216 | chunk = states.input_queue.get_nowait()
217 | states.input_queue.task_done()
218 | content = np.concatenate((content, chunk.content))
219 |
220 | states.debug_write_wav(content)
221 | if is_finished:
222 | states.debug_write_wav(np.zeros(16000))
223 |
224 | if len(content) == 0: # empty queue
225 | if not states.source_finished:
226 | return ReadAction()
227 | else:
228 | # NOTE: this should never happen, this logic is a safeguard
229 | segment = EmptySegment(finished=True)
230 | else:
231 | segment = SpeechSegment(
232 | content=content.tolist(),
233 | finished=is_finished,
234 | sample_rate=states.sample_rate,
235 | )
236 |
237 | return WriteAction(segment, finished=is_finished)
238 |
--------------------------------------------------------------------------------
/examples/quick_start/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 | RUN apt-get update \
3 | && apt-get upgrade -y \
4 | && apt-get install -y \
5 | && apt-get -y install apt-utils gcc libpq-dev libsndfile-dev
6 | RUN git clone https://github.com/facebookresearch/SimulEval.git
7 | WORKDIR SimulEval
8 | RUN git checkout v1.1.0
9 | RUN pip install -e .
10 | CMD ["simuleval", "--standalone", "--remote-port", "8888", "--agent", "examples/quick_start/first_agent.py"]
11 |
--------------------------------------------------------------------------------
/examples/quick_start/agent_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 | from simuleval.utils import entrypoint
9 | from simuleval.agents import TextToTextAgent
10 | from simuleval.agents.actions import ReadAction, WriteAction
11 | from simuleval.agents import AgentPipeline
12 |
13 |
14 | class DummyWaitkTextAgent(TextToTextAgent):
15 | waitk = 0
16 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
17 |
18 | def policy(self):
19 | lagging = len(self.states.source) - len(self.states.target)
20 |
21 | if lagging >= self.waitk or self.states.source_finished:
22 | prediction = random.choice(self.vocab)
23 |
24 | return WriteAction(prediction, finished=(lagging <= 1))
25 | else:
26 | return ReadAction()
27 |
28 |
29 | class DummyWait2TextAgent(DummyWaitkTextAgent):
30 | waitk = 2
31 |
32 |
33 | class DummyWait4TextAgent(DummyWaitkTextAgent):
34 | waitk = 4
35 |
36 |
37 | @entrypoint
38 | class DummyPipeline(AgentPipeline):
39 | pipeline = [DummyWait2TextAgent, DummyWait4TextAgent]
40 |
--------------------------------------------------------------------------------
/examples/quick_start/agent_with_configs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 | from simuleval.utils import entrypoint
9 | from simuleval.agents import TextToTextAgent
10 | from simuleval.agents.actions import ReadAction, WriteAction
11 | from argparse import Namespace, ArgumentParser
12 |
13 |
14 | @entrypoint
15 | class DummyWaitkTextAgent(TextToTextAgent):
16 | def __init__(self, args: Namespace):
17 | """Initialize your agent here.
18 | For example loading model, vocab, etc
19 | """
20 | super().__init__(args)
21 | self.waitk = args.waitk
22 | with open(args.vocab) as f:
23 | self.vocab = [line.strip() for line in f]
24 |
25 | @staticmethod
26 | def add_args(parser: ArgumentParser):
27 | """Add customized command line arguments"""
28 | parser.add_argument("--waitk", type=int, default=3)
29 | parser.add_argument("--vocab", type=str)
30 |
31 | def policy(self):
32 | lagging = len(self.states.source) - len(self.states.target)
33 |
34 | if lagging >= self.waitk or self.states.source_finished:
35 | prediction = random.choice(self.vocab)
36 |
37 | return WriteAction(prediction, finished=(lagging <= 1))
38 | else:
39 | return ReadAction()
40 |
--------------------------------------------------------------------------------
/examples/quick_start/agent_with_new_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 | from statistics import mean
9 | from simuleval.utils import entrypoint
10 | from simuleval.evaluator.scorers.latency_scorer import (
11 | register_latency_scorer,
12 | LatencyScorer,
13 | )
14 | from simuleval.agents import TextToTextAgent
15 | from simuleval.agents.actions import ReadAction, WriteAction
16 |
17 |
18 | @register_latency_scorer("RTF")
19 | class RTFScorer(LatencyScorer):
20 | """Real time factor
21 |
22 | Usage:
23 | --latency-metrics RTF
24 | """
25 |
26 | def __call__(self, instances) -> float:
27 | scores = []
28 | for ins in instances.values():
29 | scores.append(ins.delays[-1] / ins.source_length)
30 |
31 | return mean(scores)
32 |
33 |
34 | @entrypoint
35 | class DummyWaitkTextAgent(TextToTextAgent):
36 | waitk = 3
37 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
38 |
39 | def policy(self):
40 | lagging = len(self.states.source) - len(self.states.target)
41 |
42 | if lagging >= self.waitk or self.states.source_finished:
43 | prediction = random.choice(self.vocab)
44 |
45 | return WriteAction(prediction, finished=(lagging <= 1))
46 | else:
47 | return ReadAction()
48 |
--------------------------------------------------------------------------------
/examples/quick_start/dict.txt:
--------------------------------------------------------------------------------
1 | A
2 | B
3 | C
4 | D
5 | E
6 | F
7 | G
8 | H
9 | I
10 | J
11 | K
12 | L
13 | M
14 | N
15 | O
16 | P
17 | Q
18 | R
19 | S
20 | T
21 | U
22 | V
23 | W
24 | X
25 | Y
26 | Z
27 |
--------------------------------------------------------------------------------
/examples/quick_start/first_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 | from simuleval.utils import entrypoint
9 | from simuleval.agents import TextToTextAgent
10 | from simuleval.agents.actions import ReadAction, WriteAction
11 |
12 |
13 | @entrypoint
14 | class DummyWaitkTextAgent(TextToTextAgent):
15 | waitk = 3
16 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
17 |
18 | def policy(self):
19 | lagging = len(self.states.source) - len(self.states.target)
20 |
21 | if lagging >= self.waitk or self.states.source_finished:
22 | prediction = random.choice(self.vocab)
23 |
24 | return WriteAction(prediction, finished=(lagging <= 1))
25 | else:
26 | return ReadAction()
27 |
--------------------------------------------------------------------------------
/examples/quick_start/readme.md:
--------------------------------------------------------------------------------
1 | # Quick Start
2 | Following are some minimal examples to use SimulEval. More details can be found [here](https://simuleval.readthedocs.io/en/v1.1.0/quick_start.html).
3 |
4 | ## First Agent
5 | To evaluate a text-to-text wait-3 system with random output:
6 |
7 | ```
8 | > simuleval --source source.txt --target target.txt --agent first_agent.py
9 |
10 | 2022-12-05 13:43:58 | INFO | simuleval.cli | Evaluate system: DummyWaitkTextAgent
11 | 2022-12-05 13:43:58 | INFO | simuleval.dataloader | Evaluating from text to text.
12 | 2022-12-05 13:43:58 | INFO | simuleval.sentence_level_evaluator | Results:
13 | BLEU AL AP DAL
14 | 1.541 3.0 0.688 3.0
15 |
16 | ```
17 |
18 | ## Agent with Command Line Arguments
19 | ```
20 | simuleval --source source.txt --target target.txt --agent agent_with_configs.py --waitk 3 --vocab dict.txt
21 | ```
22 |
23 | ## Agent Pipeline
24 | ```
25 | simuleval --source source.txt --target target.txt --agent agent_pipeline.py
26 | ```
27 |
28 | ## Agent with New Metrics
29 | ```
30 | simuleval --source source.txt --target target.txt --agent agent_with_new_metrics.py
31 | ```
32 |
33 | ## Standalone Agent & Remote Evaluation
34 | Start an agent server:
35 | ```
36 | simuleval --standalone --remote-port 8888 --agent agent_with_new_metrics.py
37 | ```
38 | Or with docker
39 | ```
40 | docker build -t simuleval_agent .
41 | docker run -p 8888:8888 simuleval_agent:latest
42 | ```
43 |
44 | Start a remote evaluator:
45 | ```
46 | simuleval --remote-eval --source source.txt --target target.txt --source-type text --target-type text --remote-port 8888
47 | ```
48 |
--------------------------------------------------------------------------------
/examples/quick_start/source.txt:
--------------------------------------------------------------------------------
1 | Z U S N B Y X Q L O T
2 | M A J F P G O Y V R H M Z O T A
3 | M O W A O I D H H B O F
4 | N Q N I P C O H A A
5 | G B O J H P W C I A L V
6 | P T Z D E E N T B Y G Z R K
7 | F S H U K R W K S B R K M B B Q F C O U
8 | M H O L W Z G J Y X J B I
9 | A V B F E S F E W Q C S
10 | F N O I E Z B R S C V N S
11 |
--------------------------------------------------------------------------------
/examples/quick_start/spm_detokenizer_agent.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from fairseq.data.encoders import build_bpe
4 |
5 | from simuleval.agents import TextToTextAgent
6 | from simuleval.agents.actions import ReadAction, WriteAction
7 | from simuleval.agents.pipeline import AgentPipeline
8 | from simuleval.agents.states import AgentStates
9 |
10 |
11 | class DummySegmentAgent(TextToTextAgent):
12 | """
13 | This agent just splits on space
14 | """
15 |
16 | def __init__(self, args):
17 | super().__init__(args)
18 | self.segment_k = args.segment_k
19 |
20 | @classmethod
21 | def from_args(cls, args, **kwargs):
22 | return cls(args)
23 |
24 | def add_args(parser: ArgumentParser):
25 | parser.add_argument(
26 | "--segment-k",
27 | type=int,
28 | help="Output segments with this many words",
29 | required=True,
30 | )
31 |
32 | def policy(self, states: AgentStates):
33 | if len(states.source) == self.segment_k or states.source_finished:
34 | out = " ".join(states.source)
35 | states.source = []
36 | return WriteAction(out, finished=states.source_finished)
37 | return ReadAction()
38 |
39 |
40 | class SentencePieceModelDetokenizerAgent(TextToTextAgent):
41 | def __init__(self, args):
42 | super().__init__(args)
43 | self.args.bpe = "sentencepiece"
44 | spm_processor = build_bpe(self.args)
45 | self.spm_processor = spm_processor
46 | self.detokenize_only = args.detokenize_only
47 |
48 | @classmethod
49 | def from_args(cls, args, **kwargs):
50 | return cls(args)
51 |
52 | def add_args(parser: ArgumentParser):
53 | parser.add_argument(
54 | "--sentencepiece-model",
55 | type=str,
56 | help="Path to sentencepiece model.",
57 | required=True,
58 | )
59 | parser.add_argument(
60 | "--detokenize-only",
61 | action="store_true",
62 | default=False,
63 | help=(
64 | "Run detokenization without waiting for new token. By default(False),"
65 | "wait for beginning of next word before finalizing the previous word"
66 | ),
67 | )
68 |
69 | def policy(self, states: AgentStates):
70 | possible_full_words = self.spm_processor.decode(
71 | " ".join([x for x in states.source])
72 | )
73 |
74 | if self.detokenize_only and len(states.source) > 0:
75 | states.source = []
76 | if len(possible_full_words) == 0 and not states.source_finished:
77 | return ReadAction()
78 | else:
79 | return WriteAction(possible_full_words, states.source_finished)
80 |
81 | if states.source_finished:
82 | return WriteAction(possible_full_words, True)
83 | elif len(possible_full_words.split()) > 1:
84 | full_words, last_word = (
85 | possible_full_words.split()[:-1],
86 | possible_full_words.split()[-1],
87 | )
88 | states.source = [self.spm_processor.encode(last_word)]
89 | return WriteAction(" ".join(full_words), finished=False)
90 | else:
91 | return ReadAction()
92 |
93 |
94 | class DummyPipeline(AgentPipeline):
95 | pipeline = [DummySegmentAgent, SentencePieceModelDetokenizerAgent]
96 |
--------------------------------------------------------------------------------
/examples/quick_start/spm_source.txt:
--------------------------------------------------------------------------------
1 | ▁Let ' s ▁do ▁it ▁with out ▁hesitation .
--------------------------------------------------------------------------------
/examples/quick_start/spm_target.txt:
--------------------------------------------------------------------------------
1 | Let's do it without hesitation.
--------------------------------------------------------------------------------
/examples/quick_start/target.txt:
--------------------------------------------------------------------------------
1 | Z U S N B Y X Q L O T
2 | M A J F P G O Y V R H M Z O T A
3 | M O W A O I D H H B O F
4 | N Q N I P C O H A A
5 | G B O J H P W C I A L V
6 | P T Z D E E N T B Y G Z R K
7 | F S H U K R W K S B R K M B B Q F C O U
8 | M H O L W Z G J Y X J B I
9 | A V B F E S F E W Q C S
10 | F N O I E Z B R S C V N S
--------------------------------------------------------------------------------
/examples/speech_to_speech/english_alternate_agent.py:
--------------------------------------------------------------------------------
1 | from simuleval.utils import entrypoint
2 | from simuleval.data.segments import SpeechSegment
3 | from simuleval.agents import SpeechToSpeechAgent
4 | from simuleval.agents.actions import WriteAction, ReadAction
5 | from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
6 | from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
7 |
8 |
9 | class TTSModel:
10 | def __init__(self):
11 | models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
12 | "facebook/fastspeech2-en-ljspeech",
13 | arg_overrides={"vocoder": "hifigan", "fp16": False},
14 | )
15 | TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
16 | self.tts_generator = task.build_generator(models, cfg)
17 | self.tts_task = task
18 | self.tts_model = models[0]
19 | self.tts_model.to("cpu")
20 | self.tts_generator.vocoder.to("cpu")
21 |
22 | def synthesize(self, text):
23 | sample = TTSHubInterface.get_model_input(self.tts_task, text)
24 | if sample["net_input"]["src_lengths"][0] == 0:
25 | return [], 0
26 | for key in sample["net_input"].keys():
27 | if sample["net_input"][key] is not None:
28 | sample["net_input"][key] = sample["net_input"][key].to("cpu")
29 |
30 | wav, rate = TTSHubInterface.get_prediction(
31 | self.tts_task, self.tts_model, self.tts_generator, sample
32 | )
33 | wav = wav.tolist()
34 | return wav, rate
35 |
36 |
37 | @entrypoint
38 | class EnglishAlternateAgent(SpeechToSpeechAgent):
39 | """
40 | Incrementally feed text to this offline Fastspeech2 TTS model,
41 | with an alternating speech pattern that is decrementing.
42 | """
43 |
44 | def __init__(self, args):
45 | super().__init__(args)
46 | self.wait_seconds = args.wait_seconds
47 | self.tts_model = TTSModel()
48 |
49 | @staticmethod
50 | def add_args(parser):
51 | parser.add_argument("--wait-seconds", default=1, type=int)
52 |
53 | def policy(self):
54 | length_in_seconds = round(
55 | len(self.states.source) / self.states.source_sample_rate
56 | )
57 | if not self.states.source_finished and length_in_seconds < self.wait_seconds:
58 | return ReadAction()
59 | if length_in_seconds % 2 == 0:
60 | samples, fs = self.tts_model.synthesize(
61 | f"{8 - length_in_seconds} even even"
62 | )
63 | else:
64 | samples, fs = self.tts_model.synthesize(f"{8 - length_in_seconds} odd odd")
65 |
66 | # A SpeechSegment has to be returned for speech-to-speech translation system
67 | return WriteAction(
68 | SpeechSegment(
69 | content=samples,
70 | sample_rate=fs,
71 | finished=self.states.source_finished,
72 | ),
73 | finished=self.states.source_finished,
74 | )
75 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/english_counter_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from simuleval.agents.states import AgentStates
3 | from simuleval.utils import entrypoint
4 | from simuleval.data.segments import SpeechSegment
5 | from simuleval.agents import SpeechToSpeechAgent
6 | from simuleval.agents.actions import WriteAction, ReadAction
7 | from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
8 | from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
9 |
10 |
11 | class TTSModel:
12 | def __init__(self):
13 | models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
14 | "facebook/fastspeech2-en-ljspeech",
15 | arg_overrides={"vocoder": "hifigan", "fp16": False},
16 | )
17 | TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
18 | self.tts_generator = task.build_generator(models, cfg)
19 | self.tts_task = task
20 | self.tts_model = models[0]
21 | self.tts_model.to("cpu")
22 | self.tts_generator.vocoder.to("cpu")
23 |
24 | def synthesize(self, text):
25 | sample = TTSHubInterface.get_model_input(self.tts_task, text)
26 | if sample["net_input"]["src_lengths"][0] == 0:
27 | return [], 0
28 | for key in sample["net_input"].keys():
29 | if sample["net_input"][key] is not None:
30 | sample["net_input"][key] = sample["net_input"][key].to("cpu")
31 |
32 | wav, rate = TTSHubInterface.get_prediction(
33 | self.tts_task, self.tts_model, self.tts_generator, sample
34 | )
35 | wav = wav.tolist()
36 | return wav, rate
37 |
38 |
39 | @entrypoint
40 | class EnglishSpeechCounter(SpeechToSpeechAgent):
41 | """
42 | Incrementally feed text to this offline Fastspeech2 TTS model,
43 | with a minimum numbers of phonemes every chunk.
44 | """
45 |
46 | def __init__(self, args):
47 | super().__init__(args)
48 | self.wait_seconds = args.wait_seconds
49 | self.tts_model = TTSModel()
50 |
51 | @staticmethod
52 | def add_args(parser):
53 | parser.add_argument("--wait-seconds", default=1, type=int)
54 |
55 | def policy(self, states: Optional[AgentStates] = None):
56 | if states is None:
57 | states = self.states
58 | if states.source_sample_rate == 0:
59 | # empty source, source_sample_rate not set yet
60 | length_in_seconds = 0
61 | else:
62 | length_in_seconds = round(len(states.source) / states.source_sample_rate)
63 | if not states.source_finished and length_in_seconds < self.wait_seconds:
64 | return ReadAction()
65 | samples, fs = self.tts_model.synthesize(f"{length_in_seconds} mississippi")
66 |
67 | # A SpeechSegment has to be returned for speech-to-speech translation system
68 | return WriteAction(
69 | SpeechSegment(
70 | content=samples,
71 | sample_rate=fs,
72 | finished=states.source_finished,
73 | ),
74 | finished=states.source_finished,
75 | )
76 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/eval.sh:
--------------------------------------------------------------------------------
1 | simuleval \
2 | --agent english_counter_agent.py --output output \
3 | --source source.txt --target reference/en.txt --source-segment-size 1000\
4 | --quality-metrics WHISPER_ASR_BLEU \
5 | --target-speech-lang en --transcript-lowercase --transcript-non-punctuation --whisper-model-size large \
6 | --latency-metrics StartOffset EndOffset ATD
7 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/readme.md:
--------------------------------------------------------------------------------
1 | ## Simultaneous Speech-to-Speech Translation
2 |
3 | This tutorial provides a minimal example on how to evaluate a simultaneous speech-to-speech translation system.
4 |
5 | ### Requirements
6 |
7 | To run this example, the following package is required
8 |
9 | - [`whisper`](https://github.com/openai/whisper): for quality evaluation (`WHISPER_ASR_BLEU`).
10 |
11 | ### Agent
12 |
13 | The speech-to-speech agent ([english_counter_agent.py](english_counter_agent.py)) in this example is a counter, which generates a piece of audio every second after an initial wait.
14 | The policy of the agent is show follow. The agent will wait for `self.wait_seconds` seconds,
15 | and generate the audio of `{length_in_seconds} mississippi` every second afterward.
16 |
17 | ```python
18 | def policy(self):
19 | length_in_seconds = round(
20 | len(self.states.source) / self.states.source_sample_rate
21 | )
22 | if not self.states.source_finished and length_in_seconds < self.wait_seconds:
23 | return ReadAction()
24 | print(length_in_seconds)
25 | samples, fs = self.tts_model.synthesize(f"{length_in_seconds} mississippi")
26 |
27 | # A SpeechSegment has to be returned for speech-to-speech translation system
28 | return WriteAction(
29 | SpeechSegment(
30 | content=samples,
31 | sample_rate=fs,
32 | finished=self.states.source_finished,
33 | ),
34 | finished=self.states.source_finished,
35 | )
36 | ```
37 |
38 | Notice that for speech output agent, the `WriteAction` has to contain a `SpeechSegment` class.
39 |
40 | ### Evaluation
41 |
42 | The following command will start an evaluation
43 |
44 | ```bash
45 | simuleval \
46 | --agent english_counter_agent.py --output output \
47 | --source source.txt --target reference/en.txt --source-segment-size 1000\
48 | --quality-metrics WHISPER_ASR_BLEU \
49 | --target-speech-lang en --transcript-lowercase --transcript-non-punctuation\
50 | --latency-metrics StartOffset EndOffset ATD
51 | ```
52 |
53 | For quality evaluation, we use ASR_BLEU, that is transcribing the speech output and compute BLEU score with the reference text. To use this feature, `whisper` has to be installed.
54 |
55 | We use three metrics for latency evaluation
56 |
57 | - `StartOffset`: The starting offset of translation comparing with source audio
58 | - `EndOffset`: The ending offset of translation comparing with source audio
59 | - `ATD`: Average Token Delay
60 |
61 | The results of the evaluation should be as following. The transcripts and alignments can be found in the `output` directory.
62 |
63 | ```
64 | WHISPER_ASR_BLEU StartOffset EndOffset ATD
65 | 100.0 1000.0 1490.703 1248.261
66 | ```
67 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/reference/de.txt:
--------------------------------------------------------------------------------
1 | ein Mississippi zwei Mississippi drei Mississippi vier Mississippi fünf Mississippi sechs Mississippi sieben Mississippi
2 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/reference/en.txt:
--------------------------------------------------------------------------------
1 | one mississippi two mississippi three mississippi four mississippi five mississippi six mississippi seven mississippi
2 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/reference/ja.txt:
--------------------------------------------------------------------------------
1 | 1 ミシシッピ 2 ミシシッピ 3 ミシシッピ 4 ミシシッピ 5 ミシシッピ 6 ミシシッピ 7 ミシシッピ
2 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/reference/tgt_lang.txt:
--------------------------------------------------------------------------------
1 | es
--------------------------------------------------------------------------------
/examples/speech_to_speech/reference/zh.txt:
--------------------------------------------------------------------------------
1 | 一密西西比二密西西比三密西西比四密西西比五密西西比六密西西比七密西西比
2 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/source.txt:
--------------------------------------------------------------------------------
1 | test.wav
2 |
--------------------------------------------------------------------------------
/examples/speech_to_speech/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/SimulEval/536de8253b82d805c9845440169a5010ff507357/examples/speech_to_speech/test.wav
--------------------------------------------------------------------------------
/examples/speech_to_speech_demo/english_counter_pipeline.py:
--------------------------------------------------------------------------------
1 | from simuleval.agents import AgentPipeline
2 | from examples.demo.silero_vad import SileroVADAgent
3 | from examples.speech_to_speech.english_counter_agent import EnglishSpeechCounter
4 |
5 |
6 | class EnglishCounterAgentPipeline(AgentPipeline):
7 | pipeline = [
8 | SileroVADAgent,
9 | EnglishSpeechCounter,
10 | ]
11 |
--------------------------------------------------------------------------------
/examples/speech_to_speech_demo/readme.md:
--------------------------------------------------------------------------------
1 | Running the demo:
2 | 1. Create a directory for the dummy model: `models/$DUMMY_MODEL`
3 | 2. Create a new yaml file `models/$DUMMY_MODEL/vad_main.yaml`, with the following:
4 | ```
5 | agent_class: examples.speech_to_speech_demo.english_counter_pipeline.EnglishCounterAgentPipeline
6 | ```
7 | 3. Set the available agent in `SimulevalAgentDirectory.py` to `$DUMMY_MODEL`
8 | 4. Run `python app.py`
9 |
10 |
11 | - Note: If you get an ImportError for `examples.speech_to_speech_demo`, run `python -c "import examples; print(examples.__file__)"`. If the file is something like `$PREFIX/site-packages/examples/__init__.py`, run `rm -r $PREFIX/site-packages/examples` and try again.
--------------------------------------------------------------------------------
/examples/speech_to_speech_text/tree_agent_pipeline.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from simuleval.agents import TreeAgentPipeline
3 | from examples.speech_to_speech.english_counter_agent import (
4 | EnglishSpeechCounter as EnglishSpeechToSpeech,
5 | )
6 | from examples.speech_to_text.english_counter_agent import (
7 | EnglishSpeechCounter as EnglishSpeechToText,
8 | )
9 | from simuleval.agents.actions import WriteAction
10 | from simuleval.agents.agent import SpeechToTextAgent
11 | from simuleval.agents.states import AgentStates
12 |
13 |
14 | class EnglishWait2SpeechToText(EnglishSpeechToText):
15 | def __init__(self, args):
16 | super().__init__(args)
17 | args.wait_seconds = 2
18 |
19 | @staticmethod
20 | def add_args(parser):
21 | pass
22 |
23 |
24 | class DotSpeechToText(SpeechToTextAgent):
25 | def policy(self, states: Optional[AgentStates] = None):
26 | return WriteAction(
27 | content=".",
28 | finished=states.source_finished,
29 | )
30 |
31 |
32 | class EnglishWait2SpeechToSpeech(EnglishSpeechToSpeech):
33 | def __init__(self, args):
34 | super().__init__(args)
35 | args.wait_seconds = 2
36 |
37 | @staticmethod
38 | def add_args(parser):
39 | pass
40 |
41 |
42 | class DummyTreePipeline(TreeAgentPipeline):
43 | # pipeline is a dict, used to instantiate agents in from_args
44 | pipeline = {
45 | EnglishSpeechToSpeech: [EnglishWait2SpeechToText, EnglishWait2SpeechToSpeech],
46 | EnglishWait2SpeechToSpeech: [],
47 | EnglishWait2SpeechToText: [],
48 | }
49 |
50 |
51 | class TemplateTreeAgentPipeline(TreeAgentPipeline):
52 | def __init__(self, args) -> None:
53 | speech_speech = self.pipeline[0](args)
54 | speech_speech_wait2 = self.pipeline[1](args)
55 | speech_text = self.pipeline[2](args)
56 |
57 | module_dict = {
58 | speech_speech: [speech_text, speech_speech_wait2],
59 | speech_speech_wait2: [],
60 | speech_text: [],
61 | }
62 |
63 | super().__init__(module_dict, args)
64 |
65 | @classmethod
66 | def from_args(cls, args):
67 | return cls(args)
68 |
69 |
70 | class InstantiatedTreeAgentPipeline(TemplateTreeAgentPipeline):
71 | pipeline = [
72 | EnglishSpeechToSpeech,
73 | EnglishWait2SpeechToSpeech,
74 | EnglishWait2SpeechToText,
75 | ]
76 |
77 |
78 | class AnotherInstantiatedTreeAgentPipeline(TemplateTreeAgentPipeline):
79 | pipeline = [
80 | EnglishSpeechToSpeech,
81 | EnglishWait2SpeechToSpeech,
82 | DotSpeechToText, # swap the speech_text module in the pipeline
83 | ]
84 |
--------------------------------------------------------------------------------
/examples/speech_to_text/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 | RUN apt-get update \
3 | && apt-get upgrade -y \
4 | && apt-get -y install apt-utils gcc libpq-dev libsndfile-dev
5 | RUN pip install -U openai-whisper
6 | RUN pip install -U editdistance
7 | RUN git clone https://github.com/facebookresearch/SimulEval
8 | WORKDIR /SimulEval/
9 | RUN pip install -e .
10 | WORKDIR /SimulEval/examples/speech_to_text/
11 | CMD ["simuleval", "--standalone", "--remote-port", "8888", "--agent", "whisper_waitk.py", "--waitk-lagging", "3"]
12 |
--------------------------------------------------------------------------------
/examples/speech_to_text/counter_in_tgt_lang_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from simuleval.agents.states import AgentStates
3 | from simuleval.utils import entrypoint
4 | from simuleval.agents import SpeechToTextAgent
5 | from simuleval.agents.actions import WriteAction, ReadAction
6 |
7 |
8 | @entrypoint
9 | class CounterInTargetLanguage(SpeechToTextAgent):
10 | """
11 | The agent generate the number of seconds from an input audio and output it in the target language text
12 | """
13 |
14 | def __init__(self, args):
15 | super().__init__(args)
16 | self.wait_seconds = args.wait_seconds
17 |
18 | @staticmethod
19 | def add_args(parser):
20 | parser.add_argument("--wait-seconds", default=1, type=int)
21 |
22 | def policy(self, states: Optional[AgentStates] = None):
23 | if states is None:
24 | states = self.states
25 | if states.source_sample_rate == 0:
26 | # empty source, source_sample_rate not set yet
27 | length_in_seconds = 0
28 | else:
29 | length_in_seconds = round(len(states.source) / states.source_sample_rate)
30 |
31 | if not states.source_finished and length_in_seconds < self.wait_seconds:
32 | return ReadAction()
33 |
34 | prediction = f"{length_in_seconds} "
35 | tgt_lang = states.tgt_lang
36 | if tgt_lang == "en":
37 | prediction += "seconds"
38 | elif tgt_lang == "es":
39 | prediction += "segundos"
40 | elif tgt_lang == "de":
41 | prediction += "sekunden"
42 | else:
43 | prediction += ""
44 |
45 | return WriteAction(
46 | content=prediction,
47 | finished=states.source_finished,
48 | )
49 |
--------------------------------------------------------------------------------
/examples/speech_to_text/english_counter_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from simuleval.agents.states import AgentStates
3 | from simuleval.utils import entrypoint
4 | from simuleval.agents import SpeechToTextAgent
5 | from simuleval.agents.actions import WriteAction, ReadAction
6 |
7 |
8 | @entrypoint
9 | class EnglishSpeechCounter(SpeechToTextAgent):
10 | """
11 | The agent generate the number of seconds from an input audio.
12 | """
13 |
14 | def __init__(self, args):
15 | super().__init__(args)
16 | self.wait_seconds = args.wait_seconds
17 |
18 | @staticmethod
19 | def add_args(parser):
20 | parser.add_argument("--wait-seconds", default=1, type=int)
21 |
22 | def policy(self, states: Optional[AgentStates] = None):
23 | if states is None:
24 | states = self.states
25 | if states.source_sample_rate == 0:
26 | # empty source, source_sample_rate not set yet
27 | length_in_seconds = 0
28 | else:
29 | length_in_seconds = round(len(states.source) / states.source_sample_rate)
30 | if not states.source_finished and length_in_seconds < self.wait_seconds:
31 | return ReadAction()
32 |
33 | prediction = f"{length_in_seconds} second"
34 |
35 | return WriteAction(
36 | content=prediction,
37 | finished=states.source_finished,
38 | )
39 |
--------------------------------------------------------------------------------
/examples/speech_to_text/eval.sh:
--------------------------------------------------------------------------------
1 | simuleval \
2 | --agent counter_in_tgt_lang_agent.py \
3 | --source-segment-size 1000 \
4 | --source source.txt --target reference/en.txt \
5 | --tgt-lang reference/tgt_lang.txt \
6 | --output output
7 |
--------------------------------------------------------------------------------
/examples/speech_to_text/readme.md:
--------------------------------------------------------------------------------
1 | ## Simultaneous Speech-to-Text Translation
2 |
3 | This tutorial provides a minimal example on how to evaluate a simultaneous speech-to-text translation system.
4 |
5 | ### Agent
6 |
7 | The speech-to-text agent ([english_counter_agent.py](english_counter_agent.py)) in this example is a counter, which generates number of seconds in text, after waiting for `self.wait_seconds` seconds. The policy finishes when the source is finished.
8 |
9 | ```python
10 | def policy(self):
11 | length_in_seconds = round(
12 | len(self.states.source) / self.states.source_sample_rate
13 | )
14 | if not self.states.source_finished and length_in_seconds < self.wait_seconds:
15 | return ReadAction()
16 |
17 | prediction = f"{length_in_seconds} second"
18 |
19 | return WriteAction(
20 | content=prediction,
21 | finished=self.states.source_finished,
22 | )
23 | ```
24 |
25 | ### Evaluation
26 |
27 | The following command will start an evaluation
28 |
29 | ```bash
30 | simuleval \
31 | --agent english_counter_agent.py \
32 | --source-segment-size 1000 \
33 | --source source.txt --target reference/en.txt \
34 | --output output
35 | ```
36 |
37 | The results of the evaluation should be as following. The detailed results can be found in the `output` directory.
38 |
39 | ```
40 | BLEU LAAL AL AP DAL ATD
41 | 100.0 822.018 822.018 0.581 1061.271 2028.555
42 | ```
43 |
44 | ### Example Streaming ASR / S2T: Whipser Wait-K model
45 |
46 | This section provide a more realistic model. [whisper_waitk.py](whisper_waitk.py) is a streaming ASR agent running [wait-k](https://aclanthology.org/P19-1289/) policy on the [Whisper](https://github.com/openai/whisper) ASR model
47 |
48 | ```bash
49 | simuleval \
50 | --agent whisper_waitk.py \
51 | --source-segment-size 500 \
52 | --waitk-lagging 3 \
53 | --source source.txt --target reference/transcript.txt \
54 | --output output --quality-metrics WER
55 | ```
56 |
57 | The results of the evaluation should be as following. The detailed results can be found in the `output` directory.
58 |
59 | ```
60 | WER LAAL AL AP DAL ATD
61 | 25.0 2353.772 2353.772 0.721 2491.04 2457.847
62 | ```
63 |
64 | This agent can also perform S2T task, by adding `--task translate`.
65 |
66 | ### Streaming Speech-to-Text Demo
67 |
68 | A streaming speech to text demo feature, taking input from user's microphone, sending it to Whisper's wait-k model, and displaying the prediction texts in the terminal.
69 |
70 | 1. Kick off a remote agent. More information [Remote_agent](../../docs/tutorials/remote_evaluation.rst)
71 | 2. Enter demo mode by providing a desired segment size (usually 500ms):
72 |
73 | ```bash
74 | simuleval --remote-eval --demo --source-segment-size 500 --remote-port 8888
75 | ```
76 |
77 | 3. Speak into the microphone and watch the live transcription!
78 | 4. Press ^c (Control C) to exit the program in terminal
79 |
--------------------------------------------------------------------------------
/examples/speech_to_text/reference/en.txt:
--------------------------------------------------------------------------------
1 | 1 second 2 second 3 second 4 second 5 second 6 second 7 second
2 |
--------------------------------------------------------------------------------
/examples/speech_to_text/reference/tgt_lang.txt:
--------------------------------------------------------------------------------
1 | es
--------------------------------------------------------------------------------
/examples/speech_to_text/reference/transcript.txt:
--------------------------------------------------------------------------------
1 | This is a synthesized audio file to test your simultaneous speech to text and to speech to speach translation system.
--------------------------------------------------------------------------------
/examples/speech_to_text/source.txt:
--------------------------------------------------------------------------------
1 | test.wav
2 |
--------------------------------------------------------------------------------
/examples/speech_to_text/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/SimulEval/536de8253b82d805c9845440169a5010ff507357/examples/speech_to_text/test.wav
--------------------------------------------------------------------------------
/examples/speech_to_text/whisper_waitk.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from simuleval.agents.states import AgentStates
3 | from simuleval.utils import entrypoint
4 | from simuleval.data.segments import SpeechSegment
5 | from simuleval.agents import SpeechToTextAgent
6 | from simuleval.agents.actions import WriteAction, ReadAction
7 |
8 | import whisper
9 | import numpy
10 |
11 |
12 | @entrypoint
13 | class WaitkWhisper(SpeechToTextAgent):
14 | """
15 | The agent generate the number of seconds from an input audio.
16 | """
17 |
18 | def __init__(self, args):
19 | super().__init__(args)
20 | self.waitk_lagging = args.waitk_lagging
21 | self.source_segment_size = args.source_segment_size
22 | self.source_language = args.source_language
23 | self.continuous_write = args.continuous_write
24 | self.model_size = args.model_size
25 | self.model = whisper.load_model(self.model_size)
26 | self.task = args.task
27 | if self.task == "translate":
28 | assert (
29 | self.source_language != "en"
30 | ), "source language must be different from en for translation task"
31 |
32 | @staticmethod
33 | def add_args(parser):
34 | parser.add_argument("--waitk-lagging", default=1, type=int)
35 | parser.add_argument("--source-language", default="en", type=str)
36 | parser.add_argument(
37 | "--continuous-write",
38 | default=1,
39 | type=int,
40 | help="Max number of words to write at each step",
41 | )
42 | parser.add_argument("--model-size", default="tiny", type=str)
43 | parser.add_argument(
44 | "--task",
45 | default="transcribe",
46 | type=str,
47 | choices=["transcribe", "translate"],
48 | )
49 |
50 | def policy(self, states: Optional[AgentStates] = None):
51 | if states is None:
52 | states = self.states
53 |
54 | if states.source_sample_rate == 0:
55 | # empty source, source_sample_rate not set yet
56 | length_in_seconds = 0
57 | else:
58 | length_in_seconds = float(len(states.source)) / states.source_sample_rate
59 |
60 | if not states.source_finished:
61 | if (
62 | length_in_seconds * 1000 / self.source_segment_size
63 | ) < self.waitk_lagging:
64 | return ReadAction()
65 |
66 | previous_translation = " ".join(states.target)
67 | # We use the previous translation as a prefix.
68 | options = whisper.DecodingOptions(
69 | prefix=previous_translation,
70 | language=self.source_language,
71 | without_timestamps=True,
72 | fp16=False,
73 | )
74 |
75 | # We encode the whole audio to get the full transcription each time a new audio chunk is received.
76 | audio = whisper.pad_or_trim(numpy.array(states.source).astype("float32"))
77 | mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
78 | output = self.model.decode(mel, options)
79 | prediction = output.text.split()
80 |
81 | if not states.source_finished and self.continuous_write > 0:
82 | prediction = prediction[: self.continuous_write]
83 |
84 | return WriteAction(
85 | content=" ".join(prediction),
86 | finished=states.source_finished,
87 | )
88 |
--------------------------------------------------------------------------------
/examples/speech_to_text_demo/counter_in_tgt_lang_pipeline.py:
--------------------------------------------------------------------------------
1 | from simuleval.agents import AgentPipeline
2 | from examples.demo.silero_vad import SileroVADAgent
3 | from examples.speech_to_text.counter_in_tgt_lang import CounterInTargetLanguage
4 |
5 |
6 | class CounterInTargetLanguageAgentPipeline(AgentPipeline):
7 | pipeline = [
8 | SileroVADAgent,
9 | CounterInTargetLanguage,
10 | ]
11 |
--------------------------------------------------------------------------------
/examples/speech_to_text_demo/english_counter_pipeline.py:
--------------------------------------------------------------------------------
1 | from simuleval.agents import AgentPipeline
2 | from examples.demo.silero_vad import SileroVADAgent
3 | from examples.speech_to_text.english_counter_agent import EnglishSpeechCounter
4 |
5 |
6 | class EnglishCounterAgentPipeline(AgentPipeline):
7 | pipeline = [
8 | SileroVADAgent,
9 | EnglishSpeechCounter,
10 | ]
11 |
--------------------------------------------------------------------------------
/examples/speech_to_text_demo/readme.md:
--------------------------------------------------------------------------------
1 | Running the demo:
2 | 1. Create a directory for the dummy model: `models/$DUMMY_MODEL`
3 | 2. Create a new yaml file `models/$DUMMY_MODEL/vad_main.yaml`, with the following:
4 | ```
5 | agent_class: examples.speech_to_text_demo.english_counter_pipeline.EnglishCounterAgentPipeline
6 | ```
7 | 3. Set the available agent in `SimulevalAgentDirectory.py` to `$DUMMY_MODEL`
8 | 4. Run `python app.py`
9 |
10 |
11 | - Note: If you get an ImportError for `examples.speech_to_text_demo`, run `python -c "import examples; print(examples.__file__)"`. If the file is something like `$PREFIX/site-packages/examples/__init__.py`, run `rm -r $PREFIX/site-packages/examples` and try again.
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | flake8-max-line-length = 127
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import find_packages, setup
8 |
9 | with open("README.md", "r") as readme_file:
10 | long_description = readme_file.read()
11 |
12 | setup(
13 | python_requires=">3.7.0",
14 | name="simuleval",
15 | version="1.1.4",
16 | author="Xutai Ma",
17 | description="SimulEval: A Flexible Toolkit for Automated Machine Translation Evaluation",
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | homepage="https://github.com/facebookresearch/SimulEval.git",
21 | documentation="https://simuleval.readthedocs.io/en/v1.1.0/quick_start.html",
22 | license="LICENSE",
23 | entry_points={
24 | "console_scripts": [
25 | "simuleval = simuleval.cli:main",
26 | ],
27 | },
28 | install_requires=[
29 | "pytest",
30 | "pytest-cov",
31 | "sacrebleu>=2.3.1",
32 | "tornado",
33 | "soundfile",
34 | "pandas",
35 | "requests",
36 | "pytest-flake8",
37 | "textgrid",
38 | "tqdm==4.64.1",
39 | "pyyaml",
40 | "bitarray==2.6.0",
41 | "yt-dlp",
42 | "pydub",
43 | "matplotlib",
44 | ],
45 | classifiers=[
46 | "Programming Language :: Python :: 3",
47 | "License :: OSI Approved :: Apache Software License",
48 | "Operating System :: OS Independent",
49 | "Development Status :: 4 - Beta",
50 | ],
51 | keywords=[
52 | "SimulEval",
53 | "Machine Translation",
54 | "Evaluation",
55 | "Metrics",
56 | "BLEU",
57 | "TER",
58 | "METEOR",
59 | "chrF",
60 | "RIBES",
61 | "WMD",
62 | "Embedding Average",
63 | "Embedding Extrema",
64 | "Embedding Greedy",
65 | "Embedding Average",
66 | "SimulEval",
67 | "SimulEval_Testing_Package_1",
68 | "facebookresearch",
69 | "facebook",
70 | "Meta-Evaluation",
71 | ],
72 | packages=find_packages(
73 | exclude=[
74 | "examples",
75 | "examples.*",
76 | "docs",
77 | "docs.*",
78 | ]
79 | ),
80 | setup_requires=["setuptools_scm"],
81 | use_scm_version=True,
82 | )
83 |
--------------------------------------------------------------------------------
/simuleval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/simuleval/agents/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .agent import ( # noqa
8 | GenericAgent,
9 | SpeechToTextAgent,
10 | SpeechToSpeechAgent,
11 | TextToSpeechAgent,
12 | TextToTextAgent,
13 | )
14 | from .states import AgentStates # noqa
15 | from .actions import Action, ReadAction, WriteAction # noqa
16 | from .pipeline import AgentPipeline, TreeAgentPipeline # noqa
17 |
--------------------------------------------------------------------------------
/simuleval/agents/actions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Tuple, Union, List
8 | from dataclasses import dataclass
9 | from simuleval.data.segments import Segment
10 |
11 |
12 | class Action:
13 | """
14 | Abstract Action class
15 | """
16 |
17 | def is_read(self) -> bool:
18 | """
19 | Whether the action is a read action
20 |
21 | Returns:
22 | bool: True if the action is a read action.
23 | """
24 | assert NotImplementedError
25 |
26 |
27 | class ReadAction(Action):
28 | """
29 | Action to return when policy decide to read one more source segment.
30 | The only way to use it is to return :code:`ReadAction()`
31 | """
32 |
33 | def is_read(self) -> bool:
34 | return True
35 |
36 | def __repr__(self) -> str:
37 | return "ReadAction()"
38 |
39 |
40 | @dataclass
41 | class WriteAction(Action):
42 | """
43 | Action to return when policy decide to generate a prediction
44 |
45 | Args:
46 | content (Union[str, List[float], Tuple[List[float], str]]): The prediction.
47 | finished (bool): Indicates if current sentence is finished.
48 |
49 | .. note::
50 | For text the prediction a str; for speech, it's a list.
51 | For speech_text, it's a tuple[list, str]
52 |
53 | """
54 |
55 | content: Union[str, List[float], Tuple[List[float], str]]
56 | finished: bool
57 |
58 | def is_read(self) -> bool:
59 | return False
60 |
61 | def __repr__(self) -> str:
62 | return f"WriteAction({self.content}, finished={self.finished})"
63 |
--------------------------------------------------------------------------------
/simuleval/agents/agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from inspect import signature
8 | from argparse import Namespace, ArgumentParser
9 | from simuleval.data.segments import (
10 | Segment,
11 | SpeechTextSegment,
12 | TextSegment,
13 | SpeechSegment,
14 | EmptySegment,
15 | )
16 | from typing import Optional, List
17 | from .states import AgentStates
18 | from .actions import Action
19 |
20 |
21 | SEGMENT_TYPE_DICT = {
22 | "text": TextSegment,
23 | "speech": SpeechSegment,
24 | "speech_text": SpeechTextSegment,
25 | }
26 |
27 |
28 | class GenericAgent:
29 | """
30 | Generic Agent class.
31 | """
32 |
33 | source_type = None
34 | target_type = None
35 |
36 | def __init__(self, args: Optional[Namespace] = None) -> None:
37 | if args is not None:
38 | self.args = args
39 | assert self.source_type
40 | assert self.target_type
41 | self.device = "cpu"
42 |
43 | self.states = self.build_states()
44 | self.reset()
45 |
46 | def build_states(self) -> AgentStates:
47 | """
48 | Build states instance for agent
49 |
50 | Returns:
51 | AgentStates: agent states
52 | """
53 | return AgentStates()
54 |
55 | def reset(self) -> None:
56 | """
57 | Reset agent, called every time when a new sentence coming in.
58 | Applies for stateful agents.
59 | """
60 | self.states.reset()
61 |
62 | def policy(self, states: Optional[AgentStates] = None) -> Action:
63 | """
64 | The policy to make decision every time
65 | when the system has new input.
66 | The function has to return an Action instance
67 |
68 | Args:
69 | states (Optional[AgentStates]): an optional states for stateless agent
70 |
71 | Returns:
72 | Action: The actions to make at certain point.
73 |
74 | .. note:
75 |
76 | WriteAction means that the system has a prediction.
77 | ReadAction means that the system needs more source.
78 | When states are provided, the agent will become stateless and ignore self.states.
79 | """
80 | assert NotImplementedError
81 |
82 | def push(
83 | self,
84 | source_segment: Segment,
85 | states: Optional[AgentStates] = None,
86 | upstream_states: Optional[List[AgentStates]] = None,
87 | ) -> None:
88 | """
89 | The function to process the incoming information.
90 |
91 | Args:
92 | source_info (dict): incoming information dictionary
93 | states (Optional[AgentStates]): an optional states for stateless agent
94 | """
95 | if states is None:
96 | states = self.states
97 |
98 | if upstream_states is None:
99 | upstream_states = []
100 |
101 | states.upstream_states = upstream_states
102 |
103 | states.update_config(source_segment.config)
104 | states.update_source(source_segment)
105 |
106 | def pop(self, states: Optional[AgentStates] = None) -> Segment:
107 | """
108 | The function to generate system output.
109 | By default, it first runs policy,
110 | and than returns the output segment.
111 | If the policy decide to read,
112 | it will return an empty segment.
113 |
114 | Args:
115 | states (Optional[AgentStates]): an optional states for stateless agent
116 |
117 | Returns:
118 | Segment: segment to return.
119 | """
120 | if len(signature(self.policy).parameters) == 0:
121 | is_stateless = False
122 | if states:
123 | raise RuntimeError("Feeding states to stateful agents.")
124 | else:
125 | is_stateless = True
126 |
127 | if states is None:
128 | states = self.states
129 |
130 | if states.target_finished:
131 | return EmptySegment(finished=True)
132 |
133 | if is_stateless:
134 | action = self.policy(states)
135 | else:
136 | action = self.policy()
137 |
138 | if not isinstance(action, Action):
139 | raise RuntimeError(
140 | f"The return value of {self.policy.__qualname__} is not an {Action.__qualname__} instance"
141 | )
142 | if action.is_read():
143 | return EmptySegment()
144 | else:
145 | if isinstance(action.content, Segment):
146 | return action.content
147 |
148 | segment = SEGMENT_TYPE_DICT[self.target_type](
149 | index=0, content=action.content, finished=action.finished
150 | )
151 | states.update_target(segment)
152 | return segment
153 |
154 | def pushpop(
155 | self,
156 | segment: Segment,
157 | states: Optional[AgentStates] = None,
158 | upstream_states: Optional[List[AgentStates]] = None,
159 | ) -> Segment:
160 | """
161 | Operate pop immediately after push.
162 |
163 | Args:
164 | segment (Segment): input segment
165 |
166 | Returns:
167 | Segment: output segment
168 | """
169 | self.push(segment, states, upstream_states)
170 | return self.pop(states)
171 |
172 | @staticmethod
173 | def add_args(parser: ArgumentParser):
174 | """
175 | Add agent arguments to parser.
176 | Has to be a static method.
177 |
178 | Args:
179 | parser (ArgumentParser): cli argument parser
180 | """
181 | pass
182 |
183 | @classmethod
184 | def from_args(cls, args):
185 | return cls(args)
186 |
187 | def to(self, device: str, *args, **kwargs) -> None:
188 | """
189 | Move agent to specified device.
190 |
191 | Args:
192 | device (str): Device to move agent to.
193 | """
194 | pass
195 |
196 | def __repr__(self) -> str:
197 | return f"{self.__class__.__name__}[{self.source_type} -> {self.target_type}]"
198 |
199 | def __str__(self) -> str:
200 | return self.__repr__()
201 |
202 |
203 | class SpeechToTextAgent(GenericAgent):
204 | """
205 | Same as generic agent, but with explicit types
206 | speech -> text
207 | """
208 |
209 | source_type: str = "speech"
210 | target_type: str = "text"
211 | tgt_lang: Optional[str] = None
212 |
213 |
214 | class SpeechToSpeechAgent(GenericAgent):
215 | """
216 | Same as generic agent, but with explicit types
217 | speech -> speech
218 | """
219 |
220 | source_type: str = "speech"
221 | target_type: str = "speech"
222 |
223 |
224 | class TextToSpeechAgent(GenericAgent):
225 | """
226 | Same as generic agent, but with explicit types
227 | text -> speech
228 | """
229 |
230 | source_type: str = "text"
231 | target_type: str = "speech"
232 |
233 |
234 | class TextToTextAgent(GenericAgent):
235 | """
236 | Same as generic agent, but with explicit types
237 | text -> text
238 | """
239 |
240 | source_type: str = "text"
241 | target_type: str = "text"
242 |
--------------------------------------------------------------------------------
/simuleval/agents/service.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import os
6 | import json
7 | import logging
8 | from tornado import web, ioloop
9 | from simuleval.data.segments import segment_from_json_string
10 | from simuleval import options
11 |
12 | logger = logging.getLogger("simuleval.agent_server")
13 |
14 |
15 | class SystemHandler(web.RequestHandler):
16 | def initialize(self, system):
17 | self.system = system
18 |
19 | def get(self):
20 | self.write(json.dumps({"info": str(self.system)}))
21 |
22 |
23 | class ResetHandle(SystemHandler):
24 | def post(self):
25 | self.system.reset()
26 |
27 |
28 | class OutputHandler(SystemHandler):
29 | def get(self):
30 | output_segment = self.system.pop()
31 | self.write(output_segment.json())
32 |
33 |
34 | class InputHandler(SystemHandler):
35 | def put(self):
36 | segment = segment_from_json_string(self.request.body)
37 | self.system.push(segment)
38 |
39 |
40 | def start_agent_service(system):
41 | parser = options.general_parser()
42 | options.add_evaluator_args(parser)
43 | args, _ = parser.parse_known_args()
44 | app = web.Application(
45 | [
46 | (r"/reset", ResetHandle, {"system": system}),
47 | (r"/input", InputHandler, {"system": system}),
48 | (r"/output", OutputHandler, {"system": system}),
49 | (r"/", SystemHandler, {"system": system}),
50 | ],
51 | debug=False,
52 | )
53 |
54 | app.listen(args.remote_port, max_buffer_size=1024**3)
55 |
56 | logger.info(
57 | f"Simultaneous Translation Server Started (process id {os.getpid()}). Listening to port {args.remote_port} "
58 | )
59 | ioloop.IOLoop.current().start()
60 |
--------------------------------------------------------------------------------
/simuleval/agents/states.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from simuleval.data.segments import Segment, TextSegment, EmptySegment, SpeechSegment
8 |
9 |
10 | class AgentStates:
11 | """
12 | Tracker of the decoding progress.
13 |
14 | Attributes:
15 | source (list): current source sequence.
16 | target (list): current target sequence.
17 | source_finished (bool): if the source is finished.
18 | target_finished (bool): if the target is finished.
19 | """
20 |
21 | def __init__(self) -> None:
22 | self.reset()
23 |
24 | def reset(self) -> None:
25 | """Reset Agent states"""
26 | self.source = []
27 | self.target = []
28 | self.source_finished = False
29 | self.target_finished = False
30 | self.source_sample_rate = 0
31 | self.target_sample_rate = 0
32 | self.tgt_lang = None
33 | self.upstream_states = []
34 | self.config = {}
35 |
36 | def update_config(self, config: dict):
37 | for k in config.keys():
38 | self.config[k] = config[k]
39 |
40 | def update_source(self, segment: Segment):
41 | """
42 | Update states from input segment
43 |
44 | Args:
45 | segment (~simuleval.agents.segments.Segment): input segment
46 | """
47 | self.source_finished = segment.finished
48 | if isinstance(segment, EmptySegment):
49 | return
50 | elif isinstance(segment, TextSegment):
51 | self.source.append(segment.content)
52 | self.tgt_lang = segment.tgt_lang
53 | elif isinstance(segment, SpeechSegment):
54 | self.source += segment.content
55 | self.source_sample_rate = segment.sample_rate
56 | self.tgt_lang = segment.tgt_lang
57 | else:
58 | raise NotImplementedError
59 |
60 | def update_target(self, segment: Segment):
61 | """
62 | Update states from output segment
63 |
64 | Args:
65 | segment (~simuleval.agents.segments.Segment): input segment
66 | """
67 | self.target_finished = segment.finished
68 | if not self.target_finished:
69 | if isinstance(segment, EmptySegment):
70 | return
71 | elif isinstance(segment, TextSegment):
72 | self.target.append(segment.content)
73 | elif isinstance(segment, SpeechSegment):
74 | self.target += segment.content
75 | self.target_sample_rate = segment.sample_rate
76 | else:
77 | raise NotImplementedError
78 |
--------------------------------------------------------------------------------
/simuleval/analysis/curve.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | import pandas
9 | from pathlib import Path
10 | from typing import Dict, List, Union
11 |
12 |
13 | class SimulEvalResults:
14 | def __init__(self, path: Union[Path, str]) -> None:
15 | self.path = Path(path)
16 | scores_path = self.path / "scores"
17 | if scores_path.exists():
18 | self.is_finished = True
19 | with open(self.path / "scores") as f:
20 | self.scores = json.load(f)
21 | else:
22 | self.is_finished = False
23 | self.scores = {}
24 |
25 | @property
26 | def quality(self) -> float:
27 | if self.is_finished:
28 | if self.scores is None:
29 | return 0
30 | return self.scores["Quality"]["BLEU"]
31 | else:
32 | return 0
33 |
34 | @property
35 | def bleu(self) -> float:
36 | return self.quality
37 |
38 | @property
39 | def latency(self) -> Dict[str, float]:
40 | if self.is_finished:
41 | return self.scores["Latency"]
42 | else:
43 | return {}
44 |
45 | @property
46 | def average_lagging(self):
47 | return self.latency.get("AL", 0)
48 |
49 | @property
50 | def average_lagging_ca(self):
51 | return self.latency.get("AL_CA", 0)
52 |
53 | @property
54 | def average_proportion(self):
55 | return self.latency.get("AP", 0)
56 |
57 | @property
58 | def name(self):
59 | return self.path.name
60 |
61 |
62 | class S2SSimulEvalResults(SimulEvalResults):
63 | @property
64 | def bow_average_lagging(self):
65 | return self.latency.get("BOW", {}).get("AL", 0)
66 |
67 | @property
68 | def cow_average_lagging(self):
69 | return self.latency.get("COW", {}).get("AL", 0)
70 |
71 | @property
72 | def eow_average_lagging(self):
73 | return self.latency.get("EOW", {}).get("AL", 0)
74 |
75 |
76 | class QualityLatencyAnalyzer:
77 | def __init__(self) -> None:
78 | self.score_list: List[SimulEvalResults] = []
79 |
80 | def add_scores_from_path(self, path: Path):
81 | self.score_list.append(SimulEvalResults(path))
82 |
83 | @classmethod
84 | def from_paths(cls, path_list: List[Path]):
85 | analyzer = cls()
86 | for path in path_list:
87 | analyzer.add_scores_from_path(path)
88 | return analyzer
89 |
90 | def summarize(self):
91 | results = []
92 | for score in self.score_list:
93 | if score.bleu == 0:
94 | continue
95 | results.append(
96 | [
97 | score.name,
98 | round(score.average_lagging / 1000, 2),
99 | round(score.average_lagging_ca / 1000, 2),
100 | round(score.average_proportion, 2),
101 | round(score.bleu, 2),
102 | ]
103 | )
104 | results.sort(key=lambda x: x[1])
105 | return pandas.DataFrame(results, columns=["name", "AL", "AL(CA)", "AP", "BLEU"])
106 |
107 |
108 | class S2SQualityLatencyAnalyzer(QualityLatencyAnalyzer):
109 | def add_scores_from_path(self, path: Path):
110 | self.score_list.append(S2SSimulEvalResults(path))
111 |
112 | def summarize(self):
113 | results = []
114 | for score in self.score_list:
115 | if score.bleu == 0:
116 | continue
117 | results.append(
118 | [
119 | score.name,
120 | round(score.bow_average_lagging / 1000, 2),
121 | round(score.cow_average_lagging / 1000, 2),
122 | round(score.eow_average_lagging / 1000, 2),
123 | round(score.bleu, 2),
124 | ]
125 | )
126 | results.sort(key=lambda x: x[1])
127 | return pandas.DataFrame(
128 | results, columns=["name", "BOW_AL", "COW_AL", "EOW_AL", "BLEU"]
129 | )
130 |
--------------------------------------------------------------------------------
/simuleval/cli.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import sys
9 | from argparse import ArgumentParser
10 | from typing import Optional
11 |
12 | from simuleval import options
13 | from simuleval.agents import GenericAgent
14 | from simuleval.agents.service import start_agent_service
15 | from simuleval.evaluator import (
16 | SentenceLevelEvaluator,
17 | build_evaluator,
18 | build_remote_evaluator,
19 | )
20 | from simuleval.utils import EVALUATION_SYSTEM_LIST
21 | from simuleval.utils.agent import build_system_args
22 | from simuleval.utils.arguments import check_argument
23 | from simuleval.utils.slurm import submit_slurm_job
24 |
25 | logging.basicConfig(
26 | format="%(asctime)s | %(levelname)-8s | %(name)-16s | %(message)s",
27 | datefmt="%Y-%m-%d %H:%M:%S",
28 | level=logging.INFO,
29 | stream=sys.stderr,
30 | )
31 |
32 |
33 | logger = logging.getLogger("simuleval.cli")
34 |
35 |
36 | def main():
37 | if check_argument("remote_eval"):
38 | remote_evaluate()
39 | return
40 |
41 | if check_argument("score_only"):
42 | scoring()
43 | return
44 |
45 | if check_argument("slurm"):
46 | submit_slurm_job()
47 | return
48 |
49 | system, args = build_system_args()
50 |
51 | if check_argument("standalone"):
52 | start_agent_service(system)
53 | return
54 |
55 | # build evaluator
56 | evaluator = build_evaluator(args)
57 |
58 | # evaluate system
59 | evaluator(system)
60 |
61 |
62 | def evaluate(
63 | system_class: GenericAgent,
64 | config_dict: dict = {},
65 | parser: Optional[ArgumentParser] = None,
66 | ):
67 | EVALUATION_SYSTEM_LIST.append(system_class)
68 | just_for_arg_check = {}
69 | for key, value in config_dict.items():
70 | if isinstance(value, list):
71 | just_for_arg_check[key] = value[0]
72 | else:
73 | just_for_arg_check[key] = value
74 | if check_argument("slurm", just_for_arg_check):
75 | submit_slurm_job(config_dict, parser)
76 | return
77 |
78 | system, args = build_system_args(config_dict, parser)
79 |
80 | # build evaluator
81 | evaluator = build_evaluator(args)
82 |
83 | # evaluate system
84 | evaluator(system)
85 |
86 |
87 | def scoring():
88 | parser = options.general_parser()
89 | options.add_evaluator_args(parser)
90 | options.add_scorer_args(parser)
91 | options.add_dataloader_args(parser)
92 | args = parser.parse_args()
93 | evaluator = SentenceLevelEvaluator.from_args(args)
94 | print(evaluator.results)
95 |
96 |
97 | def remote_evaluate():
98 | # build evaluator
99 | parser = options.general_parser()
100 | options.add_dataloader_args(parser)
101 | options.add_evaluator_args(parser)
102 | options.add_scorer_args(parser)
103 | args = parser.parse_args()
104 | evaluator = build_remote_evaluator(args)
105 |
106 | # evaluate system
107 | evaluator.remote_eval()
108 |
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/simuleval/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import build_dataloader # noqa
2 |
--------------------------------------------------------------------------------
/simuleval/data/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from argparse import Namespace
9 |
10 | from .dataloader import ( # noqa
11 | GenericDataloader,
12 | register_dataloader,
13 | register_dataloader_class,
14 | SUPPORTED_MEDIUM,
15 | SUPPORTED_SOURCE_MEDIUM,
16 | SUPPORTED_TARGET_MEDIUM,
17 | DATALOADER_DICT,
18 | )
19 | from .t2t_dataloader import TextToTextDataloader # noqa
20 | from .s2t_dataloader import SpeechToTextDataloader # noqa
21 |
22 |
23 | logger = logging.getLogger("simuleval.dataloader")
24 |
25 |
26 | def build_dataloader(args: Namespace) -> GenericDataloader:
27 | dataloader_key = getattr(args, "dataloader", None)
28 | if dataloader_key is not None:
29 | assert dataloader_key in DATALOADER_DICT, f"{dataloader_key} is not defined"
30 | logger.info(f"Evaluating from dataloader {dataloader_key}.")
31 | return DATALOADER_DICT[dataloader_key].from_args(args)
32 | if args.demo:
33 | args.source_type = "speech"
34 | args.target_type = "text"
35 | assert args.source_type in SUPPORTED_SOURCE_MEDIUM
36 | assert args.target_type in SUPPORTED_TARGET_MEDIUM
37 |
38 | logger.info(f"Evaluating from {args.source_type} to {args.target_type}.")
39 | return DATALOADER_DICT[f"{args.source_type}-to-{args.target_type}"].from_args(args)
40 |
--------------------------------------------------------------------------------
/simuleval/data/dataloader/dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import abstractmethod
8 | from argparse import ArgumentParser, Namespace
9 | from typing import Any, Dict, List, Optional, Union
10 |
11 | SUPPORTED_MEDIUM = ["text", "speech"]
12 | SUPPORTED_SOURCE_MEDIUM = ["youtube", "text", "speech"]
13 | SUPPORTED_TARGET_MEDIUM = ["text", "speech"]
14 | DATALOADER_DICT = {}
15 |
16 |
17 | def register_dataloader(name):
18 | def register(cls):
19 | DATALOADER_DICT[name] = cls
20 | return cls
21 |
22 | return register
23 |
24 |
25 | def register_dataloader_class(name, cls):
26 | DATALOADER_DICT[name] = cls
27 |
28 |
29 | class GenericDataloader:
30 | """
31 | Load source and target data
32 |
33 | .. argparse::
34 | :ref: simuleval.options.add_data_args
35 | :passparser:
36 | :prog:
37 |
38 | """
39 |
40 | def __init__(
41 | self,
42 | source_list: List[str],
43 | target_list: Union[List[str], List[None]],
44 | tgt_lang_list: Optional[List[str]] = None,
45 | ) -> None:
46 | self.source_list = source_list
47 | self.target_list = target_list
48 | self.tgt_lang_list = tgt_lang_list
49 | assert len(self.source_list) == len(self.target_list)
50 |
51 | def __len__(self):
52 | return len(self.source_list)
53 |
54 | def get_source(self, index: int) -> Any:
55 | return self.preprocess_source(self.source_list[index])
56 |
57 | def get_target(self, index: int) -> Any:
58 | return self.preprocess_target(self.target_list[index])
59 |
60 | def get_tgt_lang(self, index: int) -> Optional[str]:
61 | if getattr(self, "tgt_lang_list", None) is None or index >= len(
62 | self.tgt_lang_list
63 | ):
64 | return None
65 | else:
66 | return self.tgt_lang_list[index]
67 |
68 | def __getitem__(self, index: int) -> Dict[str, Any]:
69 | return {
70 | "source": self.get_source(index),
71 | "target": self.get_target(index),
72 | "tgt_lang": self.get_tgt_lang(index),
73 | }
74 |
75 | def preprocess_source(self, source: Any) -> Any:
76 | raise NotImplementedError
77 |
78 | def preprocess_target(self, target: Any) -> Any:
79 | raise NotImplementedError
80 |
81 | @classmethod
82 | def from_args(cls, args: Namespace):
83 | return cls(args.source, args.target)
84 |
85 | @staticmethod
86 | def add_args(parser: ArgumentParser):
87 | parser.add_argument(
88 | "--source",
89 | type=str,
90 | help="Source file.",
91 | )
92 | parser.add_argument(
93 | "--target",
94 | type=str,
95 | help="Target file.",
96 | )
97 | parser.add_argument(
98 | "--source-type",
99 | type=str,
100 | choices=SUPPORTED_SOURCE_MEDIUM,
101 | help="Source Data type to evaluate.",
102 | )
103 | parser.add_argument(
104 | "--target-type",
105 | type=str,
106 | choices=SUPPORTED_TARGET_MEDIUM,
107 | help="Data type to evaluate.",
108 | )
109 | parser.add_argument(
110 | "--source-segment-size",
111 | type=int,
112 | default=1,
113 | help="Source segment size, For text the unit is # token, for speech is ms",
114 | )
115 | parser.add_argument(
116 | "--tgt-lang",
117 | type=str,
118 | default=None,
119 | help="Target language",
120 | )
121 |
122 |
123 | class IterableDataloader:
124 | cur_index: int
125 |
126 | @abstractmethod
127 | def __iter__(self):
128 | pass
129 |
130 | @abstractmethod
131 | def __next__(self):
132 | pass
133 |
--------------------------------------------------------------------------------
/simuleval/data/dataloader/s2t_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from __future__ import annotations
8 | from pathlib import Path
9 | from typing import List, Union, Optional
10 | from .dataloader import GenericDataloader
11 | from simuleval.data.dataloader import register_dataloader
12 | from argparse import Namespace
13 | from urllib.parse import urlparse, parse_qs
14 |
15 | try:
16 | import yt_dlp as youtube_dl
17 | from pydub import AudioSegment
18 | except ImportError:
19 | yt_dlp = AudioSegment = None
20 |
21 | try:
22 | import soundfile
23 |
24 | IS_IMPORT_SOUNDFILE = True
25 | except Exception:
26 | IS_IMPORT_SOUNDFILE = False
27 |
28 |
29 | def download_youtube_video(url):
30 | def get_video_id(url):
31 | url_data = urlparse(url)
32 | query = parse_qs(url_data.query)
33 | video = query.get("v", [])
34 | if video:
35 | return video[0]
36 | else:
37 | raise Exception("unrecoginzed url format.")
38 |
39 | id = get_video_id(url)
40 | name = f"{id}.wav"
41 |
42 | if not Path(name).exists():
43 | ydl_opts = {
44 | "format": "bestaudio/best",
45 | "postprocessors": [
46 | {
47 | "key": "FFmpegExtractAudio",
48 | "preferredcodec": "wav",
49 | "preferredquality": "192",
50 | }
51 | ],
52 | "outtmpl": id, # name the file "downloaded_video" with original extension
53 | }
54 | with youtube_dl.YoutubeDL(ydl_opts) as ydl:
55 | ydl.download([url])
56 |
57 | sound = AudioSegment.from_wav(name)
58 | sound = sound.set_channels(1).set_frame_rate(16000)
59 | sound.export(name, format="wav")
60 | return name
61 |
62 |
63 | def load_list_from_file(file_path: Union[Path, str]) -> List[str]:
64 | with open(file_path) as f:
65 | return [line.strip() for line in f]
66 |
67 |
68 | @register_dataloader("speech-to-text")
69 | class SpeechToTextDataloader(GenericDataloader):
70 | def __init__(
71 | self,
72 | source_list: List[str],
73 | target_list: List[str],
74 | tgt_lang_list: Optional[List[str]] = None,
75 | ) -> None:
76 | super().__init__(source_list, target_list, tgt_lang_list)
77 |
78 | def preprocess_source(self, source: Union[Path, str]) -> List[float]:
79 | assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed."
80 | samples, _ = soundfile.read(source, dtype="float32")
81 | samples = samples.tolist()
82 | return samples
83 |
84 | def preprocess_target(self, target: str) -> str:
85 | return target
86 |
87 | def get_source_audio_info(self, index: int) -> soundfile._SoundFileInfo:
88 | return soundfile.info(self.get_source_audio_path(index))
89 |
90 | def get_source_audio_path(self, index: int):
91 | return self.source_list[index]
92 |
93 | @classmethod
94 | def from_files(
95 | cls,
96 | source: Union[Path, str],
97 | target: Union[Path, str],
98 | tgt_lang: Union[Path, str],
99 | ) -> SpeechToTextDataloader:
100 | source_list = load_list_from_file(source)
101 | target_list = load_list_from_file(target)
102 | tgt_lang_list = []
103 | if tgt_lang is not None:
104 | tgt_lang_list = load_list_from_file(tgt_lang)
105 | dataloader = cls(source_list, target_list, tgt_lang_list)
106 | return dataloader
107 |
108 | @classmethod
109 | def from_args(cls, args: Namespace):
110 | args.source_type = "speech"
111 | args.target_type = "text"
112 | if args.demo:
113 | return cls([], [], [])
114 | return cls.from_files(args.source, args.target, args.tgt_lang)
115 |
116 |
117 | @register_dataloader("speech-to-speech")
118 | class SpeechToSpeechDataloader(SpeechToTextDataloader):
119 | @classmethod
120 | def from_files(
121 | cls,
122 | source: Union[Path, str],
123 | target: Union[Path, str],
124 | tgt_lang: Union[Path, str, None] = None,
125 | ) -> SpeechToSpeechDataloader:
126 | source_list = load_list_from_file(source)
127 | target_list = load_list_from_file(target)
128 | tgt_lang_list = []
129 | if tgt_lang is not None:
130 | tgt_lang_list = load_list_from_file(tgt_lang)
131 | dataloader = cls(source_list, target_list, tgt_lang_list)
132 | return dataloader
133 |
134 | @classmethod
135 | def from_args(cls, args: Namespace):
136 | args.source_type = "speech"
137 | args.target_type = "speech"
138 | return cls.from_files(args.source, args.target, args.tgt_lang)
139 |
140 |
141 | @register_dataloader("youtube-to-text")
142 | class YoutubeToTextDataloader(SpeechToTextDataloader):
143 | @classmethod
144 | def from_youtube(
145 | cls, source: Union[Path, str], target: Union[Path, str]
146 | ) -> YoutubeToTextDataloader:
147 | assert AudioSegment is not None
148 | assert youtube_dl is not None
149 | source_list = [download_youtube_video(source)]
150 | target_list = [target]
151 | dataloader = cls(source_list, target_list)
152 | return dataloader
153 |
154 | @classmethod
155 | def from_args(cls, args: Namespace):
156 | args.source_type = "youtube"
157 | args.target_type = "text"
158 | return cls.from_youtube(args.source, args.target)
159 |
160 |
161 | @register_dataloader("youtube-to-speech")
162 | class YoutubeToSpeechDataloader(YoutubeToTextDataloader):
163 | @classmethod
164 | def from_args(cls, args: Namespace):
165 | args.source_type = "youtube"
166 | args.target_type = "speech"
167 | return cls.from_youtube(args.source, args.target)
168 |
--------------------------------------------------------------------------------
/simuleval/data/dataloader/t2t_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from __future__ import annotations
8 | from pathlib import Path
9 | from typing import Callable, List, Union, Optional
10 | from .dataloader import GenericDataloader
11 | from simuleval.data.dataloader import register_dataloader
12 | from argparse import Namespace
13 |
14 |
15 | @register_dataloader("text-to-text")
16 | class TextToTextDataloader(GenericDataloader):
17 | def __init__(
18 | self, source_list: List[str], target_list: Union[List[str], List[None]]
19 | ) -> None:
20 | super().__init__(source_list, target_list)
21 | self.source_splitter = lambda x: x.split()
22 | self.target_splitter = lambda x: x
23 |
24 | def set_source_splitter(self, function: Callable) -> None:
25 | # TODO, make is configurable
26 | self.splitter = function
27 |
28 | def preprocess_source(self, source: str) -> List:
29 | return self.source_splitter(source)
30 |
31 | def preprocess_target(self, target: str) -> List:
32 | return self.target_splitter(target)
33 |
34 | @classmethod
35 | def from_files(
36 | cls, source: Union[Path, str], target: Optional[Union[Path, str]]
37 | ) -> TextToTextDataloader:
38 | assert source
39 | with open(source) as f:
40 | source_list = f.readlines()
41 | if target:
42 | with open(target) as f:
43 | target_list = f.readlines()
44 | else:
45 | target_list = [None for _ in source_list]
46 | dataloader = cls(source_list, target_list)
47 | return dataloader
48 |
49 | @classmethod
50 | def from_args(cls, args: Namespace):
51 | args.source_type = "text"
52 | args.target_type = "text"
53 | return cls.from_files(args.source, args.target)
54 |
--------------------------------------------------------------------------------
/simuleval/data/segments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | from dataclasses import dataclass, field
9 | from typing import Union, Optional
10 |
11 |
12 | @dataclass
13 | class Segment:
14 | index: int = 0
15 | content: list = field(default_factory=list)
16 | finished: bool = False
17 | is_empty: bool = False
18 | data_type: str = None
19 | tgt_lang: str = None
20 | config: dict = field(default_factory=dict)
21 |
22 | def json(self) -> str:
23 | info_dict = {attribute: value for attribute, value in self.__dict__.items()}
24 | return json.dumps(info_dict)
25 |
26 | @classmethod
27 | def from_json(cls, json_string: str):
28 | return cls(**json.loads(json_string))
29 |
30 |
31 | @dataclass
32 | class EmptySegment(Segment):
33 | is_empty: bool = True
34 |
35 |
36 | @dataclass
37 | class TextSegment(Segment):
38 | content: str = ""
39 | data_type: str = "text"
40 | tgt_lang: Optional[str] = None
41 |
42 |
43 | @dataclass
44 | class SpeechSegment(Segment):
45 | sample_rate: int = -1
46 | data_type: str = "speech"
47 | tgt_lang: Optional[str] = None
48 |
49 |
50 | @dataclass
51 | class SpeechTextSegment:
52 | text_segment: Union[EmptySegment, TextSegment]
53 | speech_segment: Union[EmptySegment, SpeechSegment]
54 | data_type: str = "speech_text"
55 |
56 |
57 | def segment_from_json_string(string: str):
58 | info_dict = json.loads(string)
59 | if info_dict["data_type"] == "text":
60 | return TextSegment.from_json(string)
61 | elif info_dict["data_type"] == "speech":
62 | return SpeechSegment.from_json(string)
63 | elif info_dict["data_type"] == "speech_text":
64 | return SpeechTextSegment.from_json(string)
65 | else:
66 | return EmptySegment.from_json(string)
67 |
--------------------------------------------------------------------------------
/simuleval/evaluator/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .evaluator import SentenceLevelEvaluator
8 | from .remote import RemoteEvaluator
9 | from .remote import DemoRemote
10 |
11 |
12 | def build_evaluator(args):
13 | return SentenceLevelEvaluator.from_args(args)
14 |
15 |
16 | def build_remote_evaluator(args):
17 | if args.demo:
18 | return DemoRemote(build_evaluator(args))
19 | return RemoteEvaluator(build_evaluator(args))
20 |
--------------------------------------------------------------------------------
/simuleval/evaluator/evaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import contextlib
8 | import json
9 | import logging
10 | import numbers
11 | import os
12 | from argparse import Namespace
13 | from pathlib import Path
14 | from typing import Dict, Generator, Optional
15 |
16 | import pandas
17 | import yaml
18 | from simuleval.data.dataloader import GenericDataloader, build_dataloader
19 | from simuleval.data.dataloader.dataloader import IterableDataloader
20 | from tqdm import tqdm
21 |
22 | from .instance import INSTANCE_TYPE_DICT, LogInstance
23 | from .scorers import get_scorer_class
24 | from .scorers.latency_scorer import LatencyScorer
25 | from .scorers.quality_scorer import QualityScorer
26 | from ..utils.visualize import Visualize
27 |
28 | try:
29 | import sentencepiece
30 |
31 | IS_IMPORT_SPM = True
32 | except Exception:
33 | IS_IMPORT_SPM = False
34 |
35 |
36 | logger = logging.getLogger("simuleval.sentence_level_evaluator")
37 |
38 |
39 | class SentenceLevelEvaluator(object):
40 | """
41 | Sentence Level evaluator. It iterates over sentence pairs and run evaluation.
42 |
43 |
44 | .. code-block:: python
45 |
46 | for instance in self.maybe_tqdm(self.instances.values()):
47 | agent.reset()
48 | while not instance.finish_prediction:
49 | input_segment = instance.send_source(self.source_segment_size)
50 | output_segment = agent.pushpop(input_segment)
51 | instance.receive_prediction(output_segment)
52 |
53 |
54 | Attributes:
55 | instances: collections of sentence pairs. Instances also keep track of delays.
56 | latency_scorers (List[~simuleval.scorers.latency_scorer.LatencyScorer]): Scorers for latency evaluation.
57 | quality_scorers (List[~simuleval.scorers.latency_scorer.QualityScorer]): Scorers for quality evaluation.
58 | output: output directory
59 |
60 | Evaluator related command line arguments:
61 |
62 | .. argparse::
63 | :ref: simuleval.options.add_evaluator_args
64 | :passparser:
65 | :prog:
66 | """
67 |
68 | def __init__(
69 | self,
70 | dataloader: Optional[GenericDataloader],
71 | quality_scorers: Dict[str, QualityScorer],
72 | latency_scorers: Dict[str, LatencyScorer],
73 | args: Namespace,
74 | ) -> None:
75 | self.dataloader = dataloader
76 | self.quality_scorers = quality_scorers
77 | self.latency_scorers = latency_scorers
78 | self.instances = {}
79 |
80 | self.args = args
81 | self.output = Path(args.output) if args.output else None
82 | self.score_only = args.score_only
83 | self.no_scoring = args.no_scoring
84 | self.source_segment_size = getattr(args, "source_segment_size", 1)
85 | self.source_type = getattr(args, "source_type", None)
86 | self.target_type = getattr(args, "target_type", None)
87 | self.visualize = args.visualize
88 |
89 | self.target_spm_model = None
90 | if args.eval_latency_unit == "spm":
91 | assert args.eval_latency_spm_model
92 | assert IS_IMPORT_SPM
93 | self.target_spm_model = sentencepiece.SentencePieceProcessor(
94 | model_file=args.eval_latency_spm_model
95 | )
96 |
97 | if (
98 | self.source_type is None
99 | and self.target_type is None
100 | and self.output is not None
101 | ):
102 | with open(self.output / "config.yaml") as f:
103 | configs = yaml.safe_load(f)
104 | self.source_type = configs["source_type"]
105 | self.target_type = configs["target_type"]
106 |
107 | assert self.source_type
108 | assert self.target_type
109 |
110 | if self.output is not None:
111 | os.makedirs(self.output, exist_ok=True)
112 | with open(self.output / "config.yaml", "w") as f:
113 | yaml.dump(
114 | {"source_type": self.source_type, "target_type": self.target_type},
115 | f,
116 | default_flow_style=False,
117 | )
118 |
119 | self.instance_class = INSTANCE_TYPE_DICT[
120 | f"{self.source_type}-{self.target_type}"
121 | ]
122 | self.start_index = getattr(args, "start_index", 0)
123 | self.end_index = getattr(args, "end_index", -1)
124 |
125 | if not self.score_only:
126 | if self.output:
127 | if (
128 | self.args.continue_unfinished
129 | and (self.output / "instances.log").exists()
130 | ):
131 | with open(self.output / "instances.log", "r") as f:
132 | line = None
133 | for line in f: # noqa
134 | pass
135 | if line is not None:
136 | last_info = json.loads(line.strip())
137 | self.start_index = last_info["index"] + 1
138 | else:
139 | self.output.mkdir(exist_ok=True, parents=True)
140 | open(self.output / "instances.log", "w").close()
141 | if self.end_index < 0:
142 | assert self.dataloader is not None
143 | self.end_index = len(self.dataloader)
144 |
145 | self.build_instances()
146 |
147 | iterable = self.instances.values()
148 | if isinstance(self.dataloader, IterableDataloader):
149 | iterable = self.dataloader
150 |
151 | if not self.args.no_progress_bar and not self.score_only:
152 | self.iterator = tqdm(
153 | iterable,
154 | initial=self.start_index,
155 | )
156 | else:
157 | self.iterator = iterable
158 |
159 | def write_log(self, instance):
160 | if self.output is not None:
161 | with open(self.output / "instances.log", "a") as f:
162 | f.write(json.dumps(instance.summarize()) + "\n")
163 |
164 | def build_instances(self):
165 | if self.score_only:
166 | self.build_instances_from_log()
167 | else:
168 | self.build_instances_from_dataloader()
169 |
170 | def build_instances_from_log(self):
171 | self.instances = {}
172 | if self.output is not None:
173 | with open(self.output / "instances.log", "r") as f:
174 | for line in f:
175 | instance = LogInstance(line.strip(), self.args.eval_latency_unit)
176 | index = instance.index - self.start_index
177 | self.instances[index] = instance
178 | self.instances[index].set_target_spm_model(self.target_spm_model)
179 |
180 | def build_instances_from_dataloader(self):
181 | if isinstance(self.dataloader, IterableDataloader):
182 | return
183 |
184 | for i in self.get_indices():
185 | self.instances[i] = self.instance_class(i, self.dataloader, self.args)
186 | self.instances[i].set_target_spm_model(self.target_spm_model)
187 |
188 | def __len__(self) -> int:
189 | return self.end_index - self.start_index
190 |
191 | def get_indices(self) -> Generator:
192 | if self.end_index < 0:
193 | self.end_index = max(self.instances.keys()) + 1
194 |
195 | if self.start_index > self.end_index:
196 | return []
197 |
198 | for index in range(self.start_index, self.end_index):
199 | yield index
200 |
201 | @property
202 | def quality(self) -> Dict[str, float]:
203 | return {
204 | name: scorer(self.instances)
205 | for name, scorer in self.quality_scorers.items()
206 | }
207 |
208 | @property
209 | def latency(self) -> Dict[str, float]:
210 | return {
211 | name: scorer(self.instances)
212 | for name, scorer in self.latency_scorers.items()
213 | }
214 |
215 | @property
216 | def results(self):
217 | scores = {**self.quality, **self.latency}
218 | new_scores = {}
219 | for name, value in scores.items():
220 | if isinstance(value, numbers.Number):
221 | value = round(value, 3)
222 | new_scores[name] = [value]
223 |
224 | df = pandas.DataFrame(new_scores)
225 | if self.output and self.visualize:
226 | self.make_visual()
227 | return df
228 |
229 | def dump_results(self) -> None:
230 | results = self.results
231 | if self.output:
232 | results.to_csv(self.output / "scores.tsv", sep="\t", index=False)
233 |
234 | logger.info("Results:")
235 | print(results.to_string(index=False))
236 |
237 | def dump_metrics(self) -> None:
238 | metrics = pandas.DataFrame([ins.metrics for ins in self.instances.values()])
239 | metrics = metrics.round(3)
240 | if self.output:
241 | metrics.to_csv(self.output / "metrics.tsv", sep="\t", index=False)
242 |
243 | def is_finished(self, instance) -> bool:
244 | if hasattr(instance, "source_finished_reading"):
245 | return instance.source_finished_reading
246 | return instance.finish_prediction
247 |
248 | def make_visual(self):
249 | with open(self.output / "instances.log", "r") as file:
250 | for line in file:
251 | # Load data & index
252 | data = json.loads(line)
253 | index = data.get("index", 0)
254 |
255 | # Create object & graph
256 | visualize = Visualize(data, index, self.output)
257 | visualize.make_graph()
258 |
259 | def __call__(self, system):
260 | with (
261 | open(self.output / "instances.log", "a")
262 | if self.output
263 | else contextlib.nullcontext()
264 | ) as file:
265 | system.reset()
266 | for sample in self.iterator:
267 | instance = (
268 | self.instance_class(
269 | self.dataloader.cur_index, self.dataloader, self.args
270 | )
271 | if isinstance(self.dataloader, IterableDataloader)
272 | else sample
273 | )
274 | while not self.is_finished(instance):
275 | input_segment = instance.send_source(self.source_segment_size)
276 | output_segment = system.pushpop(input_segment)
277 | instance.receive_prediction(output_segment)
278 | if instance.finish_prediction:
279 | # if instance.finish_prediction where set by the reader,
280 | # source_finished_reading will be set as well. If it is
281 | # set by any of the intermediate components, then we didn't
282 | # end yet. We are going to clear the state and continue
283 | # processing the rest of the input.
284 | system.reset()
285 |
286 | if not self.score_only and self.output:
287 | file.write(json.dumps(instance.summarize()) + "\n")
288 |
289 | if self.output:
290 | self.build_instances_from_log()
291 | if not self.no_scoring:
292 | self.dump_results()
293 | self.dump_metrics()
294 | if self.output and self.visualize:
295 | self.make_visual()
296 |
297 | @classmethod
298 | def from_args(cls, args):
299 | if not args.score_only:
300 | dataloader = build_dataloader(args)
301 | else:
302 | dataloader = None
303 |
304 | latency_scorers = {}
305 | use_ref_len = not args.no_use_ref_len
306 | for name in args.latency_metrics:
307 | latency_scorers[name] = get_scorer_class("latency", name).from_args(args)
308 | if args.computation_aware:
309 | latency_scorers[name + "_CA"] = get_scorer_class("latency", name)(
310 | computation_aware=True, use_ref_len=use_ref_len
311 | )
312 |
313 | quality_scorers = {}
314 | for name in args.quality_metrics:
315 | quality_scorers[name] = get_scorer_class("quality", name).from_args(args)
316 |
317 | return cls(dataloader, quality_scorers, latency_scorers, args)
318 |
--------------------------------------------------------------------------------
/simuleval/evaluator/remote.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 | import logging
9 | import threading
10 | import time
11 | from queue import Queue
12 | import numpy as np
13 |
14 | try:
15 | import wave
16 | import pyaudio
17 | from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
18 | except:
19 | wave, pyaudio, load_silero_vad, read_audio, get_speech_timestamps = [
20 | None for _ in range(5)
21 | ]
22 |
23 | from simuleval.data.segments import (
24 | Segment,
25 | segment_from_json_string,
26 | SpeechSegment,
27 | EmptySegment,
28 | )
29 | from simuleval.evaluator import SentenceLevelEvaluator
30 | import requests
31 |
32 | logger = logging.getLogger("simuleval.remote_evaluator")
33 |
34 |
35 | class RemoteEvaluator:
36 | def __init__(self, evaluator: SentenceLevelEvaluator) -> None:
37 | self.evaluator = evaluator
38 | self.address = evaluator.args.remote_address
39 | self.port = evaluator.args.remote_port
40 | self.source_segment_size = evaluator.args.source_segment_size
41 | self.base_url = f"http://{self.address}:{self.port}"
42 |
43 | def send_source(self, segment: Segment):
44 | url = f"{self.base_url}/input"
45 | requests.put(url, data=segment.json())
46 |
47 | def receive_prediction(self) -> Segment:
48 | url = f"{self.base_url}/output"
49 | r = requests.get(url)
50 | return segment_from_json_string(r.text)
51 |
52 | def system_reset(self):
53 | requests.post(f"{self.base_url}/reset")
54 |
55 | def results(self):
56 | return self.evaluator.results()
57 |
58 | def remote_eval(self):
59 | for instance in self.evaluator.iterator:
60 | self.system_reset()
61 | while not instance.finish_prediction:
62 | self.send_source(instance.send_source(self.source_segment_size))
63 | # instance.py line 275, returns a segment object with all the floats in the 500 ms range
64 |
65 | output_segment = self.receive_prediction()
66 | # gets the prediction in text! like "This"...
67 | # refreshes each time. "This" for the 1st, "is" for the second
68 |
69 | instance.receive_prediction(output_segment)
70 | # instance.py line 190
71 | # processes data, gets in a prediction list with ["This", "is"] on 2nd iteration
72 | self.evaluator.write_log(instance)
73 |
74 | self.evaluator.dump_results()
75 |
76 |
77 | class DemoRemote(RemoteEvaluator):
78 | def __init__(self, evaluator: SentenceLevelEvaluator) -> None:
79 | if None in [wave, pyaudio, load_silero_vad, read_audio, get_speech_timestamps]:
80 | raise Exception(
81 | "Please install wave, pyaudio, and silero_vad to run the demo"
82 | )
83 | super().__init__(evaluator)
84 | self.float_array = np.asarray([])
85 | self.sample_rate = 16000
86 | self.finished = False
87 | self.queue = Queue(maxsize=0)
88 | self.VADmodel = load_silero_vad()
89 | self.silence_count = 0
90 |
91 | def record_audio(self):
92 | CHUNK = 1024
93 | FORMAT = pyaudio.paInt16
94 | CHANNELS = 1 if sys.platform == "darwin" else 2
95 | RATE = self.sample_rate
96 | RECORD_SECONDS = 10000 # Indefinite time
97 |
98 | with wave.open(f"output.wav", "wb") as wf:
99 | p = pyaudio.PyAudio()
100 | wf.setnchannels(CHANNELS)
101 | wf.setsampwidth(p.get_sample_size(FORMAT))
102 | wf.setframerate(RATE)
103 |
104 | stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True)
105 |
106 | all_data = bytearray()
107 | start = time.time()
108 | for _ in range(0, round(RATE // CHUNK * RECORD_SECONDS)):
109 | data = stream.read(CHUNK)
110 | wf.writeframes(data)
111 | all_data += data
112 | if time.time() - start > (self.source_segment_size / 1000.0):
113 | self.queue.put(all_data)
114 | all_data = bytearray()
115 | start = time.time()
116 |
117 | self.queue.put(all_data)
118 | stream.close()
119 | p.terminate()
120 | self.finished = True
121 |
122 | def remote_eval(self):
123 | # Initialization
124 | self.system_reset()
125 | recording = threading.Thread(target=self.record_audio)
126 | recording.start()
127 |
128 | # Start recording
129 | print("Recording...")
130 | while not self.finished or not self.queue.empty():
131 | data = byte_to_float(self.queue.get()).tolist()
132 | # VAD
133 | speech_timestamps = get_speech_timestamps(
134 | audio=data, model=self.VADmodel, sampling_rate=self.sample_rate
135 | )
136 |
137 | if len(speech_timestamps) != 0: # has audio
138 | self.silence_count = 0
139 | else:
140 | self.silence_count += 1
141 |
142 | if self.silence_count <= 4:
143 | segment = SpeechSegment(
144 | index=self.source_segment_size,
145 | content=data,
146 | sample_rate=self.sample_rate,
147 | finished=False,
148 | )
149 | self.send_source(segment)
150 | output_segment = self.receive_prediction()
151 | if len(output_segment.content) == 0:
152 | continue
153 | prediction_list = str(output_segment.content.replace(" ", ""))
154 | print(prediction_list, end=" ")
155 | sys.stdout.flush()
156 |
157 | else:
158 | segment = SpeechSegment(
159 | index=self.source_segment_size,
160 | content=[0.0 for _ in range(8192)],
161 | sample_rate=self.sample_rate,
162 | finished=True,
163 | )
164 | self.send_source(segment)
165 | output_segment = self.receive_prediction()
166 | self.silence_count = 0
167 | self.system_reset()
168 |
169 |
170 | def pcm2float(sig, dtype="float32"):
171 | sig = np.asarray(sig)
172 | if sig.dtype.kind not in "iu":
173 | raise TypeError("'sig' must be an array of integers")
174 | dtype = np.dtype(dtype)
175 | if dtype.kind != "f":
176 | raise TypeError("'dtype' must be a floating point type")
177 |
178 | # pcm (16 bit) min = -32768, max = 32767, map it to -1 to 1 by dividing by max (32767)
179 | i = np.iinfo(sig.dtype)
180 | abs_max = 2 ** (i.bits - 1)
181 | offset = i.min + abs_max
182 | return (sig.astype(dtype) - offset) / abs_max
183 |
184 |
185 | def byte_to_float(byte):
186 | # byte -> int16(PCM_16) -> float32
187 | return pcm2float(np.frombuffer(byte, dtype=np.int16), dtype="float32")
188 |
--------------------------------------------------------------------------------
/simuleval/evaluator/scorers/__init__.py:
--------------------------------------------------------------------------------
1 | from .latency_scorer import LATENCY_SCORERS_DICT
2 | from .quality_scorer import QUALITY_SCORERS_DICT
3 |
4 |
5 | def get_scorer_class(scorer_type, name):
6 | if scorer_type == "quality":
7 | scorer_dict = QUALITY_SCORERS_DICT
8 | else:
9 | scorer_dict = LATENCY_SCORERS_DICT
10 |
11 | if name not in scorer_dict:
12 | raise RuntimeError(f"No {scorer_type} metric called {name}")
13 |
14 | return scorer_dict[name]
15 |
--------------------------------------------------------------------------------
/simuleval/evaluator/scorers/quality_scorer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import string
9 | import subprocess
10 | from pathlib import Path
11 | from typing import Dict
12 |
13 | import sacrebleu
14 | import tqdm
15 | from sacrebleu.metrics.bleu import BLEU
16 |
17 | QUALITY_SCORERS_DICT = {}
18 |
19 |
20 | def register_quality_scorer(name):
21 | def register(cls):
22 | QUALITY_SCORERS_DICT[name] = cls
23 | return cls
24 |
25 | return register
26 |
27 |
28 | class QualityScorer:
29 | def __init__(self) -> None:
30 | pass
31 |
32 | def __call__(self, instances: Dict) -> float:
33 | raise NotImplementedError
34 |
35 | @staticmethod
36 | def add_args(parser):
37 | pass
38 |
39 |
40 | def add_sacrebleu_args(parser):
41 | parser.add_argument(
42 | "--sacrebleu-tokenizer",
43 | type=str,
44 | default=sacrebleu.metrics.METRICS["BLEU"].TOKENIZER_DEFAULT,
45 | choices=sacrebleu.metrics.METRICS["BLEU"].TOKENIZERS,
46 | help="Tokenizer in sacrebleu",
47 | )
48 |
49 |
50 | @register_quality_scorer("WER")
51 | class WERScorer(QualityScorer):
52 | """
53 | Compute Word Error Rate (WER)
54 |
55 | Usage:
56 | :code:`--quality-metrics WER`
57 | """
58 |
59 | def __init__(self, args) -> None:
60 | super().__init__()
61 | try:
62 | import editdistance as ed
63 | except ImportError:
64 | raise ImportError("Please install editdistance to use WER scorer")
65 | self.logger = logging.getLogger("simuleval.scorer.wer")
66 | self.logger.warning("WER scorer only support language with spaces.")
67 | self.logger.warning(
68 | "Current WER scorer is on raw text (un-tokenized with punctuations)."
69 | )
70 | self.ed = ed
71 |
72 | def __call__(self, instances: Dict) -> float:
73 | distance = 0
74 | ref_length = 0
75 | for ins in instances.values():
76 | distance += self.ed.eval(ins.prediction.split(), ins.reference.split())
77 | ref_length += len(ins.reference.split())
78 | if ref_length == 0:
79 | self.logger.warning("Reference length is 0. Return WER as 0.")
80 | return 0
81 |
82 | return 100.0 * distance / ref_length
83 |
84 | @classmethod
85 | def from_args(cls, args):
86 | return cls(args)
87 |
88 |
89 | @register_quality_scorer("BLEU")
90 | class SacreBLEUScorer(QualityScorer):
91 | """
92 | SacreBLEU Scorer
93 |
94 | Usage:
95 | :code:`--quality-metrics BLEU`
96 |
97 | Additional command line arguments:
98 |
99 | .. argparse::
100 | :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args
101 | :passparser:
102 | :prog:
103 | """
104 |
105 | def __init__(self, tokenizer: str = "13a") -> None:
106 | super().__init__()
107 | self.logger = logging.getLogger("simuleval.scorer.bleu")
108 | self.tokenizer = tokenizer
109 |
110 | def __call__(self, instances: Dict) -> float:
111 | try:
112 | return (
113 | BLEU(tokenize=self.tokenizer)
114 | .corpus_score(
115 | [ins.prediction for ins in instances.values()],
116 | [[ins.reference for ins in instances.values()]],
117 | )
118 | .score
119 | )
120 | except Exception as e:
121 | self.logger.error(str(e))
122 | return 0
123 |
124 | @staticmethod
125 | def add_args(parser):
126 | add_sacrebleu_args(parser)
127 |
128 | @classmethod
129 | def from_args(cls, args):
130 | return cls(args.sacrebleu_tokenizer)
131 |
132 |
133 | @register_quality_scorer("ASR_BLEU")
134 | class ASRSacreBLEUScorer(QualityScorer):
135 | """
136 | ASR + SacreBLEU Scorer (BETA version)
137 |
138 | Usage:
139 | :code:`--quality-metrics ASR_BLEU`
140 |
141 | Additional command line arguments:
142 |
143 | .. argparse::
144 | :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args
145 | :passparser:
146 | :prog:
147 | """
148 |
149 | def __init__(self, tokenizer: str = "13a", target_lang: str = "en") -> None:
150 | super().__init__()
151 | self.logger = logging.getLogger("simuleval.scorer.asr_bleu")
152 | self.tokenizer = tokenizer
153 | self.target_lang = target_lang
154 |
155 | def __call__(self, instances: Dict) -> float:
156 | transcripts = self.asr_transcribe(instances)
157 | score = (
158 | BLEU(tokenize=self.tokenizer)
159 | .corpus_score(
160 | transcripts,
161 | [[ins.reference for ins in instances.values()]],
162 | )
163 | .score
164 | )
165 | return score
166 |
167 | def asr_transcribe(self, instances):
168 | self.logger.warn("Beta feature: Evaluating speech output. Faieseq is required.")
169 | try:
170 | import fairseq
171 |
172 | fairseq_path = Path(fairseq.__path__[0]).parent # type: ignore
173 | except Exception:
174 | self.logger.warn("Please install fairseq.")
175 | return ["" for _ in instances.keys()]
176 |
177 | wav_dir = Path(instances[0].prediction).absolute().parent
178 | root_dir = wav_dir.parent
179 | transcripts_path = root_dir / "asr_transcripts.txt"
180 | asr_cmd_bash_path = root_dir / "asr_cmd.bash"
181 |
182 | # This is a dummy reference. The bleu score will be compute separately.
183 | reference_path = root_dir / "instances.log"
184 |
185 | fairseq_asr_bleu_cmd = "\n".join(
186 | [
187 | f"cd {fairseq_path.as_posix()}/examples/speech_to_speech/asr_bleu/",
188 | " ".join(
189 | [
190 | "python compute_asr_bleu.py",
191 | f"--reference_path {reference_path.as_posix()}",
192 | f"--lang {self.target_lang}",
193 | f"--audio_dirpath {wav_dir.as_posix()}",
194 | "--reference_format txt",
195 | f"--transcripts_path {(root_dir / 'asr_transcripts.txt').as_posix()}",
196 | ]
197 | ),
198 | ]
199 | )
200 | with open(asr_cmd_bash_path, "w") as f:
201 | f.write(fairseq_asr_bleu_cmd + "\n")
202 |
203 | process = subprocess.Popen(["bash", asr_cmd_bash_path], stdout=subprocess.PIPE)
204 | _, stderr = process.communicate()
205 |
206 | if process.returncode != 0:
207 | self.logger.error("ASR on target speech failed:")
208 | self.logger.error(str(stderr) + "\n")
209 | return ["" for _ in instances.keys()]
210 |
211 | with open(transcripts_path, "r") as f:
212 | transcripts = [line.strip() for line in f]
213 |
214 | for idx, item in enumerate(transcripts):
215 | with open(wav_dir / f"{idx}_pred.txt", "w") as f:
216 | f.write(item.lower() + "\n")
217 |
218 | return transcripts
219 |
220 | @staticmethod
221 | def add_args(parser):
222 | add_sacrebleu_args(parser)
223 | parser.add_argument(
224 | "--target-speech-lang",
225 | type=str,
226 | default="en",
227 | help="The language of target speech",
228 | )
229 |
230 | @classmethod
231 | def from_args(cls, args):
232 | return cls(args.sacrebleu_tokenizer, args.target_speech_lang)
233 |
234 |
235 | PUNCTUATIONS_EXCLUDE_APOSTROPHE = (
236 | string.punctuation.replace("'", "") + "¡¨«°³º»¿‘“”…♪♫ˆᵉ™,ʾ˚"
237 | )
238 | PUNCTUATIONS_TO_SPACE = "-/–·—•"
239 |
240 |
241 | def remove_punctuations(text, punctuations=string.punctuation):
242 | text = text.translate(
243 | str.maketrans(PUNCTUATIONS_TO_SPACE, " " * len(PUNCTUATIONS_TO_SPACE))
244 | )
245 | return text.translate(str.maketrans("", "", punctuations))
246 |
247 |
248 | @register_quality_scorer("WHISPER_ASR_BLEU")
249 | class WhisperASRSacreBLEUScorer(QualityScorer):
250 | """
251 | Whisper ASR + SacreBLEU Scorer with whisper model
252 |
253 | Usage:
254 | :code:`--quality-metrics ASR_BLEU`
255 |
256 | Additional command line arguments:
257 |
258 | .. argparse::
259 | :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args
260 | :passparser:
261 | :prog:
262 | """
263 |
264 | def __init__(
265 | self,
266 | tokenizer: str = "13a",
267 | target_lang: str = "en",
268 | model_size: str = "base",
269 | temperature: float = 0.0,
270 | lowercase: bool = False,
271 | remove_punctuations: bool = False,
272 | ) -> None:
273 | super().__init__()
274 | self.logger = logging.getLogger("simuleval.scorer.whisper_asr_bleu")
275 | self.tokenizer = tokenizer
276 | self.target_lang = target_lang
277 | self.model_size = model_size
278 | self.temperature = temperature
279 | self.lowercase = lowercase
280 | self.remove_punctuations = remove_punctuations
281 |
282 | def __call__(self, instances: Dict) -> float:
283 | transcripts = self.asr_transcribe(instances)
284 | score = (
285 | BLEU(tokenize=self.tokenizer)
286 | .corpus_score(
287 | transcripts,
288 | [[ins.reference for ins in instances.values()]],
289 | )
290 | .score
291 | )
292 | return score
293 |
294 | def asr_transcribe(self, instances):
295 | self.logger.info(
296 | "Evaluating speech output by ASR BLEU. whisper and sacrebleu are required."
297 | )
298 | self.logger.info("Configs:")
299 | self.logger.info(f"tokenizer = {self.tokenizer}")
300 | self.logger.info(f"target_lang = {self.target_lang}")
301 | self.logger.info(f"model_size = {self.model_size}")
302 | self.logger.info(f"temperature = {self.temperature}")
303 | self.logger.info(f"lowercase = {self.lowercase}")
304 | self.logger.info(f"remove_punctuations = {self.remove_punctuations}")
305 | try:
306 | import whisper
307 | except Exception:
308 | self.logger.warn("Please install whisper.")
309 | return ["" for _ in instances.keys()]
310 |
311 | model = whisper.load_model(self.model_size)
312 | wav_dir = Path(instances[0].prediction).absolute().parent
313 |
314 | transcripts = []
315 | for index in tqdm.tqdm(instances.keys()):
316 | wav_path = wav_dir / f"{index}_pred.wav"
317 | if wav_path.exists():
318 | result = model.transcribe(
319 | wav_path.as_posix(),
320 | language=self.target_lang,
321 | temperature=self.temperature,
322 | )
323 | text = result["text"]
324 | assert type(text) == str
325 | if self.lowercase:
326 | text = text.lower()
327 | if self.remove_punctuations:
328 | text = remove_punctuations(text)
329 | transcripts.append(text.strip())
330 | else:
331 | transcripts.append("")
332 |
333 | root_dir = wav_dir.parent
334 | transcripts_path = root_dir / "asr_transcripts.txt"
335 | with open(transcripts_path, "w") as f:
336 | for line in transcripts:
337 | f.write(line + "\n")
338 |
339 | return transcripts
340 |
341 | @staticmethod
342 | def add_args(parser):
343 | add_sacrebleu_args(parser)
344 | parser.add_argument(
345 | "--target-speech-lang",
346 | type=str,
347 | default="en",
348 | help="The language of target speech",
349 | )
350 | parser.add_argument(
351 | "--whisper-model-size",
352 | type=str,
353 | default="large",
354 | help="The size of whisper asr model",
355 | )
356 | parser.add_argument(
357 | "--whisper-model-temperature",
358 | type=float,
359 | default=0.0,
360 | help="If temperature > 0.0, the decoding will perform sampling",
361 | )
362 | parser.add_argument(
363 | "--transcript-lowercase",
364 | action="store_true",
365 | help="Lowercase the whisper output",
366 | )
367 | parser.add_argument(
368 | "--transcript-non-punctuation",
369 | action="store_true",
370 | help="Remove punctuations in the whisper output",
371 | )
372 |
373 | @classmethod
374 | def from_args(cls, args):
375 | return cls(
376 | args.sacrebleu_tokenizer,
377 | args.target_speech_lang,
378 | args.whisper_model_size,
379 | args.whisper_model_temperature,
380 | args.transcript_lowercase,
381 | args.transcript_non_punctuation,
382 | )
383 |
--------------------------------------------------------------------------------
/simuleval/options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import importlib
9 | import logging
10 | import os
11 | import sys
12 | from typing import List, Optional
13 |
14 | from simuleval.data.dataloader import (
15 | DATALOADER_DICT,
16 | GenericDataloader,
17 | register_dataloader_class,
18 | )
19 | from simuleval.evaluator.scorers import get_scorer_class
20 |
21 |
22 | def add_dataloader_args(
23 | parser: argparse.ArgumentParser, cli_argument_list: Optional[List[str]] = None
24 | ):
25 | if cli_argument_list is None:
26 | args, _ = parser.parse_known_args()
27 | else:
28 | args, _ = parser.parse_known_args(cli_argument_list)
29 |
30 | if args.dataloader_class:
31 | dataloader_module = importlib.import_module(
32 | ".".join(args.dataloader_class.split(".")[:-1])
33 | )
34 | dataloader_class = getattr(
35 | dataloader_module, args.dataloader_class.split(".")[-1]
36 | )
37 | register_dataloader_class(args.dataloader, dataloader_class)
38 |
39 | dataloader_class = DATALOADER_DICT.get(args.dataloader)
40 | if dataloader_class is None:
41 | dataloader_class = GenericDataloader
42 | dataloader_class.add_args(parser)
43 |
44 |
45 | def add_evaluator_args(parser: argparse.ArgumentParser):
46 | parser.add_argument(
47 | "--quality-metrics",
48 | nargs="+",
49 | default=["BLEU"],
50 | help="Quality metrics",
51 | )
52 | parser.add_argument(
53 | "--latency-metrics",
54 | nargs="+",
55 | default=["LAAL", "AL", "AP", "DAL", "ATD"],
56 | help="Latency metrics",
57 | )
58 | parser.add_argument(
59 | "--continue-unfinished",
60 | action="store_true",
61 | default=False,
62 | help="Continue the experiments in output dir.",
63 | )
64 | parser.add_argument(
65 | "--computation-aware",
66 | action="store_true",
67 | default=False,
68 | help="Include computational latency.",
69 | )
70 | parser.add_argument(
71 | "--no-use-ref-len",
72 | action="store_true",
73 | default=False,
74 | help="Include computational latency.",
75 | )
76 | parser.add_argument(
77 | "--eval-latency-unit",
78 | type=str,
79 | default="word",
80 | choices=["word", "char", "spm"],
81 | help="Basic unit used for latency calculation, choose from "
82 | "words (detokenized) and characters.",
83 | )
84 | parser.add_argument(
85 | "--eval-latency-spm-model",
86 | type=str,
87 | default=None,
88 | help="Pass the spm model path if the eval_latency_unit is spm.",
89 | )
90 | parser.add_argument(
91 | "--remote-address",
92 | default="localhost",
93 | help="Address to client backend",
94 | )
95 | parser.add_argument(
96 | "--remote-port",
97 | default=12321,
98 | help="Port to client backend",
99 | )
100 | parser.add_argument(
101 | "--no-progress-bar",
102 | action="store_true",
103 | default=False,
104 | help="Do not use progress bar",
105 | )
106 | parser.add_argument(
107 | "--start-index",
108 | type=int,
109 | default=0,
110 | help="Start index for evaluation.",
111 | )
112 | parser.add_argument(
113 | "--end-index",
114 | type=int,
115 | default=-1,
116 | help="The last index for evaluation.",
117 | )
118 | parser.add_argument(
119 | "--output",
120 | type=str,
121 | default=None,
122 | help="Output directory. Required if using iterable dataloader.",
123 | )
124 |
125 |
126 | def add_scorer_args(
127 | parser: argparse.ArgumentParser, cli_argument_list: Optional[List[str]] = None
128 | ):
129 | if cli_argument_list is None:
130 | args, _ = parser.parse_known_args()
131 | else:
132 | args, _ = parser.parse_known_args(cli_argument_list)
133 |
134 | for metric in args.latency_metrics:
135 | get_scorer_class("latency", metric).add_args(parser)
136 |
137 | for metric in args.quality_metrics:
138 | get_scorer_class("quality", metric).add_args(parser)
139 |
140 |
141 | def import_user_module(module_path):
142 | module_path = os.path.abspath(module_path)
143 | module_parent, module_name = os.path.split(module_path)
144 |
145 | sys.path.insert(0, module_parent)
146 | importlib.import_module(module_name)
147 | sys.path.pop(0)
148 |
149 |
150 | def general_parser(
151 | config_dict: Optional[dict] = None,
152 | parser: Optional[argparse.ArgumentParser] = None,
153 | ):
154 | if parser is None:
155 | parser = argparse.ArgumentParser(
156 | add_help=False,
157 | description="SimulEval - Simultaneous Evaluation CLI",
158 | conflict_handler="resolve",
159 | )
160 |
161 | parser.add_argument(
162 | "--user-dir",
163 | default=None,
164 | help="path to a python module containing custom agents",
165 | )
166 | args, _ = parser.parse_known_args()
167 | if args.user_dir is not None:
168 | import_user_module(args.user_dir)
169 |
170 | parser.add_argument(
171 | "--remote-eval",
172 | action="store_true",
173 | help="Evaluate a standalone agent",
174 | )
175 | parser.add_argument(
176 | "--standalone",
177 | action="store_true",
178 | help="",
179 | )
180 | parser.add_argument(
181 | "--slurm", action="store_true", default=False, help="Use slurm."
182 | )
183 | parser.add_argument("--agent", default=None, help="Agent file")
184 | parser.add_argument(
185 | "--agent-class",
186 | default=None,
187 | help="The full string of class of the agent.",
188 | )
189 | parser.add_argument(
190 | "--system-dir",
191 | default=None,
192 | help="Directory that contains everything to start the simultaneous system.",
193 | )
194 | parser.add_argument(
195 | "--system-config",
196 | default="main.yaml",
197 | help="Name of the config yaml of the system configs.",
198 | )
199 | parser.add_argument("--dataloader", default=None, help="Dataloader to use")
200 | parser.add_argument(
201 | "--dataloader-class", default=None, help="Dataloader class to use"
202 | )
203 | parser.add_argument(
204 | "--log-level",
205 | type=str,
206 | default="info",
207 | choices=[x.lower() for x in logging._levelToName.values()],
208 | help="Log level.",
209 | )
210 | scoring_arg_group = parser.add_mutually_exclusive_group()
211 | scoring_arg_group.add_argument(
212 | "--score-only",
213 | action="store_true",
214 | default=False,
215 | help="Only score the inference file.",
216 | )
217 | scoring_arg_group.add_argument(
218 | "--no-scoring",
219 | action="store_true",
220 | help="No scoring after inference",
221 | )
222 | parser.add_argument(
223 | "--device", type=str, default="cpu", help="Device to run the model."
224 | )
225 | dtype_arg_group = parser.add_mutually_exclusive_group()
226 | dtype_arg_group.add_argument(
227 | "--dtype",
228 | choices=["fp16", "fp32"],
229 | type=str,
230 | help=(
231 | "Choose between half-precision (fp16) and single precision (fp32) floating point formats."
232 | + " Prefer this over the fp16 flag."
233 | ),
234 | )
235 | dtype_arg_group.add_argument(
236 | "--fp16", action="store_true", default=False, help="Use fp16."
237 | )
238 | parser.add_argument(
239 | "--visualize",
240 | action="store_true",
241 | default=False,
242 | help="Create visualization graphs",
243 | )
244 | parser.add_argument(
245 | "--demo",
246 | action="store_true",
247 | default=False,
248 | help="Live remote speech to text demonstration in terminal",
249 | )
250 |
251 | return parser
252 |
253 |
254 | def add_slurm_args(parser):
255 | parser.add_argument("--slurm-partition", default="", help="Slurm partition.")
256 | parser.add_argument("--slurm-job-name", default="simuleval", help="Slurm job name.")
257 | parser.add_argument("--slurm-time", default="2:00:00", help="Slurm partition.")
258 |
--------------------------------------------------------------------------------
/simuleval/test/test_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import sys
9 | import tempfile
10 | from pathlib import Path
11 | import urllib.request
12 | from argparse import Namespace
13 |
14 | import pytest
15 |
16 | import simuleval.cli as cli
17 | from simuleval.agents import TextToTextAgent
18 | from simuleval.agents.actions import ReadAction, WriteAction
19 | from simuleval.data.segments import TextSegment
20 | import logging
21 |
22 | logger = logging.getLogger()
23 |
24 |
25 | ROOT_PATH = Path(__file__).parents[2]
26 | sys.path.insert(0, str(ROOT_PATH)) # may be needed for import from `examples`
27 |
28 | from examples.quick_start.spm_detokenizer_agent import (
29 | SentencePieceModelDetokenizerAgent,
30 | )
31 |
32 |
33 | def test_agent(root_path=ROOT_PATH):
34 | with tempfile.TemporaryDirectory() as tmpdirname:
35 | cli.sys.argv[1:] = [
36 | "--user-dir",
37 | os.path.join(root_path, "examples"),
38 | "--agent-class",
39 | "examples.quick_start.first_agent.DummyWaitkTextAgent",
40 | "--source",
41 | os.path.join(root_path, "examples", "quick_start", "source.txt"),
42 | "--target",
43 | os.path.join(root_path, "examples", "quick_start", "target.txt"),
44 | "--output",
45 | tmpdirname,
46 | ]
47 | cli.main()
48 |
49 |
50 | def test_statelss_agent(root_path=ROOT_PATH):
51 | class DummyWaitkTextAgent(TextToTextAgent):
52 | waitk = 0
53 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
54 |
55 | def policy(self, states=None):
56 | if states is None:
57 | states = self.states
58 |
59 | lagging = len(states.source) - len(states.target)
60 |
61 | if lagging >= self.waitk or states.source_finished:
62 | prediction = self.vocab[len(states.source)]
63 |
64 | return WriteAction(prediction, finished=(lagging <= 1))
65 | else:
66 | return ReadAction()
67 |
68 | args = None
69 | agent_stateless = DummyWaitkTextAgent.from_args(args)
70 | agent_state = agent_stateless.build_states()
71 | agent_stateful = DummyWaitkTextAgent.from_args(args)
72 |
73 | for _ in range(10):
74 | segment = TextSegment(0, "A")
75 | output_1 = agent_stateless.pushpop(segment, agent_state)
76 | output_2 = agent_stateful.pushpop(segment)
77 | assert output_1.content == output_2.content
78 |
79 |
80 | @pytest.mark.parametrize("detokenize_only", [True, False])
81 | def test_spm_detokenizer_agent(detokenize_only):
82 | with tempfile.TemporaryDirectory() as tmpdirname:
83 | tokenizer_file = f"{tmpdirname}/tokenizer.model"
84 | tokenizer_url = "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
85 | urllib.request.urlretrieve(tokenizer_url, tokenizer_file)
86 |
87 | args = Namespace()
88 | args.sentencepiece_model = tokenizer_file
89 | args.detokenize_only = detokenize_only
90 |
91 | output = []
92 | delays = []
93 | agent = SentencePieceModelDetokenizerAgent.from_args(args)
94 | agent_state = agent.build_states()
95 | segments = [
96 | TextSegment(0, "▁Let ' s"),
97 | TextSegment(1, "▁do ▁it ▁with"),
98 | TextSegment(2, "out ▁hesitation .", finished=True),
99 | ]
100 | for i, segment in enumerate(segments):
101 | output_segment = agent.pushpop(segment, agent_state)
102 | if not output_segment.is_empty:
103 | output.append(output_segment.content)
104 | delays += [i] * len(output_segment.content.split())
105 | if detokenize_only:
106 | assert output == ["Let's", "do it with", "out hesitation."]
107 | assert delays == [0, 1, 1, 1, 2, 2]
108 | else:
109 | assert output == ["Let's do it", "without hesitation."]
110 | assert delays == [1, 1, 1, 2, 2]
111 |
112 |
113 | @pytest.mark.parametrize("detokenize_only", [True, False])
114 | def test_spm_detokenizer_agent_pipeline(detokenize_only, root_path=ROOT_PATH):
115 | with tempfile.TemporaryDirectory() as tmpdirname:
116 | tokenizer_file = f"{tmpdirname}/tokenizer.model"
117 | tokenizer_url = "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
118 | urllib.request.urlretrieve(tokenizer_url, tokenizer_file)
119 |
120 | cli.sys.argv[1:] = [
121 | "--user-dir",
122 | os.path.join(root_path, "examples"),
123 | "--agent-class",
124 | "examples.quick_start.spm_detokenizer_agent.DummyPipeline",
125 | "--source",
126 | os.path.join(root_path, "examples", "quick_start", "spm_source.txt"),
127 | "--target",
128 | os.path.join(root_path, "examples", "quick_start", "spm_target.txt"),
129 | "--output",
130 | tmpdirname,
131 | "--segment-k",
132 | "3",
133 | "--sentencepiece-model",
134 | tokenizer_file,
135 | ]
136 | if detokenize_only:
137 | cli.sys.argv.append("--detokenize-only")
138 | cli.main()
139 |
--------------------------------------------------------------------------------
/simuleval/test/test_agent_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from pathlib import Path
9 | import tempfile
10 |
11 | import simuleval.cli as cli
12 | from simuleval.agents import AgentPipeline, TextToTextAgent
13 | from simuleval.agents.actions import ReadAction, WriteAction
14 | from simuleval.data.segments import TextSegment
15 |
16 | ROOT_PATH = Path(__file__).parents[2]
17 |
18 |
19 | def test_pipeline_cmd(root_path=ROOT_PATH):
20 | # NOTE: When importing --agent we use import_file, thus need to specify
21 | # --agent-class as agents.DummyPipeline
22 | cli.sys.argv[1:] = [
23 | "--agent",
24 | os.path.join(root_path, "examples", "quick_start", "agent_pipeline.py"),
25 | "--user-dir",
26 | os.path.join(root_path, "examples"),
27 | "--agent-class",
28 | "agents.DummyPipeline",
29 | "--source",
30 | os.path.join(root_path, "examples", "quick_start", "source.txt"),
31 | "--target",
32 | os.path.join(root_path, "examples", "quick_start", "target.txt"),
33 | ]
34 | cli.main()
35 |
36 |
37 | def test_tree_pipeline_cmd(root_path=ROOT_PATH):
38 | args_path = Path.joinpath(root_path, "examples", "speech_to_speech")
39 | os.chdir(args_path)
40 | with tempfile.TemporaryDirectory() as tmpdirname:
41 | cli.sys.argv[1:] = [
42 | "--agent-class",
43 | "examples.speech_to_speech_text.tree_agent_pipeline.DummyTreePipeline",
44 | "--user-dir",
45 | os.path.join(root_path, "examples"),
46 | "--source",
47 | os.path.join(root_path, "examples", "speech_to_speech", "source.txt"),
48 | "--target",
49 | os.path.join(
50 | root_path, "examples", "speech_to_text", "reference", "en.txt"
51 | ),
52 | "--source-segment-size",
53 | "320",
54 | "--output-index",
55 | "0",
56 | "--output",
57 | tmpdirname,
58 | ]
59 | cli.main()
60 |
61 |
62 | def test_instantiated_tree_pipeline_cmd(root_path=ROOT_PATH):
63 | args_path = Path.joinpath(root_path, "examples", "speech_to_speech")
64 | os.chdir(args_path)
65 | with tempfile.TemporaryDirectory() as tmpdirname:
66 | cli.sys.argv[1:] = [
67 | "--agent-class",
68 | "examples.speech_to_speech_text.tree_agent_pipeline.AnotherInstantiatedTreeAgentPipeline",
69 | "--user-dir",
70 | os.path.join(root_path, "examples"),
71 | "--source",
72 | os.path.join(root_path, "examples", "speech_to_speech", "source.txt"),
73 | "--target",
74 | os.path.join(
75 | root_path, "examples", "speech_to_text", "reference", "en.txt"
76 | ),
77 | "--source-segment-size",
78 | "320",
79 | "--output-index",
80 | "0",
81 | "--output",
82 | tmpdirname,
83 | ]
84 | cli.main()
85 |
86 |
87 | def test_pipeline():
88 | class DummyWaitkTextAgent(TextToTextAgent):
89 | waitk = 0
90 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
91 |
92 | def policy(self):
93 | lagging = len(self.states.source) - len(self.states.target)
94 |
95 | if lagging >= self.waitk or self.states.source_finished:
96 | prediction = self.vocab[len(self.states.source)]
97 |
98 | return WriteAction(prediction, finished=(lagging <= 1))
99 | else:
100 | return ReadAction()
101 |
102 | class DummyWait2TextAgent(DummyWaitkTextAgent):
103 | waitk = 2
104 |
105 | class DummyWait4TextAgent(DummyWaitkTextAgent):
106 | waitk = 4
107 |
108 | class DummyPipeline(AgentPipeline):
109 | pipeline = [DummyWait2TextAgent, DummyWait4TextAgent]
110 |
111 | args = None
112 | agent_1 = DummyPipeline.from_args(args)
113 | agent_2 = DummyPipeline.from_args(args)
114 | for _ in range(10):
115 | segment = TextSegment(0, "A")
116 | output_1 = agent_1.pushpop(segment)
117 | agent_2.push(segment)
118 | output_2 = agent_2.pop()
119 | assert output_1.content == output_2.content
120 |
--------------------------------------------------------------------------------
/simuleval/test/test_evaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 | from pathlib import Path
10 |
11 | import simuleval.cli as cli
12 |
13 | ROOT_PATH = Path(__file__).parents[2]
14 |
15 |
16 | def test_score_only(root_path=ROOT_PATH):
17 | with tempfile.TemporaryDirectory() as tmpdirname:
18 | cli.sys.argv[1:] = [
19 | "--user-dir",
20 | os.path.join(root_path, "examples"),
21 | "--agent-class",
22 | "examples.quick_start.first_agent.DummyWaitkTextAgent",
23 | "--source",
24 | os.path.join(root_path, "examples", "quick_start", "source.txt"),
25 | "--target",
26 | os.path.join(root_path, "examples", "quick_start", "target.txt"),
27 | "--output",
28 | tmpdirname,
29 | ]
30 | cli.main()
31 | cli.sys.argv[1:] = ["--score-only", "--output", tmpdirname]
32 | cli.main()
33 |
--------------------------------------------------------------------------------
/simuleval/test/test_remote_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 | import time
10 | from multiprocessing import Process
11 | from pathlib import Path
12 |
13 | import simuleval.cli as cli
14 | from simuleval.utils.functional import find_free_port
15 |
16 | ROOT_PATH = Path(__file__).parents[2]
17 |
18 |
19 | def p1(port, root_path):
20 | cli.sys.argv[1:] = [
21 | "--standalone",
22 | "--remote-port",
23 | str(port),
24 | "--user-dir",
25 | os.path.join(root_path, "examples"),
26 | "--agent-class",
27 | "examples.quick_start.first_agent.DummyWaitkTextAgent",
28 | ]
29 | cli.main()
30 | time.sleep(5)
31 |
32 |
33 | def p2(port, root_path):
34 | with tempfile.TemporaryDirectory() as tmpdirname:
35 | cli.sys.argv[1:] = [
36 | "--remote-eval",
37 | "--remote-port",
38 | str(port),
39 | "--source",
40 | os.path.join(root_path, "examples", "quick_start", "source.txt"),
41 | "--target",
42 | os.path.join(root_path, "examples", "quick_start", "target.txt"),
43 | "--dataloader",
44 | "text-to-text",
45 | "--output",
46 | tmpdirname,
47 | ]
48 | cli.main()
49 |
50 |
51 | def test_remote_eval(root_path=ROOT_PATH):
52 | port = find_free_port()
53 |
54 | p_1 = Process(target=p1, args=(port, root_path))
55 | p_1.start()
56 |
57 | p_2 = Process(target=p2, args=(port, root_path))
58 | p_2.start()
59 |
60 | p_1.kill()
61 | p_2.kill()
62 |
--------------------------------------------------------------------------------
/simuleval/test/test_s2s.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 | from pathlib import Path
10 | from typing import Optional
11 | from simuleval.agents.states import AgentStates
12 |
13 | import simuleval.cli as cli
14 | from simuleval.agents import SpeechToSpeechAgent
15 | from simuleval.agents.actions import ReadAction, WriteAction
16 | from simuleval.data.segments import SpeechSegment
17 |
18 | ROOT_PATH = Path(__file__).parents[2]
19 |
20 |
21 | def test_s2s(root_path=ROOT_PATH):
22 | args_path = Path.joinpath(root_path, "examples", "speech_to_speech")
23 | os.chdir(args_path)
24 | with tempfile.TemporaryDirectory() as tmpdirname:
25 | cli.sys.argv[1:] = [
26 | "--agent",
27 | os.path.join(
28 | root_path, "examples", "speech_to_speech", "english_alternate_agent.py"
29 | ),
30 | "--user-dir",
31 | os.path.join(root_path, "examples"),
32 | "--agent-class",
33 | "agents.EnglishAlternateAgent",
34 | "--source-segment-size",
35 | "1000",
36 | "--source",
37 | os.path.join(root_path, "examples", "speech_to_speech", "source.txt"),
38 | "--target",
39 | os.path.join(root_path, "examples", "speech_to_speech", "reference/en.txt"),
40 | "--output",
41 | tmpdirname,
42 | "--tgt-lang",
43 | os.path.join(
44 | root_path, "examples", "speech_to_speech", "reference/tgt_lang.txt"
45 | ),
46 | ]
47 | cli.main()
48 |
49 |
50 | def test_stateless_agent(root_path=ROOT_PATH):
51 | class EnglishAlternateAgent(SpeechToSpeechAgent):
52 | waitk = 0
53 | wait_seconds = 3
54 | vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
55 |
56 | def policy(self, states: Optional[AgentStates] = None):
57 | if states is None:
58 | states = states
59 |
60 | length_in_seconds = round(len(states.source) / states.source_sample_rate)
61 | if (
62 | not self.states.source_finished
63 | and length_in_seconds < self.wait_seconds
64 | ):
65 | return ReadAction()
66 |
67 | if length_in_seconds % 2 == 0:
68 | samples, fs = self.tts_model.synthesize(
69 | f"{8 - length_in_seconds} even even"
70 | )
71 | else:
72 | samples, fs = self.tts_model.synthesize(
73 | f"{8 - length_in_seconds} odd odd"
74 | )
75 |
76 | prediction = f"{length_in_seconds} second"
77 |
78 | return WriteAction(
79 | SpeechSegment(
80 | content=samples,
81 | sample_rate=fs,
82 | finished=self.states.source_finished,
83 | ),
84 | content=prediction,
85 | finished=self.states.source_finished,
86 | )
87 |
88 | args = None
89 | agent_stateless = EnglishAlternateAgent.from_args(args)
90 | agent_state = agent_stateless.build_states()
91 | agent_stateful = EnglishAlternateAgent.from_args(args)
92 |
93 | for _ in range(10):
94 | segment = SpeechSegment(0, "A")
95 | output_1 = agent_stateless.pushpop(segment, agent_state)
96 | output_2 = agent_stateful.pushpop(segment)
97 | assert output_1.content == output_2.content
98 |
--------------------------------------------------------------------------------
/simuleval/test/test_s2t.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 | from pathlib import Path
10 |
11 | import simuleval.cli as cli
12 | from simuleval.agents import SpeechToTextAgent
13 | from simuleval.agents.actions import ReadAction, WriteAction
14 | from simuleval.data.segments import SpeechSegment
15 | from simuleval.evaluator.instance import LogInstance
16 |
17 | ROOT_PATH = Path(__file__).parents[2]
18 |
19 |
20 | def test_s2t(root_path=ROOT_PATH):
21 | args_path = Path.joinpath(root_path, "examples", "speech_to_text")
22 | os.chdir(args_path)
23 | with tempfile.TemporaryDirectory() as tmpdirname:
24 | cli.sys.argv[1:] = [
25 | "--agent",
26 | os.path.join(
27 | root_path, "examples", "speech_to_text", "counter_in_tgt_lang_agent.py"
28 | ),
29 | "--user-dir",
30 | os.path.join(root_path, "examples"),
31 | "--agent-class",
32 | "agents.EnglishSpeechCounter",
33 | "--source-segment-size",
34 | "1000",
35 | "--source",
36 | os.path.join(root_path, "examples", "speech_to_text", "source.txt"),
37 | "--target",
38 | os.path.join(root_path, "examples", "speech_to_text", "reference/en.txt"),
39 | "--output",
40 | tmpdirname,
41 | "--tgt-lang",
42 | os.path.join(
43 | root_path, "examples", "speech_to_text", "reference/tgt_lang.txt"
44 | ),
45 | ]
46 | cli.main()
47 |
48 | with open(os.path.join(tmpdirname, "instances.log"), "r") as f:
49 | for line in f:
50 | instance = LogInstance(line.strip())
51 | assert (
52 | instance.prediction
53 | == "1 segundos 2 segundos 3 segundos 4 segundos 5 segundos 6 segundos 7 segundos"
54 | )
55 |
56 |
57 | def test_statelss_agent(root_path=ROOT_PATH):
58 | class EnglishSpeechCounter(SpeechToTextAgent):
59 | wait_seconds = 3
60 |
61 | def policy(self, states=None):
62 | if states is None:
63 | states = self.states
64 |
65 | length_in_seconds = round(len(states.source) / states.source_sample_rate)
66 | if not states.source_finished and length_in_seconds < self.wait_seconds:
67 | return ReadAction()
68 |
69 | prediction = f"{length_in_seconds} second"
70 |
71 | return WriteAction(
72 | content=prediction,
73 | finished=states.source_finished,
74 | )
75 |
76 | args = None
77 | agent_stateless = EnglishSpeechCounter.from_args(args)
78 | agent_state = agent_stateless.build_states()
79 | agent_stateful = EnglishSpeechCounter.from_args(args)
80 |
81 | for _ in range(10):
82 | segment = SpeechSegment(0, "A")
83 | output_1 = agent_stateless.pushpop(segment, agent_state)
84 | output_2 = agent_stateful.pushpop(segment)
85 | assert output_1.content == output_2.content
86 |
87 |
88 | def test_s2t_with_tgt_lang(root_path=ROOT_PATH):
89 | args_path = Path.joinpath(root_path, "examples", "speech_to_text")
90 | os.chdir(args_path)
91 | with tempfile.TemporaryDirectory() as tmpdirname:
92 | cli.sys.argv[1:] = [
93 | "--agent",
94 | os.path.join(
95 | root_path, "examples", "speech_to_text", "counter_in_tgt_lang_agent.py"
96 | ),
97 | "--user-dir",
98 | os.path.join(root_path, "examples"),
99 | "--agent-class",
100 | "agents.CounterInTargetLanguage",
101 | "--source-segment-size",
102 | "1000",
103 | "--source",
104 | os.path.join(root_path, "examples", "speech_to_text", "source.txt"),
105 | "--target",
106 | os.path.join(root_path, "examples", "speech_to_text", "reference/en.txt"),
107 | "--output",
108 | tmpdirname,
109 | "--tgt-lang",
110 | os.path.join(
111 | root_path, "examples", "speech_to_text", "reference/tgt_lang.txt"
112 | ),
113 | ]
114 | cli.main()
115 |
116 | with open(os.path.join(tmpdirname, "instances.log"), "r") as f:
117 | for line in f:
118 | instance = LogInstance(line.strip())
119 | assert (
120 | instance.prediction
121 | == "1 segundos 2 segundos 3 segundos 4 segundos 5 segundos 6 segundos 7 segundos"
122 | )
123 |
--------------------------------------------------------------------------------
/simuleval/test/test_visualize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 | from pathlib import Path
10 | import simuleval.cli as cli
11 | import shutil
12 | import json
13 |
14 | ROOT_PATH = Path(__file__).parents[2]
15 |
16 |
17 | def test_visualize(root_path=ROOT_PATH):
18 | args_path = Path.joinpath(root_path, "examples", "speech_to_text")
19 | os.chdir(args_path)
20 | with tempfile.TemporaryDirectory() as tmpdirname:
21 | cli.sys.argv[1:] = [
22 | "--agent",
23 | os.path.join(root_path, "examples", "speech_to_text", "whisper_waitk.py"),
24 | "--source-segment-size",
25 | "500",
26 | "--waitk-lagging",
27 | "3",
28 | "--source",
29 | os.path.join(root_path, "examples", "speech_to_text", "source.txt"),
30 | "--target",
31 | os.path.join(
32 | root_path, "examples", "speech_to_text", "reference/transcript.txt"
33 | ),
34 | "--output",
35 | "output",
36 | "--quality-metrics",
37 | "WER",
38 | "--visualize",
39 | ]
40 | cli.main()
41 |
42 | visual_folder_path = os.path.join("output", "visual")
43 | source_path = os.path.join(
44 | root_path, "examples", "speech_to_text", "source.txt"
45 | )
46 | source_length = 0
47 |
48 | with open(source_path, "r") as f:
49 | source_length = len(f.readlines())
50 | images = list(Path(visual_folder_path).glob("*.png"))
51 | assert len(images) == source_length
52 | shutil.rmtree("output")
53 |
54 |
55 | def test_visualize_score_only(root_path=ROOT_PATH):
56 | args_path = Path.joinpath(root_path, "examples", "speech_to_text")
57 | os.chdir(args_path)
58 |
59 | # Create sample instances.log and config.yaml in output directory
60 | output = Path("output")
61 | output.mkdir()
62 | os.chdir(output)
63 | with open("config.yaml", "w") as config:
64 | config.write("source_type: speech\n")
65 | config.write("target_type: speech")
66 | with open("instances.log", "w") as instances:
67 | json.dump(
68 | {
69 | "index": 0,
70 | "prediction": "This is a synthesized audio file to test your simultaneous speech, to speak to speech, to speak translation system.",
71 | "delays": [
72 | 1500.0,
73 | 2000.0,
74 | 2500.0,
75 | 3000.0,
76 | 3500.0,
77 | 4000.0,
78 | 4500.0,
79 | 5000.0,
80 | 5500.0,
81 | 6000.0,
82 | 6500.0,
83 | 6849.886621315192,
84 | 6849.886621315192,
85 | 6849.886621315192,
86 | 6849.886621315192,
87 | 6849.886621315192,
88 | 6849.886621315192,
89 | 6849.886621315192,
90 | 6849.886621315192,
91 | ],
92 | "elapsed": [
93 | 1947.3278522491455,
94 | 2592.338800430298,
95 | 3256.8109035491943,
96 | 3900.0539779663086,
97 | 4561.986684799194,
98 | 5216.205835342407,
99 | 5874.6888637542725,
100 | 6526.906728744507,
101 | 7193.655729293823,
102 | 7852.792739868164,
103 | 8539.628744125366,
104 | 9043.279374916267,
105 | 9043.279374916267,
106 | 9043.279374916267,
107 | 9043.279374916267,
108 | 9043.279374916267,
109 | 9043.279374916267,
110 | 9043.279374916267,
111 | 9043.279374916267,
112 | ],
113 | "prediction_length": 19,
114 | "reference": "This is a synthesized audio file to test your simultaneous speech to text and to speech to speach translation system.",
115 | "source": [
116 | "test.wav",
117 | "samplerate: 22050 Hz",
118 | "channels: 1",
119 | "duration: 6.850 s",
120 | "format: WAV (Microsoft) [WAV]",
121 | "subtype: Signed 16 bit PCM [PCM_16]",
122 | ],
123 | "source_length": 6849.886621315192,
124 | },
125 | instances,
126 | )
127 |
128 | os.chdir(args_path)
129 |
130 | with tempfile.TemporaryDirectory() as tmpdirname:
131 | cli.sys.argv[1:] = ["--score-only", "--output", "output", "--visualize"]
132 | cli.main()
133 |
134 | visual_folder_path = os.path.join("output", "visual")
135 | source_path = os.path.join(
136 | root_path, "examples", "speech_to_text", "source.txt"
137 | )
138 | source_length = 0
139 |
140 | with open(source_path, "r") as f:
141 | source_length = len(f.readlines())
142 | images = list(Path(visual_folder_path).glob("*.png"))
143 | assert len(images) == source_length
144 | shutil.rmtree("output")
145 |
--------------------------------------------------------------------------------
/simuleval/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .agent import build_system_from_dir, EVALUATION_SYSTEM_LIST # noqa F401
8 |
9 |
10 | def entrypoint(klass):
11 | EVALUATION_SYSTEM_LIST.append(klass)
12 | return klass
13 |
--------------------------------------------------------------------------------
/simuleval/utils/agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import importlib
9 | import logging
10 | import os
11 | import sys
12 | from argparse import ArgumentParser, Namespace
13 | from pathlib import Path
14 | from typing import Optional, Tuple, Union
15 |
16 | import yaml
17 | from simuleval import options
18 | from simuleval.agents import GenericAgent
19 | from simuleval.utils.arguments import check_argument, cli_argument_list
20 |
21 | EVALUATION_SYSTEM_LIST = []
22 |
23 | logger = logging.getLogger("simuleval.utils.agent")
24 |
25 |
26 | def import_file(file_path):
27 | spec = importlib.util.spec_from_file_location("agents", file_path)
28 | agent_modules = importlib.util.module_from_spec(spec)
29 | spec.loader.exec_module(agent_modules)
30 |
31 |
32 | def get_agent_class(config_dict: Optional[dict] = None) -> GenericAgent:
33 | class_name = check_argument("agent_class", config_dict)
34 |
35 | if class_name is not None:
36 | if not check_argument("agent"):
37 | EVALUATION_SYSTEM_LIST.append(get_agent_class_from_string(class_name))
38 |
39 | system_dir = check_argument("system_dir")
40 | config_name = check_argument("system_config")
41 |
42 | if system_dir is not None:
43 | EVALUATION_SYSTEM_LIST.append(get_agent_class_from_dir(system_dir, config_name))
44 |
45 | agent_file = check_argument("agent")
46 | if agent_file is not None:
47 | import_file(agent_file)
48 |
49 | if len(EVALUATION_SYSTEM_LIST) == 0:
50 | raise RuntimeError(
51 | "Please use @entrypoint decorator to indicate the system you want to evaluate."
52 | )
53 | if len(EVALUATION_SYSTEM_LIST) > 1:
54 | if class_name is None:
55 | raise RuntimeError(
56 | "--agent-class must be specified if more than one system."
57 | )
58 | # if both --agent-class and --agent:
59 | for system in EVALUATION_SYSTEM_LIST:
60 | if f"{system.__module__}.{system.__name__}" == class_name:
61 | return system
62 | raise RuntimeError(
63 | f"--agent-class {class_name} not found in system list: {EVALUATION_SYSTEM_LIST}"
64 | )
65 | return EVALUATION_SYSTEM_LIST[0]
66 |
67 |
68 | def get_system_config(path: Union[Path, str], config_name) -> dict:
69 | path = Path(path)
70 | with open(path / config_name, "r") as f:
71 | try:
72 | config_dict = yaml.safe_load(f)
73 | except yaml.YAMLError as exc:
74 | logging.error(f"Failed to load configs from {path / config_name}.")
75 | logging.error(exc)
76 | sys.exit(1)
77 | return config_dict
78 |
79 |
80 | def get_agent_class_from_string(class_name: str) -> GenericAgent:
81 | try:
82 | agent_module = importlib.import_module(".".join(class_name.split(".")[:-1]))
83 | agent_class = getattr(agent_module, class_name.split(".")[-1])
84 | except Exception as e:
85 | logger.error(f"Not able to load {class_name}. Try setting --user-dir?")
86 | raise e
87 | return agent_class
88 |
89 |
90 | def get_agent_class_from_dir(
91 | path: Union[Path, str], config_name: str = "main.yaml"
92 | ) -> GenericAgent:
93 | config_dict = get_system_config(path, config_name)
94 | assert "agent_class" in config_dict
95 | class_name = config_dict["agent_class"]
96 | return get_agent_class_from_string(class_name)
97 |
98 |
99 | def build_system_from_dir(
100 | path: Union[Path, str],
101 | config_name: str = "main.yaml",
102 | overwrite_config_dict: Optional[dict] = None,
103 | ) -> GenericAgent:
104 | path = Path(path)
105 | config_dict = get_system_config(path, config_name)
106 | if overwrite_config_dict is not None:
107 | for key, value in overwrite_config_dict:
108 | config_dict[key] = value
109 | agent_class = get_agent_class_from_dir(path, config_name)
110 |
111 | parser = options.general_parser()
112 | agent_class.add_args(parser)
113 | args, _ = parser.parse_known_args(cli_argument_list(config_dict))
114 | sys.path.append(path.as_posix())
115 |
116 | cur_dir = os.getcwd()
117 | os.chdir(path.as_posix())
118 | system = agent_class.from_args(args)
119 | os.chdir(cur_dir)
120 | return system
121 |
122 |
123 | def build_system_args(
124 | config_dict: Optional[dict] = None,
125 | parser: Optional[ArgumentParser] = None,
126 | ) -> Tuple[GenericAgent, Namespace]:
127 | parser = options.general_parser(config_dict, parser)
128 | cli_arguments = cli_argument_list(config_dict)
129 | options.add_evaluator_args(parser)
130 | options.add_scorer_args(parser, cli_arguments)
131 | options.add_slurm_args(parser)
132 | options.add_dataloader_args(parser, cli_arguments)
133 |
134 | if check_argument("system_dir"):
135 | system = build_system_from_dir(
136 | check_argument("system_dir"), check_argument("system_config"), config_dict
137 | )
138 | else:
139 | system_class = get_agent_class(config_dict)
140 | system_class.add_args(parser)
141 | add_command_helper_arg(parser)
142 | args, _ = parser.parse_known_args(cli_argument_list(config_dict))
143 | system = system_class.from_args(args)
144 |
145 | args = parser.parse_args(cli_argument_list(config_dict))
146 |
147 | dtype = args.dtype if args.dtype else "fp16" if args.fp16 else "fp32"
148 | logger.info(f"System will run on device: {args.device}. dtype: {dtype}")
149 | system.to(args.device, fp16=(dtype == "fp16"))
150 |
151 | args.source_type = system.source_type
152 | args.target_type = system.target_type
153 | return system, args
154 |
155 |
156 | def add_command_helper_arg(parser: ArgumentParser) -> None:
157 | parser.add_argument(
158 | "-h",
159 | "--help",
160 | action="help",
161 | default="==SUPPRESS==",
162 | help=("show this help message and exit"),
163 | )
164 |
--------------------------------------------------------------------------------
/simuleval/utils/arguments.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Optional
3 | from simuleval import options
4 |
5 |
6 | def cli_argument_list(config_dict: Optional[dict]):
7 | if config_dict is None:
8 | return sys.argv[1:]
9 | else:
10 | string = ""
11 | for key, value in config_dict.items():
12 | if f"--{key.replace('_', '-')}" in sys.argv:
13 | continue
14 |
15 | if type(value) is not bool:
16 | string += f" --{key.replace('_', '-')} {value}"
17 | else:
18 | string += f" --{key.replace('_', '-')}"
19 | return sys.argv[1:] + string.split()
20 |
21 |
22 | def check_argument(name: str, config_dict: Optional[dict] = None):
23 | parser = options.general_parser()
24 | args, _ = parser.parse_known_args(cli_argument_list(config_dict))
25 | return getattr(args, name)
26 |
--------------------------------------------------------------------------------
/simuleval/utils/functional.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from contextlib import closing
8 | import socket
9 |
10 |
11 | def find_free_port():
12 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
13 | s.bind(("", 0))
14 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
15 | return s.getsockname()[1]
16 |
--------------------------------------------------------------------------------
/simuleval/utils/slurm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import itertools
8 | import logging
9 | import os
10 | import re
11 | import subprocess
12 | import sys
13 | from argparse import ArgumentParser
14 | from typing import Dict, List, Optional
15 |
16 | from simuleval import options
17 | from simuleval.utils.agent import get_agent_class
18 | from simuleval.utils.arguments import cli_argument_list
19 |
20 | logger = logging.getLogger("simuleval.slurm")
21 |
22 |
23 | def mkdir_output_dir(path: str) -> bool:
24 | try:
25 | os.makedirs(path, exist_ok=True)
26 | return True
27 | except BaseException as be:
28 | logger.error(f"Failed to write results to {path}.")
29 | logger.error(be)
30 | logger.error("Skip writing predictions.")
31 | return False
32 |
33 |
34 | def submit_slurm_job(
35 | config_dict: Optional[Dict] = None, parser: Optional[ArgumentParser] = None
36 | ) -> None:
37 | if config_dict is not None and "slurm" in config_dict:
38 | raise RuntimeError("--slurm is only available as a CLI argument")
39 |
40 | sweep_options = [
41 | [[key, v] for v in value]
42 | for key, value in config_dict.items()
43 | if isinstance(value, list)
44 | ]
45 | sweep_config_dict_list = []
46 | if len(sweep_options) > 0:
47 | for option_list in itertools.product(*sweep_options):
48 | sweep_config_dict_list.append({k: v for k, v in option_list})
49 |
50 | for x in sweep_options:
51 | if x[0][0] in config_dict:
52 | del config_dict[x[0][0]]
53 |
54 | cli_arguments = cli_argument_list(config_dict)
55 | parser = options.general_parser(config_dict, parser)
56 | options.add_evaluator_args(parser)
57 | options.add_scorer_args(parser, cli_arguments)
58 | options.add_slurm_args(parser)
59 | options.add_dataloader_args(parser, cli_arguments)
60 | system_class = get_agent_class(config_dict)
61 | system_class.add_args(parser)
62 | args = parser.parse_args(cli_argument_list(config_dict))
63 | args.output = os.path.abspath(args.output)
64 | assert mkdir_output_dir(args.output)
65 |
66 | if args.agent is None:
67 | args.agent = sys.argv[0]
68 |
69 | os.system(f"cp {args.agent} {args.output}/agent.py")
70 | _args = [sys.argv[0]]
71 | for arg in sys.argv[1:]:
72 | if str(arg).isdigit() or str(arg).startswith("--"):
73 | _args.append(arg)
74 | else:
75 | _args.append(f'"{arg}"')
76 | command = " ".join(_args).strip()
77 | command = re.sub(r"(--slurm\S*(\s+[^-]\S+)*)", "", command).strip()
78 | if subprocess.check_output(["which", "simuleval"]).decode().strip() in command:
79 | command = re.sub(
80 | r"--agent\s+\S+", f"--agent {args.output}/agent.py", command
81 | ).strip()
82 | else:
83 | # Attention: not fully tested!
84 | command = re.sub(
85 | r"[^\"'\s]+\.py", f"{os.path.abspath(args.output)}/agent.py", command
86 | ).strip()
87 |
88 | sweep_command = ""
89 | sbatch_job_array_head = ""
90 | job_array_configs = ""
91 |
92 | if len(sweep_config_dict_list) > 0:
93 | job_array_configs = "declare -A JobArrayConfigs\n"
94 | for i, sub_config_dict in enumerate(sweep_config_dict_list):
95 | sub_config_string = " ".join(
96 | [f"--{k.replace('_', '-')} {v}" for k, v in sub_config_dict.items()]
97 | )
98 | job_array_configs += f'JobArrayConfigs[{i}]="{sub_config_string}"\n'
99 |
100 | job_array_configs += "\ndeclare -A JobArrayString\n"
101 | for i, sub_config_dict in enumerate(sweep_config_dict_list):
102 | sub_config_string = ".".join([str(v) for k, v in sub_config_dict.items()])
103 | job_array_configs += f'JobArrayString[{i}]="{sub_config_string}"\n'
104 |
105 | sweep_command = "${JobArrayConfigs[$SLURM_ARRAY_TASK_ID]}"
106 | sbatch_job_array_head = f"#SBATCH --array=0-{len(sweep_config_dict_list) - 1}"
107 | output_dir = (
108 | f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}"
109 | )
110 | log_path = f"{args.output}/logs/slurm-%A_%a.log"
111 |
112 | else:
113 | output_dir = args.output
114 | log_path = f"{args.output}/slurm-%j.log"
115 |
116 | if "--output" in command:
117 | command = re.sub(r"--output\s+\S+", f"--output {output_dir}", command).strip()
118 | else:
119 | command += f" --output {output_dir}"
120 |
121 | command = command.replace("--", "\\\n\t--")
122 | script = f"""#!/bin/bash
123 | #SBATCH --time={args.slurm_time}
124 | #SBATCH --partition={args.slurm_partition}
125 | #SBATCH --nodes=1
126 | #SBATCH --gpus-per-node=1
127 | #SBATCH --ntasks-per-node=8
128 | #SBATCH --output="{log_path}"
129 | #SBATCH --job-name="{args.slurm_job_name}"
130 | {sbatch_job_array_head}
131 |
132 | {job_array_configs}
133 |
134 | mkdir -p {args.output}/logs
135 | cd {os.path.abspath(args.output)}
136 |
137 | GPU_ID=$SLURM_LOCALID
138 |
139 | # Change to local a gpu id for debugging, e.g.
140 | # GPU_ID=0
141 |
142 |
143 | CUDA_VISIBLE_DEVICES=$GPU_ID {command} {sweep_command}
144 | """
145 | script_file = os.path.join(args.output, "script.sh")
146 | with open(script_file, "w") as f:
147 | f.writelines(script)
148 |
149 | process = subprocess.Popen(
150 | ["sbatch", script_file],
151 | stderr=subprocess.PIPE,
152 | stdout=subprocess.PIPE,
153 | )
154 | stdout, stderr = process.communicate()
155 | logger.info("Using slurm.")
156 | logger.info(f"sbatch stdout: {stdout.decode('utf-8').strip()}")
157 | stderr = stderr.decode("utf-8").strip()
158 | if len(stderr) > 0:
159 | logger.info(f"sbatch stderr: {stderr.strip()}")
160 |
--------------------------------------------------------------------------------
/simuleval/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import soundfile
2 | import matplotlib.patches as patches
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import os
6 | from pathlib import Path
7 |
8 |
9 | class Visualize:
10 | def __init__(self, data, index, path):
11 | self.data = data
12 | self.index = index
13 | self.path = path
14 |
15 | def make_graph(self):
16 | data = self.data
17 |
18 | ##### 1st PLOT #####
19 | # Organize all data
20 | words = np.arange(1, data.get("prediction_length", 0) + 1)
21 | delays = [x / 1000 for x in data.get("delays", "N/A delays data")]
22 | data_points = [
23 | (delays[i], words[i]) for i in range(len(delays))
24 | ] # /1000 to make it from ms to s
25 | data_points.insert(0, (0, 0))
26 | prediction_word_list = data.get("prediction", "N/A predition data").split(" ")
27 |
28 | # Insert points to create the staircase effect
29 | i = 0
30 | while i < len(data_points) - 1:
31 | if data_points[i][0] != data_points[i + 1][0]:
32 | data_points.insert(i + 1, (data_points[i + 1][0], data_points[i][1]))
33 | i += 2
34 | else:
35 | break
36 |
37 | # Create a new figure
38 | fig = plt.figure(figsize=(10, 16), dpi=300)
39 | ax1, ax2 = fig.subplots(2, gridspec_kw={"hspace": 0.1})
40 |
41 | # Plot the points and connect them with lines
42 | x, y = zip(*data_points)
43 | ax1.plot(list(x), list(y), marker="o", color="black")
44 |
45 | # Draw arrows between points
46 | red_arrow_indices = []
47 | for i in range(len(data_points) - 1):
48 | color = "blue" if y[i] == y[i + 1] else "red"
49 | if color == "red":
50 | red_arrow_indices.append(i)
51 | ax1.annotate(
52 | "",
53 | xy=(x[i + 1], y[i + 1]),
54 | xytext=(x[i], y[i]),
55 | arrowprops=dict(
56 | facecolor=color, edgecolor=color, shrink=0, width=1, headwidth=7
57 | ),
58 | )
59 |
60 | # Annotate blue arrows with words from prediction_word_list
61 | for idx, word in zip(red_arrow_indices, prediction_word_list):
62 | ax1.text(
63 | x[idx] + 0.1, # x position of the annotation
64 | (y[idx] + y[idx + 1]) / 2, # y position of the annotation
65 | word, # text to annotate
66 | va="center", # vertical alignment
67 | fontsize=9, # font size
68 | color="black", # text color
69 | )
70 |
71 | # Custom legend on the right side
72 | legend_x = 0.85 # X-coordinate for the legend in axes coordinates
73 | legend_y = 0.9 # Starting Y-coordinate for the legend
74 |
75 | # Red downward arrow
76 | red_arrow = patches.FancyArrow(
77 | legend_x + 0.03,
78 | legend_y,
79 | 0,
80 | -0.04,
81 | width=0.01,
82 | head_width=0.03,
83 | head_length=0.03,
84 | color="red",
85 | transform=ax1.transAxes,
86 | )
87 | ax1.add_patch(red_arrow)
88 | ax1.text(
89 | legend_x + 0.05,
90 | legend_y - 0.03,
91 | "Write",
92 | color="red",
93 | transform=ax1.transAxes,
94 | )
95 |
96 | # Blue rightward arrow
97 | blue_arrow = patches.FancyArrow(
98 | legend_x - 0.05,
99 | legend_y,
100 | 0.05,
101 | 0,
102 | width=0.01,
103 | head_width=0.03,
104 | head_length=0.03,
105 | color="blue",
106 | transform=ax1.transAxes,
107 | )
108 | ax1.add_patch(blue_arrow)
109 | ax1.text(
110 | legend_x - 0.05,
111 | legend_y + 0.02,
112 | "Read",
113 | color="blue",
114 | transform=ax1.transAxes,
115 | )
116 |
117 | # Flip the y-axis
118 | ax1.invert_yaxis()
119 | # Set the limits of the axes to start at (0,0)
120 | ax1.set_xlim(left=0)
121 | ax1.set_xlim(right=delays[-1])
122 | ax1.set_ylim(top=0)
123 |
124 | # Set the grid
125 | ax1.grid(True)
126 |
127 | # Add labels and title
128 | ax1.set_xlabel("Source / delays (s)")
129 | ax1.set_ylabel("Target / number of words")
130 | plt.suptitle("SimulEval", fontsize=20)
131 | reference = data.get("reference", "N/A reference data")
132 | subtitle = f'Reference: "{reference}"'
133 | ax1.set_title(subtitle, fontsize=10, pad=20)
134 | ax1.set_xticks(
135 | [
136 | i * (delays[1] - delays[0])
137 | for i in range(0, round(delays[0] / (delays[1] - delays[0])))
138 | ]
139 | + delays
140 | )
141 | ax1.set_yticks(words)
142 |
143 | # Additional text at the bottom
144 | additional_text = [
145 | ("Source", data.get("source", "N/A source data")[0]),
146 | (
147 | "Prediction length",
148 | data.get("prediction_length", "N/A prediction length"),
149 | ),
150 | (
151 | "Sample rate",
152 | data.get("source", "N/A source data")[1].split(":")[1].strip(),
153 | ),
154 | (
155 | "Duration",
156 | data.get("source", "N/A source data")[3].split(":")[1].strip(),
157 | ),
158 | ]
159 | start_y = 1.2
160 | line_space = 0.05
161 | for i in range(len(additional_text)):
162 | if i % 2 == 0:
163 | ax1.text(
164 | 0,
165 | start_y,
166 | additional_text[i][0],
167 | ha="left",
168 | va="top",
169 | transform=ax1.transAxes,
170 | fontsize=10,
171 | bbox=dict(facecolor="white", alpha=0.5),
172 | )
173 | ax1.text(
174 | 0.3,
175 | start_y,
176 | additional_text[i][1],
177 | ha="right",
178 | va="top",
179 | transform=ax1.transAxes,
180 | fontsize=10,
181 | )
182 | else:
183 | ax1.text(
184 | 0.65,
185 | start_y,
186 | additional_text[i][0],
187 | ha="left",
188 | va="top",
189 | transform=ax1.transAxes,
190 | fontsize=10,
191 | bbox=dict(facecolor="white", alpha=0.5),
192 | )
193 | ax1.text(
194 | 0.85,
195 | start_y,
196 | additional_text[i][1],
197 | ha="right",
198 | va="top",
199 | transform=ax1.transAxes,
200 | fontsize=10,
201 | )
202 | start_y -= line_space
203 |
204 | ##### 2nd PLOT #####
205 | # Load the audio file data
206 | example_path = Path(self.path).parent
207 | audio_path = example_path / data["source"][0]
208 | audio_data, rate = soundfile.read(audio_path)
209 | length = audio_data.shape[0] / rate
210 | time = np.linspace(0, length, audio_data.shape[0])
211 |
212 | # Make the figure
213 | ax2.set_xlim(left=0)
214 | ax2.set_xlim(right=delays[-1])
215 | ax2.set_xticks(
216 | [
217 | i * (delays[1] - delays[0])
218 | for i in range(0, round(delays[0] / (delays[1] - delays[0])))
219 | ]
220 | + delays
221 | )
222 |
223 | # Plot the waveform
224 | ax2.plot(time, audio_data)
225 | ax2.set_xlabel("Delays (s)")
226 | ax2.set_ylabel("Amplitude")
227 |
228 | # Set the grid
229 | ax2.grid(True)
230 |
231 | # Write to output/visual/graph_.png
232 | img_path = Path(self.path) / "visual"
233 | img_path.mkdir(exist_ok=True, parents=True)
234 | ## here we use / instead of os.path.join because img_path is a Path object which cannot be joined with this function
235 | plt.savefig(img_path / ("graph" + str(self.index)))
236 |
--------------------------------------------------------------------------------