├── tests ├── __init__.py ├── utils │ ├── __init__.py │ └── test_utils.py ├── benchmark │ ├── __init__.py │ ├── profiler_configs │ │ ├── causal.yaml │ │ └── causal_macct.yaml │ ├── benchmark_utils.py │ ├── profiler_utils.py │ ├── README.md │ ├── benchmark_tokenizer_skip.py │ └── cycle_trainer_profiling.py ├── command │ ├── __init__.py │ └── test_cli_utils.py ├── configs │ ├── __init__.py │ └── test_model_config.py ├── trainer │ ├── __init__.py │ └── test_prepare_cycle_inputs.py ├── task_processors │ ├── __init__.py │ ├── utils.py │ ├── test_ner_processors.py │ └── test_translation_processors.py ├── testing_utils │ ├── __init__.py │ ├── custom_marks.py │ ├── model_registry.py │ └── test_model_registry.py ├── sample_data │ ├── conll2003 │ │ ├── dataset_dict.json │ │ ├── test │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ │ ├── train │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ │ └── validation │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ ├── wmt14 │ │ ├── dataset_dict.json │ │ ├── test │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ │ ├── train │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ │ └── validation │ │ │ ├── data-00000-of-00001.arrow │ │ │ ├── state.json │ │ │ └── dataset_info.json │ └── generate_sample_data.py ├── models_to_test.yaml └── test_examples.py ├── docs ├── examples │ ├── index.md │ ├── cycle_ner.md │ └── wmt_2014.md ├── api_reference │ ├── configuration.md │ ├── cycles.md │ ├── task_processors.md │ └── cycle_trainer.md ├── conceptual_reference │ ├── index.md │ ├── cycle_consistency_training.md │ ├── task_processors.md │ └── macct.md ├── assets │ ├── icicle_snakeviz.png │ └── sunburst_snakeviz.png ├── installation.md ├── index.md ├── usage.md └── performance.md ├── trufflehog.yaml ├── src └── cycleformers │ ├── command │ ├── __init__.py │ └── cli_utils.py │ ├── cycles │ ├── __init__.py │ └── utils.py │ ├── data_config.py │ ├── import_utils.py │ ├── __init__.py │ ├── task_processors │ ├── __init__.py │ ├── base.py │ ├── translation.py │ └── ner.py │ ├── cycle_trainer_utils.py │ ├── cycle_training_arguments.py │ ├── exceptions.py │ ├── model_config.py │ ├── utils.py │ └── trainer_callback.py ├── .github ├── workflows │ ├── trufflehog.yml │ ├── pull-request.yml │ ├── build-docs.yml │ ├── release.yml │ └── checks.yml ├── ISSUE_TEMPLATE │ ├── question.md │ ├── feature_request.md │ └── bug_report.md ├── actions │ └── setup-poetry │ │ └── action.yml ├── badges │ ├── build.svg │ └── coverage.svg └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── examples ├── configs │ ├── seq2seq.yaml │ ├── causal.yaml │ ├── causal_macct.yaml │ └── seq2seq_macct.yaml ├── translation_wmt14 │ ├── prepare_data.py │ └── train.py └── cycle_ner │ └── train.py ├── .pre-commit-config.yaml ├── Makefile ├── mkdocs.yml ├── CONTRIBUTING.md ├── README.md └── pyproject.toml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/command/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/task_processors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/testing_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/examples/index.md: -------------------------------------------------------------------------------- 1 | 🚧 Under Construction 🚧 -------------------------------------------------------------------------------- /docs/examples/cycle_ner.md: -------------------------------------------------------------------------------- 1 | 🚧 Under Construction 🚧 -------------------------------------------------------------------------------- /docs/examples/wmt_2014.md: -------------------------------------------------------------------------------- 1 | 🚧 Under Construction 🚧 -------------------------------------------------------------------------------- /docs/api_reference/configuration.md: -------------------------------------------------------------------------------- 1 | 🚧 Under Construction 🚧 -------------------------------------------------------------------------------- /trufflehog.yaml: -------------------------------------------------------------------------------- 1 | detectors: 2 | SentryToken: 3 | disabled: true 4 | -------------------------------------------------------------------------------- /tests/sample_data/conll2003/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /tests/sample_data/wmt14/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /docs/conceptual_reference/index.md: -------------------------------------------------------------------------------- 1 | Fundamental concepts are accessible via the left sidebar navigation. -------------------------------------------------------------------------------- /tests/task_processors/utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | class TaskProcessorTestMixin: 5 | pass 6 | -------------------------------------------------------------------------------- /docs/assets/icicle_snakeviz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/docs/assets/icicle_snakeviz.png -------------------------------------------------------------------------------- /docs/assets/sunburst_snakeviz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/docs/assets/sunburst_snakeviz.png -------------------------------------------------------------------------------- /src/cycleformers/command/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli_utils import CfArgumentParser 2 | 3 | 4 | __all__ = ["CfArgumentParser"] 5 | -------------------------------------------------------------------------------- /docs/api_reference/cycles.md: -------------------------------------------------------------------------------- 1 | # Cycle Implementations 2 | 3 | ## Causal to Causal Tokenizer Bypass 4 | 5 | ::: cycleformers.cycles._prepare_causal_skip_cycle_inputs -------------------------------------------------------------------------------- /tests/sample_data/wmt14/test/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/wmt14/test/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /tests/sample_data/wmt14/train/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/wmt14/train/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /tests/sample_data/conll2003/test/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/conll2003/test/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /tests/sample_data/conll2003/train/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/conll2003/train/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /tests/sample_data/wmt14/validation/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/wmt14/validation/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /tests/sample_data/conll2003/validation/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/conll2003/validation/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /docs/api_reference/task_processors.md: -------------------------------------------------------------------------------- 1 | # Task Processors 2 | 3 | ::: cycleformers.task_processors.base.BaseProcessor 4 | 5 | 6 | ## NER 7 | 8 | ### CONLL2003Processor 9 | 10 | ::: cycleformers.task_processors.ner.CONLL2003Processor 11 | -------------------------------------------------------------------------------- /docs/api_reference/cycle_trainer.md: -------------------------------------------------------------------------------- 1 | # CycleTrainer 2 | 3 | ::: cycleformers.cycle_trainer.CycleTrainer 4 | 5 | 6 | ### Methods 7 | 8 | ::: cycleformers.cycle_trainer.CycleTrainer.train 9 | 10 | ::: cycleformers.cycle_trainer.CycleTrainer._cycle_step 11 | -------------------------------------------------------------------------------- /src/cycleformers/cycles/__init__.py: -------------------------------------------------------------------------------- 1 | from .cycles import _default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs 2 | from .utils import PrepareCycleInputsNotSet 3 | 4 | 5 | __all__ = ["_prepare_causal_skip_cycle_inputs", "_default_prepare_cycle_inputs", "PrepareCycleInputsNotSet"] 6 | -------------------------------------------------------------------------------- /tests/models_to_test.yaml: -------------------------------------------------------------------------------- 1 | # Causal Models 2 | tiny-llama-3.1: 3 | repo_id: "trl-internal-testing/tiny-LlamaForCausalLM-3.1" 4 | 5 | qwen-2.5: 6 | repo_id: "Qwen/Qwen2.5-0.5B" 7 | description: "Needs to be replaced with a tiny model at some point" 8 | 9 | 10 | # Seq2Seq Models 11 | tiny-t5: 12 | repo_id: "google/t5-efficient-tiny" 13 | -------------------------------------------------------------------------------- /tests/sample_data/conll2003/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "9b1ccf21d530db94", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "test[:20]" 13 | } -------------------------------------------------------------------------------- /tests/sample_data/wmt14/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "85d5b6c8f838535a", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "test[:20]" 13 | } -------------------------------------------------------------------------------- /tests/sample_data/wmt14/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "f2070ded563ec01c", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "train[:100]" 13 | } -------------------------------------------------------------------------------- /tests/sample_data/conll2003/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "35806e9fdfaf6cd6", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "train[:100]" 13 | } -------------------------------------------------------------------------------- /tests/sample_data/wmt14/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "60eb8be03cba3fa7", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "validation[:20]" 13 | } -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@main -------------------------------------------------------------------------------- /tests/sample_data/conll2003/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "229dfaf63b462a8d", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "validation[:20]" 13 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.so 5 | .Python 6 | .env 7 | .venv 8 | env/ 9 | venv/ 10 | ENV/ 11 | dist/ 12 | build/ 13 | *.egg-info/ 14 | .coverage 15 | htmlcov/ 16 | .pytest_cache/ 17 | .vscode/ 18 | .idea 19 | site/ 20 | wandb/ 21 | examples/**/data/ 22 | misc/ 23 | benchmark_results/ 24 | .mypy_cache/ 25 | .ruff_cache/ 26 | profiles/ 27 | *cache* -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | Cycleformers is available on PyPi: 2 | 3 | ```bash 4 | pip install cycleformers 5 | ``` 6 | 7 | The module has only been tested on Linux. I have no plans to support Windows or MacOS at this time. 8 | 9 | ## Development Setup 10 | 11 | To build the module locally for development and bugfixing, I recommend the following command which initialises a .env, runs poetry install and sets up the pre-commit hooks: 12 | 13 | ```bash 14 | make init 15 | ``` -------------------------------------------------------------------------------- /src/cycleformers/cycles/utils.py: -------------------------------------------------------------------------------- 1 | class PrepareCycleInputsNotSet(Exception): 2 | """Exception raised when class has not properly set""" 3 | 4 | def __init__(self, msg: str | None = None): 5 | self.msg = ( 6 | msg 7 | or "_prepare_cycle_inputs is not corrently set. If you are modifying CycleTrainer.__init__ " 8 | "make sure to call self.set_cycle_inputs_fn(), optionally with a valid function." 9 | ) 10 | super().__init__(self.msg) 11 | -------------------------------------------------------------------------------- /examples/configs/seq2seq.yaml: -------------------------------------------------------------------------------- 1 | ### ONLY REQUIRED ARGUMENTS ### 2 | output_dir: "./outputs" 3 | ############################### 4 | 5 | # === Training Arguments === # 6 | per_device_train_batch_size: 2 7 | gradient_accumulation_steps: 8 8 | per_device_eval_batch_size: 2 9 | logging_strategy: steps 10 | logging_steps: 1 11 | 12 | 13 | # === Model Configs === # 14 | model_name_or_path: "google/flan-t5-base" 15 | 16 | # Optimisation 17 | # attn_implementation: "flash_attention_2" 18 | # use_liger_kernel: true -------------------------------------------------------------------------------- /.github/workflows/pull-request.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | pull-requests: read 10 | contents: write 11 | 12 | jobs: 13 | check-pr-title: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: amannn/action-semantic-pull-request@v5 17 | env: 18 | GITHUB_TOKEN: ${{ github.token }} 19 | 20 | checks: 21 | uses: ./.github/workflows/checks.yml 22 | 23 | build-docs: 24 | uses: ./.github/workflows/build-docs.yml -------------------------------------------------------------------------------- /tests/testing_utils/custom_marks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cycleformers.import_utils import is_liger_kernel_available, is_peft_available 4 | 5 | 6 | def requires_peft(func): 7 | """Decorator to skip test if PEFT is not available""" 8 | return pytest.mark.skipif(not is_peft_available(), reason="PEFT not available")(func) 9 | 10 | 11 | def requires_liger_kernel(func): 12 | """Decorator to skip test if Liger kernel is not available""" 13 | return pytest.mark.skipif(not is_liger_kernel_available(), reason="Liger kernel not available")(func) 14 | -------------------------------------------------------------------------------- /src/cycleformers/data_config.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class DataConfig: 7 | """ 8 | Arguments pertaining to what data we are going to input our model for training and eval. 9 | """ 10 | 11 | dataset_name: str | None = None 12 | dataset_config_name: str | None = None 13 | text_column: str = "text" 14 | formatting_func: Callable | None = None 15 | max_seq_length: int | None = None 16 | remove_unused_columns: bool = True 17 | preprocessing_num_workers: int | None = None 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a question about this project 4 | title: '[QUESTION] ' 5 | labels: question 6 | assignees: '' 7 | --- 8 | 9 | **Question** 10 | State your question clearly and provide as much context as possible. 11 | 12 | **Context** 13 | Provide any relevant context about what you're trying to achieve: 14 | - What have you already tried? 15 | - What documentation have you already consulted? 16 | - What research have you done? 17 | 18 | **Additional Information** 19 | Add any other relevant information or screenshots about your question here. -------------------------------------------------------------------------------- /tests/benchmark/profiler_configs/causal.yaml: -------------------------------------------------------------------------------- 1 | ### ONLY REQUIRED ARGUMENTS ### 2 | output_dir: "/tmp/profiler_outputs" 3 | ############################### 4 | 5 | # === Training Arguments === # 6 | per_device_train_batch_size: 2 7 | gradient_accumulation_steps: 8 8 | per_device_eval_batch_size: 2 9 | logging_strategy: steps 10 | logging_steps: 1 11 | report_to: "tensorboard" 12 | 13 | 14 | max_steps: 100 15 | 16 | 17 | # === Model Configs === # 18 | model_name_or_path: "Qwen/Qwen2.5-1.5B" 19 | 20 | 21 | # Optimisation 22 | # attn_implementation: "flash_attention_2" 23 | # use_liger_kernel: true -------------------------------------------------------------------------------- /examples/configs/causal.yaml: -------------------------------------------------------------------------------- 1 | ### ONLY REQUIRED ARGUMENTS ### 2 | output_dir: "./outputs" 3 | ############################### 4 | 5 | # === Training Arguments === # 6 | per_device_train_batch_size: 1 7 | gradient_accumulation_steps: 16 8 | per_device_eval_batch_size: 1 9 | gradient_checkpointing: true 10 | logging_strategy: steps 11 | logging_steps: 1 12 | 13 | 14 | # === Model Configs === # 15 | A_model_name_or_path: "Qwen/Qwen2.5-1.5B" 16 | A_torch_dtype: "bfloat16" 17 | 18 | B_model_name_or_path: "Qwen/Qwen2.5-1.5B" 19 | B_torch_dtype: "bfloat16" 20 | 21 | 22 | # Optimisation 23 | # attn_implementation: "flash_attention_2" 24 | # use_liger_kernel: true -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_install_hook_types: [ pre-commit, pre-push ] 2 | 3 | repos: 4 | - repo: local 5 | hooks: 6 | - id: make-format 7 | name: Run formatting 8 | entry: make 9 | args: [ format ] 10 | language: system 11 | pass_filenames: false 12 | 13 | - id: make-lint 14 | name: Run linting 15 | entry: make 16 | args: [ lint ] 17 | language: system 18 | pass_filenames: false 19 | 20 | - id: build-docs 21 | name: Build documentation 22 | entry: make 23 | args: [ build-docs ] 24 | language: system 25 | pass_filenames: false 26 | stages: [ pre-push ] 27 | -------------------------------------------------------------------------------- /examples/configs/causal_macct.yaml: -------------------------------------------------------------------------------- 1 | ### ONLY REQUIRED ARGUMENTS ### 2 | output_dir: "./outputs" 3 | ############################### 4 | 5 | # === Training Arguments === # 6 | per_device_train_batch_size: 4 7 | gradient_accumulation_steps: 4 8 | per_device_eval_batch_size: 4 9 | logging_strategy: steps 10 | logging_steps: 1 11 | 12 | 13 | # === Model Configs === # 14 | model_name_or_path: "Qwen/Qwen2.5-3B" 15 | 16 | # Lora 17 | use_peft: true 18 | lora_r: 32 19 | lora_alpha: 64 20 | use_rslora: true 21 | lora_dropout: 0.05 22 | lora_target_modules: "all-linear" 23 | lora_task_type: "CAUSAL_LM" 24 | 25 | # Optimisation 26 | # attn_implementation: "flash_attention_2" 27 | # use_liger_kernel: true -------------------------------------------------------------------------------- /src/cycleformers/import_utils.py: -------------------------------------------------------------------------------- 1 | from transformers.utils.import_utils import _is_package_available 2 | 3 | 4 | _liger_kernel_available = _is_package_available("liger-kernel") 5 | _flash_attn_available = _is_package_available("flash-attn") 6 | _peft_available = _is_package_available("peft") 7 | _wandb_available = _is_package_available("wandb") 8 | 9 | 10 | def is_liger_kernel_available() -> bool: 11 | return _liger_kernel_available 12 | 13 | 14 | def is_flash_attn_available() -> bool: 15 | return _flash_attn_available 16 | 17 | 18 | def is_peft_available() -> bool: 19 | return _peft_available 20 | 21 | 22 | def is_wandb_available() -> bool: 23 | return _wandb_available 24 | -------------------------------------------------------------------------------- /examples/configs/seq2seq_macct.yaml: -------------------------------------------------------------------------------- 1 | ### ONLY REQUIRED ARGUMENTS ### 2 | output_dir: "./outputs" 3 | ############################### 4 | 5 | # === Training Arguments === # 6 | per_device_train_batch_size: 2 7 | gradient_accumulation_steps: 8 8 | per_device_eval_batch_size: 2 9 | logging_strategy: steps 10 | logging_steps: 1 11 | 12 | 13 | # === Model Configs === # 14 | model_name_or_path: "google/flan-t5-base" 15 | 16 | # Lora 17 | use_peft: true 18 | lora_r: 32 19 | lora_alpha: 64 20 | use_rslora: true 21 | lora_dropout: 0.05 22 | lora_target_modules: "all-linear" 23 | lora_task_type: "CAUSAL_LM" 24 | 25 | # Optimisation 26 | # attn_implementation: "flash_attention_2" 27 | # use_liger_kernel: true -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: '[FEATURE] ' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | **Is your feature request related to a problem? Please describe.** 10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | 12 | **Describe the solution you'd like** 13 | A clear and concise description of what you want to happen. 14 | 15 | **Describe alternatives you've considered** 16 | A clear and concise description of any alternative solutions or features you've considered. 17 | 18 | **Additional context** 19 | Add any other context or screenshots about the feature request here. -------------------------------------------------------------------------------- /.github/workflows/build-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build Documentation 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | deploy: 7 | type: boolean 8 | description: "If true, the docs will be deployed." 9 | default: false 10 | 11 | jobs: 12 | build-docs: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - uses: ./.github/actions/setup-poetry 18 | with: 19 | python-version: 3.11 20 | 21 | - name: Build & maybe deploy documentation 22 | run: | 23 | poetry run mkdocs build --verbose --clean 24 | if ${{ inputs.deploy }}; then 25 | poetry run mkdocs gh-deploy --force 26 | fi 27 | -------------------------------------------------------------------------------- /tests/benchmark/profiler_configs/causal_macct.yaml: -------------------------------------------------------------------------------- 1 | ### ONLY REQUIRED ARGUMENTS ### 2 | output_dir: "/tmp/profiler_outputs" 3 | ############################### 4 | 5 | # === Training Arguments === # 6 | use_macct: true 7 | 8 | per_device_train_batch_size: 1 9 | gradient_accumulation_steps: 16 10 | per_device_eval_batch_size: 1 11 | logging_strategy: steps 12 | logging_steps: 1 13 | report_to: "tensorboard" 14 | 15 | 16 | max_steps: 100 17 | 18 | 19 | # === Model Configs === # 20 | model_name_or_path: "Qwen/Qwen2.5-0.5B" 21 | 22 | 23 | # === LoRA Configs === # 24 | lora_r: 16 25 | lora_alpha: 32 26 | lora_dropout: 0.05 27 | task_type: "CAUSAL_LM" 28 | use_rslora: true 29 | 30 | 31 | # Optimisation 32 | # attn_implementation: "flash_attention_2" 33 | # use_liger_kernel: true -------------------------------------------------------------------------------- /src/cycleformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | 3 | from .command import CfArgumentParser 4 | from .cycle_trainer import CycleTrainer 5 | from .cycle_training_arguments import CycleTrainingArguments 6 | from .data_config import DataConfig 7 | from .exceptions import InvalidCycleKeyError, MACCTModelError, MissingModelError 8 | from .model_config import ModelConfig, ModelConfigA, ModelConfigB, merge_configs 9 | from .utils import DEFAULT_SEP_SEQ 10 | 11 | 12 | __all__ = [ 13 | "CycleTrainer", 14 | "CycleTrainingArguments", 15 | "ModelConfig", 16 | "ModelConfigA", 17 | "ModelConfigB", 18 | "DataConfig", 19 | "MACCTModelError", 20 | "MissingModelError", 21 | "InvalidCycleKeyError", 22 | "DEFAULT_SEP_SEQ", 23 | "CfArgumentParser", 24 | "merge_configs", 25 | ] 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: '[BUG] ' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | **Describe the bug** 10 | A clear and concise description of what the bug is. 11 | 12 | **To Reproduce** 13 | Steps to reproduce the behavior: 14 | 1. Go to '...' 15 | 2. Click on '....' 16 | 3. Scroll down to '....' 17 | 4. See error 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Environment (please complete the following information):** 26 | - OS: [e.g. iOS, Windows 10] 27 | - Browser: [e.g. chrome, safari] 28 | - Version: [e.g. 22] 29 | 30 | **Additional context** 31 | Add any other context about the problem here. -------------------------------------------------------------------------------- /src/cycleformers/task_processors/__init__.py: -------------------------------------------------------------------------------- 1 | """Processors module for transforming various dataset formats into cycleformers-compatible format. 2 | 3 | This module provides a flexible API for converting different dataset formats into cycleformers-compatible format. 4 | Each processor handles a specific dataset format or task type (e.g., NER, machine translation, etc.) while sharing 5 | common functionality through the base class. 6 | """ 7 | 8 | from .base import BaseProcessor, ProcessorConfig 9 | from .ner import CONLL2003Processor, CONLL2003ProcessorConfig 10 | from .translation import TranslationProcessor, TranslationProcessorConfig 11 | 12 | 13 | __all__ = [ 14 | "BaseProcessor", 15 | "CONLL2003Processor", 16 | "CONLL2003ProcessorConfig", 17 | "ProcessorConfig", 18 | "TranslationProcessor", 19 | "TranslationProcessorConfig", 20 | ] 21 | -------------------------------------------------------------------------------- /.github/actions/setup-poetry/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Set up Poetry and install' 2 | description: 'Set up a specific version of Poetry and install dependencies using caching.' 3 | inputs: 4 | python-version: 5 | description: "Version range or exact version of Python or PyPy to use, using SemVer's version range syntax." 6 | default: 3.11 7 | install-dependencies: 8 | description: "If true, dependencies will be installed." 9 | default: true 10 | 11 | runs: 12 | using: 'composite' 13 | steps: 14 | - name: Install poetry 15 | run: pipx install poetry==1.8.3 16 | shell: bash 17 | 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ inputs.python-version }} 21 | cache: 'poetry' 22 | 23 | - name: Install dependencies 24 | if: ${{ inputs.install-dependencies }} 25 | run: poetry install --all-extras 26 | shell: bash -------------------------------------------------------------------------------- /tests/sample_data/generate_sample_data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from datasets import DatasetDict, load_dataset 4 | 5 | 6 | CURR_DIR = Path(__file__).parent 7 | 8 | dataset_configs = { 9 | "wmt14": {"name": "wmt14", "config": "de-en"}, # Specify language pair 10 | "conll2003": {"name": "conll2003", "config": None}, 11 | } 12 | 13 | for dataset_info in dataset_configs.values(): 14 | dataset_dict = DatasetDict() 15 | for split in ["train[:100]", "validation[:20]", "test[:20]"]: 16 | # Use streaming to avoid downloading entire dataset 17 | dataset = load_dataset( 18 | dataset_info["name"], 19 | dataset_info["config"], 20 | split=split, 21 | ) 22 | dataset_dict[split.split("[")[0]] = dataset 23 | 24 | # Save the small sample dataset 25 | output_path = CURR_DIR / f"{dataset_info['name']}" 26 | dataset_dict.save_to_disk(output_path) 27 | -------------------------------------------------------------------------------- /tests/benchmark/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | 4 | from transformers import PreTrainedTokenizerBase 5 | 6 | from cycleformers import CycleTrainer, CycleTrainingArguments 7 | 8 | 9 | @dataclass 10 | class BenchmarkConfig: 11 | config_name: str 12 | output_dir: str = "benchmark_results/" 13 | 14 | 15 | class Benchmark(ABC): 16 | def __init__(self, config: BenchmarkConfig, *args, **kwargs): 17 | self.config = config 18 | 19 | @abstractmethod 20 | def run(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | 24 | class MockCycleTrainer(CycleTrainer): 25 | """Lightweight mock of CycleTrainer for benchmarking""" 26 | 27 | def __init__(self, tokenizer: PreTrainedTokenizerBase): 28 | self.args = CycleTrainingArguments(output_dir="./tmp") 29 | self.tokenizer_A = self.tokenizer_B = tokenizer 30 | self.sep_seq = "\n\n" 31 | -------------------------------------------------------------------------------- /src/cycleformers/cycle_trainer_utils.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Any 3 | 4 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, PreTrainedModel 5 | 6 | from cycleformers.exceptions import CycleModelError 7 | 8 | 9 | def load_model(model_path: str | PathLike[str], **model_init_kwargs: dict[str, Any]) -> PreTrainedModel: 10 | auto_config = AutoConfig.from_pretrained(model_path) 11 | if "ForCausalLM" in auto_config.model_type: 12 | model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) 13 | elif auto_config.is_encoder_decoder: 14 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path, **model_init_kwargs) 15 | else: 16 | raise CycleModelError( 17 | "Unsupported or unrecognised model type. Make sure the provided model is either " 18 | "CausalLM or Seq2SeqLM. If you are using a custom model, you may need to pass the instantiated model to " 19 | "CycleTrainer." 20 | ) 21 | 22 | # TODO: Handle quantisation 23 | return model 24 | -------------------------------------------------------------------------------- /.github/badges/build.svg: -------------------------------------------------------------------------------- 1 | tests: 109tests109 -------------------------------------------------------------------------------- /.github/badges/coverage.svg: -------------------------------------------------------------------------------- 1 | coverage: 75.90%coverage75.90% -------------------------------------------------------------------------------- /src/cycleformers/cycle_training_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import TYPE_CHECKING, Any 3 | 4 | from transformers.training_args import TrainingArguments 5 | 6 | 7 | if TYPE_CHECKING: 8 | from .model_config import ModelConfig 9 | 10 | 11 | @dataclass 12 | class CycleTrainingArguments(TrainingArguments): 13 | """Will eventually contain cycle-specific arguments""" 14 | 15 | use_macct: bool = False 16 | report_to: list[str] | None = None 17 | sep_seq: str = field(default="") # Separator sequence 18 | model_init_kwargs: dict[str, Any] = field(default_factory=dict) 19 | 20 | def __post_init__(self): 21 | super().__post_init__() 22 | self._model_config = None 23 | 24 | @property 25 | def model_config(self) -> "ModelConfig": 26 | return self._model_config 27 | 28 | @model_config.setter 29 | def model_config(self, value: "ModelConfig"): 30 | self._model_config = value 31 | 32 | 33 | @dataclass 34 | class ModelTrainingArguments: 35 | """Will eventually contain model-specific arguments""" 36 | 37 | pass 38 | 39 | 40 | __all__ = ["CycleTrainingArguments", "ModelTrainingArguments"] 41 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPi 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | name: Create Release 9 | runs-on: ubuntu-latest 10 | environment: 11 | name: pypi 12 | url: https://pypi.org/project/cycleformers 13 | permissions: 14 | id-token: write 15 | contents: write 16 | issues: write 17 | pull-requests: write 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | 24 | - uses: ./.github/actions/setup-poetry 25 | with: 26 | python-version: 3.11 27 | 28 | - name: Run quality checks 29 | uses: ./.github/workflows/checks.yml 30 | 31 | - name: Python Semantic Release 32 | env: 33 | GH_TOKEN: ${{ github.token }} 34 | run: | 35 | pip install python-semantic-release 36 | git config user.name github-actions 37 | git config user.email github-actions@github.com 38 | semantic-release version 39 | semantic-release publish 40 | 41 | documentation: 42 | name: Build & deploy documentation 43 | uses: ./.github/workflows/build-docs.yml 44 | with: 45 | deploy: true 46 | -------------------------------------------------------------------------------- /docs/conceptual_reference/cycle_consistency_training.md: -------------------------------------------------------------------------------- 1 | ## Cycle-consistency Training 2 | 3 | Cycle-consistency training (CCT) creates a closed feedback loop between source and target domains by linking two models together during training. Each model implements the inverse function of the other i.e. $f(g(x)) = x$, for example one that translates English to German and the other German to English. Both models are trained jointly trained using a single optimiser on the round-trip reconstruction of input data from non-parallel dataset. Quite often, cycle-consistency loss is used as a secondary loss component such as in [Zhu et al. 2017](https://arxiv.org/abs/1703.10593) which also uses adversarial loss. 4 | 5 | 📈 GRAPH TO BE INSERTED HERE 6 | 7 | For tasks such as text generation, we must sample discrete tokens from the model's continuous output distribution, breaking the gradient flow. In these settings, $f(g(x))$ is non-differentiable, preventing us from using the standard CCT loss. While CCT enforces a closed loop within each training batch, iterative back translation (IBT) avoids the same optimization issue by using one model to generate synthetic parallel data for the other to use as input. Each cycle therefore has a separate loss function and optimiser, alternating between training each model ([Gou et al. 2020](https://arxiv.org/abs/2006.04702)). 8 | 9 | 📈 GRAPH TO BE INSERTED HERE 10 | 11 | 12 | -------------------------------------------------------------------------------- /docs/conceptual_reference/task_processors.md: -------------------------------------------------------------------------------- 1 | # Task Processors 2 | 3 | 🚧 This section is under construction. 🚧 4 | 5 | When using task_processors, which splits are downloaded and the size of them can be controlled via the datasets `slice split` syntax. For example, the following will download the first 100 samples from the train split of the WMT14 dataset and the first 30 samples from the test split. 6 | 7 | ```python 8 | from datasets import load_dataset 9 | from cycleformers.task_processors.translation import TranslationProcessor 10 | 11 | config = TranslationProcessorConfig( 12 | dataset_name="wmt14", 13 | dataset_config_name="de-en", 14 | split=["train[:100]", "test[:30]"], 15 | ) 16 | dataset = TranslationProcessor(config) 17 | ``` 18 | 19 | More information on the syntax for `slice split` can be found in the [datasets documentation](https://huggingface.co/docs/datasets/loading#slice-splits). 20 | 21 | ## BaseProcessor 22 | 23 | ::: src.cycleformers.task_processors.base.BaseProcessor 24 | 25 | ::: src.cycleformers.task_processors.base.ProcessorConfig 26 | 27 | 28 | ## Named-Entity Recognition (NER) 29 | 30 | ::: src.cycleformers.task_processors.ner.CONLL2003Processor 31 | 32 | ::: src.cycleformers.task_processors.ner.CONLL2003ProcessorConfig 33 | 34 | 35 | 36 | ### Helper Functions 37 | 38 | ::: src.cycleformers.task_processors.ner.reconstruct_sentence 39 | 40 | ::: src.cycleformers.task_processors.ner.ner_to_sequences 41 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Cycleformers 2 | 3 |
4 | 5 | [![Python](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/) 6 | [![PyPI](https://img.shields.io/pypi/v/cycleformers)](https://pypi.org/project/cycleformers/) 7 | [![License: CC BY 4.0](https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by/4.0/) 8 | [![Coverage](.github/badges/coverage.svg)](https://codecov.io/gh/wrmthorne/cycleformers) 9 | [![Build Status](.github/badges/build.svg)](https://github.com/wrmthorne/cycleformers/actions/workflows) 10 | 11 |
12 | 13 | A Python library for efficient cycle-consistency training of transformer models. Cycleformers simplifies iterative back-translation with support for both causal and seq2seq architectures. We also implement Multi-Adapter Cycle-Consistency Training (MACCT), enabling training of LoRA adapters on a frozen base model for `7.5x` larger model capacity for the same memory footprint. 14 | 15 | ## Features 16 | 17 | - 🤗 Seamless integration with Hugging Face Transformers 18 | - 🚀 PEFT/LoRA support for memory-efficient training 19 | - 🤖 Compatible with both causal and seq2seq models 20 | - 🔥 Optimized for various hardware configurations 21 | 22 | ## Documentation 23 | 24 | - [Conceptual Reference](conceptual_reference/task_processors.md) 25 | - [API Reference](api_reference/task_processors.md) 26 | - [Examples](examples/translation_wmt14/train.py) 27 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) 16 | 17 | 18 | ## Before submitting 19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 20 | - [ ] Did you read the [contributor guideline](https://github.com/wrmthorne/cycleformers/blob/main/CONTRIBUTING.md)? 21 | - [ ] Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case. 22 | - [ ] Did you make sure to update the documentation with your changes? Here are the 23 | [documentation guidelines](https://github.com/wrmthorne/cycleformers/tree/main/docs). 24 | - [ ] Did you write any new necessary tests? 25 | 26 | 27 | ## Who can review? 28 | 29 | Please tag @wrmthorne for review. Request can be merged one review is provided and all checks have passed. 30 | -------------------------------------------------------------------------------- /examples/translation_wmt14/prepare_data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from datasets import DatasetDict, load_dataset 4 | 5 | 6 | def prepare_wmt14_en_de_datasets(num_samples=5000, test_size=100): 7 | dataset = load_dataset("wmt/wmt14", split="train") 8 | dataset = dataset.map( 9 | lambda x: {"english": x["translation"]["en"], "german": x["translation"]["de"]} 10 | ).remove_columns(["translation"]) 11 | 12 | # Pull out sample of paired data for evaluation 13 | dataset = dataset.train_test_split(test_size=test_size) 14 | 15 | en_train = ( 16 | dataset["train"] 17 | .select_columns("english") 18 | .shuffle(seed=42) # Shuffle to simulate not having any exact translations 19 | .select(range(num_samples)) 20 | .rename_column("english", "text") 21 | ) 22 | en_eval = dataset["test"].rename_columns({"english": "text", "german": "labels"}) 23 | dataset_en = DatasetDict({"train": en_train, "test": en_eval}) 24 | 25 | de_train = ( 26 | dataset["train"] 27 | .select_columns("german") 28 | .shuffle(seed=0) # Shuffle to simulate not having any exact translations 29 | .select(range(num_samples)) 30 | .rename_column("german", "text") 31 | ) 32 | de_eval = dataset["test"].rename_columns({"german": "text", "english": "labels"}) 33 | dataset_de = DatasetDict({"train": de_train, "test": de_eval}) 34 | 35 | return dataset_en, dataset_de 36 | 37 | 38 | if __name__ == "__main__": 39 | dataset_en, dataset_de = prepare_wmt14_en_de_datasets() 40 | 41 | Path("./data").mkdir(exist_ok=True) 42 | 43 | dataset_en.save_to_disk("./data/en") 44 | dataset_de.save_to_disk("./data/de") 45 | -------------------------------------------------------------------------------- /tests/benchmark/profiler_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | 4 | import torch 5 | 6 | 7 | def record_function_wrapper(name=None): 8 | """Decorator to profile a function's execution time and memory usage. 9 | 10 | Args: 11 | name (str, optional): Custom name for the profiling label. Defaults to function name. 12 | """ 13 | 14 | def decorator(func): 15 | @wraps(func) 16 | def wrapper(self, *args, **kwargs): 17 | step_start_time = time.perf_counter() 18 | 19 | # Clear cache before step if CUDA is available 20 | if torch.cuda.is_available(): 21 | torch.cuda.empty_cache() 22 | 23 | profile_name = name or f"## {func.__name__} ##" 24 | with torch.profiler.record_function(profile_name): 25 | result = func(self, *args, **kwargs) 26 | 27 | # Track timing and memory if the object has profiling_stats 28 | if hasattr(self, "profiling_stats"): 29 | step_duration = time.perf_counter() - step_start_time 30 | self.profiling_stats["step_times"].append(step_duration) 31 | 32 | # Log GPU memory 33 | if hasattr(self, "_log_gpu_memory"): 34 | current_memory, max_memory = self._log_gpu_memory() or (0, 0) 35 | 36 | # Update metrics if result is a dict 37 | if isinstance(result, dict): 38 | result.update( 39 | { 40 | "gpu_memory_used": current_memory, 41 | "gpu_memory_max": max_memory, 42 | } 43 | ) 44 | 45 | return result 46 | 47 | return wrapper 48 | 49 | return decorator 50 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Code quality checks & tests 2 | on: 3 | workflow_call: 4 | 5 | permissions: 6 | contents: write 7 | 8 | jobs: 9 | checks: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.11, 3.12] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - uses: ./.github/actions/setup-poetry 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | 22 | # Cheapest check first 23 | - name: Run audit 24 | run: make audit 25 | 26 | - name: Run quality checks 27 | run: make quality 28 | 29 | - name: Run doctests 30 | run: poetry run python -m doctest **/*.py 31 | 32 | - name: Run tests with coverage 33 | run: > 34 | poetry run pytest 35 | --cov=src/cycleformers 36 | --cov=examples 37 | --cov-report=xml 38 | --cov-report=term-missing tests/ 39 | tests/ 40 | --all 41 | --instafail 42 | -n auto 43 | --junitxml=junit.xml 44 | # --cov-fail-under=80 45 | 46 | - name: Generate coverage badge 47 | run: | 48 | mkdir -p .github/badges 49 | poetry run genbadge coverage -i coverage.xml -o .github/badges/coverage.svg 50 | poetry run genbadge tests -i junit.xml -o .github/badges/build.svg 51 | 52 | - name: Commit badge 53 | if: github.event_name == 'pull_request' 54 | run: | 55 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 56 | git config --local user.name "github-actions[bot]" 57 | # Fetch and checkout the PR branch 58 | git fetch origin ${{ github.head_ref }} 59 | git checkout -f ${{ github.head_ref }} 60 | git add .github/badges/coverage.svg .github/badges/build.svg 61 | git commit -m "Update build and coverage badges" || exit 0 62 | git push origin ${{ github.head_ref }} 63 | -------------------------------------------------------------------------------- /tests/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Profiling Process 2 | 3 | ## Compute and Runtime Performance 4 | 5 | 1) Set up the appropriate config yaml in `profiler_configs/` and make any modifications to the script as needed. 6 | 2) Run the script as you would any of the examples e.g. 7 | ```bash 8 | python cycle_trainer_profiling.py profiler_configs/causal.yaml 9 | ``` 10 | 11 | 3) The script will generate a folder in `cycle_trainer_profiles/` with the start date and time of the run. 12 | 4) Snakeviz can be used to visualise the flame graph for the run: 13 | ```bash 14 | # If working on remote machine 15 | ssh -L 8080:localhost:8080 @ 16 | 17 | # Run snakeviz 18 | snakeviz cycle_trainer_profiles//profile.prof --server 19 | >> flame graph can be viewed at http://localhost:8080/ 20 | ``` 21 | 22 | ### The Sunburst View 23 | 24 | Read from the center outwards where each ring represents a level of function calls. The size (arc length) of each segment shows how much time that function took. The colors help distinguish different segments. Inner segments contain the functions that called them (outer segments). 25 | 26 | ![Sunburst View](../../docs/assets/sunburst_snakeviz.png) 27 | 28 | ### Icicle View 29 | 30 | Each row represents a level in the call stack where the width of each bar represents the time taken by that function. Functions at the top call the functions below them. The full width represents 100% of the execution time. 31 | 32 | ![Icicle View](../../docs/assets/icicle_snakeviz.png) 33 | 34 | 35 | ## Memory Usage 36 | 37 | Memory stats are recorded while computing the runtime stats. To view the CUDA memory snapshots, visit the [pytorch memory viz webpage](https://pytorch.org/memory_viz) and drag and drop in your `cuda_memory_snapshots.pickle` file. 38 | 39 | 40 | ### References 41 | - [Understanding GPU Memory 1: Visualizing All Allocations over Time](https://pytorch.org/blog/understanding-gpu-memory-1/) 42 | - [Understanding GPU Memory 2: Finding and Removing Reference Cycles](https://pytorch.org/blog/understanding-gpu-memory-2/) -------------------------------------------------------------------------------- /src/cycleformers/exceptions.py: -------------------------------------------------------------------------------- 1 | class CycleModelError(Exception): 2 | """Exception raised when models are not properly configured for CycleTrainer. 3 | 4 | This exception is raised when models are either not provided or are invalid for use with CycleTrainer. 5 | """ 6 | 7 | def __init__(self, message="CycleTrainer is missing valid models for training."): 8 | self.message = message 9 | super().__init__(self.message) 10 | 11 | 12 | class MissingModelError(CycleModelError): 13 | """Exception raised when a model is not provided for CycleTrainer. 14 | 15 | This exception is raised when a model is not provided for CycleTrainer. 16 | """ 17 | 18 | default_message = ( 19 | "CycleTrainer requires two models to train but only received a single model. " 20 | "To train using adapter-switching, set `args.use_macct = True`. Otherwise, pass two separate models " 21 | "for cycle training." 22 | ) 23 | 24 | def __init__(self, message=None): 25 | self.message = message or self.default_message 26 | super().__init__(self.message) 27 | 28 | 29 | class MACCTModelError(Exception): 30 | """Exception raised when MACCT models are not properly configured for CycleTrainer. 31 | 32 | This exception is raised when MACCT models are either not provided or are invalid for use with CycleTrainer. 33 | """ 34 | 35 | def __init__(self, message="There is something wrong with the MACCT model or configuration provided."): 36 | self.message = message 37 | super().__init__(self.message) 38 | 39 | 40 | class InvalidCycleKeyError(Exception): 41 | """Exception raised when an invalid model key is provided to CycleTrainer. 42 | 43 | This exception is raised when a cycle key other than 'A' or 'B' is provided when configuring or accessing 44 | models in CycleTrainer. 45 | """ 46 | 47 | def __init__(self, message="Invalid cycle key provided. Only 'A' and 'B' are valid keys for cycle training."): 48 | self.message = message 49 | super().__init__(self.message) 50 | -------------------------------------------------------------------------------- /tests/sample_data/wmt14/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "parquet", 3 | "citation": "", 4 | "config_name": "de-en", 5 | "dataset_name": "wmt14", 6 | "dataset_size": 1359920533, 7 | "description": "", 8 | "download_checksums": { 9 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00000-of-00003.parquet": { 10 | "num_bytes": 279527247, 11 | "checksum": null 12 | }, 13 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00001-of-00003.parquet": { 14 | "num_bytes": 264970514, 15 | "checksum": null 16 | }, 17 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00002-of-00003.parquet": { 18 | "num_bytes": 272987200, 19 | "checksum": null 20 | }, 21 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/validation-00000-of-00001.parquet": { 22 | "num_bytes": 473798, 23 | "checksum": null 24 | }, 25 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/test-00000-of-00001.parquet": { 26 | "num_bytes": 508753, 27 | "checksum": null 28 | } 29 | }, 30 | "download_size": 818467512, 31 | "features": { 32 | "translation": { 33 | "languages": [ 34 | "de", 35 | "en" 36 | ], 37 | "_type": "Translation" 38 | } 39 | }, 40 | "homepage": "", 41 | "license": "", 42 | "size_in_bytes": 2178388045, 43 | "splits": { 44 | "train": { 45 | "name": "train", 46 | "num_bytes": 1358406800, 47 | "num_examples": 4508785, 48 | "shard_lengths": [ 49 | 1521929, 50 | 1721928, 51 | 1264928 52 | ], 53 | "dataset_name": "wmt14" 54 | }, 55 | "validation": { 56 | "name": "validation", 57 | "num_bytes": 736407, 58 | "num_examples": 3000, 59 | "dataset_name": "wmt14" 60 | }, 61 | "test": { 62 | "name": "test", 63 | "num_bytes": 777326, 64 | "num_examples": 3003, 65 | "dataset_name": "wmt14" 66 | } 67 | }, 68 | "version": { 69 | "version_str": "0.0.0", 70 | "major": 0, 71 | "minor": 0, 72 | "patch": 0 73 | } 74 | } -------------------------------------------------------------------------------- /tests/sample_data/wmt14/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "parquet", 3 | "citation": "", 4 | "config_name": "de-en", 5 | "dataset_name": "wmt14", 6 | "dataset_size": 1359920533, 7 | "description": "", 8 | "download_checksums": { 9 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00000-of-00003.parquet": { 10 | "num_bytes": 279527247, 11 | "checksum": null 12 | }, 13 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00001-of-00003.parquet": { 14 | "num_bytes": 264970514, 15 | "checksum": null 16 | }, 17 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00002-of-00003.parquet": { 18 | "num_bytes": 272987200, 19 | "checksum": null 20 | }, 21 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/validation-00000-of-00001.parquet": { 22 | "num_bytes": 473798, 23 | "checksum": null 24 | }, 25 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/test-00000-of-00001.parquet": { 26 | "num_bytes": 508753, 27 | "checksum": null 28 | } 29 | }, 30 | "download_size": 818467512, 31 | "features": { 32 | "translation": { 33 | "languages": [ 34 | "de", 35 | "en" 36 | ], 37 | "_type": "Translation" 38 | } 39 | }, 40 | "homepage": "", 41 | "license": "", 42 | "size_in_bytes": 2178388045, 43 | "splits": { 44 | "train": { 45 | "name": "train", 46 | "num_bytes": 1358406800, 47 | "num_examples": 4508785, 48 | "shard_lengths": [ 49 | 1521929, 50 | 1721928, 51 | 1264928 52 | ], 53 | "dataset_name": "wmt14" 54 | }, 55 | "validation": { 56 | "name": "validation", 57 | "num_bytes": 736407, 58 | "num_examples": 3000, 59 | "dataset_name": "wmt14" 60 | }, 61 | "test": { 62 | "name": "test", 63 | "num_bytes": 777326, 64 | "num_examples": 3003, 65 | "dataset_name": "wmt14" 66 | } 67 | }, 68 | "version": { 69 | "version_str": "0.0.0", 70 | "major": 0, 71 | "minor": 0, 72 | "patch": 0 73 | } 74 | } -------------------------------------------------------------------------------- /tests/sample_data/wmt14/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "parquet", 3 | "citation": "", 4 | "config_name": "de-en", 5 | "dataset_name": "wmt14", 6 | "dataset_size": 1359920533, 7 | "description": "", 8 | "download_checksums": { 9 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00000-of-00003.parquet": { 10 | "num_bytes": 279527247, 11 | "checksum": null 12 | }, 13 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00001-of-00003.parquet": { 14 | "num_bytes": 264970514, 15 | "checksum": null 16 | }, 17 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00002-of-00003.parquet": { 18 | "num_bytes": 272987200, 19 | "checksum": null 20 | }, 21 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/validation-00000-of-00001.parquet": { 22 | "num_bytes": 473798, 23 | "checksum": null 24 | }, 25 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/test-00000-of-00001.parquet": { 26 | "num_bytes": 508753, 27 | "checksum": null 28 | } 29 | }, 30 | "download_size": 818467512, 31 | "features": { 32 | "translation": { 33 | "languages": [ 34 | "de", 35 | "en" 36 | ], 37 | "_type": "Translation" 38 | } 39 | }, 40 | "homepage": "", 41 | "license": "", 42 | "size_in_bytes": 2178388045, 43 | "splits": { 44 | "train": { 45 | "name": "train", 46 | "num_bytes": 1358406800, 47 | "num_examples": 4508785, 48 | "shard_lengths": [ 49 | 1521929, 50 | 1721928, 51 | 1264928 52 | ], 53 | "dataset_name": "wmt14" 54 | }, 55 | "validation": { 56 | "name": "validation", 57 | "num_bytes": 736407, 58 | "num_examples": 3000, 59 | "dataset_name": "wmt14" 60 | }, 61 | "test": { 62 | "name": "test", 63 | "num_bytes": 777326, 64 | "num_examples": 3003, 65 | "dataset_name": "wmt14" 66 | } 67 | }, 68 | "version": { 69 | "version_str": "0.0.0", 70 | "major": 0, 71 | "minor": 0, 72 | "patch": 0 73 | } 74 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # https://medium.com/@dkraczkowski/crafting-a-ci-pipeline-my-experience-with-github-actions-python-aws-c67f428adee8 2 | -include .env 3 | SOURCE_DIR = src 4 | TEST_DIR = tests 5 | EXAMPLE_DIR = examples 6 | PROJECT_DIRS = $(SOURCE_DIR) $(TEST_DIR) $(EXAMPLE_DIR) 7 | PWD := $(dir $(abspath $(firstword $(MAKEFILE_LIST)))) 8 | PROJECT_NAME ?= Cycleformers 9 | PROJECT_VERSION ?= v$(shell poetry version -s) 10 | PYTHON_VERSION ?= 3.11 11 | .DEFAULT_GOAL := all 12 | 13 | .PHONY: init-env init check-toml lint-src format lint quality audit test test-all all clean info build-docs 14 | 15 | init-env: 16 | @if [ ! -f .env ]; then \ 17 | echo "Creating .env file..."; \ 18 | echo "PROJECT_NAME=${PROJECT_NAME}" > .env; \ 19 | echo "PYTHON_VERSION=${PYTHON_VERSION}" >> .env; \ 20 | echo "export PYTHONPATH=${SOURCE_DIR}" >> .env; \ 21 | else \ 22 | echo "using existing .env file..."; \ 23 | fi 24 | 25 | init: init-env 26 | @echo "Installing dependencies..." 27 | poetry install 28 | poetry run pre-commit install 29 | 30 | -check-toml: 31 | poetry check 32 | 33 | -reformat-src: 34 | poetry run ruff format $(PROJECT_DIRS) 35 | poetry run ruff check --select I --fix $(PROJECT_DIRS) 36 | 37 | -lint-src: 38 | poetry run ruff check --fix $(SOURCE_DIR) 39 | poetry run mypy --install-types --show-error-codes --non-interactive $(SOURCE_DIR) 40 | 41 | format: -check-toml -reformat-src 42 | 43 | lint: -lint-src 44 | 45 | quality: 46 | @poetry run python -c "from cycleformers import *" || (echo "Import failure. Unprotected import?"; exit 1) 47 | @make format 48 | @make lint 49 | 50 | audit: 51 | poetry run bandit -r $(SOURCE_DIR) -x $(TEST_DIR) 52 | 53 | test: 54 | poetry run pytest $(TEST_DIR) 55 | 56 | test-all: 57 | poetry run pytest $(TEST_DIR) -v --slow --meta -n auto --instafail 58 | 59 | build-docs: 60 | poetry run mkdocs build 61 | 62 | all: audit quality test build-docs 63 | 64 | clean: 65 | rm -rf dist/ 66 | rm -rf build/ 67 | rm -rf *.egg-info 68 | find . -type d -name '__pycache__' -exec rm -rf {} + 69 | find . -type d -name '.pytest_cache' -exec rm -rf {} + 70 | find . -type d -name '.mypy_cache' -exec rm -rf {} + 71 | 72 | info: 73 | @echo "Project name: ${PROJECT_NAME}" 74 | @echo "Project version: ${PROJECT_VERSION}" 75 | @echo "Python version: ${PYTHON_VERSION}" -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | ### Training 2 | 3 | The `CycleTrainer` class is an extension but significant redesign of the 🤗 Transformers trainer, designed to abstract away the specifics of training while remaining configurable. Both Seq2Seq and Causal architectures are supported, each able to train via PEFT adapter swapping for memory efficient configurations. Check the [docs] for [usage] details and [examples]. 4 | 5 | To train using two identical models the following sample code can be used along with two datasets: 6 | 7 | ```python 8 | from cycleformers import CycleTrainer, CycleTrainingArguments 9 | 10 | model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") 11 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 12 | 13 | args = CycleTrainingArguments(output_dir="gpt2-cct") 14 | trainer = CycleTrainer( 15 | args, 16 | models = model 17 | tokenizers = tokenizer 18 | train_dataset_A = dataset_A, 19 | train_dataset_B = dataset_B 20 | ) 21 | trainer.train() 22 | ``` 23 | 24 | Any two models (🚧 currently both seq2seq or both causal) can be combined together for completely customisable training: 25 | 26 | ```python 27 | model_A = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") 28 | model_B = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", device_map="auto") 29 | tokenizer_A = AutoTokenizer.from_pretrained("gpt2") 30 | tokenizer_B = AutoTokenizer.from_pretrained("google/flan-t5-small") 31 | 32 | trainer = CycleTrainer( 33 | args, 34 | models = { 35 | "A": model_A, 36 | "B": model_B 37 | } 38 | tokenizers = { 39 | "A": tokenizer_A, 40 | "B": tokenizer_B 41 | } 42 | train_dataset_A = dataset_A, 43 | train_dataset_B = dataset_B 44 | ) 45 | ``` 46 | 47 | ### Multi-Adapter Cycle-Consistency Training (MACCT) 48 | 49 | The `CycleTrainer` class is also setup to accept a single base model and train two PEFT adapters ontop of it, switching between them to emulate the two model setup. This allows for the training of `7.5x larger models` for the same memory footprint: 50 | 51 | ```python 52 | peft_config = PeftConfig( 53 | task_type="CAUSAL_LM", 54 | r=16, 55 | lora_alpha=32, 56 | target_modules="all-linear", 57 | inference_mode=False, 58 | bias="none" 59 | ) 60 | 61 | args = CycleTrainingArguments(output_dir="gpt2-macct") 62 | trainer = CycleTrainer( 63 | args, 64 | model = model, 65 | tokenizer = tokenizer, 66 | peft_configs = peft_config # Or same A, B dict 67 | ) 68 | ``` -------------------------------------------------------------------------------- /src/cycleformers/command/cli_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from collections.abc import Iterable 4 | 5 | import yaml 6 | from transformers.hf_argparser import DataClassType, HfArgumentParser 7 | 8 | 9 | VALID_TASKS = ["train"] 10 | 11 | 12 | class CfArgumentParser(HfArgumentParser): 13 | def __init__( 14 | self, 15 | dataclass_types: list[DataClassType] | None = None, 16 | task: str | None = None, 17 | **kwargs, 18 | ): 19 | # Make sure dataclass_types is an iterable 20 | if dataclass_types is None: 21 | dataclass_types = [] 22 | elif not isinstance(dataclass_types, Iterable): 23 | dataclass_types = [dataclass_types] 24 | 25 | self.task = task 26 | super().__init__(dataclass_types=dataclass_types, **kwargs) 27 | 28 | def _parse_yaml_config(self, config_file: str): 29 | with open(config_file, "r") as f: 30 | config = yaml.safe_load(f) 31 | 32 | # Convert config YAMLs to be A_ and B_ 33 | config_a = {f"A_{k}": v for k, v in config.pop("A", {}).items()} 34 | config_b = {f"B_{k}": v for k, v in config.pop("B", {}).items()} 35 | 36 | file_args = [] 37 | for c in [config, config_a, config_b]: 38 | for k, v in c.items(): 39 | v = json.dumps(v) if isinstance(v, dict) else str(v) 40 | 41 | if f"--{k}" in file_args: 42 | raise ValueError(f"Duplicate argument {k} found in config files") 43 | 44 | file_args.extend([f"--{k}", v]) 45 | return file_args 46 | 47 | def parse_args_and_config(self): 48 | args = sys.argv[1:] 49 | 50 | # Handle task argument 51 | if not self.task and len(args) == 0: 52 | raise ValueError(f"No task provided. Task must be one of {VALID_TASKS}.") 53 | 54 | # Task is already set 55 | if self.task and args[0] in VALID_TASKS: 56 | raise ValueError(f"Task already set by script to {self.task}. Try again without {args[0]}.") 57 | # Task is not set and arg is a valid task 58 | elif not self.task: 59 | if args[0] in VALID_TASKS: 60 | self.task = args.pop(0) 61 | else: 62 | raise ValueError(f"Task must be one of {VALID_TASKS}, got {args[0]}") 63 | 64 | if len(args) > 0 and (args[0].endswith(".yaml") or args[0].endswith(".yml")): 65 | config_file = args.pop(0) 66 | file_args = self._parse_yaml_config(config_file) 67 | args = file_args + args 68 | 69 | return self.parse_args_into_dataclasses(args) 70 | -------------------------------------------------------------------------------- /tests/task_processors/test_ner_processors.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from cycleformers.task_processors.ner import CONLL2003Processor, CONLL2003ProcessorConfig 4 | 5 | 6 | SAMPLES_DIR = Path(__file__).parent.parent / "sample_data" 7 | 8 | 9 | class TestCONLL2003Processor: 10 | def test_preprocess(self, monkeypatch, temp_dir): 11 | """Test preprocessing of CONLL2003 dataset.""" 12 | monkeypatch.setenv("HF_DATASETS_CACHE", str(temp_dir)) 13 | 14 | config = CONLL2003ProcessorConfig( 15 | dataset_name=SAMPLES_DIR / "conll2003", 16 | cache_dir=str(temp_dir), 17 | dataset_seed=42, # Fixed seed for reproducible tests 18 | ) 19 | processor = CONLL2003Processor(config) 20 | dataset_A, dataset_B = processor.process() 21 | 22 | # Check that all splits are preserved 23 | assert set(dataset_A.keys()) == {"train", "validation", "test"} 24 | assert set(dataset_B.keys()) == {"train", "validation", "test"} 25 | 26 | # Check that training data is non-parallel (no labels) 27 | assert "label" not in dataset_A["train"].column_names 28 | assert "label" not in dataset_B["train"].column_names 29 | 30 | # Verify training splits are properly shuffled and unaligned 31 | train_texts_A = dataset_A["train"]["text"] 32 | train_texts_B = dataset_B["train"]["text"] 33 | 34 | # Check no exact alignment exists 35 | assert not any( 36 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i])) 37 | for i in range(len(train_texts_B)) 38 | ), "Training splits appear to be aligned with some offset" 39 | 40 | # Verify that shuffling is deterministic with the same seed 41 | config2 = CONLL2003ProcessorConfig( 42 | dataset_name=SAMPLES_DIR / "conll2003", 43 | dataset_seed=42, 44 | cache_dir=str(temp_dir), 45 | ) 46 | processor2 = CONLL2003Processor(config2) 47 | dataset_A2, dataset_B2 = processor2.process() 48 | 49 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"]) 50 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"]) 51 | 52 | # Check that evaluation splits maintain parallel data 53 | for key in ["validation", "test"]: 54 | assert dataset_A[key]["text"] == dataset_B[key]["label"] 55 | assert dataset_A[key]["label"] == dataset_B[key]["text"] 56 | assert dataset_A[key]["text"] != dataset_B[key]["text"] 57 | assert dataset_A[key]["label"] != dataset_B[key]["label"] 58 | -------------------------------------------------------------------------------- /docs/conceptual_reference/macct.md: -------------------------------------------------------------------------------- 1 | # Multi-Adapter Cycle-Consistency Training (MACCT) 2 | 3 | Multi-Adapter Cycle-Consistency Training (MACCT) is an implementation of the Cycle-Consistency Training (CCT) training method using [PEFT](https://huggingface.co/docs/peft/en/index) LoRA adapters ([hu et al. 2022](https://openreview.net/forum?id=nZeVKeeFYf9)) inplace of full model weights to learn the A->B and B->A translation mappings. MACCT shares the base model weights which are frozen; therefore, we greatly reduce the number of optimizer states, making a significant reduction in memory footprint. So much so, that we can load a frozen base model that is `~7.5x larger` than either model in the full fine-tuning case. 4 | 5 | [A HELPFUL DIAGRAM WILL GO HERE] 6 | 7 | 8 | ??? "How are these figures calculated?" 9 | 10 | Assuming a restricted case where we just look at static memory, i.e. model weights and optimizer states, the memory savings are significant. We will assume in the dual-model ($DMCCT$) case that both models are the same, likewise in multi-adapter CCT ($MACCT$) we assume both LoRA adapters are initialised the same. In all cases we use the AdamW optimizer as it is the most popular: 11 | 12 | With $\theta$ as the foundational model parameters, $\phi$ as the LoRA adapter parameters, $p$ as the number of bits per parameter ${\theta}_i$, $q$ as the number of bits per parameter ${\phi}_i$, and $r$ as the ratio of base model size $|\theta|$ to LoRA adapter size $|\phi|$, we can derive the memory savings as follows: 13 | 14 | $$ 15 | \begin{aligned} 16 | DMCCT =& \left[ 2 \text{ models} * |\theta| \text{ params} * p \text{ bits} \right] + \left[ 2 \text{ models} * |\theta| \text{ params} * 2 \text{ states} * (4*8) \text{ bits} \right] \\ 17 | =& 2|\theta|(p + 64) 18 | \end{aligned} \tag{1} 19 | $$ 20 | 21 | $$ 22 | \begin{aligned} 23 | MACCT =& [ |\theta| \text{ params} * p \text{ bits} ] + [ 2 \text{ loras} * (|\theta| \text{ params} * r) * { q \text{ bits} + 2 \text{ states} * (4*8) \text{ bits} } ] \\ 24 | =& |\theta|(p + 2r(q + 64)) 25 | \end{aligned} \tag{2} 26 | $$ 27 | 28 | Assuming $p = 16$, $q = 32$ (LoRA are trained in 32 bit by default) and $r = 0.03$ (~3% of base model size), and a 1B parameter model for each translation in $DMCCT$: 29 | 30 | $$ 31 | \begin{aligned} 32 | DMCCT =& 2 * 1e9 * (16 + 64) \\ 33 | =& 1.6e11 \text{ bits} \\ 34 | \\ 35 | DMCCT =& MACCT \\ 36 | 1.6e11 \text{ bits} =& N * (16 + 2 * 0.03 * (32 + 64)) \\ 37 | N =& \frac{1.6e11}{21.76} \\ 38 | \approx& 7.35B \text{ params} \\ 39 | \end{aligned} 40 | $$ 41 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Cycleformers 2 | site_url: https://wrmthorne.github.io/cycleformers 3 | repo_name: wrmthorne/cycleformers 4 | repo_url: https://github.com/wrmthorne/cycleformers 5 | 6 | theme: 7 | name: material 8 | palette: 9 | # Palette toggle for automatic mode 10 | - media: "(prefers-color-scheme)" 11 | scheme: default 12 | primary: black 13 | toggle: 14 | icon: material/brightness-auto 15 | name: Switch to light mode 16 | 17 | # Palette toggle for light mode 18 | - media: "(prefers-color-scheme: light)" 19 | scheme: default 20 | primary: black 21 | toggle: 22 | icon: material/brightness-7 23 | name: Switch to dark mode 24 | 25 | # Palette toggle for dark mode 26 | - media: "(prefers-color-scheme: dark)" 27 | scheme: slate 28 | primary: black 29 | toggle: 30 | icon: material/brightness-4 31 | name: Switch to system preference 32 | 33 | features: 34 | - navigation.top 35 | - navigation.tabs 36 | - navigation.path 37 | - navigation.footer 38 | - navigation.instant 39 | - navigation.sections 40 | - navigation.tracking 41 | - navigation.instant.prefetch 42 | - navigation.instant.progress 43 | - content.code.annotate 44 | - content.code.copy 45 | 46 | nav: 47 | - Get Started: 48 | - Cycleformers: index.md 49 | - Installation: installation.md 50 | - Usage: usage.md 51 | - Performance: performance.md 52 | - Conceptual Reference: 53 | - Conceptual Overview: conceptual_reference/index.md 54 | - Cycle-Consistency Training: conceptual_reference/cycle_consistency_training.md 55 | - MACCT: conceptual_reference/macct.md 56 | - Examples: 57 | - Example: examples/index.md 58 | - CycleNER: examples/cycle_ner.md 59 | - WMT2014: examples/wmt2014.md 60 | - API Reference: 61 | - CycleTrainer: api_reference/cycle_trainer.md 62 | - Configuration: api_reference/configuration.md 63 | - Cycles: api_reference/cycles.md 64 | - Task Processors: api_reference/task_processors.md 65 | 66 | markdown_extensions: 67 | - pymdownx.highlight: 68 | anchor_linenums: true 69 | line_spans: __span 70 | pygments_lang_class: true 71 | - pymdownx.superfences: 72 | custom_fences: 73 | - name: mermaid 74 | class: mermaid 75 | format: !!python/name:pymdownx.superfences.fence_code_format 76 | 77 | plugins: 78 | - mkdocstrings: 79 | default_handler: python 80 | handlers: 81 | python: 82 | paths: [cycleformers] # Path to your source code 83 | options: 84 | show_source: true 85 | show_root_heading: true 86 | heading_level: 1 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing Guide 3 | 4 | We are interested in external contributions to Cycleformers, particularly with those experienced in automated testing of language models, and with multi-device training. 5 | 6 | Please contribute code via pull requests. For small bug fixes or improvements, please feel free to submit a PR directly to the main branch. For larger bugs or features, please open an issue for discussion first. If you would like to request a feature, please also open an issue. For any PRs, please ensure that you have added tests and documentation where appropriate. 7 | 8 | All new features should be developed on a new branch and use [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) prefixes. The following are recommended: 9 | 10 | - `fix`: for bug fixes 11 | - `feat`: for new features 12 | - `build:` for changes that affect the build system or external dependencies 13 | - `refactor:` for code refactoring 14 | - `ci:` for changes to CI configuration files and scripts 15 | - `docs:` for documentation changes 16 | - `perf:` for performance improvements 17 | - `test:` for adding tests 18 | - `chore:` for other changes that don't fit into the other categories 19 | 20 | Although this rule is not strictly enforced for commits, pull requests must be titled with a valid conventional commit prefix or the merge pipeline will fail. Pull requests should be made to main and will not be released immediately. Changes will be accumulated on main until a release is decided. The release workflow will handle versioning, publishing, changelog generation, etc. and is triggered by a release workflow dispatch event, selecting `main` as the branch. 21 | 22 | ## Development Setup 23 | 24 | We use Poetry to manage dependencies on Python 3.11. To install Poetry, follow the documentation here: https://python-poetry.org/docs/master/#installing-with-the-official-installer 25 | 26 | We recommend the following command to setup the build environment for the module: 27 | 28 | ```bash 29 | make init 30 | ``` 31 | 32 | Many other useful commands are available in the Makefile which will automate formatting, linting, testing and building the documentation: 33 | 34 | - `make format` automatically fixes code style issues and standardizes formatting across the codebase. 35 | - `make lint` checks for code quality issues, potential bugs, and style violations. 36 | - `make audit` performs security vulnerability scanning to identify potential security risks. 37 | - `make test` executes the automated test suite to verify code functionality. 38 | - `make build-docs` generates the HTML documentation from source files based on the configuration in `mkdocs.yml`. 39 | - `make all` runs all of the above commands in sequence. 40 | - `make clean` removes build and distribution artifacts. 41 | - `make info` displays information about the project, including the project name, version, and Python version. 42 | 43 | 44 | ## Project Procedures 45 | 46 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import tempfile 3 | from pathlib import Path 4 | 5 | import pytest 6 | import yaml 7 | 8 | 9 | SAMPLES_DIR = Path(__file__).parent / "sample_data" 10 | 11 | yaml_base = { 12 | "per_device_train_batch_size": 1, 13 | "gradient_accumulation_steps": 1, 14 | "per_device_eval_batch_size": 1, 15 | "save_strategy": "steps", 16 | "max_steps": 3, 17 | "eval_steps": 1, 18 | "save_steps": 1, 19 | "logging_strategy": "steps", 20 | "logging_steps": 1, 21 | } 22 | 23 | lora_config = { 24 | "use_peft": True, 25 | "lora_r": 8, 26 | "lora_alpha": 16, 27 | } 28 | 29 | causal_yaml = { 30 | "model_name_or_path": "trl-internal-testing/tiny-LlamaForCausalLM-3.1", 31 | "A": { 32 | "model_name_or_path": "trl-internal-testing/tiny-LlamaForCausalLM-3.1", 33 | }, 34 | "B": { 35 | "model_name_or_path": "trl-internal-testing/tiny-LlamaForCausalLM-3.1", 36 | }, 37 | } | yaml_base 38 | 39 | seq2seq_yaml = { 40 | "model_name_or_path": "google/flan-t5-small", 41 | "A": { 42 | "model_name_or_path": "google/flan-t5-small", 43 | }, 44 | "B": { 45 | "model_name_or_path": "google/flan-t5-small", 46 | }, 47 | } | yaml_base 48 | 49 | causal_yaml_macct = {**lora_config, "lora_task_type": "CAUSAL_LM"} | causal_yaml 50 | 51 | seq2seq_yaml_macct = {**lora_config, "lora_task_type": "SEQ_2_SEQ_LM"} | seq2seq_yaml 52 | 53 | 54 | @pytest.mark.slow 55 | @pytest.mark.requires_gpu 56 | @pytest.mark.parametrize( 57 | "example_script,dataset_name", 58 | [("cycle_ner/train.py", SAMPLES_DIR / "conll2003"), ("translation_wmt14/train.py", SAMPLES_DIR / "wmt14")], 59 | ) 60 | @pytest.mark.parametrize( 61 | "config_yaml", 62 | [ 63 | ("causal-base", "causal.yaml", causal_yaml), 64 | ("seq2seq-base", "seq2seq.yaml", seq2seq_yaml), 65 | ("causal-macct", "causal-macct.yaml", causal_yaml_macct), 66 | ("seq2seq-macct", "seq2seq-macct.yaml", seq2seq_yaml_macct), 67 | ], 68 | ) 69 | def test_cycle_ner(example_script, dataset_name, config_yaml, temp_dir): 70 | out_dirname, filename, config = config_yaml 71 | out_dir = Path(temp_dir) / out_dirname 72 | 73 | config["output_dir"] = str(out_dir) 74 | config["dataset_name"] = str(dataset_name) 75 | yaml_file = Path(temp_dir) / filename 76 | with open(yaml_file, "w") as f: 77 | yaml.dump(config, f) 78 | 79 | project_root = Path(__file__) 80 | while project_root.name != "cycleformers": 81 | project_root = project_root.parent 82 | 83 | command = f"poetry run python {project_root}/examples/{example_script} {yaml_file}" 84 | 85 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 86 | # Check if the process completed successfully 87 | assert result.returncode == 0, f"Process failed with error: {result.stderr}" 88 | 89 | # Check for model checkpoitns 90 | checkpoint_dir = Path(config["output_dir"]) 91 | assert checkpoint_dir.exists(), "Checkpoint directory was not created" 92 | assert any(checkpoint_dir.iterdir()), "No checkpoints were saved" 93 | -------------------------------------------------------------------------------- /docs/performance.md: -------------------------------------------------------------------------------- 1 | # Performance 2 | 3 | ## 1 Potential Ways to Improve Evaluation Metric Performance 4 | 5 | 6 | lsLoRA ([Kalajdzievski, Damjan 2023](https://huggingface.co/papers/2312.03732)) scales adapters during each forward pass by `lora_alpha/math.sqrt(r)` which stabilizes performance at higher ranks $r$ ([rsLoRA docs](https://huggingface.co/papers/2312.03732)). 7 | 8 | DoRA ([Liu et al. 2024](https://arxiv.org/abs/2402.09353)) decomposes weight updates into magnitude and direction which they show is better correlated with full fine-tuning loss signals. This technique is particularly useful at low-ranks but can incurr a significant speed penalty. Significant performance gains can be made at the expense of higher VRAM usage by using `ephemeral_gpu_offload=True`. More info can be found at the [DoRA docs](ephemeral_gpu_offload=True). 9 | 10 | 11 | LoRA+ ([Hayou et al. 2024](https://arxiv.org/abs/2402.12354)) is an optimisation strategy for LoRA that allows for different learning rates for adapter matrices A and B. This can increase fine-tuning speed by up to 2x and *potentially* boost performance on some tasks by 1-2% ([LoRA+ docs](https://arxiv.org/abs/2402.12354)). 12 | 13 | 14 | ## 2 Potential Ways to Improve Throughput 15 | 16 | ### 2.1 Tokenization 17 | 18 | Wherever possible, we recommend using the same model for each direction of translation; at the very least, we recommend using models that share a tokenizer, i.e. are from the same generation of the same model family. 19 | 20 | Sending tokens from the GPU, back to CPU to detokenize, manipulate as strings, then re-tokenizing and sending back to GPU is costly. This is particularly significant for causal models that require sequences to be split and concatenated to produce the correct input_ids for training. We can skip this overhead by manipulating tokens as tensors on the GPU if we know that the tokenizer is compatible. 21 | 22 | Just having a compatible tokenizer is not always sufficient. If you perform any custom processing of synthetic samples before they are given to the training model such as applying a chat template, you may find that the tokenization overhead is unavoidable. 23 | 24 | ### 2.2 Specific Attention Kernels and Implementations 25 | 26 | Flash Attention 2 ([Dao et al. 2023](https://arxiv.org/abs/2205.14135)) is a drop-in replacement for the standard attention mechanism that significantly reduces memory usage and increases throughput. It can be installed via `pip install flash-attn` (see their [github](https://github.com/Dao-AILab/flash-attention) if having installation issues) and setting `attn_implementation="flash_attention_2"` in your model config. 27 | 28 | The Liger Kernel ([Hsu et al. 2024](https://arxiv.org/abs/2410.10989)) is a set of Triton kernels for efficient training of transformer models. They claim up to 20% throughput increase and 60% reduction in GPU memory usage. It can be installed via `pip install liger-kernel` and enabled by setting `use_liger_kernel=True` in your model config. 29 | 30 | These methods can be combined to see a very significant throughput increase and reduction in VRAM usage. 31 | 32 | ```python 33 | model = AutoModelForCausalLM.from_pretrained( 34 | # ... 35 | attn_implementation="flash_attention_2", 36 | use_liger_kernel=True 37 | ) 38 | ``` 39 | 40 | 🚧 # TODO: Test whether you can compile, use optimum, etc. on base model when using MACCT. -------------------------------------------------------------------------------- /src/cycleformers/task_processors/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from os import PathLike 4 | from pathlib import Path 5 | 6 | from datasets import DatasetDict, load_dataset, load_from_disk 7 | 8 | 9 | @dataclass 10 | class ProcessorConfig: 11 | """Configuration class for dataset processors. 12 | 13 | This class defines the common configuration parameters used across different processors. 14 | Each specific processor can extend this with additional task-specific parameters. 15 | """ 16 | 17 | dataset_name: str | PathLike[str] | None = None 18 | eval_split_ratio: float = 0.2 19 | dataset_seed: int = 42 20 | cache_dir: str | None = None 21 | # split: list[str] | str | None = None FIXME: not a key feature right now - taking too much time 22 | max_samples: int | None = None 23 | streaming: bool = False # FIXME: not a key feature right now - taking too much time 24 | 25 | 26 | class BaseProcessor(ABC): 27 | """Base class for dataset processors. 28 | 29 | This abstract base class defines the interface and common functionality for all dataset processors. 30 | Each specific format/task processor should inherit from this class and implement the required methods. 31 | 32 | The processor handles: 33 | 1. Loading source datasets 34 | 2. Converting to cycleformers format by splitting into two separate datasets A and B 35 | 3. Applying common transformations (train/val splitting, shuffling) 36 | 37 | If the dataset already has a train/val/test split, those splits will be preserved and only train will be made 38 | non-parallel. 39 | 40 | Args: 41 | config: Configuration object controlling processor behavior. If not provided, uses default CONFIG_CLS. 42 | 43 | Example: 44 | >>> config = ProcessorConfig(eval_split_ratio=0.2, dataset_seed=42) 45 | >>> processor = ConcreteProcessor(config) 46 | >>> dataset_A, dataset_B = processor.process() 47 | >>> print(dataset_A.keys()) 48 | dict_keys(['train', 'test']) 49 | """ 50 | 51 | def __init__(self, config: ProcessorConfig = ProcessorConfig()): 52 | self.config = config 53 | 54 | def load(self) -> DatasetDict: 55 | """Load the source dataset. Override for custom loading logic.""" 56 | if self.config.dataset_name is None: 57 | raise ValueError("No dataset name was provided. Cannot load `None`.") 58 | 59 | if Path(self.config.dataset_name).exists(): 60 | return load_from_disk(self.config.dataset_name) 61 | 62 | return load_dataset( 63 | self.config.dataset_name, 64 | cache_dir=self.config.cache_dir, 65 | streaming=self.config.streaming, 66 | ) 67 | 68 | @abstractmethod 69 | def preprocess(self, dataset: DatasetDict) -> DatasetDict: 70 | """Preprocess the dataset into two separate datasets A and B.""" 71 | raise NotImplementedError 72 | 73 | def process(self) -> DatasetDict: 74 | """Process the dataset into two separate datasets A and B. 75 | 76 | Returns: 77 | Tuple[DatasetDict, DatasetDict]: Two datasets A and B, each containing 'train' and 'test' splits 78 | """ 79 | dataset = self.load() 80 | 81 | if not isinstance(dataset, DatasetDict): 82 | dataset = DatasetDict({"train": dataset}) 83 | 84 | if dataset.keys() == ["train"]: 85 | dataset = dataset.train_test_split(test_size=self.config.eval_split_ratio, seed=self.config.dataset_seed) 86 | 87 | dataset_A, dataset_B = self.preprocess(dataset) 88 | 89 | dataset_A["train"] = dataset_A["train"].shuffle(seed=self.config.dataset_seed) 90 | dataset_B["train"] = dataset_B["train"].shuffle(seed=self.config.dataset_seed + 1) 91 | 92 | return dataset_A, dataset_B 93 | -------------------------------------------------------------------------------- /examples/cycle_ner/train.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer 2 | 3 | from cycleformers import ( 4 | CfArgumentParser, 5 | CycleTrainer, 6 | CycleTrainingArguments, 7 | ModelConfig, 8 | ModelConfigA, 9 | ModelConfigB, 10 | merge_configs, 11 | ) 12 | from cycleformers.import_utils import is_liger_kernel_available 13 | from cycleformers.task_processors.ner import CONLL2003Processor, CONLL2003ProcessorConfig 14 | from cycleformers.utils import VALID_LIGER_MODELS, get_peft_config 15 | 16 | 17 | if is_liger_kernel_available(): 18 | from liger_kernel.transformers import AutoLigerKernelForCausalLM 19 | 20 | 21 | def get_model_and_tokenizer(model_config, training_args): 22 | """Initialize model and tokenizer from config""" 23 | config = AutoConfig.from_pretrained( 24 | model_config.model_name_or_path, 25 | trust_remote_code=model_config.trust_remote_code, 26 | ) 27 | config.use_cache = False 28 | 29 | model_kwargs = {} 30 | 31 | if not config.is_encoder_decoder: 32 | if is_liger_kernel_available() and model_config.use_liger and config.model_type in VALID_LIGER_MODELS: 33 | model_class = AutoLigerKernelForCausalLM 34 | model_kwargs["use_liger_kernel"] = training_args.use_liger_kernel 35 | else: 36 | model_class = AutoModelForCausalLM 37 | else: 38 | model_class = AutoModelForSeq2SeqLM 39 | 40 | model = model_class.from_pretrained( 41 | model_config.model_name_or_path, 42 | revision=model_config.model_revision, 43 | config=config, 44 | trust_remote_code=model_config.trust_remote_code, 45 | attn_implementation=model_config.attn_implementation, 46 | torch_dtype=model_config.torch_dtype, 47 | device_map="auto", 48 | ) 49 | 50 | if training_args.gradient_checkpointing: 51 | model.enable_input_require_grads() 52 | 53 | # Print the actual dtype of the first parameter 54 | print(f"Model weights dtype: {next(model.parameters()).dtype}") 55 | 56 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) 57 | return model, tokenizer 58 | 59 | 60 | def main(): 61 | parser = CfArgumentParser( 62 | (CycleTrainingArguments, ModelConfig, ModelConfigA, ModelConfigB, CONLL2003ProcessorConfig), task="train" 63 | ) 64 | args, model_config_base, model_config_A, model_config_B, conll_config = parser.parse_args_and_config() 65 | model_config_base = merge_configs(model_config_base, model_config_A, model_config_B) 66 | args.model_config = model_config_base 67 | 68 | task_processor = CONLL2003Processor(conll_config) 69 | dataset_A, dataset_B = task_processor.process() 70 | 71 | # Get model A using merged A config 72 | model_A, tokenizer_A = get_model_and_tokenizer(args.model_config.A, args) 73 | 74 | # Train by adapter swapping 75 | if not args.use_macct: 76 | # Get model B using merged B config 77 | model_B, tokenizer_B = get_model_and_tokenizer(args.model_config.B, args) 78 | models = {"A": model_A, "B": model_B} 79 | tokenizers = {"A": tokenizer_A, "B": tokenizer_B} if tokenizer_A != tokenizer_B else tokenizer_A 80 | else: 81 | models = model_A 82 | tokenizers = tokenizer_A 83 | 84 | trainer = CycleTrainer( 85 | args=args, 86 | models=models, 87 | tokenizers=tokenizers, 88 | train_dataset_A=dataset_A["train"], 89 | train_dataset_B=dataset_B["train"], 90 | eval_dataset_A=dataset_A["eval"] if not args.eval_strategy == "no" else None, 91 | eval_dataset_B=dataset_B["eval"] if not args.eval_strategy == "no" else None, 92 | peft_configs=get_peft_config(model_config_base), 93 | ) 94 | 95 | trainer.train() 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cycleformers 2 | 3 |
4 | 5 | [![Python](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/) 6 | [![PyPI](https://img.shields.io/pypi/v/cycleformers)](https://pypi.org/project/cycleformers/) 7 | [![License: CC BY 4.0](https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by/4.0/) 8 | [![Coverage](.github/badges/coverage.svg)](https://codecov.io/gh/wrmthorne/cycleformers) 9 | [![Build Status](.github/badges/build.svg)](https://github.com/wrmthorne/cycleformers/actions/workflows) 10 | 11 |
12 | 13 | A Python library for efficient cycle-consistency training of transformer models. Cycleformers simplifies iterative back-translation with support for both causal and seq2seq architectures. We also implement Multi-Adapter Cycle-Consistency Training (MACCT), enabling training of LoRA adapters on a frozen base model for `7.5x` larger model capacity for the same memory footprint. 14 | 15 | ## Features 16 | 17 | - 🤗 Seamless integration with Hugging Face Transformers 18 | - 🚀 PEFT/LoRA support for memory-efficient training 19 | - 🤖 Compatible with both causal and seq2seq models 20 | - 🔥 Optimized for various hardware configurations 21 | 22 | 23 | ## Quick Tour 24 | 25 | ### Installation 26 | 27 | ```bash 28 | pip install cycleformers 29 | ``` 30 | 31 | ### Training 32 | 33 | The `CycleTrainer` class is an extension but significant redesign of the 🤗 Transformers trainer, designed to abstract away the specifics of training while remaining configurable. Both Seq2Seq and Causal architectures are supported, each able to train via PEFT adapter swapping for memory efficient configurations. Check the [docs] for [usage] details and [examples]. 34 | 35 | To train using two identical models the following sample code can be used along with two datasets: 36 | 37 | ```python 38 | from cycleformers import CycleTrainer, CycleTrainingArguments 39 | 40 | model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") 41 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 42 | 43 | args = CycleTrainingArguments(output_dir="gpt2-cct") 44 | trainer = CycleTrainer( 45 | args, 46 | models = model 47 | tokenizers = tokenizer 48 | train_dataset_A = dataset_A, 49 | train_dataset_B = dataset_B 50 | ) 51 | trainer.train() 52 | ``` 53 | 54 | Any two models (🚧 currently both seq2seq or both causal) can be combined together for completely customisable training: 55 | 56 | ```python 57 | model_A = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") 58 | model_B = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", device_map="auto") 59 | tokenizer_A = AutoTokenizer.from_pretrained("gpt2") 60 | tokenizer_B = AutoTokenizer.from_pretrained("google/flan-t5-small") 61 | 62 | trainer = CycleTrainer( 63 | args, 64 | models = { 65 | "A": model_A, 66 | "B": model_B 67 | } 68 | tokenizers = { 69 | "A": tokenizer_A, 70 | "B": tokenizer_B 71 | } 72 | train_dataset_A = dataset_A, 73 | train_dataset_B = dataset_B 74 | ) 75 | ``` 76 | 77 | ### Multi-Adapter Cycle-Consistency Training (MACCT) 78 | 79 | The `CycleTrainer` class is also setup to accept a single base model and train two PEFT adapters ontop of it, switching between them to emulate the two model setup. This allows for the training of `7.5x larger models` for the same memory footprint: 80 | 81 | ```python 82 | peft_config = PeftConfig( 83 | task_type="CAUSAL_LM", 84 | r=16, 85 | lora_alpha=32, 86 | target_modules="all-linear", 87 | inference_mode=False, 88 | bias="none" 89 | ) 90 | 91 | args = CycleTrainingArguments(output_dir="gpt2-macct") 92 | trainer = CycleTrainer( 93 | args, 94 | model = model, 95 | tokenizer = tokenizer, 96 | peft_configs = peft_config # Or same A, B dict 97 | ) 98 | ``` 99 | 100 | 101 | 102 | ## Citing 103 | 104 | If you use Cycleformers in your research, please cite: 105 | 106 | ```bibtex 107 | add once zenodo/paper citation is available 108 | ``` -------------------------------------------------------------------------------- /examples/translation_wmt14/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from datasets import load_from_disk 5 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer 6 | 7 | from cycleformers import ( 8 | CfArgumentParser, 9 | CycleTrainer, 10 | CycleTrainingArguments, 11 | ModelConfig, 12 | ModelConfigA, 13 | ModelConfigB, 14 | merge_configs, 15 | ) 16 | from cycleformers.import_utils import is_liger_kernel_available 17 | from cycleformers.task_processors.translation import TranslationProcessor, TranslationProcessorConfig 18 | from cycleformers.utils import VALID_LIGER_MODELS, get_peft_config 19 | 20 | 21 | if is_liger_kernel_available(): 22 | from liger_kernel.transformers import AutoLigerKernelForCausalLM 23 | 24 | 25 | def get_model_and_tokenizer(model_config, training_args): 26 | """Initialize model and tokenizer from config""" 27 | config = AutoConfig.from_pretrained( 28 | model_config.model_name_or_path, 29 | trust_remote_code=model_config.trust_remote_code, 30 | ) 31 | config.use_cache = False 32 | 33 | model_kwargs = {} 34 | 35 | if not config.is_encoder_decoder: 36 | if is_liger_kernel_available() and model_config.use_liger and config.model_type in VALID_LIGER_MODELS: 37 | model_class = AutoLigerKernelForCausalLM 38 | model_kwargs["use_liger_kernel"] = training_args.use_liger_kernel 39 | else: 40 | model_class = AutoModelForCausalLM 41 | else: 42 | model_class = AutoModelForSeq2SeqLM 43 | 44 | model = model_class.from_pretrained( 45 | model_config.model_name_or_path, 46 | revision=model_config.model_revision, 47 | config=config, 48 | trust_remote_code=model_config.trust_remote_code, 49 | attn_implementation=model_config.attn_implementation, 50 | torch_dtype=model_config.torch_dtype, 51 | device_map="auto", 52 | **model_kwargs, 53 | ) 54 | 55 | if training_args.gradient_checkpointing: 56 | model.enable_input_require_grads() 57 | 58 | # Print the actual dtype of the first parameter 59 | print(f"Model weights dtype: {next(model.parameters()).dtype}") 60 | 61 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) 62 | return model, tokenizer 63 | 64 | 65 | def main(): 66 | parser = CfArgumentParser( 67 | (CycleTrainingArguments, ModelConfig, ModelConfigA, ModelConfigB, TranslationProcessorConfig), task="train" 68 | ) 69 | args, model_config_base, model_config_A, model_config_B, translation_config = parser.parse_args_and_config() 70 | model_config_base = merge_configs(model_config_base, model_config_A, model_config_B) 71 | args.model_config = model_config_base 72 | 73 | task_processor = TranslationProcessor(translation_config) 74 | dataset_A, dataset_B = task_processor.process() 75 | 76 | # Get model A using merged A config 77 | model_A, tokenizer_A = get_model_and_tokenizer(args.model_config.A, args) 78 | 79 | # Train by adapter swapping 80 | if not args.use_macct: 81 | # Get model B using merged B config 82 | model_B, tokenizer_B = get_model_and_tokenizer(args.model_config.B, args) 83 | models = {"A": model_A, "B": model_B} 84 | tokenizers = {"A": tokenizer_A, "B": tokenizer_B} if tokenizer_A != tokenizer_B else tokenizer_A 85 | else: 86 | models = model_A 87 | tokenizers = tokenizer_A 88 | 89 | trainer = CycleTrainer( 90 | args=args, 91 | models=models, 92 | tokenizers=tokenizers, 93 | train_dataset_A=dataset_A["train"], 94 | train_dataset_B=dataset_B["train"], 95 | eval_dataset_A=dataset_A["test"] if not args.eval_strategy == "no" else None, 96 | eval_dataset_B=dataset_B["test"] if not args.eval_strategy == "no" else None, 97 | peft_configs=get_peft_config(model_config_base), 98 | ) 99 | 100 | trainer.train() 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /tests/configs/test_model_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cycleformers.model_config import ModelConfig, ModelConfigA, ModelConfigB, merge_configs 4 | 5 | 6 | class TestMergeConfigs: 7 | @pytest.mark.parametrize( 8 | "base_config, config_a, config_b, expected_a, expected_b", 9 | [ 10 | # Test case 1: Basic override of model names 11 | ( 12 | ModelConfig(model_name_or_path="base"), 13 | ModelConfigA(A_model_name_or_path="model_a"), 14 | ModelConfigB(B_model_name_or_path="model_b"), 15 | "model_a", 16 | "model_b", 17 | ), 18 | # Test case 2: Override with default values (should not override) 19 | ( 20 | ModelConfig(lora_r=64), 21 | ModelConfigA(A_lora_r=16), # default value 22 | ModelConfigB(B_lora_r=32), 23 | 64, # should keep base value since A uses default 24 | 32, 25 | ), 26 | # Test case 3: Multiple field overrides 27 | ( 28 | ModelConfig(model_name_or_path="base", lora_r=16, lora_alpha=32), 29 | ModelConfigA(A_model_name_or_path="model_a", A_lora_r=64, A_lora_alpha=128), 30 | ModelConfigB(B_model_name_or_path="model_b", B_lora_r=32, B_lora_alpha=64), 31 | (64, 128), # (lora_r, lora_alpha) 32 | (32, 64), 33 | ), 34 | ], 35 | ids=["basic_override", "default_value_handling", "multiple_fields"], 36 | ) 37 | def test_merge_configs_parametrized(self, base_config, config_a, config_b, expected_a, expected_b): 38 | result = merge_configs(base_config, config_a, config_b) 39 | 40 | if isinstance(expected_a, tuple): 41 | # Handle multiple field test case 42 | assert result.A.lora_r == expected_a[0] 43 | assert result.A.lora_alpha == expected_a[1] 44 | assert result.B.lora_r == expected_b[0] 45 | assert result.B.lora_alpha == expected_b[1] 46 | else: 47 | # Handle single field test cases 48 | if isinstance(expected_a, str): 49 | assert result.A.model_name_or_path == expected_a 50 | assert result.B.model_name_or_path == expected_b 51 | else: 52 | assert result.A.lora_r == expected_a 53 | assert result.B.lora_r == expected_b 54 | 55 | def test_merge_configs_preserves_base_values(self): 56 | base = ModelConfig(model_name_or_path="base", lora_r=32, trust_remote_code=True) 57 | config_a = ModelConfigA(A_model_name_or_path="model_a") 58 | config_b = ModelConfigB(B_model_name_or_path="model_b") 59 | 60 | result = merge_configs(base, config_a, config_b) 61 | 62 | # Check that base values are preserved 63 | assert result.A.lora_r == 32 64 | assert result.A.trust_remote_code is True 65 | assert result.B.lora_r == 32 66 | assert result.B.trust_remote_code is True 67 | 68 | def test_merge_configs_list_handling(self): 69 | base = ModelConfig(lora_target_modules=["query", "value"]) 70 | config_a = ModelConfigA(A_lora_target_modules=["key"]) 71 | config_b = ModelConfigB(B_lora_target_modules=["output"]) 72 | 73 | result = merge_configs(base, config_a, config_b) 74 | 75 | assert result.A.lora_target_modules == ["key"] 76 | assert result.B.lora_target_modules == ["output"] 77 | 78 | def test_merge_configs_original_unmodified(self): 79 | base = ModelConfig(model_name_or_path="base", lora_r=32) 80 | config_a = ModelConfigA(A_model_name_or_path="model_a", A_lora_r=64) 81 | config_b = ModelConfigB(B_model_name_or_path="model_b", B_lora_r=128) 82 | 83 | # Store original values 84 | original_base_model = base.model_name_or_path 85 | original_base_lora_r = base.lora_r 86 | 87 | merge_configs(base, config_a, config_b) 88 | 89 | # Check original configs weren't modified 90 | assert base.model_name_or_path == original_base_model 91 | assert base.lora_r == original_base_lora_r 92 | -------------------------------------------------------------------------------- /tests/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import pytest 5 | 6 | from cycleformers.utils import prefixed_view 7 | 8 | 9 | # Test fixtures 10 | @dataclass 11 | class SimpleConfig: 12 | name: str | None = None 13 | value: int = 42 14 | 15 | 16 | @dataclass 17 | class ComplexConfig: 18 | name: str 19 | values: list[int] = field(default_factory=list) 20 | optional: Optional[str] = None 21 | 22 | 23 | @dataclass(frozen=True) 24 | class FrozenConfig: 25 | name: str 26 | value: int = 42 27 | 28 | 29 | class TestPrefixedView: 30 | def test_basic_creation(self): 31 | @prefixed_view(SimpleConfig, "test_") 32 | class TestConfig: 33 | pass 34 | 35 | assert TestConfig.__annotations__ == {"test_name": str | None, "test_value": int} 36 | 37 | @pytest.mark.parametrize( 38 | "input_values,expected", 39 | [ 40 | ({"test_name": "example"}, SimpleConfig(name="example", value=42)), 41 | ({"test_name": "test", "test_value": 100}, SimpleConfig(name="test", value=100)), 42 | ({}, SimpleConfig(name=None, value=42)), 43 | ], 44 | ) 45 | def test_valid_instances(self, input_values, expected): 46 | @prefixed_view(SimpleConfig, "test_") 47 | class TestConfig: 48 | pass 49 | 50 | instance = TestConfig(**input_values) 51 | assert instance == expected 52 | 53 | @pytest.mark.parametrize( 54 | "prefix", 55 | [ 56 | "test_", 57 | "_", 58 | "PREFIX_", 59 | "A_", 60 | ], 61 | ) 62 | def test_prefix_styles(self, prefix): 63 | @prefixed_view(SimpleConfig, prefix) 64 | class TestConfig: 65 | pass 66 | 67 | # Verify prefixed attributes exist 68 | assert all(name.startswith(prefix) for name in TestConfig.__annotations__) 69 | 70 | def test_default_factory(self): 71 | @prefixed_view(ComplexConfig, "test_") 72 | class TestConfig: 73 | pass 74 | 75 | instance = TestConfig(test_name="test") 76 | assert instance.values == [] 77 | 78 | instance = TestConfig(test_name="test", test_values=[1, 2, 3]) 79 | assert instance.values == [1, 2, 3] 80 | 81 | def test_frozen_dataclass(self): 82 | @prefixed_view(FrozenConfig, "test_") 83 | class TestConfig: 84 | pass 85 | 86 | instance = TestConfig(test_name="test") 87 | assert instance.name == "test" 88 | assert instance.value == 42 89 | 90 | @pytest.mark.parametrize( 91 | "invalid_base,prefix,expected_error", 92 | [ 93 | (None, "test_", TypeError), # Missing base class 94 | (dict, "test_", TypeError), # not a dataclass 95 | ], 96 | ) 97 | def test_creation_errors(self, invalid_base, prefix, expected_error): 98 | with pytest.raises(expected_error): 99 | 100 | @prefixed_view(invalid_base, prefix) 101 | class TestConfig: 102 | pass 103 | 104 | def test_nested_dataclass(self): 105 | @dataclass 106 | class NestedConfig: 107 | config: SimpleConfig 108 | name: str 109 | 110 | @prefixed_view(NestedConfig, "test_") 111 | class TestNestedConfig: 112 | pass 113 | 114 | nested = SimpleConfig(name="nested") 115 | instance = TestNestedConfig(test_config=nested, test_name="test") 116 | assert instance.config == nested 117 | assert instance.name == "test" 118 | 119 | def test_inheritance(self): 120 | @dataclass 121 | class BaseConfig: 122 | name: str 123 | 124 | @dataclass 125 | class ChildConfig(BaseConfig): 126 | value: int = 42 127 | 128 | @prefixed_view(ChildConfig, "test_") 129 | class TestConfig: 130 | pass 131 | 132 | # Should include both inherited and child fields 133 | assert set(TestConfig.__annotations__.keys()) == {"test_name", "test_value"} 134 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cycleformers" 3 | packages = [ 4 | { from = "src", include = "cycleformers" }, 5 | ] 6 | version = "0.1.0" 7 | description = "A comprehensive implementation of the cycle-consistency training paradigm, extending the Huggingface Transformers trainer API to accommodate arbitrary combinations of generative models." 8 | authors = ["William Thorne "] 9 | license = "Attribution 4.0 International" 10 | readme = "README.md" 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.11,<3.13" 14 | transformers = "4.46" 15 | torch = "^2.5.0" 16 | datasets = "^3.0.2" 17 | typing-extensions = {version = "^4.12.2", python = ">=3.8,<3.12"} 18 | peft = "^0.13.2" # TODO: Add importutils is_peft_available methods and remove dependency 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | bandit = "^1.8.0" 22 | pre-commit = "^4.0.1" 23 | ruff = "^0.7.1" 24 | mypy = "^1.13.0" 25 | semantic-release = "^0.1.0" 26 | genbadge = {extras = ["build", "coverage"], version = "^1.1.1"} 27 | pytest = "^8.3.3" 28 | pytest-cov = "^6.0.0" 29 | pytest-benchmark = "^5.1.0" 30 | pytest-xdist = "^3.6.1" 31 | pytest-sugar = "^1.0.0" 32 | pytest-instafail = "^0.5.0" 33 | pytest-flakefinder = "^1.1.0" 34 | pytest-picked = "^0.5.1" 35 | pytest-random-order = "^1.1.1" 36 | 37 | [tool.poetry.group.profiling.dependencies] 38 | memory-profiler = "^0.61.0" 39 | snakeviz = "^2.2.2" 40 | torch-tb-profiler = "^0.4.3" 41 | tensorboard-plugin-profile = "^2.18.0" 42 | 43 | [tool.poetry.group.docs.dependencies] 44 | mkdocs = "^1.6.1" 45 | mkdocstrings = {extras = ["python"], version = "^0.26.2"} 46 | mkdocs-material = "^9.5.43" 47 | pygments = "^2.18.0" 48 | 49 | [tool.ruff] 50 | # Exclude a variety of commonly ignored directories. 51 | exclude = [ 52 | ".bzr", 53 | ".direnv", 54 | ".eggs", 55 | ".git", 56 | ".git-rewrite", 57 | ".hg", 58 | ".ipynb_checkpoints", 59 | ".mypy_cache", 60 | ".nox", 61 | ".pants.d", 62 | ".pyenv", 63 | ".pytest_cache", 64 | ".pytype", 65 | ".ruff_cache", 66 | ".svn", 67 | ".tox", 68 | ".venv", 69 | ".vscode", 70 | "__pypackages__", 71 | "_build", 72 | "buck-out", 73 | "build", 74 | "dist", 75 | "node_modules", 76 | "site-packages", 77 | "venv", 78 | "misc", 79 | ] 80 | 81 | # Same as Transformers. 82 | line-length = 119 83 | indent-width = 4 84 | 85 | target-version = "py311" 86 | 87 | [tool.ruff.lint] 88 | # Same as Transformers. 89 | ignore = ["C901", "E501", "E741", "F402", "F823" ] 90 | select = [ 91 | "C", 92 | "E", 93 | "F", 94 | "I", 95 | "W", 96 | "UP006", # Use built-in types instead of typing ones (list vs List) 97 | "UP007", # Use | instead of Union 98 | "UP008", # Use | None instead of Optional 99 | "UP009", # Use collections.abc instead of typing 100 | "UP035", # Remove redundant unions 101 | ] 102 | 103 | [tool.ruff.lint.isort] 104 | lines-after-imports = 2 105 | known-first-party = ["cycleformers"] 106 | 107 | [tool.ruff.format] 108 | # Same as Transformers. 109 | quote-style = "double" 110 | indent-style = "space" 111 | skip-magic-trailing-comma = false 112 | line-ending = "auto" 113 | 114 | [tool.ruff.lint.per-file-ignores] 115 | # Ignore import violations in all `__init__.py` files. 116 | "__init__.py" = ["F401"] 117 | 118 | [[tool.mypy.overrides]] 119 | module = "accelerate.*,datasets.*,transformers.*" 120 | ignore_missing_imports = true 121 | 122 | [tool.pytest.ini_options] 123 | markers = [ 124 | "slow: mark test as slow to run (run with --slow to enable)", 125 | "meta: mark test as a test of the testing framework (run with --meta to enable)", 126 | "requires_gpu: mark test as requiring a GPU or taking far too long to run on a CPU (skip with --no-gpu)" 127 | ] 128 | 129 | [tool.coverage.run] 130 | data_file = ".coverage" 131 | source = ["src/cycleformers"] 132 | omit = [ 133 | "tests/*", 134 | "misc/*", 135 | ] 136 | 137 | [tool.coverage.report] 138 | exclude_lines = [ 139 | "pragma: no cover", 140 | "def __repr__", 141 | "raise NotImplementedError", 142 | ] 143 | 144 | [tool.semantic_release] 145 | version_variable = [ 146 | "src/cycleformers/__init__.py", 147 | "pyproject.toml:version", 148 | ] 149 | 150 | parser_angular_allowed_types = "build, chore, ci, docs, feat, fix, perf, style, refactor, test" 151 | parser_angular_minor_types = "feat" 152 | parser_angular_patch_types = "fix, perf" 153 | 154 | [build-system] 155 | requires = ["poetry-core"] 156 | build-backend = "poetry.core.masonry.api" 157 | -------------------------------------------------------------------------------- /tests/sample_data/conll2003/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "conll2003", 3 | "citation": "@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,\n title = \"Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition\",\n author = \"Tjong Kim Sang, Erik F. and\n De Meulder, Fien\",\n booktitle = \"Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003\",\n year = \"2003\",\n url = \"https://www.aclweb.org/anthology/W03-0419\",\n pages = \"142--147\",\n}\n", 4 | "config_name": "conll2003", 5 | "dataset_name": "conll2003", 6 | "dataset_size": 10252622, 7 | "description": "The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on\nfour types of named entities: persons, locations, organizations and names of miscellaneous entities that do\nnot belong to the previous three groups.\n\nThe CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on\na separate line and there is an empty line after each sentence. The first item on each line is a word, the second\na part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags\nand the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only\nif two phrases of the same type immediately follow each other, the first word of the second phrase will have tag\nB-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2\ntagging scheme, whereas the original dataset uses IOB1.\n\nFor more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419\n", 8 | "download_checksums": { 9 | "https://data.deepai.org/conll2003.zip": { 10 | "num_bytes": 982975, 11 | "checksum": null 12 | } 13 | }, 14 | "download_size": 982975, 15 | "features": { 16 | "id": { 17 | "dtype": "string", 18 | "_type": "Value" 19 | }, 20 | "tokens": { 21 | "feature": { 22 | "dtype": "string", 23 | "_type": "Value" 24 | }, 25 | "_type": "Sequence" 26 | }, 27 | "pos_tags": { 28 | "feature": { 29 | "names": [ 30 | "\"", 31 | "''", 32 | "#", 33 | "$", 34 | "(", 35 | ")", 36 | ",", 37 | ".", 38 | ":", 39 | "``", 40 | "CC", 41 | "CD", 42 | "DT", 43 | "EX", 44 | "FW", 45 | "IN", 46 | "JJ", 47 | "JJR", 48 | "JJS", 49 | "LS", 50 | "MD", 51 | "NN", 52 | "NNP", 53 | "NNPS", 54 | "NNS", 55 | "NN|SYM", 56 | "PDT", 57 | "POS", 58 | "PRP", 59 | "PRP$", 60 | "RB", 61 | "RBR", 62 | "RBS", 63 | "RP", 64 | "SYM", 65 | "TO", 66 | "UH", 67 | "VB", 68 | "VBD", 69 | "VBG", 70 | "VBN", 71 | "VBP", 72 | "VBZ", 73 | "WDT", 74 | "WP", 75 | "WP$", 76 | "WRB" 77 | ], 78 | "_type": "ClassLabel" 79 | }, 80 | "_type": "Sequence" 81 | }, 82 | "chunk_tags": { 83 | "feature": { 84 | "names": [ 85 | "O", 86 | "B-ADJP", 87 | "I-ADJP", 88 | "B-ADVP", 89 | "I-ADVP", 90 | "B-CONJP", 91 | "I-CONJP", 92 | "B-INTJ", 93 | "I-INTJ", 94 | "B-LST", 95 | "I-LST", 96 | "B-NP", 97 | "I-NP", 98 | "B-PP", 99 | "I-PP", 100 | "B-PRT", 101 | "I-PRT", 102 | "B-SBAR", 103 | "I-SBAR", 104 | "B-UCP", 105 | "I-UCP", 106 | "B-VP", 107 | "I-VP" 108 | ], 109 | "_type": "ClassLabel" 110 | }, 111 | "_type": "Sequence" 112 | }, 113 | "ner_tags": { 114 | "feature": { 115 | "names": [ 116 | "O", 117 | "B-PER", 118 | "I-PER", 119 | "B-ORG", 120 | "I-ORG", 121 | "B-LOC", 122 | "I-LOC", 123 | "B-MISC", 124 | "I-MISC" 125 | ], 126 | "_type": "ClassLabel" 127 | }, 128 | "_type": "Sequence" 129 | } 130 | }, 131 | "homepage": "https://www.aclweb.org/anthology/W03-0419/", 132 | "license": "", 133 | "size_in_bytes": 11235597, 134 | "splits": { 135 | "train": { 136 | "name": "train", 137 | "num_bytes": 6931345, 138 | "num_examples": 14041, 139 | "dataset_name": "conll2003" 140 | }, 141 | "validation": { 142 | "name": "validation", 143 | "num_bytes": 1739223, 144 | "num_examples": 3250, 145 | "dataset_name": "conll2003" 146 | }, 147 | "test": { 148 | "name": "test", 149 | "num_bytes": 1582054, 150 | "num_examples": 3453, 151 | "dataset_name": "conll2003" 152 | } 153 | }, 154 | "version": { 155 | "version_str": "1.0.0", 156 | "major": 1, 157 | "minor": 0, 158 | "patch": 0 159 | } 160 | } -------------------------------------------------------------------------------- /tests/sample_data/conll2003/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "conll2003", 3 | "citation": "@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,\n title = \"Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition\",\n author = \"Tjong Kim Sang, Erik F. and\n De Meulder, Fien\",\n booktitle = \"Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003\",\n year = \"2003\",\n url = \"https://www.aclweb.org/anthology/W03-0419\",\n pages = \"142--147\",\n}\n", 4 | "config_name": "conll2003", 5 | "dataset_name": "conll2003", 6 | "dataset_size": 10252622, 7 | "description": "The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on\nfour types of named entities: persons, locations, organizations and names of miscellaneous entities that do\nnot belong to the previous three groups.\n\nThe CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on\na separate line and there is an empty line after each sentence. The first item on each line is a word, the second\na part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags\nand the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only\nif two phrases of the same type immediately follow each other, the first word of the second phrase will have tag\nB-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2\ntagging scheme, whereas the original dataset uses IOB1.\n\nFor more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419\n", 8 | "download_checksums": { 9 | "https://data.deepai.org/conll2003.zip": { 10 | "num_bytes": 982975, 11 | "checksum": null 12 | } 13 | }, 14 | "download_size": 982975, 15 | "features": { 16 | "id": { 17 | "dtype": "string", 18 | "_type": "Value" 19 | }, 20 | "tokens": { 21 | "feature": { 22 | "dtype": "string", 23 | "_type": "Value" 24 | }, 25 | "_type": "Sequence" 26 | }, 27 | "pos_tags": { 28 | "feature": { 29 | "names": [ 30 | "\"", 31 | "''", 32 | "#", 33 | "$", 34 | "(", 35 | ")", 36 | ",", 37 | ".", 38 | ":", 39 | "``", 40 | "CC", 41 | "CD", 42 | "DT", 43 | "EX", 44 | "FW", 45 | "IN", 46 | "JJ", 47 | "JJR", 48 | "JJS", 49 | "LS", 50 | "MD", 51 | "NN", 52 | "NNP", 53 | "NNPS", 54 | "NNS", 55 | "NN|SYM", 56 | "PDT", 57 | "POS", 58 | "PRP", 59 | "PRP$", 60 | "RB", 61 | "RBR", 62 | "RBS", 63 | "RP", 64 | "SYM", 65 | "TO", 66 | "UH", 67 | "VB", 68 | "VBD", 69 | "VBG", 70 | "VBN", 71 | "VBP", 72 | "VBZ", 73 | "WDT", 74 | "WP", 75 | "WP$", 76 | "WRB" 77 | ], 78 | "_type": "ClassLabel" 79 | }, 80 | "_type": "Sequence" 81 | }, 82 | "chunk_tags": { 83 | "feature": { 84 | "names": [ 85 | "O", 86 | "B-ADJP", 87 | "I-ADJP", 88 | "B-ADVP", 89 | "I-ADVP", 90 | "B-CONJP", 91 | "I-CONJP", 92 | "B-INTJ", 93 | "I-INTJ", 94 | "B-LST", 95 | "I-LST", 96 | "B-NP", 97 | "I-NP", 98 | "B-PP", 99 | "I-PP", 100 | "B-PRT", 101 | "I-PRT", 102 | "B-SBAR", 103 | "I-SBAR", 104 | "B-UCP", 105 | "I-UCP", 106 | "B-VP", 107 | "I-VP" 108 | ], 109 | "_type": "ClassLabel" 110 | }, 111 | "_type": "Sequence" 112 | }, 113 | "ner_tags": { 114 | "feature": { 115 | "names": [ 116 | "O", 117 | "B-PER", 118 | "I-PER", 119 | "B-ORG", 120 | "I-ORG", 121 | "B-LOC", 122 | "I-LOC", 123 | "B-MISC", 124 | "I-MISC" 125 | ], 126 | "_type": "ClassLabel" 127 | }, 128 | "_type": "Sequence" 129 | } 130 | }, 131 | "homepage": "https://www.aclweb.org/anthology/W03-0419/", 132 | "license": "", 133 | "size_in_bytes": 11235597, 134 | "splits": { 135 | "train": { 136 | "name": "train", 137 | "num_bytes": 6931345, 138 | "num_examples": 14041, 139 | "dataset_name": "conll2003" 140 | }, 141 | "validation": { 142 | "name": "validation", 143 | "num_bytes": 1739223, 144 | "num_examples": 3250, 145 | "dataset_name": "conll2003" 146 | }, 147 | "test": { 148 | "name": "test", 149 | "num_bytes": 1582054, 150 | "num_examples": 3453, 151 | "dataset_name": "conll2003" 152 | } 153 | }, 154 | "version": { 155 | "version_str": "1.0.0", 156 | "major": 1, 157 | "minor": 0, 158 | "patch": 0 159 | } 160 | } -------------------------------------------------------------------------------- /tests/sample_data/conll2003/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "conll2003", 3 | "citation": "@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,\n title = \"Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition\",\n author = \"Tjong Kim Sang, Erik F. and\n De Meulder, Fien\",\n booktitle = \"Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003\",\n year = \"2003\",\n url = \"https://www.aclweb.org/anthology/W03-0419\",\n pages = \"142--147\",\n}\n", 4 | "config_name": "conll2003", 5 | "dataset_name": "conll2003", 6 | "dataset_size": 10252622, 7 | "description": "The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on\nfour types of named entities: persons, locations, organizations and names of miscellaneous entities that do\nnot belong to the previous three groups.\n\nThe CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on\na separate line and there is an empty line after each sentence. The first item on each line is a word, the second\na part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags\nand the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only\nif two phrases of the same type immediately follow each other, the first word of the second phrase will have tag\nB-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2\ntagging scheme, whereas the original dataset uses IOB1.\n\nFor more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419\n", 8 | "download_checksums": { 9 | "https://data.deepai.org/conll2003.zip": { 10 | "num_bytes": 982975, 11 | "checksum": null 12 | } 13 | }, 14 | "download_size": 982975, 15 | "features": { 16 | "id": { 17 | "dtype": "string", 18 | "_type": "Value" 19 | }, 20 | "tokens": { 21 | "feature": { 22 | "dtype": "string", 23 | "_type": "Value" 24 | }, 25 | "_type": "Sequence" 26 | }, 27 | "pos_tags": { 28 | "feature": { 29 | "names": [ 30 | "\"", 31 | "''", 32 | "#", 33 | "$", 34 | "(", 35 | ")", 36 | ",", 37 | ".", 38 | ":", 39 | "``", 40 | "CC", 41 | "CD", 42 | "DT", 43 | "EX", 44 | "FW", 45 | "IN", 46 | "JJ", 47 | "JJR", 48 | "JJS", 49 | "LS", 50 | "MD", 51 | "NN", 52 | "NNP", 53 | "NNPS", 54 | "NNS", 55 | "NN|SYM", 56 | "PDT", 57 | "POS", 58 | "PRP", 59 | "PRP$", 60 | "RB", 61 | "RBR", 62 | "RBS", 63 | "RP", 64 | "SYM", 65 | "TO", 66 | "UH", 67 | "VB", 68 | "VBD", 69 | "VBG", 70 | "VBN", 71 | "VBP", 72 | "VBZ", 73 | "WDT", 74 | "WP", 75 | "WP$", 76 | "WRB" 77 | ], 78 | "_type": "ClassLabel" 79 | }, 80 | "_type": "Sequence" 81 | }, 82 | "chunk_tags": { 83 | "feature": { 84 | "names": [ 85 | "O", 86 | "B-ADJP", 87 | "I-ADJP", 88 | "B-ADVP", 89 | "I-ADVP", 90 | "B-CONJP", 91 | "I-CONJP", 92 | "B-INTJ", 93 | "I-INTJ", 94 | "B-LST", 95 | "I-LST", 96 | "B-NP", 97 | "I-NP", 98 | "B-PP", 99 | "I-PP", 100 | "B-PRT", 101 | "I-PRT", 102 | "B-SBAR", 103 | "I-SBAR", 104 | "B-UCP", 105 | "I-UCP", 106 | "B-VP", 107 | "I-VP" 108 | ], 109 | "_type": "ClassLabel" 110 | }, 111 | "_type": "Sequence" 112 | }, 113 | "ner_tags": { 114 | "feature": { 115 | "names": [ 116 | "O", 117 | "B-PER", 118 | "I-PER", 119 | "B-ORG", 120 | "I-ORG", 121 | "B-LOC", 122 | "I-LOC", 123 | "B-MISC", 124 | "I-MISC" 125 | ], 126 | "_type": "ClassLabel" 127 | }, 128 | "_type": "Sequence" 129 | } 130 | }, 131 | "homepage": "https://www.aclweb.org/anthology/W03-0419/", 132 | "license": "", 133 | "size_in_bytes": 11235597, 134 | "splits": { 135 | "train": { 136 | "name": "train", 137 | "num_bytes": 6931345, 138 | "num_examples": 14041, 139 | "dataset_name": "conll2003" 140 | }, 141 | "validation": { 142 | "name": "validation", 143 | "num_bytes": 1739223, 144 | "num_examples": 3250, 145 | "dataset_name": "conll2003" 146 | }, 147 | "test": { 148 | "name": "test", 149 | "num_bytes": 1582054, 150 | "num_examples": 3453, 151 | "dataset_name": "conll2003" 152 | } 153 | }, 154 | "version": { 155 | "version_str": "1.0.0", 156 | "major": 1, 157 | "minor": 0, 158 | "patch": 0 159 | } 160 | } -------------------------------------------------------------------------------- /src/cycleformers/task_processors/translation.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from dataclasses import dataclass 3 | 4 | from datasets import DatasetDict, IterableDatasetDict 5 | 6 | from .base import BaseProcessor, ProcessorConfig 7 | 8 | 9 | @dataclass 10 | class TranslationProcessorConfig(ProcessorConfig): 11 | """Configuration class for translation dataset processors. 12 | 13 | This class extends the base ProcessorConfig with translation-specific parameters. 14 | 15 | Args: 16 | dataset_name (str): HuggingFace dataset name/path. Defaults to "wmt14". 17 | dataset_config_name (str | None): Specific configuration of the dataset to load. 18 | For WMT14, must be one of ['cs-en', 'de-en', 'fr-en', 'hi-en', 'ru-en']. 19 | Defaults to "de-en". 20 | source_lang (str): Source language code (e.g., "en" for English). Defaults to "en". 21 | target_lang (str): Target language code (e.g., "de" for German). Defaults to "de". 22 | source_column (str): Column name containing source text. Defaults to "translation". 23 | target_column (str): Column name containing target text. If None, uses source_column. Defaults to None. 24 | preprocessing_fn (callable | None): Optional function to preprocess raw dataset entries. 25 | Should take a dataset entry and return a dict with 'source' and 'target' keys. 26 | Defaults to None. 27 | 28 | Example: 29 | >>> config = TranslationProcessorConfig( 30 | ... dataset_name="wmt14", 31 | ... dataset_config_name="de-en", 32 | ... source_lang="en", 33 | ... target_lang="de" 34 | ... ) 35 | >>> processor = TranslationProcessor(config) 36 | """ 37 | 38 | dataset_name: str = "wmt14" 39 | dataset_config_name: str = "de-en" # Required for WMT14 40 | source_lang: str = "en" 41 | target_lang: str = "de" 42 | source_column: str = "translation" 43 | target_column: str | None = None 44 | preprocessing_fn: Callable | None = None 45 | 46 | 47 | class TranslationProcessor(BaseProcessor): 48 | """Processor for translation datasets. 49 | 50 | This processor handles translation datasets in various formats and converts them into 51 | cycleformers-compatible format. It supports both: 52 | - Standard parallel corpora (source -> target) 53 | - Back-translation style training (target -> source) 54 | 55 | The processor: 56 | 1. Loads the translation dataset (streaming for large datasets) 57 | 2. Extracts source and target text 58 | 3. Creates two complementary datasets for cycle training 59 | 60 | Args: 61 | config (TranslationProcessorConfig): Configuration object controlling processor behavior. 62 | Includes settings like dataset name, language pairs, and column names. 63 | 64 | Example: 65 | >>> config = TranslationProcessorConfig( 66 | ... dataset_name="wmt14", 67 | ... dataset_config_name="de-en", 68 | ... source_lang="en", 69 | ... target_lang="de" 70 | ... ) 71 | >>> processor = TranslationProcessor(config) 72 | >>> dataset_A, dataset_B = processor.process() 73 | >>> print(dataset_A["train"][0]) 74 | {'text': 'The cat sat on the mat.'} 75 | >>> print(dataset_B["train"][0]) 76 | {'text': 'Die Katze saß auf der Matte.'} 77 | """ 78 | 79 | def __init__(self, config: TranslationProcessorConfig = TranslationProcessorConfig()): 80 | super().__init__(config) 81 | self.config: TranslationProcessorConfig = config 82 | 83 | def _extract_text_pair(self, example: dict) -> dict: 84 | """Extract source and target text from a dataset example. 85 | 86 | This method handles different dataset formats: 87 | 1. Nested dictionary format (e.g., {'translation': {'en': '...', 'de': '...'}}) 88 | 2. Flat dictionary format (e.g., {'source': '...', 'target': '...'}) 89 | 3. Custom formats via preprocessing_fn 90 | 91 | Args: 92 | example (dict): A single example from the dataset 93 | 94 | Returns: 95 | dict: Dictionary with 'source' and 'target' text 96 | """ 97 | if self.config.preprocessing_fn is not None: 98 | return self.config.preprocessing_fn(example) 99 | 100 | source_col = self.config.source_column 101 | target_col = self.config.target_column or source_col 102 | 103 | if isinstance(example[source_col], dict): 104 | # Handle nested dictionary format (e.g., WMT datasets) 105 | return { 106 | "source": example[source_col][self.config.source_lang], 107 | "target": example[target_col][self.config.target_lang], 108 | } 109 | else: 110 | return { 111 | "source": example[source_col], 112 | "target": example[target_col], 113 | } 114 | 115 | def preprocess(self, dataset: DatasetDict | IterableDatasetDict) -> tuple[DatasetDict, DatasetDict]: 116 | """Preprocess the dataset into two separate datasets for cycle training. 117 | 118 | Args: 119 | dataset (DatasetDict | IterableDatasetDict): The raw dataset containing translation pairs 120 | 121 | Returns: 122 | tuple[DatasetDict, DatasetDict]: Two datasets: 123 | - Dataset A: Source language texts 124 | - Dataset B: Target language texts 125 | Each containing 'train' and 'test' splits with parallel data in test 126 | """ 127 | original_cols = set(dataset["train"].column_names) 128 | dataset = dataset.map(self._extract_text_pair) 129 | dataset = dataset.remove_columns(original_cols - {"source", "target"}) 130 | 131 | dataset_A = dataset.map(lambda x: {"text": x["source"], "label": x["target"]}) 132 | dataset_B = dataset.map(lambda x: {"text": x["target"], "label": x["source"]}) 133 | 134 | dataset_A["train"] = dataset_A["train"].remove_columns(["label"]) 135 | dataset_B["train"] = dataset_B["train"].remove_columns(["label"]) 136 | 137 | return dataset_A, dataset_B 138 | -------------------------------------------------------------------------------- /src/cycleformers/model_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | from .utils import prefixed_view 5 | 6 | 7 | @dataclass 8 | class ModelConfig: 9 | """https://github.com/huggingface/trl/blob/main/trl/trainer/model_config.py 10 | 11 | Parameters: 12 | model_name_or_path (`Optional[str]`, *optional*, defaults to `None`): 13 | Model checkpoint for weights initialization. 14 | model_revision (`str`, *optional*, defaults to `"main"`): 15 | Specific model version to use. It can be a branch name, a tag name, or a commit id. 16 | torch_dtype (`Optional[Literal["auto", "bfloat16", "float16", "float32"]]`, *optional*, defaults to `None`): 17 | Override the default `torch.dtype` and load the model under this dtype. Possible values are 18 | 19 | - `"bfloat16"`: `torch.bfloat16` 20 | - `"float16"`: `torch.float16` 21 | - `"float32"`: `torch.float32` 22 | - `"auto"`: Automatically derive the dtype from the model's weights. 23 | 24 | trust_remote_code (`bool`, *optional*, defaults to `False`): 25 | Whether to allow for custom models defined on the Hub in their own modeling files. This option should only 26 | be set to `True` for repositories you trust and in which you have read the code, as it will execute code 27 | present on the Hub on your local machine. 28 | attn_implementation (`Optional[str]`, *optional*, defaults to `None`): 29 | Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case 30 | you must install this manually by running `pip install flash-attn --no-build-isolation`. 31 | use_peft (`bool`, *optional*, defaults to `False`): 32 | Whether to use PEFT for training. 33 | lora_r (`int`, *optional*, defaults to `16`): 34 | LoRA R value. 35 | lora_alpha (`int`, *optional*, defaults to `32`): 36 | LoRA alpha. 37 | lora_dropout (`float`, *optional*, defaults to `0.05`): 38 | LoRA dropout. 39 | lora_target_modules (`Optional[Union[str, list[str]]]`, *optional*, defaults to `None`): 40 | LoRA target modules. 41 | lora_modules_to_save (`Optional[list[str]]`, *optional*, defaults to `None`): 42 | Model layers to unfreeze & train. 43 | lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`): 44 | Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling). 45 | use_rslora (`bool`, *optional*, defaults to `False`): 46 | Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, instead of 47 | the original default value of `lora_alpha/r`.""" 48 | 49 | model_name_or_path: str | None = None 50 | model_revision: str = "main" 51 | torch_dtype: Literal["auto", "bfloat16", "float16", "float32"] | None = None 52 | trust_remote_code: bool = False 53 | attn_implementation: str | None = None 54 | use_peft: bool = False 55 | lora_r: int = 16 56 | lora_alpha: int = 32 57 | lora_dropout: float = 0.05 58 | lora_target_modules: list[str] | None = None 59 | lora_modules_to_save: list[str] | None = None 60 | lora_task_type: str = "CAUSAL_LM" 61 | use_rslora: bool = False 62 | use_dora: bool = False 63 | 64 | def __post_init__(self): 65 | self._A = None 66 | self._B = None 67 | 68 | @property 69 | def A(self) -> "ModelConfig": 70 | return self._A 71 | 72 | @A.setter 73 | def A(self, value: "ModelConfig"): 74 | self._A = value 75 | 76 | @property 77 | def B(self) -> "ModelConfig": 78 | return self._B 79 | 80 | @B.setter 81 | def B(self, value: "ModelConfig"): 82 | self._B = value 83 | 84 | 85 | @dataclass 86 | @prefixed_view(ModelConfig, "A_") 87 | class ModelConfigA: 88 | pass 89 | 90 | 91 | @dataclass 92 | @prefixed_view(ModelConfig, "B_") 93 | class ModelConfigB: 94 | pass 95 | 96 | 97 | def merge_configs(base_config: ModelConfig, config_a: ModelConfigA, config_b: ModelConfigB) -> ModelConfig: 98 | """Merge configs, with A/B specific values overriding base values, unless they're defaults. 99 | 100 | Args: 101 | base_config (ModelConfig): Base configuration with default values 102 | config_a (ModelConfigA): Model A specific configuration that may override base values 103 | config_b (ModelConfigB): Model B specific configuration that may override base values 104 | 105 | Returns: 106 | ModelConfig: The base config with A and B specific configs merged in 107 | 108 | Example: 109 | >>> base = ModelConfig(model_name="base", lora_r=32) 110 | >>> a = ModelConfigA(A_model_name="model_a", A_lora_r=64) 111 | >>> b = ModelConfigB(B_model_name="model_b") 112 | >>> merged = merge_configs(base, a, b) 113 | >>> merged.A.model_name 114 | 'model_a' 115 | >>> merged.A.lora_r 116 | 64 117 | >>> merged.B.model_name 118 | 'model_b' 119 | >>> merged.B.lora_r 120 | 32 121 | """ 122 | # Create copies to avoid modifying originals 123 | merged_a = ModelConfig(**{k: getattr(base_config, k) for k in base_config.__dataclass_fields__}) 124 | merged_b = ModelConfig(**{k: getattr(base_config, k) for k in base_config.__dataclass_fields__}) 125 | 126 | # Create a default config to check against 127 | default_config = ModelConfig() 128 | 129 | # Override with A-specific values, but only if they're not defaults 130 | for field in base_config.__dataclass_fields__: 131 | if hasattr(config_a, field): 132 | config_a_value = getattr(config_a, field) 133 | # Only override if the A-specific value is different from default 134 | if config_a_value != getattr(default_config, field): 135 | setattr(merged_a, field, config_a_value) 136 | 137 | # Override with B-specific values, but only if they're not defaults 138 | for field in base_config.__dataclass_fields__: 139 | if hasattr(config_b, field): 140 | config_b_value = getattr(config_b, field) 141 | # Only override if the B-specific value is different from default 142 | if config_b_value != getattr(default_config, field): 143 | setattr(merged_b, field, config_b_value) 144 | 145 | base_config.A = merged_a 146 | base_config.B = merged_b 147 | return base_config 148 | 149 | 150 | __all__ = ["ModelConfig", "ModelConfigA", "ModelConfigB", "merge_configs"] 151 | -------------------------------------------------------------------------------- /tests/benchmark/benchmark_tokenizer_skip.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import timeit 4 | from dataclasses import dataclass, field 5 | from datetime import datetime 6 | from pathlib import Path 7 | from random import Random 8 | 9 | import torch 10 | from benchmark_utils import Benchmark, BenchmarkConfig, MockCycleTrainer 11 | from datasets import load_dataset 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | from cycleformers.cycles import _default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs 16 | from cycleformers.utils import DEFAULT_SEP_SEQ 17 | 18 | 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | 22 | @dataclass 23 | class TokenPrepBenchmarkConfig(BenchmarkConfig): 24 | config_name: str = "token_prep" 25 | dataset_name: str 26 | model_A_name: str 27 | model_B_name: str 28 | cycle_name: str | None = None 29 | num_samples: int = 500 30 | batch_sizes: list[int] = field(default_factory=lambda: [1, 8, 32, 128]) 31 | devices: list[str] = field(default_factory=lambda: ["cpu", "cuda"]) 32 | seed: int = field(default_factory=lambda: int(datetime.now().timestamp())) 33 | 34 | 35 | class TokenPreparationBenchmark: 36 | def __init__(self, config: TokenPrepBenchmarkConfig): 37 | self.config = config 38 | self.random = Random(config.seed) 39 | 40 | # Initialize tokenizer 41 | self.tokenizer = AutoTokenizer.from_pretrained(config.model_A_name) 42 | if self.tokenizer.pad_token is None: 43 | self.tokenizer.pad_token = self.tokenizer.eos_token 44 | 45 | self.model_A = AutoModelForCausalLM.from_pretrained(config.model_A_name) 46 | 47 | self.save_dir = Path(__file__).parent / "benchmark_results" / "token_prep" 48 | self.prepare_dataset() 49 | 50 | def prepare_dataset(self): 51 | # FIXME: Currently only supports datasets with "instruction", "response" as columns 52 | self.dataset = load_dataset(self.config.dataset_name, split="train").select(range(self.config.num_samples)) 53 | self.dataset = self.dataset.map( 54 | lambda x: { 55 | "synth_ids": self.tokenizer( 56 | x["instruction"] + DEFAULT_SEP_SEQ + x["response"], padding=False 57 | ).input_ids, 58 | "real_ids": self.tokenizer(x["instruction"], padding=False).input_ids, 59 | } 60 | ) 61 | 62 | def manual_right_pad_batch(self, batch_ids: list[list[int]], pad_token_id: int) -> torch.Tensor: 63 | """Right pad a batch of token IDs to the maximum length in the batch.""" 64 | max_length = max(len(ids) for ids in batch_ids) 65 | padded_batch = [] 66 | 67 | for ids in batch_ids: 68 | padding_length = max_length - len(ids) 69 | padded_ids = ids + [pad_token_id] * padding_length 70 | padded_batch.append(padded_ids) 71 | 72 | return torch.tensor(padded_batch) 73 | 74 | def get_dataloader(self, batch_size: int, device: str): 75 | generator = torch.Generator(device=device) 76 | generator.manual_seed(self.config.seed) 77 | 78 | def collate_fn(batch): 79 | real_ids = [item["real_ids"] for item in batch] 80 | synth_ids = [item["synth_ids"] for item in batch] 81 | 82 | # Manually pad both sequences 83 | real_ids_padded = self.manual_right_pad_batch(real_ids, self.tokenizer.pad_token_id) 84 | synth_ids_padded = self.manual_right_pad_batch(synth_ids, self.tokenizer.pad_token_id) 85 | 86 | return real_ids_padded, synth_ids_padded 87 | 88 | dataloader = DataLoader( 89 | self.dataset, 90 | batch_size=batch_size, 91 | shuffle=True, 92 | num_workers=4, 93 | pin_memory=True, 94 | # generator=generator, 95 | collate_fn=collate_fn, 96 | ) 97 | return dataloader 98 | 99 | def _save_metrics(self, metrics: dict): 100 | run_dir = self.save_dir / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" 101 | run_dir.mkdir(parents=True, exist_ok=True) 102 | 103 | with open(run_dir / f"metrics.json", "w") as f: 104 | json.dump(metrics, f) 105 | 106 | def run(self, implementations: list[str] | None = None, n_runs: int = 5): 107 | """Run benchmarks across all configurations""" 108 | 109 | results = [] 110 | 111 | for batch_size in self.config.batch_sizes: 112 | for device in self.config.devices: 113 | if device == "cuda" and not torch.cuda.is_available(): 114 | continue 115 | 116 | print(f"\nBenchmarking batch_size={batch_size} on {device}") 117 | 118 | # Prepare data 119 | dataloader = self.get_dataloader(batch_size, device) 120 | 121 | for impl in implementations: 122 | # Time the implementation 123 | 124 | def run_impl(): 125 | outputs = [] 126 | for batch in dataloader: 127 | real_ids, synth_ids = batch 128 | 129 | output = impl( 130 | real_input_ids=real_ids.to(device), 131 | synth_input_ids=synth_ids.to(device), 132 | model_gen=self.model_A, 133 | model_train=self.model_A, 134 | tokenizer_gen=self.tokenizer, 135 | tokenizer_train=self.tokenizer, 136 | cycle_name="A", 137 | ) 138 | outputs.append(output) 139 | return outputs 140 | 141 | # Time runs 142 | if device == "cuda": 143 | torch.cuda.synchronize() 144 | 145 | timer = timeit.Timer(run_impl) 146 | times = timer.repeat(repeat=n_runs, number=5) 147 | 148 | metrics = {"mean": sum(times) / len(times), "min": min(times), "max": max(times)} 149 | 150 | results.append({"implementation": impl, "batch_size": batch_size, "device": device, **metrics}) 151 | 152 | print(f"\n{impl}:") 153 | print(f"Mean: {metrics['mean']*1000:.2f}ms") 154 | print(f"Min: {metrics['min']*1000:.2f}ms") 155 | print(f"Max: {metrics['max']*1000:.2f}ms") 156 | 157 | return results 158 | 159 | 160 | if __name__ == "__main__": 161 | config = TokenPrepBenchmarkConfig( 162 | dataset_name="MBZUAI/LaMini-instruction", 163 | model_A_name="trl-internal-testing/tiny-LlamaForCausalLM-3.1", 164 | model_B_name="trl-internal-testing/tiny-LlamaForCausalLM-3.1", 165 | ) 166 | 167 | benchmark = TokenPreparationBenchmark(config=config) 168 | benchmark.run(implementations=[_default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs]) 169 | -------------------------------------------------------------------------------- /tests/testing_utils/model_registry.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from dataclasses import dataclass 3 | from enum import Enum, auto 4 | from pathlib import Path 5 | from typing import Callable, Union 6 | 7 | import yaml 8 | from transformers import AutoConfig, PretrainedConfig 9 | 10 | 11 | class CapabilityExpression: 12 | """Expression for filtering models based on their capabilities.""" 13 | 14 | def __init__(self, condition: Callable[[set["ModelCapability"]], bool]): 15 | self.condition = condition 16 | 17 | def evaluate(self, capabilities: set["ModelCapability"]) -> bool: 18 | return self.condition(capabilities) 19 | 20 | def __and__(self, other: "CapabilityExpression") -> "CapabilityExpression": 21 | return CapabilityExpression(lambda caps: self.condition(caps) and other.condition(caps)) 22 | 23 | def __or__(self, other: "CapabilityExpression") -> "CapabilityExpression": 24 | return CapabilityExpression(lambda caps: self.condition(caps) or other.condition(caps)) 25 | 26 | def __invert__(self) -> "CapabilityExpression": 27 | return CapabilityExpression(lambda caps: not self.condition(caps)) 28 | 29 | 30 | class ModelCapability(Enum): 31 | SEQ2SEQ_LM = auto() 32 | CAUSAL_LM = auto() 33 | 34 | @classmethod 35 | def from_str(cls, name: str) -> "ModelCapability": 36 | return cls[name] 37 | 38 | def to_expression(self) -> "CapabilityExpression": 39 | return CapabilityExpression(lambda caps: self in caps) 40 | 41 | def __and__(self, other: Union["ModelCapability", "CapabilityExpression"]) -> "CapabilityExpression": 42 | return self.to_expression() & (other.to_expression() if isinstance(other, ModelCapability) else other) 43 | 44 | def __or__(self, other: Union["ModelCapability", "CapabilityExpression"]) -> "CapabilityExpression": 45 | return self.to_expression() | (other.to_expression() if isinstance(other, ModelCapability) else other) 46 | 47 | def __invert__(self) -> "CapabilityExpression": 48 | return ~self.to_expression() 49 | 50 | 51 | def infer_capabilities_from_config(config: PretrainedConfig) -> set[ModelCapability]: 52 | """Infer model capabilities from HuggingFace config. 53 | 54 | Args: 55 | config: HuggingFace model config 56 | 57 | Returns: 58 | Set of model capabilities inferred from the config. 59 | """ 60 | capabilities = set() 61 | 62 | # Architecture-based capabilities 63 | if hasattr(config, "is_encoder_decoder") and config.is_encoder_decoder: 64 | capabilities.add(ModelCapability.SEQ2SEQ_LM) 65 | 66 | if hasattr(config, "architectures") and config.architectures: 67 | if any("ForCausalLM" in arch for arch in config.architectures): 68 | capabilities.add(ModelCapability.CAUSAL_LM) 69 | 70 | return capabilities 71 | 72 | 73 | @dataclass 74 | class ModelSpec: 75 | name: str 76 | repo_id: str 77 | capabilities: set[ModelCapability] 78 | config: PretrainedConfig 79 | description: str = "" # Any notes that are important to remember about the model 80 | 81 | def __hash__(self) -> int: 82 | return hash((self.name, self.repo_id)) 83 | 84 | def __eq__(self, other: object) -> bool: 85 | if not isinstance(other, ModelSpec): 86 | return NotImplemented 87 | return self.name == other.name and self.repo_id == other.repo_id 88 | 89 | @classmethod 90 | def from_hub(cls, name: str, repo_id: str) -> "ModelSpec": 91 | config = AutoConfig.from_pretrained(repo_id) 92 | capabilities = infer_capabilities_from_config(config) 93 | return cls(name=name, repo_id=repo_id, capabilities=capabilities, config=config) 94 | 95 | 96 | class ModelRegistry: 97 | def __init__(self, registry_path: Path): 98 | self._models: dict[str, ModelSpec] = {} 99 | self.load_registry(registry_path) 100 | 101 | def load_registry(self, registry_path: Path): 102 | with registry_path.open() as f: 103 | registry_dict = yaml.safe_load(f) 104 | 105 | for name, spec in registry_dict.items(): 106 | self._models[name] = ModelSpec.from_hub(name=name, repo_id=spec["repo_id"]) 107 | 108 | def get_matching_models( 109 | self, 110 | capability_expr: ModelCapability | CapabilityExpression | None = None, 111 | model_names: list[str] | str | None = None, 112 | ) -> list[ModelSpec]: 113 | """Get models matching capability expression AND model names. 114 | 115 | Args: 116 | capability_expr: Capability expression to match. Can be a ModelCapability or a CapabilityExpression. 117 | If None, no capability filtering is applied. 118 | model_names: List of model names to match. Can be a single string or a list of strings. 119 | If None, no name filtering is applied. 120 | 121 | Returns: 122 | List of models matching both the capability expression and model names. 123 | If both capability_expr and model_names are None, returns all models. 124 | 125 | Examples: 126 | >>> registry = ModelRegistry(Path("models_to_test.yaml")) 127 | >>> # Find models that are NOT seq2seq 128 | >>> models = registry.get_matching_models(~ModelCapability.SEQ2SEQ_LM) 129 | 130 | >>> # Find models that are causal LM 131 | >>> models = registry.get_matching_models(ModelCapability.CAUSAL_LM) 132 | 133 | >>> # Find models that are either causal LM or seq2seq 134 | >>> models = registry.get_matching_models( 135 | ... ModelCapability.CAUSAL_LM | ModelCapability.SEQ2SEQ_LM 136 | ... ) 137 | 138 | >>> # Find models that are both causal LM and seq2seq (empty list) 139 | >>> models = registry.get_matching_models( 140 | ... ModelCapability.CAUSAL_LM & ModelCapability.SEQ2SEQ_LM 141 | ... ) 142 | 143 | >>> # Get specific models by name 144 | >>> models = registry.get_matching_models(model_names=["tiny-llama", "tiny-t5"]) 145 | 146 | >>> # Combine capability and name filters 147 | >>> models = registry.get_matching_models( 148 | ... capability_expr=ModelCapability.SEQ2SEQ_LM, 149 | ... model_names=["tiny-t5"] 150 | ... ) 151 | 152 | >>> # Return tiny-llama-3.1 153 | >>> models = registry.get_matching_models(model_names="tiny-llama-3.1") 154 | 155 | >>> # Get all models 156 | >>> models = registry.get_matching_models() 157 | """ 158 | if capability_expr is None and model_names is None: 159 | return list(self._models.values()) 160 | 161 | if isinstance(capability_expr, ModelCapability): 162 | capability_expr = capability_expr.to_expression() 163 | 164 | if isinstance(model_names, str): 165 | model_names = [model_names] 166 | 167 | matches = [] 168 | for spec in self._models.values(): 169 | if capability_expr is not None: 170 | if not capability_expr.evaluate(spec.capabilities): 171 | continue 172 | 173 | if model_names is not None: 174 | if spec.name not in model_names: 175 | continue 176 | 177 | matches.append(spec) 178 | 179 | return matches 180 | -------------------------------------------------------------------------------- /tests/task_processors/test_translation_processors.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from datasets import Dataset, DatasetDict 5 | 6 | from cycleformers.task_processors.translation import TranslationProcessor, TranslationProcessorConfig 7 | 8 | 9 | SAMPLES_DIR = Path(__file__).parent.parent / "sample_data" 10 | 11 | 12 | class TestTranslationProcessor: 13 | def test_preprocess_wmt_format(self, monkeypatch, temp_dir): 14 | """Test preprocessing of WMT-style datasets with nested translation dictionaries.""" 15 | monkeypatch.setenv("HF_DATASETS_CACHE", str(temp_dir)) 16 | 17 | config = TranslationProcessorConfig( 18 | dataset_name=SAMPLES_DIR / "wmt14", 19 | dataset_config_name="de-en", # Required for WMT14 20 | dataset_seed=42, # Fixed seed for reproducible tests 21 | cache_dir=str(temp_dir), 22 | ) 23 | processor = TranslationProcessor(config) 24 | dataset_A, dataset_B = processor.process() 25 | 26 | # Check that all splits are preserved 27 | assert set(dataset_A.keys()) == {"train", "validation", "test"} 28 | assert set(dataset_B.keys()) == {"train", "validation", "test"} 29 | 30 | # Check that training data is non-parallel (no labels) 31 | assert "label" not in dataset_A["train"].column_names 32 | assert "label" not in dataset_B["train"].column_names 33 | 34 | # Verify training splits are properly shuffled and unaligned 35 | train_texts_A = dataset_A["train"]["text"] 36 | train_texts_B = dataset_B["train"]["text"] 37 | 38 | # Check no exact alignment exists 39 | assert not any( 40 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i])) 41 | for i in range(len(train_texts_B)) 42 | ), "Training splits appear to be aligned with some offset" 43 | 44 | # Verify that shuffling is deterministic with the same seed 45 | config2 = TranslationProcessorConfig( 46 | dataset_name=SAMPLES_DIR / "wmt14", 47 | dataset_config_name="de-en", 48 | dataset_seed=42, 49 | cache_dir=str(temp_dir), 50 | ) 51 | processor2 = TranslationProcessor(config2) 52 | dataset_A2, dataset_B2 = processor2.process() 53 | 54 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"]) 55 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"]) 56 | 57 | # Check that evaluation splits maintain parallel data 58 | for key in ["validation", "test"]: 59 | assert dataset_A[key]["text"] == dataset_B[key]["label"] 60 | assert dataset_A[key]["label"] == dataset_B[key]["text"] 61 | assert dataset_A[key]["text"] != dataset_B[key]["text"] 62 | assert dataset_A[key]["label"] != dataset_B[key]["label"] 63 | 64 | def test_preprocess_flat_format(self): 65 | """Test preprocessing of datasets with flat source/target columns.""" 66 | data = { 67 | "train": { 68 | "source": ["Hello", "World", "Good", "Morning"], 69 | "target": ["Hallo", "Welt", "Gut", "Morgen"], 70 | }, 71 | "test": { 72 | "source": ["Test"], 73 | "target": ["Test"], 74 | }, 75 | } 76 | dataset = DatasetDict({split: Dataset.from_dict(data) for split, data in data.items()}) 77 | 78 | config = TranslationProcessorConfig( 79 | dataset_name=SAMPLES_DIR / "wmt14", 80 | source_column="source", 81 | target_column="target", 82 | dataset_seed=42, # Fixed seed for reproducible tests 83 | ) 84 | processor = TranslationProcessor(config) 85 | dataset_A, dataset_B = processor.preprocess(dataset) 86 | 87 | # Check basic structure 88 | assert len(dataset_A.keys()) == len(dataset.keys()) 89 | assert len(dataset_B.keys()) == len(dataset.keys()) 90 | 91 | # Verify training splits are properly shuffled and unaligned 92 | train_texts_A = dataset_A["train"]["text"] 93 | train_texts_B = dataset_B["train"]["text"] 94 | 95 | # Check no exact alignment exists 96 | assert not any( 97 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i])) 98 | for i in range(len(train_texts_B)) 99 | ), "Training splits appear to be aligned with some offset" 100 | 101 | # Verify that shuffling is deterministic with the same seed 102 | config2 = TranslationProcessorConfig( 103 | source_column="source", 104 | target_column="target", 105 | dataset_seed=42, 106 | ) 107 | processor2 = TranslationProcessor(config2) 108 | dataset_A2, dataset_B2 = processor2.preprocess(dataset) 109 | 110 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"]) 111 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"]) 112 | 113 | # Check evaluation data is parallel 114 | assert dataset_A["test"][0]["text"] == "Test" 115 | assert dataset_A["test"][0]["label"] == "Test" 116 | 117 | def test_custom_preprocessing(self): 118 | """Test preprocessing with a custom preprocessing function.""" 119 | data = { 120 | "train": { 121 | "text": [ 122 | "en: Hello || de: Hallo", 123 | "en: World || de: Welt", 124 | "en: Good || de: Gut", 125 | "en: Morning || de: Morgen", 126 | ], 127 | }, 128 | "test": { 129 | "text": ["en: Test || de: Test"], 130 | }, 131 | } 132 | dataset = DatasetDict({split: Dataset.from_dict(data) for split, data in data.items()}) 133 | 134 | def custom_preprocessor(example): 135 | en, de = example["text"].split("||") 136 | return { 137 | "source": en.split(": ")[1].strip(), 138 | "target": de.split(": ")[1].strip(), 139 | } 140 | 141 | config = TranslationProcessorConfig( 142 | dataset_name=SAMPLES_DIR / "wmt14", 143 | preprocessing_fn=custom_preprocessor, 144 | dataset_seed=42, # Fixed seed for reproducible tests 145 | ) 146 | processor = TranslationProcessor(config) 147 | dataset_A, dataset_B = processor.preprocess(dataset) 148 | 149 | # Verify training splits are properly shuffled and unaligned 150 | train_texts_A = dataset_A["train"]["text"] 151 | train_texts_B = dataset_B["train"]["text"] 152 | 153 | # Check no exact alignment exists 154 | assert not any( 155 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i])) 156 | for i in range(len(train_texts_B)) 157 | ), "Training splits appear to be aligned with some offset" 158 | 159 | # Verify that shuffling is deterministic with the same seed 160 | config2 = TranslationProcessorConfig( 161 | preprocessing_fn=custom_preprocessor, 162 | dataset_seed=42, 163 | ) 164 | processor2 = TranslationProcessor(config2) 165 | dataset_A2, dataset_B2 = processor2.preprocess(dataset) 166 | 167 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"]) 168 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"]) 169 | 170 | # Check evaluation data is parallel 171 | assert dataset_A["test"][0]["text"] == "Test" 172 | assert dataset_A["test"][0]["label"] == "Test" 173 | -------------------------------------------------------------------------------- /src/cycleformers/task_processors/ner.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | from datasets import DatasetDict 5 | 6 | from .base import BaseProcessor, ProcessorConfig 7 | 8 | 9 | tag_to_idx = {"O": 0, "B-PER": 1, "I-PER": 2, "B-ORG": 3, "I-ORG": 4, "B-LOC": 5, "I-LOC": 6, "B-MISC": 7, "I-MISC": 8} 10 | idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()} 11 | tag_to_string = {"PER": "person", "ORG": "organisation", "LOC": "location", "MISC": "miscellaneous"} 12 | string_to_tag = {string: tag for tag, string in tag_to_string.items()} 13 | 14 | 15 | @dataclass 16 | class CONLL2003ProcessorConfig(ProcessorConfig): 17 | """Configuration class for CONLL2003 dataset processor. 18 | 19 | This class extends the base ProcessorConfig with CONLL2003-specific parameters. 20 | 21 | Args: 22 | dataset_name (str): HuggingFace dataset name/path. Defaults to "eriktks/conll2003". 23 | preset (Literal["entity_seqs"] | None): Processing preset to use. Currently only supports "entity_seqs" 24 | which extracts sequences containing named entities. Defaults to "entity_seqs". 25 | sep_token (str): Token used to separate elements in processed sequences. Defaults to "|". 26 | 27 | Example: 28 | >>> config = CONLL2003ProcessorConfig( 29 | ... dataset_name="conll2003", 30 | ... preset="entity_seqs", 31 | ... sep_token="|" 32 | ... ) 33 | >>> processor = CONLL2003Processor(config) 34 | """ 35 | 36 | dataset_name: str = "eriktks/conll2003" 37 | preset: Literal["entity_seqs"] | None = "entity_seqs" 38 | # trust_remote_code: bool = False 39 | sep_token: str = "|" 40 | 41 | 42 | def reconstruct_sentence(tokens: list[str]) -> str: 43 | """Reconstructs a sentence from CoNLL-2003 tokens while handling whitespace correctly. 44 | 45 | This function takes a list of tokens from the CoNLL-2003 format and reconstructs them into a properly formatted sentence 46 | by applying standard English language spacing rules: 47 | 48 | - No spaces before punctuation marks (,.!?:;) 49 | - Spaces between regular words 50 | - Proper handling of contractions (don't, I'm, etc.) 51 | - Correct spacing around quotation marks 52 | 53 | Args: 54 | tokens (list[str]): List of tokens from the CoNLL-2003 dataset 55 | 56 | Returns: 57 | str: The reconstructed sentence with proper spacing and punctuation 58 | 59 | Example: 60 | >>> tokens = ["I", "don", "'t", "like", "New", "York", "."] 61 | >>> reconstruct_sentence(tokens) 62 | "I don't like New York." 63 | """ 64 | if not tokens: 65 | return "" 66 | 67 | # Define punctuation categories 68 | closing_puncts = set(",.!?:;") 69 | opening_quotes = set('"' '"') 70 | closing_quotes = set('"' '"') 71 | contractions = set("'s 're 've 'm 't 'll 'd".split()) 72 | 73 | result = [] 74 | for i, token in enumerate(tokens): 75 | if i == 0: 76 | result.append(token) 77 | continue 78 | 79 | prev_token = tokens[i - 1] 80 | 81 | needs_space = not ( 82 | token in closing_puncts 83 | or token.lower() in contractions 84 | or prev_token in opening_quotes 85 | or token in closing_quotes 86 | ) 87 | 88 | if needs_space: 89 | result.append(" " + token) 90 | else: 91 | result.append(token) 92 | 93 | return "".join(result) 94 | 95 | 96 | def ner_to_sequences(tokens: list[str], tags: list[int], sep_token: str) -> str: 97 | """Convert a list of tokens and their corresponding tags to a sequence of entity types. 98 | 99 | This function takes tokens and their NER tags and converts them into a sequence format 100 | suitable for sequence-to-sequence training. It follows the BIO2 tagging scheme where: 101 | - B- prefix indicates beginning of an entity 102 | - I- prefix indicates inside/continuation of an entity 103 | - O indicates outside any entity (not used in output) 104 | 105 | Args: 106 | tokens (list[str]): List of string tokens from a single sentence 107 | tags (list[int]): List of integer tags corresponding to the tokens using BIO2 scheme 108 | sep_token (str): Separator token to use between entity and its type 109 | 110 | Returns: 111 | str: A string containing entities and their types separated by the sep_token. 112 | For example: "John Smith | person Google | organization" 113 | 114 | Raises: 115 | ValueError: If an I- tag appears without a preceding B- tag (invalid BIO2 sequence) 116 | 117 | Example: 118 | >>> tokens = ["John", "Smith", "works", "at", "Google", "."] 119 | >>> tags = [1, 2, 0, 0, 7, 0] # B-PER, I-PER, O, O, B-ORG, O 120 | >>> ner_to_sequences(tokens, tags, " | ") 121 | 'John Smith | person Google | organization' 122 | """ 123 | compound_tokens = [] 124 | for token, tag in zip(tokens, tags): 125 | if tag in [1, 3, 5, 7]: 126 | tag_string = tag_to_string[idx_to_tag[tag].split("-")[-1]] 127 | compound_tokens.append([token, sep_token, tag_string]) 128 | elif tag in [2, 4, 6, 8]: 129 | if not compound_tokens: 130 | raise ValueError("Missing B-tag before I-tag. Please use BIO2 tagging scheme, not BIO1.") 131 | 132 | compound_tokens[-1].insert(-2, token) 133 | 134 | return f" {sep_token} ".join([" ".join(token_tags) for token_tags in compound_tokens]) 135 | 136 | 137 | class CONLL2003Processor(BaseProcessor): 138 | """Processor for the CONLL2003 Named Entity Recognition dataset. 139 | 140 | This processor handles the CONLL2003 dataset which contains text annotated with named entities. 141 | It converts the dataset into two formats: 142 | - Dataset A: Raw text -> Entity sequences 143 | - Dataset B: Entity sequences -> Raw text 144 | 145 | The processor: 146 | 1. Loads the CONLL2003 dataset 147 | 2. Converts the token-level NER annotations into sequence format 148 | 3. Creates two complementary datasets for cycle training 149 | 150 | Args: 151 | config (CONLL2003ProcessorConfig): Configuration object controlling processor behavior. 152 | Includes settings like separator token between entities and their types. 153 | 154 | Example: 155 | >>> config = CONLL2003ProcessorConfig(sep_token=" | ") 156 | >>> processor = CONLL2003Processor(config) 157 | >>> dataset_A, dataset_B = processor.process() 158 | >>> print(dataset_A["train"][`0`]) 159 | {'text': 'John Smith works at Google.'} 160 | >>> print(dataset_B["train"][`0`]) 161 | {'text': 'John Smith | person Google | organization'} 162 | """ 163 | 164 | def __init__(self, config: CONLL2003ProcessorConfig = CONLL2003ProcessorConfig()): 165 | super().__init__(config) 166 | # Ensure formatting of sep token is correct 167 | self.config: CONLL2003ProcessorConfig = config # type annotation for config 168 | self.sep_token = config.sep_token.strip() 169 | 170 | def preprocess(self, dataset: DatasetDict) -> tuple[DatasetDict, DatasetDict]: 171 | original_cols = dataset["train"].column_names 172 | dataset = dataset.map( 173 | lambda x: { 174 | "sentence": reconstruct_sentence(x["tokens"]), 175 | "entity_seq": ner_to_sequences(x["tokens"], x["ner_tags"], self.sep_token), 176 | } 177 | ).remove_columns(original_cols) 178 | 179 | dataset_A = dataset.map(lambda x: {"text": x["sentence"], "label": x["entity_seq"]}) 180 | dataset_B = dataset.map(lambda x: {"text": x["entity_seq"], "label": x["sentence"]}) 181 | 182 | dataset_A["train"] = dataset_A["train"].remove_columns(["label"]) 183 | dataset_B["train"] = dataset_B["train"].remove_columns(["label"]) 184 | 185 | return dataset_A, dataset_B 186 | -------------------------------------------------------------------------------- /tests/testing_utils/test_model_registry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import yaml 3 | from transformers import AutoConfig, PretrainedConfig 4 | 5 | from tests.testing_utils.model_registry import ( 6 | CapabilityExpression, 7 | ModelCapability, 8 | ModelRegistry, 9 | infer_capabilities_from_config, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def sample_registry_file(temp_dir): 15 | registry_path = temp_dir / "Verify_models.yaml" 16 | registry_content = { 17 | "tiny-llama": { 18 | "repo_id": "fake-repo/tiny-llama", 19 | }, 20 | "tiny-t5": { 21 | "repo_id": "fake-repo/tiny-t5", 22 | }, 23 | } 24 | 25 | with registry_path.open("w") as f: 26 | yaml.dump(registry_content, f) 27 | 28 | return registry_path 29 | 30 | 31 | @pytest.fixture 32 | def mock_config_test(monkeypatch): 33 | def mock_from_pretrained(repo_id): 34 | if "llama" in repo_id.lower(): 35 | config = PretrainedConfig() 36 | config.architectures = ["LlamaForCausalLM"] 37 | config.model_type = "llama" 38 | return config 39 | elif "t5" in repo_id.lower(): 40 | config = PretrainedConfig() 41 | config.architectures = ["T5ForConditionalGeneration"] 42 | config.is_encoder_decoder = True 43 | config.model_type = "t5" 44 | return config 45 | raise ValueError(f"Unknown model: {repo_id}") 46 | 47 | monkeypatch.setattr(AutoConfig, "from_pretrained", mock_from_pretrained) 48 | 49 | 50 | @pytest.mark.meta 51 | class TestModelRegistry: 52 | def test_load_registry(self, sample_registry_file, mock_config_test): 53 | registry = ModelRegistry(sample_registry_file) 54 | 55 | assert len(registry._models) == 2 56 | assert "tiny-llama" in registry._models 57 | assert "tiny-t5" in registry._models 58 | 59 | llama = registry._models["tiny-llama"] 60 | assert llama.name == "tiny-llama" 61 | assert llama.repo_id == "fake-repo/tiny-llama" 62 | assert ModelCapability.CAUSAL_LM in llama.capabilities 63 | assert ModelCapability.SEQ2SEQ_LM not in llama.capabilities 64 | 65 | t5 = registry._models["tiny-t5"] 66 | assert t5.name == "tiny-t5" 67 | assert t5.repo_id == "fake-repo/tiny-t5" 68 | assert ModelCapability.SEQ2SEQ_LM in t5.capabilities 69 | 70 | @pytest.mark.parametrize( 71 | "capability,model_names,expected_models", 72 | [ 73 | (ModelCapability.CAUSAL_LM, None, ["tiny-llama"]), 74 | (ModelCapability.SEQ2SEQ_LM, None, ["tiny-t5"]), 75 | (~ModelCapability.SEQ2SEQ_LM, None, ["tiny-llama"]), 76 | (ModelCapability.CAUSAL_LM | ModelCapability.SEQ2SEQ_LM, None, ["tiny-llama", "tiny-t5"]), 77 | (ModelCapability.CAUSAL_LM & ModelCapability.SEQ2SEQ_LM, None, []), 78 | (None, ["tiny-llama"], ["tiny-llama"]), 79 | (None, ["tiny-llama", "tiny-t5"], ["tiny-llama", "tiny-t5"]), 80 | (ModelCapability.CAUSAL_LM, "tiny-llama", ["tiny-llama"]), 81 | (ModelCapability.CAUSAL_LM, ["tiny-t5"], []), 82 | ], 83 | ) 84 | def test_get_matching_models( 85 | self, sample_registry_file, mock_config_test, capability, model_names, expected_models 86 | ): 87 | registry = ModelRegistry(sample_registry_file) 88 | matches = registry.get_matching_models(capability, model_names=model_names) 89 | assert sorted([m.name for m in matches]) == sorted(expected_models) 90 | 91 | 92 | @pytest.mark.meta 93 | class TestCapabilityExpression: 94 | def test_simple_condition(self): 95 | expr = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps) 96 | assert expr.evaluate({ModelCapability.CAUSAL_LM}) 97 | assert not expr.evaluate({ModelCapability.SEQ2SEQ_LM}) 98 | 99 | def test_and_operator(self): 100 | expr1 = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps) 101 | expr2 = CapabilityExpression(lambda caps: ModelCapability.SEQ2SEQ_LM in caps) 102 | combined = expr1 & expr2 103 | 104 | assert not combined.evaluate({ModelCapability.CAUSAL_LM}) 105 | assert not combined.evaluate({ModelCapability.SEQ2SEQ_LM}) 106 | assert combined.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM}) 107 | 108 | def test_or_operator(self): 109 | expr1 = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps) 110 | expr2 = CapabilityExpression(lambda caps: ModelCapability.SEQ2SEQ_LM in caps) 111 | combined = expr1 | expr2 112 | 113 | assert combined.evaluate({ModelCapability.CAUSAL_LM}) 114 | assert combined.evaluate({ModelCapability.SEQ2SEQ_LM}) 115 | assert combined.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM}) 116 | assert not combined.evaluate(set()) 117 | 118 | def test_not_operator(self): 119 | expr = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps) 120 | inverted = ~expr 121 | 122 | assert not inverted.evaluate({ModelCapability.CAUSAL_LM}) 123 | assert inverted.evaluate({ModelCapability.SEQ2SEQ_LM}) 124 | assert inverted.evaluate(set()) 125 | 126 | 127 | @pytest.mark.meta 128 | class TestModelCapability: 129 | def test_from_str(self): 130 | assert ModelCapability.from_str("CAUSAL_LM") == ModelCapability.CAUSAL_LM 131 | assert ModelCapability.from_str("SEQ2SEQ_LM") == ModelCapability.SEQ2SEQ_LM 132 | 133 | with pytest.raises(KeyError): 134 | ModelCapability.from_str("INVALID") 135 | 136 | def test_to_expression(self): 137 | expr = ModelCapability.CAUSAL_LM.to_expression() 138 | assert isinstance(expr, CapabilityExpression) 139 | assert expr.evaluate({ModelCapability.CAUSAL_LM}) 140 | assert not expr.evaluate({ModelCapability.SEQ2SEQ_LM}) 141 | 142 | def test_capability_and(self): 143 | expr = ModelCapability.CAUSAL_LM & ModelCapability.SEQ2SEQ_LM 144 | assert isinstance(expr, CapabilityExpression) 145 | assert not expr.evaluate({ModelCapability.CAUSAL_LM}) 146 | assert not expr.evaluate({ModelCapability.SEQ2SEQ_LM}) 147 | assert expr.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM}) 148 | 149 | def test_capability_or(self): 150 | expr = ModelCapability.CAUSAL_LM | ModelCapability.SEQ2SEQ_LM 151 | assert isinstance(expr, CapabilityExpression) 152 | assert expr.evaluate({ModelCapability.CAUSAL_LM}) 153 | assert expr.evaluate({ModelCapability.SEQ2SEQ_LM}) 154 | assert expr.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM}) 155 | 156 | def test_capability_not(self): 157 | expr = ~ModelCapability.CAUSAL_LM 158 | assert isinstance(expr, CapabilityExpression) 159 | assert not expr.evaluate({ModelCapability.CAUSAL_LM}) 160 | assert expr.evaluate({ModelCapability.SEQ2SEQ_LM}) 161 | assert expr.evaluate(set()) 162 | 163 | 164 | @pytest.mark.meta 165 | class TestInferCapabilities: 166 | @pytest.fixture 167 | def causal_lm_config(self): 168 | config = PretrainedConfig() 169 | config.architectures = ["LlamaForCausalLM"] 170 | config.model_type = "llama" 171 | config.is_encoder_decoder = False 172 | return config 173 | 174 | @pytest.fixture 175 | def seq2seq_config(self): 176 | config = PretrainedConfig() 177 | config.architectures = ["T5ForConditionalGeneration"] 178 | config.model_type = "t5" 179 | config.is_encoder_decoder = True 180 | return config 181 | 182 | def test_infer_causal_lm_from_architecture(self, causal_lm_config): 183 | capabilities = infer_capabilities_from_config(causal_lm_config) 184 | assert ModelCapability.CAUSAL_LM in capabilities 185 | assert ModelCapability.SEQ2SEQ_LM not in capabilities 186 | 187 | def test_infer_seq2seq_from_architecture(self, seq2seq_config): 188 | capabilities = infer_capabilities_from_config(seq2seq_config) 189 | assert ModelCapability.SEQ2SEQ_LM in capabilities 190 | assert ModelCapability.CAUSAL_LM not in capabilities 191 | 192 | def test_empty_config(self): 193 | config = PretrainedConfig() 194 | capabilities = infer_capabilities_from_config(config) 195 | assert len(capabilities) == 0 196 | -------------------------------------------------------------------------------- /tests/command/test_cli_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | from typing import Tuple 4 | 5 | import pytest 6 | import yaml 7 | 8 | from cycleformers.command.cli_utils import VALID_TASKS, CfArgumentParser 9 | 10 | 11 | @dataclass 12 | class DummyArgs: 13 | model_name: str 14 | learning_rate: float = 0.001 15 | 16 | 17 | @dataclass 18 | class DummyArgsAB: 19 | A_model: str 20 | B_model: str 21 | A_learning_rate: float = 0.001 22 | B_learning_rate: float = 0.0001 23 | shared_param: str = "default" 24 | 25 | 26 | @dataclass 27 | class DummyArgsC: 28 | C_model: str 29 | C_learning_rate: float = 0.0005 30 | 31 | 32 | class TestCFArgumentParser: 33 | def test_init_with_task(self): 34 | parser = CfArgumentParser(task="train") 35 | assert parser.task == "train" 36 | 37 | def test_init_without_task(self): 38 | parser = CfArgumentParser() 39 | assert parser.task is None 40 | 41 | def test_init_with_multiple_dataclasses(self): 42 | parser = CfArgumentParser([DummyArgs, DummyArgsAB]) 43 | assert len(parser.dataclass_types) == 2 44 | assert parser.dataclass_types[0] == DummyArgs 45 | assert parser.dataclass_types[1] == DummyArgsAB 46 | 47 | def test_init_with_single_dataclass_not_in_list(self): 48 | parser = CfArgumentParser(DummyArgs) 49 | assert len(parser.dataclass_types) == 1 50 | assert parser.dataclass_types[0] == DummyArgs 51 | 52 | def test_init_with_no_dataclasses(self): 53 | parser = CfArgumentParser() 54 | assert len(parser.dataclass_types) == 0 55 | 56 | def test_invalid_task_raises_error(self, monkeypatch): 57 | parser = CfArgumentParser() 58 | monkeypatch.setattr(sys, "argv", ["invalid_task"]) 59 | with pytest.raises(ValueError, match="Task must be one of"): 60 | parser.parse_args_and_config() 61 | 62 | def test_no_task_provided_raises_error(self, monkeypatch): 63 | parser = CfArgumentParser() 64 | monkeypatch.setattr(sys, "argv", ["script.py"]) 65 | with pytest.raises(ValueError, match="No task provided"): 66 | parser.parse_args_and_config() 67 | 68 | def test_task_mismatch_raises_error(self, monkeypatch): 69 | parser = CfArgumentParser(task="train") 70 | monkeypatch.setattr(sys, "argv", ["script.py", "train"]) 71 | with pytest.raises(ValueError, match="Task already set"): 72 | parser.parse_args_and_config() 73 | 74 | def test_parse_simple_args(self, monkeypatch): 75 | parser = CfArgumentParser([DummyArgs]) 76 | args = ["train", "--model_name", "bert", "--learning_rate", "0.001"] 77 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 78 | 79 | (parsed_args,) = parser.parse_args_and_config() 80 | assert isinstance(parsed_args, DummyArgs) 81 | assert parsed_args.model_name == "bert" 82 | assert parsed_args.learning_rate == 0.001 83 | 84 | def test_parse_multiple_dataclasses(self, monkeypatch): 85 | parser = CfArgumentParser([DummyArgs, DummyArgsC]) 86 | args = [ 87 | "train", 88 | "--model_name", 89 | "bert", 90 | "--learning_rate", 91 | "0.001", 92 | "--C_model", 93 | "gpt2", 94 | "--C_learning_rate", 95 | "0.0005", 96 | ] 97 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 98 | 99 | parsed_args = parser.parse_args_and_config() 100 | assert isinstance(parsed_args, Tuple) 101 | assert len(parsed_args) == 2 102 | assert parsed_args[0].model_name == "bert" 103 | assert parsed_args[1].C_model == "gpt2" 104 | 105 | def test_parse_yaml_config(self, tmp_path): 106 | config = { 107 | "A": {"model": "bert-base", "learning_rate": 0.001}, 108 | "B": {"model": "gpt2", "learning_rate": 0.0001}, 109 | "shared_param": "value", 110 | } 111 | 112 | config_path = tmp_path / "config.yaml" 113 | with open(config_path, "w") as f: 114 | yaml.dump(config, f) 115 | 116 | parser = CfArgumentParser([DummyArgsAB]) 117 | file_args = parser._parse_yaml_config(str(config_path)) 118 | 119 | # Convert file_args list to dict for easier assertion 120 | args_dict = dict(zip(file_args[::2], file_args[1::2])) 121 | 122 | assert args_dict["--A_model"] == "bert-base" 123 | assert args_dict["--B_model"] == "gpt2" 124 | assert args_dict["--A_learning_rate"] == "0.001" 125 | assert args_dict["--B_learning_rate"] == "0.0001" 126 | assert args_dict["--shared_param"] == "value" 127 | 128 | def test_yaml_with_cli_override(self, monkeypatch, tmp_path): 129 | config = {"A": {"model": "bert-base"}, "B": {"model": "gpt2"}} 130 | 131 | config_path = tmp_path / "config.yaml" 132 | with open(config_path, "w") as f: 133 | yaml.dump(config, f) 134 | 135 | parser = CfArgumentParser([DummyArgsAB]) 136 | args = [ 137 | "train", 138 | str(config_path), 139 | "--A_learning_rate", 140 | "0.001", 141 | "--B_learning_rate", 142 | "0.0001", 143 | "--shared_param", 144 | "override", 145 | ] 146 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 147 | 148 | (parsed_args,) = parser.parse_args_and_config() 149 | assert parsed_args.A_model == "bert-base" 150 | assert parsed_args.B_model == "gpt2" 151 | assert parsed_args.A_learning_rate == 0.001 152 | assert parsed_args.B_learning_rate == 0.0001 153 | assert parsed_args.shared_param == "override" 154 | 155 | def test_yaml_with_defaults(self, monkeypatch, tmp_path): 156 | config = {"A": {"model": "bert-base"}, "B": {"model": "gpt2"}} 157 | 158 | config_path = tmp_path / "config.yaml" 159 | with open(config_path, "w") as f: 160 | yaml.dump(config, f) 161 | 162 | parser = CfArgumentParser([DummyArgsAB]) 163 | args = ["train", str(config_path)] 164 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 165 | 166 | (parsed_args,) = parser.parse_args_and_config() 167 | assert parsed_args.A_model == "bert-base" 168 | assert parsed_args.B_model == "gpt2" 169 | assert parsed_args.A_learning_rate == 0.001 # default value 170 | assert parsed_args.B_learning_rate == 0.0001 # default value 171 | assert parsed_args.shared_param == "default" # default value 172 | 173 | def test_duplicate_args_in_yaml(self, tmp_path): 174 | config = { 175 | "A": {"param": "value1"}, 176 | "A_param": "value2", # This creates a duplicate A_param 177 | } 178 | 179 | config_path = tmp_path / "config.yaml" 180 | with open(config_path, "w") as f: 181 | yaml.dump(config, f) 182 | 183 | parser = CfArgumentParser() 184 | with pytest.raises(ValueError, match="Duplicate argument"): 185 | parser._parse_yaml_config(str(config_path)) 186 | 187 | def test_yaml_with_invalid_param(self, monkeypatch, tmp_path): 188 | config = {"invalid_param": "value"} 189 | 190 | config_path = tmp_path / "config.yaml" 191 | with open(config_path, "w") as f: 192 | yaml.dump(config, f) 193 | 194 | parser = CfArgumentParser([DummyArgs]) 195 | args = ["train", str(config_path)] 196 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 197 | 198 | with pytest.raises(SystemExit): # Printing helpstring 199 | parser.parse_args_and_config() 200 | 201 | def test_cli_only_with_defaults(self, monkeypatch): 202 | parser = CfArgumentParser([DummyArgsAB]) 203 | args = ["train", "--A_model", "bert", "--B_model", "gpt2"] 204 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 205 | 206 | (parsed_args,) = parser.parse_args_and_config() 207 | assert parsed_args.A_model == "bert" 208 | assert parsed_args.B_model == "gpt2" 209 | assert parsed_args.A_learning_rate == 0.001 # default value 210 | assert parsed_args.B_learning_rate == 0.0001 # default value 211 | assert parsed_args.shared_param == "default" # default value 212 | 213 | def test_task_only_task(self, monkeypatch): 214 | parser = CfArgumentParser([DummyArgsAB]) 215 | args = ["train"] 216 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 217 | 218 | with pytest.raises(SystemExit): 219 | parser.parse_args_and_config() 220 | 221 | def test_no_task_when_preset(self, monkeypatch): 222 | parser = CfArgumentParser([DummyArgs], task="train") 223 | args = ["--model_name", "bert"] 224 | monkeypatch.setattr(sys, "argv", ["script.py"] + args) 225 | 226 | (parsed_args,) = parser.parse_args_and_config() 227 | assert parsed_args.model_name == "bert" 228 | assert parsed_args.learning_rate == 0.001 # default value 229 | -------------------------------------------------------------------------------- /tests/trainer/test_prepare_cycle_inputs.py: -------------------------------------------------------------------------------- 1 | from types import MethodType 2 | from unittest.mock import Mock 3 | 4 | import pytest 5 | import torch 6 | from accelerate import Accelerator 7 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer 8 | 9 | from cycleformers import DEFAULT_SEP_SEQ, CycleTrainer 10 | from cycleformers.cycles import _default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs 11 | 12 | 13 | AVAILABLE_DEVICES = ["cpu"] 14 | if torch.cuda.is_available(): 15 | AVAILABLE_DEVICES.append("cuda") 16 | 17 | 18 | @pytest.mark.parametrize("device", AVAILABLE_DEVICES) 19 | class BaseTestCycleInputs: 20 | @pytest.fixture(name="cycle_trainer") 21 | def fixture_cycle_trainer(self, device): 22 | trainer = Mock(spec=CycleTrainer) 23 | accelerator = Mock(spec=Accelerator) 24 | accelerator.device = device 25 | 26 | trainer._prepare_cycle_inputs = CycleTrainer._prepare_cycle_inputs.__get__(trainer) 27 | trainer.sep_seq = DEFAULT_SEP_SEQ 28 | trainer.accelerator = accelerator 29 | return trainer 30 | 31 | @pytest.fixture(name="tokenizer") 32 | def fixture_tokenizer(self, model_name): 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | if tokenizer.pad_token_id is None: 35 | tokenizer.pad_token_id = tokenizer.eos_token_id 36 | return tokenizer 37 | 38 | # TODO: Replace these with better test examples 39 | @pytest.fixture(name="real_seq") 40 | def fixture_real_sequence(self): 41 | """Intentinally jagged to force padding (x10, x6)""" 42 | self.real_fst = 16 43 | self.real_snd = 6 44 | self.real_seq = ["A" * self.real_fst, "A" * self.real_snd] 45 | return self.real_seq 46 | 47 | @pytest.fixture(name="synth_seq") 48 | def fixture_synth_sequence(self): 49 | """Intentinally jagged to force padding (x8, x10)""" 50 | self.synth_fst = 8 51 | self.synth_snd = 20 52 | self.synth_seq = ["B" * self.synth_fst, "B" * self.synth_snd] 53 | return self.synth_seq 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "model_name", 58 | [ 59 | "trl-internal-testing/tiny-LlamaForCausalLM-3.1", # Llama tokenizer 60 | "Qwen/Qwen2.5-0.5B", 61 | # "gpt2", # GPT2 tokenizer 62 | # "facebook/opt-125m", # OPT tokenizer (GPT based) 63 | # "EleutherAI/pythia-70m", # Pythia/GPT-NeoX tokenizer (GPT based) 64 | ], 65 | ) 66 | class TestPrepareCausalCycleInputs(BaseTestCycleInputs): 67 | @pytest.fixture(name="model") 68 | def fixture_model(self, model_name): 69 | return AutoModelForCausalLM.from_pretrained(model_name) 70 | 71 | @pytest.fixture(name="causal_sample") 72 | def fixture_prepare_causal_data(self, tokenizer, real_seq, synth_seq, sep_seq=DEFAULT_SEP_SEQ): 73 | """Prepare input sequences and expected outputs for causal LM testing.""" 74 | tokenizer_kwargs = { 75 | "return_tensors": "pt", 76 | "padding": True, 77 | "truncation": True, 78 | "max_length": 512, 79 | } 80 | 81 | real_seq_with_sep = [seq + sep_seq for seq in real_seq] 82 | synth_seq_with_sep = [seq + sep_seq for seq in synth_seq] 83 | 84 | # The prompt that would be given to the generation model 85 | real_prompt_ids = tokenizer(real_seq_with_sep, **tokenizer_kwargs, padding_side="left").input_ids 86 | # PAD BOS R R R SEP 87 | 88 | # The output that would be generated as a continuation of the real input 89 | synth_seqs_eos = [synth + (tokenizer.eos_token or "") for synth in synth_seq] 90 | synth_response_ids = tokenizer( 91 | synth_seqs_eos, 92 | **tokenizer_kwargs, 93 | padding_side="right", 94 | ).input_ids 95 | # BOS S S S EOS PAD 96 | 97 | # Some tokenizers (like GPT2) don't add BOS automatically 98 | if tokenizer.bos_token_id is not None and synth_response_ids[0, 0] == tokenizer.bos_token_id: 99 | synth_response_ids = synth_response_ids[:, 1:] 100 | # S S S EOS PAD 101 | 102 | # The full sequence that would be returned from the causal model 103 | synth_output_ids = torch.cat([real_prompt_ids, synth_response_ids], dim=1) 104 | # PAD BOS R R R SEP S S S EOS PAD 105 | 106 | target_seq = [synth + real + (tokenizer.eos_token or "") for synth, real in zip(synth_seq_with_sep, real_seq)] 107 | targets = tokenizer(target_seq, **tokenizer_kwargs, padding_side="right") 108 | # BOS S S S SEP R R R EOS PAD 109 | 110 | labels = torch.clone(targets.input_ids) 111 | for i, text in enumerate(synth_seq_with_sep): 112 | prompt_ids = tokenizer.encode(text, **tokenizer_kwargs, padding_side=None) 113 | prompt_length = prompt_ids.shape[-1] 114 | labels[i, :prompt_length] = -100 115 | # Take care not to set eos token as -100 116 | labels[targets.attention_mask == 0] = -100 117 | 118 | sample_data = { 119 | "real_input_ids": real_prompt_ids, 120 | "synth_input_ids": synth_output_ids, 121 | "input_ids": targets.input_ids, 122 | "attention_mask": targets.attention_mask, 123 | "labels": labels, 124 | } 125 | return sample_data 126 | 127 | @pytest.mark.parametrize( 128 | "prepare_fn", 129 | [ 130 | _default_prepare_cycle_inputs, 131 | _prepare_causal_skip_cycle_inputs, 132 | ], 133 | ) 134 | def test_prepare_cycle_inputs(self, causal_sample, cycle_trainer, model, tokenizer, prepare_fn): 135 | # Copy the real method to our mock 136 | bound_method = MethodType(prepare_fn, cycle_trainer) 137 | setattr(cycle_trainer, "_prepare_cycle_inputs", bound_method) 138 | causal_sample = {k: v.to(cycle_trainer.accelerator.device) for k, v in causal_sample.items()} 139 | 140 | synth_batch = cycle_trainer._prepare_cycle_inputs( 141 | causal_sample["real_input_ids"], 142 | causal_sample["synth_input_ids"], 143 | model, 144 | model, 145 | tokenizer, 146 | tokenizer, 147 | "A", 148 | ) 149 | # Test outputs 150 | assert torch.allclose(synth_batch["input_ids"], causal_sample["input_ids"]) 151 | assert torch.allclose(synth_batch["attention_mask"], causal_sample["attention_mask"]) 152 | assert torch.allclose(synth_batch["labels"], causal_sample["labels"]) 153 | 154 | 155 | @pytest.mark.parametrize( 156 | "model_name", 157 | [ 158 | "google/flan-t5-small", 159 | ], 160 | ) 161 | class TestPrepareSeq2SeqCycleInputs(BaseTestCycleInputs): 162 | @pytest.fixture(name="model") 163 | def fixture_model(self, model_name): 164 | return AutoModelForSeq2SeqLM.from_pretrained(model_name) 165 | 166 | @pytest.fixture(name="seq2seq_sample") 167 | def fixture_seq2seq_data(self, tokenizer, real_seq, synth_seq): 168 | tokenizer_kwargs = { 169 | "return_tensors": "pt", 170 | "padding": True, 171 | "truncation": True, 172 | "max_length": 512, 173 | "padding_side": "right", 174 | } 175 | 176 | # The prompt that would be given to the generation model 177 | real_prompt_ids = tokenizer(real_seq, **tokenizer_kwargs) 178 | # As if the model generated the synthetic output 179 | synth_response_ids = tokenizer( 180 | [tokenizer.pad_token + synth for synth in synth_seq], 181 | **tokenizer_kwargs, 182 | ) 183 | targets = tokenizer(synth_seq, text_target=real_seq, **tokenizer_kwargs) 184 | 185 | return { 186 | "real_input_ids": real_prompt_ids.input_ids, 187 | "synth_input_ids": synth_response_ids.input_ids, 188 | "input_ids": targets.input_ids, 189 | "attention_mask": targets.attention_mask, 190 | "labels": targets.labels, 191 | } 192 | 193 | @pytest.mark.parametrize( 194 | "prepare_fn", 195 | [ 196 | _default_prepare_cycle_inputs, 197 | ], 198 | ) 199 | def test_prepare_cycle_inputs(self, seq2seq_sample, cycle_trainer, model, tokenizer, prepare_fn): 200 | # Copy the real method to our mock 201 | bound_method = MethodType(prepare_fn, cycle_trainer) 202 | setattr(cycle_trainer, "_prepare_cycle_inputs", bound_method) 203 | seq2seq_sample = {k: v.to(cycle_trainer.accelerator.device) for k, v in seq2seq_sample.items()} 204 | 205 | synth_batch = cycle_trainer._prepare_cycle_inputs( 206 | seq2seq_sample["real_input_ids"], 207 | seq2seq_sample["synth_input_ids"], 208 | model, 209 | model, 210 | tokenizer, 211 | tokenizer, 212 | "A", 213 | ) 214 | assert torch.allclose(synth_batch["input_ids"], seq2seq_sample["input_ids"]) 215 | assert torch.allclose(synth_batch["attention_mask"], seq2seq_sample["attention_mask"]) 216 | assert torch.allclose(synth_batch["labels"], seq2seq_sample["labels"]) 217 | -------------------------------------------------------------------------------- /src/cycleformers/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import MISSING, fields 3 | from functools import wraps 4 | from typing import get_type_hints 5 | 6 | from peft import LoraConfig 7 | from transformers.hf_argparser import DataClassType 8 | 9 | from .import_utils import is_liger_kernel_available 10 | 11 | 12 | if is_liger_kernel_available(): 13 | from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN # type: ignore 14 | 15 | VALID_LIGER_MODELS = MODEL_TYPE_TO_APPLY_LIGER_FN.keys() 16 | else: 17 | VALID_LIGER_MODELS = [] 18 | 19 | DEFAULT_SEP_SEQ = "\n\n" 20 | 21 | 22 | def prefixed_view(base_class: DataClassType, prefix: str): 23 | """Creates a dataclass-like decorator that provides a prefixed view of another dataclass. 24 | When instantiated, returns an instance of the original dataclass with unprefixed attributes. 25 | 26 | This decorator allows you to create a class that acts as a view of another dataclass, 27 | where all attributes are prefixed. When instantiating the prefixed class, it returns 28 | an instance of the original dataclass with the prefixes removed. 29 | 30 | Args: 31 | base_class: The original dataclass to create a view of 32 | prefix: The prefix to add to all attribute names 33 | 34 | Returns: 35 | A decorator function that creates the prefixed view class 36 | 37 | Example: 38 | >>> from dataclasses import dataclass 39 | >>> @dataclass 40 | ... class Config: 41 | ... name: str 42 | ... value: int = 42 43 | ... 44 | >>> @prefixed_view(Config, "test_") 45 | ... class TestConfig: 46 | ... pass 47 | ... 48 | >>> # The new class has prefixed type hints 49 | >>> TestConfig.__annotations__ 50 | {'test_name': , 'test_value': } 51 | >>> 52 | >>> # Creating an instance with prefixed attributes 53 | >>> config = TestConfig(test_name="example", test_value=100) 54 | >>> config 55 | Config(name='example', value=100) 56 | >>> 57 | >>> # Invalid attributes raise TypeError 58 | >>> config = TestConfig(invalid_attr="test") # doctest: +IGNORE_EXCEPTION_DETAIL 59 | Traceback (most recent call last): 60 | ... 61 | TypeError: Unexpected argument: invalid_attr 62 | >>> 63 | >>> # Default values are preserved 64 | >>> config = TestConfig(test_name="example") 65 | >>> config 66 | Config(name='example', value=42) 67 | """ 68 | 69 | def wrapper(cls): 70 | # Get original class type hints and fields 71 | base_hints = get_type_hints(base_class) 72 | base_fields = {f.name: f for f in fields(base_class)} 73 | 74 | # Create new fields with prefixed names but same types/defaults 75 | new_annotations = {} 76 | new_defaults = {} 77 | 78 | for name, type_hint in base_hints.items(): 79 | prefixed_name = f"{prefix}{name}" 80 | new_annotations[prefixed_name] = type_hint 81 | 82 | # Copy default values if they exist 83 | if name in base_fields: 84 | field = base_fields[name] 85 | if field.default is not MISSING: 86 | new_defaults[prefixed_name] = field.default 87 | if field.default_factory is not MISSING: 88 | new_defaults[prefixed_name] = field.default_factory() 89 | 90 | # Add annotations and default values to the class 91 | cls.__annotations__ = new_annotations 92 | for name, value in new_defaults.items(): 93 | setattr(cls, name, value) 94 | 95 | def __new__(cls, **kwargs): 96 | # Validate input against our prefixed fields 97 | for key in kwargs: 98 | if key not in new_annotations: 99 | raise TypeError(f"Unexpected argument: {key}") 100 | 101 | # Create mapping of unprefixed attributes 102 | unprefixed_kwargs = {key[len(prefix) :]: value for key, value in kwargs.items()} 103 | 104 | # Return instance of original dataclass 105 | return base_class(**unprefixed_kwargs) 106 | 107 | cls.__new__ = staticmethod(__new__) 108 | return cls 109 | 110 | return wrapper 111 | 112 | 113 | def auto_temp_attributes(*attrs_to_cleanup): 114 | """Decorator that automatically manages temporary attributes on a class instance. 115 | 116 | This decorator solves the issue of methods that need to temporarily modify class attributes 117 | that might be needed by other methods. It automatically sets attributes based on method 118 | parameters and restores their original values (or removes them) after the method completes. 119 | 120 | Args: 121 | *attrs_to_cleanup: Variable number of attribute names to manage 122 | 123 | Returns: 124 | Callable: Decorated function that handles attribute lifecycle 125 | 126 | Examples: 127 | >>> class MyClass: 128 | ... def __init__(self): 129 | ... self.permanent = "permanent" 130 | ... 131 | ... @auto_temp_attributes("model", "optimizer") 132 | ... def my_method(self, model, optimizer=None): 133 | ... print(f"Using {model} and {optimizer}") 134 | ... print(f"Permanent: {self.permanent}") 135 | >>> 136 | >>> obj = MyClass() 137 | >>> obj.my_method("bert", optimizer="adam") 138 | Using bert and adam 139 | Permanent: permanent 140 | >>> hasattr(obj, "model") # Attribute is cleaned up 141 | False 142 | 143 | Notes: 144 | - Original attribute values are restored after method execution 145 | - Attributes that didn't exist are removed 146 | - Works with both positional and keyword arguments 147 | - Handles exceptions by ensuring cleanup 148 | """ 149 | 150 | def decorator(func): 151 | @wraps(func) 152 | def wrapper(self, *args, **kwargs): 153 | # Store original attribute values 154 | original_values = {} 155 | for attr in attrs_to_cleanup: 156 | if hasattr(self, attr): 157 | original_values[attr] = getattr(self, attr) 158 | 159 | # Get function signature to match positional args to parameter names 160 | sig = inspect.signature(func) 161 | bound_args = sig.bind(self, *args, **kwargs) 162 | bound_args.apply_defaults() 163 | 164 | # Set attributes based on parameters 165 | for attr in attrs_to_cleanup: 166 | # Skip 'self' parameter 167 | if attr in bound_args.arguments and attr != "self": 168 | setattr(self, attr, bound_args.arguments[attr]) 169 | else: 170 | setattr(self, attr, None) 171 | 172 | try: 173 | # Execute the method 174 | result = func(self, *args, **kwargs) 175 | return result 176 | finally: 177 | # Clean up attributes 178 | for attr in attrs_to_cleanup: 179 | if attr in original_values: 180 | setattr(self, attr, original_values[attr]) 181 | else: 182 | try: 183 | delattr(self, attr) 184 | except AttributeError: 185 | pass 186 | 187 | return wrapper 188 | 189 | return decorator 190 | 191 | 192 | def get_peft_config(model_config): 193 | """Creates a PEFT LoRA configuration from a model configuration dataclass.""" 194 | if model_config.use_peft is False: 195 | return None 196 | 197 | peft_config = LoraConfig( 198 | task_type=model_config.lora_task_type, 199 | r=model_config.lora_r, 200 | target_modules=model_config.lora_target_modules, 201 | lora_alpha=model_config.lora_alpha, 202 | lora_dropout=model_config.lora_dropout, 203 | bias="none", 204 | use_rslora=model_config.use_rslora, 205 | modules_to_save=model_config.lora_modules_to_save, 206 | ) 207 | 208 | return peft_config 209 | 210 | 211 | def print_trainable_params(model): 212 | """Prints the number of trainable and all parameters in a model.""" 213 | trainable_params = 0 214 | all_param = 0 215 | for _, param in model.named_parameters(): 216 | num_params = param.numel() 217 | # if using DS Zero 3 and the weights are initialized empty 218 | if num_params == 0 and hasattr(param, "ds_numel"): 219 | num_params = param.ds_numel 220 | 221 | # Due to the design of 4bit linear layers from bitsandbytes one needs to multiply 222 | # the number of parameters by 2 to get the correct number of parameters 223 | if param.__class__.__name__ == "Params4bit": 224 | num_params = num_params * 2 225 | 226 | all_param += num_params 227 | if param.requires_grad: 228 | trainable_params += num_params 229 | print( 230 | f"trainable params: {trainable_params:,d} || " 231 | f"all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}" 232 | ) 233 | -------------------------------------------------------------------------------- /src/cycleformers/trainer_callback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Any 4 | 5 | from transformers.trainer_callback import DefaultFlowCallback, TrainerCallback, TrainerControl, TrainerState 6 | 7 | 8 | if TYPE_CHECKING: 9 | from transformers.training_args import TrainingArguments 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @dataclass 16 | class CycleTrainerState(TrainerState): 17 | """Extension of TrainerState to handle cycle-specific states""" 18 | 19 | steps_a: int = 0 # Number of updates for model A 20 | steps_b: int = 0 # Number of updates for model B 21 | loss_a: float = 0.0 # Current loss for model A 22 | loss_b: float = 0.0 # Current loss for model B 23 | 24 | @property 25 | def global_step(self) -> int: 26 | """Global step is the maximum of steps between models""" 27 | return max(self.steps_a, self.steps_b) 28 | 29 | @global_step.setter 30 | def global_step(self, value: int): 31 | """Setting global step is not allowed as it's derived from model steps""" 32 | raise AttributeError("global_step cannot be set directly. Update steps_a or steps_b instead.") 33 | 34 | def update_model_step(self, cycle_name: str): 35 | """Update step count for a specific model""" 36 | if cycle_name == "A": 37 | self.steps_a += 1 38 | elif cycle_name == "B": 39 | self.steps_b += 1 40 | else: 41 | raise ValueError(f"Invalid cycle name: {cycle_name}") 42 | 43 | def get_model_step(self, cycle_name: str) -> int: 44 | """Get current step for a specific model""" 45 | if cycle_name == "A": 46 | return self.steps_a 47 | elif cycle_name == "B": 48 | return self.steps_b 49 | else: 50 | raise ValueError(f"Invalid cycle name: {cycle_name}") 51 | 52 | def update_model_loss(self, cycle_name: str, loss: float): 53 | """Update current loss for a specific model""" 54 | if cycle_name == "A": 55 | self.loss_a = loss 56 | elif cycle_name == "B": 57 | self.loss_b = loss 58 | else: 59 | raise ValueError(f"Invalid cycle name: {cycle_name}") 60 | 61 | def get_model_loss(self, cycle_name: str) -> float: 62 | """Get current loss for a specific model""" 63 | if cycle_name == "A": 64 | return self.loss_a 65 | elif cycle_name == "B": 66 | return self.loss_b 67 | else: 68 | raise ValueError(f"Invalid cycle name: {cycle_name}") 69 | 70 | 71 | class CallbackHandler(TrainerCallback): 72 | """Internal class that just calls the list of callbacks in order.""" 73 | 74 | def __init__( 75 | self, 76 | callbacks: list[TrainerCallback], 77 | model: dict[str, Any], 78 | processing_class: dict[str, Any], 79 | optimizer: dict[str, Any], 80 | lr_scheduler: dict[str, Any], 81 | ): 82 | self.callbacks: list[TrainerCallback] = [] 83 | for cb in callbacks: 84 | self.add_callback(cb) 85 | self.model = model 86 | self.processing_class = processing_class 87 | self.optimizer = optimizer 88 | self.lr_scheduler = lr_scheduler 89 | self.train_dataloader = None 90 | self.eval_dataloader = None 91 | 92 | if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): 93 | logger.warning( 94 | "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n" 95 | + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of" 96 | + "callbacks is\n:" 97 | + self.callback_list 98 | ) 99 | 100 | def add_callback(self, callback: TrainerCallback) -> None: 101 | cb = callback() if isinstance(callback, type) else callback 102 | cb_class = callback if isinstance(callback, type) else callback.__class__ 103 | if cb_class in [c.__class__ for c in self.callbacks]: 104 | logger.warning( 105 | f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current" 106 | + "list of callbacks is\n:" 107 | + self.callback_list 108 | ) 109 | self.callbacks.append(cb) 110 | 111 | def pop_callback(self, callback): 112 | if isinstance(callback, type): 113 | for cb in self.callbacks: 114 | if isinstance(cb, callback): 115 | self.callbacks.remove(cb) 116 | return cb 117 | else: 118 | for cb in self.callbacks: 119 | if cb == callback: 120 | self.callbacks.remove(cb) 121 | return cb 122 | 123 | def remove_callback(self, callback): 124 | if isinstance(callback, type): 125 | for cb in self.callbacks: 126 | if isinstance(cb, callback): 127 | self.callbacks.remove(cb) 128 | return 129 | else: 130 | self.callbacks.remove(callback) 131 | 132 | @property 133 | def callback_list(self): 134 | return "\n".join(cb.__class__.__name__ for cb in self.callbacks) 135 | 136 | def on_init_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 137 | return self.call_event("on_init_end", args, state, control) 138 | 139 | def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 140 | control.should_training_stop = False 141 | return self.call_event("on_train_begin", args, state, control) 142 | 143 | def on_train_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 144 | return self.call_event("on_train_end", args, state, control) 145 | 146 | def on_epoch_begin(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 147 | control.should_epoch_stop = False 148 | return self.call_event("on_epoch_begin", args, state, control) 149 | 150 | def on_epoch_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 151 | return self.call_event("on_epoch_end", args, state, control) 152 | 153 | def on_step_begin(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 154 | control.should_log = False 155 | control.should_evaluate = False 156 | control.should_save = False 157 | return self.call_event("on_step_begin", args, state, control) 158 | 159 | def on_pre_optimizer_step(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 160 | return self.call_event("on_pre_optimizer_step", args, state, control) 161 | 162 | def on_optimizer_step(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 163 | return self.call_event("on_optimizer_step", args, state, control) 164 | 165 | def on_substep_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 166 | return self.call_event("on_substep_end", args, state, control) 167 | 168 | def on_step_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 169 | return self.call_event("on_step_end", args, state, control) 170 | 171 | def on_evaluate(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl, metrics): 172 | control.should_evaluate = False 173 | return self.call_event("on_evaluate", args, state, control, metrics=metrics) 174 | 175 | def on_predict(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl, metrics): 176 | return self.call_event("on_predict", args, state, control, metrics=metrics) 177 | 178 | def on_save(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 179 | control.should_save = False 180 | return self.call_event("on_save", args, state, control) 181 | 182 | def on_log(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl, logs): 183 | control.should_log = False 184 | return self.call_event("on_log", args, state, control, logs=logs) 185 | 186 | def on_prediction_step(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl): 187 | return self.call_event("on_prediction_step", args, state, control) 188 | 189 | def call_event(self, event, args, state, control, model_key="both", **kwargs): 190 | for callback in self.callbacks: 191 | if model_key == "both": 192 | result = getattr(callback, event)( 193 | args, 194 | state, 195 | control, 196 | model=self.model, 197 | processing_class=self.processing_class, 198 | optimizer=self.optimizer, 199 | lr_scheduler=self.lr_scheduler, 200 | train_dataloader=self.train_dataloader, 201 | eval_dataloader=self.eval_dataloader, 202 | **kwargs, 203 | ) 204 | # A Callback can skip the return of `control` if it doesn't change it. 205 | if result is not None: 206 | control = result 207 | return control 208 | -------------------------------------------------------------------------------- /tests/benchmark/cycle_trainer_profiling.py: -------------------------------------------------------------------------------- 1 | import cProfile 2 | import logging 3 | import pstats 4 | import sys 5 | import time 6 | from pathlib import Path 7 | 8 | import torch.cuda as cuda 9 | import torch.profiler 10 | import yaml 11 | from memory_profiler import profile as memory_profile 12 | from profiler_utils import record_function_wrapper 13 | from torch.profiler import ProfilerActivity, profile, record_function 14 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer 15 | 16 | from cycleformers import CfArgumentParser, CycleTrainer, CycleTrainingArguments, ModelConfig 17 | from cycleformers.import_utils import is_liger_kernel_available 18 | from cycleformers.model_config import ModelConfigA, ModelConfigB, merge_configs 19 | from cycleformers.task_processors.ner import CONLL2003Processor, CONLL2003ProcessorConfig 20 | from cycleformers.utils import VALID_LIGER_MODELS, get_peft_config, print_trainable_params 21 | 22 | 23 | logger = logging.getLogger(__file__) 24 | 25 | MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT = 1_000_000 26 | 27 | 28 | class ProfilingCycleTrainer(CycleTrainer): 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | self.profiling_stats = { 32 | "step_times": [], 33 | "gpu_memory": [], 34 | "max_gpu_memory": 0, 35 | } 36 | 37 | # Create unique output directory for this run 38 | self.profile_output_dir = ( 39 | Path(__file__).parent / "profiles" / f'cycle_trainer--{time.strftime("%Y%m%d_%H%M%S")}' 40 | ) 41 | self.profile_output_dir.mkdir(parents=True, exist_ok=True) 42 | 43 | self.profile_path = self.profile_output_dir / "profile.prof" 44 | self.memory_path = self.profile_output_dir / "memory_profile.txt" 45 | self.cuda_memory_path = self.profile_output_dir / "cuda_memory_snapshots" 46 | 47 | # Save args to yaml file 48 | with open(self.profile_output_dir / "profiler_args.yaml", "w") as f: 49 | yaml.dump(self.args, f) 50 | 51 | print("=" * 40) 52 | print("Model A: ", end="") 53 | print_trainable_params(self.model_A) 54 | if not self.is_macct_model: 55 | print("Model B: ", end="") 56 | print_trainable_params(self.model_B) 57 | print("=" * 40) 58 | 59 | def _log_gpu_memory(self): 60 | """Log current GPU memory usage""" 61 | if not cuda.is_available(): 62 | return None 63 | 64 | current_memory = cuda.memory_allocated() / 1024**2 # Convert to MB 65 | max_memory = cuda.max_memory_allocated() / 1024**2 # Convert to MB 66 | 67 | self.profiling_stats["gpu_memory"].append(current_memory) 68 | self.profiling_stats["max_gpu_memory"] = max(self.profiling_stats["max_gpu_memory"], max_memory) 69 | 70 | return current_memory, max_memory 71 | 72 | def train(self): 73 | torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT) 74 | 75 | # Profile entire training run 76 | profiler = cProfile.Profile() 77 | profiler.enable() 78 | 79 | # Setup PyTorch profiler for detailed GPU analysis 80 | pytorch_profiler = torch.profiler.profile( 81 | activities=[ 82 | torch.profiler.ProfilerActivity.CPU, 83 | torch.profiler.ProfilerActivity.CUDA, 84 | ], 85 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), 86 | on_trace_ready=torch.profiler.tensorboard_trace_handler(self.profile_output_dir / "tensorboard_trace"), 87 | record_shapes=True, 88 | profile_memory=True, 89 | with_stack=True, 90 | ) 91 | 92 | pytorch_profiler.start() 93 | 94 | try: 95 | result = super().train() 96 | 97 | # Collect profiling data 98 | profiler.disable() 99 | stats = pstats.Stats(profiler) 100 | stats.sort_stats("cumtime") 101 | 102 | # Save profiling data 103 | stats.dump_stats(str(self.profile_path)) 104 | 105 | # Save memory statistics 106 | with open(self.memory_path, "w") as f: 107 | f.write(f"Max GPU Memory Used: {self.profiling_stats['max_gpu_memory']:.2f} MB\n") 108 | f.write("GPU Memory Usage per Step:\n") 109 | for i, mem in enumerate(self.profiling_stats["gpu_memory"]): 110 | f.write(f"Step {i}: {mem:.2f} MB\n") 111 | 112 | return result 113 | finally: 114 | pytorch_profiler.stop() 115 | torch.cuda.memory._record_memory_history(enabled=None) 116 | try: 117 | torch.cuda.memory._dump_snapshot(f"{self.cuda_memory_path}.pickle") 118 | except Exception as e: 119 | logger.error(f"Failed to capture memory snapshot {e}") 120 | 121 | @record_function_wrapper("## Cycle Step ##") 122 | def _cycle_step(self, *args, **kwargs): 123 | return super()._cycle_step(*args, **kwargs) 124 | 125 | @record_function_wrapper("## Prepare Cycle Inputs ##") 126 | def _prepare_cycle_inputs(self, *args, **kwargs): 127 | return super().prepare_cycle_inputs(*args, **kwargs) 128 | 129 | def analyze_performance(self): 130 | """Analyze collected performance metrics""" 131 | if self.profiling_stats.get("step_times"): 132 | avg_step_time = sum(self.profiling_stats["step_times"]) / len(self.profiling_stats["step_times"]) 133 | print(f"Average step time: {avg_step_time:.4f} seconds") 134 | 135 | # GPU Memory Statistics 136 | if cuda.is_available(): 137 | print("\nGPU Memory Statistics:") 138 | print(f"Peak GPU memory usage: {self.profiling_stats['max_gpu_memory']:.2f} MB") 139 | avg_memory = sum(self.profiling_stats["gpu_memory"]) / len(self.profiling_stats["gpu_memory"]) 140 | print(f"Average GPU memory usage: {avg_memory:.2f} MB") 141 | print(f"Current GPU memory usage: {cuda.memory_allocated() / 1024**2:.2f} MB") 142 | print(f"Current GPU memory cached: {cuda.memory_reserved() / 1024**2:.2f} MB") 143 | 144 | # Load and analyze the cProfile data 145 | stats = pstats.Stats(str(self.profile_output_dir / "profile.prof")) 146 | print("\nTop 10 time-consuming functions:") 147 | stats.sort_stats("cumtime").print_stats(10) 148 | 149 | # Memory analysis tips 150 | print(f"\nProfiling data has been saved to: {self.profile_output_dir}") 151 | print("Check PyTorch profiler traces in TensorBoard for detailed GPU memory analysis") 152 | 153 | 154 | if is_liger_kernel_available(): 155 | from liger_kernel.transformers import AutoLigerKernelForCausalLM 156 | 157 | 158 | def get_model_and_tokenizer(model_config, training_args): 159 | """Initialize model and tokenizer from config""" 160 | config = AutoConfig.from_pretrained( 161 | model_config.model_name_or_path, 162 | trust_remote_code=model_config.trust_remote_code, 163 | ) 164 | config.use_cache = False 165 | 166 | model_kwargs = {} 167 | 168 | if not config.is_encoder_decoder: 169 | if is_liger_kernel_available() and model_config.use_liger and config.model_type in VALID_LIGER_MODELS: 170 | model_class = AutoLigerKernelForCausalLM 171 | model_kwargs["use_liger_kernel"] = training_args.use_liger_kernel 172 | else: 173 | model_class = AutoModelForCausalLM 174 | else: 175 | model_class = AutoModelForSeq2SeqLM 176 | 177 | model = model_class.from_pretrained( 178 | model_config.model_name_or_path, 179 | revision=model_config.model_revision, 180 | config=config, 181 | trust_remote_code=model_config.trust_remote_code, 182 | attn_implementation=model_config.attn_implementation, 183 | torch_dtype=model_config.torch_dtype, 184 | device_map="auto", 185 | ) 186 | 187 | if training_args.gradient_checkpointing: 188 | model.enable_input_require_grads() 189 | 190 | # Print the actual dtype of the first parameter 191 | print(f"Model weights dtype: {next(model.parameters()).dtype}") 192 | 193 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) 194 | return model, tokenizer 195 | 196 | 197 | def main(): 198 | sys.argv = [__file__, str(Path(__file__).parent / "profiler_configs/causal.yaml")] 199 | parser = CfArgumentParser( 200 | (CycleTrainingArguments, ModelConfig, ModelConfigA, ModelConfigB, CONLL2003ProcessorConfig), task="train" 201 | ) 202 | args, model_config_base, model_config_A, model_config_B, conll_config = parser.parse_args_and_config() 203 | model_config_base = merge_configs(model_config_base, model_config_A, model_config_B) 204 | args.model_config = model_config_base 205 | 206 | task_processor = CONLL2003Processor(conll_config) 207 | dataset_A, dataset_B = task_processor.process() 208 | 209 | model_A, tokenizer_A = get_model_and_tokenizer(args.model_config.A, args) 210 | 211 | # Train by adapter swapping 212 | if not args.use_macct: 213 | # Get model B using merged B config 214 | model_B, tokenizer_B = get_model_and_tokenizer(args.model_config.B, args) 215 | models = {"A": model_A, "B": model_B} 216 | tokenizers = {"A": tokenizer_A, "B": tokenizer_B} if tokenizer_A != tokenizer_B else tokenizer_A 217 | else: 218 | models = model_A 219 | tokenizers = tokenizer_A 220 | 221 | trainer = ProfilingCycleTrainer( 222 | args=args, 223 | models=models, 224 | tokenizers=tokenizers, 225 | train_dataset_A=dataset_A["train"], 226 | train_dataset_B=dataset_B["train"], 227 | eval_dataset_A=dataset_A["eval"] if not args.eval_strategy == "no" else None, 228 | eval_dataset_B=dataset_B["eval"] if not args.eval_strategy == "no" else None, 229 | peft_configs=get_peft_config(model_config_base), 230 | ) 231 | 232 | trainer.train() 233 | 234 | trainer.analyze_performance() 235 | 236 | 237 | if __name__ == "__main__": 238 | main() 239 | --------------------------------------------------------------------------------