├── tests
├── __init__.py
├── utils
│ ├── __init__.py
│ └── test_utils.py
├── benchmark
│ ├── __init__.py
│ ├── profiler_configs
│ │ ├── causal.yaml
│ │ └── causal_macct.yaml
│ ├── benchmark_utils.py
│ ├── profiler_utils.py
│ ├── README.md
│ ├── benchmark_tokenizer_skip.py
│ └── cycle_trainer_profiling.py
├── command
│ ├── __init__.py
│ └── test_cli_utils.py
├── configs
│ ├── __init__.py
│ └── test_model_config.py
├── trainer
│ ├── __init__.py
│ └── test_prepare_cycle_inputs.py
├── task_processors
│ ├── __init__.py
│ ├── utils.py
│ ├── test_ner_processors.py
│ └── test_translation_processors.py
├── testing_utils
│ ├── __init__.py
│ ├── custom_marks.py
│ ├── model_registry.py
│ └── test_model_registry.py
├── sample_data
│ ├── conll2003
│ │ ├── dataset_dict.json
│ │ ├── test
│ │ │ ├── data-00000-of-00001.arrow
│ │ │ ├── state.json
│ │ │ └── dataset_info.json
│ │ ├── train
│ │ │ ├── data-00000-of-00001.arrow
│ │ │ ├── state.json
│ │ │ └── dataset_info.json
│ │ └── validation
│ │ │ ├── data-00000-of-00001.arrow
│ │ │ ├── state.json
│ │ │ └── dataset_info.json
│ ├── wmt14
│ │ ├── dataset_dict.json
│ │ ├── test
│ │ │ ├── data-00000-of-00001.arrow
│ │ │ ├── state.json
│ │ │ └── dataset_info.json
│ │ ├── train
│ │ │ ├── data-00000-of-00001.arrow
│ │ │ ├── state.json
│ │ │ └── dataset_info.json
│ │ └── validation
│ │ │ ├── data-00000-of-00001.arrow
│ │ │ ├── state.json
│ │ │ └── dataset_info.json
│ └── generate_sample_data.py
├── models_to_test.yaml
└── test_examples.py
├── docs
├── examples
│ ├── index.md
│ ├── cycle_ner.md
│ └── wmt_2014.md
├── api_reference
│ ├── configuration.md
│ ├── cycles.md
│ ├── task_processors.md
│ └── cycle_trainer.md
├── conceptual_reference
│ ├── index.md
│ ├── cycle_consistency_training.md
│ ├── task_processors.md
│ └── macct.md
├── assets
│ ├── icicle_snakeviz.png
│ └── sunburst_snakeviz.png
├── installation.md
├── index.md
├── usage.md
└── performance.md
├── trufflehog.yaml
├── src
└── cycleformers
│ ├── command
│ ├── __init__.py
│ └── cli_utils.py
│ ├── cycles
│ ├── __init__.py
│ └── utils.py
│ ├── data_config.py
│ ├── import_utils.py
│ ├── __init__.py
│ ├── task_processors
│ ├── __init__.py
│ ├── base.py
│ ├── translation.py
│ └── ner.py
│ ├── cycle_trainer_utils.py
│ ├── cycle_training_arguments.py
│ ├── exceptions.py
│ ├── model_config.py
│ ├── utils.py
│ └── trainer_callback.py
├── .github
├── workflows
│ ├── trufflehog.yml
│ ├── pull-request.yml
│ ├── build-docs.yml
│ ├── release.yml
│ └── checks.yml
├── ISSUE_TEMPLATE
│ ├── question.md
│ ├── feature_request.md
│ └── bug_report.md
├── actions
│ └── setup-poetry
│ │ └── action.yml
├── badges
│ ├── build.svg
│ └── coverage.svg
└── PULL_REQUEST_TEMPLATE.md
├── .gitignore
├── examples
├── configs
│ ├── seq2seq.yaml
│ ├── causal.yaml
│ ├── causal_macct.yaml
│ └── seq2seq_macct.yaml
├── translation_wmt14
│ ├── prepare_data.py
│ └── train.py
└── cycle_ner
│ └── train.py
├── .pre-commit-config.yaml
├── Makefile
├── mkdocs.yml
├── CONTRIBUTING.md
├── README.md
└── pyproject.toml
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/command/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/trainer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/task_processors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/testing_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/examples/index.md:
--------------------------------------------------------------------------------
1 | 🚧 Under Construction 🚧
--------------------------------------------------------------------------------
/docs/examples/cycle_ner.md:
--------------------------------------------------------------------------------
1 | 🚧 Under Construction 🚧
--------------------------------------------------------------------------------
/docs/examples/wmt_2014.md:
--------------------------------------------------------------------------------
1 | 🚧 Under Construction 🚧
--------------------------------------------------------------------------------
/docs/api_reference/configuration.md:
--------------------------------------------------------------------------------
1 | 🚧 Under Construction 🚧
--------------------------------------------------------------------------------
/trufflehog.yaml:
--------------------------------------------------------------------------------
1 | detectors:
2 | SentryToken:
3 | disabled: true
4 |
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/dataset_dict.json:
--------------------------------------------------------------------------------
1 | {"splits": ["train", "validation", "test"]}
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/dataset_dict.json:
--------------------------------------------------------------------------------
1 | {"splits": ["train", "validation", "test"]}
--------------------------------------------------------------------------------
/docs/conceptual_reference/index.md:
--------------------------------------------------------------------------------
1 | Fundamental concepts are accessible via the left sidebar navigation.
--------------------------------------------------------------------------------
/tests/task_processors/utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | class TaskProcessorTestMixin:
5 | pass
6 |
--------------------------------------------------------------------------------
/docs/assets/icicle_snakeviz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/docs/assets/icicle_snakeviz.png
--------------------------------------------------------------------------------
/docs/assets/sunburst_snakeviz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/docs/assets/sunburst_snakeviz.png
--------------------------------------------------------------------------------
/src/cycleformers/command/__init__.py:
--------------------------------------------------------------------------------
1 | from .cli_utils import CfArgumentParser
2 |
3 |
4 | __all__ = ["CfArgumentParser"]
5 |
--------------------------------------------------------------------------------
/docs/api_reference/cycles.md:
--------------------------------------------------------------------------------
1 | # Cycle Implementations
2 |
3 | ## Causal to Causal Tokenizer Bypass
4 |
5 | ::: cycleformers.cycles._prepare_causal_skip_cycle_inputs
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/test/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/wmt14/test/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/train/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/wmt14/train/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/test/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/conll2003/test/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/train/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/conll2003/train/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/validation/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/wmt14/validation/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/validation/data-00000-of-00001.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wrmthorne/cycleformers/HEAD/tests/sample_data/conll2003/validation/data-00000-of-00001.arrow
--------------------------------------------------------------------------------
/docs/api_reference/task_processors.md:
--------------------------------------------------------------------------------
1 | # Task Processors
2 |
3 | ::: cycleformers.task_processors.base.BaseProcessor
4 |
5 |
6 | ## NER
7 |
8 | ### CONLL2003Processor
9 |
10 | ::: cycleformers.task_processors.ner.CONLL2003Processor
11 |
--------------------------------------------------------------------------------
/docs/api_reference/cycle_trainer.md:
--------------------------------------------------------------------------------
1 | # CycleTrainer
2 |
3 | ::: cycleformers.cycle_trainer.CycleTrainer
4 |
5 |
6 | ### Methods
7 |
8 | ::: cycleformers.cycle_trainer.CycleTrainer.train
9 |
10 | ::: cycleformers.cycle_trainer.CycleTrainer._cycle_step
11 |
--------------------------------------------------------------------------------
/src/cycleformers/cycles/__init__.py:
--------------------------------------------------------------------------------
1 | from .cycles import _default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs
2 | from .utils import PrepareCycleInputsNotSet
3 |
4 |
5 | __all__ = ["_prepare_causal_skip_cycle_inputs", "_default_prepare_cycle_inputs", "PrepareCycleInputsNotSet"]
6 |
--------------------------------------------------------------------------------
/tests/models_to_test.yaml:
--------------------------------------------------------------------------------
1 | # Causal Models
2 | tiny-llama-3.1:
3 | repo_id: "trl-internal-testing/tiny-LlamaForCausalLM-3.1"
4 |
5 | qwen-2.5:
6 | repo_id: "Qwen/Qwen2.5-0.5B"
7 | description: "Needs to be replaced with a tiny model at some point"
8 |
9 |
10 | # Seq2Seq Models
11 | tiny-t5:
12 | repo_id: "google/t5-efficient-tiny"
13 |
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/test/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "9b1ccf21d530db94",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": "test[:20]"
13 | }
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/test/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "85d5b6c8f838535a",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": "test[:20]"
13 | }
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/train/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "f2070ded563ec01c",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": "train[:100]"
13 | }
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/train/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "35806e9fdfaf6cd6",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": "train[:100]"
13 | }
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/validation/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "60eb8be03cba3fa7",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": "validation[:20]"
13 | }
--------------------------------------------------------------------------------
/.github/workflows/trufflehog.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 |
4 | name: Secret Leaks
5 |
6 | jobs:
7 | trufflehog:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout code
11 | uses: actions/checkout@v4
12 | with:
13 | fetch-depth: 0
14 | - name: Secret Scanning
15 | uses: trufflesecurity/trufflehog@main
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/validation/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "data-00000-of-00001.arrow"
5 | }
6 | ],
7 | "_fingerprint": "229dfaf63b462a8d",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_output_all_columns": false,
12 | "_split": "validation[:20]"
13 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | *.so
5 | .Python
6 | .env
7 | .venv
8 | env/
9 | venv/
10 | ENV/
11 | dist/
12 | build/
13 | *.egg-info/
14 | .coverage
15 | htmlcov/
16 | .pytest_cache/
17 | .vscode/
18 | .idea
19 | site/
20 | wandb/
21 | examples/**/data/
22 | misc/
23 | benchmark_results/
24 | .mypy_cache/
25 | .ruff_cache/
26 | profiles/
27 | *cache*
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | Cycleformers is available on PyPi:
2 |
3 | ```bash
4 | pip install cycleformers
5 | ```
6 |
7 | The module has only been tested on Linux. I have no plans to support Windows or MacOS at this time.
8 |
9 | ## Development Setup
10 |
11 | To build the module locally for development and bugfixing, I recommend the following command which initialises a .env, runs poetry install and sets up the pre-commit hooks:
12 |
13 | ```bash
14 | make init
15 | ```
--------------------------------------------------------------------------------
/src/cycleformers/cycles/utils.py:
--------------------------------------------------------------------------------
1 | class PrepareCycleInputsNotSet(Exception):
2 | """Exception raised when class has not properly set"""
3 |
4 | def __init__(self, msg: str | None = None):
5 | self.msg = (
6 | msg
7 | or "_prepare_cycle_inputs is not corrently set. If you are modifying CycleTrainer.__init__ "
8 | "make sure to call self.set_cycle_inputs_fn(), optionally with a valid function."
9 | )
10 | super().__init__(self.msg)
11 |
--------------------------------------------------------------------------------
/examples/configs/seq2seq.yaml:
--------------------------------------------------------------------------------
1 | ### ONLY REQUIRED ARGUMENTS ###
2 | output_dir: "./outputs"
3 | ###############################
4 |
5 | # === Training Arguments === #
6 | per_device_train_batch_size: 2
7 | gradient_accumulation_steps: 8
8 | per_device_eval_batch_size: 2
9 | logging_strategy: steps
10 | logging_steps: 1
11 |
12 |
13 | # === Model Configs === #
14 | model_name_or_path: "google/flan-t5-base"
15 |
16 | # Optimisation
17 | # attn_implementation: "flash_attention_2"
18 | # use_liger_kernel: true
--------------------------------------------------------------------------------
/.github/workflows/pull-request.yml:
--------------------------------------------------------------------------------
1 | name: Pull Request
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 |
8 | permissions:
9 | pull-requests: read
10 | contents: write
11 |
12 | jobs:
13 | check-pr-title:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: amannn/action-semantic-pull-request@v5
17 | env:
18 | GITHUB_TOKEN: ${{ github.token }}
19 |
20 | checks:
21 | uses: ./.github/workflows/checks.yml
22 |
23 | build-docs:
24 | uses: ./.github/workflows/build-docs.yml
--------------------------------------------------------------------------------
/tests/testing_utils/custom_marks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cycleformers.import_utils import is_liger_kernel_available, is_peft_available
4 |
5 |
6 | def requires_peft(func):
7 | """Decorator to skip test if PEFT is not available"""
8 | return pytest.mark.skipif(not is_peft_available(), reason="PEFT not available")(func)
9 |
10 |
11 | def requires_liger_kernel(func):
12 | """Decorator to skip test if Liger kernel is not available"""
13 | return pytest.mark.skipif(not is_liger_kernel_available(), reason="Liger kernel not available")(func)
14 |
--------------------------------------------------------------------------------
/src/cycleformers/data_config.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from dataclasses import dataclass
3 |
4 |
5 | @dataclass
6 | class DataConfig:
7 | """
8 | Arguments pertaining to what data we are going to input our model for training and eval.
9 | """
10 |
11 | dataset_name: str | None = None
12 | dataset_config_name: str | None = None
13 | text_column: str = "text"
14 | formatting_func: Callable | None = None
15 | max_seq_length: int | None = None
16 | remove_unused_columns: bool = True
17 | preprocessing_num_workers: int | None = None
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Question
3 | about: Ask a question about this project
4 | title: '[QUESTION] '
5 | labels: question
6 | assignees: ''
7 | ---
8 |
9 | **Question**
10 | State your question clearly and provide as much context as possible.
11 |
12 | **Context**
13 | Provide any relevant context about what you're trying to achieve:
14 | - What have you already tried?
15 | - What documentation have you already consulted?
16 | - What research have you done?
17 |
18 | **Additional Information**
19 | Add any other relevant information or screenshots about your question here.
--------------------------------------------------------------------------------
/tests/benchmark/profiler_configs/causal.yaml:
--------------------------------------------------------------------------------
1 | ### ONLY REQUIRED ARGUMENTS ###
2 | output_dir: "/tmp/profiler_outputs"
3 | ###############################
4 |
5 | # === Training Arguments === #
6 | per_device_train_batch_size: 2
7 | gradient_accumulation_steps: 8
8 | per_device_eval_batch_size: 2
9 | logging_strategy: steps
10 | logging_steps: 1
11 | report_to: "tensorboard"
12 |
13 |
14 | max_steps: 100
15 |
16 |
17 | # === Model Configs === #
18 | model_name_or_path: "Qwen/Qwen2.5-1.5B"
19 |
20 |
21 | # Optimisation
22 | # attn_implementation: "flash_attention_2"
23 | # use_liger_kernel: true
--------------------------------------------------------------------------------
/examples/configs/causal.yaml:
--------------------------------------------------------------------------------
1 | ### ONLY REQUIRED ARGUMENTS ###
2 | output_dir: "./outputs"
3 | ###############################
4 |
5 | # === Training Arguments === #
6 | per_device_train_batch_size: 1
7 | gradient_accumulation_steps: 16
8 | per_device_eval_batch_size: 1
9 | gradient_checkpointing: true
10 | logging_strategy: steps
11 | logging_steps: 1
12 |
13 |
14 | # === Model Configs === #
15 | A_model_name_or_path: "Qwen/Qwen2.5-1.5B"
16 | A_torch_dtype: "bfloat16"
17 |
18 | B_model_name_or_path: "Qwen/Qwen2.5-1.5B"
19 | B_torch_dtype: "bfloat16"
20 |
21 |
22 | # Optimisation
23 | # attn_implementation: "flash_attention_2"
24 | # use_liger_kernel: true
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_install_hook_types: [ pre-commit, pre-push ]
2 |
3 | repos:
4 | - repo: local
5 | hooks:
6 | - id: make-format
7 | name: Run formatting
8 | entry: make
9 | args: [ format ]
10 | language: system
11 | pass_filenames: false
12 |
13 | - id: make-lint
14 | name: Run linting
15 | entry: make
16 | args: [ lint ]
17 | language: system
18 | pass_filenames: false
19 |
20 | - id: build-docs
21 | name: Build documentation
22 | entry: make
23 | args: [ build-docs ]
24 | language: system
25 | pass_filenames: false
26 | stages: [ pre-push ]
27 |
--------------------------------------------------------------------------------
/examples/configs/causal_macct.yaml:
--------------------------------------------------------------------------------
1 | ### ONLY REQUIRED ARGUMENTS ###
2 | output_dir: "./outputs"
3 | ###############################
4 |
5 | # === Training Arguments === #
6 | per_device_train_batch_size: 4
7 | gradient_accumulation_steps: 4
8 | per_device_eval_batch_size: 4
9 | logging_strategy: steps
10 | logging_steps: 1
11 |
12 |
13 | # === Model Configs === #
14 | model_name_or_path: "Qwen/Qwen2.5-3B"
15 |
16 | # Lora
17 | use_peft: true
18 | lora_r: 32
19 | lora_alpha: 64
20 | use_rslora: true
21 | lora_dropout: 0.05
22 | lora_target_modules: "all-linear"
23 | lora_task_type: "CAUSAL_LM"
24 |
25 | # Optimisation
26 | # attn_implementation: "flash_attention_2"
27 | # use_liger_kernel: true
--------------------------------------------------------------------------------
/src/cycleformers/import_utils.py:
--------------------------------------------------------------------------------
1 | from transformers.utils.import_utils import _is_package_available
2 |
3 |
4 | _liger_kernel_available = _is_package_available("liger-kernel")
5 | _flash_attn_available = _is_package_available("flash-attn")
6 | _peft_available = _is_package_available("peft")
7 | _wandb_available = _is_package_available("wandb")
8 |
9 |
10 | def is_liger_kernel_available() -> bool:
11 | return _liger_kernel_available
12 |
13 |
14 | def is_flash_attn_available() -> bool:
15 | return _flash_attn_available
16 |
17 |
18 | def is_peft_available() -> bool:
19 | return _peft_available
20 |
21 |
22 | def is_wandb_available() -> bool:
23 | return _wandb_available
24 |
--------------------------------------------------------------------------------
/examples/configs/seq2seq_macct.yaml:
--------------------------------------------------------------------------------
1 | ### ONLY REQUIRED ARGUMENTS ###
2 | output_dir: "./outputs"
3 | ###############################
4 |
5 | # === Training Arguments === #
6 | per_device_train_batch_size: 2
7 | gradient_accumulation_steps: 8
8 | per_device_eval_batch_size: 2
9 | logging_strategy: steps
10 | logging_steps: 1
11 |
12 |
13 | # === Model Configs === #
14 | model_name_or_path: "google/flan-t5-base"
15 |
16 | # Lora
17 | use_peft: true
18 | lora_r: 32
19 | lora_alpha: 64
20 | use_rslora: true
21 | lora_dropout: 0.05
22 | lora_target_modules: "all-linear"
23 | lora_task_type: "CAUSAL_LM"
24 |
25 | # Optimisation
26 | # attn_implementation: "flash_attention_2"
27 | # use_liger_kernel: true
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature Request
3 | about: Suggest an idea for this project
4 | title: '[FEATURE] '
5 | labels: enhancement
6 | assignees: ''
7 | ---
8 |
9 | **Is your feature request related to a problem? Please describe.**
10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
11 |
12 | **Describe the solution you'd like**
13 | A clear and concise description of what you want to happen.
14 |
15 | **Describe alternatives you've considered**
16 | A clear and concise description of any alternative solutions or features you've considered.
17 |
18 | **Additional context**
19 | Add any other context or screenshots about the feature request here.
--------------------------------------------------------------------------------
/.github/workflows/build-docs.yml:
--------------------------------------------------------------------------------
1 | name: Build Documentation
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | deploy:
7 | type: boolean
8 | description: "If true, the docs will be deployed."
9 | default: false
10 |
11 | jobs:
12 | build-docs:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 |
17 | - uses: ./.github/actions/setup-poetry
18 | with:
19 | python-version: 3.11
20 |
21 | - name: Build & maybe deploy documentation
22 | run: |
23 | poetry run mkdocs build --verbose --clean
24 | if ${{ inputs.deploy }}; then
25 | poetry run mkdocs gh-deploy --force
26 | fi
27 |
--------------------------------------------------------------------------------
/tests/benchmark/profiler_configs/causal_macct.yaml:
--------------------------------------------------------------------------------
1 | ### ONLY REQUIRED ARGUMENTS ###
2 | output_dir: "/tmp/profiler_outputs"
3 | ###############################
4 |
5 | # === Training Arguments === #
6 | use_macct: true
7 |
8 | per_device_train_batch_size: 1
9 | gradient_accumulation_steps: 16
10 | per_device_eval_batch_size: 1
11 | logging_strategy: steps
12 | logging_steps: 1
13 | report_to: "tensorboard"
14 |
15 |
16 | max_steps: 100
17 |
18 |
19 | # === Model Configs === #
20 | model_name_or_path: "Qwen/Qwen2.5-0.5B"
21 |
22 |
23 | # === LoRA Configs === #
24 | lora_r: 16
25 | lora_alpha: 32
26 | lora_dropout: 0.05
27 | task_type: "CAUSAL_LM"
28 | use_rslora: true
29 |
30 |
31 | # Optimisation
32 | # attn_implementation: "flash_attention_2"
33 | # use_liger_kernel: true
--------------------------------------------------------------------------------
/src/cycleformers/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
2 |
3 | from .command import CfArgumentParser
4 | from .cycle_trainer import CycleTrainer
5 | from .cycle_training_arguments import CycleTrainingArguments
6 | from .data_config import DataConfig
7 | from .exceptions import InvalidCycleKeyError, MACCTModelError, MissingModelError
8 | from .model_config import ModelConfig, ModelConfigA, ModelConfigB, merge_configs
9 | from .utils import DEFAULT_SEP_SEQ
10 |
11 |
12 | __all__ = [
13 | "CycleTrainer",
14 | "CycleTrainingArguments",
15 | "ModelConfig",
16 | "ModelConfigA",
17 | "ModelConfigB",
18 | "DataConfig",
19 | "MACCTModelError",
20 | "MissingModelError",
21 | "InvalidCycleKeyError",
22 | "DEFAULT_SEP_SEQ",
23 | "CfArgumentParser",
24 | "merge_configs",
25 | ]
26 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug Report
3 | about: Create a report to help us improve
4 | title: '[BUG] '
5 | labels: bug
6 | assignees: ''
7 | ---
8 |
9 | **Describe the bug**
10 | A clear and concise description of what the bug is.
11 |
12 | **To Reproduce**
13 | Steps to reproduce the behavior:
14 | 1. Go to '...'
15 | 2. Click on '....'
16 | 3. Scroll down to '....'
17 | 4. See error
18 |
19 | **Expected behavior**
20 | A clear and concise description of what you expected to happen.
21 |
22 | **Screenshots**
23 | If applicable, add screenshots to help explain your problem.
24 |
25 | **Environment (please complete the following information):**
26 | - OS: [e.g. iOS, Windows 10]
27 | - Browser: [e.g. chrome, safari]
28 | - Version: [e.g. 22]
29 |
30 | **Additional context**
31 | Add any other context about the problem here.
--------------------------------------------------------------------------------
/src/cycleformers/task_processors/__init__.py:
--------------------------------------------------------------------------------
1 | """Processors module for transforming various dataset formats into cycleformers-compatible format.
2 |
3 | This module provides a flexible API for converting different dataset formats into cycleformers-compatible format.
4 | Each processor handles a specific dataset format or task type (e.g., NER, machine translation, etc.) while sharing
5 | common functionality through the base class.
6 | """
7 |
8 | from .base import BaseProcessor, ProcessorConfig
9 | from .ner import CONLL2003Processor, CONLL2003ProcessorConfig
10 | from .translation import TranslationProcessor, TranslationProcessorConfig
11 |
12 |
13 | __all__ = [
14 | "BaseProcessor",
15 | "CONLL2003Processor",
16 | "CONLL2003ProcessorConfig",
17 | "ProcessorConfig",
18 | "TranslationProcessor",
19 | "TranslationProcessorConfig",
20 | ]
21 |
--------------------------------------------------------------------------------
/.github/actions/setup-poetry/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Set up Poetry and install'
2 | description: 'Set up a specific version of Poetry and install dependencies using caching.'
3 | inputs:
4 | python-version:
5 | description: "Version range or exact version of Python or PyPy to use, using SemVer's version range syntax."
6 | default: 3.11
7 | install-dependencies:
8 | description: "If true, dependencies will be installed."
9 | default: true
10 |
11 | runs:
12 | using: 'composite'
13 | steps:
14 | - name: Install poetry
15 | run: pipx install poetry==1.8.3
16 | shell: bash
17 |
18 | - uses: actions/setup-python@v5
19 | with:
20 | python-version: ${{ inputs.python-version }}
21 | cache: 'poetry'
22 |
23 | - name: Install dependencies
24 | if: ${{ inputs.install-dependencies }}
25 | run: poetry install --all-extras
26 | shell: bash
--------------------------------------------------------------------------------
/tests/sample_data/generate_sample_data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from datasets import DatasetDict, load_dataset
4 |
5 |
6 | CURR_DIR = Path(__file__).parent
7 |
8 | dataset_configs = {
9 | "wmt14": {"name": "wmt14", "config": "de-en"}, # Specify language pair
10 | "conll2003": {"name": "conll2003", "config": None},
11 | }
12 |
13 | for dataset_info in dataset_configs.values():
14 | dataset_dict = DatasetDict()
15 | for split in ["train[:100]", "validation[:20]", "test[:20]"]:
16 | # Use streaming to avoid downloading entire dataset
17 | dataset = load_dataset(
18 | dataset_info["name"],
19 | dataset_info["config"],
20 | split=split,
21 | )
22 | dataset_dict[split.split("[")[0]] = dataset
23 |
24 | # Save the small sample dataset
25 | output_path = CURR_DIR / f"{dataset_info['name']}"
26 | dataset_dict.save_to_disk(output_path)
27 |
--------------------------------------------------------------------------------
/tests/benchmark/benchmark_utils.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 |
4 | from transformers import PreTrainedTokenizerBase
5 |
6 | from cycleformers import CycleTrainer, CycleTrainingArguments
7 |
8 |
9 | @dataclass
10 | class BenchmarkConfig:
11 | config_name: str
12 | output_dir: str = "benchmark_results/"
13 |
14 |
15 | class Benchmark(ABC):
16 | def __init__(self, config: BenchmarkConfig, *args, **kwargs):
17 | self.config = config
18 |
19 | @abstractmethod
20 | def run(self, *args, **kwargs):
21 | raise NotImplementedError
22 |
23 |
24 | class MockCycleTrainer(CycleTrainer):
25 | """Lightweight mock of CycleTrainer for benchmarking"""
26 |
27 | def __init__(self, tokenizer: PreTrainedTokenizerBase):
28 | self.args = CycleTrainingArguments(output_dir="./tmp")
29 | self.tokenizer_A = self.tokenizer_B = tokenizer
30 | self.sep_seq = "\n\n"
31 |
--------------------------------------------------------------------------------
/src/cycleformers/cycle_trainer_utils.py:
--------------------------------------------------------------------------------
1 | from os import PathLike
2 | from typing import Any
3 |
4 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, PreTrainedModel
5 |
6 | from cycleformers.exceptions import CycleModelError
7 |
8 |
9 | def load_model(model_path: str | PathLike[str], **model_init_kwargs: dict[str, Any]) -> PreTrainedModel:
10 | auto_config = AutoConfig.from_pretrained(model_path)
11 | if "ForCausalLM" in auto_config.model_type:
12 | model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
13 | elif auto_config.is_encoder_decoder:
14 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path, **model_init_kwargs)
15 | else:
16 | raise CycleModelError(
17 | "Unsupported or unrecognised model type. Make sure the provided model is either "
18 | "CausalLM or Seq2SeqLM. If you are using a custom model, you may need to pass the instantiated model to "
19 | "CycleTrainer."
20 | )
21 |
22 | # TODO: Handle quantisation
23 | return model
24 |
--------------------------------------------------------------------------------
/.github/badges/build.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/badges/coverage.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/cycleformers/cycle_training_arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import TYPE_CHECKING, Any
3 |
4 | from transformers.training_args import TrainingArguments
5 |
6 |
7 | if TYPE_CHECKING:
8 | from .model_config import ModelConfig
9 |
10 |
11 | @dataclass
12 | class CycleTrainingArguments(TrainingArguments):
13 | """Will eventually contain cycle-specific arguments"""
14 |
15 | use_macct: bool = False
16 | report_to: list[str] | None = None
17 | sep_seq: str = field(default="") # Separator sequence
18 | model_init_kwargs: dict[str, Any] = field(default_factory=dict)
19 |
20 | def __post_init__(self):
21 | super().__post_init__()
22 | self._model_config = None
23 |
24 | @property
25 | def model_config(self) -> "ModelConfig":
26 | return self._model_config
27 |
28 | @model_config.setter
29 | def model_config(self, value: "ModelConfig"):
30 | self._model_config = value
31 |
32 |
33 | @dataclass
34 | class ModelTrainingArguments:
35 | """Will eventually contain model-specific arguments"""
36 |
37 | pass
38 |
39 |
40 | __all__ = ["CycleTrainingArguments", "ModelTrainingArguments"]
41 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPi
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | release:
8 | name: Create Release
9 | runs-on: ubuntu-latest
10 | environment:
11 | name: pypi
12 | url: https://pypi.org/project/cycleformers
13 | permissions:
14 | id-token: write
15 | contents: write
16 | issues: write
17 | pull-requests: write
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | with:
22 | fetch-depth: 0
23 |
24 | - uses: ./.github/actions/setup-poetry
25 | with:
26 | python-version: 3.11
27 |
28 | - name: Run quality checks
29 | uses: ./.github/workflows/checks.yml
30 |
31 | - name: Python Semantic Release
32 | env:
33 | GH_TOKEN: ${{ github.token }}
34 | run: |
35 | pip install python-semantic-release
36 | git config user.name github-actions
37 | git config user.email github-actions@github.com
38 | semantic-release version
39 | semantic-release publish
40 |
41 | documentation:
42 | name: Build & deploy documentation
43 | uses: ./.github/workflows/build-docs.yml
44 | with:
45 | deploy: true
46 |
--------------------------------------------------------------------------------
/docs/conceptual_reference/cycle_consistency_training.md:
--------------------------------------------------------------------------------
1 | ## Cycle-consistency Training
2 |
3 | Cycle-consistency training (CCT) creates a closed feedback loop between source and target domains by linking two models together during training. Each model implements the inverse function of the other i.e. $f(g(x)) = x$, for example one that translates English to German and the other German to English. Both models are trained jointly trained using a single optimiser on the round-trip reconstruction of input data from non-parallel dataset. Quite often, cycle-consistency loss is used as a secondary loss component such as in [Zhu et al. 2017](https://arxiv.org/abs/1703.10593) which also uses adversarial loss.
4 |
5 | 📈 GRAPH TO BE INSERTED HERE
6 |
7 | For tasks such as text generation, we must sample discrete tokens from the model's continuous output distribution, breaking the gradient flow. In these settings, $f(g(x))$ is non-differentiable, preventing us from using the standard CCT loss. While CCT enforces a closed loop within each training batch, iterative back translation (IBT) avoids the same optimization issue by using one model to generate synthetic parallel data for the other to use as input. Each cycle therefore has a separate loss function and optimiser, alternating between training each model ([Gou et al. 2020](https://arxiv.org/abs/2006.04702)).
8 |
9 | 📈 GRAPH TO BE INSERTED HERE
10 |
11 |
12 |
--------------------------------------------------------------------------------
/docs/conceptual_reference/task_processors.md:
--------------------------------------------------------------------------------
1 | # Task Processors
2 |
3 | 🚧 This section is under construction. 🚧
4 |
5 | When using task_processors, which splits are downloaded and the size of them can be controlled via the datasets `slice split` syntax. For example, the following will download the first 100 samples from the train split of the WMT14 dataset and the first 30 samples from the test split.
6 |
7 | ```python
8 | from datasets import load_dataset
9 | from cycleformers.task_processors.translation import TranslationProcessor
10 |
11 | config = TranslationProcessorConfig(
12 | dataset_name="wmt14",
13 | dataset_config_name="de-en",
14 | split=["train[:100]", "test[:30]"],
15 | )
16 | dataset = TranslationProcessor(config)
17 | ```
18 |
19 | More information on the syntax for `slice split` can be found in the [datasets documentation](https://huggingface.co/docs/datasets/loading#slice-splits).
20 |
21 | ## BaseProcessor
22 |
23 | ::: src.cycleformers.task_processors.base.BaseProcessor
24 |
25 | ::: src.cycleformers.task_processors.base.ProcessorConfig
26 |
27 |
28 | ## Named-Entity Recognition (NER)
29 |
30 | ::: src.cycleformers.task_processors.ner.CONLL2003Processor
31 |
32 | ::: src.cycleformers.task_processors.ner.CONLL2003ProcessorConfig
33 |
34 |
35 |
36 | ### Helper Functions
37 |
38 | ::: src.cycleformers.task_processors.ner.reconstruct_sentence
39 |
40 | ::: src.cycleformers.task_processors.ner.ner_to_sequences
41 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Cycleformers
2 |
3 |
4 |
5 | [](https://www.python.org/downloads/)
6 | [](https://pypi.org/project/cycleformers/)
7 | [](https://creativecommons.org/licenses/by/4.0/)
8 | [](https://codecov.io/gh/wrmthorne/cycleformers)
9 | [](https://github.com/wrmthorne/cycleformers/actions/workflows)
10 |
11 |
12 |
13 | A Python library for efficient cycle-consistency training of transformer models. Cycleformers simplifies iterative back-translation with support for both causal and seq2seq architectures. We also implement Multi-Adapter Cycle-Consistency Training (MACCT), enabling training of LoRA adapters on a frozen base model for `7.5x` larger model capacity for the same memory footprint.
14 |
15 | ## Features
16 |
17 | - 🤗 Seamless integration with Hugging Face Transformers
18 | - 🚀 PEFT/LoRA support for memory-efficient training
19 | - 🤖 Compatible with both causal and seq2seq models
20 | - 🔥 Optimized for various hardware configurations
21 |
22 | ## Documentation
23 |
24 | - [Conceptual Reference](conceptual_reference/task_processors.md)
25 | - [API Reference](api_reference/task_processors.md)
26 | - [Examples](examples/translation_wmt14/train.py)
27 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # What does this PR do?
2 |
3 |
12 |
13 |
14 |
15 | Fixes # (issue)
16 |
17 |
18 | ## Before submitting
19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
20 | - [ ] Did you read the [contributor guideline](https://github.com/wrmthorne/cycleformers/blob/main/CONTRIBUTING.md)?
21 | - [ ] Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
22 | - [ ] Did you make sure to update the documentation with your changes? Here are the
23 | [documentation guidelines](https://github.com/wrmthorne/cycleformers/tree/main/docs).
24 | - [ ] Did you write any new necessary tests?
25 |
26 |
27 | ## Who can review?
28 |
29 | Please tag @wrmthorne for review. Request can be merged one review is provided and all checks have passed.
30 |
--------------------------------------------------------------------------------
/examples/translation_wmt14/prepare_data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from datasets import DatasetDict, load_dataset
4 |
5 |
6 | def prepare_wmt14_en_de_datasets(num_samples=5000, test_size=100):
7 | dataset = load_dataset("wmt/wmt14", split="train")
8 | dataset = dataset.map(
9 | lambda x: {"english": x["translation"]["en"], "german": x["translation"]["de"]}
10 | ).remove_columns(["translation"])
11 |
12 | # Pull out sample of paired data for evaluation
13 | dataset = dataset.train_test_split(test_size=test_size)
14 |
15 | en_train = (
16 | dataset["train"]
17 | .select_columns("english")
18 | .shuffle(seed=42) # Shuffle to simulate not having any exact translations
19 | .select(range(num_samples))
20 | .rename_column("english", "text")
21 | )
22 | en_eval = dataset["test"].rename_columns({"english": "text", "german": "labels"})
23 | dataset_en = DatasetDict({"train": en_train, "test": en_eval})
24 |
25 | de_train = (
26 | dataset["train"]
27 | .select_columns("german")
28 | .shuffle(seed=0) # Shuffle to simulate not having any exact translations
29 | .select(range(num_samples))
30 | .rename_column("german", "text")
31 | )
32 | de_eval = dataset["test"].rename_columns({"german": "text", "english": "labels"})
33 | dataset_de = DatasetDict({"train": de_train, "test": de_eval})
34 |
35 | return dataset_en, dataset_de
36 |
37 |
38 | if __name__ == "__main__":
39 | dataset_en, dataset_de = prepare_wmt14_en_de_datasets()
40 |
41 | Path("./data").mkdir(exist_ok=True)
42 |
43 | dataset_en.save_to_disk("./data/en")
44 | dataset_de.save_to_disk("./data/de")
45 |
--------------------------------------------------------------------------------
/tests/benchmark/profiler_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | from functools import wraps
3 |
4 | import torch
5 |
6 |
7 | def record_function_wrapper(name=None):
8 | """Decorator to profile a function's execution time and memory usage.
9 |
10 | Args:
11 | name (str, optional): Custom name for the profiling label. Defaults to function name.
12 | """
13 |
14 | def decorator(func):
15 | @wraps(func)
16 | def wrapper(self, *args, **kwargs):
17 | step_start_time = time.perf_counter()
18 |
19 | # Clear cache before step if CUDA is available
20 | if torch.cuda.is_available():
21 | torch.cuda.empty_cache()
22 |
23 | profile_name = name or f"## {func.__name__} ##"
24 | with torch.profiler.record_function(profile_name):
25 | result = func(self, *args, **kwargs)
26 |
27 | # Track timing and memory if the object has profiling_stats
28 | if hasattr(self, "profiling_stats"):
29 | step_duration = time.perf_counter() - step_start_time
30 | self.profiling_stats["step_times"].append(step_duration)
31 |
32 | # Log GPU memory
33 | if hasattr(self, "_log_gpu_memory"):
34 | current_memory, max_memory = self._log_gpu_memory() or (0, 0)
35 |
36 | # Update metrics if result is a dict
37 | if isinstance(result, dict):
38 | result.update(
39 | {
40 | "gpu_memory_used": current_memory,
41 | "gpu_memory_max": max_memory,
42 | }
43 | )
44 |
45 | return result
46 |
47 | return wrapper
48 |
49 | return decorator
50 |
--------------------------------------------------------------------------------
/.github/workflows/checks.yml:
--------------------------------------------------------------------------------
1 | name: Code quality checks & tests
2 | on:
3 | workflow_call:
4 |
5 | permissions:
6 | contents: write
7 |
8 | jobs:
9 | checks:
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: [3.11, 3.12]
14 |
15 | steps:
16 | - uses: actions/checkout@v4
17 |
18 | - uses: ./.github/actions/setup-poetry
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 |
22 | # Cheapest check first
23 | - name: Run audit
24 | run: make audit
25 |
26 | - name: Run quality checks
27 | run: make quality
28 |
29 | - name: Run doctests
30 | run: poetry run python -m doctest **/*.py
31 |
32 | - name: Run tests with coverage
33 | run: >
34 | poetry run pytest
35 | --cov=src/cycleformers
36 | --cov=examples
37 | --cov-report=xml
38 | --cov-report=term-missing tests/
39 | tests/
40 | --all
41 | --instafail
42 | -n auto
43 | --junitxml=junit.xml
44 | # --cov-fail-under=80
45 |
46 | - name: Generate coverage badge
47 | run: |
48 | mkdir -p .github/badges
49 | poetry run genbadge coverage -i coverage.xml -o .github/badges/coverage.svg
50 | poetry run genbadge tests -i junit.xml -o .github/badges/build.svg
51 |
52 | - name: Commit badge
53 | if: github.event_name == 'pull_request'
54 | run: |
55 | git config --local user.email "github-actions[bot]@users.noreply.github.com"
56 | git config --local user.name "github-actions[bot]"
57 | # Fetch and checkout the PR branch
58 | git fetch origin ${{ github.head_ref }}
59 | git checkout -f ${{ github.head_ref }}
60 | git add .github/badges/coverage.svg .github/badges/build.svg
61 | git commit -m "Update build and coverage badges" || exit 0
62 | git push origin ${{ github.head_ref }}
63 |
--------------------------------------------------------------------------------
/tests/benchmark/README.md:
--------------------------------------------------------------------------------
1 | # Profiling Process
2 |
3 | ## Compute and Runtime Performance
4 |
5 | 1) Set up the appropriate config yaml in `profiler_configs/` and make any modifications to the script as needed.
6 | 2) Run the script as you would any of the examples e.g.
7 | ```bash
8 | python cycle_trainer_profiling.py profiler_configs/causal.yaml
9 | ```
10 |
11 | 3) The script will generate a folder in `cycle_trainer_profiles/` with the start date and time of the run.
12 | 4) Snakeviz can be used to visualise the flame graph for the run:
13 | ```bash
14 | # If working on remote machine
15 | ssh -L 8080:localhost:8080 @
16 |
17 | # Run snakeviz
18 | snakeviz cycle_trainer_profiles//profile.prof --server
19 | >> flame graph can be viewed at http://localhost:8080/
20 | ```
21 |
22 | ### The Sunburst View
23 |
24 | Read from the center outwards where each ring represents a level of function calls. The size (arc length) of each segment shows how much time that function took. The colors help distinguish different segments. Inner segments contain the functions that called them (outer segments).
25 |
26 | 
27 |
28 | ### Icicle View
29 |
30 | Each row represents a level in the call stack where the width of each bar represents the time taken by that function. Functions at the top call the functions below them. The full width represents 100% of the execution time.
31 |
32 | 
33 |
34 |
35 | ## Memory Usage
36 |
37 | Memory stats are recorded while computing the runtime stats. To view the CUDA memory snapshots, visit the [pytorch memory viz webpage](https://pytorch.org/memory_viz) and drag and drop in your `cuda_memory_snapshots.pickle` file.
38 |
39 |
40 | ### References
41 | - [Understanding GPU Memory 1: Visualizing All Allocations over Time](https://pytorch.org/blog/understanding-gpu-memory-1/)
42 | - [Understanding GPU Memory 2: Finding and Removing Reference Cycles](https://pytorch.org/blog/understanding-gpu-memory-2/)
--------------------------------------------------------------------------------
/src/cycleformers/exceptions.py:
--------------------------------------------------------------------------------
1 | class CycleModelError(Exception):
2 | """Exception raised when models are not properly configured for CycleTrainer.
3 |
4 | This exception is raised when models are either not provided or are invalid for use with CycleTrainer.
5 | """
6 |
7 | def __init__(self, message="CycleTrainer is missing valid models for training."):
8 | self.message = message
9 | super().__init__(self.message)
10 |
11 |
12 | class MissingModelError(CycleModelError):
13 | """Exception raised when a model is not provided for CycleTrainer.
14 |
15 | This exception is raised when a model is not provided for CycleTrainer.
16 | """
17 |
18 | default_message = (
19 | "CycleTrainer requires two models to train but only received a single model. "
20 | "To train using adapter-switching, set `args.use_macct = True`. Otherwise, pass two separate models "
21 | "for cycle training."
22 | )
23 |
24 | def __init__(self, message=None):
25 | self.message = message or self.default_message
26 | super().__init__(self.message)
27 |
28 |
29 | class MACCTModelError(Exception):
30 | """Exception raised when MACCT models are not properly configured for CycleTrainer.
31 |
32 | This exception is raised when MACCT models are either not provided or are invalid for use with CycleTrainer.
33 | """
34 |
35 | def __init__(self, message="There is something wrong with the MACCT model or configuration provided."):
36 | self.message = message
37 | super().__init__(self.message)
38 |
39 |
40 | class InvalidCycleKeyError(Exception):
41 | """Exception raised when an invalid model key is provided to CycleTrainer.
42 |
43 | This exception is raised when a cycle key other than 'A' or 'B' is provided when configuring or accessing
44 | models in CycleTrainer.
45 | """
46 |
47 | def __init__(self, message="Invalid cycle key provided. Only 'A' and 'B' are valid keys for cycle training."):
48 | self.message = message
49 | super().__init__(self.message)
50 |
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/test/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": "parquet",
3 | "citation": "",
4 | "config_name": "de-en",
5 | "dataset_name": "wmt14",
6 | "dataset_size": 1359920533,
7 | "description": "",
8 | "download_checksums": {
9 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00000-of-00003.parquet": {
10 | "num_bytes": 279527247,
11 | "checksum": null
12 | },
13 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00001-of-00003.parquet": {
14 | "num_bytes": 264970514,
15 | "checksum": null
16 | },
17 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00002-of-00003.parquet": {
18 | "num_bytes": 272987200,
19 | "checksum": null
20 | },
21 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/validation-00000-of-00001.parquet": {
22 | "num_bytes": 473798,
23 | "checksum": null
24 | },
25 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/test-00000-of-00001.parquet": {
26 | "num_bytes": 508753,
27 | "checksum": null
28 | }
29 | },
30 | "download_size": 818467512,
31 | "features": {
32 | "translation": {
33 | "languages": [
34 | "de",
35 | "en"
36 | ],
37 | "_type": "Translation"
38 | }
39 | },
40 | "homepage": "",
41 | "license": "",
42 | "size_in_bytes": 2178388045,
43 | "splits": {
44 | "train": {
45 | "name": "train",
46 | "num_bytes": 1358406800,
47 | "num_examples": 4508785,
48 | "shard_lengths": [
49 | 1521929,
50 | 1721928,
51 | 1264928
52 | ],
53 | "dataset_name": "wmt14"
54 | },
55 | "validation": {
56 | "name": "validation",
57 | "num_bytes": 736407,
58 | "num_examples": 3000,
59 | "dataset_name": "wmt14"
60 | },
61 | "test": {
62 | "name": "test",
63 | "num_bytes": 777326,
64 | "num_examples": 3003,
65 | "dataset_name": "wmt14"
66 | }
67 | },
68 | "version": {
69 | "version_str": "0.0.0",
70 | "major": 0,
71 | "minor": 0,
72 | "patch": 0
73 | }
74 | }
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/train/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": "parquet",
3 | "citation": "",
4 | "config_name": "de-en",
5 | "dataset_name": "wmt14",
6 | "dataset_size": 1359920533,
7 | "description": "",
8 | "download_checksums": {
9 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00000-of-00003.parquet": {
10 | "num_bytes": 279527247,
11 | "checksum": null
12 | },
13 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00001-of-00003.parquet": {
14 | "num_bytes": 264970514,
15 | "checksum": null
16 | },
17 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00002-of-00003.parquet": {
18 | "num_bytes": 272987200,
19 | "checksum": null
20 | },
21 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/validation-00000-of-00001.parquet": {
22 | "num_bytes": 473798,
23 | "checksum": null
24 | },
25 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/test-00000-of-00001.parquet": {
26 | "num_bytes": 508753,
27 | "checksum": null
28 | }
29 | },
30 | "download_size": 818467512,
31 | "features": {
32 | "translation": {
33 | "languages": [
34 | "de",
35 | "en"
36 | ],
37 | "_type": "Translation"
38 | }
39 | },
40 | "homepage": "",
41 | "license": "",
42 | "size_in_bytes": 2178388045,
43 | "splits": {
44 | "train": {
45 | "name": "train",
46 | "num_bytes": 1358406800,
47 | "num_examples": 4508785,
48 | "shard_lengths": [
49 | 1521929,
50 | 1721928,
51 | 1264928
52 | ],
53 | "dataset_name": "wmt14"
54 | },
55 | "validation": {
56 | "name": "validation",
57 | "num_bytes": 736407,
58 | "num_examples": 3000,
59 | "dataset_name": "wmt14"
60 | },
61 | "test": {
62 | "name": "test",
63 | "num_bytes": 777326,
64 | "num_examples": 3003,
65 | "dataset_name": "wmt14"
66 | }
67 | },
68 | "version": {
69 | "version_str": "0.0.0",
70 | "major": 0,
71 | "minor": 0,
72 | "patch": 0
73 | }
74 | }
--------------------------------------------------------------------------------
/tests/sample_data/wmt14/validation/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": "parquet",
3 | "citation": "",
4 | "config_name": "de-en",
5 | "dataset_name": "wmt14",
6 | "dataset_size": 1359920533,
7 | "description": "",
8 | "download_checksums": {
9 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00000-of-00003.parquet": {
10 | "num_bytes": 279527247,
11 | "checksum": null
12 | },
13 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00001-of-00003.parquet": {
14 | "num_bytes": 264970514,
15 | "checksum": null
16 | },
17 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/train-00002-of-00003.parquet": {
18 | "num_bytes": 272987200,
19 | "checksum": null
20 | },
21 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/validation-00000-of-00001.parquet": {
22 | "num_bytes": 473798,
23 | "checksum": null
24 | },
25 | "hf://datasets/wmt14@b199e406369ec1b7634206d3ded5ba45de2fe696/de-en/test-00000-of-00001.parquet": {
26 | "num_bytes": 508753,
27 | "checksum": null
28 | }
29 | },
30 | "download_size": 818467512,
31 | "features": {
32 | "translation": {
33 | "languages": [
34 | "de",
35 | "en"
36 | ],
37 | "_type": "Translation"
38 | }
39 | },
40 | "homepage": "",
41 | "license": "",
42 | "size_in_bytes": 2178388045,
43 | "splits": {
44 | "train": {
45 | "name": "train",
46 | "num_bytes": 1358406800,
47 | "num_examples": 4508785,
48 | "shard_lengths": [
49 | 1521929,
50 | 1721928,
51 | 1264928
52 | ],
53 | "dataset_name": "wmt14"
54 | },
55 | "validation": {
56 | "name": "validation",
57 | "num_bytes": 736407,
58 | "num_examples": 3000,
59 | "dataset_name": "wmt14"
60 | },
61 | "test": {
62 | "name": "test",
63 | "num_bytes": 777326,
64 | "num_examples": 3003,
65 | "dataset_name": "wmt14"
66 | }
67 | },
68 | "version": {
69 | "version_str": "0.0.0",
70 | "major": 0,
71 | "minor": 0,
72 | "patch": 0
73 | }
74 | }
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # https://medium.com/@dkraczkowski/crafting-a-ci-pipeline-my-experience-with-github-actions-python-aws-c67f428adee8
2 | -include .env
3 | SOURCE_DIR = src
4 | TEST_DIR = tests
5 | EXAMPLE_DIR = examples
6 | PROJECT_DIRS = $(SOURCE_DIR) $(TEST_DIR) $(EXAMPLE_DIR)
7 | PWD := $(dir $(abspath $(firstword $(MAKEFILE_LIST))))
8 | PROJECT_NAME ?= Cycleformers
9 | PROJECT_VERSION ?= v$(shell poetry version -s)
10 | PYTHON_VERSION ?= 3.11
11 | .DEFAULT_GOAL := all
12 |
13 | .PHONY: init-env init check-toml lint-src format lint quality audit test test-all all clean info build-docs
14 |
15 | init-env:
16 | @if [ ! -f .env ]; then \
17 | echo "Creating .env file..."; \
18 | echo "PROJECT_NAME=${PROJECT_NAME}" > .env; \
19 | echo "PYTHON_VERSION=${PYTHON_VERSION}" >> .env; \
20 | echo "export PYTHONPATH=${SOURCE_DIR}" >> .env; \
21 | else \
22 | echo "using existing .env file..."; \
23 | fi
24 |
25 | init: init-env
26 | @echo "Installing dependencies..."
27 | poetry install
28 | poetry run pre-commit install
29 |
30 | -check-toml:
31 | poetry check
32 |
33 | -reformat-src:
34 | poetry run ruff format $(PROJECT_DIRS)
35 | poetry run ruff check --select I --fix $(PROJECT_DIRS)
36 |
37 | -lint-src:
38 | poetry run ruff check --fix $(SOURCE_DIR)
39 | poetry run mypy --install-types --show-error-codes --non-interactive $(SOURCE_DIR)
40 |
41 | format: -check-toml -reformat-src
42 |
43 | lint: -lint-src
44 |
45 | quality:
46 | @poetry run python -c "from cycleformers import *" || (echo "Import failure. Unprotected import?"; exit 1)
47 | @make format
48 | @make lint
49 |
50 | audit:
51 | poetry run bandit -r $(SOURCE_DIR) -x $(TEST_DIR)
52 |
53 | test:
54 | poetry run pytest $(TEST_DIR)
55 |
56 | test-all:
57 | poetry run pytest $(TEST_DIR) -v --slow --meta -n auto --instafail
58 |
59 | build-docs:
60 | poetry run mkdocs build
61 |
62 | all: audit quality test build-docs
63 |
64 | clean:
65 | rm -rf dist/
66 | rm -rf build/
67 | rm -rf *.egg-info
68 | find . -type d -name '__pycache__' -exec rm -rf {} +
69 | find . -type d -name '.pytest_cache' -exec rm -rf {} +
70 | find . -type d -name '.mypy_cache' -exec rm -rf {} +
71 |
72 | info:
73 | @echo "Project name: ${PROJECT_NAME}"
74 | @echo "Project version: ${PROJECT_VERSION}"
75 | @echo "Python version: ${PYTHON_VERSION}"
--------------------------------------------------------------------------------
/docs/usage.md:
--------------------------------------------------------------------------------
1 | ### Training
2 |
3 | The `CycleTrainer` class is an extension but significant redesign of the 🤗 Transformers trainer, designed to abstract away the specifics of training while remaining configurable. Both Seq2Seq and Causal architectures are supported, each able to train via PEFT adapter swapping for memory efficient configurations. Check the [docs] for [usage] details and [examples].
4 |
5 | To train using two identical models the following sample code can be used along with two datasets:
6 |
7 | ```python
8 | from cycleformers import CycleTrainer, CycleTrainingArguments
9 |
10 | model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
11 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
12 |
13 | args = CycleTrainingArguments(output_dir="gpt2-cct")
14 | trainer = CycleTrainer(
15 | args,
16 | models = model
17 | tokenizers = tokenizer
18 | train_dataset_A = dataset_A,
19 | train_dataset_B = dataset_B
20 | )
21 | trainer.train()
22 | ```
23 |
24 | Any two models (🚧 currently both seq2seq or both causal) can be combined together for completely customisable training:
25 |
26 | ```python
27 | model_A = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
28 | model_B = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", device_map="auto")
29 | tokenizer_A = AutoTokenizer.from_pretrained("gpt2")
30 | tokenizer_B = AutoTokenizer.from_pretrained("google/flan-t5-small")
31 |
32 | trainer = CycleTrainer(
33 | args,
34 | models = {
35 | "A": model_A,
36 | "B": model_B
37 | }
38 | tokenizers = {
39 | "A": tokenizer_A,
40 | "B": tokenizer_B
41 | }
42 | train_dataset_A = dataset_A,
43 | train_dataset_B = dataset_B
44 | )
45 | ```
46 |
47 | ### Multi-Adapter Cycle-Consistency Training (MACCT)
48 |
49 | The `CycleTrainer` class is also setup to accept a single base model and train two PEFT adapters ontop of it, switching between them to emulate the two model setup. This allows for the training of `7.5x larger models` for the same memory footprint:
50 |
51 | ```python
52 | peft_config = PeftConfig(
53 | task_type="CAUSAL_LM",
54 | r=16,
55 | lora_alpha=32,
56 | target_modules="all-linear",
57 | inference_mode=False,
58 | bias="none"
59 | )
60 |
61 | args = CycleTrainingArguments(output_dir="gpt2-macct")
62 | trainer = CycleTrainer(
63 | args,
64 | model = model,
65 | tokenizer = tokenizer,
66 | peft_configs = peft_config # Or same A, B dict
67 | )
68 | ```
--------------------------------------------------------------------------------
/src/cycleformers/command/cli_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from collections.abc import Iterable
4 |
5 | import yaml
6 | from transformers.hf_argparser import DataClassType, HfArgumentParser
7 |
8 |
9 | VALID_TASKS = ["train"]
10 |
11 |
12 | class CfArgumentParser(HfArgumentParser):
13 | def __init__(
14 | self,
15 | dataclass_types: list[DataClassType] | None = None,
16 | task: str | None = None,
17 | **kwargs,
18 | ):
19 | # Make sure dataclass_types is an iterable
20 | if dataclass_types is None:
21 | dataclass_types = []
22 | elif not isinstance(dataclass_types, Iterable):
23 | dataclass_types = [dataclass_types]
24 |
25 | self.task = task
26 | super().__init__(dataclass_types=dataclass_types, **kwargs)
27 |
28 | def _parse_yaml_config(self, config_file: str):
29 | with open(config_file, "r") as f:
30 | config = yaml.safe_load(f)
31 |
32 | # Convert config YAMLs to be A_ and B_
33 | config_a = {f"A_{k}": v for k, v in config.pop("A", {}).items()}
34 | config_b = {f"B_{k}": v for k, v in config.pop("B", {}).items()}
35 |
36 | file_args = []
37 | for c in [config, config_a, config_b]:
38 | for k, v in c.items():
39 | v = json.dumps(v) if isinstance(v, dict) else str(v)
40 |
41 | if f"--{k}" in file_args:
42 | raise ValueError(f"Duplicate argument {k} found in config files")
43 |
44 | file_args.extend([f"--{k}", v])
45 | return file_args
46 |
47 | def parse_args_and_config(self):
48 | args = sys.argv[1:]
49 |
50 | # Handle task argument
51 | if not self.task and len(args) == 0:
52 | raise ValueError(f"No task provided. Task must be one of {VALID_TASKS}.")
53 |
54 | # Task is already set
55 | if self.task and args[0] in VALID_TASKS:
56 | raise ValueError(f"Task already set by script to {self.task}. Try again without {args[0]}.")
57 | # Task is not set and arg is a valid task
58 | elif not self.task:
59 | if args[0] in VALID_TASKS:
60 | self.task = args.pop(0)
61 | else:
62 | raise ValueError(f"Task must be one of {VALID_TASKS}, got {args[0]}")
63 |
64 | if len(args) > 0 and (args[0].endswith(".yaml") or args[0].endswith(".yml")):
65 | config_file = args.pop(0)
66 | file_args = self._parse_yaml_config(config_file)
67 | args = file_args + args
68 |
69 | return self.parse_args_into_dataclasses(args)
70 |
--------------------------------------------------------------------------------
/tests/task_processors/test_ner_processors.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from cycleformers.task_processors.ner import CONLL2003Processor, CONLL2003ProcessorConfig
4 |
5 |
6 | SAMPLES_DIR = Path(__file__).parent.parent / "sample_data"
7 |
8 |
9 | class TestCONLL2003Processor:
10 | def test_preprocess(self, monkeypatch, temp_dir):
11 | """Test preprocessing of CONLL2003 dataset."""
12 | monkeypatch.setenv("HF_DATASETS_CACHE", str(temp_dir))
13 |
14 | config = CONLL2003ProcessorConfig(
15 | dataset_name=SAMPLES_DIR / "conll2003",
16 | cache_dir=str(temp_dir),
17 | dataset_seed=42, # Fixed seed for reproducible tests
18 | )
19 | processor = CONLL2003Processor(config)
20 | dataset_A, dataset_B = processor.process()
21 |
22 | # Check that all splits are preserved
23 | assert set(dataset_A.keys()) == {"train", "validation", "test"}
24 | assert set(dataset_B.keys()) == {"train", "validation", "test"}
25 |
26 | # Check that training data is non-parallel (no labels)
27 | assert "label" not in dataset_A["train"].column_names
28 | assert "label" not in dataset_B["train"].column_names
29 |
30 | # Verify training splits are properly shuffled and unaligned
31 | train_texts_A = dataset_A["train"]["text"]
32 | train_texts_B = dataset_B["train"]["text"]
33 |
34 | # Check no exact alignment exists
35 | assert not any(
36 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i]))
37 | for i in range(len(train_texts_B))
38 | ), "Training splits appear to be aligned with some offset"
39 |
40 | # Verify that shuffling is deterministic with the same seed
41 | config2 = CONLL2003ProcessorConfig(
42 | dataset_name=SAMPLES_DIR / "conll2003",
43 | dataset_seed=42,
44 | cache_dir=str(temp_dir),
45 | )
46 | processor2 = CONLL2003Processor(config2)
47 | dataset_A2, dataset_B2 = processor2.process()
48 |
49 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"])
50 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"])
51 |
52 | # Check that evaluation splits maintain parallel data
53 | for key in ["validation", "test"]:
54 | assert dataset_A[key]["text"] == dataset_B[key]["label"]
55 | assert dataset_A[key]["label"] == dataset_B[key]["text"]
56 | assert dataset_A[key]["text"] != dataset_B[key]["text"]
57 | assert dataset_A[key]["label"] != dataset_B[key]["label"]
58 |
--------------------------------------------------------------------------------
/docs/conceptual_reference/macct.md:
--------------------------------------------------------------------------------
1 | # Multi-Adapter Cycle-Consistency Training (MACCT)
2 |
3 | Multi-Adapter Cycle-Consistency Training (MACCT) is an implementation of the Cycle-Consistency Training (CCT) training method using [PEFT](https://huggingface.co/docs/peft/en/index) LoRA adapters ([hu et al. 2022](https://openreview.net/forum?id=nZeVKeeFYf9)) inplace of full model weights to learn the A->B and B->A translation mappings. MACCT shares the base model weights which are frozen; therefore, we greatly reduce the number of optimizer states, making a significant reduction in memory footprint. So much so, that we can load a frozen base model that is `~7.5x larger` than either model in the full fine-tuning case.
4 |
5 | [A HELPFUL DIAGRAM WILL GO HERE]
6 |
7 |
8 | ??? "How are these figures calculated?"
9 |
10 | Assuming a restricted case where we just look at static memory, i.e. model weights and optimizer states, the memory savings are significant. We will assume in the dual-model ($DMCCT$) case that both models are the same, likewise in multi-adapter CCT ($MACCT$) we assume both LoRA adapters are initialised the same. In all cases we use the AdamW optimizer as it is the most popular:
11 |
12 | With $\theta$ as the foundational model parameters, $\phi$ as the LoRA adapter parameters, $p$ as the number of bits per parameter ${\theta}_i$, $q$ as the number of bits per parameter ${\phi}_i$, and $r$ as the ratio of base model size $|\theta|$ to LoRA adapter size $|\phi|$, we can derive the memory savings as follows:
13 |
14 | $$
15 | \begin{aligned}
16 | DMCCT =& \left[ 2 \text{ models} * |\theta| \text{ params} * p \text{ bits} \right] + \left[ 2 \text{ models} * |\theta| \text{ params} * 2 \text{ states} * (4*8) \text{ bits} \right] \\
17 | =& 2|\theta|(p + 64)
18 | \end{aligned} \tag{1}
19 | $$
20 |
21 | $$
22 | \begin{aligned}
23 | MACCT =& [ |\theta| \text{ params} * p \text{ bits} ] + [ 2 \text{ loras} * (|\theta| \text{ params} * r) * { q \text{ bits} + 2 \text{ states} * (4*8) \text{ bits} } ] \\
24 | =& |\theta|(p + 2r(q + 64))
25 | \end{aligned} \tag{2}
26 | $$
27 |
28 | Assuming $p = 16$, $q = 32$ (LoRA are trained in 32 bit by default) and $r = 0.03$ (~3% of base model size), and a 1B parameter model for each translation in $DMCCT$:
29 |
30 | $$
31 | \begin{aligned}
32 | DMCCT =& 2 * 1e9 * (16 + 64) \\
33 | =& 1.6e11 \text{ bits} \\
34 | \\
35 | DMCCT =& MACCT \\
36 | 1.6e11 \text{ bits} =& N * (16 + 2 * 0.03 * (32 + 64)) \\
37 | N =& \frac{1.6e11}{21.76} \\
38 | \approx& 7.35B \text{ params} \\
39 | \end{aligned}
40 | $$
41 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Cycleformers
2 | site_url: https://wrmthorne.github.io/cycleformers
3 | repo_name: wrmthorne/cycleformers
4 | repo_url: https://github.com/wrmthorne/cycleformers
5 |
6 | theme:
7 | name: material
8 | palette:
9 | # Palette toggle for automatic mode
10 | - media: "(prefers-color-scheme)"
11 | scheme: default
12 | primary: black
13 | toggle:
14 | icon: material/brightness-auto
15 | name: Switch to light mode
16 |
17 | # Palette toggle for light mode
18 | - media: "(prefers-color-scheme: light)"
19 | scheme: default
20 | primary: black
21 | toggle:
22 | icon: material/brightness-7
23 | name: Switch to dark mode
24 |
25 | # Palette toggle for dark mode
26 | - media: "(prefers-color-scheme: dark)"
27 | scheme: slate
28 | primary: black
29 | toggle:
30 | icon: material/brightness-4
31 | name: Switch to system preference
32 |
33 | features:
34 | - navigation.top
35 | - navigation.tabs
36 | - navigation.path
37 | - navigation.footer
38 | - navigation.instant
39 | - navigation.sections
40 | - navigation.tracking
41 | - navigation.instant.prefetch
42 | - navigation.instant.progress
43 | - content.code.annotate
44 | - content.code.copy
45 |
46 | nav:
47 | - Get Started:
48 | - Cycleformers: index.md
49 | - Installation: installation.md
50 | - Usage: usage.md
51 | - Performance: performance.md
52 | - Conceptual Reference:
53 | - Conceptual Overview: conceptual_reference/index.md
54 | - Cycle-Consistency Training: conceptual_reference/cycle_consistency_training.md
55 | - MACCT: conceptual_reference/macct.md
56 | - Examples:
57 | - Example: examples/index.md
58 | - CycleNER: examples/cycle_ner.md
59 | - WMT2014: examples/wmt2014.md
60 | - API Reference:
61 | - CycleTrainer: api_reference/cycle_trainer.md
62 | - Configuration: api_reference/configuration.md
63 | - Cycles: api_reference/cycles.md
64 | - Task Processors: api_reference/task_processors.md
65 |
66 | markdown_extensions:
67 | - pymdownx.highlight:
68 | anchor_linenums: true
69 | line_spans: __span
70 | pygments_lang_class: true
71 | - pymdownx.superfences:
72 | custom_fences:
73 | - name: mermaid
74 | class: mermaid
75 | format: !!python/name:pymdownx.superfences.fence_code_format
76 |
77 | plugins:
78 | - mkdocstrings:
79 | default_handler: python
80 | handlers:
81 | python:
82 | paths: [cycleformers] # Path to your source code
83 | options:
84 | show_source: true
85 | show_root_heading: true
86 | heading_level: 1
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 |
2 | # Contributing Guide
3 |
4 | We are interested in external contributions to Cycleformers, particularly with those experienced in automated testing of language models, and with multi-device training.
5 |
6 | Please contribute code via pull requests. For small bug fixes or improvements, please feel free to submit a PR directly to the main branch. For larger bugs or features, please open an issue for discussion first. If you would like to request a feature, please also open an issue. For any PRs, please ensure that you have added tests and documentation where appropriate.
7 |
8 | All new features should be developed on a new branch and use [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) prefixes. The following are recommended:
9 |
10 | - `fix`: for bug fixes
11 | - `feat`: for new features
12 | - `build:` for changes that affect the build system or external dependencies
13 | - `refactor:` for code refactoring
14 | - `ci:` for changes to CI configuration files and scripts
15 | - `docs:` for documentation changes
16 | - `perf:` for performance improvements
17 | - `test:` for adding tests
18 | - `chore:` for other changes that don't fit into the other categories
19 |
20 | Although this rule is not strictly enforced for commits, pull requests must be titled with a valid conventional commit prefix or the merge pipeline will fail. Pull requests should be made to main and will not be released immediately. Changes will be accumulated on main until a release is decided. The release workflow will handle versioning, publishing, changelog generation, etc. and is triggered by a release workflow dispatch event, selecting `main` as the branch.
21 |
22 | ## Development Setup
23 |
24 | We use Poetry to manage dependencies on Python 3.11. To install Poetry, follow the documentation here: https://python-poetry.org/docs/master/#installing-with-the-official-installer
25 |
26 | We recommend the following command to setup the build environment for the module:
27 |
28 | ```bash
29 | make init
30 | ```
31 |
32 | Many other useful commands are available in the Makefile which will automate formatting, linting, testing and building the documentation:
33 |
34 | - `make format` automatically fixes code style issues and standardizes formatting across the codebase.
35 | - `make lint` checks for code quality issues, potential bugs, and style violations.
36 | - `make audit` performs security vulnerability scanning to identify potential security risks.
37 | - `make test` executes the automated test suite to verify code functionality.
38 | - `make build-docs` generates the HTML documentation from source files based on the configuration in `mkdocs.yml`.
39 | - `make all` runs all of the above commands in sequence.
40 | - `make clean` removes build and distribution artifacts.
41 | - `make info` displays information about the project, including the project name, version, and Python version.
42 |
43 |
44 | ## Project Procedures
45 |
46 |
--------------------------------------------------------------------------------
/tests/test_examples.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import tempfile
3 | from pathlib import Path
4 |
5 | import pytest
6 | import yaml
7 |
8 |
9 | SAMPLES_DIR = Path(__file__).parent / "sample_data"
10 |
11 | yaml_base = {
12 | "per_device_train_batch_size": 1,
13 | "gradient_accumulation_steps": 1,
14 | "per_device_eval_batch_size": 1,
15 | "save_strategy": "steps",
16 | "max_steps": 3,
17 | "eval_steps": 1,
18 | "save_steps": 1,
19 | "logging_strategy": "steps",
20 | "logging_steps": 1,
21 | }
22 |
23 | lora_config = {
24 | "use_peft": True,
25 | "lora_r": 8,
26 | "lora_alpha": 16,
27 | }
28 |
29 | causal_yaml = {
30 | "model_name_or_path": "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
31 | "A": {
32 | "model_name_or_path": "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
33 | },
34 | "B": {
35 | "model_name_or_path": "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
36 | },
37 | } | yaml_base
38 |
39 | seq2seq_yaml = {
40 | "model_name_or_path": "google/flan-t5-small",
41 | "A": {
42 | "model_name_or_path": "google/flan-t5-small",
43 | },
44 | "B": {
45 | "model_name_or_path": "google/flan-t5-small",
46 | },
47 | } | yaml_base
48 |
49 | causal_yaml_macct = {**lora_config, "lora_task_type": "CAUSAL_LM"} | causal_yaml
50 |
51 | seq2seq_yaml_macct = {**lora_config, "lora_task_type": "SEQ_2_SEQ_LM"} | seq2seq_yaml
52 |
53 |
54 | @pytest.mark.slow
55 | @pytest.mark.requires_gpu
56 | @pytest.mark.parametrize(
57 | "example_script,dataset_name",
58 | [("cycle_ner/train.py", SAMPLES_DIR / "conll2003"), ("translation_wmt14/train.py", SAMPLES_DIR / "wmt14")],
59 | )
60 | @pytest.mark.parametrize(
61 | "config_yaml",
62 | [
63 | ("causal-base", "causal.yaml", causal_yaml),
64 | ("seq2seq-base", "seq2seq.yaml", seq2seq_yaml),
65 | ("causal-macct", "causal-macct.yaml", causal_yaml_macct),
66 | ("seq2seq-macct", "seq2seq-macct.yaml", seq2seq_yaml_macct),
67 | ],
68 | )
69 | def test_cycle_ner(example_script, dataset_name, config_yaml, temp_dir):
70 | out_dirname, filename, config = config_yaml
71 | out_dir = Path(temp_dir) / out_dirname
72 |
73 | config["output_dir"] = str(out_dir)
74 | config["dataset_name"] = str(dataset_name)
75 | yaml_file = Path(temp_dir) / filename
76 | with open(yaml_file, "w") as f:
77 | yaml.dump(config, f)
78 |
79 | project_root = Path(__file__)
80 | while project_root.name != "cycleformers":
81 | project_root = project_root.parent
82 |
83 | command = f"poetry run python {project_root}/examples/{example_script} {yaml_file}"
84 |
85 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
86 | # Check if the process completed successfully
87 | assert result.returncode == 0, f"Process failed with error: {result.stderr}"
88 |
89 | # Check for model checkpoitns
90 | checkpoint_dir = Path(config["output_dir"])
91 | assert checkpoint_dir.exists(), "Checkpoint directory was not created"
92 | assert any(checkpoint_dir.iterdir()), "No checkpoints were saved"
93 |
--------------------------------------------------------------------------------
/docs/performance.md:
--------------------------------------------------------------------------------
1 | # Performance
2 |
3 | ## 1 Potential Ways to Improve Evaluation Metric Performance
4 |
5 |
6 | lsLoRA ([Kalajdzievski, Damjan 2023](https://huggingface.co/papers/2312.03732)) scales adapters during each forward pass by `lora_alpha/math.sqrt(r)` which stabilizes performance at higher ranks $r$ ([rsLoRA docs](https://huggingface.co/papers/2312.03732)).
7 |
8 | DoRA ([Liu et al. 2024](https://arxiv.org/abs/2402.09353)) decomposes weight updates into magnitude and direction which they show is better correlated with full fine-tuning loss signals. This technique is particularly useful at low-ranks but can incurr a significant speed penalty. Significant performance gains can be made at the expense of higher VRAM usage by using `ephemeral_gpu_offload=True`. More info can be found at the [DoRA docs](ephemeral_gpu_offload=True).
9 |
10 |
11 | LoRA+ ([Hayou et al. 2024](https://arxiv.org/abs/2402.12354)) is an optimisation strategy for LoRA that allows for different learning rates for adapter matrices A and B. This can increase fine-tuning speed by up to 2x and *potentially* boost performance on some tasks by 1-2% ([LoRA+ docs](https://arxiv.org/abs/2402.12354)).
12 |
13 |
14 | ## 2 Potential Ways to Improve Throughput
15 |
16 | ### 2.1 Tokenization
17 |
18 | Wherever possible, we recommend using the same model for each direction of translation; at the very least, we recommend using models that share a tokenizer, i.e. are from the same generation of the same model family.
19 |
20 | Sending tokens from the GPU, back to CPU to detokenize, manipulate as strings, then re-tokenizing and sending back to GPU is costly. This is particularly significant for causal models that require sequences to be split and concatenated to produce the correct input_ids for training. We can skip this overhead by manipulating tokens as tensors on the GPU if we know that the tokenizer is compatible.
21 |
22 | Just having a compatible tokenizer is not always sufficient. If you perform any custom processing of synthetic samples before they are given to the training model such as applying a chat template, you may find that the tokenization overhead is unavoidable.
23 |
24 | ### 2.2 Specific Attention Kernels and Implementations
25 |
26 | Flash Attention 2 ([Dao et al. 2023](https://arxiv.org/abs/2205.14135)) is a drop-in replacement for the standard attention mechanism that significantly reduces memory usage and increases throughput. It can be installed via `pip install flash-attn` (see their [github](https://github.com/Dao-AILab/flash-attention) if having installation issues) and setting `attn_implementation="flash_attention_2"` in your model config.
27 |
28 | The Liger Kernel ([Hsu et al. 2024](https://arxiv.org/abs/2410.10989)) is a set of Triton kernels for efficient training of transformer models. They claim up to 20% throughput increase and 60% reduction in GPU memory usage. It can be installed via `pip install liger-kernel` and enabled by setting `use_liger_kernel=True` in your model config.
29 |
30 | These methods can be combined to see a very significant throughput increase and reduction in VRAM usage.
31 |
32 | ```python
33 | model = AutoModelForCausalLM.from_pretrained(
34 | # ...
35 | attn_implementation="flash_attention_2",
36 | use_liger_kernel=True
37 | )
38 | ```
39 |
40 | 🚧 # TODO: Test whether you can compile, use optimum, etc. on base model when using MACCT.
--------------------------------------------------------------------------------
/src/cycleformers/task_processors/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | from os import PathLike
4 | from pathlib import Path
5 |
6 | from datasets import DatasetDict, load_dataset, load_from_disk
7 |
8 |
9 | @dataclass
10 | class ProcessorConfig:
11 | """Configuration class for dataset processors.
12 |
13 | This class defines the common configuration parameters used across different processors.
14 | Each specific processor can extend this with additional task-specific parameters.
15 | """
16 |
17 | dataset_name: str | PathLike[str] | None = None
18 | eval_split_ratio: float = 0.2
19 | dataset_seed: int = 42
20 | cache_dir: str | None = None
21 | # split: list[str] | str | None = None FIXME: not a key feature right now - taking too much time
22 | max_samples: int | None = None
23 | streaming: bool = False # FIXME: not a key feature right now - taking too much time
24 |
25 |
26 | class BaseProcessor(ABC):
27 | """Base class for dataset processors.
28 |
29 | This abstract base class defines the interface and common functionality for all dataset processors.
30 | Each specific format/task processor should inherit from this class and implement the required methods.
31 |
32 | The processor handles:
33 | 1. Loading source datasets
34 | 2. Converting to cycleformers format by splitting into two separate datasets A and B
35 | 3. Applying common transformations (train/val splitting, shuffling)
36 |
37 | If the dataset already has a train/val/test split, those splits will be preserved and only train will be made
38 | non-parallel.
39 |
40 | Args:
41 | config: Configuration object controlling processor behavior. If not provided, uses default CONFIG_CLS.
42 |
43 | Example:
44 | >>> config = ProcessorConfig(eval_split_ratio=0.2, dataset_seed=42)
45 | >>> processor = ConcreteProcessor(config)
46 | >>> dataset_A, dataset_B = processor.process()
47 | >>> print(dataset_A.keys())
48 | dict_keys(['train', 'test'])
49 | """
50 |
51 | def __init__(self, config: ProcessorConfig = ProcessorConfig()):
52 | self.config = config
53 |
54 | def load(self) -> DatasetDict:
55 | """Load the source dataset. Override for custom loading logic."""
56 | if self.config.dataset_name is None:
57 | raise ValueError("No dataset name was provided. Cannot load `None`.")
58 |
59 | if Path(self.config.dataset_name).exists():
60 | return load_from_disk(self.config.dataset_name)
61 |
62 | return load_dataset(
63 | self.config.dataset_name,
64 | cache_dir=self.config.cache_dir,
65 | streaming=self.config.streaming,
66 | )
67 |
68 | @abstractmethod
69 | def preprocess(self, dataset: DatasetDict) -> DatasetDict:
70 | """Preprocess the dataset into two separate datasets A and B."""
71 | raise NotImplementedError
72 |
73 | def process(self) -> DatasetDict:
74 | """Process the dataset into two separate datasets A and B.
75 |
76 | Returns:
77 | Tuple[DatasetDict, DatasetDict]: Two datasets A and B, each containing 'train' and 'test' splits
78 | """
79 | dataset = self.load()
80 |
81 | if not isinstance(dataset, DatasetDict):
82 | dataset = DatasetDict({"train": dataset})
83 |
84 | if dataset.keys() == ["train"]:
85 | dataset = dataset.train_test_split(test_size=self.config.eval_split_ratio, seed=self.config.dataset_seed)
86 |
87 | dataset_A, dataset_B = self.preprocess(dataset)
88 |
89 | dataset_A["train"] = dataset_A["train"].shuffle(seed=self.config.dataset_seed)
90 | dataset_B["train"] = dataset_B["train"].shuffle(seed=self.config.dataset_seed + 1)
91 |
92 | return dataset_A, dataset_B
93 |
--------------------------------------------------------------------------------
/examples/cycle_ner/train.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
2 |
3 | from cycleformers import (
4 | CfArgumentParser,
5 | CycleTrainer,
6 | CycleTrainingArguments,
7 | ModelConfig,
8 | ModelConfigA,
9 | ModelConfigB,
10 | merge_configs,
11 | )
12 | from cycleformers.import_utils import is_liger_kernel_available
13 | from cycleformers.task_processors.ner import CONLL2003Processor, CONLL2003ProcessorConfig
14 | from cycleformers.utils import VALID_LIGER_MODELS, get_peft_config
15 |
16 |
17 | if is_liger_kernel_available():
18 | from liger_kernel.transformers import AutoLigerKernelForCausalLM
19 |
20 |
21 | def get_model_and_tokenizer(model_config, training_args):
22 | """Initialize model and tokenizer from config"""
23 | config = AutoConfig.from_pretrained(
24 | model_config.model_name_or_path,
25 | trust_remote_code=model_config.trust_remote_code,
26 | )
27 | config.use_cache = False
28 |
29 | model_kwargs = {}
30 |
31 | if not config.is_encoder_decoder:
32 | if is_liger_kernel_available() and model_config.use_liger and config.model_type in VALID_LIGER_MODELS:
33 | model_class = AutoLigerKernelForCausalLM
34 | model_kwargs["use_liger_kernel"] = training_args.use_liger_kernel
35 | else:
36 | model_class = AutoModelForCausalLM
37 | else:
38 | model_class = AutoModelForSeq2SeqLM
39 |
40 | model = model_class.from_pretrained(
41 | model_config.model_name_or_path,
42 | revision=model_config.model_revision,
43 | config=config,
44 | trust_remote_code=model_config.trust_remote_code,
45 | attn_implementation=model_config.attn_implementation,
46 | torch_dtype=model_config.torch_dtype,
47 | device_map="auto",
48 | )
49 |
50 | if training_args.gradient_checkpointing:
51 | model.enable_input_require_grads()
52 |
53 | # Print the actual dtype of the first parameter
54 | print(f"Model weights dtype: {next(model.parameters()).dtype}")
55 |
56 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
57 | return model, tokenizer
58 |
59 |
60 | def main():
61 | parser = CfArgumentParser(
62 | (CycleTrainingArguments, ModelConfig, ModelConfigA, ModelConfigB, CONLL2003ProcessorConfig), task="train"
63 | )
64 | args, model_config_base, model_config_A, model_config_B, conll_config = parser.parse_args_and_config()
65 | model_config_base = merge_configs(model_config_base, model_config_A, model_config_B)
66 | args.model_config = model_config_base
67 |
68 | task_processor = CONLL2003Processor(conll_config)
69 | dataset_A, dataset_B = task_processor.process()
70 |
71 | # Get model A using merged A config
72 | model_A, tokenizer_A = get_model_and_tokenizer(args.model_config.A, args)
73 |
74 | # Train by adapter swapping
75 | if not args.use_macct:
76 | # Get model B using merged B config
77 | model_B, tokenizer_B = get_model_and_tokenizer(args.model_config.B, args)
78 | models = {"A": model_A, "B": model_B}
79 | tokenizers = {"A": tokenizer_A, "B": tokenizer_B} if tokenizer_A != tokenizer_B else tokenizer_A
80 | else:
81 | models = model_A
82 | tokenizers = tokenizer_A
83 |
84 | trainer = CycleTrainer(
85 | args=args,
86 | models=models,
87 | tokenizers=tokenizers,
88 | train_dataset_A=dataset_A["train"],
89 | train_dataset_B=dataset_B["train"],
90 | eval_dataset_A=dataset_A["eval"] if not args.eval_strategy == "no" else None,
91 | eval_dataset_B=dataset_B["eval"] if not args.eval_strategy == "no" else None,
92 | peft_configs=get_peft_config(model_config_base),
93 | )
94 |
95 | trainer.train()
96 |
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cycleformers
2 |
3 |
4 |
5 | [](https://www.python.org/downloads/)
6 | [](https://pypi.org/project/cycleformers/)
7 | [](https://creativecommons.org/licenses/by/4.0/)
8 | [](https://codecov.io/gh/wrmthorne/cycleformers)
9 | [](https://github.com/wrmthorne/cycleformers/actions/workflows)
10 |
11 |
12 |
13 | A Python library for efficient cycle-consistency training of transformer models. Cycleformers simplifies iterative back-translation with support for both causal and seq2seq architectures. We also implement Multi-Adapter Cycle-Consistency Training (MACCT), enabling training of LoRA adapters on a frozen base model for `7.5x` larger model capacity for the same memory footprint.
14 |
15 | ## Features
16 |
17 | - 🤗 Seamless integration with Hugging Face Transformers
18 | - 🚀 PEFT/LoRA support for memory-efficient training
19 | - 🤖 Compatible with both causal and seq2seq models
20 | - 🔥 Optimized for various hardware configurations
21 |
22 |
23 | ## Quick Tour
24 |
25 | ### Installation
26 |
27 | ```bash
28 | pip install cycleformers
29 | ```
30 |
31 | ### Training
32 |
33 | The `CycleTrainer` class is an extension but significant redesign of the 🤗 Transformers trainer, designed to abstract away the specifics of training while remaining configurable. Both Seq2Seq and Causal architectures are supported, each able to train via PEFT adapter swapping for memory efficient configurations. Check the [docs] for [usage] details and [examples].
34 |
35 | To train using two identical models the following sample code can be used along with two datasets:
36 |
37 | ```python
38 | from cycleformers import CycleTrainer, CycleTrainingArguments
39 |
40 | model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
41 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
42 |
43 | args = CycleTrainingArguments(output_dir="gpt2-cct")
44 | trainer = CycleTrainer(
45 | args,
46 | models = model
47 | tokenizers = tokenizer
48 | train_dataset_A = dataset_A,
49 | train_dataset_B = dataset_B
50 | )
51 | trainer.train()
52 | ```
53 |
54 | Any two models (🚧 currently both seq2seq or both causal) can be combined together for completely customisable training:
55 |
56 | ```python
57 | model_A = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
58 | model_B = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", device_map="auto")
59 | tokenizer_A = AutoTokenizer.from_pretrained("gpt2")
60 | tokenizer_B = AutoTokenizer.from_pretrained("google/flan-t5-small")
61 |
62 | trainer = CycleTrainer(
63 | args,
64 | models = {
65 | "A": model_A,
66 | "B": model_B
67 | }
68 | tokenizers = {
69 | "A": tokenizer_A,
70 | "B": tokenizer_B
71 | }
72 | train_dataset_A = dataset_A,
73 | train_dataset_B = dataset_B
74 | )
75 | ```
76 |
77 | ### Multi-Adapter Cycle-Consistency Training (MACCT)
78 |
79 | The `CycleTrainer` class is also setup to accept a single base model and train two PEFT adapters ontop of it, switching between them to emulate the two model setup. This allows for the training of `7.5x larger models` for the same memory footprint:
80 |
81 | ```python
82 | peft_config = PeftConfig(
83 | task_type="CAUSAL_LM",
84 | r=16,
85 | lora_alpha=32,
86 | target_modules="all-linear",
87 | inference_mode=False,
88 | bias="none"
89 | )
90 |
91 | args = CycleTrainingArguments(output_dir="gpt2-macct")
92 | trainer = CycleTrainer(
93 | args,
94 | model = model,
95 | tokenizer = tokenizer,
96 | peft_configs = peft_config # Or same A, B dict
97 | )
98 | ```
99 |
100 |
101 |
102 | ## Citing
103 |
104 | If you use Cycleformers in your research, please cite:
105 |
106 | ```bibtex
107 | add once zenodo/paper citation is available
108 | ```
--------------------------------------------------------------------------------
/examples/translation_wmt14/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | from datasets import load_from_disk
5 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
6 |
7 | from cycleformers import (
8 | CfArgumentParser,
9 | CycleTrainer,
10 | CycleTrainingArguments,
11 | ModelConfig,
12 | ModelConfigA,
13 | ModelConfigB,
14 | merge_configs,
15 | )
16 | from cycleformers.import_utils import is_liger_kernel_available
17 | from cycleformers.task_processors.translation import TranslationProcessor, TranslationProcessorConfig
18 | from cycleformers.utils import VALID_LIGER_MODELS, get_peft_config
19 |
20 |
21 | if is_liger_kernel_available():
22 | from liger_kernel.transformers import AutoLigerKernelForCausalLM
23 |
24 |
25 | def get_model_and_tokenizer(model_config, training_args):
26 | """Initialize model and tokenizer from config"""
27 | config = AutoConfig.from_pretrained(
28 | model_config.model_name_or_path,
29 | trust_remote_code=model_config.trust_remote_code,
30 | )
31 | config.use_cache = False
32 |
33 | model_kwargs = {}
34 |
35 | if not config.is_encoder_decoder:
36 | if is_liger_kernel_available() and model_config.use_liger and config.model_type in VALID_LIGER_MODELS:
37 | model_class = AutoLigerKernelForCausalLM
38 | model_kwargs["use_liger_kernel"] = training_args.use_liger_kernel
39 | else:
40 | model_class = AutoModelForCausalLM
41 | else:
42 | model_class = AutoModelForSeq2SeqLM
43 |
44 | model = model_class.from_pretrained(
45 | model_config.model_name_or_path,
46 | revision=model_config.model_revision,
47 | config=config,
48 | trust_remote_code=model_config.trust_remote_code,
49 | attn_implementation=model_config.attn_implementation,
50 | torch_dtype=model_config.torch_dtype,
51 | device_map="auto",
52 | **model_kwargs,
53 | )
54 |
55 | if training_args.gradient_checkpointing:
56 | model.enable_input_require_grads()
57 |
58 | # Print the actual dtype of the first parameter
59 | print(f"Model weights dtype: {next(model.parameters()).dtype}")
60 |
61 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
62 | return model, tokenizer
63 |
64 |
65 | def main():
66 | parser = CfArgumentParser(
67 | (CycleTrainingArguments, ModelConfig, ModelConfigA, ModelConfigB, TranslationProcessorConfig), task="train"
68 | )
69 | args, model_config_base, model_config_A, model_config_B, translation_config = parser.parse_args_and_config()
70 | model_config_base = merge_configs(model_config_base, model_config_A, model_config_B)
71 | args.model_config = model_config_base
72 |
73 | task_processor = TranslationProcessor(translation_config)
74 | dataset_A, dataset_B = task_processor.process()
75 |
76 | # Get model A using merged A config
77 | model_A, tokenizer_A = get_model_and_tokenizer(args.model_config.A, args)
78 |
79 | # Train by adapter swapping
80 | if not args.use_macct:
81 | # Get model B using merged B config
82 | model_B, tokenizer_B = get_model_and_tokenizer(args.model_config.B, args)
83 | models = {"A": model_A, "B": model_B}
84 | tokenizers = {"A": tokenizer_A, "B": tokenizer_B} if tokenizer_A != tokenizer_B else tokenizer_A
85 | else:
86 | models = model_A
87 | tokenizers = tokenizer_A
88 |
89 | trainer = CycleTrainer(
90 | args=args,
91 | models=models,
92 | tokenizers=tokenizers,
93 | train_dataset_A=dataset_A["train"],
94 | train_dataset_B=dataset_B["train"],
95 | eval_dataset_A=dataset_A["test"] if not args.eval_strategy == "no" else None,
96 | eval_dataset_B=dataset_B["test"] if not args.eval_strategy == "no" else None,
97 | peft_configs=get_peft_config(model_config_base),
98 | )
99 |
100 | trainer.train()
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/tests/configs/test_model_config.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cycleformers.model_config import ModelConfig, ModelConfigA, ModelConfigB, merge_configs
4 |
5 |
6 | class TestMergeConfigs:
7 | @pytest.mark.parametrize(
8 | "base_config, config_a, config_b, expected_a, expected_b",
9 | [
10 | # Test case 1: Basic override of model names
11 | (
12 | ModelConfig(model_name_or_path="base"),
13 | ModelConfigA(A_model_name_or_path="model_a"),
14 | ModelConfigB(B_model_name_or_path="model_b"),
15 | "model_a",
16 | "model_b",
17 | ),
18 | # Test case 2: Override with default values (should not override)
19 | (
20 | ModelConfig(lora_r=64),
21 | ModelConfigA(A_lora_r=16), # default value
22 | ModelConfigB(B_lora_r=32),
23 | 64, # should keep base value since A uses default
24 | 32,
25 | ),
26 | # Test case 3: Multiple field overrides
27 | (
28 | ModelConfig(model_name_or_path="base", lora_r=16, lora_alpha=32),
29 | ModelConfigA(A_model_name_or_path="model_a", A_lora_r=64, A_lora_alpha=128),
30 | ModelConfigB(B_model_name_or_path="model_b", B_lora_r=32, B_lora_alpha=64),
31 | (64, 128), # (lora_r, lora_alpha)
32 | (32, 64),
33 | ),
34 | ],
35 | ids=["basic_override", "default_value_handling", "multiple_fields"],
36 | )
37 | def test_merge_configs_parametrized(self, base_config, config_a, config_b, expected_a, expected_b):
38 | result = merge_configs(base_config, config_a, config_b)
39 |
40 | if isinstance(expected_a, tuple):
41 | # Handle multiple field test case
42 | assert result.A.lora_r == expected_a[0]
43 | assert result.A.lora_alpha == expected_a[1]
44 | assert result.B.lora_r == expected_b[0]
45 | assert result.B.lora_alpha == expected_b[1]
46 | else:
47 | # Handle single field test cases
48 | if isinstance(expected_a, str):
49 | assert result.A.model_name_or_path == expected_a
50 | assert result.B.model_name_or_path == expected_b
51 | else:
52 | assert result.A.lora_r == expected_a
53 | assert result.B.lora_r == expected_b
54 |
55 | def test_merge_configs_preserves_base_values(self):
56 | base = ModelConfig(model_name_or_path="base", lora_r=32, trust_remote_code=True)
57 | config_a = ModelConfigA(A_model_name_or_path="model_a")
58 | config_b = ModelConfigB(B_model_name_or_path="model_b")
59 |
60 | result = merge_configs(base, config_a, config_b)
61 |
62 | # Check that base values are preserved
63 | assert result.A.lora_r == 32
64 | assert result.A.trust_remote_code is True
65 | assert result.B.lora_r == 32
66 | assert result.B.trust_remote_code is True
67 |
68 | def test_merge_configs_list_handling(self):
69 | base = ModelConfig(lora_target_modules=["query", "value"])
70 | config_a = ModelConfigA(A_lora_target_modules=["key"])
71 | config_b = ModelConfigB(B_lora_target_modules=["output"])
72 |
73 | result = merge_configs(base, config_a, config_b)
74 |
75 | assert result.A.lora_target_modules == ["key"]
76 | assert result.B.lora_target_modules == ["output"]
77 |
78 | def test_merge_configs_original_unmodified(self):
79 | base = ModelConfig(model_name_or_path="base", lora_r=32)
80 | config_a = ModelConfigA(A_model_name_or_path="model_a", A_lora_r=64)
81 | config_b = ModelConfigB(B_model_name_or_path="model_b", B_lora_r=128)
82 |
83 | # Store original values
84 | original_base_model = base.model_name_or_path
85 | original_base_lora_r = base.lora_r
86 |
87 | merge_configs(base, config_a, config_b)
88 |
89 | # Check original configs weren't modified
90 | assert base.model_name_or_path == original_base_model
91 | assert base.lora_r == original_base_lora_r
92 |
--------------------------------------------------------------------------------
/tests/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | import pytest
5 |
6 | from cycleformers.utils import prefixed_view
7 |
8 |
9 | # Test fixtures
10 | @dataclass
11 | class SimpleConfig:
12 | name: str | None = None
13 | value: int = 42
14 |
15 |
16 | @dataclass
17 | class ComplexConfig:
18 | name: str
19 | values: list[int] = field(default_factory=list)
20 | optional: Optional[str] = None
21 |
22 |
23 | @dataclass(frozen=True)
24 | class FrozenConfig:
25 | name: str
26 | value: int = 42
27 |
28 |
29 | class TestPrefixedView:
30 | def test_basic_creation(self):
31 | @prefixed_view(SimpleConfig, "test_")
32 | class TestConfig:
33 | pass
34 |
35 | assert TestConfig.__annotations__ == {"test_name": str | None, "test_value": int}
36 |
37 | @pytest.mark.parametrize(
38 | "input_values,expected",
39 | [
40 | ({"test_name": "example"}, SimpleConfig(name="example", value=42)),
41 | ({"test_name": "test", "test_value": 100}, SimpleConfig(name="test", value=100)),
42 | ({}, SimpleConfig(name=None, value=42)),
43 | ],
44 | )
45 | def test_valid_instances(self, input_values, expected):
46 | @prefixed_view(SimpleConfig, "test_")
47 | class TestConfig:
48 | pass
49 |
50 | instance = TestConfig(**input_values)
51 | assert instance == expected
52 |
53 | @pytest.mark.parametrize(
54 | "prefix",
55 | [
56 | "test_",
57 | "_",
58 | "PREFIX_",
59 | "A_",
60 | ],
61 | )
62 | def test_prefix_styles(self, prefix):
63 | @prefixed_view(SimpleConfig, prefix)
64 | class TestConfig:
65 | pass
66 |
67 | # Verify prefixed attributes exist
68 | assert all(name.startswith(prefix) for name in TestConfig.__annotations__)
69 |
70 | def test_default_factory(self):
71 | @prefixed_view(ComplexConfig, "test_")
72 | class TestConfig:
73 | pass
74 |
75 | instance = TestConfig(test_name="test")
76 | assert instance.values == []
77 |
78 | instance = TestConfig(test_name="test", test_values=[1, 2, 3])
79 | assert instance.values == [1, 2, 3]
80 |
81 | def test_frozen_dataclass(self):
82 | @prefixed_view(FrozenConfig, "test_")
83 | class TestConfig:
84 | pass
85 |
86 | instance = TestConfig(test_name="test")
87 | assert instance.name == "test"
88 | assert instance.value == 42
89 |
90 | @pytest.mark.parametrize(
91 | "invalid_base,prefix,expected_error",
92 | [
93 | (None, "test_", TypeError), # Missing base class
94 | (dict, "test_", TypeError), # not a dataclass
95 | ],
96 | )
97 | def test_creation_errors(self, invalid_base, prefix, expected_error):
98 | with pytest.raises(expected_error):
99 |
100 | @prefixed_view(invalid_base, prefix)
101 | class TestConfig:
102 | pass
103 |
104 | def test_nested_dataclass(self):
105 | @dataclass
106 | class NestedConfig:
107 | config: SimpleConfig
108 | name: str
109 |
110 | @prefixed_view(NestedConfig, "test_")
111 | class TestNestedConfig:
112 | pass
113 |
114 | nested = SimpleConfig(name="nested")
115 | instance = TestNestedConfig(test_config=nested, test_name="test")
116 | assert instance.config == nested
117 | assert instance.name == "test"
118 |
119 | def test_inheritance(self):
120 | @dataclass
121 | class BaseConfig:
122 | name: str
123 |
124 | @dataclass
125 | class ChildConfig(BaseConfig):
126 | value: int = 42
127 |
128 | @prefixed_view(ChildConfig, "test_")
129 | class TestConfig:
130 | pass
131 |
132 | # Should include both inherited and child fields
133 | assert set(TestConfig.__annotations__.keys()) == {"test_name", "test_value"}
134 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "cycleformers"
3 | packages = [
4 | { from = "src", include = "cycleformers" },
5 | ]
6 | version = "0.1.0"
7 | description = "A comprehensive implementation of the cycle-consistency training paradigm, extending the Huggingface Transformers trainer API to accommodate arbitrary combinations of generative models."
8 | authors = ["William Thorne "]
9 | license = "Attribution 4.0 International"
10 | readme = "README.md"
11 |
12 | [tool.poetry.dependencies]
13 | python = ">=3.11,<3.13"
14 | transformers = "4.46"
15 | torch = "^2.5.0"
16 | datasets = "^3.0.2"
17 | typing-extensions = {version = "^4.12.2", python = ">=3.8,<3.12"}
18 | peft = "^0.13.2" # TODO: Add importutils is_peft_available methods and remove dependency
19 |
20 | [tool.poetry.group.dev.dependencies]
21 | bandit = "^1.8.0"
22 | pre-commit = "^4.0.1"
23 | ruff = "^0.7.1"
24 | mypy = "^1.13.0"
25 | semantic-release = "^0.1.0"
26 | genbadge = {extras = ["build", "coverage"], version = "^1.1.1"}
27 | pytest = "^8.3.3"
28 | pytest-cov = "^6.0.0"
29 | pytest-benchmark = "^5.1.0"
30 | pytest-xdist = "^3.6.1"
31 | pytest-sugar = "^1.0.0"
32 | pytest-instafail = "^0.5.0"
33 | pytest-flakefinder = "^1.1.0"
34 | pytest-picked = "^0.5.1"
35 | pytest-random-order = "^1.1.1"
36 |
37 | [tool.poetry.group.profiling.dependencies]
38 | memory-profiler = "^0.61.0"
39 | snakeviz = "^2.2.2"
40 | torch-tb-profiler = "^0.4.3"
41 | tensorboard-plugin-profile = "^2.18.0"
42 |
43 | [tool.poetry.group.docs.dependencies]
44 | mkdocs = "^1.6.1"
45 | mkdocstrings = {extras = ["python"], version = "^0.26.2"}
46 | mkdocs-material = "^9.5.43"
47 | pygments = "^2.18.0"
48 |
49 | [tool.ruff]
50 | # Exclude a variety of commonly ignored directories.
51 | exclude = [
52 | ".bzr",
53 | ".direnv",
54 | ".eggs",
55 | ".git",
56 | ".git-rewrite",
57 | ".hg",
58 | ".ipynb_checkpoints",
59 | ".mypy_cache",
60 | ".nox",
61 | ".pants.d",
62 | ".pyenv",
63 | ".pytest_cache",
64 | ".pytype",
65 | ".ruff_cache",
66 | ".svn",
67 | ".tox",
68 | ".venv",
69 | ".vscode",
70 | "__pypackages__",
71 | "_build",
72 | "buck-out",
73 | "build",
74 | "dist",
75 | "node_modules",
76 | "site-packages",
77 | "venv",
78 | "misc",
79 | ]
80 |
81 | # Same as Transformers.
82 | line-length = 119
83 | indent-width = 4
84 |
85 | target-version = "py311"
86 |
87 | [tool.ruff.lint]
88 | # Same as Transformers.
89 | ignore = ["C901", "E501", "E741", "F402", "F823" ]
90 | select = [
91 | "C",
92 | "E",
93 | "F",
94 | "I",
95 | "W",
96 | "UP006", # Use built-in types instead of typing ones (list vs List)
97 | "UP007", # Use | instead of Union
98 | "UP008", # Use | None instead of Optional
99 | "UP009", # Use collections.abc instead of typing
100 | "UP035", # Remove redundant unions
101 | ]
102 |
103 | [tool.ruff.lint.isort]
104 | lines-after-imports = 2
105 | known-first-party = ["cycleformers"]
106 |
107 | [tool.ruff.format]
108 | # Same as Transformers.
109 | quote-style = "double"
110 | indent-style = "space"
111 | skip-magic-trailing-comma = false
112 | line-ending = "auto"
113 |
114 | [tool.ruff.lint.per-file-ignores]
115 | # Ignore import violations in all `__init__.py` files.
116 | "__init__.py" = ["F401"]
117 |
118 | [[tool.mypy.overrides]]
119 | module = "accelerate.*,datasets.*,transformers.*"
120 | ignore_missing_imports = true
121 |
122 | [tool.pytest.ini_options]
123 | markers = [
124 | "slow: mark test as slow to run (run with --slow to enable)",
125 | "meta: mark test as a test of the testing framework (run with --meta to enable)",
126 | "requires_gpu: mark test as requiring a GPU or taking far too long to run on a CPU (skip with --no-gpu)"
127 | ]
128 |
129 | [tool.coverage.run]
130 | data_file = ".coverage"
131 | source = ["src/cycleformers"]
132 | omit = [
133 | "tests/*",
134 | "misc/*",
135 | ]
136 |
137 | [tool.coverage.report]
138 | exclude_lines = [
139 | "pragma: no cover",
140 | "def __repr__",
141 | "raise NotImplementedError",
142 | ]
143 |
144 | [tool.semantic_release]
145 | version_variable = [
146 | "src/cycleformers/__init__.py",
147 | "pyproject.toml:version",
148 | ]
149 |
150 | parser_angular_allowed_types = "build, chore, ci, docs, feat, fix, perf, style, refactor, test"
151 | parser_angular_minor_types = "feat"
152 | parser_angular_patch_types = "fix, perf"
153 |
154 | [build-system]
155 | requires = ["poetry-core"]
156 | build-backend = "poetry.core.masonry.api"
157 |
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/test/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": "conll2003",
3 | "citation": "@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,\n title = \"Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition\",\n author = \"Tjong Kim Sang, Erik F. and\n De Meulder, Fien\",\n booktitle = \"Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003\",\n year = \"2003\",\n url = \"https://www.aclweb.org/anthology/W03-0419\",\n pages = \"142--147\",\n}\n",
4 | "config_name": "conll2003",
5 | "dataset_name": "conll2003",
6 | "dataset_size": 10252622,
7 | "description": "The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on\nfour types of named entities: persons, locations, organizations and names of miscellaneous entities that do\nnot belong to the previous three groups.\n\nThe CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on\na separate line and there is an empty line after each sentence. The first item on each line is a word, the second\na part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags\nand the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only\nif two phrases of the same type immediately follow each other, the first word of the second phrase will have tag\nB-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2\ntagging scheme, whereas the original dataset uses IOB1.\n\nFor more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419\n",
8 | "download_checksums": {
9 | "https://data.deepai.org/conll2003.zip": {
10 | "num_bytes": 982975,
11 | "checksum": null
12 | }
13 | },
14 | "download_size": 982975,
15 | "features": {
16 | "id": {
17 | "dtype": "string",
18 | "_type": "Value"
19 | },
20 | "tokens": {
21 | "feature": {
22 | "dtype": "string",
23 | "_type": "Value"
24 | },
25 | "_type": "Sequence"
26 | },
27 | "pos_tags": {
28 | "feature": {
29 | "names": [
30 | "\"",
31 | "''",
32 | "#",
33 | "$",
34 | "(",
35 | ")",
36 | ",",
37 | ".",
38 | ":",
39 | "``",
40 | "CC",
41 | "CD",
42 | "DT",
43 | "EX",
44 | "FW",
45 | "IN",
46 | "JJ",
47 | "JJR",
48 | "JJS",
49 | "LS",
50 | "MD",
51 | "NN",
52 | "NNP",
53 | "NNPS",
54 | "NNS",
55 | "NN|SYM",
56 | "PDT",
57 | "POS",
58 | "PRP",
59 | "PRP$",
60 | "RB",
61 | "RBR",
62 | "RBS",
63 | "RP",
64 | "SYM",
65 | "TO",
66 | "UH",
67 | "VB",
68 | "VBD",
69 | "VBG",
70 | "VBN",
71 | "VBP",
72 | "VBZ",
73 | "WDT",
74 | "WP",
75 | "WP$",
76 | "WRB"
77 | ],
78 | "_type": "ClassLabel"
79 | },
80 | "_type": "Sequence"
81 | },
82 | "chunk_tags": {
83 | "feature": {
84 | "names": [
85 | "O",
86 | "B-ADJP",
87 | "I-ADJP",
88 | "B-ADVP",
89 | "I-ADVP",
90 | "B-CONJP",
91 | "I-CONJP",
92 | "B-INTJ",
93 | "I-INTJ",
94 | "B-LST",
95 | "I-LST",
96 | "B-NP",
97 | "I-NP",
98 | "B-PP",
99 | "I-PP",
100 | "B-PRT",
101 | "I-PRT",
102 | "B-SBAR",
103 | "I-SBAR",
104 | "B-UCP",
105 | "I-UCP",
106 | "B-VP",
107 | "I-VP"
108 | ],
109 | "_type": "ClassLabel"
110 | },
111 | "_type": "Sequence"
112 | },
113 | "ner_tags": {
114 | "feature": {
115 | "names": [
116 | "O",
117 | "B-PER",
118 | "I-PER",
119 | "B-ORG",
120 | "I-ORG",
121 | "B-LOC",
122 | "I-LOC",
123 | "B-MISC",
124 | "I-MISC"
125 | ],
126 | "_type": "ClassLabel"
127 | },
128 | "_type": "Sequence"
129 | }
130 | },
131 | "homepage": "https://www.aclweb.org/anthology/W03-0419/",
132 | "license": "",
133 | "size_in_bytes": 11235597,
134 | "splits": {
135 | "train": {
136 | "name": "train",
137 | "num_bytes": 6931345,
138 | "num_examples": 14041,
139 | "dataset_name": "conll2003"
140 | },
141 | "validation": {
142 | "name": "validation",
143 | "num_bytes": 1739223,
144 | "num_examples": 3250,
145 | "dataset_name": "conll2003"
146 | },
147 | "test": {
148 | "name": "test",
149 | "num_bytes": 1582054,
150 | "num_examples": 3453,
151 | "dataset_name": "conll2003"
152 | }
153 | },
154 | "version": {
155 | "version_str": "1.0.0",
156 | "major": 1,
157 | "minor": 0,
158 | "patch": 0
159 | }
160 | }
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/train/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": "conll2003",
3 | "citation": "@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,\n title = \"Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition\",\n author = \"Tjong Kim Sang, Erik F. and\n De Meulder, Fien\",\n booktitle = \"Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003\",\n year = \"2003\",\n url = \"https://www.aclweb.org/anthology/W03-0419\",\n pages = \"142--147\",\n}\n",
4 | "config_name": "conll2003",
5 | "dataset_name": "conll2003",
6 | "dataset_size": 10252622,
7 | "description": "The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on\nfour types of named entities: persons, locations, organizations and names of miscellaneous entities that do\nnot belong to the previous three groups.\n\nThe CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on\na separate line and there is an empty line after each sentence. The first item on each line is a word, the second\na part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags\nand the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only\nif two phrases of the same type immediately follow each other, the first word of the second phrase will have tag\nB-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2\ntagging scheme, whereas the original dataset uses IOB1.\n\nFor more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419\n",
8 | "download_checksums": {
9 | "https://data.deepai.org/conll2003.zip": {
10 | "num_bytes": 982975,
11 | "checksum": null
12 | }
13 | },
14 | "download_size": 982975,
15 | "features": {
16 | "id": {
17 | "dtype": "string",
18 | "_type": "Value"
19 | },
20 | "tokens": {
21 | "feature": {
22 | "dtype": "string",
23 | "_type": "Value"
24 | },
25 | "_type": "Sequence"
26 | },
27 | "pos_tags": {
28 | "feature": {
29 | "names": [
30 | "\"",
31 | "''",
32 | "#",
33 | "$",
34 | "(",
35 | ")",
36 | ",",
37 | ".",
38 | ":",
39 | "``",
40 | "CC",
41 | "CD",
42 | "DT",
43 | "EX",
44 | "FW",
45 | "IN",
46 | "JJ",
47 | "JJR",
48 | "JJS",
49 | "LS",
50 | "MD",
51 | "NN",
52 | "NNP",
53 | "NNPS",
54 | "NNS",
55 | "NN|SYM",
56 | "PDT",
57 | "POS",
58 | "PRP",
59 | "PRP$",
60 | "RB",
61 | "RBR",
62 | "RBS",
63 | "RP",
64 | "SYM",
65 | "TO",
66 | "UH",
67 | "VB",
68 | "VBD",
69 | "VBG",
70 | "VBN",
71 | "VBP",
72 | "VBZ",
73 | "WDT",
74 | "WP",
75 | "WP$",
76 | "WRB"
77 | ],
78 | "_type": "ClassLabel"
79 | },
80 | "_type": "Sequence"
81 | },
82 | "chunk_tags": {
83 | "feature": {
84 | "names": [
85 | "O",
86 | "B-ADJP",
87 | "I-ADJP",
88 | "B-ADVP",
89 | "I-ADVP",
90 | "B-CONJP",
91 | "I-CONJP",
92 | "B-INTJ",
93 | "I-INTJ",
94 | "B-LST",
95 | "I-LST",
96 | "B-NP",
97 | "I-NP",
98 | "B-PP",
99 | "I-PP",
100 | "B-PRT",
101 | "I-PRT",
102 | "B-SBAR",
103 | "I-SBAR",
104 | "B-UCP",
105 | "I-UCP",
106 | "B-VP",
107 | "I-VP"
108 | ],
109 | "_type": "ClassLabel"
110 | },
111 | "_type": "Sequence"
112 | },
113 | "ner_tags": {
114 | "feature": {
115 | "names": [
116 | "O",
117 | "B-PER",
118 | "I-PER",
119 | "B-ORG",
120 | "I-ORG",
121 | "B-LOC",
122 | "I-LOC",
123 | "B-MISC",
124 | "I-MISC"
125 | ],
126 | "_type": "ClassLabel"
127 | },
128 | "_type": "Sequence"
129 | }
130 | },
131 | "homepage": "https://www.aclweb.org/anthology/W03-0419/",
132 | "license": "",
133 | "size_in_bytes": 11235597,
134 | "splits": {
135 | "train": {
136 | "name": "train",
137 | "num_bytes": 6931345,
138 | "num_examples": 14041,
139 | "dataset_name": "conll2003"
140 | },
141 | "validation": {
142 | "name": "validation",
143 | "num_bytes": 1739223,
144 | "num_examples": 3250,
145 | "dataset_name": "conll2003"
146 | },
147 | "test": {
148 | "name": "test",
149 | "num_bytes": 1582054,
150 | "num_examples": 3453,
151 | "dataset_name": "conll2003"
152 | }
153 | },
154 | "version": {
155 | "version_str": "1.0.0",
156 | "major": 1,
157 | "minor": 0,
158 | "patch": 0
159 | }
160 | }
--------------------------------------------------------------------------------
/tests/sample_data/conll2003/validation/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": "conll2003",
3 | "citation": "@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,\n title = \"Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition\",\n author = \"Tjong Kim Sang, Erik F. and\n De Meulder, Fien\",\n booktitle = \"Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003\",\n year = \"2003\",\n url = \"https://www.aclweb.org/anthology/W03-0419\",\n pages = \"142--147\",\n}\n",
4 | "config_name": "conll2003",
5 | "dataset_name": "conll2003",
6 | "dataset_size": 10252622,
7 | "description": "The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on\nfour types of named entities: persons, locations, organizations and names of miscellaneous entities that do\nnot belong to the previous three groups.\n\nThe CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on\na separate line and there is an empty line after each sentence. The first item on each line is a word, the second\na part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags\nand the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only\nif two phrases of the same type immediately follow each other, the first word of the second phrase will have tag\nB-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2\ntagging scheme, whereas the original dataset uses IOB1.\n\nFor more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419\n",
8 | "download_checksums": {
9 | "https://data.deepai.org/conll2003.zip": {
10 | "num_bytes": 982975,
11 | "checksum": null
12 | }
13 | },
14 | "download_size": 982975,
15 | "features": {
16 | "id": {
17 | "dtype": "string",
18 | "_type": "Value"
19 | },
20 | "tokens": {
21 | "feature": {
22 | "dtype": "string",
23 | "_type": "Value"
24 | },
25 | "_type": "Sequence"
26 | },
27 | "pos_tags": {
28 | "feature": {
29 | "names": [
30 | "\"",
31 | "''",
32 | "#",
33 | "$",
34 | "(",
35 | ")",
36 | ",",
37 | ".",
38 | ":",
39 | "``",
40 | "CC",
41 | "CD",
42 | "DT",
43 | "EX",
44 | "FW",
45 | "IN",
46 | "JJ",
47 | "JJR",
48 | "JJS",
49 | "LS",
50 | "MD",
51 | "NN",
52 | "NNP",
53 | "NNPS",
54 | "NNS",
55 | "NN|SYM",
56 | "PDT",
57 | "POS",
58 | "PRP",
59 | "PRP$",
60 | "RB",
61 | "RBR",
62 | "RBS",
63 | "RP",
64 | "SYM",
65 | "TO",
66 | "UH",
67 | "VB",
68 | "VBD",
69 | "VBG",
70 | "VBN",
71 | "VBP",
72 | "VBZ",
73 | "WDT",
74 | "WP",
75 | "WP$",
76 | "WRB"
77 | ],
78 | "_type": "ClassLabel"
79 | },
80 | "_type": "Sequence"
81 | },
82 | "chunk_tags": {
83 | "feature": {
84 | "names": [
85 | "O",
86 | "B-ADJP",
87 | "I-ADJP",
88 | "B-ADVP",
89 | "I-ADVP",
90 | "B-CONJP",
91 | "I-CONJP",
92 | "B-INTJ",
93 | "I-INTJ",
94 | "B-LST",
95 | "I-LST",
96 | "B-NP",
97 | "I-NP",
98 | "B-PP",
99 | "I-PP",
100 | "B-PRT",
101 | "I-PRT",
102 | "B-SBAR",
103 | "I-SBAR",
104 | "B-UCP",
105 | "I-UCP",
106 | "B-VP",
107 | "I-VP"
108 | ],
109 | "_type": "ClassLabel"
110 | },
111 | "_type": "Sequence"
112 | },
113 | "ner_tags": {
114 | "feature": {
115 | "names": [
116 | "O",
117 | "B-PER",
118 | "I-PER",
119 | "B-ORG",
120 | "I-ORG",
121 | "B-LOC",
122 | "I-LOC",
123 | "B-MISC",
124 | "I-MISC"
125 | ],
126 | "_type": "ClassLabel"
127 | },
128 | "_type": "Sequence"
129 | }
130 | },
131 | "homepage": "https://www.aclweb.org/anthology/W03-0419/",
132 | "license": "",
133 | "size_in_bytes": 11235597,
134 | "splits": {
135 | "train": {
136 | "name": "train",
137 | "num_bytes": 6931345,
138 | "num_examples": 14041,
139 | "dataset_name": "conll2003"
140 | },
141 | "validation": {
142 | "name": "validation",
143 | "num_bytes": 1739223,
144 | "num_examples": 3250,
145 | "dataset_name": "conll2003"
146 | },
147 | "test": {
148 | "name": "test",
149 | "num_bytes": 1582054,
150 | "num_examples": 3453,
151 | "dataset_name": "conll2003"
152 | }
153 | },
154 | "version": {
155 | "version_str": "1.0.0",
156 | "major": 1,
157 | "minor": 0,
158 | "patch": 0
159 | }
160 | }
--------------------------------------------------------------------------------
/src/cycleformers/task_processors/translation.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from dataclasses import dataclass
3 |
4 | from datasets import DatasetDict, IterableDatasetDict
5 |
6 | from .base import BaseProcessor, ProcessorConfig
7 |
8 |
9 | @dataclass
10 | class TranslationProcessorConfig(ProcessorConfig):
11 | """Configuration class for translation dataset processors.
12 |
13 | This class extends the base ProcessorConfig with translation-specific parameters.
14 |
15 | Args:
16 | dataset_name (str): HuggingFace dataset name/path. Defaults to "wmt14".
17 | dataset_config_name (str | None): Specific configuration of the dataset to load.
18 | For WMT14, must be one of ['cs-en', 'de-en', 'fr-en', 'hi-en', 'ru-en'].
19 | Defaults to "de-en".
20 | source_lang (str): Source language code (e.g., "en" for English). Defaults to "en".
21 | target_lang (str): Target language code (e.g., "de" for German). Defaults to "de".
22 | source_column (str): Column name containing source text. Defaults to "translation".
23 | target_column (str): Column name containing target text. If None, uses source_column. Defaults to None.
24 | preprocessing_fn (callable | None): Optional function to preprocess raw dataset entries.
25 | Should take a dataset entry and return a dict with 'source' and 'target' keys.
26 | Defaults to None.
27 |
28 | Example:
29 | >>> config = TranslationProcessorConfig(
30 | ... dataset_name="wmt14",
31 | ... dataset_config_name="de-en",
32 | ... source_lang="en",
33 | ... target_lang="de"
34 | ... )
35 | >>> processor = TranslationProcessor(config)
36 | """
37 |
38 | dataset_name: str = "wmt14"
39 | dataset_config_name: str = "de-en" # Required for WMT14
40 | source_lang: str = "en"
41 | target_lang: str = "de"
42 | source_column: str = "translation"
43 | target_column: str | None = None
44 | preprocessing_fn: Callable | None = None
45 |
46 |
47 | class TranslationProcessor(BaseProcessor):
48 | """Processor for translation datasets.
49 |
50 | This processor handles translation datasets in various formats and converts them into
51 | cycleformers-compatible format. It supports both:
52 | - Standard parallel corpora (source -> target)
53 | - Back-translation style training (target -> source)
54 |
55 | The processor:
56 | 1. Loads the translation dataset (streaming for large datasets)
57 | 2. Extracts source and target text
58 | 3. Creates two complementary datasets for cycle training
59 |
60 | Args:
61 | config (TranslationProcessorConfig): Configuration object controlling processor behavior.
62 | Includes settings like dataset name, language pairs, and column names.
63 |
64 | Example:
65 | >>> config = TranslationProcessorConfig(
66 | ... dataset_name="wmt14",
67 | ... dataset_config_name="de-en",
68 | ... source_lang="en",
69 | ... target_lang="de"
70 | ... )
71 | >>> processor = TranslationProcessor(config)
72 | >>> dataset_A, dataset_B = processor.process()
73 | >>> print(dataset_A["train"][0])
74 | {'text': 'The cat sat on the mat.'}
75 | >>> print(dataset_B["train"][0])
76 | {'text': 'Die Katze saß auf der Matte.'}
77 | """
78 |
79 | def __init__(self, config: TranslationProcessorConfig = TranslationProcessorConfig()):
80 | super().__init__(config)
81 | self.config: TranslationProcessorConfig = config
82 |
83 | def _extract_text_pair(self, example: dict) -> dict:
84 | """Extract source and target text from a dataset example.
85 |
86 | This method handles different dataset formats:
87 | 1. Nested dictionary format (e.g., {'translation': {'en': '...', 'de': '...'}})
88 | 2. Flat dictionary format (e.g., {'source': '...', 'target': '...'})
89 | 3. Custom formats via preprocessing_fn
90 |
91 | Args:
92 | example (dict): A single example from the dataset
93 |
94 | Returns:
95 | dict: Dictionary with 'source' and 'target' text
96 | """
97 | if self.config.preprocessing_fn is not None:
98 | return self.config.preprocessing_fn(example)
99 |
100 | source_col = self.config.source_column
101 | target_col = self.config.target_column or source_col
102 |
103 | if isinstance(example[source_col], dict):
104 | # Handle nested dictionary format (e.g., WMT datasets)
105 | return {
106 | "source": example[source_col][self.config.source_lang],
107 | "target": example[target_col][self.config.target_lang],
108 | }
109 | else:
110 | return {
111 | "source": example[source_col],
112 | "target": example[target_col],
113 | }
114 |
115 | def preprocess(self, dataset: DatasetDict | IterableDatasetDict) -> tuple[DatasetDict, DatasetDict]:
116 | """Preprocess the dataset into two separate datasets for cycle training.
117 |
118 | Args:
119 | dataset (DatasetDict | IterableDatasetDict): The raw dataset containing translation pairs
120 |
121 | Returns:
122 | tuple[DatasetDict, DatasetDict]: Two datasets:
123 | - Dataset A: Source language texts
124 | - Dataset B: Target language texts
125 | Each containing 'train' and 'test' splits with parallel data in test
126 | """
127 | original_cols = set(dataset["train"].column_names)
128 | dataset = dataset.map(self._extract_text_pair)
129 | dataset = dataset.remove_columns(original_cols - {"source", "target"})
130 |
131 | dataset_A = dataset.map(lambda x: {"text": x["source"], "label": x["target"]})
132 | dataset_B = dataset.map(lambda x: {"text": x["target"], "label": x["source"]})
133 |
134 | dataset_A["train"] = dataset_A["train"].remove_columns(["label"])
135 | dataset_B["train"] = dataset_B["train"].remove_columns(["label"])
136 |
137 | return dataset_A, dataset_B
138 |
--------------------------------------------------------------------------------
/src/cycleformers/model_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal
3 |
4 | from .utils import prefixed_view
5 |
6 |
7 | @dataclass
8 | class ModelConfig:
9 | """https://github.com/huggingface/trl/blob/main/trl/trainer/model_config.py
10 |
11 | Parameters:
12 | model_name_or_path (`Optional[str]`, *optional*, defaults to `None`):
13 | Model checkpoint for weights initialization.
14 | model_revision (`str`, *optional*, defaults to `"main"`):
15 | Specific model version to use. It can be a branch name, a tag name, or a commit id.
16 | torch_dtype (`Optional[Literal["auto", "bfloat16", "float16", "float32"]]`, *optional*, defaults to `None`):
17 | Override the default `torch.dtype` and load the model under this dtype. Possible values are
18 |
19 | - `"bfloat16"`: `torch.bfloat16`
20 | - `"float16"`: `torch.float16`
21 | - `"float32"`: `torch.float32`
22 | - `"auto"`: Automatically derive the dtype from the model's weights.
23 |
24 | trust_remote_code (`bool`, *optional*, defaults to `False`):
25 | Whether to allow for custom models defined on the Hub in their own modeling files. This option should only
26 | be set to `True` for repositories you trust and in which you have read the code, as it will execute code
27 | present on the Hub on your local machine.
28 | attn_implementation (`Optional[str]`, *optional*, defaults to `None`):
29 | Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
30 | you must install this manually by running `pip install flash-attn --no-build-isolation`.
31 | use_peft (`bool`, *optional*, defaults to `False`):
32 | Whether to use PEFT for training.
33 | lora_r (`int`, *optional*, defaults to `16`):
34 | LoRA R value.
35 | lora_alpha (`int`, *optional*, defaults to `32`):
36 | LoRA alpha.
37 | lora_dropout (`float`, *optional*, defaults to `0.05`):
38 | LoRA dropout.
39 | lora_target_modules (`Optional[Union[str, list[str]]]`, *optional*, defaults to `None`):
40 | LoRA target modules.
41 | lora_modules_to_save (`Optional[list[str]]`, *optional*, defaults to `None`):
42 | Model layers to unfreeze & train.
43 | lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`):
44 | Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling).
45 | use_rslora (`bool`, *optional*, defaults to `False`):
46 | Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, instead of
47 | the original default value of `lora_alpha/r`."""
48 |
49 | model_name_or_path: str | None = None
50 | model_revision: str = "main"
51 | torch_dtype: Literal["auto", "bfloat16", "float16", "float32"] | None = None
52 | trust_remote_code: bool = False
53 | attn_implementation: str | None = None
54 | use_peft: bool = False
55 | lora_r: int = 16
56 | lora_alpha: int = 32
57 | lora_dropout: float = 0.05
58 | lora_target_modules: list[str] | None = None
59 | lora_modules_to_save: list[str] | None = None
60 | lora_task_type: str = "CAUSAL_LM"
61 | use_rslora: bool = False
62 | use_dora: bool = False
63 |
64 | def __post_init__(self):
65 | self._A = None
66 | self._B = None
67 |
68 | @property
69 | def A(self) -> "ModelConfig":
70 | return self._A
71 |
72 | @A.setter
73 | def A(self, value: "ModelConfig"):
74 | self._A = value
75 |
76 | @property
77 | def B(self) -> "ModelConfig":
78 | return self._B
79 |
80 | @B.setter
81 | def B(self, value: "ModelConfig"):
82 | self._B = value
83 |
84 |
85 | @dataclass
86 | @prefixed_view(ModelConfig, "A_")
87 | class ModelConfigA:
88 | pass
89 |
90 |
91 | @dataclass
92 | @prefixed_view(ModelConfig, "B_")
93 | class ModelConfigB:
94 | pass
95 |
96 |
97 | def merge_configs(base_config: ModelConfig, config_a: ModelConfigA, config_b: ModelConfigB) -> ModelConfig:
98 | """Merge configs, with A/B specific values overriding base values, unless they're defaults.
99 |
100 | Args:
101 | base_config (ModelConfig): Base configuration with default values
102 | config_a (ModelConfigA): Model A specific configuration that may override base values
103 | config_b (ModelConfigB): Model B specific configuration that may override base values
104 |
105 | Returns:
106 | ModelConfig: The base config with A and B specific configs merged in
107 |
108 | Example:
109 | >>> base = ModelConfig(model_name="base", lora_r=32)
110 | >>> a = ModelConfigA(A_model_name="model_a", A_lora_r=64)
111 | >>> b = ModelConfigB(B_model_name="model_b")
112 | >>> merged = merge_configs(base, a, b)
113 | >>> merged.A.model_name
114 | 'model_a'
115 | >>> merged.A.lora_r
116 | 64
117 | >>> merged.B.model_name
118 | 'model_b'
119 | >>> merged.B.lora_r
120 | 32
121 | """
122 | # Create copies to avoid modifying originals
123 | merged_a = ModelConfig(**{k: getattr(base_config, k) for k in base_config.__dataclass_fields__})
124 | merged_b = ModelConfig(**{k: getattr(base_config, k) for k in base_config.__dataclass_fields__})
125 |
126 | # Create a default config to check against
127 | default_config = ModelConfig()
128 |
129 | # Override with A-specific values, but only if they're not defaults
130 | for field in base_config.__dataclass_fields__:
131 | if hasattr(config_a, field):
132 | config_a_value = getattr(config_a, field)
133 | # Only override if the A-specific value is different from default
134 | if config_a_value != getattr(default_config, field):
135 | setattr(merged_a, field, config_a_value)
136 |
137 | # Override with B-specific values, but only if they're not defaults
138 | for field in base_config.__dataclass_fields__:
139 | if hasattr(config_b, field):
140 | config_b_value = getattr(config_b, field)
141 | # Only override if the B-specific value is different from default
142 | if config_b_value != getattr(default_config, field):
143 | setattr(merged_b, field, config_b_value)
144 |
145 | base_config.A = merged_a
146 | base_config.B = merged_b
147 | return base_config
148 |
149 |
150 | __all__ = ["ModelConfig", "ModelConfigA", "ModelConfigB", "merge_configs"]
151 |
--------------------------------------------------------------------------------
/tests/benchmark/benchmark_tokenizer_skip.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import timeit
4 | from dataclasses import dataclass, field
5 | from datetime import datetime
6 | from pathlib import Path
7 | from random import Random
8 |
9 | import torch
10 | from benchmark_utils import Benchmark, BenchmarkConfig, MockCycleTrainer
11 | from datasets import load_dataset
12 | from torch.utils.data import DataLoader
13 | from transformers import AutoModelForCausalLM, AutoTokenizer
14 |
15 | from cycleformers.cycles import _default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs
16 | from cycleformers.utils import DEFAULT_SEP_SEQ
17 |
18 |
19 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
20 |
21 |
22 | @dataclass
23 | class TokenPrepBenchmarkConfig(BenchmarkConfig):
24 | config_name: str = "token_prep"
25 | dataset_name: str
26 | model_A_name: str
27 | model_B_name: str
28 | cycle_name: str | None = None
29 | num_samples: int = 500
30 | batch_sizes: list[int] = field(default_factory=lambda: [1, 8, 32, 128])
31 | devices: list[str] = field(default_factory=lambda: ["cpu", "cuda"])
32 | seed: int = field(default_factory=lambda: int(datetime.now().timestamp()))
33 |
34 |
35 | class TokenPreparationBenchmark:
36 | def __init__(self, config: TokenPrepBenchmarkConfig):
37 | self.config = config
38 | self.random = Random(config.seed)
39 |
40 | # Initialize tokenizer
41 | self.tokenizer = AutoTokenizer.from_pretrained(config.model_A_name)
42 | if self.tokenizer.pad_token is None:
43 | self.tokenizer.pad_token = self.tokenizer.eos_token
44 |
45 | self.model_A = AutoModelForCausalLM.from_pretrained(config.model_A_name)
46 |
47 | self.save_dir = Path(__file__).parent / "benchmark_results" / "token_prep"
48 | self.prepare_dataset()
49 |
50 | def prepare_dataset(self):
51 | # FIXME: Currently only supports datasets with "instruction", "response" as columns
52 | self.dataset = load_dataset(self.config.dataset_name, split="train").select(range(self.config.num_samples))
53 | self.dataset = self.dataset.map(
54 | lambda x: {
55 | "synth_ids": self.tokenizer(
56 | x["instruction"] + DEFAULT_SEP_SEQ + x["response"], padding=False
57 | ).input_ids,
58 | "real_ids": self.tokenizer(x["instruction"], padding=False).input_ids,
59 | }
60 | )
61 |
62 | def manual_right_pad_batch(self, batch_ids: list[list[int]], pad_token_id: int) -> torch.Tensor:
63 | """Right pad a batch of token IDs to the maximum length in the batch."""
64 | max_length = max(len(ids) for ids in batch_ids)
65 | padded_batch = []
66 |
67 | for ids in batch_ids:
68 | padding_length = max_length - len(ids)
69 | padded_ids = ids + [pad_token_id] * padding_length
70 | padded_batch.append(padded_ids)
71 |
72 | return torch.tensor(padded_batch)
73 |
74 | def get_dataloader(self, batch_size: int, device: str):
75 | generator = torch.Generator(device=device)
76 | generator.manual_seed(self.config.seed)
77 |
78 | def collate_fn(batch):
79 | real_ids = [item["real_ids"] for item in batch]
80 | synth_ids = [item["synth_ids"] for item in batch]
81 |
82 | # Manually pad both sequences
83 | real_ids_padded = self.manual_right_pad_batch(real_ids, self.tokenizer.pad_token_id)
84 | synth_ids_padded = self.manual_right_pad_batch(synth_ids, self.tokenizer.pad_token_id)
85 |
86 | return real_ids_padded, synth_ids_padded
87 |
88 | dataloader = DataLoader(
89 | self.dataset,
90 | batch_size=batch_size,
91 | shuffle=True,
92 | num_workers=4,
93 | pin_memory=True,
94 | # generator=generator,
95 | collate_fn=collate_fn,
96 | )
97 | return dataloader
98 |
99 | def _save_metrics(self, metrics: dict):
100 | run_dir = self.save_dir / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
101 | run_dir.mkdir(parents=True, exist_ok=True)
102 |
103 | with open(run_dir / f"metrics.json", "w") as f:
104 | json.dump(metrics, f)
105 |
106 | def run(self, implementations: list[str] | None = None, n_runs: int = 5):
107 | """Run benchmarks across all configurations"""
108 |
109 | results = []
110 |
111 | for batch_size in self.config.batch_sizes:
112 | for device in self.config.devices:
113 | if device == "cuda" and not torch.cuda.is_available():
114 | continue
115 |
116 | print(f"\nBenchmarking batch_size={batch_size} on {device}")
117 |
118 | # Prepare data
119 | dataloader = self.get_dataloader(batch_size, device)
120 |
121 | for impl in implementations:
122 | # Time the implementation
123 |
124 | def run_impl():
125 | outputs = []
126 | for batch in dataloader:
127 | real_ids, synth_ids = batch
128 |
129 | output = impl(
130 | real_input_ids=real_ids.to(device),
131 | synth_input_ids=synth_ids.to(device),
132 | model_gen=self.model_A,
133 | model_train=self.model_A,
134 | tokenizer_gen=self.tokenizer,
135 | tokenizer_train=self.tokenizer,
136 | cycle_name="A",
137 | )
138 | outputs.append(output)
139 | return outputs
140 |
141 | # Time runs
142 | if device == "cuda":
143 | torch.cuda.synchronize()
144 |
145 | timer = timeit.Timer(run_impl)
146 | times = timer.repeat(repeat=n_runs, number=5)
147 |
148 | metrics = {"mean": sum(times) / len(times), "min": min(times), "max": max(times)}
149 |
150 | results.append({"implementation": impl, "batch_size": batch_size, "device": device, **metrics})
151 |
152 | print(f"\n{impl}:")
153 | print(f"Mean: {metrics['mean']*1000:.2f}ms")
154 | print(f"Min: {metrics['min']*1000:.2f}ms")
155 | print(f"Max: {metrics['max']*1000:.2f}ms")
156 |
157 | return results
158 |
159 |
160 | if __name__ == "__main__":
161 | config = TokenPrepBenchmarkConfig(
162 | dataset_name="MBZUAI/LaMini-instruction",
163 | model_A_name="trl-internal-testing/tiny-LlamaForCausalLM-3.1",
164 | model_B_name="trl-internal-testing/tiny-LlamaForCausalLM-3.1",
165 | )
166 |
167 | benchmark = TokenPreparationBenchmark(config=config)
168 | benchmark.run(implementations=[_default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs])
169 |
--------------------------------------------------------------------------------
/tests/testing_utils/model_registry.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from dataclasses import dataclass
3 | from enum import Enum, auto
4 | from pathlib import Path
5 | from typing import Callable, Union
6 |
7 | import yaml
8 | from transformers import AutoConfig, PretrainedConfig
9 |
10 |
11 | class CapabilityExpression:
12 | """Expression for filtering models based on their capabilities."""
13 |
14 | def __init__(self, condition: Callable[[set["ModelCapability"]], bool]):
15 | self.condition = condition
16 |
17 | def evaluate(self, capabilities: set["ModelCapability"]) -> bool:
18 | return self.condition(capabilities)
19 |
20 | def __and__(self, other: "CapabilityExpression") -> "CapabilityExpression":
21 | return CapabilityExpression(lambda caps: self.condition(caps) and other.condition(caps))
22 |
23 | def __or__(self, other: "CapabilityExpression") -> "CapabilityExpression":
24 | return CapabilityExpression(lambda caps: self.condition(caps) or other.condition(caps))
25 |
26 | def __invert__(self) -> "CapabilityExpression":
27 | return CapabilityExpression(lambda caps: not self.condition(caps))
28 |
29 |
30 | class ModelCapability(Enum):
31 | SEQ2SEQ_LM = auto()
32 | CAUSAL_LM = auto()
33 |
34 | @classmethod
35 | def from_str(cls, name: str) -> "ModelCapability":
36 | return cls[name]
37 |
38 | def to_expression(self) -> "CapabilityExpression":
39 | return CapabilityExpression(lambda caps: self in caps)
40 |
41 | def __and__(self, other: Union["ModelCapability", "CapabilityExpression"]) -> "CapabilityExpression":
42 | return self.to_expression() & (other.to_expression() if isinstance(other, ModelCapability) else other)
43 |
44 | def __or__(self, other: Union["ModelCapability", "CapabilityExpression"]) -> "CapabilityExpression":
45 | return self.to_expression() | (other.to_expression() if isinstance(other, ModelCapability) else other)
46 |
47 | def __invert__(self) -> "CapabilityExpression":
48 | return ~self.to_expression()
49 |
50 |
51 | def infer_capabilities_from_config(config: PretrainedConfig) -> set[ModelCapability]:
52 | """Infer model capabilities from HuggingFace config.
53 |
54 | Args:
55 | config: HuggingFace model config
56 |
57 | Returns:
58 | Set of model capabilities inferred from the config.
59 | """
60 | capabilities = set()
61 |
62 | # Architecture-based capabilities
63 | if hasattr(config, "is_encoder_decoder") and config.is_encoder_decoder:
64 | capabilities.add(ModelCapability.SEQ2SEQ_LM)
65 |
66 | if hasattr(config, "architectures") and config.architectures:
67 | if any("ForCausalLM" in arch for arch in config.architectures):
68 | capabilities.add(ModelCapability.CAUSAL_LM)
69 |
70 | return capabilities
71 |
72 |
73 | @dataclass
74 | class ModelSpec:
75 | name: str
76 | repo_id: str
77 | capabilities: set[ModelCapability]
78 | config: PretrainedConfig
79 | description: str = "" # Any notes that are important to remember about the model
80 |
81 | def __hash__(self) -> int:
82 | return hash((self.name, self.repo_id))
83 |
84 | def __eq__(self, other: object) -> bool:
85 | if not isinstance(other, ModelSpec):
86 | return NotImplemented
87 | return self.name == other.name and self.repo_id == other.repo_id
88 |
89 | @classmethod
90 | def from_hub(cls, name: str, repo_id: str) -> "ModelSpec":
91 | config = AutoConfig.from_pretrained(repo_id)
92 | capabilities = infer_capabilities_from_config(config)
93 | return cls(name=name, repo_id=repo_id, capabilities=capabilities, config=config)
94 |
95 |
96 | class ModelRegistry:
97 | def __init__(self, registry_path: Path):
98 | self._models: dict[str, ModelSpec] = {}
99 | self.load_registry(registry_path)
100 |
101 | def load_registry(self, registry_path: Path):
102 | with registry_path.open() as f:
103 | registry_dict = yaml.safe_load(f)
104 |
105 | for name, spec in registry_dict.items():
106 | self._models[name] = ModelSpec.from_hub(name=name, repo_id=spec["repo_id"])
107 |
108 | def get_matching_models(
109 | self,
110 | capability_expr: ModelCapability | CapabilityExpression | None = None,
111 | model_names: list[str] | str | None = None,
112 | ) -> list[ModelSpec]:
113 | """Get models matching capability expression AND model names.
114 |
115 | Args:
116 | capability_expr: Capability expression to match. Can be a ModelCapability or a CapabilityExpression.
117 | If None, no capability filtering is applied.
118 | model_names: List of model names to match. Can be a single string or a list of strings.
119 | If None, no name filtering is applied.
120 |
121 | Returns:
122 | List of models matching both the capability expression and model names.
123 | If both capability_expr and model_names are None, returns all models.
124 |
125 | Examples:
126 | >>> registry = ModelRegistry(Path("models_to_test.yaml"))
127 | >>> # Find models that are NOT seq2seq
128 | >>> models = registry.get_matching_models(~ModelCapability.SEQ2SEQ_LM)
129 |
130 | >>> # Find models that are causal LM
131 | >>> models = registry.get_matching_models(ModelCapability.CAUSAL_LM)
132 |
133 | >>> # Find models that are either causal LM or seq2seq
134 | >>> models = registry.get_matching_models(
135 | ... ModelCapability.CAUSAL_LM | ModelCapability.SEQ2SEQ_LM
136 | ... )
137 |
138 | >>> # Find models that are both causal LM and seq2seq (empty list)
139 | >>> models = registry.get_matching_models(
140 | ... ModelCapability.CAUSAL_LM & ModelCapability.SEQ2SEQ_LM
141 | ... )
142 |
143 | >>> # Get specific models by name
144 | >>> models = registry.get_matching_models(model_names=["tiny-llama", "tiny-t5"])
145 |
146 | >>> # Combine capability and name filters
147 | >>> models = registry.get_matching_models(
148 | ... capability_expr=ModelCapability.SEQ2SEQ_LM,
149 | ... model_names=["tiny-t5"]
150 | ... )
151 |
152 | >>> # Return tiny-llama-3.1
153 | >>> models = registry.get_matching_models(model_names="tiny-llama-3.1")
154 |
155 | >>> # Get all models
156 | >>> models = registry.get_matching_models()
157 | """
158 | if capability_expr is None and model_names is None:
159 | return list(self._models.values())
160 |
161 | if isinstance(capability_expr, ModelCapability):
162 | capability_expr = capability_expr.to_expression()
163 |
164 | if isinstance(model_names, str):
165 | model_names = [model_names]
166 |
167 | matches = []
168 | for spec in self._models.values():
169 | if capability_expr is not None:
170 | if not capability_expr.evaluate(spec.capabilities):
171 | continue
172 |
173 | if model_names is not None:
174 | if spec.name not in model_names:
175 | continue
176 |
177 | matches.append(spec)
178 |
179 | return matches
180 |
--------------------------------------------------------------------------------
/tests/task_processors/test_translation_processors.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | from datasets import Dataset, DatasetDict
5 |
6 | from cycleformers.task_processors.translation import TranslationProcessor, TranslationProcessorConfig
7 |
8 |
9 | SAMPLES_DIR = Path(__file__).parent.parent / "sample_data"
10 |
11 |
12 | class TestTranslationProcessor:
13 | def test_preprocess_wmt_format(self, monkeypatch, temp_dir):
14 | """Test preprocessing of WMT-style datasets with nested translation dictionaries."""
15 | monkeypatch.setenv("HF_DATASETS_CACHE", str(temp_dir))
16 |
17 | config = TranslationProcessorConfig(
18 | dataset_name=SAMPLES_DIR / "wmt14",
19 | dataset_config_name="de-en", # Required for WMT14
20 | dataset_seed=42, # Fixed seed for reproducible tests
21 | cache_dir=str(temp_dir),
22 | )
23 | processor = TranslationProcessor(config)
24 | dataset_A, dataset_B = processor.process()
25 |
26 | # Check that all splits are preserved
27 | assert set(dataset_A.keys()) == {"train", "validation", "test"}
28 | assert set(dataset_B.keys()) == {"train", "validation", "test"}
29 |
30 | # Check that training data is non-parallel (no labels)
31 | assert "label" not in dataset_A["train"].column_names
32 | assert "label" not in dataset_B["train"].column_names
33 |
34 | # Verify training splits are properly shuffled and unaligned
35 | train_texts_A = dataset_A["train"]["text"]
36 | train_texts_B = dataset_B["train"]["text"]
37 |
38 | # Check no exact alignment exists
39 | assert not any(
40 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i]))
41 | for i in range(len(train_texts_B))
42 | ), "Training splits appear to be aligned with some offset"
43 |
44 | # Verify that shuffling is deterministic with the same seed
45 | config2 = TranslationProcessorConfig(
46 | dataset_name=SAMPLES_DIR / "wmt14",
47 | dataset_config_name="de-en",
48 | dataset_seed=42,
49 | cache_dir=str(temp_dir),
50 | )
51 | processor2 = TranslationProcessor(config2)
52 | dataset_A2, dataset_B2 = processor2.process()
53 |
54 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"])
55 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"])
56 |
57 | # Check that evaluation splits maintain parallel data
58 | for key in ["validation", "test"]:
59 | assert dataset_A[key]["text"] == dataset_B[key]["label"]
60 | assert dataset_A[key]["label"] == dataset_B[key]["text"]
61 | assert dataset_A[key]["text"] != dataset_B[key]["text"]
62 | assert dataset_A[key]["label"] != dataset_B[key]["label"]
63 |
64 | def test_preprocess_flat_format(self):
65 | """Test preprocessing of datasets with flat source/target columns."""
66 | data = {
67 | "train": {
68 | "source": ["Hello", "World", "Good", "Morning"],
69 | "target": ["Hallo", "Welt", "Gut", "Morgen"],
70 | },
71 | "test": {
72 | "source": ["Test"],
73 | "target": ["Test"],
74 | },
75 | }
76 | dataset = DatasetDict({split: Dataset.from_dict(data) for split, data in data.items()})
77 |
78 | config = TranslationProcessorConfig(
79 | dataset_name=SAMPLES_DIR / "wmt14",
80 | source_column="source",
81 | target_column="target",
82 | dataset_seed=42, # Fixed seed for reproducible tests
83 | )
84 | processor = TranslationProcessor(config)
85 | dataset_A, dataset_B = processor.preprocess(dataset)
86 |
87 | # Check basic structure
88 | assert len(dataset_A.keys()) == len(dataset.keys())
89 | assert len(dataset_B.keys()) == len(dataset.keys())
90 |
91 | # Verify training splits are properly shuffled and unaligned
92 | train_texts_A = dataset_A["train"]["text"]
93 | train_texts_B = dataset_B["train"]["text"]
94 |
95 | # Check no exact alignment exists
96 | assert not any(
97 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i]))
98 | for i in range(len(train_texts_B))
99 | ), "Training splits appear to be aligned with some offset"
100 |
101 | # Verify that shuffling is deterministic with the same seed
102 | config2 = TranslationProcessorConfig(
103 | source_column="source",
104 | target_column="target",
105 | dataset_seed=42,
106 | )
107 | processor2 = TranslationProcessor(config2)
108 | dataset_A2, dataset_B2 = processor2.preprocess(dataset)
109 |
110 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"])
111 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"])
112 |
113 | # Check evaluation data is parallel
114 | assert dataset_A["test"][0]["text"] == "Test"
115 | assert dataset_A["test"][0]["label"] == "Test"
116 |
117 | def test_custom_preprocessing(self):
118 | """Test preprocessing with a custom preprocessing function."""
119 | data = {
120 | "train": {
121 | "text": [
122 | "en: Hello || de: Hallo",
123 | "en: World || de: Welt",
124 | "en: Good || de: Gut",
125 | "en: Morning || de: Morgen",
126 | ],
127 | },
128 | "test": {
129 | "text": ["en: Test || de: Test"],
130 | },
131 | }
132 | dataset = DatasetDict({split: Dataset.from_dict(data) for split, data in data.items()})
133 |
134 | def custom_preprocessor(example):
135 | en, de = example["text"].split("||")
136 | return {
137 | "source": en.split(": ")[1].strip(),
138 | "target": de.split(": ")[1].strip(),
139 | }
140 |
141 | config = TranslationProcessorConfig(
142 | dataset_name=SAMPLES_DIR / "wmt14",
143 | preprocessing_fn=custom_preprocessor,
144 | dataset_seed=42, # Fixed seed for reproducible tests
145 | )
146 | processor = TranslationProcessor(config)
147 | dataset_A, dataset_B = processor.preprocess(dataset)
148 |
149 | # Verify training splits are properly shuffled and unaligned
150 | train_texts_A = dataset_A["train"]["text"]
151 | train_texts_B = dataset_B["train"]["text"]
152 |
153 | # Check no exact alignment exists
154 | assert not any(
155 | all(a == b for a, b in zip(train_texts_A, train_texts_B[i:] + train_texts_B[:i]))
156 | for i in range(len(train_texts_B))
157 | ), "Training splits appear to be aligned with some offset"
158 |
159 | # Verify that shuffling is deterministic with the same seed
160 | config2 = TranslationProcessorConfig(
161 | preprocessing_fn=custom_preprocessor,
162 | dataset_seed=42,
163 | )
164 | processor2 = TranslationProcessor(config2)
165 | dataset_A2, dataset_B2 = processor2.preprocess(dataset)
166 |
167 | assert list(dataset_A["train"]["text"]) == list(dataset_A2["train"]["text"])
168 | assert list(dataset_B["train"]["text"]) == list(dataset_B2["train"]["text"])
169 |
170 | # Check evaluation data is parallel
171 | assert dataset_A["test"][0]["text"] == "Test"
172 | assert dataset_A["test"][0]["label"] == "Test"
173 |
--------------------------------------------------------------------------------
/src/cycleformers/task_processors/ner.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal
3 |
4 | from datasets import DatasetDict
5 |
6 | from .base import BaseProcessor, ProcessorConfig
7 |
8 |
9 | tag_to_idx = {"O": 0, "B-PER": 1, "I-PER": 2, "B-ORG": 3, "I-ORG": 4, "B-LOC": 5, "I-LOC": 6, "B-MISC": 7, "I-MISC": 8}
10 | idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}
11 | tag_to_string = {"PER": "person", "ORG": "organisation", "LOC": "location", "MISC": "miscellaneous"}
12 | string_to_tag = {string: tag for tag, string in tag_to_string.items()}
13 |
14 |
15 | @dataclass
16 | class CONLL2003ProcessorConfig(ProcessorConfig):
17 | """Configuration class for CONLL2003 dataset processor.
18 |
19 | This class extends the base ProcessorConfig with CONLL2003-specific parameters.
20 |
21 | Args:
22 | dataset_name (str): HuggingFace dataset name/path. Defaults to "eriktks/conll2003".
23 | preset (Literal["entity_seqs"] | None): Processing preset to use. Currently only supports "entity_seqs"
24 | which extracts sequences containing named entities. Defaults to "entity_seqs".
25 | sep_token (str): Token used to separate elements in processed sequences. Defaults to "|".
26 |
27 | Example:
28 | >>> config = CONLL2003ProcessorConfig(
29 | ... dataset_name="conll2003",
30 | ... preset="entity_seqs",
31 | ... sep_token="|"
32 | ... )
33 | >>> processor = CONLL2003Processor(config)
34 | """
35 |
36 | dataset_name: str = "eriktks/conll2003"
37 | preset: Literal["entity_seqs"] | None = "entity_seqs"
38 | # trust_remote_code: bool = False
39 | sep_token: str = "|"
40 |
41 |
42 | def reconstruct_sentence(tokens: list[str]) -> str:
43 | """Reconstructs a sentence from CoNLL-2003 tokens while handling whitespace correctly.
44 |
45 | This function takes a list of tokens from the CoNLL-2003 format and reconstructs them into a properly formatted sentence
46 | by applying standard English language spacing rules:
47 |
48 | - No spaces before punctuation marks (,.!?:;)
49 | - Spaces between regular words
50 | - Proper handling of contractions (don't, I'm, etc.)
51 | - Correct spacing around quotation marks
52 |
53 | Args:
54 | tokens (list[str]): List of tokens from the CoNLL-2003 dataset
55 |
56 | Returns:
57 | str: The reconstructed sentence with proper spacing and punctuation
58 |
59 | Example:
60 | >>> tokens = ["I", "don", "'t", "like", "New", "York", "."]
61 | >>> reconstruct_sentence(tokens)
62 | "I don't like New York."
63 | """
64 | if not tokens:
65 | return ""
66 |
67 | # Define punctuation categories
68 | closing_puncts = set(",.!?:;")
69 | opening_quotes = set('"' '"')
70 | closing_quotes = set('"' '"')
71 | contractions = set("'s 're 've 'm 't 'll 'd".split())
72 |
73 | result = []
74 | for i, token in enumerate(tokens):
75 | if i == 0:
76 | result.append(token)
77 | continue
78 |
79 | prev_token = tokens[i - 1]
80 |
81 | needs_space = not (
82 | token in closing_puncts
83 | or token.lower() in contractions
84 | or prev_token in opening_quotes
85 | or token in closing_quotes
86 | )
87 |
88 | if needs_space:
89 | result.append(" " + token)
90 | else:
91 | result.append(token)
92 |
93 | return "".join(result)
94 |
95 |
96 | def ner_to_sequences(tokens: list[str], tags: list[int], sep_token: str) -> str:
97 | """Convert a list of tokens and their corresponding tags to a sequence of entity types.
98 |
99 | This function takes tokens and their NER tags and converts them into a sequence format
100 | suitable for sequence-to-sequence training. It follows the BIO2 tagging scheme where:
101 | - B- prefix indicates beginning of an entity
102 | - I- prefix indicates inside/continuation of an entity
103 | - O indicates outside any entity (not used in output)
104 |
105 | Args:
106 | tokens (list[str]): List of string tokens from a single sentence
107 | tags (list[int]): List of integer tags corresponding to the tokens using BIO2 scheme
108 | sep_token (str): Separator token to use between entity and its type
109 |
110 | Returns:
111 | str: A string containing entities and their types separated by the sep_token.
112 | For example: "John Smith | person Google | organization"
113 |
114 | Raises:
115 | ValueError: If an I- tag appears without a preceding B- tag (invalid BIO2 sequence)
116 |
117 | Example:
118 | >>> tokens = ["John", "Smith", "works", "at", "Google", "."]
119 | >>> tags = [1, 2, 0, 0, 7, 0] # B-PER, I-PER, O, O, B-ORG, O
120 | >>> ner_to_sequences(tokens, tags, " | ")
121 | 'John Smith | person Google | organization'
122 | """
123 | compound_tokens = []
124 | for token, tag in zip(tokens, tags):
125 | if tag in [1, 3, 5, 7]:
126 | tag_string = tag_to_string[idx_to_tag[tag].split("-")[-1]]
127 | compound_tokens.append([token, sep_token, tag_string])
128 | elif tag in [2, 4, 6, 8]:
129 | if not compound_tokens:
130 | raise ValueError("Missing B-tag before I-tag. Please use BIO2 tagging scheme, not BIO1.")
131 |
132 | compound_tokens[-1].insert(-2, token)
133 |
134 | return f" {sep_token} ".join([" ".join(token_tags) for token_tags in compound_tokens])
135 |
136 |
137 | class CONLL2003Processor(BaseProcessor):
138 | """Processor for the CONLL2003 Named Entity Recognition dataset.
139 |
140 | This processor handles the CONLL2003 dataset which contains text annotated with named entities.
141 | It converts the dataset into two formats:
142 | - Dataset A: Raw text -> Entity sequences
143 | - Dataset B: Entity sequences -> Raw text
144 |
145 | The processor:
146 | 1. Loads the CONLL2003 dataset
147 | 2. Converts the token-level NER annotations into sequence format
148 | 3. Creates two complementary datasets for cycle training
149 |
150 | Args:
151 | config (CONLL2003ProcessorConfig): Configuration object controlling processor behavior.
152 | Includes settings like separator token between entities and their types.
153 |
154 | Example:
155 | >>> config = CONLL2003ProcessorConfig(sep_token=" | ")
156 | >>> processor = CONLL2003Processor(config)
157 | >>> dataset_A, dataset_B = processor.process()
158 | >>> print(dataset_A["train"][`0`])
159 | {'text': 'John Smith works at Google.'}
160 | >>> print(dataset_B["train"][`0`])
161 | {'text': 'John Smith | person Google | organization'}
162 | """
163 |
164 | def __init__(self, config: CONLL2003ProcessorConfig = CONLL2003ProcessorConfig()):
165 | super().__init__(config)
166 | # Ensure formatting of sep token is correct
167 | self.config: CONLL2003ProcessorConfig = config # type annotation for config
168 | self.sep_token = config.sep_token.strip()
169 |
170 | def preprocess(self, dataset: DatasetDict) -> tuple[DatasetDict, DatasetDict]:
171 | original_cols = dataset["train"].column_names
172 | dataset = dataset.map(
173 | lambda x: {
174 | "sentence": reconstruct_sentence(x["tokens"]),
175 | "entity_seq": ner_to_sequences(x["tokens"], x["ner_tags"], self.sep_token),
176 | }
177 | ).remove_columns(original_cols)
178 |
179 | dataset_A = dataset.map(lambda x: {"text": x["sentence"], "label": x["entity_seq"]})
180 | dataset_B = dataset.map(lambda x: {"text": x["entity_seq"], "label": x["sentence"]})
181 |
182 | dataset_A["train"] = dataset_A["train"].remove_columns(["label"])
183 | dataset_B["train"] = dataset_B["train"].remove_columns(["label"])
184 |
185 | return dataset_A, dataset_B
186 |
--------------------------------------------------------------------------------
/tests/testing_utils/test_model_registry.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import yaml
3 | from transformers import AutoConfig, PretrainedConfig
4 |
5 | from tests.testing_utils.model_registry import (
6 | CapabilityExpression,
7 | ModelCapability,
8 | ModelRegistry,
9 | infer_capabilities_from_config,
10 | )
11 |
12 |
13 | @pytest.fixture
14 | def sample_registry_file(temp_dir):
15 | registry_path = temp_dir / "Verify_models.yaml"
16 | registry_content = {
17 | "tiny-llama": {
18 | "repo_id": "fake-repo/tiny-llama",
19 | },
20 | "tiny-t5": {
21 | "repo_id": "fake-repo/tiny-t5",
22 | },
23 | }
24 |
25 | with registry_path.open("w") as f:
26 | yaml.dump(registry_content, f)
27 |
28 | return registry_path
29 |
30 |
31 | @pytest.fixture
32 | def mock_config_test(monkeypatch):
33 | def mock_from_pretrained(repo_id):
34 | if "llama" in repo_id.lower():
35 | config = PretrainedConfig()
36 | config.architectures = ["LlamaForCausalLM"]
37 | config.model_type = "llama"
38 | return config
39 | elif "t5" in repo_id.lower():
40 | config = PretrainedConfig()
41 | config.architectures = ["T5ForConditionalGeneration"]
42 | config.is_encoder_decoder = True
43 | config.model_type = "t5"
44 | return config
45 | raise ValueError(f"Unknown model: {repo_id}")
46 |
47 | monkeypatch.setattr(AutoConfig, "from_pretrained", mock_from_pretrained)
48 |
49 |
50 | @pytest.mark.meta
51 | class TestModelRegistry:
52 | def test_load_registry(self, sample_registry_file, mock_config_test):
53 | registry = ModelRegistry(sample_registry_file)
54 |
55 | assert len(registry._models) == 2
56 | assert "tiny-llama" in registry._models
57 | assert "tiny-t5" in registry._models
58 |
59 | llama = registry._models["tiny-llama"]
60 | assert llama.name == "tiny-llama"
61 | assert llama.repo_id == "fake-repo/tiny-llama"
62 | assert ModelCapability.CAUSAL_LM in llama.capabilities
63 | assert ModelCapability.SEQ2SEQ_LM not in llama.capabilities
64 |
65 | t5 = registry._models["tiny-t5"]
66 | assert t5.name == "tiny-t5"
67 | assert t5.repo_id == "fake-repo/tiny-t5"
68 | assert ModelCapability.SEQ2SEQ_LM in t5.capabilities
69 |
70 | @pytest.mark.parametrize(
71 | "capability,model_names,expected_models",
72 | [
73 | (ModelCapability.CAUSAL_LM, None, ["tiny-llama"]),
74 | (ModelCapability.SEQ2SEQ_LM, None, ["tiny-t5"]),
75 | (~ModelCapability.SEQ2SEQ_LM, None, ["tiny-llama"]),
76 | (ModelCapability.CAUSAL_LM | ModelCapability.SEQ2SEQ_LM, None, ["tiny-llama", "tiny-t5"]),
77 | (ModelCapability.CAUSAL_LM & ModelCapability.SEQ2SEQ_LM, None, []),
78 | (None, ["tiny-llama"], ["tiny-llama"]),
79 | (None, ["tiny-llama", "tiny-t5"], ["tiny-llama", "tiny-t5"]),
80 | (ModelCapability.CAUSAL_LM, "tiny-llama", ["tiny-llama"]),
81 | (ModelCapability.CAUSAL_LM, ["tiny-t5"], []),
82 | ],
83 | )
84 | def test_get_matching_models(
85 | self, sample_registry_file, mock_config_test, capability, model_names, expected_models
86 | ):
87 | registry = ModelRegistry(sample_registry_file)
88 | matches = registry.get_matching_models(capability, model_names=model_names)
89 | assert sorted([m.name for m in matches]) == sorted(expected_models)
90 |
91 |
92 | @pytest.mark.meta
93 | class TestCapabilityExpression:
94 | def test_simple_condition(self):
95 | expr = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps)
96 | assert expr.evaluate({ModelCapability.CAUSAL_LM})
97 | assert not expr.evaluate({ModelCapability.SEQ2SEQ_LM})
98 |
99 | def test_and_operator(self):
100 | expr1 = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps)
101 | expr2 = CapabilityExpression(lambda caps: ModelCapability.SEQ2SEQ_LM in caps)
102 | combined = expr1 & expr2
103 |
104 | assert not combined.evaluate({ModelCapability.CAUSAL_LM})
105 | assert not combined.evaluate({ModelCapability.SEQ2SEQ_LM})
106 | assert combined.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM})
107 |
108 | def test_or_operator(self):
109 | expr1 = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps)
110 | expr2 = CapabilityExpression(lambda caps: ModelCapability.SEQ2SEQ_LM in caps)
111 | combined = expr1 | expr2
112 |
113 | assert combined.evaluate({ModelCapability.CAUSAL_LM})
114 | assert combined.evaluate({ModelCapability.SEQ2SEQ_LM})
115 | assert combined.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM})
116 | assert not combined.evaluate(set())
117 |
118 | def test_not_operator(self):
119 | expr = CapabilityExpression(lambda caps: ModelCapability.CAUSAL_LM in caps)
120 | inverted = ~expr
121 |
122 | assert not inverted.evaluate({ModelCapability.CAUSAL_LM})
123 | assert inverted.evaluate({ModelCapability.SEQ2SEQ_LM})
124 | assert inverted.evaluate(set())
125 |
126 |
127 | @pytest.mark.meta
128 | class TestModelCapability:
129 | def test_from_str(self):
130 | assert ModelCapability.from_str("CAUSAL_LM") == ModelCapability.CAUSAL_LM
131 | assert ModelCapability.from_str("SEQ2SEQ_LM") == ModelCapability.SEQ2SEQ_LM
132 |
133 | with pytest.raises(KeyError):
134 | ModelCapability.from_str("INVALID")
135 |
136 | def test_to_expression(self):
137 | expr = ModelCapability.CAUSAL_LM.to_expression()
138 | assert isinstance(expr, CapabilityExpression)
139 | assert expr.evaluate({ModelCapability.CAUSAL_LM})
140 | assert not expr.evaluate({ModelCapability.SEQ2SEQ_LM})
141 |
142 | def test_capability_and(self):
143 | expr = ModelCapability.CAUSAL_LM & ModelCapability.SEQ2SEQ_LM
144 | assert isinstance(expr, CapabilityExpression)
145 | assert not expr.evaluate({ModelCapability.CAUSAL_LM})
146 | assert not expr.evaluate({ModelCapability.SEQ2SEQ_LM})
147 | assert expr.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM})
148 |
149 | def test_capability_or(self):
150 | expr = ModelCapability.CAUSAL_LM | ModelCapability.SEQ2SEQ_LM
151 | assert isinstance(expr, CapabilityExpression)
152 | assert expr.evaluate({ModelCapability.CAUSAL_LM})
153 | assert expr.evaluate({ModelCapability.SEQ2SEQ_LM})
154 | assert expr.evaluate({ModelCapability.CAUSAL_LM, ModelCapability.SEQ2SEQ_LM})
155 |
156 | def test_capability_not(self):
157 | expr = ~ModelCapability.CAUSAL_LM
158 | assert isinstance(expr, CapabilityExpression)
159 | assert not expr.evaluate({ModelCapability.CAUSAL_LM})
160 | assert expr.evaluate({ModelCapability.SEQ2SEQ_LM})
161 | assert expr.evaluate(set())
162 |
163 |
164 | @pytest.mark.meta
165 | class TestInferCapabilities:
166 | @pytest.fixture
167 | def causal_lm_config(self):
168 | config = PretrainedConfig()
169 | config.architectures = ["LlamaForCausalLM"]
170 | config.model_type = "llama"
171 | config.is_encoder_decoder = False
172 | return config
173 |
174 | @pytest.fixture
175 | def seq2seq_config(self):
176 | config = PretrainedConfig()
177 | config.architectures = ["T5ForConditionalGeneration"]
178 | config.model_type = "t5"
179 | config.is_encoder_decoder = True
180 | return config
181 |
182 | def test_infer_causal_lm_from_architecture(self, causal_lm_config):
183 | capabilities = infer_capabilities_from_config(causal_lm_config)
184 | assert ModelCapability.CAUSAL_LM in capabilities
185 | assert ModelCapability.SEQ2SEQ_LM not in capabilities
186 |
187 | def test_infer_seq2seq_from_architecture(self, seq2seq_config):
188 | capabilities = infer_capabilities_from_config(seq2seq_config)
189 | assert ModelCapability.SEQ2SEQ_LM in capabilities
190 | assert ModelCapability.CAUSAL_LM not in capabilities
191 |
192 | def test_empty_config(self):
193 | config = PretrainedConfig()
194 | capabilities = infer_capabilities_from_config(config)
195 | assert len(capabilities) == 0
196 |
--------------------------------------------------------------------------------
/tests/command/test_cli_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from dataclasses import dataclass
3 | from typing import Tuple
4 |
5 | import pytest
6 | import yaml
7 |
8 | from cycleformers.command.cli_utils import VALID_TASKS, CfArgumentParser
9 |
10 |
11 | @dataclass
12 | class DummyArgs:
13 | model_name: str
14 | learning_rate: float = 0.001
15 |
16 |
17 | @dataclass
18 | class DummyArgsAB:
19 | A_model: str
20 | B_model: str
21 | A_learning_rate: float = 0.001
22 | B_learning_rate: float = 0.0001
23 | shared_param: str = "default"
24 |
25 |
26 | @dataclass
27 | class DummyArgsC:
28 | C_model: str
29 | C_learning_rate: float = 0.0005
30 |
31 |
32 | class TestCFArgumentParser:
33 | def test_init_with_task(self):
34 | parser = CfArgumentParser(task="train")
35 | assert parser.task == "train"
36 |
37 | def test_init_without_task(self):
38 | parser = CfArgumentParser()
39 | assert parser.task is None
40 |
41 | def test_init_with_multiple_dataclasses(self):
42 | parser = CfArgumentParser([DummyArgs, DummyArgsAB])
43 | assert len(parser.dataclass_types) == 2
44 | assert parser.dataclass_types[0] == DummyArgs
45 | assert parser.dataclass_types[1] == DummyArgsAB
46 |
47 | def test_init_with_single_dataclass_not_in_list(self):
48 | parser = CfArgumentParser(DummyArgs)
49 | assert len(parser.dataclass_types) == 1
50 | assert parser.dataclass_types[0] == DummyArgs
51 |
52 | def test_init_with_no_dataclasses(self):
53 | parser = CfArgumentParser()
54 | assert len(parser.dataclass_types) == 0
55 |
56 | def test_invalid_task_raises_error(self, monkeypatch):
57 | parser = CfArgumentParser()
58 | monkeypatch.setattr(sys, "argv", ["invalid_task"])
59 | with pytest.raises(ValueError, match="Task must be one of"):
60 | parser.parse_args_and_config()
61 |
62 | def test_no_task_provided_raises_error(self, monkeypatch):
63 | parser = CfArgumentParser()
64 | monkeypatch.setattr(sys, "argv", ["script.py"])
65 | with pytest.raises(ValueError, match="No task provided"):
66 | parser.parse_args_and_config()
67 |
68 | def test_task_mismatch_raises_error(self, monkeypatch):
69 | parser = CfArgumentParser(task="train")
70 | monkeypatch.setattr(sys, "argv", ["script.py", "train"])
71 | with pytest.raises(ValueError, match="Task already set"):
72 | parser.parse_args_and_config()
73 |
74 | def test_parse_simple_args(self, monkeypatch):
75 | parser = CfArgumentParser([DummyArgs])
76 | args = ["train", "--model_name", "bert", "--learning_rate", "0.001"]
77 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
78 |
79 | (parsed_args,) = parser.parse_args_and_config()
80 | assert isinstance(parsed_args, DummyArgs)
81 | assert parsed_args.model_name == "bert"
82 | assert parsed_args.learning_rate == 0.001
83 |
84 | def test_parse_multiple_dataclasses(self, monkeypatch):
85 | parser = CfArgumentParser([DummyArgs, DummyArgsC])
86 | args = [
87 | "train",
88 | "--model_name",
89 | "bert",
90 | "--learning_rate",
91 | "0.001",
92 | "--C_model",
93 | "gpt2",
94 | "--C_learning_rate",
95 | "0.0005",
96 | ]
97 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
98 |
99 | parsed_args = parser.parse_args_and_config()
100 | assert isinstance(parsed_args, Tuple)
101 | assert len(parsed_args) == 2
102 | assert parsed_args[0].model_name == "bert"
103 | assert parsed_args[1].C_model == "gpt2"
104 |
105 | def test_parse_yaml_config(self, tmp_path):
106 | config = {
107 | "A": {"model": "bert-base", "learning_rate": 0.001},
108 | "B": {"model": "gpt2", "learning_rate": 0.0001},
109 | "shared_param": "value",
110 | }
111 |
112 | config_path = tmp_path / "config.yaml"
113 | with open(config_path, "w") as f:
114 | yaml.dump(config, f)
115 |
116 | parser = CfArgumentParser([DummyArgsAB])
117 | file_args = parser._parse_yaml_config(str(config_path))
118 |
119 | # Convert file_args list to dict for easier assertion
120 | args_dict = dict(zip(file_args[::2], file_args[1::2]))
121 |
122 | assert args_dict["--A_model"] == "bert-base"
123 | assert args_dict["--B_model"] == "gpt2"
124 | assert args_dict["--A_learning_rate"] == "0.001"
125 | assert args_dict["--B_learning_rate"] == "0.0001"
126 | assert args_dict["--shared_param"] == "value"
127 |
128 | def test_yaml_with_cli_override(self, monkeypatch, tmp_path):
129 | config = {"A": {"model": "bert-base"}, "B": {"model": "gpt2"}}
130 |
131 | config_path = tmp_path / "config.yaml"
132 | with open(config_path, "w") as f:
133 | yaml.dump(config, f)
134 |
135 | parser = CfArgumentParser([DummyArgsAB])
136 | args = [
137 | "train",
138 | str(config_path),
139 | "--A_learning_rate",
140 | "0.001",
141 | "--B_learning_rate",
142 | "0.0001",
143 | "--shared_param",
144 | "override",
145 | ]
146 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
147 |
148 | (parsed_args,) = parser.parse_args_and_config()
149 | assert parsed_args.A_model == "bert-base"
150 | assert parsed_args.B_model == "gpt2"
151 | assert parsed_args.A_learning_rate == 0.001
152 | assert parsed_args.B_learning_rate == 0.0001
153 | assert parsed_args.shared_param == "override"
154 |
155 | def test_yaml_with_defaults(self, monkeypatch, tmp_path):
156 | config = {"A": {"model": "bert-base"}, "B": {"model": "gpt2"}}
157 |
158 | config_path = tmp_path / "config.yaml"
159 | with open(config_path, "w") as f:
160 | yaml.dump(config, f)
161 |
162 | parser = CfArgumentParser([DummyArgsAB])
163 | args = ["train", str(config_path)]
164 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
165 |
166 | (parsed_args,) = parser.parse_args_and_config()
167 | assert parsed_args.A_model == "bert-base"
168 | assert parsed_args.B_model == "gpt2"
169 | assert parsed_args.A_learning_rate == 0.001 # default value
170 | assert parsed_args.B_learning_rate == 0.0001 # default value
171 | assert parsed_args.shared_param == "default" # default value
172 |
173 | def test_duplicate_args_in_yaml(self, tmp_path):
174 | config = {
175 | "A": {"param": "value1"},
176 | "A_param": "value2", # This creates a duplicate A_param
177 | }
178 |
179 | config_path = tmp_path / "config.yaml"
180 | with open(config_path, "w") as f:
181 | yaml.dump(config, f)
182 |
183 | parser = CfArgumentParser()
184 | with pytest.raises(ValueError, match="Duplicate argument"):
185 | parser._parse_yaml_config(str(config_path))
186 |
187 | def test_yaml_with_invalid_param(self, monkeypatch, tmp_path):
188 | config = {"invalid_param": "value"}
189 |
190 | config_path = tmp_path / "config.yaml"
191 | with open(config_path, "w") as f:
192 | yaml.dump(config, f)
193 |
194 | parser = CfArgumentParser([DummyArgs])
195 | args = ["train", str(config_path)]
196 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
197 |
198 | with pytest.raises(SystemExit): # Printing helpstring
199 | parser.parse_args_and_config()
200 |
201 | def test_cli_only_with_defaults(self, monkeypatch):
202 | parser = CfArgumentParser([DummyArgsAB])
203 | args = ["train", "--A_model", "bert", "--B_model", "gpt2"]
204 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
205 |
206 | (parsed_args,) = parser.parse_args_and_config()
207 | assert parsed_args.A_model == "bert"
208 | assert parsed_args.B_model == "gpt2"
209 | assert parsed_args.A_learning_rate == 0.001 # default value
210 | assert parsed_args.B_learning_rate == 0.0001 # default value
211 | assert parsed_args.shared_param == "default" # default value
212 |
213 | def test_task_only_task(self, monkeypatch):
214 | parser = CfArgumentParser([DummyArgsAB])
215 | args = ["train"]
216 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
217 |
218 | with pytest.raises(SystemExit):
219 | parser.parse_args_and_config()
220 |
221 | def test_no_task_when_preset(self, monkeypatch):
222 | parser = CfArgumentParser([DummyArgs], task="train")
223 | args = ["--model_name", "bert"]
224 | monkeypatch.setattr(sys, "argv", ["script.py"] + args)
225 |
226 | (parsed_args,) = parser.parse_args_and_config()
227 | assert parsed_args.model_name == "bert"
228 | assert parsed_args.learning_rate == 0.001 # default value
229 |
--------------------------------------------------------------------------------
/tests/trainer/test_prepare_cycle_inputs.py:
--------------------------------------------------------------------------------
1 | from types import MethodType
2 | from unittest.mock import Mock
3 |
4 | import pytest
5 | import torch
6 | from accelerate import Accelerator
7 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
8 |
9 | from cycleformers import DEFAULT_SEP_SEQ, CycleTrainer
10 | from cycleformers.cycles import _default_prepare_cycle_inputs, _prepare_causal_skip_cycle_inputs
11 |
12 |
13 | AVAILABLE_DEVICES = ["cpu"]
14 | if torch.cuda.is_available():
15 | AVAILABLE_DEVICES.append("cuda")
16 |
17 |
18 | @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
19 | class BaseTestCycleInputs:
20 | @pytest.fixture(name="cycle_trainer")
21 | def fixture_cycle_trainer(self, device):
22 | trainer = Mock(spec=CycleTrainer)
23 | accelerator = Mock(spec=Accelerator)
24 | accelerator.device = device
25 |
26 | trainer._prepare_cycle_inputs = CycleTrainer._prepare_cycle_inputs.__get__(trainer)
27 | trainer.sep_seq = DEFAULT_SEP_SEQ
28 | trainer.accelerator = accelerator
29 | return trainer
30 |
31 | @pytest.fixture(name="tokenizer")
32 | def fixture_tokenizer(self, model_name):
33 | tokenizer = AutoTokenizer.from_pretrained(model_name)
34 | if tokenizer.pad_token_id is None:
35 | tokenizer.pad_token_id = tokenizer.eos_token_id
36 | return tokenizer
37 |
38 | # TODO: Replace these with better test examples
39 | @pytest.fixture(name="real_seq")
40 | def fixture_real_sequence(self):
41 | """Intentinally jagged to force padding (x10, x6)"""
42 | self.real_fst = 16
43 | self.real_snd = 6
44 | self.real_seq = ["A" * self.real_fst, "A" * self.real_snd]
45 | return self.real_seq
46 |
47 | @pytest.fixture(name="synth_seq")
48 | def fixture_synth_sequence(self):
49 | """Intentinally jagged to force padding (x8, x10)"""
50 | self.synth_fst = 8
51 | self.synth_snd = 20
52 | self.synth_seq = ["B" * self.synth_fst, "B" * self.synth_snd]
53 | return self.synth_seq
54 |
55 |
56 | @pytest.mark.parametrize(
57 | "model_name",
58 | [
59 | "trl-internal-testing/tiny-LlamaForCausalLM-3.1", # Llama tokenizer
60 | "Qwen/Qwen2.5-0.5B",
61 | # "gpt2", # GPT2 tokenizer
62 | # "facebook/opt-125m", # OPT tokenizer (GPT based)
63 | # "EleutherAI/pythia-70m", # Pythia/GPT-NeoX tokenizer (GPT based)
64 | ],
65 | )
66 | class TestPrepareCausalCycleInputs(BaseTestCycleInputs):
67 | @pytest.fixture(name="model")
68 | def fixture_model(self, model_name):
69 | return AutoModelForCausalLM.from_pretrained(model_name)
70 |
71 | @pytest.fixture(name="causal_sample")
72 | def fixture_prepare_causal_data(self, tokenizer, real_seq, synth_seq, sep_seq=DEFAULT_SEP_SEQ):
73 | """Prepare input sequences and expected outputs for causal LM testing."""
74 | tokenizer_kwargs = {
75 | "return_tensors": "pt",
76 | "padding": True,
77 | "truncation": True,
78 | "max_length": 512,
79 | }
80 |
81 | real_seq_with_sep = [seq + sep_seq for seq in real_seq]
82 | synth_seq_with_sep = [seq + sep_seq for seq in synth_seq]
83 |
84 | # The prompt that would be given to the generation model
85 | real_prompt_ids = tokenizer(real_seq_with_sep, **tokenizer_kwargs, padding_side="left").input_ids
86 | # PAD BOS R R R SEP
87 |
88 | # The output that would be generated as a continuation of the real input
89 | synth_seqs_eos = [synth + (tokenizer.eos_token or "") for synth in synth_seq]
90 | synth_response_ids = tokenizer(
91 | synth_seqs_eos,
92 | **tokenizer_kwargs,
93 | padding_side="right",
94 | ).input_ids
95 | # BOS S S S EOS PAD
96 |
97 | # Some tokenizers (like GPT2) don't add BOS automatically
98 | if tokenizer.bos_token_id is not None and synth_response_ids[0, 0] == tokenizer.bos_token_id:
99 | synth_response_ids = synth_response_ids[:, 1:]
100 | # S S S EOS PAD
101 |
102 | # The full sequence that would be returned from the causal model
103 | synth_output_ids = torch.cat([real_prompt_ids, synth_response_ids], dim=1)
104 | # PAD BOS R R R SEP S S S EOS PAD
105 |
106 | target_seq = [synth + real + (tokenizer.eos_token or "") for synth, real in zip(synth_seq_with_sep, real_seq)]
107 | targets = tokenizer(target_seq, **tokenizer_kwargs, padding_side="right")
108 | # BOS S S S SEP R R R EOS PAD
109 |
110 | labels = torch.clone(targets.input_ids)
111 | for i, text in enumerate(synth_seq_with_sep):
112 | prompt_ids = tokenizer.encode(text, **tokenizer_kwargs, padding_side=None)
113 | prompt_length = prompt_ids.shape[-1]
114 | labels[i, :prompt_length] = -100
115 | # Take care not to set eos token as -100
116 | labels[targets.attention_mask == 0] = -100
117 |
118 | sample_data = {
119 | "real_input_ids": real_prompt_ids,
120 | "synth_input_ids": synth_output_ids,
121 | "input_ids": targets.input_ids,
122 | "attention_mask": targets.attention_mask,
123 | "labels": labels,
124 | }
125 | return sample_data
126 |
127 | @pytest.mark.parametrize(
128 | "prepare_fn",
129 | [
130 | _default_prepare_cycle_inputs,
131 | _prepare_causal_skip_cycle_inputs,
132 | ],
133 | )
134 | def test_prepare_cycle_inputs(self, causal_sample, cycle_trainer, model, tokenizer, prepare_fn):
135 | # Copy the real method to our mock
136 | bound_method = MethodType(prepare_fn, cycle_trainer)
137 | setattr(cycle_trainer, "_prepare_cycle_inputs", bound_method)
138 | causal_sample = {k: v.to(cycle_trainer.accelerator.device) for k, v in causal_sample.items()}
139 |
140 | synth_batch = cycle_trainer._prepare_cycle_inputs(
141 | causal_sample["real_input_ids"],
142 | causal_sample["synth_input_ids"],
143 | model,
144 | model,
145 | tokenizer,
146 | tokenizer,
147 | "A",
148 | )
149 | # Test outputs
150 | assert torch.allclose(synth_batch["input_ids"], causal_sample["input_ids"])
151 | assert torch.allclose(synth_batch["attention_mask"], causal_sample["attention_mask"])
152 | assert torch.allclose(synth_batch["labels"], causal_sample["labels"])
153 |
154 |
155 | @pytest.mark.parametrize(
156 | "model_name",
157 | [
158 | "google/flan-t5-small",
159 | ],
160 | )
161 | class TestPrepareSeq2SeqCycleInputs(BaseTestCycleInputs):
162 | @pytest.fixture(name="model")
163 | def fixture_model(self, model_name):
164 | return AutoModelForSeq2SeqLM.from_pretrained(model_name)
165 |
166 | @pytest.fixture(name="seq2seq_sample")
167 | def fixture_seq2seq_data(self, tokenizer, real_seq, synth_seq):
168 | tokenizer_kwargs = {
169 | "return_tensors": "pt",
170 | "padding": True,
171 | "truncation": True,
172 | "max_length": 512,
173 | "padding_side": "right",
174 | }
175 |
176 | # The prompt that would be given to the generation model
177 | real_prompt_ids = tokenizer(real_seq, **tokenizer_kwargs)
178 | # As if the model generated the synthetic output
179 | synth_response_ids = tokenizer(
180 | [tokenizer.pad_token + synth for synth in synth_seq],
181 | **tokenizer_kwargs,
182 | )
183 | targets = tokenizer(synth_seq, text_target=real_seq, **tokenizer_kwargs)
184 |
185 | return {
186 | "real_input_ids": real_prompt_ids.input_ids,
187 | "synth_input_ids": synth_response_ids.input_ids,
188 | "input_ids": targets.input_ids,
189 | "attention_mask": targets.attention_mask,
190 | "labels": targets.labels,
191 | }
192 |
193 | @pytest.mark.parametrize(
194 | "prepare_fn",
195 | [
196 | _default_prepare_cycle_inputs,
197 | ],
198 | )
199 | def test_prepare_cycle_inputs(self, seq2seq_sample, cycle_trainer, model, tokenizer, prepare_fn):
200 | # Copy the real method to our mock
201 | bound_method = MethodType(prepare_fn, cycle_trainer)
202 | setattr(cycle_trainer, "_prepare_cycle_inputs", bound_method)
203 | seq2seq_sample = {k: v.to(cycle_trainer.accelerator.device) for k, v in seq2seq_sample.items()}
204 |
205 | synth_batch = cycle_trainer._prepare_cycle_inputs(
206 | seq2seq_sample["real_input_ids"],
207 | seq2seq_sample["synth_input_ids"],
208 | model,
209 | model,
210 | tokenizer,
211 | tokenizer,
212 | "A",
213 | )
214 | assert torch.allclose(synth_batch["input_ids"], seq2seq_sample["input_ids"])
215 | assert torch.allclose(synth_batch["attention_mask"], seq2seq_sample["attention_mask"])
216 | assert torch.allclose(synth_batch["labels"], seq2seq_sample["labels"])
217 |
--------------------------------------------------------------------------------
/src/cycleformers/utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import MISSING, fields
3 | from functools import wraps
4 | from typing import get_type_hints
5 |
6 | from peft import LoraConfig
7 | from transformers.hf_argparser import DataClassType
8 |
9 | from .import_utils import is_liger_kernel_available
10 |
11 |
12 | if is_liger_kernel_available():
13 | from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN # type: ignore
14 |
15 | VALID_LIGER_MODELS = MODEL_TYPE_TO_APPLY_LIGER_FN.keys()
16 | else:
17 | VALID_LIGER_MODELS = []
18 |
19 | DEFAULT_SEP_SEQ = "\n\n"
20 |
21 |
22 | def prefixed_view(base_class: DataClassType, prefix: str):
23 | """Creates a dataclass-like decorator that provides a prefixed view of another dataclass.
24 | When instantiated, returns an instance of the original dataclass with unprefixed attributes.
25 |
26 | This decorator allows you to create a class that acts as a view of another dataclass,
27 | where all attributes are prefixed. When instantiating the prefixed class, it returns
28 | an instance of the original dataclass with the prefixes removed.
29 |
30 | Args:
31 | base_class: The original dataclass to create a view of
32 | prefix: The prefix to add to all attribute names
33 |
34 | Returns:
35 | A decorator function that creates the prefixed view class
36 |
37 | Example:
38 | >>> from dataclasses import dataclass
39 | >>> @dataclass
40 | ... class Config:
41 | ... name: str
42 | ... value: int = 42
43 | ...
44 | >>> @prefixed_view(Config, "test_")
45 | ... class TestConfig:
46 | ... pass
47 | ...
48 | >>> # The new class has prefixed type hints
49 | >>> TestConfig.__annotations__
50 | {'test_name': , 'test_value': }
51 | >>>
52 | >>> # Creating an instance with prefixed attributes
53 | >>> config = TestConfig(test_name="example", test_value=100)
54 | >>> config
55 | Config(name='example', value=100)
56 | >>>
57 | >>> # Invalid attributes raise TypeError
58 | >>> config = TestConfig(invalid_attr="test") # doctest: +IGNORE_EXCEPTION_DETAIL
59 | Traceback (most recent call last):
60 | ...
61 | TypeError: Unexpected argument: invalid_attr
62 | >>>
63 | >>> # Default values are preserved
64 | >>> config = TestConfig(test_name="example")
65 | >>> config
66 | Config(name='example', value=42)
67 | """
68 |
69 | def wrapper(cls):
70 | # Get original class type hints and fields
71 | base_hints = get_type_hints(base_class)
72 | base_fields = {f.name: f for f in fields(base_class)}
73 |
74 | # Create new fields with prefixed names but same types/defaults
75 | new_annotations = {}
76 | new_defaults = {}
77 |
78 | for name, type_hint in base_hints.items():
79 | prefixed_name = f"{prefix}{name}"
80 | new_annotations[prefixed_name] = type_hint
81 |
82 | # Copy default values if they exist
83 | if name in base_fields:
84 | field = base_fields[name]
85 | if field.default is not MISSING:
86 | new_defaults[prefixed_name] = field.default
87 | if field.default_factory is not MISSING:
88 | new_defaults[prefixed_name] = field.default_factory()
89 |
90 | # Add annotations and default values to the class
91 | cls.__annotations__ = new_annotations
92 | for name, value in new_defaults.items():
93 | setattr(cls, name, value)
94 |
95 | def __new__(cls, **kwargs):
96 | # Validate input against our prefixed fields
97 | for key in kwargs:
98 | if key not in new_annotations:
99 | raise TypeError(f"Unexpected argument: {key}")
100 |
101 | # Create mapping of unprefixed attributes
102 | unprefixed_kwargs = {key[len(prefix) :]: value for key, value in kwargs.items()}
103 |
104 | # Return instance of original dataclass
105 | return base_class(**unprefixed_kwargs)
106 |
107 | cls.__new__ = staticmethod(__new__)
108 | return cls
109 |
110 | return wrapper
111 |
112 |
113 | def auto_temp_attributes(*attrs_to_cleanup):
114 | """Decorator that automatically manages temporary attributes on a class instance.
115 |
116 | This decorator solves the issue of methods that need to temporarily modify class attributes
117 | that might be needed by other methods. It automatically sets attributes based on method
118 | parameters and restores their original values (or removes them) after the method completes.
119 |
120 | Args:
121 | *attrs_to_cleanup: Variable number of attribute names to manage
122 |
123 | Returns:
124 | Callable: Decorated function that handles attribute lifecycle
125 |
126 | Examples:
127 | >>> class MyClass:
128 | ... def __init__(self):
129 | ... self.permanent = "permanent"
130 | ...
131 | ... @auto_temp_attributes("model", "optimizer")
132 | ... def my_method(self, model, optimizer=None):
133 | ... print(f"Using {model} and {optimizer}")
134 | ... print(f"Permanent: {self.permanent}")
135 | >>>
136 | >>> obj = MyClass()
137 | >>> obj.my_method("bert", optimizer="adam")
138 | Using bert and adam
139 | Permanent: permanent
140 | >>> hasattr(obj, "model") # Attribute is cleaned up
141 | False
142 |
143 | Notes:
144 | - Original attribute values are restored after method execution
145 | - Attributes that didn't exist are removed
146 | - Works with both positional and keyword arguments
147 | - Handles exceptions by ensuring cleanup
148 | """
149 |
150 | def decorator(func):
151 | @wraps(func)
152 | def wrapper(self, *args, **kwargs):
153 | # Store original attribute values
154 | original_values = {}
155 | for attr in attrs_to_cleanup:
156 | if hasattr(self, attr):
157 | original_values[attr] = getattr(self, attr)
158 |
159 | # Get function signature to match positional args to parameter names
160 | sig = inspect.signature(func)
161 | bound_args = sig.bind(self, *args, **kwargs)
162 | bound_args.apply_defaults()
163 |
164 | # Set attributes based on parameters
165 | for attr in attrs_to_cleanup:
166 | # Skip 'self' parameter
167 | if attr in bound_args.arguments and attr != "self":
168 | setattr(self, attr, bound_args.arguments[attr])
169 | else:
170 | setattr(self, attr, None)
171 |
172 | try:
173 | # Execute the method
174 | result = func(self, *args, **kwargs)
175 | return result
176 | finally:
177 | # Clean up attributes
178 | for attr in attrs_to_cleanup:
179 | if attr in original_values:
180 | setattr(self, attr, original_values[attr])
181 | else:
182 | try:
183 | delattr(self, attr)
184 | except AttributeError:
185 | pass
186 |
187 | return wrapper
188 |
189 | return decorator
190 |
191 |
192 | def get_peft_config(model_config):
193 | """Creates a PEFT LoRA configuration from a model configuration dataclass."""
194 | if model_config.use_peft is False:
195 | return None
196 |
197 | peft_config = LoraConfig(
198 | task_type=model_config.lora_task_type,
199 | r=model_config.lora_r,
200 | target_modules=model_config.lora_target_modules,
201 | lora_alpha=model_config.lora_alpha,
202 | lora_dropout=model_config.lora_dropout,
203 | bias="none",
204 | use_rslora=model_config.use_rslora,
205 | modules_to_save=model_config.lora_modules_to_save,
206 | )
207 |
208 | return peft_config
209 |
210 |
211 | def print_trainable_params(model):
212 | """Prints the number of trainable and all parameters in a model."""
213 | trainable_params = 0
214 | all_param = 0
215 | for _, param in model.named_parameters():
216 | num_params = param.numel()
217 | # if using DS Zero 3 and the weights are initialized empty
218 | if num_params == 0 and hasattr(param, "ds_numel"):
219 | num_params = param.ds_numel
220 |
221 | # Due to the design of 4bit linear layers from bitsandbytes one needs to multiply
222 | # the number of parameters by 2 to get the correct number of parameters
223 | if param.__class__.__name__ == "Params4bit":
224 | num_params = num_params * 2
225 |
226 | all_param += num_params
227 | if param.requires_grad:
228 | trainable_params += num_params
229 | print(
230 | f"trainable params: {trainable_params:,d} || "
231 | f"all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
232 | )
233 |
--------------------------------------------------------------------------------
/src/cycleformers/trainer_callback.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from dataclasses import dataclass
3 | from typing import TYPE_CHECKING, Any
4 |
5 | from transformers.trainer_callback import DefaultFlowCallback, TrainerCallback, TrainerControl, TrainerState
6 |
7 |
8 | if TYPE_CHECKING:
9 | from transformers.training_args import TrainingArguments
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | @dataclass
16 | class CycleTrainerState(TrainerState):
17 | """Extension of TrainerState to handle cycle-specific states"""
18 |
19 | steps_a: int = 0 # Number of updates for model A
20 | steps_b: int = 0 # Number of updates for model B
21 | loss_a: float = 0.0 # Current loss for model A
22 | loss_b: float = 0.0 # Current loss for model B
23 |
24 | @property
25 | def global_step(self) -> int:
26 | """Global step is the maximum of steps between models"""
27 | return max(self.steps_a, self.steps_b)
28 |
29 | @global_step.setter
30 | def global_step(self, value: int):
31 | """Setting global step is not allowed as it's derived from model steps"""
32 | raise AttributeError("global_step cannot be set directly. Update steps_a or steps_b instead.")
33 |
34 | def update_model_step(self, cycle_name: str):
35 | """Update step count for a specific model"""
36 | if cycle_name == "A":
37 | self.steps_a += 1
38 | elif cycle_name == "B":
39 | self.steps_b += 1
40 | else:
41 | raise ValueError(f"Invalid cycle name: {cycle_name}")
42 |
43 | def get_model_step(self, cycle_name: str) -> int:
44 | """Get current step for a specific model"""
45 | if cycle_name == "A":
46 | return self.steps_a
47 | elif cycle_name == "B":
48 | return self.steps_b
49 | else:
50 | raise ValueError(f"Invalid cycle name: {cycle_name}")
51 |
52 | def update_model_loss(self, cycle_name: str, loss: float):
53 | """Update current loss for a specific model"""
54 | if cycle_name == "A":
55 | self.loss_a = loss
56 | elif cycle_name == "B":
57 | self.loss_b = loss
58 | else:
59 | raise ValueError(f"Invalid cycle name: {cycle_name}")
60 |
61 | def get_model_loss(self, cycle_name: str) -> float:
62 | """Get current loss for a specific model"""
63 | if cycle_name == "A":
64 | return self.loss_a
65 | elif cycle_name == "B":
66 | return self.loss_b
67 | else:
68 | raise ValueError(f"Invalid cycle name: {cycle_name}")
69 |
70 |
71 | class CallbackHandler(TrainerCallback):
72 | """Internal class that just calls the list of callbacks in order."""
73 |
74 | def __init__(
75 | self,
76 | callbacks: list[TrainerCallback],
77 | model: dict[str, Any],
78 | processing_class: dict[str, Any],
79 | optimizer: dict[str, Any],
80 | lr_scheduler: dict[str, Any],
81 | ):
82 | self.callbacks: list[TrainerCallback] = []
83 | for cb in callbacks:
84 | self.add_callback(cb)
85 | self.model = model
86 | self.processing_class = processing_class
87 | self.optimizer = optimizer
88 | self.lr_scheduler = lr_scheduler
89 | self.train_dataloader = None
90 | self.eval_dataloader = None
91 |
92 | if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
93 | logger.warning(
94 | "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
95 | + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
96 | + "callbacks is\n:"
97 | + self.callback_list
98 | )
99 |
100 | def add_callback(self, callback: TrainerCallback) -> None:
101 | cb = callback() if isinstance(callback, type) else callback
102 | cb_class = callback if isinstance(callback, type) else callback.__class__
103 | if cb_class in [c.__class__ for c in self.callbacks]:
104 | logger.warning(
105 | f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
106 | + "list of callbacks is\n:"
107 | + self.callback_list
108 | )
109 | self.callbacks.append(cb)
110 |
111 | def pop_callback(self, callback):
112 | if isinstance(callback, type):
113 | for cb in self.callbacks:
114 | if isinstance(cb, callback):
115 | self.callbacks.remove(cb)
116 | return cb
117 | else:
118 | for cb in self.callbacks:
119 | if cb == callback:
120 | self.callbacks.remove(cb)
121 | return cb
122 |
123 | def remove_callback(self, callback):
124 | if isinstance(callback, type):
125 | for cb in self.callbacks:
126 | if isinstance(cb, callback):
127 | self.callbacks.remove(cb)
128 | return
129 | else:
130 | self.callbacks.remove(callback)
131 |
132 | @property
133 | def callback_list(self):
134 | return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
135 |
136 | def on_init_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
137 | return self.call_event("on_init_end", args, state, control)
138 |
139 | def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
140 | control.should_training_stop = False
141 | return self.call_event("on_train_begin", args, state, control)
142 |
143 | def on_train_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
144 | return self.call_event("on_train_end", args, state, control)
145 |
146 | def on_epoch_begin(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
147 | control.should_epoch_stop = False
148 | return self.call_event("on_epoch_begin", args, state, control)
149 |
150 | def on_epoch_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
151 | return self.call_event("on_epoch_end", args, state, control)
152 |
153 | def on_step_begin(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
154 | control.should_log = False
155 | control.should_evaluate = False
156 | control.should_save = False
157 | return self.call_event("on_step_begin", args, state, control)
158 |
159 | def on_pre_optimizer_step(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
160 | return self.call_event("on_pre_optimizer_step", args, state, control)
161 |
162 | def on_optimizer_step(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
163 | return self.call_event("on_optimizer_step", args, state, control)
164 |
165 | def on_substep_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
166 | return self.call_event("on_substep_end", args, state, control)
167 |
168 | def on_step_end(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
169 | return self.call_event("on_step_end", args, state, control)
170 |
171 | def on_evaluate(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl, metrics):
172 | control.should_evaluate = False
173 | return self.call_event("on_evaluate", args, state, control, metrics=metrics)
174 |
175 | def on_predict(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl, metrics):
176 | return self.call_event("on_predict", args, state, control, metrics=metrics)
177 |
178 | def on_save(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
179 | control.should_save = False
180 | return self.call_event("on_save", args, state, control)
181 |
182 | def on_log(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl, logs):
183 | control.should_log = False
184 | return self.call_event("on_log", args, state, control, logs=logs)
185 |
186 | def on_prediction_step(self, args: "TrainingArguments", state: TrainerState, control: TrainerControl):
187 | return self.call_event("on_prediction_step", args, state, control)
188 |
189 | def call_event(self, event, args, state, control, model_key="both", **kwargs):
190 | for callback in self.callbacks:
191 | if model_key == "both":
192 | result = getattr(callback, event)(
193 | args,
194 | state,
195 | control,
196 | model=self.model,
197 | processing_class=self.processing_class,
198 | optimizer=self.optimizer,
199 | lr_scheduler=self.lr_scheduler,
200 | train_dataloader=self.train_dataloader,
201 | eval_dataloader=self.eval_dataloader,
202 | **kwargs,
203 | )
204 | # A Callback can skip the return of `control` if it doesn't change it.
205 | if result is not None:
206 | control = result
207 | return control
208 |
--------------------------------------------------------------------------------
/tests/benchmark/cycle_trainer_profiling.py:
--------------------------------------------------------------------------------
1 | import cProfile
2 | import logging
3 | import pstats
4 | import sys
5 | import time
6 | from pathlib import Path
7 |
8 | import torch.cuda as cuda
9 | import torch.profiler
10 | import yaml
11 | from memory_profiler import profile as memory_profile
12 | from profiler_utils import record_function_wrapper
13 | from torch.profiler import ProfilerActivity, profile, record_function
14 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
15 |
16 | from cycleformers import CfArgumentParser, CycleTrainer, CycleTrainingArguments, ModelConfig
17 | from cycleformers.import_utils import is_liger_kernel_available
18 | from cycleformers.model_config import ModelConfigA, ModelConfigB, merge_configs
19 | from cycleformers.task_processors.ner import CONLL2003Processor, CONLL2003ProcessorConfig
20 | from cycleformers.utils import VALID_LIGER_MODELS, get_peft_config, print_trainable_params
21 |
22 |
23 | logger = logging.getLogger(__file__)
24 |
25 | MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT = 1_000_000
26 |
27 |
28 | class ProfilingCycleTrainer(CycleTrainer):
29 | def __init__(self, *args, **kwargs):
30 | super().__init__(*args, **kwargs)
31 | self.profiling_stats = {
32 | "step_times": [],
33 | "gpu_memory": [],
34 | "max_gpu_memory": 0,
35 | }
36 |
37 | # Create unique output directory for this run
38 | self.profile_output_dir = (
39 | Path(__file__).parent / "profiles" / f'cycle_trainer--{time.strftime("%Y%m%d_%H%M%S")}'
40 | )
41 | self.profile_output_dir.mkdir(parents=True, exist_ok=True)
42 |
43 | self.profile_path = self.profile_output_dir / "profile.prof"
44 | self.memory_path = self.profile_output_dir / "memory_profile.txt"
45 | self.cuda_memory_path = self.profile_output_dir / "cuda_memory_snapshots"
46 |
47 | # Save args to yaml file
48 | with open(self.profile_output_dir / "profiler_args.yaml", "w") as f:
49 | yaml.dump(self.args, f)
50 |
51 | print("=" * 40)
52 | print("Model A: ", end="")
53 | print_trainable_params(self.model_A)
54 | if not self.is_macct_model:
55 | print("Model B: ", end="")
56 | print_trainable_params(self.model_B)
57 | print("=" * 40)
58 |
59 | def _log_gpu_memory(self):
60 | """Log current GPU memory usage"""
61 | if not cuda.is_available():
62 | return None
63 |
64 | current_memory = cuda.memory_allocated() / 1024**2 # Convert to MB
65 | max_memory = cuda.max_memory_allocated() / 1024**2 # Convert to MB
66 |
67 | self.profiling_stats["gpu_memory"].append(current_memory)
68 | self.profiling_stats["max_gpu_memory"] = max(self.profiling_stats["max_gpu_memory"], max_memory)
69 |
70 | return current_memory, max_memory
71 |
72 | def train(self):
73 | torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT)
74 |
75 | # Profile entire training run
76 | profiler = cProfile.Profile()
77 | profiler.enable()
78 |
79 | # Setup PyTorch profiler for detailed GPU analysis
80 | pytorch_profiler = torch.profiler.profile(
81 | activities=[
82 | torch.profiler.ProfilerActivity.CPU,
83 | torch.profiler.ProfilerActivity.CUDA,
84 | ],
85 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
86 | on_trace_ready=torch.profiler.tensorboard_trace_handler(self.profile_output_dir / "tensorboard_trace"),
87 | record_shapes=True,
88 | profile_memory=True,
89 | with_stack=True,
90 | )
91 |
92 | pytorch_profiler.start()
93 |
94 | try:
95 | result = super().train()
96 |
97 | # Collect profiling data
98 | profiler.disable()
99 | stats = pstats.Stats(profiler)
100 | stats.sort_stats("cumtime")
101 |
102 | # Save profiling data
103 | stats.dump_stats(str(self.profile_path))
104 |
105 | # Save memory statistics
106 | with open(self.memory_path, "w") as f:
107 | f.write(f"Max GPU Memory Used: {self.profiling_stats['max_gpu_memory']:.2f} MB\n")
108 | f.write("GPU Memory Usage per Step:\n")
109 | for i, mem in enumerate(self.profiling_stats["gpu_memory"]):
110 | f.write(f"Step {i}: {mem:.2f} MB\n")
111 |
112 | return result
113 | finally:
114 | pytorch_profiler.stop()
115 | torch.cuda.memory._record_memory_history(enabled=None)
116 | try:
117 | torch.cuda.memory._dump_snapshot(f"{self.cuda_memory_path}.pickle")
118 | except Exception as e:
119 | logger.error(f"Failed to capture memory snapshot {e}")
120 |
121 | @record_function_wrapper("## Cycle Step ##")
122 | def _cycle_step(self, *args, **kwargs):
123 | return super()._cycle_step(*args, **kwargs)
124 |
125 | @record_function_wrapper("## Prepare Cycle Inputs ##")
126 | def _prepare_cycle_inputs(self, *args, **kwargs):
127 | return super().prepare_cycle_inputs(*args, **kwargs)
128 |
129 | def analyze_performance(self):
130 | """Analyze collected performance metrics"""
131 | if self.profiling_stats.get("step_times"):
132 | avg_step_time = sum(self.profiling_stats["step_times"]) / len(self.profiling_stats["step_times"])
133 | print(f"Average step time: {avg_step_time:.4f} seconds")
134 |
135 | # GPU Memory Statistics
136 | if cuda.is_available():
137 | print("\nGPU Memory Statistics:")
138 | print(f"Peak GPU memory usage: {self.profiling_stats['max_gpu_memory']:.2f} MB")
139 | avg_memory = sum(self.profiling_stats["gpu_memory"]) / len(self.profiling_stats["gpu_memory"])
140 | print(f"Average GPU memory usage: {avg_memory:.2f} MB")
141 | print(f"Current GPU memory usage: {cuda.memory_allocated() / 1024**2:.2f} MB")
142 | print(f"Current GPU memory cached: {cuda.memory_reserved() / 1024**2:.2f} MB")
143 |
144 | # Load and analyze the cProfile data
145 | stats = pstats.Stats(str(self.profile_output_dir / "profile.prof"))
146 | print("\nTop 10 time-consuming functions:")
147 | stats.sort_stats("cumtime").print_stats(10)
148 |
149 | # Memory analysis tips
150 | print(f"\nProfiling data has been saved to: {self.profile_output_dir}")
151 | print("Check PyTorch profiler traces in TensorBoard for detailed GPU memory analysis")
152 |
153 |
154 | if is_liger_kernel_available():
155 | from liger_kernel.transformers import AutoLigerKernelForCausalLM
156 |
157 |
158 | def get_model_and_tokenizer(model_config, training_args):
159 | """Initialize model and tokenizer from config"""
160 | config = AutoConfig.from_pretrained(
161 | model_config.model_name_or_path,
162 | trust_remote_code=model_config.trust_remote_code,
163 | )
164 | config.use_cache = False
165 |
166 | model_kwargs = {}
167 |
168 | if not config.is_encoder_decoder:
169 | if is_liger_kernel_available() and model_config.use_liger and config.model_type in VALID_LIGER_MODELS:
170 | model_class = AutoLigerKernelForCausalLM
171 | model_kwargs["use_liger_kernel"] = training_args.use_liger_kernel
172 | else:
173 | model_class = AutoModelForCausalLM
174 | else:
175 | model_class = AutoModelForSeq2SeqLM
176 |
177 | model = model_class.from_pretrained(
178 | model_config.model_name_or_path,
179 | revision=model_config.model_revision,
180 | config=config,
181 | trust_remote_code=model_config.trust_remote_code,
182 | attn_implementation=model_config.attn_implementation,
183 | torch_dtype=model_config.torch_dtype,
184 | device_map="auto",
185 | )
186 |
187 | if training_args.gradient_checkpointing:
188 | model.enable_input_require_grads()
189 |
190 | # Print the actual dtype of the first parameter
191 | print(f"Model weights dtype: {next(model.parameters()).dtype}")
192 |
193 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
194 | return model, tokenizer
195 |
196 |
197 | def main():
198 | sys.argv = [__file__, str(Path(__file__).parent / "profiler_configs/causal.yaml")]
199 | parser = CfArgumentParser(
200 | (CycleTrainingArguments, ModelConfig, ModelConfigA, ModelConfigB, CONLL2003ProcessorConfig), task="train"
201 | )
202 | args, model_config_base, model_config_A, model_config_B, conll_config = parser.parse_args_and_config()
203 | model_config_base = merge_configs(model_config_base, model_config_A, model_config_B)
204 | args.model_config = model_config_base
205 |
206 | task_processor = CONLL2003Processor(conll_config)
207 | dataset_A, dataset_B = task_processor.process()
208 |
209 | model_A, tokenizer_A = get_model_and_tokenizer(args.model_config.A, args)
210 |
211 | # Train by adapter swapping
212 | if not args.use_macct:
213 | # Get model B using merged B config
214 | model_B, tokenizer_B = get_model_and_tokenizer(args.model_config.B, args)
215 | models = {"A": model_A, "B": model_B}
216 | tokenizers = {"A": tokenizer_A, "B": tokenizer_B} if tokenizer_A != tokenizer_B else tokenizer_A
217 | else:
218 | models = model_A
219 | tokenizers = tokenizer_A
220 |
221 | trainer = ProfilingCycleTrainer(
222 | args=args,
223 | models=models,
224 | tokenizers=tokenizers,
225 | train_dataset_A=dataset_A["train"],
226 | train_dataset_B=dataset_B["train"],
227 | eval_dataset_A=dataset_A["eval"] if not args.eval_strategy == "no" else None,
228 | eval_dataset_B=dataset_B["eval"] if not args.eval_strategy == "no" else None,
229 | peft_configs=get_peft_config(model_config_base),
230 | )
231 |
232 | trainer.train()
233 |
234 | trainer.analyze_performance()
235 |
236 |
237 | if __name__ == "__main__":
238 | main()
239 |
--------------------------------------------------------------------------------