├── examples └── .gitkeep ├── VERSION ├── src └── fairseq2 │ ├── py.typed │ ├── models │ ├── __init__.py │ ├── conformer │ │ └── __init__.py │ ├── utils │ │ └── __init__.py │ ├── transformer │ │ └── __init__.py │ ├── nllb │ │ └── __init__.py │ ├── mistral │ │ └── __init__.py │ ├── llama │ │ └── __init__.py │ ├── w2vbert │ │ └── __init__.py │ └── s2t_transformer │ │ └── __init__.py │ ├── nn │ ├── utils │ │ ├── __init__.py │ │ └── grad.py │ ├── transformer │ │ ├── norm_order.py │ │ └── layer_norm.py │ ├── ops.py │ └── __init__.py │ ├── utils │ ├── __init__.py │ ├── version.py │ └── rng.py │ ├── assets │ ├── error.py │ ├── cards │ │ ├── llama.yaml │ │ ├── mistral.yaml │ │ └── s2t_conformer.yaml │ └── __init__.py │ ├── __init__.py │ ├── typing.py │ ├── optim │ ├── __init__.py │ └── optimizer_base.py │ ├── data │ ├── typing.py │ ├── vocabulary_info.py │ ├── text │ │ ├── __init__.py │ │ └── text_reader.py │ ├── __init__.py │ └── cstring.py │ ├── generation │ └── __init__.py │ └── memory.py ├── tests ├── unit │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── audio │ │ │ ├── __init__.py │ │ │ └── test.ogg │ │ ├── text │ │ │ ├── __init__.py │ │ │ └── test.spm │ │ └── data_pipeline │ │ │ ├── __init__.py │ │ │ ├── test_constant.py │ │ │ ├── test_filter.py │ │ │ ├── test_count.py │ │ │ ├── test_collate.py │ │ │ └── test_read_sequence.py │ ├── nn │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── test_grad.py │ │ │ └── test_mask.py │ │ ├── transformer │ │ │ └── __init__.py │ │ ├── test_padding.py │ │ ├── test_functional.py │ │ └── test_module_list.py │ ├── optim │ │ └── __init__.py │ ├── utils │ │ └── __init__.py │ ├── assets │ │ └── __init__.py │ ├── generation │ │ └── __init__.py │ └── models │ │ ├── __init__.py │ │ └── utils │ │ ├── __init__.py │ │ └── test_arch_registry.py ├── integration │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── fbank.pt │ │ └── test_nllb.py │ └── generation │ │ └── __init__.py ├── __init__.py └── conftest.py ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── typo_doc_issue.md │ ├── question.md │ ├── feature_request.md │ └── bug_report.md ├── workflows │ ├── ci_lint_sh.yaml │ ├── ci_lint_py.yaml │ ├── ci_lint_cc.yaml │ ├── ci_build_doc.yaml │ ├── nightly.yaml │ ├── _lint.yaml │ ├── _lint_sh.yaml │ ├── _publish_pypi.yaml │ └── _publish_doc.yaml └── PULL_REQUEST_TEMPLATE.md ├── doc ├── .gitignore ├── bibliography.rst ├── requirements.txt ├── static │ └── img │ │ └── logo.png ├── templates │ ├── autosummary │ │ ├── data.rst │ │ ├── function.rst │ │ └── class.rst │ └── footer.html ├── reference │ ├── abc.rst │ ├── enums.rst │ ├── functions.rst │ ├── classes.rst │ └── all.rst ├── index.rst └── Makefile ├── fairseq2n ├── python │ ├── requirements-build.txt │ └── src │ │ └── fairseq2n │ │ └── bindings │ │ ├── data │ │ ├── text │ │ │ └── init.cc │ │ └── init.cc │ │ ├── init.cc │ │ ├── type_casters │ │ ├── py.h │ │ ├── string.h │ │ ├── data.h │ │ └── torch.h │ │ └── module.h ├── README.md ├── src │ ├── fairseq2n │ │ ├── api.h │ │ ├── data │ │ │ ├── data_source.cc │ │ │ ├── byte_stream.cc │ │ │ ├── text │ │ │ │ ├── sentencepiece │ │ │ │ │ ├── sentencepiece.h │ │ │ │ │ ├── sp_decoder.h │ │ │ │ │ └── sp_processor.h │ │ │ │ ├── detail │ │ │ │ │ └── utf.h │ │ │ │ ├── string_to_int_converter.h │ │ │ │ ├── text_reader.cc │ │ │ │ ├── string_splitter.h │ │ │ │ ├── string_to_tensor_converter.h │ │ │ │ ├── text_line_reader.h │ │ │ │ ├── string_to_int_converter.cc │ │ │ │ └── text_data_source.h │ │ │ ├── detail │ │ │ │ ├── file_system.h │ │ │ │ ├── tensor_helpers.h │ │ │ │ ├── thread.h │ │ │ │ ├── exception.h │ │ │ │ ├── file.cc │ │ │ │ └── file.h │ │ │ ├── memory_stream.cc │ │ │ ├── constant_data_source.cc │ │ │ ├── data_length_extractor.h │ │ │ ├── element_mapper.h │ │ │ ├── count_data_source.cc │ │ │ ├── memory_stream.h │ │ │ ├── list_data_source.cc │ │ │ ├── list_data_source.h │ │ │ ├── concat_data_source.h │ │ │ ├── tape.cc │ │ │ ├── count_data_source.h │ │ │ ├── constant_data_source.h │ │ │ ├── skip_data_source.cc │ │ │ ├── bucket_data_source.h │ │ │ ├── skip_data_source.h │ │ │ ├── data_source.h │ │ │ ├── file_stream.h │ │ │ ├── take_data_source.h │ │ │ ├── py.cc │ │ │ ├── take_data_source.cc │ │ │ ├── filter_data_source.h │ │ │ ├── shard_data_source.h │ │ │ ├── shard_data_source.cc │ │ │ ├── data.cc │ │ │ ├── concat_data_source.cc │ │ │ ├── byte_stream.h │ │ │ ├── round_robin_data_source.h │ │ │ ├── file_mapper.h │ │ │ ├── zip_file_data_source.h │ │ │ ├── sample_data_source.h │ │ │ ├── yield_from_data_source.h │ │ │ ├── shuffle_data_source.h │ │ │ ├── audio │ │ │ │ └── detail │ │ │ │ │ └── kaldi_fbank.h │ │ │ ├── composite_data_source.h │ │ │ ├── element_mapper.cc │ │ │ ├── map_data_source.h │ │ │ ├── bucket_by_length_data_source.h │ │ │ ├── zip_data_source.h │ │ │ ├── filter_data_source.cc │ │ │ ├── round_robin_data_source.cc │ │ │ ├── bucket_data_source.cc │ │ │ └── sample_data_source.cc │ │ ├── exception.cc │ │ ├── detail │ │ │ ├── error.h │ │ │ ├── parallel.h │ │ │ └── exception.h │ │ ├── config.h.in │ │ ├── utils │ │ │ ├── tensor.h │ │ │ └── cast.h │ │ ├── fmt.h │ │ ├── exception.h │ │ ├── float.h │ │ └── memory.cc │ └── fairseq2n-config.cmake.in ├── third-party │ ├── natsort.cmake │ ├── pybind11.cmake │ ├── CMakeLists.txt │ ├── natsort │ │ ├── CMakeLists.txt │ │ └── strnatcmp.h │ ├── gtest.cmake │ ├── fmt.cmake │ ├── zip.cmake │ ├── kaldi-native-fbank.cmake │ └── sentencepiece.cmake ├── LSan.supp ├── cmake │ └── modules │ │ ├── FindClangTidy.cmake │ │ └── FindSndFile.cmake ├── .clang-format ├── tests │ ├── data │ │ ├── test_immutable_string.cc │ │ └── detail │ │ │ └── test_lru_cache.cc │ ├── CMakeLists.txt │ ├── utils │ │ └── test_cast.cc │ └── test_float.cc └── .clang-tidy ├── requirements-devel.txt ├── .gitignore ├── ci ├── problem-matchers │ ├── black.json │ ├── isort.json │ ├── gcc.json │ ├── flake8.json │ └── mypy.json └── docker │ ├── manylinux_x86_64 │ ├── Dockerfile.cu116 │ ├── Dockerfile.cu117 │ ├── Dockerfile.cu118 │ ├── Dockerfile.cu121 │ ├── build-scripts │ │ ├── install-cuda-11.6.sh │ │ ├── install-cuda-11.7.sh │ │ ├── install-cuda-11.8.sh │ │ ├── install-cuda-12.1.sh │ │ └── install-llvm.sh │ └── Dockerfile.cpu │ └── build-manylinux-images.sh ├── CHANGELOG.md ├── .gitmodules ├── tools └── run-shellcheck.sh ├── LICENSE └── pyproject.toml /examples/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.2.0.dev0 2 | -------------------------------------------------------------------------------- /src/fairseq2/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fairseq2/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fairseq2/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fairseq2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/assets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/data/audio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/data/text/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @cbalioglu 2 | -------------------------------------------------------------------------------- /tests/integration/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/nn/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/data/data_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/.gitignore: -------------------------------------------------------------------------------- 1 | # Auto-Generated Sphinx Stub Files 2 | generated/ 3 | -------------------------------------------------------------------------------- /doc/bibliography.rst: -------------------------------------------------------------------------------- 1 | Bibliography 2 | ============ 3 | 4 | .. bibliography:: 5 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme~=1.2.2 2 | sphinx~=6.2.1 3 | sphinxcontrib-bibtex~=2.5.0 4 | -------------------------------------------------------------------------------- /doc/static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/fairseq2/main/doc/static/img/logo.png -------------------------------------------------------------------------------- /tests/unit/data/text/test.spm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/fairseq2/main/tests/unit/data/text/test.spm -------------------------------------------------------------------------------- /tests/unit/data/audio/test.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/fairseq2/main/tests/unit/data/audio/test.ogg -------------------------------------------------------------------------------- /tests/integration/models/fbank.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/fairseq2/main/tests/integration/models/fbank.pt -------------------------------------------------------------------------------- /doc/templates/autosummary/data.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | {{ name | escape | underline }} 4 | 5 | .. autodata:: {{ name }} 6 | -------------------------------------------------------------------------------- /doc/templates/autosummary/function.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | {{ name | escape | underline }} 4 | 5 | .. autofunction:: {{ name }} 6 | -------------------------------------------------------------------------------- /fairseq2n/python/requirements-build.txt: -------------------------------------------------------------------------------- 1 | cmake~=3.26 2 | ninja~=1.11 3 | packaging~=23.1 4 | pip~=23.2 5 | setuptools~=67.8 6 | tbb-devel==2021.8;platform_machine=='x86_64' 7 | wheel~=0.40 8 | -------------------------------------------------------------------------------- /doc/reference/abc.rst: -------------------------------------------------------------------------------- 1 | ABCs and Protocols 2 | ================== 3 | .. body 4 | 5 | .. currentmodule:: fairseq2 6 | 7 | .. autosummary:: 8 | :toctree: generated/abc 9 | :nosignatures: 10 | 11 | gang.Gang 12 | -------------------------------------------------------------------------------- /requirements-devel.txt: -------------------------------------------------------------------------------- 1 | black~=23.3 2 | flake8~=6.0 3 | flake8-pyi~=23.5 4 | flake8-pyproject~=1.2 5 | isort~=5.12 6 | mypy~=1.5.1 7 | pytest~=7.3 8 | shellcheck-py~=0.9 9 | types-PyYAML~=6.0 10 | types-setuptools~=67.8 11 | -------------------------------------------------------------------------------- /doc/reference/enums.rst: -------------------------------------------------------------------------------- 1 | Enums 2 | ===== 3 | .. body 4 | 5 | .. currentmodule:: fairseq2 6 | 7 | .. autosummary:: 8 | :toctree: generated/enums 9 | :nosignatures: 10 | 11 | nn.transformer.TransformerNormOrder 12 | -------------------------------------------------------------------------------- /doc/reference/functions.rst: -------------------------------------------------------------------------------- 1 | Functions 2 | ========= 3 | .. body 4 | 5 | .. currentmodule:: fairseq2 6 | 7 | .. autosummary:: 8 | :toctree: generated/functions 9 | :nosignatures: 10 | 11 | nn.utils.mask.to_float_mask 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build Artifacts 2 | build/ 3 | 4 | # Core Dumps 5 | core 6 | 7 | # Byte-Compiled Modules 8 | __pycache__/ 9 | 10 | # Extension Modules 11 | *.so 12 | 13 | # Packaging Artifacts 14 | *.egg-info 15 | *.whl 16 | 17 | # IDEs and Tools 18 | .idea/ 19 | .gdb_history 20 | 21 | # Other 22 | .DS_Store 23 | -------------------------------------------------------------------------------- /doc/templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends '!footer.html' %} 2 | 3 | {% block extrafooter %} 4 |

Terms of Use, Privacy Policy

5 | 6 |

Copyright © Meta Platforms, Inc

7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /ci/problem-matchers/black.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "black", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^(would reformat) (.+)$", 9 | "file": 2, 10 | "message": 1 11 | } 12 | ] 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /ci/problem-matchers/isort.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "isort", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^ERROR: (\\S+) (.*)$", 9 | "file": 1, 10 | "message": 2 11 | } 12 | ] 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | pytest.register_assert_rewrite("tests.common") 10 | -------------------------------------------------------------------------------- /fairseq2n/README.md: -------------------------------------------------------------------------------- 1 | # fairseq2n: fairseq2 Native Library 2 | 3 | fairseq2n contains the native parts (i.e. C++, CUDA kernels) of fairseq2, while 4 | fairseq2 itself is a pure Python package. Unless you want to contribute some 5 | C++/CUDA work to fairseq2, you can safely ignore fairseq2n and consider it an 6 | implementation detail. 7 | -------------------------------------------------------------------------------- /doc/templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | {{ name | escape | underline }} 4 | 5 | .. autoclass:: {{ name }} 6 | :members: 7 | :member-order: groupwise 8 | :class-doc-from: both 9 | :special-members: __call__, __iter__ 10 | :inherited-members: Module 11 | :show-inheritance: 12 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/api.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #define FAIRSEQ2_API __attribute__((visibility("default"))) 10 | -------------------------------------------------------------------------------- /src/fairseq2/assets/error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | class AssetError(RuntimeError): 9 | """Raised when an asset operation fails.""" 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/typo_doc_issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typo or Documentation Issue 3 | about: Report a typo or an issue related to documentation. 4 | labels: 'documentation, needs triage' 5 | --- 6 | 7 | For typos, please go ahead; fix the typo and submit a PR. For documentation issues, please describe the issue here and wait for approval before submitting a PR. 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a question to the users and contributors. 4 | labels: 'question, needs triage' 5 | --- 6 | 7 | Please make sure that you first search existing issues and documentation before asking a question. If you cannot find an answer, be clear and concise. Ideally attach a minimal code sample if it is relevant to your question. 8 | -------------------------------------------------------------------------------- /doc/reference/classes.rst: -------------------------------------------------------------------------------- 1 | Classes 2 | ======= 3 | .. body 4 | 5 | .. currentmodule:: fairseq2 6 | 7 | .. autosummary:: 8 | :toctree: generated/classes 9 | :nosignatures: 10 | 11 | optim.lr_scheduler.CosineAnnealingLR 12 | optim.lr_scheduler.LRSchedulerBase 13 | optim.lr_scheduler.MyleLR 14 | optim.lr_scheduler.NoamLR 15 | optim.lr_scheduler.PolynomialDecayLR 16 | -------------------------------------------------------------------------------- /fairseq2n/third-party/natsort.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_natsort) 8 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/natsort EXCLUDE_FROM_ALL) 9 | endmacro() 10 | -------------------------------------------------------------------------------- /ci/problem-matchers/gcc.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "gcc", 5 | "pattern": [ 6 | { 7 | "regexp": "^(.*):(\\d+):(\\d+):\\s+(?:fatal\\s+)?(warning|error):\\s+(.*)$", 8 | "file": 1, 9 | "line": 2, 10 | "column": 3, 11 | "severity": 4, 12 | "message": 5 13 | } 14 | ] 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /fairseq2n/LSan.supp: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | leak:numpy 8 | leak:pybind11::cpp_function::initialize_generic 9 | leak:pybind11::cpp_function::make_function_record 10 | leak:PyInit_bindings 11 | leak:torch::jit 12 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/data_source.h" 8 | 9 | namespace fairseq2n { 10 | 11 | data_source::~data_source() = default; 12 | 13 | } 14 | -------------------------------------------------------------------------------- /ci/problem-matchers/flake8.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "flake8", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^([^:]*):(\\d+):(\\d+): ([WEF]\\d\\d\\d) (.*)$", 9 | "file": 1, 10 | "line": 2, 11 | "column": 3, 12 | "code": 4, 13 | "message": 5 14 | } 15 | ] 16 | } 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /ci/problem-matchers/mypy.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "mypy", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^([^:]*):(\\d+):(?:(\\d+):)? error: (.*?)(?: \\[(\\S+)\\])?$", 9 | "file": 1, 10 | "line": 2, 11 | "column": 3, 12 | "message": 4, 13 | "code": 5 14 | } 15 | ] 16 | } 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /doc/reference/all.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 1 2 | 3 | All 4 | === 5 | 6 | ABCs and Protocols 7 | ------------------ 8 | .. include:: abc.rst 9 | :start-after: .. body 10 | 11 | Classes 12 | ------- 13 | .. include:: classes.rst 14 | :start-after: .. body 15 | 16 | Enums 17 | ----- 18 | .. include:: enums.rst 19 | :start-after: .. body 20 | 21 | Functions 22 | --------- 23 | .. include:: functions.rst 24 | :start-after: .. body 25 | -------------------------------------------------------------------------------- /.github/workflows/ci_lint_sh.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: CI 8 | 9 | on: 10 | pull_request: 11 | paths: 12 | - '**.sh' 13 | 14 | jobs: 15 | lint_sh: 16 | name: Lint shell scripts 17 | uses: ./.github/workflows/_lint_sh.yaml 18 | -------------------------------------------------------------------------------- /fairseq2n/third-party/pybind11.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_pybind11) 8 | if(NOT TARGET pybind11::module) 9 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/pybind11 EXCLUDE_FROM_ALL) 10 | endif() 11 | endmacro() 12 | -------------------------------------------------------------------------------- /fairseq2n/third-party/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | include(fmt.cmake) 8 | include(gtest.cmake) 9 | include(kaldi-native-fbank.cmake) 10 | include(natsort.cmake) 11 | include(pybind11.cmake) 12 | include(sentencepiece.cmake) 13 | include(zip.cmake) 14 | -------------------------------------------------------------------------------- /src/fairseq2/models/conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.conformer.block import ConformerBlock as ConformerBlock 8 | from fairseq2.models.conformer.convolution import ( 9 | ConformerConvolution as ConformerConvolution, 10 | ) 11 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/exception.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/exception.h" 8 | 9 | namespace fairseq2n { 10 | 11 | internal_error::~internal_error() = default; 12 | 13 | not_supported_error::~not_supported_error() = default; 14 | 15 | } 16 | -------------------------------------------------------------------------------- /.github/workflows/ci_lint_py.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: CI 8 | 9 | on: 10 | pull_request: 11 | paths: 12 | - '**.py' 13 | - '**.pyi' 14 | 15 | jobs: 16 | lint_py: 17 | name: Lint Python 18 | uses: ./.github/workflows/_lint_py.yaml 19 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/byte_stream.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/byte_stream.h" 8 | 9 | namespace fairseq2n { 10 | 11 | byte_stream::~byte_stream() = default; 12 | 13 | byte_stream_error::~byte_stream_error() = default; 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/fairseq2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # We import fairseq2n to report any initialization error eagerly. 9 | import fairseq2n 10 | 11 | __version__ = "0.2.0.dev0" 12 | 13 | 14 | # If ``True``, indicates that we are run under Sphinx. 15 | _DOC_MODE = False 16 | -------------------------------------------------------------------------------- /.github/workflows/ci_lint_cc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: CI 8 | 9 | on: 10 | pull_request: 11 | paths: 12 | - '**.h' 13 | - '**.cc' 14 | - '**.cu' 15 | 16 | jobs: 17 | lint_cc: 18 | name: Lint C++ 19 | uses: ./.github/workflows/_lint_cc.yaml 20 | -------------------------------------------------------------------------------- /.github/workflows/ci_build_doc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: CI 8 | 9 | on: 10 | pull_request: 11 | paths-ignore: 12 | - '**.md' 13 | - 'ci/**' 14 | 15 | jobs: 16 | build_doc: 17 | name: Build documentation 18 | uses: ./.github/workflows/_build_doc.yaml 19 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n-config.cmake.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | @PACKAGE_INIT@ 8 | 9 | include(CMakeFindDependencyMacro) 10 | 11 | find_dependency(Torch @TORCH_VERSION@) 12 | 13 | include(${CMAKE_CURRENT_LIST_DIR}/fairseq2-targets.cmake) 14 | 15 | check_required_components(fairseq2) 16 | -------------------------------------------------------------------------------- /.github/workflows/nightly.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Nightly 8 | 9 | on: 10 | # At 1:15AM UTC on every Monday. 11 | schedule: 12 | - cron: '15 1 * * 1' 13 | 14 | jobs: 15 | release: 16 | uses: ./.github/workflows/release.yaml 17 | with: 18 | release_type: 'nightly' 19 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/sentencepiece/sentencepiece.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include "fairseq2n/data/text/sentencepiece/sp_decoder.h" 10 | #include "fairseq2n/data/text/sentencepiece/sp_encoder.h" 11 | #include "fairseq2n/data/text/sentencepiece/sp_model.h" 12 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to fairseq2 are documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). 5 | 6 | ## [0.2] - TBD 7 | - Introduced LLaMA and LLaMA 2 models 8 | - Introduced GLUFeedForwardNetwork 9 | 10 | ## [0.1.1] - 2023-09-07 11 | - Improvements to the build system and CI pipelines 12 | - Improvements to the installation instructions and contribution guidelines 13 | 14 | ## [0.1.0] - 2023-08-23 15 | - Initial release 16 | -------------------------------------------------------------------------------- /fairseq2n/cmake/modules/FindClangTidy.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | include(FindPackageHandleStandardArgs) 8 | 9 | find_program(CLANG_TIDY_EXECUTABLE NAMES clang-tidy) 10 | 11 | mark_as_advanced(CLANG_TIDY_EXECUTABLE) 12 | 13 | find_package_handle_standard_args(ClangTidy REQUIRED_VARS CLANG_TIDY_EXECUTABLE) 14 | -------------------------------------------------------------------------------- /src/fairseq2/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from overrides import final 8 | from overrides import override as override # noqa: F401 9 | from torch import device, dtype 10 | from typing_extensions import TypeAlias 11 | 12 | finaloverride = final 13 | 14 | Device: TypeAlias = device 15 | 16 | DataType: TypeAlias = dtype 17 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/Dockerfile.cu116: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu 8 | 9 | # Install CUDA. 10 | COPY build-scripts/install-cuda-11.6.sh /build-scripts/ 11 | 12 | RUN /build-scripts/install-cuda-11.6.sh && rm -rf /build-scripts 13 | 14 | ENV PATH=/usr/local/cuda-11.6/bin:$PATH 15 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/Dockerfile.cu117: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu 8 | 9 | # Install CUDA. 10 | COPY build-scripts/install-cuda-11.7.sh /build-scripts/ 11 | 12 | RUN /build-scripts/install-cuda-11.7.sh && rm -rf /build-scripts 13 | 14 | ENV PATH=/usr/local/cuda-11.7/bin:$PATH 15 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/Dockerfile.cu118: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu 8 | 9 | # Install CUDA. 10 | COPY build-scripts/install-cuda-11.8.sh /build-scripts/ 11 | 12 | RUN /build-scripts/install-cuda-11.8.sh && rm -rf /build-scripts 13 | 14 | ENV PATH=/usr/local/cuda-11.8/bin:$PATH 15 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/Dockerfile.cu121: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu 8 | 9 | # Install CUDA. 10 | COPY build-scripts/install-cuda-12.1.sh /build-scripts/ 11 | 12 | RUN /build-scripts/install-cuda-12.1.sh && rm -rf /build-scripts 13 | 14 | ENV PATH=/usr/local/cuda-12.1/bin:$PATH 15 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/facebookresearch/fairseq2 2 | 3 | 4 | fairseq2 documentation 5 | ====================== 6 | 7 | fairseq2 is a sequence modeling toolkit that allows researchers and developers 8 | to train custom models for translation, summarization, language modeling, and 9 | other content generation tasks. 10 | 11 | .. toctree:: 12 | :caption: fairseq2 Reference 13 | :maxdepth: 1 14 | 15 | reference/data 16 | reference/all 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Misc 21 | 22 | bibliography 23 | -------------------------------------------------------------------------------- /fairseq2n/third-party/natsort/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | add_library(natsort OBJECT strnatcmp.c) 8 | 9 | set_target_properties(natsort PROPERTIES 10 | C_VISIBILITY_PRESET 11 | hidden 12 | POSITION_INDEPENDENT_CODE 13 | ON 14 | ) 15 | 16 | target_include_directories(natsort SYSTEM INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) 17 | -------------------------------------------------------------------------------- /src/fairseq2/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.optim.adamw import AdamW as AdamW 8 | from fairseq2.optim.dynamic_loss_scaler import DynamicLossScaler as DynamicLossScaler 9 | from fairseq2.optim.dynamic_loss_scaler import LossScaleResult as LossScaleResult 10 | from fairseq2.optim.optimizer_base import OptimizerBase as OptimizerBase 11 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/detail/error.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | namespace fairseq2n::detail { 13 | 14 | inline std::error_code 15 | last_error() noexcept 16 | { 17 | return std::error_code{errno, std::generic_category()}; 18 | } 19 | 20 | } // namespace fairseq2n::detail 21 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/detail/file_system.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | data_list 18 | list_files(const std::string &pathname, const std::optional &maybe_pattern); 19 | 20 | } 21 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | on: 8 | workflow_call: 9 | 10 | jobs: 11 | lint_cc: 12 | name: Lint C++ 13 | uses: ./.github/workflows/_lint_cc.yaml 14 | 15 | lint_py: 16 | name: Lint Python 17 | uses: ./.github/workflows/_lint_py.yaml 18 | 19 | lint_sh: 20 | name: Lint shell scripts 21 | uses: ./.github/workflows/_lint_sh.yaml 22 | -------------------------------------------------------------------------------- /tests/unit/data/data_pipeline/test_constant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.data import DataPipeline 8 | 9 | 10 | class TestConstantOp: 11 | def test_op_works(self) -> None: 12 | pipeline = DataPipeline.constant("foo").take(10).and_return() 13 | 14 | for _ in range(2): 15 | list(pipeline) == ["foo"] * 10 16 | 17 | pipeline.reset() 18 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/memory_stream.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/memory_stream.h" 8 | 9 | namespace fairseq2n::detail { 10 | 11 | memory_block 12 | memory_stream::read_chunk() 13 | { 14 | return std::exchange(block_, {}); 15 | } 16 | 17 | void 18 | memory_stream::reset() 19 | { 20 | block_ = original_block_; 21 | } 22 | 23 | } // namespace fairseq2n::detail 24 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/detail/utf.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/memory.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | std::size_t 18 | compute_code_point_length(std::string_view s); 19 | 20 | std::string 21 | infer_bom_encoding(memory_span preamble) noexcept; 22 | 23 | } // namespace fairseq2n::detail 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Submit a request for a new feature. 4 | labels: 'enhancement, needs triage' 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe:** 8 | A clear and concise description of what the problem is. 9 | 10 | **Describe the solution you would like:** 11 | A clear and concise description of what you want to happen. 12 | 13 | **Describe the alternatives you have considered:** 14 | A clear and concise description of any alternative solutions or features you have considered. 15 | 16 | **Additional Context:** 17 | Add any other context about the feature request here. 18 | -------------------------------------------------------------------------------- /src/fairseq2/assets/cards/llama.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: llama 8 | model_type: llama 9 | checkpoint: "https://ai.meta.com/llama/;gated=true" 10 | tokenizer: "https://ai.meta.com/llama/;gated=true" 11 | 12 | --- 13 | 14 | name: llama_7b 15 | base: llama 16 | model_arch: 7b 17 | 18 | --- 19 | 20 | name: llama2_7b 21 | base: llama 22 | model_arch: llama2_7b 23 | 24 | --- 25 | 26 | name: llama2_7b_chat 27 | base: llama 28 | model_arch: llama2_7b 29 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/data/text/init.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/bindings/module.h" 8 | 9 | namespace py = pybind11; 10 | 11 | namespace fairseq2n { 12 | 13 | void 14 | def_text(py::module_ &data_module) 15 | { 16 | py::module_ m = data_module.def_submodule("text"); 17 | 18 | def_sentencepiece(m); 19 | 20 | def_text_reader(m); 21 | 22 | def_text_converters(m); 23 | } 24 | 25 | } // namespace fairseq2n 26 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /fairseq2n/third-party/gtest.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_gtest) 8 | if(NOT TARGET GTest::gtest_main) 9 | set(INSTALL_GTEST OFF) 10 | 11 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/gtest EXCLUDE_FROM_ALL) 12 | 13 | # We depend on the phony torch_cxx11_abi target to ensure that we use 14 | # the same libstdc++ ABI as PyTorch. 15 | target_link_libraries(gtest PRIVATE torch_cxx11_abi) 16 | endif() 17 | 18 | include(GoogleTest) 19 | endmacro() 20 | -------------------------------------------------------------------------------- /src/fairseq2/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.utils.arch_registry import ( 8 | ArchitectureRegistry as ArchitectureRegistry, 9 | ) 10 | from fairseq2.models.utils.generic_loaders import ConfigLoader as ConfigLoader 11 | from fairseq2.models.utils.generic_loaders import ModelLoader as ModelLoader 12 | from fairseq2.models.utils.generic_loaders import TokenizerLoader as TokenizerLoader 13 | from fairseq2.models.utils.generic_loaders import ( 14 | TokenizerLoaderBase as TokenizerLoaderBase, 15 | ) 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Submit a report to help us improve. 4 | labels: 'bug, needs triage' 5 | --- 6 | 7 | **Describe the bug:** 8 | A clear and concise description of what the bug is. 9 | 10 | **Describe how to reproduce:** 11 | Steps to reproduce the behavior. Ideally attach a minimal code sample to reproduce the described issue. 12 | 13 | **Describe the expected behavior:** 14 | A clear and concise description of what you expected to happen. 15 | 16 | **Environment:** 17 | At the very least, specify the versions of fairseq2, PyTorch, Python, and CUDA along with your operating system and, if relevant, GPU model. 18 | 19 | **Additional Context:** 20 | Add any other context about the bug here. 21 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.6.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.6.0/local_installers/cuda_11.6.0_510.39.01_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | 18 | # We don't need Nsight. 19 | rm -rf /usr/local/cuda-11.6/nsight* 20 | 21 | # Add CUDA libraries to the lookup cache of the dynamic linker. 22 | ldconfig 23 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.7.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | 18 | # We don't need Nsight. 19 | rm -rf /usr/local/cuda-11.7/nsight* 20 | 21 | # Add CUDA libraries to the lookup cache of the dynamic linker. 22 | ldconfig 23 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.8.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | 18 | # We don't need Nsight. 19 | rm -rf /usr/local/cuda-11.8/nsight* 20 | 21 | # Add CUDA libraries to the lookup cache of the dynamic linker. 22 | ldconfig 23 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/build-scripts/install-cuda-12.1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | 18 | # We don't need Nsight. 19 | rm -rf /usr/local/cuda-12.1/nsight* 20 | 21 | # Add CUDA libraries to the lookup cache of the dynamic linker. 22 | ldconfig 23 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/string_to_int_converter.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/api.h" 12 | #include "fairseq2n/data/data.h" 13 | 14 | namespace fairseq2n { 15 | 16 | class FAIRSEQ2_API string_to_int_converter final { 17 | public: 18 | explicit 19 | string_to_int_converter(std::int16_t base = 10) noexcept 20 | : base_{base} 21 | {} 22 | 23 | data 24 | operator()(data &&d) const; 25 | 26 | private: 27 | std::int16_t base_; 28 | }; 29 | 30 | } // namespace fairseq2n 31 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/pybind11"] 2 | path = fairseq2n/third-party/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | [submodule "third-party/fmt"] 5 | path = fairseq2n/third-party/fmt 6 | url = https://github.com/fmtlib/fmt.git 7 | [submodule "third-party/gtest"] 8 | path = fairseq2n/third-party/gtest 9 | url = https://github.com/google/googletest.git 10 | [submodule "third-party/sentencepiece"] 11 | path = fairseq2n/third-party/sentencepiece 12 | url = https://github.com/google/sentencepiece.git 13 | [submodule "third-party/zip"] 14 | path = fairseq2n/third-party/zip 15 | url = https://github.com/kuba--/zip.git 16 | [submodule "third-party/kaldi-native-fbank"] 17 | path = fairseq2n/third-party/kaldi-native-fbank 18 | url = https://github.com/cbalioglu/kaldi-native-fbank.git 19 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/constant_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/constant_data_source.h" 8 | 9 | #include "fairseq2n/data/data.h" 10 | 11 | namespace fairseq2n::detail { 12 | 13 | std::optional 14 | constant_data_source::next() 15 | { 16 | if (key_) 17 | return data_dict{{*key_, example_}}; 18 | 19 | return example_; 20 | } 21 | 22 | void 23 | constant_data_source::reset() 24 | {} 25 | 26 | void 27 | constant_data_source::record_position(tape &) const 28 | {} 29 | 30 | void 31 | constant_data_source::reload_position(tape &) 32 | {} 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/fairseq2/utils/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Final 8 | 9 | import torch 10 | from packaging import version 11 | from packaging.version import InvalidVersion, Version 12 | 13 | 14 | def _get_torch_version() -> Version: 15 | try: 16 | return version.parse(torch.__version__) 17 | except InvalidVersion: 18 | return Version("0.0.0") 19 | 20 | 21 | TORCH_VERSION: Final[Version] = _get_torch_version() 22 | 23 | 24 | def is_pt2_or_greater() -> bool: 25 | """Return ``True`` if the version of PyTorch is 2.0 or greater.""" 26 | return TORCH_VERSION.major >= 2 27 | -------------------------------------------------------------------------------- /src/fairseq2/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.transformer.decoder_model import ( 8 | TransformerDecoderModel as TransformerDecoderModel, 9 | ) 10 | from fairseq2.models.transformer.frontend import ( 11 | TransformerEmbeddingFrontend as TransformerEmbeddingFrontend, 12 | ) 13 | from fairseq2.models.transformer.frontend import ( 14 | TransformerFrontend as TransformerFrontend, 15 | ) 16 | from fairseq2.models.transformer.model import TransformerModel as TransformerModel 17 | from fairseq2.models.transformer.model import ( 18 | init_final_projection as init_final_projection, 19 | ) 20 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/config.h.in: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | namespace fairseq2n { 13 | 14 | constexpr std::int32_t version_major = @PROJECT_VERSION_MAJOR@; 15 | constexpr std::int32_t version_minor = @PROJECT_VERSION_MINOR@; 16 | constexpr std::int32_t version_patch = @PROJECT_VERSION_PATCH@; 17 | 18 | constexpr std::optional cuda_version_major = @CUDA_VERSION_MAJOR@; 19 | constexpr std::optional cuda_version_minor = @CUDA_VERSION_MINOR@; 20 | 21 | constexpr bool supports_cuda = cuda_version_major.has_value(); 22 | 23 | } // namespace fairseq2n 24 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/data_length_extractor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/api.h" 14 | #include "fairseq2n/data/element_selector.h" 15 | 16 | namespace fairseq2n { 17 | 18 | class data; 19 | 20 | class FAIRSEQ2_API data_length_extractor { 21 | public: 22 | explicit 23 | data_length_extractor(std::optional maybe_selector); 24 | 25 | std::size_t 26 | operator()(const data &d) const; 27 | 28 | private: 29 | std::optional maybe_selector_{}; 30 | }; 31 | 32 | } // namespace fairseq2n 33 | -------------------------------------------------------------------------------- /ci/docker/build-manylinux-images.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | repo=ghcr.io/facebookresearch 12 | 13 | arch=x86_64 14 | 15 | version=2 16 | 17 | declare -a variants=(cpu cu116 cu117 cu118 cu121) 18 | 19 | for variant in "${variants[@]}"; do 20 | docker build\ 21 | --network host\ 22 | --tag $repo/fairseq2-ci-manylinux_$arch:$version-$variant\ 23 | --file manylinux_$arch/Dockerfile.$variant\ 24 | manylinux_$arch/ 25 | done 26 | 27 | for variant in "${variants[@]}"; do 28 | docker push $repo/fairseq2-ci-manylinux_$arch:$version-$variant 29 | done 30 | 31 | docker logout ghcr.io 32 | -------------------------------------------------------------------------------- /fairseq2n/.clang-format: -------------------------------------------------------------------------------- 1 | Language: Cpp 2 | BasedOnStyle: LLVM 3 | AccessModifierOffset: -4 4 | AllowShortFunctionsOnASingleLine: None 5 | AllowShortLambdasOnASingleLine: Empty 6 | AlwaysBreakAfterReturnType: All 7 | AlwaysBreakTemplateDeclarations: Yes 8 | AttributeMacros: [] 9 | BitFieldColonSpacing: After 10 | BreakBeforeBraces: Custom 11 | BraceWrapping: 12 | AfterFunction: true 13 | SplitEmptyFunction: false 14 | SplitEmptyRecord: false 15 | SplitEmptyNamespace: false 16 | ColumnLimit: 100 17 | ForEachMacros: [] 18 | IncludeIsMainRegex: "^(test_)?" 19 | IndentGotoLabels: false 20 | IndentWidth: 4 21 | PointerAlignment: Right 22 | RawStringFormats: [] 23 | SpaceAfterCStyleCast: true 24 | SpacesBeforeTrailingComments: 2 25 | SpacesInContainerLiterals: false 26 | StatementAttributeLikeMacros: [] 27 | StatementMacros: [] 28 | WhitespaceSensitiveMacros: 29 | - static_assert 30 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/element_mapper.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/api.h" 14 | #include "fairseq2n/data/data_pipeline.h" 15 | #include "fairseq2n/data/element_selector.h" 16 | 17 | namespace fairseq2n { 18 | 19 | class FAIRSEQ2_API element_mapper { 20 | public: 21 | explicit 22 | element_mapper(map_fn fn, std::optional maybe_selector = {}); 23 | 24 | data 25 | operator()(data &&d); 26 | 27 | private: 28 | map_fn map_fn_; 29 | std::optional maybe_selector_{}; 30 | }; 31 | 32 | } // namespace fairseq2n 33 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM quay.io/pypa/manylinux2014_x86_64 8 | 9 | # Install system dependencies. 10 | RUN yum --assumeyes install\ 11 | devtoolset-10-lib{asan,lsan,ubsan,tsan}-devel lib{sndfile,png,jpeg}-devel &&\ 12 | yum clean all 13 | 14 | # Install Ninja. 15 | RUN pipx install --pip-args=--no-cache-dir ninja 16 | 17 | # Install LLVM. 18 | COPY build-scripts/install-llvm.sh /build-scripts/ 19 | 20 | RUN /build-scripts/install-llvm.sh && rm -rf /build-scripts 21 | 22 | # Path to sanitizer libs. Used by the CI tests. 23 | ENV LIBASAN=/usr/lib64/libasan.so.6 24 | ENV LIBTSAN=/usr/lib64/libtsan.so.0 25 | 26 | CMD ["/bin/bash"] 27 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/utils/tensor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | namespace fairseq2n::detail { 15 | 16 | template 17 | inline at::Tensor 18 | make_tensor_from_vector( 19 | const std::vector &src, 20 | const std::initializer_list &shape) noexcept 21 | { 22 | auto storage = std::make_shared>(src); 23 | 24 | return at::from_blob( 25 | storage->data(), 26 | c10::ArrayRef(shape), 27 | [storage](void*) mutable { storage.reset(); } 28 | ); 29 | } 30 | 31 | } // namespace fairseq2::detail 32 | -------------------------------------------------------------------------------- /src/fairseq2/nn/transformer/norm_order.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | 9 | 10 | class TransformerNormOrder(Enum): 11 | """Specifies the Layer Normalization order.""" 12 | 13 | POST = 0 14 | """Apply Layer Normalization after each layer's residual connection as 15 | described in :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" 16 | 17 | PRE = 1 18 | """Apply Layer Normalization at the beginning of each layer as described in 19 | :cite:t:`https://doi.org/10.48550/arxiv.2002.04745`.""" 20 | 21 | PRE_WITH_NORMFORMER = 2 22 | """Apply Layer Normalization as described in 23 | :cite:t:`https://doi.org/10.48550/arxiv.2110.09456`.""" 24 | -------------------------------------------------------------------------------- /ci/docker/manylinux_x86_64/build-scripts/install-llvm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | git clone --depth 1 --recurse-submodules --shallow-submodules --branch llvmorg-15.0.3\ 12 | https://github.com/llvm/llvm-project.git /llvm 13 | 14 | cmake\ 15 | -GNinja\ 16 | -S /llvm/llvm\ 17 | -B /llvm-build\ 18 | -DCMAKE_BUILD_TYPE=Release\ 19 | -DLLVM_ENABLE_PROJECTS="clang;clang-tools-extra;openmp"\ 20 | -DLLVM_TARGETS_TO_BUILD=host\ 21 | -Wno-dev 22 | 23 | cmake --build /llvm-build && cmake --install /llvm-build 24 | 25 | cp /llvm/clang/tools/clang-format/git-clang-format /usr/local/bin 26 | 27 | rm -rf /llvm /llvm-build 28 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/count_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/count_data_source.h" 8 | 9 | #include "fairseq2n/data/data.h" 10 | 11 | namespace fairseq2n::detail { 12 | 13 | std::optional 14 | count_data_source::next() 15 | { 16 | if (key_) 17 | return data_dict{{*key_, counter_++}}; 18 | 19 | return counter_++; 20 | } 21 | 22 | void 23 | count_data_source::reset() 24 | { 25 | counter_ = start_; 26 | } 27 | 28 | void 29 | count_data_source::record_position(tape &t) const 30 | { 31 | t.record(counter_); 32 | } 33 | 34 | void 35 | count_data_source::reload_position(tape &t) 36 | { 37 | counter_ = t.read(); 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/fairseq2/data/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os # noqa: F401 8 | from typing import Any, Union 9 | 10 | from typing_extensions import TypeAlias, TypeGuard 11 | 12 | from fairseq2.data.cstring import CString 13 | 14 | # A type alias for pathnames as recommended in PEP 519. 15 | PathLike: TypeAlias = Union[str, CString, "os.PathLike[str]"] 16 | 17 | 18 | # A convenience type alias for strings since most of our data APIs accept both 19 | # `str` and `CString`. 20 | StringLike: TypeAlias = Union[str, CString] 21 | 22 | 23 | def is_string_like(s: Any) -> TypeGuard[StringLike]: 24 | """Return ``True`` if ``s`` is of type ``str`` or :class:`CString`.""" 25 | return isinstance(s, (str, CString)) 26 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/memory_stream.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/memory.h" 12 | #include "fairseq2n/data/byte_stream.h" 13 | 14 | namespace fairseq2n::detail { 15 | 16 | class memory_stream final : public byte_stream { 17 | public: 18 | explicit 19 | memory_stream(memory_block block) noexcept 20 | : block_{std::move(block)} 21 | { 22 | original_block_ = block_; 23 | } 24 | 25 | memory_block 26 | read_chunk() override; 27 | 28 | void 29 | reset() override; 30 | 31 | private: 32 | memory_block block_; 33 | memory_block original_block_; 34 | }; 35 | 36 | } // namespace fairseq2n::detail 37 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/list_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/list_data_source.h" 8 | 9 | #include 10 | 11 | namespace fairseq2n::detail { 12 | 13 | std::optional 14 | list_data_source::next() 15 | { 16 | if (iter_ == list_.end()) 17 | return std::nullopt; 18 | 19 | return *iter_++; 20 | } 21 | 22 | void 23 | list_data_source::reset() 24 | { 25 | iter_ = list_.begin(); 26 | } 27 | 28 | void 29 | list_data_source::record_position(tape &t) const 30 | { 31 | t.record(iter_ - list_.begin()); 32 | } 33 | 34 | void 35 | list_data_source::reload_position(tape &t) 36 | { 37 | iter_ = list_.begin() + t.read(); 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/fairseq2/assets/cards/mistral.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: mistral_7b 8 | model_type: mistral 9 | model_arch: 7b 10 | checkpoint: "https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar;path=mistral-7B-v0.1%2Fconsolidated.00.pth" 11 | tokenizer: "https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar;path=mistral-7B-v0.1%2Ftokenizer.model" 12 | 13 | --- 14 | 15 | name: mistral_7b_instruct 16 | model_type: mistral 17 | model_arch: 7b 18 | checkpoint: "https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-instruct-v0.1b.tar;path=Mistral-7B-instruct-v0.1%2Fconsolidated.00.pth" 19 | tokenizer: "https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-instruct-v0.1b.tar;path=Mistral-7B-instruct-v0.1%2Ftokenizer.model" 20 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/text_reader.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/text/text_reader.h" 8 | 9 | #include 10 | 11 | #include "fairseq2n/data/data_pipeline.h" 12 | #include "fairseq2n/data/text/text_data_source.h" 13 | 14 | using namespace fairseq2n::detail; 15 | 16 | namespace fairseq2n { 17 | 18 | data_pipeline_builder 19 | read_text(std::string pathname, text_options opts) 20 | { 21 | auto factory = [pathname = std::move(pathname), opts = std::move(opts)]() mutable 22 | { 23 | return std::make_unique(std::move(pathname), std::move(opts)); 24 | }; 25 | 26 | return data_pipeline_builder{std::move(factory)}; 27 | } 28 | 29 | } // namespace fairseq2n 30 | -------------------------------------------------------------------------------- /fairseq2n/third-party/fmt.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_fmt) 8 | if(NOT TARGET fmt::fmt) 9 | set(FMT_SYSTEM_HEADERS ON) 10 | 11 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/fmt EXCLUDE_FROM_ALL) 12 | 13 | target_compile_features(fmt PRIVATE cxx_std_17) 14 | 15 | set_target_properties(fmt PROPERTIES 16 | CXX_VISIBILITY_PRESET 17 | hidden 18 | POSITION_INDEPENDENT_CODE 19 | ON 20 | ) 21 | 22 | # We depend on the phony torch_cxx11_abi target to ensure that we use 23 | # the same libstdc++ ABI as PyTorch. 24 | target_link_libraries(fmt PRIVATE torch_cxx11_abi) 25 | endif() 26 | endmacro() 27 | -------------------------------------------------------------------------------- /src/fairseq2/models/nllb/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.nllb.builder import NllbBuilder as NllbBuilder 8 | from fairseq2.models.nllb.builder import NllbConfig as NllbConfig 9 | from fairseq2.models.nllb.builder import create_nllb_model as create_nllb_model 10 | from fairseq2.models.nllb.builder import nllb_arch as nllb_arch 11 | from fairseq2.models.nllb.builder import nllb_archs as nllb_archs 12 | from fairseq2.models.nllb.loader import load_nllb_config as load_nllb_config 13 | from fairseq2.models.nllb.loader import load_nllb_model as load_nllb_model 14 | from fairseq2.models.nllb.loader import load_nllb_tokenizer as load_nllb_tokenizer 15 | from fairseq2.models.nllb.tokenizer import NllbTokenizer as NllbTokenizer 16 | -------------------------------------------------------------------------------- /src/fairseq2/data/vocabulary_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | 11 | @dataclass 12 | class VocabularyInfo: 13 | """Describes the vocabulary used by a tokenizer""" 14 | 15 | size: int 16 | """The size of the vocabulary.""" 17 | 18 | unk_idx: Optional[int] 19 | """The index of the symbol that represents an unknown element.""" 20 | 21 | bos_idx: Optional[int] 22 | """The index of the symbol that represents the beginning of a sequence.""" 23 | 24 | eos_idx: Optional[int] 25 | """The index of the symbol that represents the end of a sequence.""" 26 | 27 | pad_idx: Optional[int] 28 | """The index of the symbol that is used to pad a sequence.""" 29 | -------------------------------------------------------------------------------- /fairseq2n/tests/data/test_immutable_string.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | using namespace fairseq2n; 15 | 16 | // Also see the Python tests. 17 | TEST(test_immutable_string, constructor_throws_exception_when_string_is_invalid_utf8) 18 | { 19 | immutable_string s{"\xfe\xfe\xff\xff"}; 20 | 21 | EXPECT_THROW(s.get_code_point_length(), invalid_utf8_error); 22 | } 23 | 24 | TEST(test_immutable_string, copy_constructor_works) 25 | { 26 | immutable_string s1 = "foo"; 27 | 28 | immutable_string s2 = s1; // NOLINT(performance-unnecessary-copy-initialization) 29 | 30 | EXPECT_EQ(s1.data(), s2.data()); 31 | } 32 | -------------------------------------------------------------------------------- /src/fairseq2/models/mistral/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.mistral.builder import MistralBuilder as MistralBuilder 8 | from fairseq2.models.mistral.builder import MistralConfig as MistralConfig 9 | from fairseq2.models.mistral.builder import create_mistral_model as create_mistral_model 10 | from fairseq2.models.mistral.builder import mistral_archs as mistral_archs 11 | from fairseq2.models.mistral.loader import load_mistral_config as load_mistral_config 12 | from fairseq2.models.mistral.loader import load_mistral_model as load_mistral_model 13 | from fairseq2.models.mistral.loader import ( 14 | load_mistral_tokenizer as load_mistral_tokenizer, 15 | ) 16 | from fairseq2.models.mistral.tokenizer import MistralTokenizer as MistralTokenizer 17 | -------------------------------------------------------------------------------- /tools/run-shellcheck.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -eo pipefail 10 | 11 | function print_usage 12 | { 13 | echo "Usage: run-shellcheck PATHNAME" 14 | } 15 | 16 | function exit_with_usage 17 | { 18 | print_usage >&1 19 | 20 | exit 0 21 | } 22 | 23 | function exit_with_error 24 | { 25 | print_usage >&2 26 | 27 | exit 1 28 | } 29 | 30 | 31 | if [[ $# -ne 1 ]]; then 32 | exit_with_error 33 | fi 34 | 35 | if [[ $1 == -h || $1 == --help ]]; then 36 | exit_with_usage 37 | fi 38 | 39 | find "$1" \( -type d \( -name '.?*' -or -name 'build' -or -name 'third-party' \) -prune \) -or\ 40 | \( -type f -name '*.sh' -print0 \) | 41 | xargs --no-run-if-empty --null shellcheck -f gcc --severity=warning --norc 42 | -------------------------------------------------------------------------------- /src/fairseq2/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.llama.builder import LLaMABuilder as LLaMABuilder 8 | from fairseq2.models.llama.builder import LLaMAConfig as LLaMAConfig 9 | from fairseq2.models.llama.builder import create_llama_model as create_llama_model 10 | from fairseq2.models.llama.builder import get_llama_lora_config as get_llama_lora_config 11 | from fairseq2.models.llama.builder import llama_archs as llama_archs 12 | from fairseq2.models.llama.loader import load_llama_config as load_llama_config 13 | from fairseq2.models.llama.loader import load_llama_model as load_llama_model 14 | from fairseq2.models.llama.loader import load_llama_tokenizer as load_llama_tokenizer 15 | from fairseq2.models.llama.tokenizer import LLaMATokenizer as LLaMATokenizer 16 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/list_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/data/data_source.h" 13 | 14 | namespace fairseq2n::detail { 15 | 16 | class list_data_source final : public data_source { 17 | public: 18 | explicit 19 | list_data_source(data_list &&list) noexcept 20 | : list_(std::move(list)), iter_{list_.begin()} 21 | {} 22 | 23 | std::optional 24 | next() override; 25 | 26 | void 27 | reset() override; 28 | 29 | void 30 | record_position(tape &t) const override; 31 | 32 | void 33 | reload_position(tape &t) override; 34 | 35 | private: 36 | data_list list_; 37 | data_list::iterator iter_; 38 | }; 39 | 40 | } // namespace fairseq2n::detail 41 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/fmt.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | namespace fairseq2n { 10 | 11 | template 12 | struct repr {}; // disabled (i.e. poisoned) 13 | 14 | } // namespace fairseq2n 15 | 16 | 17 | // Allow the use of this header file even if libfmt is not available. 18 | #if __has_include() 19 | 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | 26 | template 27 | struct fmt::formatter, T>, char>> 28 | : fmt::formatter { 29 | 30 | auto 31 | format(const T &t, format_context &ctx) const 32 | { 33 | return formatter::format(fairseq2n::repr{}(t), ctx); 34 | } 35 | }; 36 | 37 | #endif 38 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/string_splitter.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "fairseq2n/api.h" 15 | #include "fairseq2n/data/data.h" 16 | 17 | namespace fairseq2n { 18 | 19 | class FAIRSEQ2_API string_splitter final { 20 | public: 21 | explicit 22 | string_splitter( 23 | char separator = '\t', 24 | std::vector names = {}, 25 | std::vector indices = {}, 26 | bool exclude = false); 27 | 28 | data 29 | operator()(data &&d) const; 30 | 31 | private: 32 | char separator_; 33 | std::vector names_; 34 | std::vector indices_; 35 | bool exclude_; 36 | }; 37 | 38 | } // namespace fairseq2n 39 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/concat_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_pipeline.h" 14 | #include "fairseq2n/data/data_source.h" 15 | 16 | namespace fairseq2n::detail { 17 | 18 | class concat_data_source final : public data_source { 19 | public: 20 | explicit 21 | concat_data_source( 22 | std::vector &&pipelines); 23 | 24 | std::optional 25 | next() override; 26 | 27 | void 28 | reset() override; 29 | 30 | void 31 | record_position(tape &t) const override; 32 | 33 | void 34 | reload_position(tape &t) override; 35 | 36 | private: 37 | std::vector pipelines_; 38 | }; 39 | 40 | } // namespace fairseq2n::detail 41 | -------------------------------------------------------------------------------- /tests/unit/nn/test_padding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from fairseq2.nn.padding import to_padding_mask 10 | from tests.common import assert_equal, device 11 | 12 | 13 | def test_to_padding_mask_works() -> None: 14 | seq_lens = torch.tensor([4, 2, 0, 5], device=device, dtype=torch.int32) 15 | 16 | mask = to_padding_mask(seq_lens, 6) 17 | 18 | # fmt: off 19 | expected_mask = torch.tensor( 20 | [ 21 | [True, True, True, True, False, False], 22 | [True, True, False, False, False, False], 23 | [False, False, False, False, False, False], 24 | [True, True, True, True, True, False], 25 | ], 26 | device=device, dtype=torch.bool 27 | ) 28 | # fmt: on 29 | 30 | assert mask is not None 31 | 32 | assert_equal(mask, expected_mask) 33 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/tape.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/tape.h" 8 | 9 | #include "fairseq2n/detail/exception.h" 10 | 11 | using namespace fairseq2n::detail; 12 | 13 | namespace fairseq2n { 14 | 15 | void 16 | tape::record_data(const data &d) 17 | { 18 | if (iter_ != storage_.end()) 19 | throw_("New data can only be recorded to the end of the tape."); 20 | 21 | storage_.push_back(d); 22 | 23 | // The iterator is invalid because of the `push_back()` call; we should not 24 | // increment it. 25 | iter_ = storage_.end(); 26 | } 27 | 28 | data 29 | tape::read_data() 30 | { 31 | if (iter_ == storage_.end()) 32 | throw_corrupt(); 33 | 34 | return *iter_++; 35 | } 36 | 37 | corrupt_tape_error::~corrupt_tape_error() = default; 38 | 39 | } // namespace fairseq2n 40 | -------------------------------------------------------------------------------- /tests/unit/nn/test_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from torch.nn.functional import cross_entropy, log_softmax 10 | 11 | from fairseq2.nn.functional import nll_loss 12 | from tests.common import assert_close, device 13 | 14 | 15 | @pytest.mark.parametrize("reduction", ["none", "sum"]) 16 | def test_nll_loss_computes_loss_correctly(reduction: str) -> None: 17 | logits = torch.randn((8, 16, 32), device=device) 18 | 19 | targets = torch.randint(low=0, high=32, size=(8, 16), device=device) 20 | 21 | loss1 = cross_entropy( 22 | logits.transpose(1, 2), targets, ignore_index=1, reduction=reduction 23 | ) 24 | 25 | lprobs = log_softmax(logits, dim=-1) 26 | 27 | loss2 = nll_loss(lprobs, targets, pad_idx=1, reduction=reduction) # type: ignore[arg-type] 28 | 29 | assert_close(loss1, loss2) 30 | -------------------------------------------------------------------------------- /fairseq2n/third-party/zip.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_zip) 8 | if(NOT TARGET zip::zip) 9 | set(backup_build_shared_libs ${BUILD_SHARED_LIBS}) 10 | 11 | # Force the library to be static. 12 | set(BUILD_SHARED_LIBS FALSE) 13 | 14 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/zip EXCLUDE_FROM_ALL) 15 | 16 | # Revert. 17 | set(BUILD_SHARED_LIBS ${backup_build_shared_libs}) 18 | 19 | unset(backup_build_shared_libs) 20 | endif() 21 | 22 | if(NOT TARGET kuba-zip) 23 | add_library(kuba-zip INTERFACE) 24 | 25 | target_link_libraries(kuba-zip INTERFACE zip::zip) 26 | 27 | target_include_directories(kuba-zip SYSTEM 28 | INTERFACE 29 | ${PROJECT_SOURCE_DIR}/third-party/zip/src 30 | ) 31 | endif() 32 | endmacro() 33 | -------------------------------------------------------------------------------- /src/fairseq2/models/w2vbert/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from fairseq2.models.w2vbert.builder import W2VBertBuilder as W2VBertBuilder 9 | from fairseq2.models.w2vbert.builder import W2VBertConfig as W2VBertConfig 10 | from fairseq2.models.w2vbert.builder import create_w2vbert_model as create_w2vbert_model 11 | from fairseq2.models.w2vbert.builder import w2vbert_arch as w2vbert_arch 12 | from fairseq2.models.w2vbert.builder import w2vbert_archs as w2vbert_archs 13 | from fairseq2.models.w2vbert.loader import load_w2vbert_config as load_w2vbert_config 14 | from fairseq2.models.w2vbert.loader import load_w2vbert_model as load_w2vbert_model 15 | from fairseq2.models.w2vbert.model import W2VBertLoss as W2VBertLoss 16 | from fairseq2.models.w2vbert.model import W2VBertModel as W2VBertModel 17 | from fairseq2.models.w2vbert.model import W2VBertOutput as W2VBertOutput 18 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/count_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/data/data_source.h" 12 | 13 | namespace fairseq2n::detail { 14 | 15 | class count_data_source final : public data_source { 16 | public: 17 | explicit 18 | count_data_source(std::int64_t start, std::optional key) noexcept 19 | : start_{start}, counter_{start}, key_{std::move(key)} 20 | {} 21 | 22 | std::optional 23 | next() override; 24 | 25 | void 26 | reset() override; 27 | 28 | void 29 | record_position(tape &t) const override; 30 | 31 | void 32 | reload_position(tape &t) override; 33 | 34 | private: 35 | std::int64_t start_; 36 | std::int64_t counter_; 37 | std::optional key_; 38 | }; 39 | 40 | } // namespace fairseq2n::detail 41 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/constant_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/data/data.h" 12 | #include "fairseq2n/data/data_source.h" 13 | 14 | namespace fairseq2n::detail { 15 | 16 | class constant_data_source final : public data_source { 17 | public: 18 | explicit 19 | constant_data_source(data &&example, std::optional key) noexcept 20 | : example_{std::move(example)}, key_{std::move(key)} 21 | {} 22 | 23 | std::optional 24 | next() override; 25 | 26 | void 27 | reset() override; 28 | 29 | void 30 | record_position(tape &t) const override; 31 | 32 | void 33 | reload_position(tape &t) override; 34 | 35 | private: 36 | data example_; 37 | std::optional key_; 38 | }; 39 | 40 | } // namespace fairseq2n::detail 41 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **What does this PR do? Please describe:** 2 | A summary of the change or the issue that is fixed. 3 | 4 | Fixes #{issue number} 5 | 6 | **Does your PR introduce any breaking changes? If yes, please list them:** 7 | List of all backwards-incompatible changes. 8 | 9 | **Check list:** 10 | - [ ] Was the content of this PR **discussed and approved** via a GitHub issue? (no need for typos or documentation improvements) 11 | - [ ] Did you read the [contributor guideline](https://github.com/facebookresearch/fairseq2/blob/main/CONTRIBUTING.md)? 12 | - [ ] Did you make sure that your **PR does only one thing** instead of bundling different changes together? 13 | - [ ] Did you make sure to **update the documentation** with your changes? (if necessary) 14 | - [ ] Did you write any **new necessary tests**? 15 | - [ ] Did you verify new and **existing tests pass** locally with your changes? 16 | - [ ] Did you **update the [CHANGELOG](https://github.com/facebookresearch/fairseq2/blob/main/CHANGELOG.md)**? (no need for typos, documentation, or minor internal changes) 17 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/detail/parallel.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #ifdef FAIRSEQ2N_USE_TBB 12 | #include 13 | #endif 14 | 15 | namespace fairseq2n::detail { 16 | 17 | template 18 | void 19 | parallel_for(const std::function &fn, T begin, T end) 20 | { 21 | #ifdef FAIRSEQ2N_USE_TBB 22 | tbb::blocked_range range{begin, end}; 23 | 24 | tbb::parallel_for( 25 | range, [&fn](const tbb::blocked_range &r) 26 | { 27 | fn(r.begin(), r.end()); 28 | }); 29 | #else 30 | // TODO: Use OpenMP! 31 | fn(begin, end); 32 | #endif 33 | } 34 | 35 | template 36 | inline void 37 | parallel_for(const std::function &fn, T size) 38 | { 39 | parallel_for(fn, T{}, size); 40 | } 41 | 42 | } // namespace fairseq2n::detail 43 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/exception.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/api.h" 12 | 13 | namespace fairseq2n { 14 | 15 | class FAIRSEQ2_API internal_error : public std::runtime_error { 16 | public: 17 | using std::runtime_error::runtime_error; 18 | 19 | public: 20 | internal_error(const internal_error &) = default; 21 | internal_error &operator=(const internal_error &) = default; 22 | 23 | ~internal_error() override; 24 | }; 25 | 26 | class FAIRSEQ2_API not_supported_error : public std::domain_error { 27 | public: 28 | using std::domain_error::domain_error; 29 | 30 | public: 31 | not_supported_error(const not_supported_error &) = default; 32 | not_supported_error &operator=(const not_supported_error &) = default; 33 | 34 | ~not_supported_error() override; 35 | }; 36 | 37 | } // namespace fairseq2n 38 | -------------------------------------------------------------------------------- /src/fairseq2/assets/cards/s2t_conformer.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: s2t_conformer_covost_st_en_de 8 | model_type: s2t_transformer 9 | model_arch: conformer_medium 10 | task: translation 11 | target_langs: [de] 12 | checkpoint: "https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/abs_asr_pt_avg_last_10_checkpoint.pt" 13 | tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip;path=spm_char.model" 14 | 15 | --- 16 | 17 | name: s2t_conformer_covost_st_en_de_rel_pos 18 | model_type: s2t_transformer 19 | model_arch: conformer_medium 20 | model_config: 21 | use_relative_pos: true 22 | task: translation 23 | target_langs: [de] 24 | checkpoint: "https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rel_pos_asr_pt_avg_last_10_checkpoint.pt" 25 | tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip;path=spm_char.model" 26 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_decoder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "fairseq2n/api.h" 15 | #include "fairseq2n/data/data.h" 16 | #include "fairseq2n/data/immutable_string.h" 17 | 18 | namespace fairseq2n { 19 | 20 | class sp_model; 21 | 22 | class FAIRSEQ2_API sp_decoder final { 23 | public: 24 | explicit 25 | sp_decoder(std::shared_ptr model, bool reverse = false) noexcept; 26 | 27 | data 28 | operator()(data &&d) const; 29 | 30 | data 31 | decode_from_tokens(data &&d) const; 32 | 33 | private: 34 | template 35 | immutable_string 36 | decode(const at::Tensor &tensor) const; 37 | 38 | private: 39 | std::shared_ptr model_; 40 | bool reverse_; 41 | }; 42 | 43 | } // namespace fairseq2n 44 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/skip_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/skip_data_source.h" 8 | 9 | namespace fairseq2n::detail { 10 | 11 | std::optional 12 | skip_data_source::next() 13 | { 14 | if (!skip_) { 15 | for (std::size_t i = 0; i < num_examples_; i++) 16 | if (!inner_->next()) 17 | break; 18 | 19 | skip_ = true; 20 | } 21 | 22 | return inner_->next(); 23 | } 24 | 25 | void 26 | skip_data_source::reset() 27 | { 28 | skip_ = false; 29 | 30 | inner_->reset(); 31 | } 32 | 33 | void 34 | skip_data_source::record_position(tape &t) const 35 | { 36 | t.record(skip_); 37 | 38 | inner_->record_position(t); 39 | } 40 | 41 | void 42 | skip_data_source::reload_position(tape &t) 43 | { 44 | skip_ = t.read(); 45 | 46 | inner_->reload_position(t); 47 | } 48 | 49 | } // namespace fairseq2n::detail 50 | -------------------------------------------------------------------------------- /tests/unit/nn/utils/test_grad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from fairseq2.nn.utils.grad import scale_grad 11 | from tests.common import assert_close, device 12 | 13 | 14 | def test_scale_grad_scales_gradient_correctly() -> None: 15 | a = torch.full((10, 10), 2.0, device=device, requires_grad=True) 16 | 17 | b = scale_grad(a, 0.1) 18 | 19 | c = b**3.0 20 | 21 | g = torch.autograd.grad(c, a, grad_outputs=torch.ones_like(b)) 22 | 23 | expected_grad = torch.full((10, 10), 1.2, device=device) 24 | 25 | assert_close(g[0], expected_grad) 26 | 27 | 28 | def test_scale_grad_raises_error_if_tensor_is_non_float() -> None: 29 | a = torch.ones((2, 2), dtype=torch.int32) 30 | 31 | with pytest.raises( 32 | TypeError, 33 | match=r"^`x` must be a float tensor, but is of type `torch\.int32` instead\.$", 34 | ): 35 | scale_grad(a, 1.0) 36 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/bucket_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_source.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class bucket_data_source final : public data_source { 18 | public: 19 | explicit 20 | bucket_data_source( 21 | std::unique_ptr &&inner, 22 | std::size_t bucket_size, 23 | bool drop_remainder) noexcept; 24 | 25 | std::optional 26 | next() override; 27 | 28 | void 29 | reset() override; 30 | 31 | void 32 | record_position(tape &t) const override; 33 | 34 | void 35 | reload_position(tape &t) override; 36 | 37 | private: 38 | std::unique_ptr inner_; 39 | std::size_t bucket_size_; 40 | bool drop_remainder_; 41 | }; 42 | 43 | } // namespace fairseq2n::detail 44 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/float.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace fairseq2n { 14 | 15 | using float32 = float; 16 | using float64 = double; 17 | 18 | namespace detail { 19 | 20 | template 21 | struct rel {}; 22 | 23 | template <> 24 | struct rel { 25 | static constexpr float32 value = 0.0001F; 26 | }; 27 | 28 | template <> 29 | struct rel { 30 | static constexpr float64 value = 0.0001; 31 | }; 32 | 33 | } // namespace detail 34 | 35 | // `T` must be a floating-point type. 36 | template >> 37 | inline constexpr bool 38 | are_close(T lhs, T rhs, T rel = detail::rel::value) noexcept 39 | { 40 | return std::abs(rhs - lhs) < rel * std::max(std::abs(lhs), std::abs(rhs)); 41 | } 42 | 43 | } // namespace fairseq2n 44 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/skip_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_source.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class skip_data_source final : public data_source { 18 | public: 19 | explicit 20 | skip_data_source(std::unique_ptr &&inner, std::size_t num_examples) noexcept 21 | : inner_{std::move(inner)}, num_examples_{num_examples} 22 | {} 23 | 24 | std::optional 25 | next() override; 26 | 27 | void 28 | reset() override; 29 | 30 | void 31 | record_position(tape &t) const override; 32 | 33 | void 34 | reload_position(tape &t) override; 35 | 36 | private: 37 | std::unique_ptr inner_; 38 | std::size_t num_examples_; 39 | bool skip_ = false; 40 | }; 41 | 42 | } // namespace fairseq2n::detail 43 | -------------------------------------------------------------------------------- /fairseq2n/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # ------------------------------------------------------------ 8 | # Target: tests 9 | # ------------------------------------------------------------ 10 | 11 | add_executable(tests) 12 | 13 | set_property(TARGET tests PROPERTY OUTPUT_NAME run-tests) 14 | 15 | target_sources(tests 16 | PRIVATE 17 | test_float.cc 18 | test_memory.cc 19 | test_span.cc 20 | data/test_immutable_string.cc 21 | data/test_tape.cc 22 | data/detail/test_lru_cache.cc 23 | utils/test_cast.cc 24 | ) 25 | 26 | fairseq2n_set_compile_options(tests) 27 | 28 | target_link_libraries(tests PRIVATE GTest::gtest_main fairseq2n) 29 | 30 | fairseq2n_set_link_options(tests) 31 | 32 | # By default, GTest discovery runs as a post-build step and fails if the targets 33 | # are built with sanitizers enabled. 34 | gtest_discover_tests(tests DISCOVERY_MODE PRE_TEST) 35 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/api.h" 13 | #include "fairseq2n/data/data.h" 14 | #include "fairseq2n/data/tape.h" 15 | 16 | namespace fairseq2n { 17 | 18 | class FAIRSEQ2_API data_source { 19 | public: 20 | data_source() noexcept = default; 21 | 22 | data_source(const data_source &) = default; 23 | data_source &operator=(const data_source &) = default; 24 | 25 | data_source(data_source &&) = default; 26 | data_source &operator=(data_source &&) = default; 27 | 28 | virtual 29 | ~data_source(); 30 | 31 | virtual std::optional 32 | next() = 0; 33 | 34 | virtual void 35 | reset() = 0; 36 | 37 | virtual void 38 | record_position(tape &t) const = 0; 39 | 40 | virtual void 41 | reload_position(tape &t) = 0; 42 | }; 43 | 44 | } // namespace fairseq2n 45 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/file_stream.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/memory.h" 13 | #include "fairseq2n/data/byte_stream.h" 14 | #include "fairseq2n/data/detail/file.h" 15 | 16 | namespace fairseq2n::detail { 17 | 18 | class file_stream final : public byte_stream { 19 | public: 20 | explicit 21 | file_stream(file_desc &&fd, std::string pathname, std::size_t chunk_size) noexcept; 22 | 23 | private: 24 | void 25 | hint_sequential_file() noexcept; 26 | 27 | public: 28 | memory_block 29 | read_chunk() override; 30 | 31 | void 32 | reset() override; 33 | 34 | private: 35 | std::size_t 36 | fill_chunk(writable_memory_span chunk); 37 | 38 | private: 39 | file_desc fd_; 40 | std::string pathname_; 41 | std::size_t chunk_size_; 42 | bool is_eod_ = false; 43 | }; 44 | 45 | } // namespace fairseq2n::detail 46 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/take_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_source.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class take_data_source final : public data_source { 18 | public: 19 | explicit 20 | take_data_source(std::unique_ptr &&inner, std::size_t num_examples) noexcept 21 | : inner_{std::move(inner)}, num_examples_{num_examples} 22 | {} 23 | 24 | std::optional 25 | next() override; 26 | 27 | void 28 | reset() override; 29 | 30 | void 31 | record_position(tape &t) const override; 32 | 33 | void 34 | reload_position(tape &t) override; 35 | 36 | private: 37 | std::unique_ptr inner_; 38 | std::size_t num_examples_; 39 | std::size_t num_examples_read_ = 0; 40 | }; 41 | 42 | } // namespace fairseq2n::detail 43 | -------------------------------------------------------------------------------- /src/fairseq2/nn/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch import Tensor 8 | 9 | 10 | def repeat_interleave(x: Tensor, dim: int, repeat: int) -> Tensor: 11 | """Repeat elements of a tensor. 12 | 13 | :param x: 14 | The input tensor. 15 | :param dim: 16 | The dimension along which to repeat values. 17 | :param repeat: 18 | The number of repetitions. 19 | 20 | :returns: 21 | The repeated tensor which has the same shape as input, except along the 22 | given axis. 23 | 24 | .. note:: 25 | This is a lightweight version of :func:`torch.repeat_interleave` that 26 | is faster for repetitions along a single dimension. 27 | """ 28 | if repeat == 1: 29 | return x 30 | 31 | shape = [-1] * (x.ndim + 1) 32 | 33 | if dim < 0: 34 | dim += x.ndim 35 | 36 | shape[dim + 1] = repeat 37 | 38 | return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/py.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/py.h" 8 | 9 | #include 10 | 11 | using namespace fairseq2n::detail; 12 | 13 | namespace fairseq2n { 14 | namespace detail { 15 | namespace { 16 | 17 | void (*inc_ref_fn_)(py_object &) noexcept = nullptr; 18 | void (*dec_ref_fn_)(py_object &) noexcept = nullptr; 19 | 20 | } // namespace 21 | 22 | void 23 | register_py_interpreter( 24 | void (*inc_ref_fn)(py_object &) noexcept, 25 | void (*dec_ref_fn)(py_object &) noexcept) 26 | { 27 | inc_ref_fn_ = inc_ref_fn; 28 | dec_ref_fn_ = dec_ref_fn; 29 | } 30 | 31 | } // namespace detail 32 | 33 | void 34 | py_object::inc_ref() noexcept 35 | { 36 | assert(inc_ref_fn_ != nullptr); 37 | 38 | inc_ref_fn_(*this); 39 | } 40 | 41 | void 42 | py_object::dec_ref() noexcept 43 | { 44 | assert(dec_ref_fn_ != nullptr); 45 | 46 | dec_ref_fn_(*this); 47 | } 48 | 49 | } // namespace fairseq2n 50 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/string_to_tensor_converter.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include "fairseq2n/api.h" 17 | #include "fairseq2n/data/data.h" 18 | 19 | namespace fairseq2n { 20 | 21 | class immutable_string; 22 | 23 | class FAIRSEQ2_API string_to_tensor_converter final { 24 | public: 25 | explicit 26 | string_to_tensor_converter( 27 | std::vector size = {}, std::optional maybe_dtype = {}); 28 | 29 | data 30 | operator()(data &&d) const; 31 | 32 | private: 33 | template 34 | void 35 | fill_storage(at::Tensor &tensor, const std::vector &strings) const; 36 | 37 | private: 38 | std::vector size_; 39 | at::ScalarType dtype_; 40 | }; 41 | 42 | } // namespace fairseq2n 43 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/take_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/take_data_source.h" 8 | 9 | namespace fairseq2n::detail { 10 | 11 | std::optional 12 | take_data_source::next() 13 | { 14 | if (num_examples_read_ == num_examples_) 15 | return std::nullopt; 16 | 17 | std::optional maybe_example = inner_->next(); 18 | if (maybe_example) 19 | num_examples_read_++; 20 | 21 | return maybe_example; 22 | } 23 | 24 | void 25 | take_data_source::reset() 26 | { 27 | num_examples_read_ = 0; 28 | 29 | inner_->reset(); 30 | } 31 | 32 | void 33 | take_data_source::record_position(tape &t) const 34 | { 35 | t.record(num_examples_read_); 36 | 37 | inner_->record_position(t); 38 | } 39 | 40 | void 41 | take_data_source::reload_position(tape &t) 42 | { 43 | num_examples_read_ = t.read(); 44 | 45 | inner_->reload_position(t); 46 | } 47 | 48 | } // namespace fairseq2n::detail 49 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/filter_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/data/data_pipeline.h" 13 | #include "fairseq2n/data/data_source.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class filter_data_source final : public data_source { 18 | public : 19 | explicit 20 | filter_data_source(std::unique_ptr &&inner, predicate_fn &&fn) noexcept 21 | : inner_{std::move(inner)}, predicate_fn_{std::move(fn)} 22 | {} 23 | 24 | std::optional 25 | next() override; 26 | 27 | void 28 | reset() override; 29 | 30 | void 31 | record_position(tape &t) const override; 32 | 33 | void 34 | reload_position(tape &t) override; 35 | 36 | private: 37 | bool 38 | invoke_function(data &example); 39 | 40 | private: 41 | std::unique_ptr inner_; 42 | predicate_fn predicate_fn_; 43 | }; 44 | 45 | } // namespace fairseq2n::detail 46 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/data/init.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/bindings/module.h" 8 | 9 | #include 10 | 11 | namespace py = pybind11; 12 | 13 | namespace fairseq2n { 14 | namespace { 15 | 16 | void 17 | inc_ref(py_object &obj) noexcept // NOLINT(bugprone-exception-escape) 18 | { 19 | py::gil_scoped_acquire gil{}; 20 | 21 | Py_IncRef(static_cast(obj.ptr())); 22 | } 23 | 24 | void 25 | dec_ref(py_object &obj) noexcept // NOLINT(bugprone-exception-escape) 26 | { 27 | py::gil_scoped_acquire gil{}; 28 | 29 | Py_DecRef(static_cast(obj.ptr())); 30 | } 31 | 32 | } // namespace 33 | 34 | void 35 | def_data(py::module_ &base) 36 | { 37 | detail::register_py_interpreter(inc_ref, dec_ref); 38 | 39 | py::module_ m = base.def_submodule("data"); 40 | 41 | def_audio(m); 42 | 43 | def_data_pipeline(m); 44 | 45 | def_string(m); 46 | 47 | def_text(m); 48 | } 49 | 50 | } // namespace fairseq2n 51 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/shard_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_source.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class shard_data_source final : public data_source { 18 | public: 19 | explicit 20 | shard_data_source( 21 | std::unique_ptr &&inner, 22 | std::size_t shard_idx, 23 | std::size_t num_shards) noexcept 24 | : inner_{std::move(inner)}, shard_idx_{shard_idx}, num_shards_{num_shards} 25 | {} 26 | 27 | std::optional 28 | next() override; 29 | 30 | void 31 | reset() override; 32 | 33 | void 34 | record_position(tape &t) const override; 35 | 36 | void 37 | reload_position(tape &t) override; 38 | 39 | private: 40 | std::unique_ptr inner_; 41 | std::size_t shard_idx_; 42 | std::size_t num_shards_; 43 | }; 44 | 45 | } // namespace fairseq2n::detail 46 | -------------------------------------------------------------------------------- /fairseq2n/tests/utils/test_cast.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include 15 | 16 | using namespace fairseq2n; 17 | using namespace fairseq2n::detail; 18 | 19 | TEST(test_cast, maybe_narrow_works_when_value_is_within_range) 20 | { 21 | std::int64_t a = 100; 22 | std::int32_t b = 0; 23 | 24 | EXPECT_TRUE(maybe_narrow(a, b)); 25 | 26 | EXPECT_EQ(b, 100); 27 | 28 | float64 c = 12.0; 29 | float32 d = 0; 30 | 31 | EXPECT_TRUE(maybe_narrow(c, d)); 32 | 33 | EXPECT_EQ(d, 12.0); 34 | } 35 | 36 | TEST(test_cast, maybe_narrow_works_when_value_is_out_of_range) 37 | { 38 | std::int64_t a = std::numeric_limits::max(); 39 | std::int32_t b = 0; 40 | 41 | EXPECT_FALSE(maybe_narrow(a, b)); 42 | 43 | float64 c = std::numeric_limits::max(); 44 | float32 d = 0; 45 | 46 | EXPECT_FALSE(maybe_narrow(c, d)); 47 | } 48 | -------------------------------------------------------------------------------- /src/fairseq2/optim/optimizer_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Callable, Optional 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | class OptimizerBase(ABC, Optimizer): 15 | """Represents the base class for all optimizers.""" 16 | 17 | def step( # type: ignore[override] 18 | self, closure: Optional[Callable[[], float]] = None 19 | ) -> Optional[float]: 20 | loss = None 21 | 22 | prev_grad = torch.is_grad_enabled() 23 | 24 | try: 25 | torch.set_grad_enabled(self.defaults["differentiable"]) 26 | 27 | if closure is not None: 28 | with torch.enable_grad(): 29 | loss = closure() 30 | 31 | self._do_step() 32 | finally: 33 | torch.set_grad_enabled(prev_grad) 34 | 35 | return loss 36 | 37 | @abstractmethod 38 | def _do_step(self) -> None: 39 | """Perform a single optimization step.""" 40 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/detail/exception.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | namespace fairseq2n::detail { 16 | 17 | template 18 | [[noreturn]] inline void 19 | throw_(fmt::format_string format, Args &&...args) 20 | { 21 | throw E{fmt::vformat(format, fmt::make_format_args(args...))}; 22 | } 23 | 24 | template 25 | [[noreturn]] inline void 26 | throw_with_nested(fmt::format_string format, Args &&...args) 27 | { 28 | std::throw_with_nested(E{fmt::vformat(format, fmt::make_format_args(args...))}); 29 | } 30 | 31 | template 32 | [[noreturn]] inline void 33 | throw_system_error(std::error_code ec, fmt::format_string format, Args &&...args) 34 | { 35 | throw std::system_error{ec, fmt::vformat(format, fmt::make_format_args(args...))}; 36 | } 37 | 38 | } // namespace fairseq2n::detail 39 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/shard_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/shard_data_source.h" 8 | 9 | namespace fairseq2n::detail { 10 | 11 | std::optional 12 | shard_data_source::next() 13 | { 14 | for (std::size_t i = 0; i < shard_idx_; i++) 15 | if (!inner_->next()) 16 | return std::nullopt; 17 | 18 | std::optional maybe_example = inner_->next(); 19 | if (!maybe_example) 20 | return std::nullopt; 21 | 22 | for (std::size_t i = 0; i < num_shards_ - shard_idx_ - 1; i++) 23 | if (!inner_->next()) 24 | return std::nullopt; 25 | 26 | return maybe_example; 27 | } 28 | 29 | void 30 | shard_data_source::reset() 31 | { 32 | inner_->reset(); 33 | } 34 | 35 | void 36 | shard_data_source::record_position(tape &t) const 37 | { 38 | inner_->record_position(t); 39 | } 40 | 41 | void 42 | shard_data_source::reload_position(tape &t) 43 | { 44 | inner_->reload_position(t); 45 | } 46 | 47 | } // namespace fairseq2n::detail 48 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/text_line_reader.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/span.h" 13 | #include "fairseq2n/data/byte_stream.h" 14 | #include "fairseq2n/data/record_reader.h" 15 | #include "fairseq2n/data/text/text_reader.h" 16 | 17 | namespace fairseq2n::detail { 18 | 19 | class text_line_reader final : public record_reader { 20 | public: 21 | explicit 22 | text_line_reader(std::unique_ptr &&stream, line_ending le) 23 | : record_reader{std::move(stream)}, line_ending_{le} 24 | {} 25 | 26 | line_ending 27 | actual_line_ending() const noexcept 28 | { 29 | return line_ending_; 30 | } 31 | 32 | private: 33 | std::optional 34 | maybe_find_record_end(memory_span chunk, bool first_chunk) override; 35 | 36 | bool 37 | infer_line_ending(span chars); 38 | 39 | private: 40 | line_ending line_ending_; 41 | }; 42 | 43 | } // namespace fairseq2n::detail 44 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/data.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/data.h" 8 | 9 | #include 10 | 11 | #include "fairseq2n/detail/exception.h" 12 | 13 | using namespace fairseq2n::detail; 14 | 15 | namespace fairseq2n { 16 | 17 | std::string 18 | repr::operator()(data_type dt) const 19 | { 20 | switch (dt) { 21 | case data_type::bool_: 22 | return "bool"; 23 | case data_type::int_: 24 | return "int"; 25 | case data_type::float_: 26 | return "float"; 27 | case data_type::string: 28 | return "string"; 29 | case data_type::tensor: 30 | return "torch.Tensor"; 31 | case data_type::memory_block: 32 | return "memory_block"; 33 | case data_type::list: 34 | return "list"; 35 | case data_type::dict: 36 | return "dict"; 37 | case data_type::pyobj: 38 | return "pyobj"; 39 | }; 40 | 41 | throw_("`dt` is not a valid data type."); 42 | } 43 | 44 | } // namespace fairseq2n 45 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/detail/tensor_helpers.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "fairseq2n/memory.h" 16 | 17 | namespace fairseq2n::detail { 18 | 19 | inline memory_span 20 | get_raw_storage(const at::Tensor &tensor) 21 | { 22 | const at::Storage &storage = tensor.storage(); 23 | 24 | return memory_span{static_cast(storage.data()), storage.nbytes()}; 25 | } 26 | 27 | inline writable_memory_span 28 | get_raw_mutable_storage(const at::Tensor &tensor) 29 | { 30 | const at::Storage &storage = tensor.storage(); 31 | 32 | return writable_memory_span{ 33 | #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 1) 34 | static_cast(storage.data()), storage.nbytes()}; 35 | #else 36 | static_cast(storage.mutable_data()), storage.nbytes()}; 37 | #endif 38 | } 39 | 40 | } // namespace fairseq2n::detail 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["packaging~=23.1", "setuptools~=67.8", "wheel~=0.40"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | extend-exclude = "third-party" 7 | 8 | [tool.flake8] 9 | extend_exclude = ["third-party"] 10 | extend_ignore = ["E", "Y"] # Black 11 | per-file-ignores = [ 12 | "__init__.py:F401", 13 | ] 14 | 15 | [tool.isort] 16 | extend_skip = "third-party" 17 | profile = "black" 18 | 19 | [tool.mypy] 20 | disable_error_code = "type-abstract,typeddict-unknown-key" 21 | disallow_untyped_calls = false 22 | disallow_untyped_decorators = false 23 | exclude = ["build", "third-party", "^setup\\.py$"] 24 | ignore_missing_imports = true 25 | python_version = 3.8 26 | show_error_codes = true 27 | show_error_context = true 28 | strict = true 29 | warn_unused_configs = false 30 | warn_unused_ignores = false 31 | 32 | [tool.pytest.ini_options] 33 | minversion = "7.1" 34 | testpaths = ["tests"] 35 | filterwarnings = [ 36 | "ignore:distutils Version classes are deprecated:DeprecationWarning", 37 | "ignore:pkg_resources is deprecated:DeprecationWarning", 38 | "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", 39 | "ignore:To copy construct from a tensor:UserWarning", 40 | ] 41 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/init.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/bindings/module.h" 8 | 9 | #include 10 | 11 | namespace py = pybind11; 12 | 13 | namespace fairseq2n { 14 | 15 | PYBIND11_MODULE(bindings, m) 16 | { 17 | py::options opts{}; 18 | opts.disable_function_signatures(); 19 | 20 | m.def( 21 | "_supports_cuda", 22 | [] 23 | { 24 | return supports_cuda; 25 | }); 26 | 27 | // See https://github.com/llvm/llvm-project/issues/57123. 28 | #pragma clang diagnostic push 29 | #pragma clang diagnostic ignored "-Wunreachable-code-return" 30 | 31 | m.def( 32 | "_cuda_version", 33 | [] 34 | { 35 | if constexpr (cuda_version_major) 36 | return py::make_tuple(*cuda_version_major, *cuda_version_minor); 37 | else 38 | return py::none(); 39 | }); 40 | 41 | #pragma clang diagnostic pop 42 | 43 | def_data(m); 44 | 45 | def_memory(m); 46 | } 47 | 48 | } // namespace fairseq2n 49 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/concat_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/concat_data_source.h" 8 | #include 9 | 10 | namespace fairseq2n::detail { 11 | 12 | concat_data_source::concat_data_source(std::vector &&pipelines) 13 | : pipelines_(std::move(pipelines)) 14 | {} 15 | 16 | std::optional 17 | concat_data_source::next() 18 | { 19 | std::optional d; 20 | for (auto &p : pipelines_) { 21 | d = p.next(); 22 | if (d) 23 | return d; 24 | } 25 | return {}; 26 | } 27 | 28 | void concat_data_source::reset() 29 | { 30 | for (auto &pipeline : pipelines_) 31 | pipeline.reset(); 32 | } 33 | 34 | void concat_data_source::record_position(tape &t) const 35 | { 36 | for (auto &pipeline : pipelines_) 37 | pipeline.record_position(t); 38 | } 39 | 40 | void concat_data_source::reload_position(tape &t) 41 | { 42 | for (auto &pipeline : pipelines_) 43 | pipeline.reload_position(t); 44 | } 45 | 46 | } // namespace fairseq2n::detail 47 | 48 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/detail/thread.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace fairseq2n::detail { 17 | 18 | template 19 | std::thread 20 | start_thread(Func &&f, Args &&...args) 21 | { 22 | ::sigset_t mask{}; 23 | ::sigset_t original_mask{}; 24 | 25 | sigfillset(&mask); 26 | 27 | // Block all async signals in the new thread. 28 | int result = ::pthread_sigmask(SIG_SETMASK, &mask, &original_mask); 29 | if (result != 0) 30 | throw std::system_error{result, std::generic_category()}; 31 | 32 | std::thread t{std::forward(f), std::forward(args)...}; 33 | 34 | // Restore the signal mask. 35 | result = ::pthread_sigmask(SIG_SETMASK, &original_mask, nullptr); 36 | if (result != 0) 37 | throw std::system_error{result, std::generic_category()}; 38 | 39 | return t; 40 | } 41 | 42 | } // namespace fairseq2n::detail 43 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/type_casters/py.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include 12 | 13 | namespace pybind11::detail { 14 | 15 | template <> 16 | struct type_caster { 17 | PYBIND11_TYPE_CASTER(fairseq2n::py_object, const_name("Any")); 18 | 19 | public: 20 | bool 21 | load(handle src, bool) 22 | { 23 | value = fairseq2n::py_object{src.ptr()}; 24 | 25 | return true; 26 | } 27 | 28 | static handle 29 | cast(const fairseq2n::py_object &src, return_value_policy, handle) 30 | { 31 | auto ptr = static_cast(src.ptr()); 32 | 33 | handle h{ptr}; 34 | 35 | h.inc_ref(); 36 | 37 | return h; 38 | } 39 | 40 | static handle 41 | cast(fairseq2n::py_object &&src, return_value_policy, handle) 42 | { 43 | auto ptr = static_cast(src.release()); 44 | 45 | return handle{ptr}; 46 | } 47 | }; 48 | 49 | } // namespace pybind11::detail 50 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/memory.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/memory.h" 8 | 9 | #include 10 | 11 | #include 12 | 13 | using namespace fairseq2n::detail; 14 | 15 | namespace fairseq2n { 16 | namespace detail { 17 | namespace { 18 | 19 | void 20 | deallocate(const void *addr, std::size_t, void *) noexcept 21 | { 22 | if (addr != nullptr) 23 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) 24 | ::operator delete(const_cast(addr)); 25 | } 26 | 27 | } // namespace 28 | } // namespace detail 29 | 30 | writable_memory_block 31 | allocate_memory(std::size_t size) 32 | { 33 | void *addr = ::operator new(size); 34 | 35 | return writable_memory_block{static_cast(addr), size, nullptr, deallocate}; 36 | } 37 | 38 | writable_memory_block 39 | copy_memory(memory_span source) 40 | { 41 | writable_memory_block target = allocate_memory(source.size()); 42 | 43 | std::copy(source.begin(), source.end(), target.begin()); 44 | 45 | return target; 46 | } 47 | 48 | } // namespace fairseq2n 49 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/byte_stream.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/api.h" 12 | #include "fairseq2n/memory.h" 13 | 14 | namespace fairseq2n { 15 | 16 | class FAIRSEQ2_API byte_stream { 17 | public: 18 | byte_stream() noexcept = default; 19 | 20 | byte_stream(const byte_stream &) = default; 21 | byte_stream &operator=(const byte_stream &) = default; 22 | 23 | byte_stream(byte_stream &&) = default; 24 | byte_stream &operator=(byte_stream &&) = default; 25 | 26 | virtual 27 | ~byte_stream(); 28 | 29 | virtual memory_block 30 | read_chunk() = 0; 31 | 32 | virtual void 33 | reset() = 0; 34 | }; 35 | 36 | class FAIRSEQ2_API byte_stream_error : public std::runtime_error { 37 | public: 38 | using std::runtime_error::runtime_error; 39 | 40 | public: 41 | byte_stream_error(const byte_stream_error &) = default; 42 | byte_stream_error &operator=(const byte_stream_error &) = default; 43 | 44 | ~byte_stream_error() override; 45 | }; 46 | 47 | } // namespace fairseq2n 48 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/round_robin_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_pipeline.h" 14 | #include "fairseq2n/data/data_source.h" 15 | #include "fairseq2n/data/composite_data_source.h" 16 | 17 | namespace fairseq2n::detail { 18 | 19 | class round_robin_data_source final : public data_source { 20 | public: 21 | explicit 22 | round_robin_data_source(std::vector &&pipelines, bool stop_at_shortest); 23 | 24 | std::optional 25 | next() override; 26 | 27 | void 28 | reset() override; 29 | 30 | void 31 | record_position(tape &t) const override; 32 | 33 | void 34 | reload_position(tape &t) override; 35 | 36 | private: 37 | std::optional 38 | next_in_pipeline(std::size_t pipeline_idx); 39 | 40 | private: 41 | std::unique_ptr inner_; 42 | std::size_t pipeline_idx_; 43 | std::size_t pipelines_count_; 44 | }; 45 | 46 | } // namespace fairseq2n::detail 47 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/file_mapper.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "fairseq2n/api.h" 16 | #include "fairseq2n/memory.h" 17 | #include "fairseq2n/data/data.h" 18 | #include "fairseq2n/data/detail/lru_cache.h" 19 | 20 | namespace fairseq2n { 21 | 22 | class immutable_string; 23 | 24 | class FAIRSEQ2_API file_mapper { 25 | static constexpr std::size_t default_cached_fd_count = 100; 26 | 27 | public: 28 | explicit 29 | file_mapper( 30 | std::optional maybe_root_dir, 31 | std::optional maybe_cached_fd_count = {}) noexcept; 32 | 33 | data 34 | operator()(data &&d) const; 35 | 36 | private: 37 | memory_block 38 | get_memory_map(const immutable_string &pathname) const; 39 | 40 | private: 41 | std::filesystem::path root_dir_{}; 42 | mutable std::mutex cache_mutex_{}; 43 | mutable detail::lru_cache cache_; 44 | }; 45 | 46 | } // namespace fairseq2n 47 | -------------------------------------------------------------------------------- /src/fairseq2/nn/transformer/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Protocol 8 | 9 | from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm 10 | from fairseq2.typing import DataType, Device 11 | 12 | 13 | class LayerNormFactory(Protocol): 14 | """Constructs instances of :class:`LayerNorm`.""" 15 | 16 | def __call__( 17 | self, 18 | model_dim: int, 19 | *, 20 | device: Optional[Device] = None, 21 | dtype: Optional[DataType] = None, 22 | ) -> LayerNorm: 23 | """ 24 | :param model_dim: 25 | The dimensionality of the model. 26 | :param device: 27 | The device on which to initialize the module. 28 | :param dtype: 29 | The data type of the module. 30 | """ 31 | 32 | 33 | def create_standard_layer_norm( 34 | model_dim: int, *, device: Optional[Device] = None, dtype: Optional[DataType] = None 35 | ) -> LayerNorm: 36 | """Constructs an instance of :class:`StandardLayerNorm`.""" 37 | return StandardLayerNorm(model_dim, bias=True, device=device, dtype=dtype) 38 | -------------------------------------------------------------------------------- /src/fairseq2/utils/rng.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def seed(value: int) -> None: 14 | """Set RNG seed for ``random``, ``np.random``, and ``torch``. 15 | 16 | :param value: 17 | The new seed. 18 | """ 19 | if value >= 1 << 32: 20 | raise ValueError( 21 | f"`value` must be greater than or equal to 0 and less than 2^32, but is {value} instead." 22 | ) 23 | 24 | random.seed(value) 25 | 26 | np.random.seed(value) 27 | 28 | torch.manual_seed(value) 29 | 30 | 31 | def use_deterministic(value: bool, warn_only: bool = False) -> None: 32 | """Set whether PyTorch algorithms must use deterministic algorithms. 33 | 34 | :param value: 35 | If ``True``, uses deterministic algorithms. 36 | :param warn_only: 37 | If ``True``, operations that do not have a deterministic implementation 38 | will raise a warning instead of an error. 39 | """ 40 | torch.backends.cudnn.benchmark = not value 41 | 42 | torch.use_deterministic_algorithms(value, warn_only=warn_only) 43 | -------------------------------------------------------------------------------- /tests/unit/data/data_pipeline/test_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | from fairseq2.data import DataPipelineError, read_sequence 10 | 11 | 12 | class TestFilterOp: 13 | def test_op_works(self) -> None: 14 | def fn(d: int) -> bool: 15 | return d % 2 == 1 16 | 17 | pipeline = read_sequence([1, 2, 3, 4, 5, 6, 7, 8, 9]).filter(fn).and_return() 18 | 19 | for _ in range(2): 20 | assert list(pipeline) == [1, 3, 5, 7, 9] 21 | 22 | pipeline.reset() 23 | 24 | def test_op_raises_nested_error_when_callable_fails(self) -> None: 25 | def fn(d: int) -> bool: 26 | if d == 3: 27 | raise ValueError("filter error") 28 | 29 | return True 30 | 31 | pipeline = read_sequence([1, 2, 3, 4]).filter(fn).and_return() 32 | 33 | with pytest.raises(DataPipelineError) as exc_info: 34 | for d in pipeline: 35 | pass 36 | 37 | cause = exc_info.value.__cause__ 38 | 39 | assert isinstance(cause, ValueError) 40 | 41 | assert str(cause) == "filter error" 42 | -------------------------------------------------------------------------------- /src/fairseq2/data/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.data.text.converters import StrSplitter as StrSplitter 8 | from fairseq2.data.text.converters import StrToIntConverter as StrToIntConverter 9 | from fairseq2.data.text.converters import StrToTensorConverter as StrToTensorConverter 10 | from fairseq2.data.text.sentencepiece import ( 11 | SentencePieceDecoder as SentencePieceDecoder, 12 | ) 13 | from fairseq2.data.text.sentencepiece import ( 14 | SentencePieceEncoder as SentencePieceEncoder, 15 | ) 16 | from fairseq2.data.text.sentencepiece import SentencePieceModel as SentencePieceModel 17 | from fairseq2.data.text.sentencepiece import ( 18 | vocabulary_from_sentencepiece as vocabulary_from_sentencepiece, 19 | ) 20 | from fairseq2.data.text.text_reader import LineEnding as LineEnding 21 | from fairseq2.data.text.text_reader import read_text as read_text 22 | from fairseq2.data.text.text_tokenizer import TextTokenDecoder as TextTokenDecoder 23 | from fairseq2.data.text.text_tokenizer import TextTokenEncoder as TextTokenEncoder 24 | from fairseq2.data.text.text_tokenizer import TextTokenizer as TextTokenizer 25 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/zip_file_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include "fairseq2n/memory.h" 17 | #include "fairseq2n/span.h" 18 | #include "fairseq2n/data/data_source.h" 19 | 20 | namespace fairseq2n::detail { 21 | 22 | class zip_file_data_source final : public data_source { 23 | public: 24 | explicit 25 | zip_file_data_source(std::string &&pathname); 26 | 27 | std::optional 28 | next() override; 29 | 30 | void 31 | reset() override; 32 | 33 | void 34 | record_position(tape &t) const override; 35 | 36 | void 37 | reload_position(tape &t) override; 38 | 39 | private: 40 | memory_block 41 | next_line(); 42 | 43 | [[noreturn]] void 44 | handle_error(); 45 | 46 | [[noreturn]] void 47 | throw_read_failure(); 48 | 49 | private: 50 | std::string pathname_; 51 | zip_t* zip_reader_; 52 | std::size_t num_entries_; 53 | std::size_t num_files_read_ = 0; 54 | }; 55 | 56 | } // namespace fairseq2n::detail 57 | -------------------------------------------------------------------------------- /fairseq2n/cmake/modules/FindSndFile.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | include(FindPackageHandleStandardArgs) 8 | 9 | find_package(SndFile QUIET CONFIG) 10 | if(SndFile_FOUND) 11 | find_package_handle_standard_args(SndFile CONFIG_MODE) 12 | 13 | return() 14 | endif() 15 | 16 | find_package(PkgConfig QUIET) 17 | if(PKG_CONFIG_FOUND) 18 | pkg_check_modules(SndFile QUIET sndfile) 19 | endif() 20 | 21 | find_library(SndFile_LIBRARY sndfile HINTS ${SndFile_LIBRARY_DIRS}) 22 | 23 | find_path(SndFile_INCLUDE_DIR sndfile.h HINTS ${SndFile_INCLUDE_DIRS}) 24 | 25 | mark_as_advanced(SndFile_LIBRARY SndFile_INCLUDE_DIR) 26 | 27 | find_package_handle_standard_args(SndFile 28 | REQUIRED_VARS 29 | SndFile_LIBRARY SndFile_INCLUDE_DIR 30 | VERSION_VAR 31 | SndFile_VERSION 32 | ) 33 | 34 | if(NOT SndFile_FOUND) 35 | return() 36 | endif() 37 | 38 | if(NOT TARGET SndFile::sndfile) 39 | add_library(SndFile::sndfile SHARED IMPORTED) 40 | 41 | set_property(TARGET SndFile::sndfile PROPERTY IMPORTED_LOCATION ${SndFile_LIBRARY}) 42 | 43 | target_include_directories(SndFile::sndfile INTERFACE ${SndFile_INCLUDE_DIR}) 44 | endif() 45 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/sample_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | #include "fairseq2n/data/data_pipeline.h" 16 | #include "fairseq2n/data/data_source.h" 17 | #include "fairseq2n/data/composite_data_source.h" 18 | 19 | namespace fairseq2n::detail { 20 | 21 | /// @brief sample from a list of datasources 22 | class sample_data_source final : public data_source { 23 | public: 24 | explicit 25 | sample_data_source(std::vector &&pipelines, std::vector &&weights, bool stop_at_shortest); 26 | 27 | std::optional 28 | next() override; 29 | 30 | void 31 | reset() override; 32 | 33 | void 34 | record_position(tape &t) const override; 35 | 36 | void 37 | reload_position(tape &t) override; 38 | 39 | private: 40 | std::size_t 41 | next_index(); 42 | 43 | private: 44 | std::unique_ptr inner_; 45 | 46 | at::Generator generator_; 47 | at::Tensor weights_; 48 | }; 49 | 50 | } // namespace fairseq2::detail 51 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/yield_from_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/data/data_pipeline.h" 13 | #include "fairseq2n/data/data_source.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class yield_from_data_source final : public data_source { 18 | public: 19 | explicit 20 | yield_from_data_source(std::unique_ptr &&inner, yield_fn &&fn) noexcept 21 | : inner_{std::move(inner)}, yield_fn_{std::move(fn)} 22 | {} 23 | 24 | std::optional 25 | next() override; 26 | 27 | void 28 | reset() override; 29 | 30 | void 31 | record_position(tape &t) const override; 32 | 33 | void 34 | reload_position(tape &t) override; 35 | 36 | private: 37 | bool 38 | load_next_data_pipeline(); 39 | 40 | data_pipeline 41 | invoke_function(data &example); 42 | 43 | private: 44 | std::unique_ptr inner_; 45 | yield_fn yield_fn_; 46 | std::optional maybe_current_example_{}; 47 | data_pipeline data_pipeline_{}; 48 | }; 49 | 50 | } // namespace fairseq2n::detail 51 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/shuffle_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include "fairseq2n/data/data_source.h" 16 | 17 | namespace fairseq2n::detail { 18 | 19 | class shuffle_data_source final : public data_source { 20 | static constexpr std::size_t max_pre_alloc_size_ = 100'000; 21 | 22 | public: 23 | explicit 24 | shuffle_data_source( 25 | std::unique_ptr &&inner, std::size_t shuffle_window, bool strict) noexcept; 26 | 27 | std::optional 28 | next() override; 29 | 30 | void 31 | reset() override; 32 | 33 | void 34 | record_position(tape &t) const override; 35 | 36 | void 37 | reload_position(tape &t) override; 38 | 39 | private: 40 | std::size_t 41 | random_index(); 42 | 43 | private: 44 | std::unique_ptr inner_; 45 | data_list buffer_{}; 46 | std::size_t shuffle_window_; 47 | at::Generator generator_; 48 | bool strict_; 49 | bool fill_buffer_ = true; 50 | }; 51 | 52 | } // namespace fairseq2n::detail 53 | -------------------------------------------------------------------------------- /fairseq2n/tests/test_float.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | 9 | #include 10 | 11 | using namespace fairseq2n; 12 | 13 | TEST(test_cast, are_close_works_when_inputs_are_equal) 14 | { 15 | float32 a = 1.0F; 16 | float32 b = 1.0F; 17 | 18 | EXPECT_TRUE(are_close(a, b)); 19 | 20 | float64 c = 1.0; 21 | float64 d = 1.0; 22 | 23 | EXPECT_TRUE(are_close(c, d)); 24 | } 25 | 26 | TEST(test_cast, are_close_works_when_inputs_are_within_relative_distance) 27 | { 28 | float32 a = 3.0F; 29 | // This is the maximum tolerance we have for the relative difference 30 | // between the two numbers. 31 | float32 b = 3.0F + (a * 0.0001F); 32 | 33 | EXPECT_TRUE(are_close(a, b)); 34 | 35 | // This number should be treated as equal. 36 | float32 c = 3.0F + (a * 0.000001F); 37 | 38 | EXPECT_TRUE(are_close(a, c)); 39 | } 40 | 41 | TEST(test_cast, are_close_works_when_inputs_are_outside_of_relative_distance) 42 | { 43 | float32 a = 3.0F; 44 | // This is beyond our tolerance threshold. 45 | float32 b = 3.0F + (a * 0.00011F); 46 | 47 | EXPECT_FALSE(are_close(a, b)); 48 | } 49 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/module.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "fairseq2n/bindings/type_casters/data.h" 15 | #include "fairseq2n/bindings/type_casters/map_fn.h" 16 | #include "fairseq2n/bindings/type_casters/py.h" 17 | #include "fairseq2n/bindings/type_casters/string.h" 18 | #include "fairseq2n/bindings/type_casters/torch.h" 19 | 20 | namespace fairseq2n { 21 | 22 | void 23 | def_audio(pybind11::module_ &data_module); 24 | 25 | void 26 | def_data(pybind11::module_ &base_module); 27 | 28 | void 29 | def_data_pipeline(pybind11::module_ &data_module); 30 | 31 | void 32 | def_memory(pybind11::module_ &base_module); 33 | 34 | void 35 | def_sentencepiece(pybind11::module_ &text_module); 36 | 37 | void 38 | def_string(pybind11::module_ &data_module); 39 | 40 | void 41 | def_text(pybind11::module_ &data_module); 42 | 43 | void 44 | def_text_converters(pybind11::module_ &text_module); 45 | 46 | void 47 | def_text_reader(pybind11::module_ &text_module); 48 | 49 | } // namespace fairseq2n 50 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/audio/detail/kaldi_fbank.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/float.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | class kaldi_fbank_computer { 18 | friend class kaldi_fbank_compute_op; 19 | 20 | public: 21 | explicit 22 | kaldi_fbank_computer(const knf::FbankOptions &opts); 23 | 24 | at::Tensor 25 | compute(const at::Tensor &waveform, bool pin_memory); 26 | 27 | float32 28 | sample_rate() const noexcept 29 | { 30 | return opts_->frame_opts.samp_freq; 31 | } 32 | 33 | private: 34 | knf::FbankComputer & 35 | native() noexcept 36 | { 37 | return native_; 38 | } 39 | 40 | const knf::FeatureWindowFunction & 41 | window_fn() const noexcept 42 | { 43 | return window_fn_; 44 | } 45 | 46 | private: 47 | knf::FbankComputer native_; 48 | knf::FeatureWindowFunction window_fn_; 49 | const knf::FbankOptions *opts_; 50 | }; 51 | 52 | } // namespace fairseq2n::detail 53 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/composite_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include "fairseq2n/data/data_pipeline.h" 12 | 13 | 14 | using index_generator_fn = std::function; 15 | 16 | namespace fairseq2n::detail { 17 | 18 | class composite_data_source final : public data_source { 19 | public: 20 | explicit 21 | composite_data_source(std::vector &&pipelines, index_generator_fn &&index_gen_fn, bool stop_at_shortest); 22 | 23 | std::optional 24 | next() override; 25 | 26 | void 27 | reset() override; 28 | 29 | void 30 | record_position(tape &t) const override; 31 | 32 | void 33 | reload_position(tape &t) override; 34 | 35 | private: 36 | std::optional 37 | next_in_pipeline(std::size_t pipeline_idx); 38 | 39 | bool 40 | eod(); 41 | 42 | private: 43 | std::vector pipelines_; 44 | index_generator_fn next_index_gen_; 45 | std::vector> buffer_{}; 46 | std::vector is_epoch_done_; 47 | bool is_eod_ = false; 48 | bool stop_at_shortest_; 49 | }; 50 | 51 | } // namespace fairseq2n::detail 52 | -------------------------------------------------------------------------------- /src/fairseq2/generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.generation.beam_search import BeamSearch as BeamSearch 8 | from fairseq2.generation.beam_search import StandardBeamSearch as StandardBeamSearch 9 | from fairseq2.generation.sequence_generator import Hypothesis as Hypothesis 10 | from fairseq2.generation.sequence_generator import Seq2SeqGenerator as Seq2SeqGenerator 11 | from fairseq2.generation.sequence_generator import ( 12 | SequenceGeneratorOptions as SequenceGeneratorOptions, 13 | ) 14 | from fairseq2.generation.sequence_generator import ( 15 | SequenceGeneratorOutput as SequenceGeneratorOutput, 16 | ) 17 | from fairseq2.generation.step_processor import ( 18 | BannedSequenceProcessor as BannedSequenceProcessor, 19 | ) 20 | from fairseq2.generation.step_processor import ( 21 | NGramRepeatBlockProcessor as NGramRepeatBlockProcessor, 22 | ) 23 | from fairseq2.generation.step_processor import StepProcessor as StepProcessor 24 | from fairseq2.generation.text import SequenceToTextGenerator as SequenceToTextGenerator 25 | from fairseq2.generation.text import SequenceToTextOutput as SequenceToTextOutput 26 | from fairseq2.generation.text import TextTranslator as TextTranslator 27 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/detail/exception.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include "fairseq2n/data/data.h" 16 | #include "fairseq2n/data/data_pipeline.h" 17 | 18 | namespace fairseq2n::detail { 19 | 20 | template 21 | [[noreturn]] inline void 22 | throw_data_pipeline_error( 23 | std::optional maybe_example, 24 | bool recoverable, 25 | fmt::format_string format, Args &&...args) 26 | { 27 | throw data_pipeline_error{ 28 | fmt::vformat(format, fmt::make_format_args(args...)), 29 | std::move(maybe_example), 30 | recoverable}; 31 | } 32 | 33 | template 34 | [[noreturn]] inline void 35 | throw_data_pipeline_error_with_nested( 36 | std::optional maybe_example, 37 | bool recoverable, 38 | fmt::format_string format, Args &&...args) 39 | { 40 | std::throw_with_nested(data_pipeline_error{ 41 | fmt::vformat(format, fmt::make_format_args(args...)), 42 | std::move(maybe_example), 43 | recoverable}); 44 | } 45 | 46 | } // namespace fairseq2n::detail 47 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/element_mapper.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/element_mapper.h" 8 | 9 | #include 10 | #include 11 | 12 | #include "fairseq2n/fmt.h" 13 | #include "fairseq2n/detail/exception.h" 14 | 15 | using namespace fairseq2n::detail; 16 | 17 | namespace fairseq2n { 18 | 19 | element_mapper::element_mapper(map_fn fn, std::optional maybe_selector) 20 | : map_fn_{std::move(fn)} 21 | { 22 | if (maybe_selector) 23 | maybe_selector_ = element_selector{*std::move(maybe_selector)}; 24 | } 25 | data 26 | element_mapper::operator()(data &&d) 27 | { 28 | if (!maybe_selector_) 29 | return map_fn_(std::move(d)); 30 | 31 | maybe_selector_->visit(d, [this](data &element, element_path_ref path) 32 | { 33 | try { 34 | element = map_fn_(std::move(element)); 35 | } catch (const std::exception &) { 36 | throw_with_nested( 37 | "The map function has failed while processing the path '{}' of the input data. See nested exception for details.", path); 38 | } 39 | }); 40 | 41 | return std::move(d); 42 | } 43 | 44 | } // namespace fairseq2n 45 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/map_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "fairseq2n/data/data_pipeline.h" 16 | #include "fairseq2n/data/data_source.h" 17 | 18 | namespace fairseq2n::detail { 19 | 20 | class map_data_source final : public data_source { 21 | public: 22 | explicit 23 | map_data_source( 24 | std::unique_ptr &&inner, map_fn &&fn, std::size_t num_parallel_calls); 25 | 26 | std::optional 27 | next() override; 28 | 29 | void 30 | reset() override; 31 | 32 | void 33 | record_position(tape &t) const override; 34 | 35 | void 36 | reload_position(tape &t) override; 37 | 38 | private: 39 | bool 40 | fill_buffer(); 41 | 42 | std::optional 43 | invoke_function(data &&example); 44 | 45 | private: 46 | std::unique_ptr inner_; 47 | map_fn map_fn_; 48 | std::size_t num_parallel_calls_; 49 | std::vector> buffer_{}; 50 | std::vector>::iterator buffer_iter_{}; 51 | }; 52 | 53 | } // namespace fairseq2n::detail 54 | -------------------------------------------------------------------------------- /src/fairseq2/nn/utils/grad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Tuple 8 | 9 | from torch import Tensor 10 | from torch.autograd import Function 11 | 12 | 13 | def scale_grad(x: Tensor, scale: float) -> Tensor: 14 | """Scale the gradient of ``x`` during backpropagation. 15 | 16 | This might be used to, for example, allow one part of a model to learn at a 17 | lower rate than the rest. 18 | 19 | :param x: 20 | The input tensor. 21 | :param scale: 22 | The scale factor of the gradient. 23 | """ 24 | return _GradScaler.apply(x, scale) # type: ignore[no-any-return] 25 | 26 | 27 | class _GradScaler(Function): 28 | @staticmethod 29 | def forward(ctx: Any, x: Tensor, scale: float) -> Tensor: # type: ignore[override] 30 | if not x.dtype.is_floating_point: 31 | raise TypeError( 32 | f"`x` must be a float tensor, but is of type `{x.dtype}` instead." 33 | ) 34 | 35 | ctx.scale = scale 36 | 37 | return x.clone().detach().requires_grad_(True) 38 | 39 | @staticmethod 40 | def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] 41 | return grad_output * ctx.scale, None 42 | -------------------------------------------------------------------------------- /src/fairseq2/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from array import array 8 | from typing import TYPE_CHECKING, Optional, Union, overload 9 | 10 | from typing_extensions import TypeAlias 11 | 12 | from fairseq2 import _DOC_MODE 13 | 14 | Buffer: TypeAlias = Union[bytes, bytearray, memoryview, array] 15 | 16 | if TYPE_CHECKING or _DOC_MODE: 17 | 18 | class MemoryBlock: 19 | """Represents a contiguous block of read-only memory.""" 20 | 21 | @overload 22 | def __init__(self) -> None: 23 | ... 24 | 25 | @overload 26 | def __init__(self, buffer: Buffer, copy: bool = False) -> None: 27 | ... 28 | 29 | def __init__(self, buffer: Optional[Buffer] = None, copy: bool = False) -> None: 30 | """ 31 | :param buffer: 32 | An object that supports the Python buffer protocol. 33 | :param copy: 34 | If ``True``, copies ``buffer``. 35 | """ 36 | 37 | def __len__(self) -> int: 38 | ... 39 | 40 | def __bytes__(self) -> bytes: 41 | ... 42 | 43 | else: 44 | from fairseq2n.bindings.memory import MemoryBlock as MemoryBlock 45 | 46 | MemoryBlock.__module__ = __name__ 47 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/string_to_int_converter.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/text/string_to_int_converter.h" 8 | 9 | #include 10 | 11 | #include "fairseq2n/fmt.h" 12 | #include "fairseq2n/utils/string.h" 13 | #include "fairseq2n/data/immutable_string.h" 14 | #include "fairseq2n/detail/exception.h" 15 | 16 | using namespace fairseq2n::detail; 17 | 18 | namespace fairseq2n { 19 | 20 | data 21 | string_to_int_converter::operator()(data &&d) const 22 | { 23 | if (!d.is_string()) 24 | throw_( 25 | "The input data must be of type `string`, but is of type `{}` instead.", d.type()); 26 | 27 | const immutable_string &s = d.as_string(); 28 | 29 | try { 30 | return from_string(s, base_); 31 | } catch (const std::out_of_range &) { 32 | throw_( 33 | "The input string must represent a 64-bit integer, but is '{}' instead, which is out of range.", s); 34 | } catch (const std::invalid_argument &) { 35 | throw_( 36 | "The input string must represent a 64-bit integer, but is '{}' instead.", s); 37 | } 38 | } 39 | 40 | } // namespace fairseq2n 41 | -------------------------------------------------------------------------------- /.github/workflows/_lint_sh.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | on: 8 | workflow_call: 9 | inputs: 10 | py: 11 | type: string 12 | default: '3.11' 13 | 14 | jobs: 15 | lint: 16 | name: Lint 17 | runs-on: 18 | labels: 4-core-ubuntu 19 | container: 20 | image: ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu 21 | defaults: 22 | run: 23 | shell: bash 24 | steps: 25 | - name: Check-out the repository 26 | uses: actions/checkout@v3 27 | - name: Create the Python virtual environment 28 | run: | 29 | python${{ inputs.py }} -m venv ~/venv 30 | 31 | echo ~/venv/bin >> "$GITHUB_PATH" 32 | - name: Install requirements 33 | id: setup 34 | run: | 35 | pip install --requirement requirements-devel.txt 36 | - name: Run shellcheck 37 | if: success() || (failure() && steps.setup.outcome == 'success') 38 | run: | 39 | echo "::add-matcher::./ci/problem-matchers/gcc.json" 40 | 41 | function remove_matcher 42 | { 43 | echo "::remove-matcher owner=gcc::" 44 | } 45 | 46 | trap remove_matcher EXIT 47 | 48 | tools/run-shellcheck.sh . 49 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/bucket_by_length_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "fairseq2n/data/data_pipeline.h" 16 | #include "fairseq2n/data/data_source.h" 17 | 18 | namespace fairseq2n::detail { 19 | 20 | class bucket_by_length_data_source final : public data_source { 21 | public: 22 | explicit 23 | bucket_by_length_data_source( 24 | std::unique_ptr &&inner, 25 | std::vector> &&bucket_sizes, 26 | data_length_fn &&fn, 27 | bool drop_remainder); 28 | 29 | std::optional 30 | next() override; 31 | 32 | void 33 | reset() override; 34 | 35 | void 36 | record_position(tape &t) const override; 37 | 38 | void 39 | reload_position(tape &t) override; 40 | 41 | private: 42 | std::unique_ptr inner_; 43 | std::vector> bucket_sizes_; 44 | std::size_t max_data_len_; 45 | data_length_fn data_length_fn_; 46 | bool drop_remainder_; 47 | std::vector buckets_{}; 48 | }; 49 | 50 | } // namespace fairseq2n::detail 51 | -------------------------------------------------------------------------------- /src/fairseq2/data/text/text_reader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | from typing import TYPE_CHECKING, Optional 9 | 10 | from fairseq2 import _DOC_MODE 11 | from fairseq2.data.data_pipeline import DataPipelineBuilder 12 | from fairseq2.data.typing import PathLike, StringLike 13 | 14 | if TYPE_CHECKING or _DOC_MODE: 15 | 16 | class LineEnding(Enum): 17 | INFER = 0 18 | LF = 1 19 | CRLF = 2 20 | 21 | def read_text( 22 | pathname: PathLike, 23 | encoding: Optional[StringLike] = None, 24 | line_ending: LineEnding = LineEnding.INFER, 25 | ltrim: bool = False, 26 | rtrim: bool = False, 27 | skip_empty: bool = False, 28 | memory_map: bool = False, 29 | block_size: Optional[int] = None, 30 | ) -> DataPipelineBuilder: 31 | """Open a text file and return a data pipeline reading lines one by one.""" 32 | ... 33 | 34 | else: 35 | from fairseq2n.bindings.data.text.text_reader import LineEnding as LineEnding 36 | from fairseq2n.bindings.data.text.text_reader import read_text as read_text 37 | 38 | def _set_module_name() -> None: 39 | for t in [LineEnding, read_text]: 40 | t.__module__ = __name__ 41 | 42 | _set_module_name() 43 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/zip_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/data/data_pipeline.h" 14 | #include "fairseq2n/data/data_source.h" 15 | 16 | namespace fairseq2n::detail { 17 | 18 | class zip_data_source final : public data_source { 19 | public: 20 | explicit 21 | zip_data_source( 22 | std::vector &&pipelines, 23 | std::vector &&names, 24 | bool zip_to_shortest, 25 | bool flatten, 26 | bool disable_parallelism) noexcept; 27 | 28 | std::optional 29 | next() override; 30 | 31 | void 32 | reset() override; 33 | 34 | void 35 | record_position(tape &t) const override; 36 | 37 | void 38 | reload_position(tape &t) override; 39 | 40 | private: 41 | static std::optional 42 | flatten_to_dict(data_list &zip); 43 | 44 | static std::optional 45 | flatten_to_list(data_list &zip); 46 | 47 | private: 48 | std::vector pipelines_; 49 | std::vector names_; 50 | bool zip_to_shortest_; 51 | bool flatten_; 52 | bool disable_parallelism_; 53 | }; 54 | 55 | } // namespace fairseq2n::detail 56 | -------------------------------------------------------------------------------- /src/fairseq2/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.nn.embedding import Embedding as Embedding 8 | from fairseq2.nn.embedding import StandardEmbedding as StandardEmbedding 9 | from fairseq2.nn.embedding import init_scaled_embedding as init_scaled_embedding 10 | from fairseq2.nn.incremental_state import IncrementalState as IncrementalState 11 | from fairseq2.nn.incremental_state import IncrementalStateBag as IncrementalStateBag 12 | from fairseq2.nn.module_list import ModuleList as ModuleList 13 | from fairseq2.nn.normalization import LayerNorm as LayerNorm 14 | from fairseq2.nn.normalization import RMSNorm as RMSNorm 15 | from fairseq2.nn.normalization import StandardLayerNorm as StandardLayerNorm 16 | from fairseq2.nn.position_encoder import ( 17 | LearnedPositionEncoder as LearnedPositionEncoder, 18 | ) 19 | from fairseq2.nn.position_encoder import PositionEncoder as PositionEncoder 20 | from fairseq2.nn.position_encoder import RotaryEncoder as RotaryEncoder 21 | from fairseq2.nn.position_encoder import ( 22 | SinusoidalPositionEncoder as SinusoidalPositionEncoder, 23 | ) 24 | from fairseq2.nn.projection import Linear as Linear 25 | from fairseq2.nn.projection import Projection as Projection 26 | from fairseq2.nn.projection import TiedProjection as TiedProjection 27 | -------------------------------------------------------------------------------- /fairseq2n/third-party/kaldi-native-fbank.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_kaldi_native_fbank) 8 | set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) 9 | 10 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF) 11 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF) 12 | 13 | set(backup_build_shared_libs ${BUILD_SHARED_LIBS}) 14 | 15 | # Force the library to be static. 16 | set(BUILD_SHARED_LIBS FALSE) 17 | 18 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/kaldi-native-fbank EXCLUDE_FROM_ALL) 19 | 20 | # Revert. 21 | set(BUILD_SHARED_LIBS ${backup_build_shared_libs}) 22 | 23 | unset(backup_build_shared_libs) 24 | 25 | set_target_properties(kaldi-native-fbank-core PROPERTIES 26 | CXX_VISIBILITY_PRESET 27 | hidden 28 | POSITION_INDEPENDENT_CODE 29 | ON 30 | ) 31 | 32 | target_include_directories(kaldi-native-fbank-core SYSTEM 33 | PUBLIC 34 | ${PROJECT_SOURCE_DIR}/third-party/kaldi-native-fbank 35 | ) 36 | 37 | # We depend on the phony torch_cxx11_abi target to ensure that we use 38 | # the same libstdc++ ABI as PyTorch. 39 | target_link_libraries(kaldi-native-fbank-core PRIVATE torch_cxx11_abi) 40 | 41 | add_library(kaldi-native-fbank::core ALIAS kaldi-native-fbank-core) 42 | endmacro() 43 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/filter_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/filter_data_source.h" 8 | 9 | #include 10 | 11 | #include "fairseq2n/data/detail/exception.h" 12 | 13 | namespace fairseq2n::detail { 14 | 15 | std::optional 16 | filter_data_source::next() 17 | { 18 | while (std::optional maybe_example = inner_->next()) 19 | if (invoke_function(*maybe_example)) 20 | return maybe_example; 21 | 22 | return std::nullopt; 23 | } 24 | 25 | void 26 | filter_data_source::reset() 27 | { 28 | inner_->reset(); 29 | } 30 | 31 | void 32 | filter_data_source::record_position(tape &t) const 33 | { 34 | inner_->record_position(t); 35 | } 36 | 37 | void 38 | filter_data_source::reload_position(tape &t) 39 | { 40 | inner_->reload_position(t); 41 | } 42 | 43 | bool 44 | filter_data_source::invoke_function(data &example) 45 | { 46 | try { 47 | return predicate_fn_(example); 48 | } catch (const data_pipeline_error &) { 49 | throw; 50 | } catch (const std::exception &) { 51 | throw_data_pipeline_error_with_nested(std::move(example), /*recoverable=*/true, 52 | "The filter operation has failed. See nested exception for details."); 53 | } 54 | } 55 | 56 | } // fairseq2n::detail 57 | -------------------------------------------------------------------------------- /fairseq2n/third-party/natsort/strnatcmp.h: -------------------------------------------------------------------------------- 1 | /* -*- mode: c; c-file-style: "k&r" -*- 2 | 3 | strnatcmp.c -- Perform 'natural order' comparisons of strings in C. 4 | Copyright (C) 2000, 2004 by Martin Pool 5 | 6 | This software is provided 'as-is', without any express or implied 7 | warranty. In no event will the authors be held liable for any damages 8 | arising from the use of this software. 9 | 10 | Permission is granted to anyone to use this software for any purpose, 11 | including commercial applications, and to alter it and redistribute it 12 | freely, subject to the following restrictions: 13 | 14 | 1. The origin of this software must not be misrepresented; you must not 15 | claim that you wrote the original software. If you use this software 16 | in a product, an acknowledgment in the product documentation would be 17 | appreciated but is not required. 18 | 2. Altered source versions must be plainly marked as such, and must not be 19 | misrepresented as being the original software. 20 | 3. This notice may not be removed or altered from any source distribution. 21 | */ 22 | 23 | #ifdef __cplusplus 24 | extern "C" { 25 | #endif 26 | 27 | /* CUSTOMIZATION SECTION 28 | * 29 | * You can change this typedef, but must then also change the inline 30 | * functions in strnatcmp.c */ 31 | typedef char nat_char; 32 | 33 | int strnatcmp(nat_char const *a, nat_char const *b); 34 | int strnatcasecmp(nat_char const *a, nat_char const *b); 35 | 36 | #ifdef __cplusplus 37 | } 38 | #endif 39 | -------------------------------------------------------------------------------- /tests/unit/nn/test_module_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | from torch.nn import Linear 9 | 10 | from fairseq2.nn import ModuleList 11 | 12 | 13 | class TestModuleList: 14 | def test_iter_returns_no_modules_when_drop_p_is_one(self) -> None: 15 | modules = [Linear(10, 10), Linear(10, 10), Linear(10, 10), Linear(10, 10)] 16 | 17 | m = ModuleList(modules, drop_p=1.0) 18 | 19 | with pytest.raises(StopIteration): 20 | next(m.drop_iter()) 21 | 22 | def test_iter_returns_all_modules_when_drop_p_is_zero(self) -> None: 23 | modules = [Linear(10, 10), Linear(10, 10), Linear(10, 10), Linear(10, 10)] 24 | 25 | m = ModuleList(modules) 26 | 27 | count = 0 28 | 29 | for m1, m2 in zip(m.drop_iter(), modules): 30 | assert m1 is m2 31 | 32 | count += 1 33 | 34 | assert count == len(modules) 35 | 36 | def test_iter_returns_all_modules_in_eval(self) -> None: 37 | modules = [Linear(10, 10), Linear(10, 10), Linear(10, 10), Linear(10, 10)] 38 | 39 | m = ModuleList(modules, drop_p=1.0) 40 | 41 | m.eval() 42 | 43 | count = 0 44 | 45 | for m1, m2 in zip(m.drop_iter(), modules): 46 | assert m1 is m2 47 | 48 | count += 1 49 | 50 | assert count == len(modules) 51 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/round_robin_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/round_robin_data_source.h" 8 | 9 | namespace fairseq2n::detail { 10 | 11 | round_robin_data_source::round_robin_data_source(std::vector &&pipelines, bool stop_at_shortest) 12 | { 13 | pipelines_count_ = pipelines.size(); 14 | pipeline_idx_ = 0; 15 | 16 | auto gen = [this]() 17 | { 18 | pipeline_idx_ %= pipelines_count_; 19 | 20 | return pipeline_idx_++; 21 | }; 22 | 23 | inner_ = std::make_unique(std::move(pipelines), std::move(gen), stop_at_shortest); 24 | } 25 | 26 | std::optional 27 | round_robin_data_source::next() 28 | { 29 | auto output = inner_->next(); 30 | if (!output) 31 | return std::nullopt; 32 | 33 | return output; 34 | } 35 | 36 | void 37 | round_robin_data_source::reset() 38 | { 39 | inner_->reset(); 40 | 41 | pipeline_idx_ = 0; 42 | } 43 | 44 | void 45 | round_robin_data_source::record_position(tape &t) const 46 | { 47 | inner_->record_position(t); 48 | 49 | t.record(pipeline_idx_); 50 | } 51 | 52 | void 53 | round_robin_data_source::reload_position(tape &t) 54 | { 55 | inner_->reload_position(t); 56 | 57 | pipeline_idx_ = t.read(); 58 | } 59 | 60 | } // namespace fairseq2n::detail 61 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/type_casters/string.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | namespace pybind11::detail { 15 | 16 | template <> 17 | struct type_caster { 18 | PYBIND11_TYPE_CASTER(std::string, const_name("str")); 19 | 20 | private: 21 | using inner_caster = string_caster; 22 | 23 | public: 24 | bool 25 | load(handle src, bool convert); 26 | 27 | static handle 28 | cast(const std::string &src, return_value_policy policy, handle parent) 29 | { 30 | return inner_caster::cast(src, policy, parent); 31 | } 32 | 33 | private: 34 | inner_caster inner_caster_{}; 35 | }; 36 | 37 | template <> 38 | struct type_caster { 39 | PYBIND11_TYPE_CASTER(std::string_view, const_name("str")); 40 | 41 | private: 42 | using inner_caster = string_caster; 43 | 44 | public: 45 | bool 46 | load(handle src, bool convert); 47 | 48 | static handle 49 | cast(const std::string_view &src, return_value_policy policy, handle parent) 50 | { 51 | return inner_caster::cast(src, policy, parent); 52 | } 53 | 54 | private: 55 | inner_caster inner_caster_{}; 56 | }; 57 | 58 | } // namespace pybind11::detail 59 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/type_casters/data.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include 12 | 13 | namespace pybind11::detail { 14 | 15 | template <> 16 | struct type_caster { 17 | PYBIND11_TYPE_CASTER(fairseq2n::data, const_name("Any")); 18 | 19 | public: 20 | bool 21 | load(handle src, bool) 22 | { 23 | value = cast_from_py(src); 24 | 25 | return true; 26 | } 27 | 28 | static handle 29 | cast(const fairseq2n::data &src, return_value_policy, handle) 30 | { 31 | object obj = cast_from_cc(src); 32 | 33 | return obj.release(); 34 | } 35 | 36 | static handle 37 | cast(fairseq2n::data &&src, return_value_policy, handle) 38 | { 39 | object obj = cast_from_cc(std::move(src)); 40 | 41 | return obj.release(); 42 | } 43 | 44 | private: 45 | static fairseq2n::data 46 | cast_from_py(handle src); 47 | 48 | template 49 | static object 50 | cast_from_cc(T &&src); 51 | }; 52 | 53 | extern template 54 | object 55 | type_caster::cast_from_cc(const fairseq2n::data &src); 56 | 57 | extern template 58 | object 59 | type_caster::cast_from_cc(fairseq2n::data &&src); 60 | 61 | } // namespace pybind11::detail 62 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from argparse import ArgumentTypeError 8 | from pathlib import Path 9 | from typing import cast 10 | 11 | import pytest 12 | 13 | import tests.common 14 | from fairseq2.typing import Device 15 | 16 | 17 | def parse_device_arg(value: str) -> Device: 18 | try: 19 | return Device(value) 20 | except RuntimeError: 21 | raise ArgumentTypeError(f"'{value}' is not a valid device name.") 22 | 23 | 24 | def pytest_addoption(parser: pytest.Parser) -> None: 25 | # fmt: off 26 | parser.addoption( 27 | "--device", default="cpu", type=parse_device_arg, 28 | help="device on which to run tests (default: %(default)s)", 29 | ) 30 | parser.addoption( 31 | "--integration", default=False, action="store_true", 32 | help="whether to run the integration tests", 33 | ) 34 | # fmt: on 35 | 36 | 37 | def pytest_sessionstart(session: pytest.Session) -> None: 38 | tests.common.device = cast(Device, session.config.getoption("device")) 39 | 40 | 41 | def pytest_ignore_collect( 42 | collection_path: Path, path: None, config: pytest.Config 43 | ) -> bool: 44 | if "integration" in collection_path.parts: 45 | # Ignore integration tests unless we run `pytest --integration`. 46 | return not cast(bool, config.getoption("integration")) 47 | 48 | return False 49 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/text_data_source.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "fairseq2n/memory.h" 15 | #include "fairseq2n/span.h" 16 | #include "fairseq2n/data/data_source.h" 17 | #include "fairseq2n/data/text/text_line_reader.h" 18 | #include "fairseq2n/data/text/text_reader.h" 19 | 20 | namespace fairseq2n::detail { 21 | 22 | class text_data_source final : public data_source { 23 | public: 24 | explicit 25 | text_data_source(std::string &&pathname, text_options &&opts); 26 | 27 | std::optional 28 | next() override; 29 | 30 | void 31 | reset() override; 32 | 33 | void 34 | record_position(tape &t) const override; 35 | 36 | void 37 | reload_position(tape &t) override; 38 | 39 | private: 40 | std::unique_ptr 41 | make_text_line_reader(); 42 | 43 | memory_block 44 | read_next_line(); 45 | 46 | bool 47 | is_empty(memory_span line) const; 48 | 49 | [[noreturn]] void 50 | handle_error(); 51 | 52 | [[noreturn]] void 53 | throw_read_failure(); 54 | 55 | private: 56 | std::string pathname_; 57 | text_options opts_; 58 | std::unique_ptr line_reader_; 59 | std::size_t num_lines_read_ = 0; 60 | }; 61 | 62 | } // namespace fairseq2n::detail 63 | -------------------------------------------------------------------------------- /tests/integration/models/test_nllb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Final 8 | 9 | import torch 10 | 11 | from fairseq2.generation import TextTranslator 12 | from fairseq2.models.nllb import load_nllb_model, load_nllb_tokenizer 13 | from tests.common import device 14 | 15 | ENG_SENTENCE: Final = "On Monday, scientists from the Stanford University School of Medicine announced the invention of a new diagnostic tool that can sort cells by type: a tiny printable chip that can be manufactured using standard inkjet printers for possibly about one U.S. cent each." 16 | DEU_SENTENCE: Final = "Am Montag kündigten Wissenschaftler der Medizinischen Fakultät der Universität Stanford die Erfindung eines neuen Diagnosetools an, das Zellen nach Typ sortieren kann: Ein winziger druckbarer Chip, der mit standardmäßigen Inkjet-Drucker für möglicherweise etwa einen US-Cent pro Stück hergestellt werden kann." 17 | 18 | 19 | def test_load_dense_distill_600m() -> None: 20 | model_name = "nllb-200_dense_distill_600m" 21 | 22 | model = load_nllb_model( 23 | model_name, device=device, dtype=torch.float32, progress=False 24 | ) 25 | 26 | tokenizer = load_nllb_tokenizer(model_name, progress=False) 27 | 28 | translator = TextTranslator( 29 | model, tokenizer, source_lang="eng_Latn", target_lang="deu_Latn" 30 | ) 31 | 32 | assert translator([ENG_SENTENCE]) == [DEU_SENTENCE] 33 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/bucket_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/bucket_data_source.h" 8 | 9 | #include 10 | 11 | namespace fairseq2n::detail { 12 | 13 | bucket_data_source::bucket_data_source( 14 | std::unique_ptr &&inner, std::size_t bucket_size, bool drop_remainder) noexcept 15 | : inner_{std::move(inner)}, bucket_size_{bucket_size}, drop_remainder_{drop_remainder} 16 | {} 17 | 18 | std::optional 19 | bucket_data_source::next() 20 | { 21 | data_list output{}; 22 | 23 | output.reserve(bucket_size_); 24 | 25 | for (std::size_t i = 0; i < bucket_size_; ++i) { 26 | std::optional maybe_example = inner_->next(); 27 | if (!maybe_example) 28 | break; 29 | 30 | output.push_back(*std::move(maybe_example)); 31 | } 32 | 33 | if (output.empty()) 34 | return std::nullopt; 35 | 36 | if (drop_remainder_ && output.size() < bucket_size_) 37 | return std::nullopt; 38 | 39 | return output; 40 | } 41 | 42 | void 43 | bucket_data_source::reset() 44 | { 45 | inner_->reset(); 46 | } 47 | 48 | void 49 | bucket_data_source::record_position(tape &t) const 50 | { 51 | inner_->record_position(t); 52 | } 53 | 54 | void 55 | bucket_data_source::reload_position(tape &t) 56 | { 57 | inner_->reload_position(t); 58 | } 59 | 60 | } // namespace fairseq2n::detail 61 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/detail/file.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/detail/file.h" 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include "fairseq2n/memory.h" 15 | #include "fairseq2n/detail/error.h" 16 | #include "fairseq2n/detail/exception.h" 17 | 18 | namespace fairseq2n::detail { 19 | namespace { 20 | 21 | void 22 | mmap_deallocate(const void *addr, std::size_t size, void *) noexcept 23 | { 24 | if (addr != nullptr) 25 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) 26 | ::munmap(const_cast(addr), size); 27 | } 28 | 29 | } // namespace 30 | 31 | memory_block 32 | memory_map_file(const file_desc &fd, std::string_view pathname) 33 | { 34 | struct ::stat buf{}; 35 | if (::fstat(fd.get(), &buf) == -1) 36 | throw_system_error(last_error(), 37 | "The file size of '{}' cannot be determined", pathname); 38 | 39 | auto size = static_cast(buf.st_size); 40 | if (size == 0) 41 | return memory_block{}; 42 | 43 | void *addr = ::mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd.get(), 0); 44 | if (addr == MAP_FAILED) 45 | throw_system_error(last_error(), 46 | "'{}' cannot be memory mapped", pathname); 47 | 48 | return memory_block{static_cast(addr), size, nullptr, mmap_deallocate}; 49 | } 50 | 51 | } // namespace fairseq2n::detail 52 | -------------------------------------------------------------------------------- /src/fairseq2/models/s2t_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.models.s2t_transformer.builder import ( 8 | S2TTransformerBuilder as S2TTransformerBuilder, 9 | ) 10 | from fairseq2.models.s2t_transformer.builder import ( 11 | S2TTransformerConfig as S2TTransformerConfig, 12 | ) 13 | from fairseq2.models.s2t_transformer.builder import ( 14 | create_s2t_transformer_model as create_s2t_transformer_model, 15 | ) 16 | from fairseq2.models.s2t_transformer.builder import ( 17 | s2t_transformer_arch as s2t_transformer_arch, 18 | ) 19 | from fairseq2.models.s2t_transformer.builder import ( 20 | s2t_transformer_archs as s2t_transformer_archs, 21 | ) 22 | from fairseq2.models.s2t_transformer.feature_extractor import ( 23 | Conv1dFbankSubsampler as Conv1dFbankSubsampler, 24 | ) 25 | from fairseq2.models.s2t_transformer.frontend import ( 26 | S2TTransformerFrontend as S2TTransformerFrontend, 27 | ) 28 | from fairseq2.models.s2t_transformer.loader import ( 29 | load_s2t_transformer_config as load_s2t_transformer_config, 30 | ) 31 | from fairseq2.models.s2t_transformer.loader import ( 32 | load_s2t_transformer_model as load_s2t_transformer_model, 33 | ) 34 | from fairseq2.models.s2t_transformer.loader import ( 35 | load_s2t_transformer_tokenizer as load_s2t_transformer_tokenizer, 36 | ) 37 | from fairseq2.models.s2t_transformer.tokenizer import ( 38 | S2TTransformerTokenizer as S2TTransformerTokenizer, 39 | ) 40 | -------------------------------------------------------------------------------- /fairseq2n/tests/data/detail/test_lru_cache.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | 13 | #include 14 | 15 | using namespace fairseq2n::detail; 16 | 17 | TEST(test_lru_cache, add_and_maybe_get_work) 18 | { 19 | lru_cache cache{/*capacity=*/5}; 20 | 21 | cache.add("1", 1); 22 | cache.add("2", 2); 23 | cache.add("3", 3); 24 | cache.add("4", 4); 25 | cache.add("5", 5); 26 | 27 | // Mark 1 as the least-recently-used entry. 28 | EXPECT_NE(cache.maybe_get("1"), nullptr); 29 | 30 | EXPECT_EQ(cache.size(), 5); 31 | 32 | // Exceed capacity; we expect 2 to be evicted. 33 | cache.add("6", 6); 34 | 35 | EXPECT_EQ(cache.size(), 5); 36 | 37 | EXPECT_EQ(cache.maybe_get("2"), nullptr); 38 | 39 | // Exceed capacity again; we expect 3 to be evicted. 40 | cache.add("7", 7); 41 | 42 | EXPECT_EQ(cache.size(), 5); 43 | 44 | EXPECT_EQ(cache.maybe_get("3"), nullptr); 45 | 46 | // This time touch 4, and mark it as least-recently-used. 47 | EXPECT_NE(cache.maybe_get("4"), nullptr); 48 | 49 | // Exceed capacity again; we expect 5 to be evicted. 50 | cache.add("8", 8); 51 | 52 | EXPECT_EQ(cache.size(), 5); 53 | 54 | EXPECT_EQ(cache.maybe_get("5"), nullptr); 55 | 56 | cache.clear(); 57 | 58 | EXPECT_EQ(cache.size(), 0); 59 | } 60 | -------------------------------------------------------------------------------- /tests/unit/data/data_pipeline/test_count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | from fairseq2.data import DataPipeline 10 | 11 | 12 | class TestCountOp: 13 | def test_op_works(self) -> None: 14 | pipeline = DataPipeline.count(start=4).take(10).and_return() 15 | 16 | for _ in range(2): 17 | list(pipeline) == list(range(4, 15)) 18 | 19 | pipeline.reset() 20 | 21 | def test_op_saves_and_restores_its_state(self) -> None: 22 | pipeline = DataPipeline.count(start=4).take(10).and_return() 23 | 24 | d = None 25 | 26 | it = iter(pipeline) 27 | 28 | # Move the the fifth example. 29 | for _ in range(5): 30 | d = next(it) 31 | 32 | assert d == 8 33 | 34 | state_dict = pipeline.state_dict() 35 | 36 | # Read a few examples before we roll back. 37 | for _ in range(4): 38 | d = next(it) 39 | 40 | assert d == 12 41 | 42 | # Expected to roll back to the fifth example. 43 | pipeline.load_state_dict(state_dict) 44 | 45 | # Move to EOD. 46 | for _ in range(5): 47 | d = next(it) 48 | 49 | assert d == 13 50 | 51 | state_dict = pipeline.state_dict() 52 | 53 | pipeline.reset() 54 | 55 | # Expected to be EOD. 56 | pipeline.load_state_dict(state_dict) 57 | 58 | with pytest.raises(StopIteration): 59 | next(iter(pipeline)) 60 | -------------------------------------------------------------------------------- /src/fairseq2/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.data.cstring import CString as CString 8 | from fairseq2.data.data_pipeline import ByteStreamError as ByteStreamError 9 | from fairseq2.data.data_pipeline import CollateOptionsOverride as CollateOptionsOverride 10 | from fairseq2.data.data_pipeline import Collater as Collater 11 | from fairseq2.data.data_pipeline import DataPipeline as DataPipeline 12 | from fairseq2.data.data_pipeline import DataPipelineBuilder as DataPipelineBuilder 13 | from fairseq2.data.data_pipeline import DataPipelineError as DataPipelineError 14 | from fairseq2.data.data_pipeline import FileMapper as FileMapper 15 | from fairseq2.data.data_pipeline import FileMapperOutput as FileMapperOutput 16 | from fairseq2.data.data_pipeline import RecordError as RecordError 17 | from fairseq2.data.data_pipeline import SequenceData as SequenceData 18 | from fairseq2.data.data_pipeline import ( 19 | get_last_failed_example as get_last_failed_example, 20 | ) 21 | from fairseq2.data.data_pipeline import list_files as list_files 22 | from fairseq2.data.data_pipeline import read_sequence as read_sequence 23 | from fairseq2.data.data_pipeline import read_zipped_records as read_zipped_records 24 | from fairseq2.data.typing import PathLike as PathLike 25 | from fairseq2.data.typing import StringLike as StringLike 26 | from fairseq2.data.typing import is_string_like as is_string_like 27 | from fairseq2.data.vocabulary_info import VocabularyInfo as VocabularyInfo 28 | -------------------------------------------------------------------------------- /fairseq2n/.clang-tidy: -------------------------------------------------------------------------------- 1 | Checks: > 2 | *, 3 | -abseil-*, 4 | -altera-*, 5 | -android-*, 6 | -bugprone-easily-swappable-parameters, 7 | -cert-dcl21-cpp, 8 | -cppcoreguidelines-avoid-magic-numbers, 9 | -cppcoreguidelines-avoid-non-const-global-variables, 10 | -cppcoreguidelines-non-private-member-variables-in-classes, 11 | -cppcoreguidelines-prefer-member-initializer, 12 | -cppcoreguidelines-pro-bounds-array-to-pointer-decay, 13 | -cppcoreguidelines-pro-bounds-pointer-arithmetic, 14 | -cppcoreguidelines-pro-type-reinterpret-cast, 15 | -cppcoreguidelines-pro-type-vararg, 16 | -darwin-*, 17 | -facebook-*, 18 | -fuchsia-*, 19 | -google-*, 20 | -hicpp-*, 21 | -linuxkernel-*, 22 | -llvm-*, 23 | -llvmlibc-*, 24 | -misc-confusable-identifiers, 25 | -misc-const-correctness, 26 | -misc-no-recursion, 27 | -misc-non-private-member-variables-in-classes, 28 | -modernize-use-equals-default, 29 | -modernize-use-nodiscard, 30 | -modernize-use-trailing-return-type, 31 | -readability-braces-around-statements, 32 | -readability-else-after-return, 33 | -readability-function-cognitive-complexity, 34 | -readability-identifier-length, 35 | -readability-isolate-declaration, 36 | -readability-magic-numbers, 37 | -readability-misleading-indentation, 38 | -readability-named-parameter, 39 | -readability-qualified-auto, 40 | -readability-redundant-access-specifiers, 41 | -readability-suspicious-call-argument, 42 | -zircon-* 43 | CheckOptions: 44 | - key: cppcoreguidelines-special-member-functions.AllowMissingMoveFunctions 45 | value: true 46 | HeaderFilterRegex: '.*' 47 | -------------------------------------------------------------------------------- /fairseq2n/python/src/fairseq2n/bindings/type_casters/torch.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace pybind11 { 16 | 17 | template <> 18 | bool 19 | isinstance(handle obj); 20 | 21 | namespace detail { 22 | 23 | template <> 24 | struct type_caster { 25 | PYBIND11_TYPE_CASTER(at::Tensor, const_name("torch.Tensor")); 26 | 27 | public: 28 | bool 29 | load(handle src, bool); 30 | 31 | static handle 32 | cast(const at::Tensor &src, return_value_policy, handle); 33 | }; 34 | 35 | 36 | template <> 37 | struct type_caster { 38 | PYBIND11_TYPE_CASTER(at::Device, const_name("torch.device")); 39 | 40 | public: 41 | type_caster() noexcept 42 | : value{at::kCPU} 43 | {} 44 | 45 | bool 46 | load(handle src, bool); 47 | 48 | static handle 49 | cast(const at::Device &src, return_value_policy, handle); 50 | }; 51 | 52 | template <> 53 | struct type_caster { 54 | PYBIND11_TYPE_CASTER(at::ScalarType, const_name("torch.dtype")); 55 | 56 | public: 57 | bool 58 | load(handle src, bool); 59 | 60 | static handle 61 | cast(const at::ScalarType &src, return_value_policy, handle); 62 | 63 | private: 64 | static handle 65 | get_dtype(const char *type_name); 66 | }; 67 | 68 | } // namespace detail 69 | } // namespace pybind11 70 | -------------------------------------------------------------------------------- /fairseq2n/third-party/sentencepiece.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | macro(fairseq2n_add_sentencepiece) 8 | set(CMAKE_POLICY_DEFAULT_CMP0063 NEW) 9 | set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) 10 | 11 | # Do not build the shared library. 12 | set(SPM_ENABLE_SHARED OFF) 13 | 14 | add_subdirectory(${PROJECT_SOURCE_DIR}/third-party/sentencepiece EXCLUDE_FROM_ALL) 15 | 16 | target_compile_features(sentencepiece-static PRIVATE cxx_std_17) 17 | 18 | if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 19 | # See https://github.com/protocolbuffers/protobuf/issues/6419. 20 | target_compile_options(sentencepiece-static PRIVATE -Wno-stringop-overflow) 21 | endif() 22 | 23 | if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 24 | target_compile_options(sentencepiece-static PRIVATE -Wno-deprecated-declarations) 25 | endif() 26 | 27 | set_target_properties(sentencepiece-static PROPERTIES 28 | CXX_VISIBILITY_PRESET 29 | hidden 30 | POSITION_INDEPENDENT_CODE 31 | ON 32 | ) 33 | 34 | target_include_directories(sentencepiece-static SYSTEM 35 | PUBLIC 36 | ${PROJECT_SOURCE_DIR}/third-party 37 | ${PROJECT_SOURCE_DIR}/third-party/sentencepiece/third_party/protobuf-lite 38 | ) 39 | 40 | # We depend on the phony torch_cxx11_abi target to ensure that we use the 41 | # same libstdc++ ABI as PyTorch. 42 | target_link_libraries(sentencepiece-static PRIVATE torch_cxx11_abi) 43 | endmacro() 44 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/utils/cast.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/float.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | template 18 | inline constexpr auto 19 | ssize(const Container &container) noexcept 20 | { 21 | return static_cast(container.size()); 22 | } 23 | 24 | template 25 | inline constexpr T 26 | conditional_cast(U value) noexcept 27 | { 28 | if constexpr (std::is_same_v) 29 | return value; 30 | else 31 | return static_cast(value); 32 | } 33 | 34 | template 35 | inline constexpr bool 36 | are_equal(const T &lhs, const T &rhs) noexcept 37 | { 38 | if constexpr (std::is_floating_point_v) 39 | return are_close(lhs, rhs); 40 | else 41 | return lhs == rhs; 42 | } 43 | 44 | template 45 | inline constexpr bool 46 | maybe_narrow(U u, T &t) noexcept 47 | { 48 | if constexpr (std::is_same_v) { 49 | t = u; 50 | 51 | return true; 52 | } else { 53 | t = static_cast(u); 54 | 55 | if constexpr (std::is_signed_v == std::is_signed_v) 56 | return are_equal(static_cast(t), u); 57 | else 58 | return are_equal(static_cast(t), u) && (t < T{}) == (u < U{}); 59 | } 60 | } 61 | 62 | } // namespace fairseq2n::detail 63 | -------------------------------------------------------------------------------- /src/fairseq2/assets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from fairseq2.assets.card import AssetCard as AssetCard 8 | from fairseq2.assets.card import AssetCardError as AssetCardError 9 | from fairseq2.assets.card import ( 10 | AssetCardFieldNotFoundError as AssetCardFieldNotFoundError, 11 | ) 12 | from fairseq2.assets.download_manager import AssetDownloadError as AssetDownloadError 13 | from fairseq2.assets.download_manager import ( 14 | AssetDownloadManager as AssetDownloadManager, 15 | ) 16 | from fairseq2.assets.download_manager import ( 17 | InProcAssetDownloadManager as InProcAssetDownloadManager, 18 | ) 19 | from fairseq2.assets.download_manager import download_manager as download_manager 20 | from fairseq2.assets.error import AssetError as AssetError 21 | from fairseq2.assets.metadata_provider import AssetMetadataError as AssetMetadataError 22 | from fairseq2.assets.metadata_provider import ( 23 | AssetMetadataProvider as AssetMetadataProvider, 24 | ) 25 | from fairseq2.assets.metadata_provider import AssetNotFoundError as AssetNotFoundError 26 | from fairseq2.assets.metadata_provider import ( 27 | FileAssetMetadataProvider as FileAssetMetadataProvider, 28 | ) 29 | from fairseq2.assets.metadata_provider import ( 30 | InProcAssetMetadataProvider as InProcAssetMetadataProvider, 31 | ) 32 | from fairseq2.assets.store import AssetStore as AssetStore 33 | from fairseq2.assets.store import ProviderBackedAssetStore as ProviderBackedAssetStore 34 | from fairseq2.assets.store import asset_store as asset_store 35 | -------------------------------------------------------------------------------- /.github/workflows/_publish_pypi.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | on: 8 | workflow_call: 9 | inputs: 10 | os: 11 | type: string 12 | required: true 13 | torch: 14 | type: string 15 | required: true 16 | py: 17 | type: string 18 | required: true 19 | variant: 20 | type: string 21 | default: 'cpu' 22 | arch: 23 | type: string 24 | default: 'x86_64' 25 | 26 | jobs: 27 | publish: 28 | name: Publish 29 | runs-on: ubuntu-latest 30 | defaults: 31 | run: 32 | shell: bash 33 | permissions: 34 | # Needed to interact with GitHub's OIDC Token endpoint. 35 | id-token: write 36 | steps: 37 | - name: Download wheels from staging 38 | uses: actions/download-artifact@v3 39 | with: 40 | name: pypi-pt${{ inputs.torch }}-py${{ inputs.py }}-${{ inputs.os }}_${{ inputs.arch }}-${{ inputs.variant }}-nosan 41 | path: artifacts/ 42 | - name: Publish fairseq2n 43 | uses: pypa/gh-action-pypi-publish@release/v1 44 | with: 45 | packages-dir: artifacts/fairseq2n/python/build/wheelhouse 46 | - name: Publish fairseq2 47 | uses: pypa/gh-action-pypi-publish@release/v1 48 | with: 49 | packages-dir: artifacts/build/wheelhouse 50 | # Multiple build variants will attempt to publish the same fairseq2 51 | # package. Ignore all but the first one. 52 | skip-existing: true 53 | -------------------------------------------------------------------------------- /tests/unit/data/data_pipeline/test_collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from fairseq2.data import Collater, read_sequence 11 | from tests.common import assert_equal, device 12 | 13 | 14 | class TestCollateOp: 15 | @pytest.mark.parametrize("pad_to_multiple", [1, 2, 3, 8]) 16 | def test_op_works(self, pad_to_multiple: int) -> None: 17 | pad_value = 3 18 | 19 | bucket1 = [ 20 | torch.full((4, 2), 0, device=device, dtype=torch.int64), 21 | torch.full((4, 2), 1, device=device, dtype=torch.int64), 22 | torch.full((4, 2), 2, device=device, dtype=torch.int64), 23 | ] 24 | 25 | bucket2 = [ 26 | [{"foo1": 0, "foo2": 1}, {"foo3": 2, "foo4": 3}], 27 | [{"foo1": 4, "foo2": 5}, {"foo3": 6, "foo4": 7}], 28 | [{"foo1": 8, "foo2": 9}, {"foo3": 0, "foo4": 1}], 29 | ] 30 | 31 | seq = [bucket1, bucket2] 32 | 33 | pipeline = read_sequence(seq).collate(pad_value, pad_to_multiple).and_return() 34 | 35 | output1, output2 = list(pipeline) 36 | 37 | collater = Collater(pad_value, pad_to_multiple) 38 | 39 | expected_output1 = collater(bucket1) 40 | expected_output2 = collater(bucket2) 41 | 42 | assert_equal(output1["seqs"], expected_output1["seqs"]) 43 | assert_equal(output1["seq_lens"], expected_output1["seq_lens"]) 44 | 45 | assert output1["is_ragged"] == expected_output1["is_ragged"] 46 | 47 | assert output2 == expected_output2 48 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/sample_data_source.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "fairseq2n/data/sample_data_source.h" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "fairseq2n/utils/tensor.h" 14 | 15 | namespace fairseq2n::detail { 16 | 17 | sample_data_source::sample_data_source(std::vector &&pipelines, std::vector &&weights, bool stop_at_shortest) 18 | { 19 | weights_ = make_tensor_from_vector(weights, { static_cast(pipelines.size()) }); 20 | generator_ = at::globalContext().defaultGenerator(c10::DeviceType::CPU); 21 | 22 | auto gen = [this]() 23 | { 24 | auto result = at::multinomial(weights_, 1, false, generator_) 25 | .item(); 26 | 27 | return static_cast(result); 28 | }; 29 | 30 | inner_ = std::make_unique(std::move(pipelines), std::move(gen), stop_at_shortest); 31 | } 32 | 33 | std::optional 34 | sample_data_source::next() 35 | { 36 | auto output = inner_->next(); 37 | if (!output) 38 | return std::nullopt; 39 | 40 | return output; 41 | } 42 | 43 | void 44 | sample_data_source::reset() 45 | { 46 | inner_->reset(); 47 | } 48 | 49 | void 50 | sample_data_source::record_position(tape &t) const 51 | { 52 | inner_->record_position(t); 53 | } 54 | 55 | void 56 | sample_data_source::reload_position(tape &t) 57 | { 58 | inner_->reload_position(t); 59 | } 60 | 61 | } // namespace fairseq2::detail 62 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/text/sentencepiece/sp_processor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #include "fairseq2n/data/text/sentencepiece/sp_model.h" 19 | 20 | namespace fairseq2n::detail { 21 | 22 | class sp_processor { 23 | public: 24 | static std::unique_ptr 25 | from_serialized(std::string_view serialized); 26 | 27 | private: 28 | sp_processor(std::unique_ptr &&proto); 29 | 30 | public: 31 | explicit 32 | sp_processor(std::string_view model_pathname, sp_model_options &&opts); 33 | 34 | sentencepiece::ImmutableSentencePieceText 35 | encode(std::string_view text) const; 36 | 37 | sentencepiece::ImmutableSentencePieceText 38 | sample(std::string_view text, std::int32_t nbest_size, float alpha) const; 39 | 40 | std::string 41 | decode(const std::vector &tokens) const; 42 | 43 | std::int32_t 44 | token_to_index(std::string_view token) const; 45 | 46 | std::string_view 47 | index_to_token(std::int32_t idx) const; 48 | 49 | std::string 50 | serialize() const; 51 | 52 | public: 53 | std::int32_t unk_idx; 54 | std::int32_t bos_idx; 55 | std::int32_t eos_idx; 56 | std::int32_t pad_idx; 57 | 58 | std::size_t vocabulary_size; 59 | 60 | private: 61 | std::unique_ptr native_; 62 | }; 63 | 64 | } // namespace fairseq2n::detail 65 | -------------------------------------------------------------------------------- /fairseq2n/src/fairseq2n/data/detail/file.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "fairseq2n/memory.h" 15 | 16 | namespace fairseq2n::detail { 17 | 18 | constexpr int invalid_fd = -1; 19 | 20 | class file_desc { 21 | public: 22 | file_desc() noexcept = default; 23 | 24 | file_desc(int fd) noexcept 25 | : fd_{fd} 26 | {} 27 | 28 | file_desc(const file_desc &) = delete; 29 | file_desc &operator=(const file_desc &) = delete; 30 | 31 | file_desc(file_desc &&other) noexcept 32 | : fd_{other.fd_} 33 | { 34 | other.fd_ = invalid_fd; 35 | } 36 | 37 | file_desc & 38 | operator=(file_desc &&other) noexcept 39 | { 40 | close_fd(); 41 | 42 | fd_ = std::exchange(other.fd_, invalid_fd); 43 | 44 | return *this; 45 | } 46 | 47 | ~file_desc() 48 | { 49 | close_fd(); 50 | } 51 | 52 | int 53 | get() const noexcept 54 | { 55 | return fd_; 56 | } 57 | 58 | private: 59 | void 60 | close_fd() noexcept 61 | { 62 | if (fd_ == invalid_fd) 63 | return; 64 | 65 | ::close(fd_); 66 | 67 | fd_ = invalid_fd; 68 | } 69 | 70 | private: 71 | int fd_ = invalid_fd; 72 | }; 73 | 74 | inline bool 75 | operator==(const file_desc &lhs, const file_desc &rhs) noexcept 76 | { 77 | return lhs.get() == rhs.get(); 78 | } 79 | 80 | inline bool 81 | operator!=(const file_desc &lhs, const file_desc &rhs) noexcept 82 | { 83 | return lhs.get() != rhs.get(); 84 | } 85 | 86 | memory_block 87 | memory_map_file(const file_desc &fd, std::string_view pathname); 88 | 89 | } 90 | -------------------------------------------------------------------------------- /tests/unit/data/data_pipeline/test_read_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | from fairseq2.data import read_sequence 10 | 11 | 12 | class TestReadSequenceOp: 13 | def test_op_works(self) -> None: 14 | seq = list(range(1, 10)) 15 | 16 | pipeline = read_sequence(seq).and_return() 17 | 18 | for _ in range(2): 19 | assert list(pipeline) == seq 20 | 21 | pipeline.reset() 22 | 23 | def test_op_works_when_input_sequence_is_empty(self) -> None: 24 | pipeline = read_sequence([]).and_return() 25 | 26 | for _ in range(2): 27 | assert list(pipeline) == [] 28 | 29 | pipeline.reset() 30 | 31 | def test_op_saves_and_restores_its_state(self) -> None: 32 | seq = list(range(1, 10)) 33 | 34 | pipeline = read_sequence(seq).and_return() 35 | 36 | d = None 37 | 38 | it = iter(pipeline) 39 | 40 | # Move the the second example. 41 | for _ in range(2): 42 | d = next(it) 43 | 44 | assert d == 2 45 | 46 | state_dict = pipeline.state_dict() 47 | 48 | # Read a few examples before we roll back. 49 | for _ in range(4): 50 | d = next(it) 51 | 52 | assert d == 6 53 | 54 | # Expected to roll back to the second example. 55 | pipeline.load_state_dict(state_dict) 56 | 57 | # Move to EOD. 58 | for _ in range(7): 59 | d = next(it) 60 | 61 | assert d == 9 62 | 63 | state_dict = pipeline.state_dict() 64 | 65 | pipeline.reset() 66 | 67 | # Expected to be EOD. 68 | pipeline.load_state_dict(state_dict) 69 | 70 | with pytest.raises(StopIteration): 71 | next(iter(pipeline)) 72 | -------------------------------------------------------------------------------- /tests/unit/nn/utils/test_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from fairseq2.nn.utils.mask import compute_row_mask 11 | from tests.common import device 12 | 13 | 14 | def test_compute_row_mask_works() -> None: 15 | shape = (32, 512) 16 | 17 | mask = compute_row_mask(shape, span_len=10, max_mask_prob=0.65, device=device) 18 | 19 | assert mask is not None 20 | 21 | num_masked = torch.count_nonzero(mask, dim=-1) 22 | 23 | assert num_masked[0] > 0 24 | assert num_masked[0] < 512 25 | 26 | assert mask.shape == shape 27 | assert mask.device == device 28 | assert mask.dtype == torch.bool 29 | 30 | assert (num_masked == num_masked[0]).all() == True 31 | 32 | 33 | def test_compute_row_mask_works_when_row_lens_is_specified() -> None: 34 | shape = (4, 16) 35 | 36 | row_lens = torch.tensor([16, 14, 15, 16], device="cpu") 37 | 38 | mask = compute_row_mask( 39 | shape, span_len=4, max_mask_prob=1.0, device=device, row_lens=row_lens 40 | ) 41 | 42 | assert mask is not None 43 | 44 | assert mask.shape == shape 45 | assert mask.device == device 46 | assert mask.dtype == torch.bool 47 | 48 | assert mask.any() 49 | 50 | 51 | def test_compute_row_mask_raises_error_when_row_length_is_smaller_than_span_len() -> ( 52 | None 53 | ): 54 | shape = (4, 16) 55 | 56 | row_lens = torch.tensor([16, 8, 5, 3], device=device) 57 | 58 | with pytest.raises( 59 | ValueError, 60 | match=r"^All lengths in `row_lens` must be greater than `span_len` \(4\), but at least one length is smaller\. row_lens: tensor", 61 | ): 62 | compute_row_mask( 63 | shape, span_len=4, max_mask_prob=1.0, row_lens=row_lens, device=device 64 | ) 65 | -------------------------------------------------------------------------------- /.github/workflows/_publish_doc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | on: 8 | workflow_call: 9 | 10 | jobs: 11 | publish: 12 | name: Publish 13 | runs-on: ubuntu-latest 14 | defaults: 15 | run: 16 | shell: bash 17 | steps: 18 | - name: Download documentation from staging 19 | uses: actions/download-artifact@v3 20 | with: 21 | name: doc 22 | path: ~/doc/ 23 | - name: Check-out the gh-pages branch of the repository 24 | uses: actions/checkout@v3 25 | with: 26 | ref: gh-pages 27 | - name: Set up Git 28 | run: | 29 | git config user.name "github-actions" 30 | git config user.email "github-actions@github.com" 31 | - name: Commit and push the documentation 32 | run: | 33 | version=$(cat ~/doc/VERSION) 34 | 35 | if [[ $version == *.dev* ]]; then 36 | doc_dir=nightly 37 | else 38 | # Ignore pre-release segment for directory name. 39 | mmm_version=$(echo $version | grep --only-matching --extended-regexp '^([0-9]+\.)*[0-9]+' -) 40 | 41 | doc_dir=$mmm_version 42 | 43 | # If we have a stable release, update the 'stable' symlink. 44 | if [[ $version == $mmm_version ]]; then 45 | ln --symbolic --no-target-directory --force $doc_dir stable 46 | fi 47 | fi 48 | 49 | rsync --recursive --delete-after ~/doc/ $doc_dir 50 | 51 | git add --all 52 | 53 | # Push a commit only if there are changes in the branch. 54 | if ! git diff --staged --quiet; then 55 | git commit --message\ 56 | "Generated from $(git rev-parse --short "$GITHUB_SHA")" 57 | 58 | git push 59 | fi 60 | -------------------------------------------------------------------------------- /tests/unit/models/utils/test_arch_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | from fairseq2.models.utils.arch_registry import ArchitectureRegistry 10 | 11 | 12 | class TestArchitectureRegistry: 13 | def test_register_works(self) -> None: 14 | registry = ArchitectureRegistry[str]("model") 15 | 16 | registry.register("arch1", lambda: "config1") 17 | registry.register("arch2", lambda: "config2") 18 | 19 | config1 = registry.get_config("arch1") 20 | config2 = registry.get_config("arch2") 21 | 22 | assert config1 == "config1" 23 | assert config2 == "config2" 24 | 25 | def test_names_works(self) -> None: 26 | registry = ArchitectureRegistry[str]("model") 27 | 28 | arch_names = {"arch1", "arch2", "arch3"} 29 | 30 | for arch_name in arch_names: 31 | registry.register(arch_name, lambda: "config") 32 | 33 | assert registry.names() == arch_names 34 | 35 | def test_register_raises_error_when_architecture_is_already_registered( 36 | self, 37 | ) -> None: 38 | registry = ArchitectureRegistry[str]("model") 39 | 40 | registry.register("arch", lambda: "config") 41 | 42 | with pytest.raises( 43 | ValueError, 44 | match=r"^The architecture name 'arch' is already registered for 'model'\.$", 45 | ): 46 | registry.register("arch", lambda: "config") 47 | 48 | def test_get_config_raises_error_when_architecture_is_not_registered(self) -> None: 49 | registry = ArchitectureRegistry[str]("model") 50 | 51 | with pytest.raises( 52 | ValueError, 53 | match=r"^The registry of 'model' does not contain an architecture named 'foo'\.$", 54 | ): 55 | registry.get_config("foo") 56 | -------------------------------------------------------------------------------- /src/fairseq2/data/cstring.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import TYPE_CHECKING, List, Optional, overload 8 | 9 | from fairseq2 import _DOC_MODE 10 | 11 | if TYPE_CHECKING or _DOC_MODE: 12 | 13 | class CString: 14 | """ 15 | Represents an immutable UTF-8 string that supports zero-copy marshalling 16 | between Python and native code. 17 | """ 18 | 19 | @overload 20 | def __init__(self) -> None: 21 | ... 22 | 23 | @overload 24 | def __init__(self, s: str) -> None: 25 | ... 26 | 27 | def __init__(self, s: Optional[str] = None) -> None: 28 | """ 29 | :param s: 30 | The Python string to copy. 31 | """ 32 | 33 | def __len__(self) -> int: 34 | ... 35 | 36 | def __eq__(self, other: object) -> bool: 37 | ... 38 | 39 | def __ne__(self, other: object) -> bool: 40 | ... 41 | 42 | def __hash__(self) -> int: 43 | ... 44 | 45 | def __bytes__(self) -> bytes: 46 | ... 47 | 48 | def strip(self) -> "CString": 49 | """Return a copy of this string with no whitespace at the beginning and end.""" 50 | 51 | def lstrip(self) -> "CString": 52 | """Return a copy of this string with no whitespace at the beginning.""" 53 | 54 | def rstrip(self) -> "CString": 55 | """Return a copy of this string with no whitespace at the end.""" 56 | 57 | def split(self, sep: Optional[str] = None) -> List["CString"]: 58 | """Return a list of the words in string using sep as the delimiter string.""" 59 | 60 | else: 61 | from fairseq2n.bindings.data.string import CString as CString 62 | 63 | CString.__module__ = __name__ 64 | --------------------------------------------------------------------------------