├── lcm ├── nn │ ├── __init__.py │ ├── schedulers │ │ └── __init__.py │ ├── denoisers │ │ └── __init__.py │ ├── transformer │ │ └── __init__.py │ ├── incremental_state.py │ ├── projection.py │ ├── normalization.py │ ├── timestep_encoder.py │ └── initialization.py ├── utils │ ├── __init__.py │ ├── data_utils.py │ ├── model_type_registry.py │ ├── common.py │ └── logging.py ├── datasets │ ├── __init__.py │ ├── utils.py │ └── base.py ├── evaluation │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── sonar.py │ │ └── hf.py │ ├── cli │ │ ├── __init__.py │ │ ├── local.py │ │ └── slurm.py │ ├── metrics │ │ ├── utils.py │ │ ├── round_trip.py │ │ ├── common.py │ │ ├── seahorse.py │ │ ├── __init__.py │ │ └── similarity.py │ ├── predictors │ │ ├── two_tower_diffusion_lcm.py │ │ ├── gemma.py │ │ ├── __init__.py │ │ └── dummy.py │ ├── __main__.py │ └── tasks │ │ ├── lcm_generation.py │ │ ├── xsum.py │ │ └── cnn_dailymail.py ├── train │ ├── lcm │ │ └── __init__.py │ ├── mse_lcm │ │ └── __init__.py │ ├── two_tower_diffusion_lcm │ │ ├── __init__.py │ │ └── trainer.py │ ├── common.py │ ├── optim.py │ ├── criterion.py │ ├── step_sampler.py │ └── __main__.py ├── cards │ ├── mock_data │ │ └── dummy_normalizer.pt │ ├── TODO_mse_dummy_model.yaml │ ├── sonar_normalizer.yaml │ └── TODO_2t_dummy_model.yaml ├── models │ ├── two_tower_diffusion_lcm │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── archs.py │ │ └── frontend.py │ ├── abstract_lcm │ │ ├── __init__.py │ │ └── builder.py │ ├── __init__.py │ ├── base_lcm │ │ ├── __init__.py │ │ ├── archs.py │ │ ├── normalization.py │ │ └── loader.py │ └── sonar_normalizer │ │ ├── __init__.py │ │ ├── loader.py │ │ └── archs.py ├── inference │ ├── lcm │ │ └── __init__.py │ └── two_tower_diffusion_lcm │ │ └── __init__.py ├── datacards │ └── datacards.yaml └── __init__.py ├── tests ├── units │ ├── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── test_predictors.py │ │ ├── conftest.py │ │ ├── test_round_trip.py │ │ ├── test_task_registry.py │ │ ├── test_model_based_metrics.py │ │ ├── test_metrics.py │ │ ├── test_judge_tasks.py │ │ ├── test_similarity.py │ │ └── test_cli.py │ ├── training │ │ ├── test_batch.py │ │ ├── test_get_trainer.py │ │ ├── conftest.py │ │ └── test_toy_task_trainer.py │ ├── datapipeline │ │ └── test_sentence_splitter.py │ ├── inference │ │ ├── test_base_lcm_kv_caching.py │ │ └── test_base_lcm_batched_inference.py │ ├── test_recipes.py │ └── conftest.py ├── __init__.py ├── common.py └── test_headers.py ├── recipes ├── common │ ├── launcher │ │ └── standalone.yaml │ ├── requirements.yaml │ └── evals.yaml └── train │ ├── defaults.yaml │ ├── finetune │ ├── mse.yaml │ └── two_tower.yaml │ ├── pretrain │ ├── mse.yaml │ └── two_tower.yaml │ └── README.md ├── .github ├── actions │ └── setup │ │ └── action.yaml ├── pull_request_template.md └── workflows │ └── lint_and_test.yaml ├── examples └── evaluation │ └── instruction.yaml ├── .pre-commit-config.yaml ├── LICENSE ├── CONTRIBUTING.md ├── .gitignore ├── scripts ├── fit_embedding_normalizer.py └── prepare_wikipedia.py ├── pyproject.toml └── CODE_OF_CONDUCT.md /lcm/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/train/lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /tests/units/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/evaluation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/train/mse_lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/evaluation/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | -------------------------------------------------------------------------------- /tests/units/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/train/two_tower_diffusion_lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | -------------------------------------------------------------------------------- /lcm/cards/mock_data/dummy_normalizer.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/large_concept_model/HEAD/lcm/cards/mock_data/dummy_normalizer.pt -------------------------------------------------------------------------------- /lcm/cards/TODO_mse_dummy_model.yaml: -------------------------------------------------------------------------------- 1 | name: dummy_mse_pretrained_model 2 | model_family: base_lcm 3 | checkpoint: file:///dev/null 4 | model_arch: toy_base_lcm 5 | -------------------------------------------------------------------------------- /lcm/cards/sonar_normalizer.yaml: -------------------------------------------------------------------------------- 1 | name: dummy_sonar_normalizer 2 | model_family: sonar_normalizer 3 | model_arch: base 4 | checkpoint: mock_data/dummy_normalizer.pt 5 | -------------------------------------------------------------------------------- /recipes/common/launcher/standalone.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - cache: file_cache 3 | 4 | _target_: stopes.core.Launcher 5 | log_folder: executor_logs 6 | cluster: debug 7 | partition: null 8 | -------------------------------------------------------------------------------- /lcm/cards/TODO_2t_dummy_model.yaml: -------------------------------------------------------------------------------- 1 | name: dummy_2t_pretrained_model 2 | model_family: two_tower_diffusion_lcm 3 | checkpoint: file:///dev/null 4 | model_arch: toy_two_tower_diffusion_lcm 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import pytest 7 | 8 | pytest.register_assert_rewrite("tests.common") 9 | -------------------------------------------------------------------------------- /lcm/models/two_tower_diffusion_lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | # Register architectures 7 | import lcm.models.two_tower_diffusion_lcm.archs # noqa 8 | -------------------------------------------------------------------------------- /recipes/common/requirements.yaml: -------------------------------------------------------------------------------- 1 | requirements: 2 | nodes: 1 3 | tasks_per_node: 8 4 | gpus_per_node: 8 5 | cpus_per_task: 8 6 | mem_gb: 256 7 | timeout_min: 4_320 # 3 days 8 | constraint: null 9 | max_num_timeout: 10 10 | -------------------------------------------------------------------------------- /.github/actions/setup/action.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | using: composite 3 | steps: 4 | - name: "Install UV" 5 | shell: bash 6 | run: | 7 | curl -LsSf https://astral.sh/uv/install.sh | sh 8 | - name: "Install libsndfile" 9 | shell: bash 10 | run: | 11 | sudo apt-get install libsndfile1 -------------------------------------------------------------------------------- /lcm/inference/lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.inference.lcm.generator import LCMGenerator as LCMGenerator 7 | from lcm.inference.lcm.generator import LCMGeneratorOptions as LCMGeneratorOptions 8 | 9 | __all__ = ["LCMGenerator", "LCMGeneratorOptions"] 10 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Why ? 2 | 3 | Why do we need to implement this feature ? What is the use case ? 4 | 5 | ## How ? 6 | 7 | Document the technical decisions you made. 8 | If some parts are WIP, please explicit them here. 9 | 10 | ## Test plan 11 | 12 | How did you test your changes ? 13 | Include full command line to help other people reproduce if needed. 14 | -------------------------------------------------------------------------------- /lcm/nn/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from lcm.nn.schedulers.ddim import ( 8 | DDIMScheduler, 9 | DDIMSchedulerConfig, 10 | DDIMSchedulerOutput, 11 | ) 12 | 13 | __all__ = [ 14 | "DDIMScheduler", 15 | "DDIMSchedulerConfig", 16 | "DDIMSchedulerOutput", 17 | ] 18 | -------------------------------------------------------------------------------- /lcm/nn/denoisers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from lcm.nn.denoisers.factory import ( 8 | DenoiserConfig, 9 | LCMDenoiser, 10 | LCMDenoiserTransformerFactory, 11 | ) 12 | 13 | __all__ = [ 14 | "DenoiserConfig", 15 | "LCMDenoiser", 16 | "LCMDenoiserTransformerFactory", 17 | ] 18 | -------------------------------------------------------------------------------- /lcm/models/abstract_lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.models.abstract_lcm.builder import ( 7 | AbstractLCModel, 8 | AbstractLCModelBuilder, 9 | AbstractLCModelConfig, 10 | ) 11 | 12 | __all__ = [ 13 | "AbstractLCModel", 14 | "AbstractLCModelBuilder", 15 | "AbstractLCModelConfig", 16 | ] 17 | -------------------------------------------------------------------------------- /examples/evaluation/instruction.yaml: -------------------------------------------------------------------------------- 1 | source_prefix_text: | 2 | [INST] You are reading a chapter of the book. 3 | Please provide the summary of the chapter in maximum 1500 words. 4 | Adhere to the main plot and keep the characters if found in the chapter. 5 | The chapter is written between the tags [CHAPTER] and [/CHAPTER]. 6 | [CHAPTER] 7 | 8 | source_suffix_text: | 9 | [/CHAPTER]. 10 | Write at least 1000 words. 11 | [/INST] 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/uv-pre-commit 3 | rev: 0.5.7 4 | hooks: 5 | - id: uv-lock 6 | - repo: https://github.com/astral-sh/ruff-pre-commit 7 | rev: v0.8.2 8 | hooks: 9 | # Lint 10 | - id: ruff 11 | args: [ --fix ] 12 | # sort imports 13 | - id: ruff 14 | args: ["check", "--select", "I", "--fix"] 15 | # format 16 | - id: ruff-format -------------------------------------------------------------------------------- /lcm/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | # We import all the model types in order to populate the model type registry 7 | from lcm.models.base_lcm.loader import BASE_LCM_MODEL_TYPE 8 | from lcm.models.two_tower_diffusion_lcm.loader import ( 9 | TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, 10 | ) 11 | 12 | __all__ = [ 13 | "BASE_LCM_MODEL_TYPE", 14 | "TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE", 15 | ] 16 | -------------------------------------------------------------------------------- /recipes/train/defaults.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - launcher: submitit 3 | - requirements@trainer 4 | - _self_ 5 | dry_run: false 6 | trainer: 7 | # please specify a trainer conf, the easiest is to have this in another config, see train/__main__.py for more details. 8 | output_dir: ??? 9 | launcher: 10 | # set the executor logs and config dumps to be under the trainer outputs directory 11 | log_folder: ${trainer.output_dir}/executor_logs 12 | config_dump_dir: ${trainer.output_dir}/config_logs 13 | -------------------------------------------------------------------------------- /lcm/models/base_lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | # Register architectures 7 | import lcm.models.base_lcm.archs # noqa 8 | from lcm.models.base_lcm.builder import ( 9 | BaseLCModel, 10 | BaseLCModelBuilder, 11 | BaseLCModelConfig, 12 | create_base_lcm_model, 13 | ) 14 | 15 | __all__ = [ 16 | "BaseLCModel", 17 | "BaseLCModelBuilder", 18 | "BaseLCModelConfig", 19 | "create_base_lcm_model", 20 | ] 21 | -------------------------------------------------------------------------------- /lcm/inference/two_tower_diffusion_lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.inference.two_tower_diffusion_lcm.generator import ( 7 | DiffusionLCMGeneratorOptions as DiffusionLCMGeneratorOptions, 8 | ) 9 | from lcm.inference.two_tower_diffusion_lcm.generator import ( 10 | TwoTowerDiffusionLCMGenerator as TwoTowerDiffusionLCMGenerator, 11 | ) 12 | 13 | __all__ = [ 14 | "TwoTowerDiffusionLCMGenerator", 15 | "DiffusionLCMGeneratorOptions", 16 | ] 17 | -------------------------------------------------------------------------------- /lcm/datacards/datacards.yaml: -------------------------------------------------------------------------------- 1 | # FIXME 2 | name: "pretraining_data" 3 | parquet_path: 4 | s3: "wiki_data" 5 | source_column: "text_sentences_sonar_emb" 6 | source_text_column: "text_sentences" 7 | # partition columns: 8 | # "split" (train, validation) 9 | --- 10 | # FIXME 11 | name: "finetuning_data" 12 | parquet_path: 13 | s3: "cosmopedia_sample" 14 | source_column: prompt_sentences_sonar_emb 15 | source_text_column: prompt_sentences 16 | target_column: text_sentences_sonar_emb 17 | target_text_column: text_sentences 18 | # partition columns: 19 | # "split" (train, validation) 20 | -------------------------------------------------------------------------------- /lcm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | """ 7 | LCM: Modular and Extensible Reasoning in an Embedding Space 8 | Code base for training different LCM models. 9 | """ 10 | 11 | from fairseq2 import setup_extensions 12 | from fairseq2.assets import default_asset_store 13 | 14 | __version__ = "0.1.0.dev0" 15 | 16 | 17 | def setup_fairseq2() -> None: 18 | default_asset_store.add_package_metadata_provider("lcm.cards") 19 | 20 | 21 | # This call activates setup_fairseq2 and potentially other extensions, 22 | setup_extensions() 23 | -------------------------------------------------------------------------------- /lcm/models/sonar_normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | # Register architectures 7 | import lcm.models.sonar_normalizer.archs # noqa 8 | from lcm.models.sonar_normalizer.builder import ( 9 | SonarNormalizer, 10 | SonarNormalizerConfig, 11 | create_sonar_normalizer, 12 | ) 13 | from lcm.models.sonar_normalizer.loader import load_sonar_normalizer_model 14 | 15 | __all__ = [ 16 | "SonarNormalizer", 17 | "SonarNormalizerConfig", 18 | "create_sonar_normalizer", 19 | "load_sonar_normalizer_model", 20 | ] 21 | -------------------------------------------------------------------------------- /lcm/nn/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.nn.transformer.attention import ( 7 | QKNormMultiheadAttention, 8 | ) 9 | from lcm.nn.transformer.decoder import ( 10 | LCMStandardTransformerDecoderLayer, 11 | LCMTransformerDecoder, 12 | ) 13 | from lcm.nn.transformer.factory import ( 14 | TransformerConfig, 15 | TransformerFactory, 16 | ) 17 | 18 | __all__ = [ 19 | "QKNormMultiheadAttention", 20 | "LCMStandardTransformerDecoderLayer", 21 | "LCMTransformerDecoder", 22 | "TransformerConfig", 23 | "TransformerFactory", 24 | ] 25 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_predictors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import pytest 7 | 8 | from lcm.evaluation.predictors import build_predictor 9 | from lcm.evaluation.predictors.huggingface import ( 10 | HuggingfacePredictor, 11 | HuggingfacePredictorConfig, 12 | ) 13 | 14 | HF_PREDICTORS = [ 15 | HuggingfacePredictorConfig( 16 | model_name="google/pegasus-x-base", 17 | model_class="PegasusXForConditionalGeneration", 18 | tokenizer_name="google/pegasus-x-large", 19 | tokenizer_class="PreTrainedTokenizerFast", 20 | ) 21 | ] 22 | 23 | 24 | @pytest.mark.parametrize("config", HF_PREDICTORS) 25 | def test_hf_predictor(config): 26 | predictor = build_predictor(config) 27 | assert isinstance(predictor, HuggingfacePredictor) 28 | -------------------------------------------------------------------------------- /lcm/models/sonar_normalizer/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from fairseq2.models.config_loader import StandardModelConfigLoader 8 | from fairseq2.models.loader import StandardModelLoader, load_model 9 | 10 | from lcm.models.sonar_normalizer.builder import ( 11 | SonarNormalizerConfig, 12 | create_sonar_normalizer, 13 | sonar_normalizer_archs, 14 | ) 15 | 16 | load_sonar_normalizer_config = StandardModelConfigLoader( 17 | family="sonar_normalizer", 18 | config_kls=SonarNormalizerConfig, 19 | arch_configs=sonar_normalizer_archs, 20 | ) 21 | 22 | load_sonar_normalizer_model = StandardModelLoader( 23 | config_loader=load_sonar_normalizer_config, 24 | factory=create_sonar_normalizer, 25 | restrict_checkpoints=False, 26 | ) 27 | 28 | load_model.register("sonar_normalizer", load_sonar_normalizer_model) 29 | -------------------------------------------------------------------------------- /tests/units/evaluation/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | # flake8: noqa 6 | 7 | from pathlib import Path 8 | from typing import Any, Dict, List 9 | 10 | import pytest 11 | 12 | from lcm.evaluation.utils.common import write_to_jsonl 13 | 14 | 15 | def get_jsonl() -> List[Dict[str, Any]]: 16 | return [ 17 | { 18 | "input_text": f"This is a long sentence that sometimes have extraordinally complex words that is unimaginatively ambiguous such as Sesquipedalianism. We want to test this with a shorter sentence that follows text {i}", 19 | "target_text": f"random target {i}", 20 | } 21 | for i in range(100) 22 | ] 23 | 24 | 25 | @pytest.fixture 26 | def simple_json_dataset(tmp_path: Path): 27 | file_path = tmp_path / "dataset.jsonl" 28 | write_to_jsonl(get_jsonl(), str(file_path)) 29 | yield file_path 30 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_round_trip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | import pytest 8 | from fairseq2.gang import FakeGang 9 | 10 | from lcm.datasets.configs import SonarEncoderConfig 11 | from lcm.evaluation.metrics.round_trip import round_trip_l2_distance, text_encoder 12 | 13 | 14 | def test_round_trip_l2(monkeypatch): 15 | monkeypatch.setattr("lcm.evaluation.utils.sonar.get_gang", lambda: FakeGang()) 16 | 17 | x = ["The quick brown fox jumps over the lazy dog."] * 2 18 | y = [ 19 | "A quick brown fox jumps over a lazy dog.", 20 | "The fast brown fox jumps over a sleeping dog.", 21 | ] 22 | 23 | encoder_config = SonarEncoderConfig() 24 | encoder = text_encoder(encoder_config) 25 | encoded_y = encoder.predict(y, "eng_Latn", batch_size=2) 26 | 27 | scores = round_trip_l2_distance(x, encoded_y, encoder_config=encoder_config) 28 | 29 | assert pytest.approx(scores, abs=0.02) == [0.05, 0.11] 30 | -------------------------------------------------------------------------------- /recipes/common/evals.yaml: -------------------------------------------------------------------------------- 1 | # @package evaluator 2 | 3 | requirements: 4 | nodes: 1 5 | tasks_per_node: 1 6 | gpus_per_node: 1 7 | cpus_per_task: 8 8 | mem_gb: 256 9 | timeout_min: 120 # 2 hours 10 | constraint: null 11 | max_num_timeout: 10 12 | 13 | _builder_: lcm.evaluation.arun.build_async_task 14 | _runner_: lcm.evaluation.arun.schedule_task 15 | 16 | run_config: 17 | dump_dir: null 18 | seed: 42 19 | confidence_level: 0.95 20 | show_progress: false 21 | log_raw_results: false 22 | temperature: 0 23 | top_k: 0 24 | top_p: 1.0 25 | 26 | # Number of input shards to evaluate in parallel 27 | nshards: 1 28 | 29 | # Run the evaluation for every n saved and consolidated checkpoints. The number 30 | # of evaluation is calculated by: 31 | # ``TrainingConfig.save_model_every_n_steps * evaluate_every_n_saved_models 32 | # We set this attribute to be mandatory to enforce 33 | evaluate_every_save_steps: ??? 34 | 35 | predictor: ??? 36 | 37 | # Calculated automatically 38 | evaluate_every_n_steps: null 39 | max_evaluation_steps: null 40 | 41 | tasks: [] 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from typing import Any, List, Union 7 | 8 | import torch 9 | from fairseq2.typing import Device 10 | from torch import Tensor 11 | 12 | # The default device that tests should use. Note that pytest can change it based 13 | # on the provided command line arguments. 14 | device = Device("cpu") 15 | 16 | 17 | # The default deugging flag is False; Pytest can turn debuggin on 18 | # via command line arguments `pytest --debug-training` 19 | DEBUG = False 20 | 21 | 22 | def assert_close(a: Tensor, b: Union[Tensor, List[Any]]) -> None: 23 | """Assert that ``a`` and ``b`` are element-wise equal within a tolerance.""" 24 | if not isinstance(b, Tensor): 25 | b = torch.tensor(b, device=device, dtype=a.dtype) 26 | 27 | torch.testing.assert_close(a, b) # type: ignore[attr-defined] 28 | 29 | 30 | def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None: 31 | """Assert that ``a`` and ``b`` are element-wise equal.""" 32 | if not isinstance(b, Tensor): 33 | b = torch.tensor(b, device=device, dtype=a.dtype) 34 | 35 | torch.testing.assert_close(a, b, rtol=0, atol=0) # type: ignore[attr-defined] 36 | -------------------------------------------------------------------------------- /lcm/models/sonar_normalizer/archs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.models.sonar_normalizer.builder import ( 7 | SonarNormalizerConfig, 8 | sonar_normalizer_arch, 9 | ) 10 | 11 | 12 | @sonar_normalizer_arch("base") 13 | def _base_sonar_normalizer() -> SonarNormalizerConfig: 14 | """The base architecture for all center-and-scale normalizers 15 | regardless of how the center/scale are estimated""" 16 | return SonarNormalizerConfig( 17 | dim=1024, 18 | ) 19 | 20 | 21 | @sonar_normalizer_arch("base_page4k") 22 | def _base_page_normalizer() -> SonarNormalizerConfig: 23 | return SonarNormalizerConfig( 24 | dim=4 * 1024, 25 | ) 26 | 27 | 28 | @sonar_normalizer_arch("base_fft") 29 | def _base_fft_sonar_normalizer() -> SonarNormalizerConfig: 30 | return SonarNormalizerConfig(dim=1024, with_fft=True) 31 | 32 | 33 | @sonar_normalizer_arch("clipping") 34 | def _clipping_sonar_normalizer() -> SonarNormalizerConfig: 35 | return SonarNormalizerConfig(dim=1024, clip_proba=1e-4) 36 | 37 | 38 | @sonar_normalizer_arch("clipping_fft") 39 | def _clipping_fft_sonar_normalizer() -> SonarNormalizerConfig: 40 | return SonarNormalizerConfig(dim=1024, clip_proba=1e-4, with_fft=True) 41 | -------------------------------------------------------------------------------- /lcm/evaluation/metrics/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from itertools import islice 7 | from typing import Iterator, Sequence, Tuple 8 | 9 | 10 | def create_context_prediction_pairs( 11 | gt_docs: Sequence[Sequence[str]], 12 | pred_docs: Sequence[Sequence[str]], 13 | max_ctx_len: int, 14 | ) -> Iterator[Tuple[Sequence[str], str]]: 15 | """ 16 | For each pair of documents (where a predicted document has 1:1 alignment with a suffix of a gt document), 17 | generate all pairs of a predicted sentence and up to max_ctx_len gt sentences immediately preceding it. 18 | This could be used to prepare data for evaluating the quality of teacher forced generation. 19 | """ 20 | for gt_doc, pred_doc in zip(gt_docs, pred_docs): 21 | prefix_size = len(gt_doc) - len(pred_doc) 22 | for pred_idx, pred_sent in enumerate(pred_doc): 23 | full_ctx_len = pred_idx + prefix_size 24 | context_sents = gt_doc[max(0, full_ctx_len - max_ctx_len) : full_ctx_len] 25 | yield context_sents, pred_sent 26 | 27 | 28 | def divide_chunks_as(iterable, reference_sequences): 29 | it = iter(iterable) 30 | chunks = [len(seq) for seq in reference_sequences] 31 | return [list(islice(it, c)) for c in chunks] 32 | -------------------------------------------------------------------------------- /tests/units/training/test_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import torch 7 | from fairseq2.nn.padding import pad_seqs 8 | 9 | from lcm.datasets.batch import EmbeddingsBatch 10 | 11 | test_cases = [(0, 2), (2, 6), (0, 15), (3, 4), (1, 9)] 12 | 13 | 14 | def test_unbatchon_embedding_batch(): 15 | ragged_seq = [torch.randn((i**2 % 11, 3)) for i in range(100)] 16 | full_eb = EmbeddingsBatch(*pad_seqs(ragged_seq)) 17 | full_eb_bis = EmbeddingsBatch(*pad_seqs(full_eb.unbatch())) 18 | 19 | assert (full_eb_bis.seqs == full_eb.seqs).all().item() 20 | assert ( 21 | (full_eb_bis.padding_mask.seq_lens == full_eb.padding_mask.seq_lens) 22 | # type: ignore 23 | .all() 24 | .item() 25 | ) 26 | assert ( 27 | full_eb_bis.padding_mask._batch_seq_len == full_eb.padding_mask._batch_seq_len # type: ignore 28 | ) 29 | 30 | 31 | def test_last_element_embedding_batch(): 32 | ragged_seq = [torch.randn((i**2 % 11 + 1, 3)) for i in range(100)] 33 | full_eb = EmbeddingsBatch(*pad_seqs(ragged_seq)) 34 | 35 | expected_ans = torch.stack([tt[-1] for tt in ragged_seq], dim=0) 36 | print(expected_ans.shape) 37 | found_ans = full_eb.get_last_element() 38 | print(found_ans.shape) 39 | 40 | assert (expected_ans == found_ans).all().item() 41 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to large_concept_model 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to large_concept_model, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /recipes/train/finetune/mse.yaml: -------------------------------------------------------------------------------- 1 | # @package trainer 2 | 3 | _trainer_: lcm.train.lcm.trainer.prepare_lcm_trainer 4 | 5 | requirements: 6 | nodes: 4 7 | tasks_per_node: 8 8 | gpus_per_node: 8 9 | cpus_per_task: 32 10 | mem_gb: 0 11 | 12 | # Overwrite to start from a different pre-trained model 13 | model_config_or_name: dummy_mse_pretrained_model 14 | 15 | criterion: 16 | name: target_mse 17 | reduction: sum 18 | compute_rmse: False 19 | 20 | output_dir: ?? 21 | dtype: "torch.float16" 22 | use_optimizer_in_fp32: true 23 | use_fsdp: true 24 | fsdp_fp32_reduce: true 25 | 26 | lr_schedule: cosine 27 | start_lr: 1e-6 28 | final_lr: 1e-6 29 | lr: 0.00003 30 | num_lr_warmup_steps: 3_000 31 | max_grad_norm: 25 32 | weight_decay: 0.01 33 | 34 | max_steps: 20_000 35 | gradient_accumulation: 1 36 | validate_every_n_steps: 1_000 37 | checkpoint_every_n_steps: 1_000 38 | save_model_every_n_steps: 1_000 39 | keep_last_n_checkpoints: 2 40 | publish_metrics_every_n_steps: 100 41 | preserve_consolidated_models: true 42 | 43 | seed: 1 44 | profile: false 45 | 46 | data_loading_config: 47 | max_tokens: 7168 48 | nb_epochs: 5 49 | 50 | training_data: 51 | - name: "finetuning_data=train" 52 | source_suffix_text: "[MODEL]:" 53 | target_suffix_text: "End of text." 54 | 55 | validation_data: 56 | - name: "finetuning_data=validation" 57 | source_suffix_text: "[MODEL]:" 58 | target_suffix_text: "End of text." 59 | -------------------------------------------------------------------------------- /lcm/evaluation/utils/sonar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from functools import lru_cache 8 | from typing import Optional 9 | 10 | from fairseq2.typing import Device 11 | from sonar.inference_pipelines.text import ( 12 | TextToEmbeddingModelPipeline, 13 | TextToTextModelPipeline, 14 | ) 15 | 16 | from lcm.datasets.configs import SonarEncoderConfig 17 | 18 | from .distributed import get_gang 19 | 20 | # We fix the sonar encoder for the LCM prompt 21 | eng_config = SonarEncoderConfig( 22 | tokenizer="text_sonar_basic_encoder", 23 | encoder="text_sonar_basic_encoder", 24 | lang="eng_Latn", 25 | ) 26 | 27 | 28 | @lru_cache(maxsize=2) 29 | def text_encoder( 30 | config: SonarEncoderConfig = eng_config, device: Optional[Device] = None 31 | ): 32 | """Load a text embedding pipleine with a sonar encoder""" 33 | if device is None: 34 | gang = get_gang() 35 | device = gang.device 36 | 37 | return TextToEmbeddingModelPipeline( 38 | encoder=config.encoder, 39 | tokenizer=config.tokenizer, 40 | device=device, 41 | ) 42 | 43 | 44 | @lru_cache 45 | def text_translator(): 46 | t2t_model = TextToTextModelPipeline( 47 | encoder="text_sonar_basic_encoder", 48 | decoder="text_sonar_basic_decoder", 49 | tokenizer="text_sonar_basic_encoder", 50 | ) 51 | return t2t_model 52 | -------------------------------------------------------------------------------- /lcm/models/base_lcm/archs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.models.base_lcm.builder import ( 7 | BaseLCModelConfig, 8 | LCMFrontendConfig, 9 | ProjectionConfig, 10 | TransformerConfig, 11 | lcm_arch, 12 | ) 13 | 14 | 15 | # Every model must register a toy_{model_family} 16 | @lcm_arch("toy_base_lcm") 17 | def toy_base_lcm() -> BaseLCModelConfig: 18 | return BaseLCModelConfig( 19 | lcm=TransformerConfig(num_layers=2), 20 | ) 21 | 22 | 23 | @lcm_arch("base_lcm_1_6B") 24 | def base_lcm_1_6B() -> BaseLCModelConfig: 25 | """Base 1.6B model 26 | Parameter Size: 1,647,635,456 27 | """ 28 | model_dim: int = 2048 29 | num_attn_heads: int = 16 30 | return BaseLCModelConfig( 31 | max_seq_len=4096, 32 | model_dim=model_dim, 33 | sonar_embed_dim=1024, 34 | sonar_normalizer_name="dummy_sonar_normalizer", 35 | frontend=LCMFrontendConfig(), 36 | lcm=TransformerConfig( 37 | final_dropout_p=0.0, 38 | attention_dropout_p=0.0, 39 | dropout_p=0.1, 40 | mha_output_proj_bias=True, 41 | ffn_inner_dim=model_dim * 4, 42 | num_attn_heads=num_attn_heads, 43 | num_layers=32, 44 | pos_embedding_style="rope", 45 | use_swiglu=True, 46 | layer_normalization_style="rms", 47 | ), 48 | postnet=ProjectionConfig(), 49 | ) 50 | -------------------------------------------------------------------------------- /lcm/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | import torch 8 | from fairseq2.models.sequence import SequenceBatch 9 | 10 | 11 | def move_eos_to_the_end( 12 | batch: SequenceBatch, pad_token_id: int = 0, eos_token_id: int = 3 13 | ) -> SequenceBatch: 14 | """ 15 | Convert a decoder-input batch (with the eos token in the beginning) to a decoder-output batch 16 | (with eos in the end) of the same shape. 17 | Note that this processing is missing two potentially critical issues: 18 | 1) If the sequence end has been truncated away, EOS token will be appended erroneously. 19 | 2) The language code token is still included in the loss computation (we may want to avoid it). 20 | """ 21 | # strip the EOS token prepended to the input and add an empty token in the end 22 | seqs = torch.cat( 23 | [ 24 | batch.seqs[:, 1:], 25 | torch.zeros_like(batch.seqs[:, :1]) + pad_token_id, 26 | ], 27 | dim=-1, 28 | ) 29 | # fill the last real token in the batch with the eos value 30 | if batch.padding_mask: 31 | seqs[ 32 | torch.arange(seqs.shape[0], dtype=torch.int32), 33 | batch.padding_mask.seq_lens - 1, 34 | ] = eos_token_id 35 | else: 36 | seqs[:, -1] = eos_token_id 37 | 38 | result = SequenceBatch( 39 | seqs=seqs, 40 | padding_mask=batch.padding_mask, 41 | ) 42 | return result 43 | -------------------------------------------------------------------------------- /lcm/evaluation/predictors/two_tower_diffusion_lcm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass 7 | from typing import Any 8 | 9 | from fairseq2.typing import override 10 | 11 | from lcm.evaluation.predictors.lcm import LCMConfig, LCMPredictor 12 | from lcm.inference.two_tower_diffusion_lcm import ( 13 | DiffusionLCMGeneratorOptions, 14 | TwoTowerDiffusionLCMGenerator, 15 | ) 16 | 17 | 18 | @dataclass(unsafe_hash=True) 19 | class TwoTowerDiffusionLCMConfig(DiffusionLCMGeneratorOptions, LCMConfig): 20 | @classmethod 21 | def predictor_class(cls): 22 | return TwoTowerDiffusionLCMPredictor 23 | 24 | 25 | class TwoTowerDiffusionLCMPredictor(LCMPredictor): 26 | """ 27 | A predictor that wraps LCMGenerator and format the output for evaluation 28 | """ 29 | 30 | config: TwoTowerDiffusionLCMConfig 31 | 32 | def __init__( 33 | self, 34 | config: TwoTowerDiffusionLCMConfig, 35 | **kwargs: Any, 36 | ): 37 | super().__init__(config, **kwargs) 38 | 39 | def build_generator(self, model): 40 | self.generator = TwoTowerDiffusionLCMGenerator( # type: ignore 41 | model=model, options=self.config, eos_vec=self.eos_vec 42 | ) 43 | 44 | @override 45 | @staticmethod 46 | def from_config(config: LCMConfig, **kwargs) -> "TwoTowerDiffusionLCMPredictor": 47 | return TwoTowerDiffusionLCMPredictor(config=config, **kwargs) # type: ignore 48 | -------------------------------------------------------------------------------- /lcm/models/base_lcm/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from typing import Optional, final 7 | 8 | import torch 9 | from fairseq2.nn import LayerNorm, RMSNorm 10 | from fairseq2.typing import DataType, Device, override 11 | 12 | 13 | @final 14 | class FP32LayerNorm(LayerNorm): 15 | """Applies Layer Normalization in single-precision.""" 16 | 17 | @override 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | w, b = self.weight, self.bias 20 | 21 | # cast input and params to float32 22 | fp32_x = x.float() 23 | fp32_w = w.float() if w is not None else None 24 | fp32_b = b.float() if b is not None else None 25 | 26 | y = torch.nn.functional.layer_norm( 27 | fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps 28 | ) 29 | 30 | return y.type_as(x) 31 | 32 | 33 | def build_rms_layer_norm( 34 | model_dim: int, 35 | *, 36 | device: Optional[Device] = None, 37 | dtype: Optional[DataType] = None, 38 | ) -> LayerNorm: 39 | """Build an RMS Layer Normalization module.""" 40 | return RMSNorm(model_dim, bias=False, device=device, dtype=dtype) 41 | 42 | 43 | def build_fp32_layer_norm( 44 | model_dim: int, 45 | *, 46 | device: Optional[Device] = None, 47 | dtype: Optional[DataType] = None, 48 | ) -> LayerNorm: 49 | """Build an Single-precision Layer Normalization module.""" 50 | return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype) 51 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_task_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # type: ignore 4 | 5 | from dataclasses import dataclass 6 | 7 | import pytest 8 | 9 | from lcm.evaluation.tasks import TaskConfig, TaskRegistry 10 | from lcm.evaluation.utils.data_utils import EvalDataLoader 11 | 12 | 13 | @pytest.fixture() 14 | def reset_task_registry(): 15 | registry = TaskRegistry._REGISTRY 16 | TaskRegistry.reset() 17 | yield 18 | TaskRegistry._REGISTRY = registry 19 | 20 | 21 | @dataclass 22 | class TestTaskConfig(TaskConfig): 23 | foo: int = 0 24 | bar: str = "qux" 25 | 26 | 27 | @pytest.mark.usefixtures("reset_task_registry") 28 | def test_new_tasks() -> None: 29 | TaskRegistry.register( 30 | "task0", lambda: TestTaskConfig(None), data_loader_type=EvalDataLoader 31 | ) 32 | TaskRegistry.register( 33 | "task1", 34 | lambda foo: TestTaskConfig(None, foo=foo), 35 | data_loader_type=EvalDataLoader, 36 | ) 37 | TaskRegistry.register( 38 | "task2", 39 | lambda bar: TestTaskConfig(None, bar=bar), 40 | data_loader_type=EvalDataLoader, 41 | ) 42 | 43 | assert TaskRegistry.names() == {"task0", "task1", "task2"}, TaskRegistry.names() 44 | assert TaskRegistry.get_config("task0") == TestTaskConfig(None, foo=0, bar="qux") 45 | assert TaskRegistry.get_config("task1", foo=4) == TestTaskConfig(None, foo=4) 46 | assert TaskRegistry.get_config("task2", bar="waldo") == TestTaskConfig( 47 | None, bar="waldo" 48 | ) 49 | -------------------------------------------------------------------------------- /.github/workflows/lint_and_test.yaml: -------------------------------------------------------------------------------- 1 | name: Lint Python Code 2 | 3 | on: 4 | # Trigger the workflow on push to master or any pull request 5 | push: 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | lock_file: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: ./.github/actions/setup 16 | - run: uv lock --locked 17 | linting: 18 | runs-on: ubuntu-latest 19 | needs: [lock_file] 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: ./.github/actions/setup 23 | - run: uvx ruff check . 24 | formatting: 25 | runs-on: ubuntu-latest 26 | needs: [lock_file] 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: ./.github/actions/setup 30 | - run: uvx ruff check --select I . 31 | - run: uvx ruff format --check . 32 | type_consistency: 33 | runs-on: ubuntu-latest 34 | needs: [lock_file] 35 | steps: 36 | - uses: actions/checkout@v4 37 | - uses: ./.github/actions/setup 38 | - run: uvx --with=types-PyYAML mypy 39 | tests: 40 | runs-on: ubuntu-latest 41 | needs: [lock_file] 42 | steps: 43 | - uses: actions/checkout@v4 44 | - uses: ./.github/actions/setup 45 | - run: uv run --extra cpu --extra eval --extra data pytest -v --full-trace 46 | build: 47 | runs-on: ubuntu-latest 48 | needs: [lock_file, linting, formatting, type_consistency, tests] 49 | steps: 50 | - uses: actions/checkout@v4 51 | - uses: ./.github/actions/setup 52 | - run: uv build -------------------------------------------------------------------------------- /lcm/nn/incremental_state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from typing import Dict, Optional, final 7 | 8 | from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag 9 | from fairseq2.nn.transformer import FullAttentionState 10 | from torch import Tensor 11 | from torch.nn import Module 12 | 13 | 14 | @final 15 | class LCMIncrementalStateBag(IncrementalStateBag): # type: ignore 16 | """Holds the module states during incremental decoding.""" 17 | 18 | _module_states: Dict[Module, FullAttentionState] # type: ignore 19 | 20 | def __init__( 21 | self, max_num_steps: int, *, capacity_increment: Optional[int] = 16 22 | ) -> None: 23 | super().__init__( 24 | max_num_steps=max_num_steps, capacity_increment=capacity_increment 25 | ) 26 | 27 | def reorder(self, new_order: Tensor) -> None: 28 | """Reorder the module states. 29 | 30 | See :meth:`IncrementalState.reorder` for more information. 31 | """ 32 | # FIXME Deal with reordering diffusion state bags here 33 | for state in self._module_states.values(): 34 | state.reorder(new_order) 35 | 36 | def set_state(self, m: Module, state: IncrementalState) -> None: 37 | """Set the state of ``m``. 38 | :param m: The module. 39 | :param state: The state to store. 40 | There is no current call to `set_state` when the bag 41 | is frozen, but it's implemented here for completeness 42 | """ 43 | super().set_state(m, state) 44 | -------------------------------------------------------------------------------- /lcm/models/two_tower_diffusion_lcm/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from fairseq2.models.config_loader import StandardModelConfigLoader 8 | from fairseq2.models.loader import StandardModelLoader, load_model 9 | 10 | from lcm.models.base_lcm.loader import convert_lcm_checkpoint 11 | from lcm.models.two_tower_diffusion_lcm.builder import ( 12 | TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, 13 | TwoTowerDiffusionLCModelConfig, 14 | create_two_tower_diffusion_lcm_model, 15 | lcm_archs, 16 | ) 17 | from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry 18 | 19 | load_two_tower_diffusion_lcm_config = StandardModelConfigLoader( 20 | family=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, 21 | config_kls=TwoTowerDiffusionLCModelConfig, 22 | arch_configs=lcm_archs, 23 | ) 24 | 25 | 26 | load_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME 27 | config_loader=load_two_tower_diffusion_lcm_config, 28 | factory=create_two_tower_diffusion_lcm_model, 29 | checkpoint_converter=convert_lcm_checkpoint, 30 | restrict_checkpoints=False, 31 | ) 32 | 33 | load_model.register( 34 | TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, load_two_tower_diffusion_lcm_model 35 | ) 36 | 37 | lcm_model_type_registry.register( 38 | ModelTypeConfig( 39 | model_type=TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE, 40 | config_loader=load_two_tower_diffusion_lcm_config, 41 | model_factory=create_two_tower_diffusion_lcm_model, 42 | model_loader=load_two_tower_diffusion_lcm_model, 43 | ) 44 | ) 45 | -------------------------------------------------------------------------------- /recipes/train/pretrain/mse.yaml: -------------------------------------------------------------------------------- 1 | # @package trainer 2 | 3 | _trainer_: lcm.train.lcm.trainer.prepare_lcm_trainer 4 | 5 | output_dir: ?? 6 | 7 | # Parameter Size: 1,647,635,456 8 | model_arch: base_lcm_1_6B 9 | 10 | criterion: 11 | name: next_sentence_mse 12 | reduction: sum 13 | compute_rmse: False 14 | 15 | dtype: "torch.float16" 16 | use_optimizer_in_fp32: true 17 | use_fsdp: true 18 | fsdp_fp32_reduce: true 19 | 20 | lr: 0.0004 21 | lr_schedule: cosine 22 | num_lr_warmup_steps: 10_000 23 | max_steps: 250_000 24 | gradient_accumulation: 1 25 | max_grad_norm: 25 26 | weight_decay: 0.1 27 | adam_betas: 28 | - 0.9 29 | - 0.95 30 | adam_eps: 1e-5 31 | 32 | validate_every_n_steps: 10_000 33 | save_model_every_n_steps: 2_000 34 | checkpoint_every_n_steps: 2_000 35 | keep_last_n_checkpoints: 2 36 | preserve_consolidated_models: True 37 | publish_metrics_every_n_steps: 100 38 | 39 | seed: 1 40 | profile: false 41 | 42 | data_loading_config: 43 | max_tokens: 7168 44 | min_batch_size: 1 45 | len_to_wrap_long_seq: 128 46 | packing: false 47 | min_length_of_sequences: 1 48 | min_length_after_batching: 2 49 | num_parallel_calls: 1 50 | nb_prefetch: 5 51 | nb_epochs: 1 52 | 53 | validation_data_loading_config: 54 | len_to_wrap_long_seq: 128 55 | 56 | training_data: 57 | - name: "pretraining_data=train" 58 | source_suffix_text: "End of text." 59 | 60 | validation_data: 61 | - name: "pretraining_data=validation" 62 | source_suffix_text: "End of text." 63 | 64 | requirements: 65 | nodes: 4 66 | tasks_per_node: 8 67 | gpus_per_node: 8 68 | cpus_per_task: 32 69 | mem_gb: 0 70 | timeout_min: 10000 71 | -------------------------------------------------------------------------------- /recipes/train/finetune/two_tower.yaml: -------------------------------------------------------------------------------- 1 | # @package trainer 2 | 3 | _trainer_: lcm.train.two_tower_diffusion_lcm.trainer.prepare_two_tower_diffusion_lcm_trainer 4 | 5 | requirements: 6 | nodes: 1 7 | tasks_per_node: 8 8 | gpus_per_node: 8 9 | cpus_per_task: 32 10 | mem_gb: 0 11 | 12 | # Overwrite to start from a different pre-trained model 13 | model_config_or_name: dummy_2t_pretrained_model 14 | 15 | criterion: 16 | name: two_tower_diffusion_next_sent_finetuning 17 | cf_guidance_probability: 0.15 18 | reduction: sum 19 | log_losses_per_timestep_bucket: False 20 | compute_rmse: False 21 | step_sampling: 22 | sampling: "uniform" 23 | weighting: "none" 24 | 25 | output_dir: ?? 26 | dtype: "torch.float16" 27 | use_optimizer_in_fp32: true 28 | use_fsdp: true 29 | fsdp_fp32_reduce: true 30 | 31 | lr_schedule: cosine 32 | start_lr: 1e-6 33 | final_lr: 1e-6 34 | lr: 0.0003 35 | num_lr_warmup_steps: 3_000 36 | max_grad_norm: 25 37 | weight_decay: 0.01 38 | 39 | max_steps: 20_000 40 | gradient_accumulation: 1 41 | validate_every_n_steps: 1_000 42 | checkpoint_every_n_steps: 1_000 43 | save_model_every_n_steps: 1_000 44 | keep_last_n_checkpoints: 2 45 | publish_metrics_every_n_steps: 100 46 | preserve_consolidated_models: true 47 | 48 | seed: 1 49 | profile: false 50 | 51 | data_loading_config: 52 | max_tokens: 7168 53 | nb_epochs: 5 54 | 55 | training_data: 56 | - name: "finetuning_data=train" 57 | source_suffix_text: "[MODEL]:" 58 | target_suffix_text: "End of text." 59 | 60 | validation_data: 61 | - name: "finetuning_data=validation" 62 | source_suffix_text: "[MODEL]:" 63 | target_suffix_text: "End of text." 64 | -------------------------------------------------------------------------------- /lcm/train/two_tower_diffusion_lcm/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass, field 7 | from typing import Union 8 | 9 | from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig 10 | from lcm.models.two_tower_diffusion_lcm.loader import ( 11 | load_two_tower_diffusion_lcm_model, 12 | ) 13 | from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig 14 | from lcm.train.two_tower_diffusion_lcm.criterion import ( 15 | TowerDiffusionLCMCriterionConfig, 16 | ) 17 | 18 | 19 | @dataclass 20 | class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig): 21 | model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None 22 | """The model configuration or name to train.""" 23 | 24 | criterion: TowerDiffusionLCMCriterionConfig = field( # type: ignore 25 | default_factory=lambda: TowerDiffusionLCMCriterionConfig() 26 | ) 27 | 28 | 29 | class DiffusionLCMTrainerBuilder(LCMTrainerBuilder): 30 | config: TwoTowerDiffusionLCMTrainingConfig 31 | 32 | def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig): 33 | super().__init__(config) 34 | 35 | @property 36 | def model_loader(self): 37 | """A fairseq2 ModelLoader""" 38 | return load_two_tower_diffusion_lcm_model 39 | 40 | 41 | def prepare_two_tower_diffusion_lcm_trainer( 42 | config: TwoTowerDiffusionLCMTrainingConfig, 43 | ) -> LCMTrainer: 44 | """Create an LCM Trainer. 45 | :param config: The training configuration. 46 | """ 47 | return DiffusionLCMTrainerBuilder(config).build_trainer() 48 | -------------------------------------------------------------------------------- /tests/units/training/test_get_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from omegaconf import DictConfig 9 | 10 | from lcm.train.common import get_trainer 11 | 12 | 13 | @dataclass 14 | class Foo: 15 | a: float = 0 16 | b: float = 0 17 | c: float = field(init=False) 18 | 19 | def __post_init__(self): 20 | self.c = self.a + self.b 21 | 22 | 23 | @dataclass 24 | class Config: 25 | foobar: str = "test" 26 | cfg: Foo = field(default_factory=lambda: Foo()) 27 | c: float = field(init=False) 28 | 29 | def __post_init__(self): 30 | self.c = 10.0 31 | 32 | 33 | class TrainerClass: 34 | def __init__(self, config: Config) -> None: 35 | self.config = config 36 | pass 37 | 38 | 39 | def trainer_builder(config: Config): 40 | return TrainerClass(config) 41 | 42 | 43 | def test_get_trainer_fn(): 44 | conf_dict = DictConfig( 45 | { 46 | "_trainer_": f"{trainer_builder.__module__}.trainer_builder", 47 | "foobar": "bar", 48 | "cfg": { 49 | "a": 1, 50 | "b": 3, 51 | }, 52 | }, 53 | ) 54 | tr = get_trainer(conf_dict) 55 | assert isinstance(tr, TrainerClass) 56 | assert tr.config.foobar == "bar" 57 | assert tr.config.cfg.c == 4.0 58 | 59 | 60 | def test_get_trainer_class(): 61 | conf_dict = DictConfig( 62 | { 63 | "_trainer_": f"{TrainerClass.__module__}.TrainerClass", 64 | "foobar": "bar", 65 | } 66 | ) 67 | tr = get_trainer(conf_dict) 68 | assert isinstance(tr, TrainerClass) 69 | assert tr.config.foobar == "bar" 70 | -------------------------------------------------------------------------------- /lcm/evaluation/metrics/round_trip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # 4 | # 5 | 6 | 7 | from typing import List, Union 8 | 9 | import torch 10 | 11 | from lcm.datasets.configs import SonarEncoderConfig 12 | from lcm.evaluation.metrics.similarity import cos_sim, l2_distance 13 | from lcm.evaluation.utils.sonar import text_encoder 14 | 15 | 16 | def round_trip_l2_distance( 17 | prediction_text: List[str], 18 | targets: torch.Tensor, 19 | encoder_config: SonarEncoderConfig, 20 | flatten: bool = False, 21 | ) -> Union[List[float], List[List[float]]]: 22 | """ 23 | Calculate the L2 distance of a text and a vector by putting the text 24 | into the sonar space embedding 25 | """ 26 | text2vec = text_encoder(encoder_config, device=targets.device) 27 | prediction_projected = text2vec.predict( 28 | prediction_text, 29 | source_lang=encoder_config.lang, 30 | batch_size=len(prediction_text), 31 | ).reshape(targets.shape) 32 | 33 | return l2_distance(prediction_projected, targets, flatten=flatten) 34 | 35 | 36 | def round_trip_cos( 37 | prediction_text: List[str], 38 | targets: torch.Tensor, 39 | encoder_config: SonarEncoderConfig, 40 | ) -> Union[List[float], List[List[float]]]: 41 | """ 42 | Calculate the cosine similarity between a text and the target embeddings 43 | (e.g. input embedding of the decoder that generates the text) 44 | """ 45 | text2vec = text_encoder(encoder_config) 46 | prediction_projected = text2vec.predict( 47 | prediction_text, 48 | source_lang=encoder_config.lang, 49 | batch_size=len(prediction_text), 50 | ).reshape(targets.shape) 51 | 52 | return cos_sim(prediction_projected.numpy(), targets.numpy()) 53 | -------------------------------------------------------------------------------- /lcm/models/base_lcm/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import logging 7 | from typing import Any, Dict 8 | 9 | from fairseq2.models.config_loader import StandardModelConfigLoader 10 | from fairseq2.models.loader import StandardModelLoader, load_model 11 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 12 | 13 | from lcm.models.base_lcm.builder import ( 14 | BASE_LCM_MODEL_TYPE, 15 | BaseLCModelConfig, 16 | create_base_lcm_model, 17 | lcm_archs, 18 | ) 19 | from lcm.utils.model_type_registry import ModelTypeConfig, lcm_model_type_registry 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def convert_lcm_checkpoint( 25 | checkpoint: Dict[str, Any], config: BaseLCModelConfig 26 | ) -> Dict[str, Any]: 27 | # For DDP checkpoints 28 | # We need to first remove the prefix "module." from state dict keys. 29 | consume_prefix_in_state_dict_if_present(checkpoint["model"], "module.") 30 | return checkpoint 31 | 32 | 33 | load_base_lcm_config = StandardModelConfigLoader( 34 | family=BASE_LCM_MODEL_TYPE, 35 | config_kls=BaseLCModelConfig, 36 | arch_configs=lcm_archs, 37 | ) 38 | 39 | load_base_lcm_model = StandardModelLoader( 40 | config_loader=load_base_lcm_config, 41 | factory=create_base_lcm_model, 42 | checkpoint_converter=convert_lcm_checkpoint, 43 | restrict_checkpoints=False, 44 | ) 45 | 46 | load_model.register(BASE_LCM_MODEL_TYPE, load_base_lcm_model) 47 | 48 | lcm_model_type_registry.register( 49 | ModelTypeConfig( 50 | model_type=BASE_LCM_MODEL_TYPE, 51 | config_loader=load_base_lcm_config, 52 | model_factory=create_base_lcm_model, 53 | model_loader=load_base_lcm_model, 54 | ) 55 | ) 56 | -------------------------------------------------------------------------------- /recipes/train/pretrain/two_tower.yaml: -------------------------------------------------------------------------------- 1 | # @package trainer 2 | 3 | _trainer_: lcm.train.two_tower_diffusion_lcm.trainer.prepare_two_tower_diffusion_lcm_trainer 4 | 5 | output_dir: ?? 6 | 7 | #Parameter Size: 1,635,101,696 8 | model_arch: two_tower_diffusion_lcm_1_6B 9 | 10 | criterion: 11 | name: two_tower_diffusion_next_sent 12 | cf_guidance_probability: 0.15 13 | reduction: sum 14 | log_losses_per_timestep_bucket: False 15 | compute_rmse: False 16 | step_sampling: 17 | sampling: "uniform" 18 | weighting: "none" 19 | 20 | dtype: "torch.float16" 21 | use_optimizer_in_fp32: true 22 | use_fsdp: true 23 | fsdp_fp32_reduce: true 24 | 25 | lr: 0.0004 26 | lr_schedule: cosine 27 | num_lr_warmup_steps: 10_000 28 | max_steps: 250_000 29 | gradient_accumulation: 1 30 | max_grad_norm: 25 31 | weight_decay: 0.1 32 | adam_betas: 33 | - 0.9 34 | - 0.95 35 | adam_eps: 1e-5 36 | 37 | validate_every_n_steps: 10_000 38 | save_model_every_n_steps: 2_000 39 | checkpoint_every_n_steps: 2_000 40 | keep_last_n_checkpoints: 2 41 | preserve_consolidated_models: True 42 | publish_metrics_every_n_steps: 100 43 | 44 | seed: 1 45 | profile: false 46 | 47 | data_loading_config: 48 | max_tokens: 7168 49 | min_batch_size: 1 50 | len_to_wrap_long_seq: 128 51 | packing: false 52 | min_length_of_sequences: 1 53 | min_length_after_batching: 2 54 | num_parallel_calls: 1 55 | nb_prefetch: 5 56 | nb_epochs: 1 57 | 58 | validation_data_loading_config: 59 | len_to_wrap_long_seq: 128 60 | 61 | training_data: 62 | - name: "pretraining_data=train" 63 | source_suffix_text: "End of text." 64 | 65 | validation_data: 66 | - name: "pretraining_data=validation" 67 | source_suffix_text: "End of text." 68 | 69 | requirements: 70 | nodes: 4 71 | tasks_per_node: 8 72 | gpus_per_node: 8 73 | cpus_per_task: 32 74 | mem_gb: 0 75 | timeout_min: 10000 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # JetBrains PyCharm IDE 3 | .idea/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | **/*/__pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # macOS dir files 14 | .DS_Store 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # mkdocs documentation 91 | /site 92 | 93 | # mypy 94 | .mypy_cache/ 95 | 96 | .pytest_cache 97 | .ruff_cache 98 | 99 | # VSCODE 100 | .vscode/ftp-sync.json 101 | .vscode/settings.json 102 | .vscode/launch.json 103 | 104 | # stopes logs 105 | executor_logs/ 106 | config_logs/ 107 | outputs/ 108 | 109 | logs/ 110 | **/dask_jobqueue_logs 111 | core.* 112 | mortimer_env.txt 113 | -------------------------------------------------------------------------------- /tests/units/datapipeline/test_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from lcm.datasets.sentence_splitting import ( 8 | ResplitSentenceSplitter, 9 | deescape_special_chars, 10 | filter_empty_string, 11 | remove_emojis, 12 | remove_non_printable_chars, 13 | resplit, 14 | ) 15 | 16 | 17 | def test_remove_emojis(): 18 | assert remove_emojis("Hello 😊, 🤣 how are you? 🤔") == "Hello , how are you? " 19 | 20 | 21 | def test_filter_empty_string(): 22 | assert not filter_empty_string("This is a long sentence with multiple words.") 23 | assert filter_empty_string(" ") 24 | 25 | 26 | def test_remove_non_printable_chars(): 27 | assert ( 28 | remove_non_printable_chars("Hello\nWorld. This is a test sentence.") 29 | == "HelloWorld. This is a test sentence." 30 | ) 31 | 32 | 33 | def test_deescape_special_chars(): 34 | assert ( 35 | deescape_special_chars("Hello\\nWorld. This is a test\\nsentence.") 36 | == "Hello\nWorld. This is a test\nsentence." 37 | ) 38 | 39 | 40 | def test_resplit(): 41 | assert resplit( 42 | "This is a long sentence that should be split into multiple parts.", 43 | max_length=20, 44 | sep=" ", 45 | ) == [ 46 | "This is a long ", 47 | "sentence that ", 48 | "should be split ", 49 | "into multiple parts.", 50 | ] 51 | 52 | 53 | def test_ResplitSentenceSplitter(): 54 | splitter = ResplitSentenceSplitter() 55 | assert splitter( 56 | "This is a long sentence. It should be split into two parts.", "eng", 200 57 | ) == ["This is a long sentence.", "It should be split into two parts."] 58 | 59 | assert splitter( 60 | "This is a long sentence?It should be split into two parts.", "eng", 50 61 | ) == ["This is a long sentence?", "It should be split into two parts."] 62 | -------------------------------------------------------------------------------- /lcm/evaluation/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | import sys 8 | import typing as tp 9 | 10 | from lcm.evaluation.cli.configs import CliConfig, LauncherOptions 11 | from lcm.evaluation.cli.params import ( 12 | extract_args_from_cli, 13 | from_cli, 14 | parse_args, 15 | to_cli, 16 | ) 17 | from lcm.evaluation.predictors import get_config_cls 18 | from lcm.evaluation.utils.common import initialize_logger 19 | 20 | 21 | def cfg_from_cli( 22 | args: tp.Optional[tp.Sequence[str]] = None, 23 | ) -> tp.Tuple[CliConfig, tp.Optional[LauncherOptions]]: 24 | known_eval_cfg, unknown = to_cli(CliConfig).parse_known_args(args) 25 | cfg: CliConfig = from_cli(CliConfig, vars(known_eval_cfg), allow_incomplete=True) 26 | 27 | # Extract data configs (dataset and data_loading) 28 | dataset_args, unknown = extract_args_from_cli(unknown, prefix="dataset.") 29 | if dataset_args: 30 | cfg.dataset_args = dataset_args 31 | dataloading_args, unknown = extract_args_from_cli(unknown, prefix="data_loading.") 32 | if dataloading_args: 33 | cfg.dataloading_args = dataloading_args 34 | 35 | cfg.predictor_config = parse_args(get_config_cls(cfg.predictor), unknown) 36 | 37 | # For CLI, the `seed` param is passed to the task config and this is not 38 | # passed to the predictor config, so we have to set it manually 39 | setattr(cfg.predictor_config, "seed", cfg.seed) 40 | 41 | return cfg, cfg.launcher 42 | 43 | 44 | if __name__ == "__main__": 45 | cfg, launcher_opts = cfg_from_cli() 46 | logger = initialize_logger() 47 | 48 | if cfg.dry_run: 49 | logger.info(f"Eval config: {cfg}") 50 | sys.exit(0) 51 | 52 | if launcher_opts: 53 | from lcm.evaluation.cli import slurm 54 | 55 | slurm.main(cfg, launcher_opts, logger=logger) 56 | else: 57 | from lcm.evaluation.cli import local 58 | 59 | local.main(cfg, logger=logger) 60 | -------------------------------------------------------------------------------- /lcm/evaluation/predictors/gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass 7 | 8 | from lcm.evaluation.api import ( 9 | PREDICTION_COLUMN, 10 | Example, 11 | ) 12 | from lcm.evaluation.predictors.huggingface import ( 13 | HuggingfacePredictor, 14 | HuggingfacePredictorConfig, 15 | ) 16 | 17 | 18 | @dataclass 19 | class GemmaPredictorConfig(HuggingfacePredictorConfig): 20 | @classmethod 21 | def predictor_class(cls): 22 | return GemmaPredictor 23 | 24 | 25 | class GemmaPredictor(HuggingfacePredictor): 26 | @staticmethod 27 | def from_config(config: GemmaPredictorConfig, **kwargs) -> "GemmaPredictor": # type: ignore 28 | predictor = HuggingfacePredictor.from_config(config, **kwargs) 29 | return GemmaPredictor(predictor.model, predictor.tokenizer, config) 30 | 31 | def post_process(self, x: Example) -> Example: 32 | """Handle the cleaning of response from Gemma models""" 33 | 34 | pred = x.pop(PREDICTION_COLUMN) 35 | 36 | # Pretrained model 37 | if not self.config.model_name.endswith("-it"): 38 | # Clean prediction for summarization task based on some default prompts 39 | if "Write me a summary" in pred: 40 | response_idx = pred.find( 41 | "\n\nResponse:", pred.rfind("Write me a summary") 42 | ) 43 | if response_idx > 0: 44 | pred = pred[response_idx + len("\n\nResponse:") :].strip() 45 | 46 | # Instruction-fine-tuned models 47 | else: 48 | if "\n\nmodel\n\n" in pred: 49 | pred = pred[pred.rfind("\n\nmodel\n") + len("\nmodel\n") :].strip() 50 | if ( 51 | pred.startswith("Sure, here") 52 | or pred.startswith("Sure here") 53 | or pred.startswith("## Summary") 54 | ): 55 | colon_idx = pred.find(":") 56 | pred = pred[colon_idx + 1 :].strip() 57 | 58 | return { 59 | PREDICTION_COLUMN: pred, 60 | **x, 61 | } 62 | -------------------------------------------------------------------------------- /recipes/train/README.md: -------------------------------------------------------------------------------- 1 | # Main ingredients of training recipes 2 | 3 | ### Training and validation data 4 | 5 | ```yaml 6 | training_data: 7 | - name: "=:" 8 | source_prefix_text: "Beginning of source." # concept added at the beginning of source 9 | source_suffix_text: "End of source." # concept added at the end of source 10 | target_prefix_text: "Beginning of target." # concept added at the beginning of target (supervised data only) 11 | target_suffix_text: "End of target." # concept added at the end of target (supervised data only) 12 | 13 | - name: "=:" 14 | 15 | 16 | ``` 17 | 18 | ### Data loading config 19 | 20 | ```yaml 21 | data_loading_config: 22 | max_tokens: 7168 # Exclusive with batch_size 23 | batch_size: none # Exclusive with max_tokens 24 | len_to_wrap_long_seq: 128 # Sequences longer than this will be wrapped. 25 | packing: true # if True, documents in the batch will be packed. 26 | ``` 27 | 28 | The batch content can be defined in several ways: 29 | - `max_tokens` / `len_to_wrap_long_seq` approximate `batch_size`. 30 | - `batch_size` x `len_to_wrap_long_seq` approximate `max_tokens`. 31 | 32 | Note that `len_to_wrap_long_seq` has to be smaller than the model's `max_seq_len` defined in the architecture (e.g. [`two_tower_diffusion_lcm_1_6B`](../../lcm/models/two_tower_diffusion_lcm/archs.py#L36)`). 33 | 34 | 35 | To filter out long samples without wrapping, you can add `filters` to each dataset config to filter based on the length of the document's list of sentences (`text_sentences`): 36 | ```yaml 37 | - name: "=:" 38 | source_prefix_text: "Beginning of source." 39 | filters: 'pa.compute.less(pa.compute.list_value_length(pa.dataset.field("text_sentences")), 128)' 40 | ``` 41 | ### Checkpointing config 42 | 43 | ```yaml 44 | checkpoint_every_n_steps: 2_000 # QED 45 | keep_last_n_checkpoints: 2 # delete all but last N non-consolidated checkpoints 46 | save_model_every_n_steps: 10_000 # consolidate model every N steps (valid if using FSDP) 47 | preserve_consolidated_models: True # preserve the consolidated checkpoints 48 | ``` 49 | -------------------------------------------------------------------------------- /tests/units/training/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from pathlib import Path 8 | 9 | import pyarrow as pa 10 | import pyarrow.parquet as pq 11 | import pytest 12 | import torch # type: ignore 13 | from stopes.utils.arrow_utils import nested_numpy_to_pyarrow 14 | 15 | from lcm.datasets.configs import ColumnsNames, ParquetDatasetConfig 16 | 17 | 18 | def simple_data(batch_size: int, split: str) -> pa.Table: 19 | # Creating a toy batch 20 | src_len, tgt_len = 17, 11 21 | sonar_dim, sonar_std = 1024, 0.006 22 | 23 | batch = { 24 | ColumnsNames.dataset_name.value: ["_train_dataset"] * batch_size, 25 | "split": [split] * batch_size, 26 | } 27 | table = pa.Table.from_pydict(batch) 28 | 29 | x = torch.randn(size=[batch_size, src_len, sonar_dim]) * sonar_std 30 | y = torch.randn(size=[batch_size, tgt_len, sonar_dim]) * sonar_std 31 | x_pa = nested_numpy_to_pyarrow([row.numpy() for row in x]) 32 | y_pa = nested_numpy_to_pyarrow([row.numpy() for row in y]) 33 | table = table.append_column("dummy_source_column", x_pa) 34 | table = table.append_column("dummy_target_column", y_pa) 35 | return table 36 | 37 | 38 | @pytest.fixture() 39 | def simple_train_dataset(tmp_path: Path): 40 | (tmp_path / "train").mkdir() 41 | pq.write_to_dataset( 42 | simple_data(10, "train"), tmp_path / "train", partition_cols=["split"] 43 | ) 44 | 45 | yield ParquetDatasetConfig( 46 | parquet_path=str(tmp_path / "train"), 47 | source_column="dummy_source_column", 48 | target_column="dummy_target_column", 49 | filesystem_expr="pc.equal(pc.field('split'), 'train')", 50 | ) 51 | 52 | 53 | @pytest.fixture() 54 | def simple_validation_dataset(tmp_path: Path): 55 | (tmp_path / "dev").mkdir() 56 | pq.write_to_dataset( 57 | simple_data(10, "dev"), tmp_path / "dev", partition_cols=["split"] 58 | ) 59 | 60 | yield ParquetDatasetConfig( 61 | parquet_path=str(tmp_path / "dev"), 62 | source_column="dummy_source_column", 63 | target_column="dummy_target_column", 64 | filesystem_expr="pc.equal(pc.field('split'), 'dev')", 65 | ) 66 | -------------------------------------------------------------------------------- /lcm/evaluation/cli/local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import logging 7 | import time 8 | 9 | import torch 10 | 11 | from lcm.evaluation.cli.configs import CliConfig, parse_configs 12 | from lcm.evaluation.run import run_task 13 | from lcm.evaluation.utils.common import ( 14 | flatten_dict, 15 | format_dict, 16 | log_final_results, 17 | write_to_json, 18 | ) 19 | from lcm.evaluation.utils.distributed import get_gang, get_global_rank, rank_zero_info 20 | 21 | logger = logging.getLogger("lcm.evaluation") 22 | 23 | 24 | def main(cfg: CliConfig, logger: logging.Logger = logger) -> None: 25 | run_configs = parse_configs(cfg) 26 | 27 | assert len(run_configs) > 0, f"No tasks were found given pattern '{cfg.tasks}'" 28 | run_ids = [r.name for r in run_configs] 29 | rank_zero_info(f"Selected task execution: {run_ids}") 30 | 31 | all_runs_metrics = {} 32 | 33 | for run_config in run_configs: 34 | name = run_config.name 35 | rank_zero_info(f"Running evaluation on task {name}", logger=logger) 36 | start = time.monotonic() 37 | 38 | metrics, result_file = run_task(run_config, logger=logger, gang=get_gang()) 39 | if run_config.dump_dir is not None and get_global_rank() == 0: 40 | result_content = { 41 | "results": flatten_dict(metrics), 42 | "configs": run_config.params, 43 | } 44 | rank_zero_info(f"Writing metric results to {result_file}", logger=logger) 45 | write_to_json(result_content, result_file, indent=4) 46 | 47 | log = format_dict(flatten_dict(metrics), delimiter=" | ", decimal=6) 48 | rank_zero_info(f"Evaluation results on task {name}: {log}", logger=logger) 49 | rank_zero_info( 50 | f"Task {name} took {time.monotonic() - start:.2f} seconds", logger=logger 51 | ) 52 | all_runs_metrics[name] = metrics 53 | torch.cuda.empty_cache() 54 | 55 | results = flatten_dict(all_runs_metrics) 56 | rank_zero_info(f"All evaluation results: {format_dict(results)}", logger=logger) 57 | log_final_results( 58 | results, cfg.predictor_config, cfg.tb_log_dir, cfg.metric_log_dir, logger 59 | ) 60 | -------------------------------------------------------------------------------- /lcm/train/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from inspect import signature 7 | from typing import Any, Dict, Protocol, Union, runtime_checkable 8 | 9 | import hydra 10 | from omegaconf import DictConfig, OmegaConf, read_write 11 | 12 | from lcm.utils.common import promote_config 13 | 14 | TRAINER_KEY = "_trainer_" 15 | 16 | 17 | @runtime_checkable 18 | class Trainer(Protocol): 19 | """Abstract trainer in LCM""" 20 | 21 | def run(self) -> Any: ... 22 | 23 | 24 | def _parse_training_config(train_config: DictConfig): 25 | """Return the TrainingConfig object from the omegaconf inputs""" 26 | # The train_config should have 2 keys "_target_" and "_trainer_" 27 | # the config is set to read-only within stopes module __init__ 28 | assert TRAINER_KEY in train_config, ( 29 | f"The trainer configuration is missing a {TRAINER_KEY} configuration, " 30 | "you need to specify a Callable to initialize your config." 31 | ) 32 | trainer_cls_or_func = train_config.get(TRAINER_KEY) 33 | try: 34 | trainer_obj = hydra.utils.get_object(trainer_cls_or_func) 35 | sign = signature(trainer_obj) 36 | assert len(sign.parameters) == 1 and "config" in sign.parameters, ( 37 | f'{trainer_cls_or_func} should take a single argument called "config"' 38 | ) 39 | param_type = sign.parameters["config"].annotation 40 | 41 | OmegaConf.resolve(train_config) 42 | with read_write(train_config): 43 | del train_config._trainer_ 44 | 45 | typed_config = promote_config(train_config, param_type) 46 | return trainer_obj, typed_config 47 | except Exception as ex: 48 | raise ValueError( 49 | f"couldnt parse the train config: {train_config}.", str(ex) 50 | ) from ex 51 | 52 | 53 | def get_trainer(train_config: DictConfig) -> Trainer: 54 | trainer_obj, typed_config = _parse_training_config(train_config) 55 | return trainer_obj(typed_config) 56 | 57 | 58 | def _is_missing(config: Union[DictConfig, Dict], attr: str) -> bool: 59 | if isinstance(config, Dict): 60 | return attr in config and config[attr] 61 | if OmegaConf.is_missing(config, attr): 62 | return True 63 | if not hasattr(config, attr) or not getattr(config, attr): 64 | return True 65 | return False 66 | -------------------------------------------------------------------------------- /tests/units/inference/test_base_lcm_kv_caching.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import pytest 7 | import torch 8 | 9 | from lcm.datasets.batch import EmbeddingsBatch 10 | from lcm.inference.lcm import LCMGenerator, LCMGeneratorOptions 11 | from lcm.models.base_lcm import BaseLCModelConfig, create_base_lcm_model 12 | from lcm.nn.transformer import TransformerConfig 13 | 14 | torch.manual_seed(0) 15 | 16 | 17 | @pytest.mark.parametrize("prefix_len", [1, 2, 8]) 18 | def test_kv_caching(prefix_len): 19 | """ 20 | Test that KV caching works as expected. 21 | Special case if prefix_len = 1 in which case 22 | the generator's prefill is a no-op 23 | """ 24 | # Sample input data 25 | batch_size = 1 26 | sonar_embed_dim = 4 27 | max_gen_len = 8 28 | # Create sample input tensor 29 | sample_input = torch.randn(batch_size, prefix_len, sonar_embed_dim) 30 | 31 | # Create an LCM model 32 | model_cfg = BaseLCModelConfig( 33 | sonar_embed_dim=sonar_embed_dim, 34 | model_dim=sonar_embed_dim, 35 | lcm=TransformerConfig( 36 | ffn_inner_dim=4 * sonar_embed_dim, 37 | num_layers=2, 38 | num_attn_heads=1, 39 | ), 40 | ) 41 | 42 | model = create_base_lcm_model(model_cfg) 43 | eos_vec = torch.zeros(sonar_embed_dim) 44 | 45 | generator = LCMGenerator( 46 | model, 47 | eos_vec=eos_vec, 48 | options=LCMGeneratorOptions( 49 | eos_threshold=1, 50 | stop_on_repetition_cosine_threshold=1, 51 | sample_latent_variable=False, 52 | seed=0, 53 | ), 54 | ) 55 | # Generate without caching 56 | lcm_output_no_cache = generator( 57 | EmbeddingsBatch(seqs=sample_input, padding_mask=None), 58 | max_gen_len=max_gen_len, 59 | disable_cache=True, 60 | ) 61 | 62 | # Enable KV caching 63 | lcm_output_with_cache = generator( 64 | EmbeddingsBatch(seqs=sample_input, padding_mask=None), 65 | max_gen_len=max_gen_len, 66 | disable_cache=False, 67 | ) 68 | 69 | # Check if the outputs are equal (indicating successful caching) 70 | assert torch.allclose( 71 | lcm_output_no_cache.hypotheses[0][0].seq, 72 | lcm_output_with_cache.hypotheses[0][0].seq, 73 | atol=1e-3, 74 | ), "Outputs with and without caching do not match" 75 | -------------------------------------------------------------------------------- /lcm/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from dataclasses import fields 8 | from typing import Any, List, Mapping 9 | 10 | from fairseq2.typing import DataClass, is_dataclass_instance 11 | 12 | 13 | def update_dataclass( 14 | obj: DataClass, 15 | overrides: Mapping[str, Any], 16 | ) -> List[str]: 17 | """Update ``obj`` with the data contained in ``overrides`` Return the unknown fields. 18 | Copied from an old version of fairseq2 with simplification. 19 | 20 | :param obj: 21 | The data class instance to update. 22 | :param overrides: 23 | The dictionary containing the data to set in ``obj``. 24 | """ 25 | 26 | unknown_fields: List[str] = [] 27 | 28 | field_path: List[str] = [] 29 | 30 | # The dataset config has a special attribute `silent_freeze` that does not allow hard update 31 | forbidden_fields_ = ["silent_freeze"] 32 | 33 | def update(obj_: DataClass, overrides_: Mapping[str, Any]) -> None: 34 | overrides_copy = {**overrides_} 35 | 36 | for field in fields(obj_): 37 | if field.name in forbidden_fields_: 38 | continue 39 | value = getattr(obj_, field.name) 40 | 41 | try: 42 | override = overrides_copy.pop(field.name) 43 | except KeyError: 44 | continue 45 | 46 | # Recursively traverse child dataclasses. 47 | if override is not None and is_dataclass_instance(value): 48 | if not isinstance(override, Mapping): 49 | pathname = ".".join(field_path + [field.name]) 50 | 51 | raise RuntimeError( 52 | pathname, 53 | f"The field '{pathname}' is expected to be of type `{type(value)}`, but is of type `{type(override)}` instead.", # fmt: skip 54 | ) 55 | 56 | field_path.append(field.name) 57 | 58 | update(value, override) 59 | 60 | field_path.pop() 61 | else: 62 | setattr(obj_, field.name, override) 63 | 64 | if overrides_copy: 65 | unknown_fields.extend( 66 | ".".join(field_path + [name]) for name in overrides_copy 67 | ) 68 | 69 | update(obj, overrides) 70 | 71 | unknown_fields.sort() 72 | 73 | return unknown_fields 74 | -------------------------------------------------------------------------------- /lcm/evaluation/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import inspect 7 | from typing import Optional, Type 8 | 9 | from lcm.utils.common import promote_config 10 | 11 | from ..api import Predictor, PredictorConfig 12 | 13 | _PREDICTOR_CONFIG_MAP = { 14 | "dummy": "lcm.evaluation.predictors.dummy.DummyPredictorConfig", 15 | "dummy_judge": "lcm.evaluation.predictors.dummy.DummyJudgeConfig", 16 | "llama3": "lcm.evaluation.predictors.llama.HFLlamaPredictorConfig", 17 | "base_lcm": "lcm.evaluation.predictors.lcm.LCMConfig", 18 | "two_tower_diffusion_lcm": "lcm.evaluation.predictors.two_tower_diffusion_lcm.TwoTowerDiffusionLCMConfig", 19 | "huggingface": "lcm.evaluation.predictors.huggingface.HuggingfacePredictorConfig", 20 | "gemma": "lcm.evaluation.predictors.gemma.GemmaPredictorConfig", 21 | } 22 | 23 | 24 | def get_config_cls(name: str) -> Type[PredictorConfig]: 25 | if name not in _PREDICTOR_CONFIG_MAP: 26 | raise ValueError(f"No predictor registered under the name {name}") 27 | 28 | module_path, config_cls_name = _PREDICTOR_CONFIG_MAP[name].rsplit(".", 1) 29 | module = __import__(module_path, fromlist=[config_cls_name]) 30 | return getattr(module, config_cls_name) 31 | 32 | 33 | def build_predictor( 34 | predictor_config: PredictorConfig, 35 | predictor_type: Optional[str] = None, 36 | **kwargs, 37 | ) -> Predictor: 38 | """ 39 | The factory function that loads the predictor from its config. The config can be 40 | a real config class, or a duck-typed config (e.g. loaded via Hydra) 41 | """ 42 | if isinstance(predictor_config, PredictorConfig): 43 | config_cls = predictor_config.__class__ 44 | else: 45 | assert predictor_type is not None, ( 46 | f"Cannot infer predictor from config type {type(predictor_config)}" 47 | ) 48 | config_cls = get_config_cls(predictor_type) 49 | predictor_config = promote_config(predictor_config, config_cls) 50 | 51 | predictor_cls: Predictor = config_cls.predictor_class() 52 | sig = inspect.signature(predictor_cls.from_config) 53 | params = sig.parameters.values() 54 | has_kwargs = any([True for p in params if p.kind == p.VAR_KEYWORD]) 55 | if has_kwargs: 56 | return predictor_cls.from_config(predictor_config, **kwargs) 57 | else: 58 | return predictor_cls.from_config(predictor_config) 59 | -------------------------------------------------------------------------------- /lcm/nn/projection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | from fairseq2.nn.projection import Linear 11 | from fairseq2.typing import DataType, Device 12 | from torch import Tensor 13 | from torch.nn import Module 14 | 15 | from lcm.nn.initialization import ( 16 | SUPPORTED_INIT_TYPES, 17 | get_init_fn, 18 | parse_activation_fn, 19 | ) 20 | from lcm.nn.normalization import SUPPORTED_LN_TYPES 21 | 22 | 23 | @dataclass 24 | class ProjectionConfig: 25 | dropout_p: float = 0.0 26 | """ The dropout probability applied to the module' output""" 27 | 28 | linear_bias: bool = True 29 | """ Whether or not the pre-linear layer has a bias term""" 30 | 31 | linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform" 32 | 33 | weight_normalization: bool = False 34 | 35 | layer_normalization_style: SUPPORTED_LN_TYPES = "standard" 36 | 37 | activation_name: Optional[str] = None 38 | """the activation function to apply after fi any""" 39 | 40 | 41 | class Projection(Module): 42 | """ 43 | An output projecton module. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | output_dim: int, 49 | input_dim: int, 50 | config: ProjectionConfig, 51 | device: Optional[Device] = None, 52 | dtype: Optional[DataType] = None, 53 | ) -> None: 54 | super().__init__() 55 | 56 | self.dtype = dtype 57 | 58 | init_fn = get_init_fn(config.linear_init_fn) 59 | 60 | lin = Linear( 61 | input_dim, 62 | output_dim, 63 | bias=config.linear_bias, 64 | device=device, 65 | dtype=dtype, 66 | init_fn=init_fn, 67 | ) 68 | if config.weight_normalization: 69 | self.fc = torch.nn.utils.parametrizations.weight_norm(lin) 70 | else: 71 | self.fc = lin 72 | 73 | self.activation_fn = parse_activation_fn(config.activation_name) 74 | 75 | if self.activation_fn is not None: 76 | # some activation functions (e.g., PReLU) have parameters 77 | # and so we need to move them to the right device 78 | self.activation_fn.to(device) 79 | 80 | def forward(self, seqs: Tensor): 81 | seqs = self.fc(seqs) 82 | 83 | if self.activation_fn is not None: 84 | seqs = self.activation_fn(seqs) 85 | 86 | return seqs 87 | -------------------------------------------------------------------------------- /lcm/utils/model_type_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | from dataclasses import dataclass 8 | from typing import Callable, Dict 9 | 10 | 11 | @dataclass 12 | class ModelTypeConfig: 13 | """A container for all functions associated with a specific model type.""" 14 | 15 | model_type: str 16 | config_loader: Callable 17 | model_factory: Callable 18 | model_loader: Callable 19 | 20 | 21 | class ModelTypeRegistry: 22 | """ 23 | Represents a registry of model types. 24 | In fairseq2 terms, "architecture" refers to a set of model hyperparameters, 25 | and "model type" refers to a more generic way of constructing the model with the given hyperparameters. 26 | """ 27 | 28 | _configs: Dict[str, ModelTypeConfig] 29 | 30 | def __init__(self) -> None: 31 | self._configs = {} 32 | 33 | def register(self, model_type_config: ModelTypeConfig) -> None: 34 | """Register a new architecture. 35 | 36 | :param arch_name: 37 | The name of the architecture. 38 | :param config_factory: 39 | The factory to construct model configurations. 40 | """ 41 | model_type = model_type_config.model_type 42 | assert model_type, ( 43 | "To register a model type, the model_type parameter should be non-empty." 44 | ) 45 | if model_type in self._configs: 46 | raise ValueError( 47 | f"`model_type` must be a unique model type name, but '{model_type}' is already registered." 48 | ) 49 | self._configs[model_type] = model_type_config 50 | 51 | def get_config(self, model_type: str) -> ModelTypeConfig: 52 | """Return the ModelTypeConfig for the specified model type. 53 | 54 | :param model_type: 55 | The model type. 56 | """ 57 | # we import lcm.modules at runtime in order to populate the registy and avoid cyclical imports 58 | 59 | try: 60 | return self._configs[model_type] 61 | except KeyError: 62 | raise ValueError( 63 | f"The registry of model types does not contain a model type named '{model_type}'." 64 | ) 65 | 66 | def get_model_loader(self, model_type: str) -> Callable: 67 | """Get a model loader function for the given model type.""" 68 | model_type_config = self.get_config(model_type) 69 | return model_type_config.model_loader 70 | 71 | 72 | lcm_model_type_registry = ModelTypeRegistry() 73 | -------------------------------------------------------------------------------- /lcm/evaluation/tasks/lcm_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | # 6 | 7 | from typing import Optional 8 | 9 | from lcm.datasets.configs import ParquetDatasetConfig 10 | from lcm.evaluation.api import EOSConfig 11 | from lcm.evaluation.metrics.common import rouge_score 12 | from lcm.evaluation.tasks import register_task 13 | from lcm.evaluation.tasks.base import GenerationTaskConfig 14 | from lcm.evaluation.utils.common import evaluate 15 | from lcm.evaluation.utils.data_utils import ( 16 | ParquetTestDataLoader, 17 | default_embed_prompt, 18 | default_lcm_postprocess, 19 | ) 20 | 21 | 22 | @register_task( 23 | "lcm_generation", 24 | data_loader_type=ParquetTestDataLoader, 25 | ) 26 | def get_task_config_lcm( 27 | dataset: ParquetDatasetConfig, 28 | max_gen_len: int = 128, 29 | max_gen_len_ratio: Optional[float] = None, 30 | max_prompt_len: int = 2048, 31 | eos_config: Optional[EOSConfig] = None, 32 | ) -> GenerationTaskConfig: 33 | return GenerationTaskConfig( 34 | dataset=dataset, 35 | prompt_func=default_embed_prompt, # type: ignore 36 | postprocess_fn=default_lcm_postprocess, 37 | metric_fns=[ 38 | evaluate( 39 | rouge_score, 40 | outputs=("rouge2", "rougeL", "rougeLsum"), 41 | types=("rouge2", "rougeL", "rougeLsum"), 42 | ) 43 | ], 44 | max_gen_len=max_gen_len, 45 | max_gen_len_ratio=max_gen_len_ratio, 46 | max_prompt_len=max_prompt_len, 47 | eos_config=eos_config, 48 | ) 49 | 50 | 51 | @register_task( 52 | "finetuning_data_lcm.validation", 53 | data_loader_type=ParquetTestDataLoader, 54 | ) 55 | def get_validation_task_config_lcm( 56 | dataset: ParquetDatasetConfig, 57 | max_gen_len: int = 128, 58 | max_gen_len_ratio: Optional[float] = None, 59 | max_prompt_len: int = 2048, 60 | eos_config: Optional[EOSConfig] = None, 61 | ) -> GenerationTaskConfig: 62 | dataset.name = "finetuning_data=validation" 63 | return GenerationTaskConfig( 64 | dataset=dataset, 65 | prompt_func=default_embed_prompt, # type: ignore 66 | postprocess_fn=default_lcm_postprocess, 67 | metric_fns=[ 68 | evaluate( 69 | rouge_score, 70 | outputs=("rouge2", "rougeL", "rougeLsum"), 71 | types=("rouge2", "rougeL", "rougeLsum"), 72 | ) 73 | ], 74 | max_gen_len=max_gen_len, 75 | max_gen_len_ratio=max_gen_len_ratio, 76 | max_prompt_len=max_prompt_len, 77 | eos_config=eos_config, 78 | ) 79 | -------------------------------------------------------------------------------- /lcm/train/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from typing import Tuple 7 | 8 | from fairseq2.logging import get_log_writer 9 | from fairseq2.optim.lr_scheduler import ( 10 | AbstractLRScheduler, 11 | CosineAnnealingLR, 12 | MyleLR, 13 | NoopLR, 14 | PolynomialDecayLR, 15 | TriStageLR, 16 | ) 17 | from torch.optim import Optimizer 18 | 19 | logger = get_log_writer(__name__) 20 | 21 | 22 | def build_lr_scheduler( 23 | optimizer: Optimizer, 24 | lr: float, 25 | warmup_steps: int, 26 | start_lr: float = 1e-7, 27 | final_lr: float = 1e-5, 28 | max_steps: int = 10_000, 29 | stage_ratio: Tuple[float, ...] = (0.1, 0.4, 0.5), 30 | schedule: str = "myle", 31 | ) -> AbstractLRScheduler: 32 | assert schedule in [ 33 | "noop", 34 | "myle", 35 | "cosine", 36 | "wsd", 37 | "polynomial", 38 | ], ( 39 | f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported" 40 | ) 41 | 42 | assert lr > 0, "The learning reate should be strictly positive" 43 | 44 | lr_scheduler: AbstractLRScheduler 45 | 46 | if schedule == "noop": 47 | lr_scheduler = NoopLR(optimizer) 48 | 49 | elif schedule == "myle": 50 | lr_scheduler = MyleLR( 51 | optimizer, 52 | num_warmup_steps=warmup_steps, 53 | start_lr=[start_lr], 54 | ) 55 | 56 | elif schedule == "cosine": 57 | lr_scheduler = CosineAnnealingLR( 58 | optimizer, 59 | cycle_len=max_steps - warmup_steps + 1, 60 | num_warmup_steps=warmup_steps, 61 | start_lr=[start_lr], 62 | final_lr=[final_lr], 63 | cycle_mul=1.0, 64 | lr_mul=1.0, 65 | ) 66 | 67 | elif schedule == "wsd": 68 | assert lr > start_lr, ( 69 | f"the starting learning rate {start_lr} should be lesser than the main lr {lr}" 70 | ) 71 | start_lr_scale = start_lr / lr 72 | 73 | assert lr > final_lr, ( 74 | f"the final learning rate {final_lr} should be lesser than the main lr {lr}" 75 | ) 76 | final_lr_scale = final_lr / lr 77 | 78 | lr_scheduler = TriStageLR( 79 | optimizer, 80 | max_steps, 81 | stage_ratio=stage_ratio, # type: ignore 82 | start_lr_scale=start_lr_scale, 83 | final_lr_scale=final_lr_scale, 84 | ) 85 | 86 | elif schedule == "polynomial": 87 | lr_scheduler = PolynomialDecayLR( 88 | optimizer, 89 | max_steps, 90 | warmup_steps, 91 | power=200, 92 | start_lr=start_lr, 93 | final_lr=final_lr, 94 | ) 95 | 96 | return lr_scheduler 97 | -------------------------------------------------------------------------------- /lcm/nn/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from typing import Literal, Optional, final 7 | 8 | import torch 9 | from fairseq2.nn import LayerNorm, RMSNorm, StandardLayerNorm 10 | from fairseq2.nn.transformer import LayerNormFactory, create_standard_layer_norm 11 | from fairseq2.typing import DataType, Device, override 12 | 13 | SUPPORTED_LN_TYPES = Literal["standard", "fp32", "rms", "unit"] 14 | 15 | 16 | @final 17 | class FP32LayerNorm(LayerNorm): 18 | """Applies Layer Normalization in single-precision.""" 19 | 20 | @override 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | w, b = self.weight, self.bias 23 | 24 | # cast input and params to float32 25 | fp32_x = x.float() 26 | fp32_w = w.float() if w is not None else None 27 | fp32_b = b.float() if b is not None else None 28 | 29 | y = torch.nn.functional.layer_norm( 30 | fp32_x, self.normalized_shape, fp32_w, fp32_b, self.eps 31 | ) 32 | 33 | return y.type_as(x) 34 | 35 | 36 | def build_rms_layer_norm( 37 | model_dim: int, 38 | *, 39 | device: Optional[Device] = None, 40 | dtype: Optional[DataType] = None, 41 | ) -> LayerNorm: 42 | """Build an RMS Layer Normalization module.""" 43 | return RMSNorm(model_dim, bias=False, device=device, dtype=dtype) 44 | 45 | 46 | def build_fp32_layer_norm( 47 | model_dim: int, 48 | *, 49 | device: Optional[Device] = None, 50 | dtype: Optional[DataType] = None, 51 | ) -> LayerNorm: 52 | """Build an Single-precision Layer Normalization module.""" 53 | return FP32LayerNorm(model_dim, bias=False, device=device, dtype=dtype) 54 | 55 | 56 | def build_unit_layer_norm( 57 | model_dim: int, 58 | *, 59 | device: Optional[Device] = None, 60 | dtype: Optional[DataType] = None, 61 | ) -> LayerNorm: 62 | """Create an instance of :class:`StandardLayerNorm 63 | without learnable mean and variance`.""" 64 | return StandardLayerNorm( 65 | model_dim, 66 | bias=False, 67 | elementwise_affine=False, 68 | device=device, 69 | dtype=dtype, 70 | ) 71 | 72 | 73 | def parse_layer_norm_factory(layer_normalization_style: str) -> LayerNormFactory: 74 | if layer_normalization_style == "rms": 75 | # Note that RMSNorm normalizes in single-precision by default 76 | return build_rms_layer_norm 77 | 78 | elif layer_normalization_style == "unit": 79 | return build_unit_layer_norm 80 | 81 | elif layer_normalization_style == "fp32": 82 | return build_fp32_layer_norm 83 | 84 | elif layer_normalization_style == "standard": 85 | return create_standard_layer_norm 86 | 87 | else: 88 | raise ValueError(f"Unsupported LayerNorm style {layer_normalization_style}") 89 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_model_based_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | import os 8 | 9 | import pytest 10 | from fairseq2.gang import FakeGang 11 | 12 | from lcm.evaluation.api import Scorer, ScorerConfig 13 | from lcm.evaluation.metrics import get_scorer 14 | from lcm.evaluation.utils.data_utils import load_jsonl 15 | from tests.common import device 16 | 17 | REFERENCE_FREE_METRICS = [ 18 | "sentence_fluency", 19 | # "sentence_perplexity", # skip due to recent error accessing the pbulic model in HF hub 20 | # "round_trip_translation", # skip due to large model, run in a separate node 21 | "word_repetition", 22 | "token_repetition", 23 | ] 24 | 25 | IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS", "") == "true" 26 | 27 | 28 | @pytest.mark.parametrize("scorer_type", REFERENCE_FREE_METRICS) 29 | def test_reference_free_scorer(simple_json_dataset, scorer_type): 30 | examples = load_jsonl(simple_json_dataset)[:3] 31 | config = ScorerConfig( 32 | scorer_type=scorer_type, 33 | inputs="input_text", # type: ignore 34 | params={"batch_size": 3}, 35 | ) 36 | metric_fn = get_scorer(config, gang=FakeGang(device=device)) 37 | assert isinstance(metric_fn, Scorer) 38 | result = metric_fn(examples) 39 | for metric_name in metric_fn.outputs: 40 | assert metric_name in result and len(result[metric_name]) == 3 41 | 42 | 43 | @pytest.mark.skipif( 44 | IN_GITHUB_ACTIONS, reason="Skip tests that download big models in CI" 45 | ) 46 | def test_round_trip_translation(simple_json_dataset): 47 | examples = load_jsonl(simple_json_dataset)[:3] 48 | config = ScorerConfig( 49 | scorer_type="round_trip_translation", 50 | inputs="input_text", # type: ignore 51 | params={"batch_size": 3}, 52 | ) 53 | metric_fn = get_scorer(config, gang=FakeGang(device=device)) 54 | assert isinstance(metric_fn, Scorer) 55 | result = metric_fn(examples) 56 | for metric_name in metric_fn.outputs: 57 | assert metric_name in result and len(result[metric_name]) == 3 58 | 59 | 60 | @pytest.mark.skip(reason="long runtime") 61 | @pytest.mark.parametrize("question_id", range(1, 7)) 62 | def test_seahorse(simple_json_dataset, question_id): 63 | examples = load_jsonl(simple_json_dataset)[:3] 64 | config = ScorerConfig( 65 | scorer_type="seahorse", 66 | model_name=f"google/seahorse-large-q{question_id}", 67 | inputs=("input_text", "target_text"), # type: ignore 68 | params={"batch_size": 3}, 69 | ) 70 | metric_fn = get_scorer(config, gang=FakeGang(device=device)) 71 | assert isinstance(metric_fn, Scorer) 72 | result = metric_fn(examples) 73 | 74 | assert ( 75 | f"seahorse-q{question_id}" in result 76 | and len(result[f"seahorse-q{question_id}"]) == 3 77 | ) 78 | -------------------------------------------------------------------------------- /lcm/evaluation/tasks/xsum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | # 6 | 7 | from functools import partial 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | from lcm.datasets.configs import JSONDatasetConfig 12 | from lcm.evaluation.metrics.common import ngram_score, rouge_score 13 | from lcm.evaluation.tasks import register_task 14 | from lcm.evaluation.tasks.base import GenerationTaskConfig 15 | from lcm.evaluation.utils.common import evaluate 16 | from lcm.evaluation.utils.data_utils import ( 17 | default_text_postprocess, 18 | default_text_prompt, 19 | ) 20 | 21 | SPLITS = ["test", "validation", "train"] 22 | FORMS = ["", "inverse_"] 23 | 24 | 25 | @register_task("xsum_{form}llm.{split}", {"split": SPLITS, "form": FORMS}) 26 | def get_task_config_llm( 27 | dataset: JSONDatasetConfig, 28 | dataset_dir: str, 29 | split: str, 30 | form: str, 31 | min_gen_len: int = 4, 32 | max_gen_len: int = 512, 33 | max_gen_len_ratio: Optional[float] = None, 34 | max_prompt_len: int = 4096, 35 | ) -> GenerationTaskConfig: 36 | file_path = f"{dataset_dir}/xsum/{split}.jsonl" 37 | 38 | # In case the user specifies the directory that point directly to the task dir 39 | if not Path(file_path).exists(): 40 | file_path = f"{dataset_dir}/{split}.jsonl" 41 | 42 | assert Path(file_path).exists(), f"{file_path} not found." 43 | 44 | dataset.file_path = file_path 45 | 46 | # Default prompt if not specified by the user. Use Llama-3.1 prompts by default 47 | if form != "inverse_": 48 | source_text_column = "document" 49 | target_text_column = "summary" 50 | else: 51 | source_text_column = "summary" 52 | target_text_column = "document" 53 | 54 | dataset.source_text_column = source_text_column 55 | dataset.target_text_column = target_text_column 56 | 57 | # Add original columns for judge tasks 58 | dataset.columns = [dataset.source_text_column, dataset.target_text_column] 59 | 60 | postprocess_fn = partial( 61 | default_text_postprocess, source_text_column=source_text_column 62 | ) 63 | 64 | return GenerationTaskConfig( 65 | dataset=dataset, 66 | prompt_func=default_text_prompt, 67 | postprocess_fn=postprocess_fn, 68 | metric_fns=[ 69 | evaluate( 70 | rouge_score, 71 | outputs=("rouge2", "rougeL"), 72 | types=("rouge2", "rougeL"), 73 | ), 74 | evaluate( 75 | ngram_score, 76 | inputs=("prediction", "source"), 77 | outputs=("ngram_overlap", "repetition_4"), 78 | ), 79 | ], 80 | min_gen_len=min_gen_len, 81 | max_gen_len=max_gen_len, 82 | max_gen_len_ratio=max_gen_len_ratio, 83 | max_prompt_len=max_prompt_len, 84 | ) 85 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | # fmt: off 6 | 7 | from typing import List 8 | 9 | import pytest 10 | 11 | from lcm.evaluation.metrics import ( 12 | bleu, 13 | exact_match, 14 | exact_match_f1, 15 | rouge_score, 16 | sentence_bleu, 17 | ) 18 | from lcm.evaluation.utils.common import evaluate 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "x,y,expected", [("yes", ["yes", "YES"], 1.0), ("no", ["yes", "YES"], 0.0)] 23 | ) 24 | def test_exact_match(x, y, expected) -> None: 25 | assert exact_match(x, y) == expected 26 | 27 | 28 | def test_exact_match_f1() -> None: 29 | precision, recall = exact_match_f1("yes", ["yes", "YES"]) 30 | assert precision == 1.0 31 | assert recall == 1.0 32 | precision, recall = exact_match_f1("no", ["yes", "YES"]) 33 | assert precision == 0.0 34 | assert recall == 0.0 35 | 36 | 37 | @pytest.mark.skip(reason="flaky test") 38 | def test_rouge_score() -> None: 39 | types: List[str] = ["rouge1", "rouge2", "rougeL"] 40 | 41 | reference = "The cat is on the mat" 42 | candidate = ["The cat is on the mat", "A dog is near the mat"] 43 | 44 | result = rouge_score(reference, candidate, types) 45 | assert list(result.values()) == [1.0, 1.0, 1.0] 46 | 47 | candidate = ["A dog is near the mat"] 48 | results = rouge_score(reference, candidate, types) 49 | 50 | assert results["rouge1"] == 0.5 51 | assert round(result["rouge2"], 1) == 1.0 52 | assert results["rougeL"] == 0.5 53 | 54 | 55 | def test_bleu() -> None: 56 | x: str = "The quick brown fox jumps over the lazy dog." 57 | y: List[str] = [ 58 | "A quick brown fox jumps over a lazy dog.", 59 | "The fast brown fox jumps over a sleeping dog.", 60 | ] 61 | expected_bleu = 52.53 62 | assert pytest.approx(bleu(x, y), abs=0.02) == expected_bleu 63 | 64 | expected_bleu = 54.1 65 | assert pytest.approx(sentence_bleu(x, y), abs=0.02) == expected_bleu 66 | 67 | 68 | def test_evaluate() -> None: 69 | """Test that a python function can be wrapped into a MetricFn successfully""" 70 | 71 | examples = { 72 | "prediction": [ 73 | "Billy Bob . They are on trial for tax fraud", 74 | "Billy Bob . They are on trial for tax fraud", 75 | ], 76 | "targets": [ 77 | ["Billy Bob . Are they really on trial for tax"], 78 | ["Billy Bob . They are on trial for tax fraud"], 79 | ], 80 | } 81 | expected_outputs = { 82 | "exact_match": [0.0, 1.0], 83 | "f1": [0.7, 1.0], 84 | } 85 | 86 | metric_fn = evaluate( 87 | exact_match_f1, 88 | inputs=("prediction", "targets"), 89 | outputs=("exact_match", "f1"), 90 | collate=True, 91 | ) 92 | 93 | assert metric_fn(examples) == expected_outputs 94 | 95 | # fmt: on 96 | -------------------------------------------------------------------------------- /lcm/evaluation/utils/hf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import logging 7 | import os 8 | import time 9 | from pathlib import Path 10 | from typing import Dict, Optional, Union 11 | 12 | import hydra 13 | import torch 14 | 15 | from lcm.evaluation.utils.distributed import get_local_rank 16 | from lcm.utils.common import torch_type 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def infer_cache_dir() -> Optional[str]: 22 | if os.getenv("HF_HOME"): 23 | # or use `HF_HUB_CACHE` 24 | logger.info(f"Using env HF_HOME={os.environ['HF_HOME']}") 25 | return f"{os.environ['HF_HOME']}/hub" 26 | return None 27 | 28 | 29 | def infer_offline_model_dir( 30 | dtype: Optional[Union[str, torch.dtype]] = None, 31 | ) -> Optional[str]: 32 | if os.getenv("HF_DOWNLOAD"): 33 | return os.environ["HF_DOWNLOAD"] 34 | return None 35 | 36 | 37 | def download_model( 38 | model_name: str, 39 | model_class: str = "AutoModelForCausalLM", 40 | tokenizer_class: str = "AutoTokenizer", 41 | model_dtype: str = "torch.float32", 42 | model_dir: Optional[str] = None, 43 | ) -> None: 44 | token = os.getenv("HF_AUTH_TOKEN") 45 | assert token is not None, "set HF_AUTH_TOKEN path please." 46 | 47 | from huggingface_hub import login 48 | 49 | login(token=token) 50 | 51 | if model_dir is None: 52 | model_dir = infer_offline_model_dir(model_dtype) 53 | assert isinstance(model_dir, str), "Unknown model_dir" 54 | 55 | start_time = time.time() 56 | dtype = torch_type(model_dtype) 57 | model_cls = hydra.utils.get_class(f"transformers.{model_class}") 58 | tokenizer_cls = hydra.utils.get_class(f"transformers.{tokenizer_class}") 59 | 60 | tokenizer = tokenizer_cls.from_pretrained(model_name) # type: ignore 61 | model = model_cls.from_pretrained( # type: ignore 62 | model_name, trust_remote_code=True, torch_dtype=dtype 63 | ) 64 | 65 | Path(model_dir).joinpath(model_name).mkdir(parents=True, exist_ok=True) 66 | print(f"Saving model to {model_dir}") 67 | 68 | model.save_pretrained(Path(model_dir).joinpath(model_name)) 69 | tokenizer.save_pretrained(Path(model_dir).joinpath(model_name)) 70 | print(f"Finish downloading in {time.time() - start_time} seconds.") 71 | 72 | 73 | def infer_hf_device_memory(model_parallel: int) -> Dict[int, int]: 74 | """Infers maximum memory allocation for each GPU in model parallel group.""" 75 | gpus_per_node = torch.cuda.device_count() 76 | start = model_parallel * get_local_rank() % gpus_per_node 77 | end = start + model_parallel 78 | max_memory = { 79 | i: torch.cuda.mem_get_info(i)[0] if start <= i < end else 0 80 | for i in range(gpus_per_node) 81 | } 82 | return max_memory 83 | 84 | 85 | if __name__ == "__main__": 86 | from fire import Fire 87 | 88 | Fire(download_model) 89 | -------------------------------------------------------------------------------- /lcm/models/abstract_lcm/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from abc import abstractmethod 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | from fairseq2.config_registry import ConfigRegistry 11 | from fairseq2.logging import get_log_writer 12 | from fairseq2.typing import DataType, Device 13 | from torch.nn import Module 14 | 15 | from lcm.models.sonar_normalizer import SonarNormalizer, load_sonar_normalizer_model 16 | 17 | logger = get_log_writer(__name__) 18 | 19 | 20 | """ 21 | An abstract LCM model class for the bare minimum 22 | """ 23 | 24 | ABSTRACT_LCM_MODEL_TYPE = "abstract_lcm" 25 | 26 | 27 | @dataclass 28 | class AbstractLCModelConfig: 29 | model_type: str = ABSTRACT_LCM_MODEL_TYPE 30 | 31 | sonar_embed_dim: int = 1024 32 | 33 | sonar_normalizer_name: Optional[str] = None 34 | 35 | 36 | lcm_archs = ConfigRegistry[AbstractLCModelConfig]() 37 | lcm_arch = lcm_archs.decorator 38 | 39 | 40 | class AbstractLCModel(Module): 41 | """Asbtract Class for LCM models""" 42 | 43 | def __init__( 44 | self, 45 | config: AbstractLCModelConfig, 46 | ) -> None: 47 | """ 48 | Asbtract LCM model 49 | """ 50 | super().__init__() 51 | 52 | self.config = config 53 | 54 | @property 55 | def dtype(self): 56 | return next(self.parameters()).dtype 57 | 58 | @property 59 | def device(self): 60 | return next(self.parameters()).device 61 | 62 | 63 | class AbstractLCModelBuilder: 64 | """Builds modules of an LCM""" 65 | 66 | config: AbstractLCModelConfig 67 | device: Optional[Device] 68 | dtype: Optional[DataType] 69 | 70 | def __init__( 71 | self, 72 | config: AbstractLCModelConfig, 73 | *, 74 | device: Optional[Device] = None, 75 | dtype: Optional[DataType] = None, 76 | ) -> None: 77 | """ 78 | :param config: 79 | The configuration. 80 | :param device: 81 | The device on which to initialize modules. 82 | :param dtype: 83 | The data type of module parameters and buffers. 84 | """ 85 | self.config = config 86 | 87 | self.device, self.dtype = device, dtype 88 | 89 | def build_sonar_normalizer( 90 | self, 91 | ) -> Optional[SonarNormalizer]: 92 | if self.config.sonar_normalizer_name is not None: 93 | logger.info( 94 | f"Building sonar_normalizer = {self.config.sonar_normalizer_name}" 95 | ) 96 | return load_sonar_normalizer_model( 97 | self.config.sonar_normalizer_name, 98 | device=self.device, 99 | dtype=self.dtype, 100 | ) 101 | return None 102 | 103 | @abstractmethod 104 | def build_model(self) -> AbstractLCModel: 105 | """Build a model.""" 106 | ... 107 | -------------------------------------------------------------------------------- /lcm/train/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from abc import abstractmethod 7 | from dataclasses import dataclass 8 | from typing import Any, Callable, Dict, List, Literal 9 | 10 | from fairseq2.logging import get_log_writer 11 | from omegaconf import MISSING 12 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 13 | FullyShardedDataParallel as FSDP, 14 | ) 15 | from torch.nn import Module 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | from lcm.train.metrics import LossTerm 19 | 20 | logger = get_log_writer(__name__) 21 | 22 | 23 | @dataclass 24 | class CriterionConfig: 25 | """A dataclass for criterion parameters""" 26 | 27 | name: str = MISSING 28 | """Name of the criterion, a unique identifier used in the CriterionsFactory""" 29 | 30 | reduction: Literal["sum", "mean"] = "sum" 31 | """How to reduce the loss across samples""" 32 | 33 | 34 | class Criterion: 35 | """And abstract class for training criterions""" 36 | 37 | def __init__( 38 | self, 39 | config: CriterionConfig, 40 | model: Module, 41 | ): 42 | self.config = config 43 | 44 | self.model = model 45 | 46 | self.summands: List[str] = [] 47 | """ A list of loss term names to track during training. 48 | This will create metric bags for each 49 | """ 50 | 51 | self.reduction = config.reduction 52 | 53 | @property 54 | def throughput_metric_name(self) -> str: 55 | return "num_target_elements" 56 | 57 | @property 58 | def base_model(self): 59 | """A pointer to the unwrapped model if training with FSDP/DDP""" 60 | if isinstance(self.model, (DDP, FSDP)): 61 | _model = self.model.module 62 | else: 63 | _model = self.model 64 | return _model 65 | 66 | @abstractmethod 67 | def __call__(self, batch) -> LossTerm: 68 | """ 69 | Computes the loss given an input batch. 70 | The model's forward pass is performed here 71 | """ 72 | 73 | 74 | class CriterionsFactory: 75 | """Factory for LCM criterions""" 76 | 77 | registry: Dict[str, Any] = {} 78 | 79 | @classmethod 80 | def build_criterion(cls, name: str, **kwargs) -> Any: 81 | """build the criterion of choice from within the trainer""" 82 | 83 | criterion_class = cls.registry[name] 84 | 85 | criterion = criterion_class(**kwargs) 86 | 87 | return criterion 88 | 89 | @classmethod 90 | def register(cls, name: str) -> Callable: 91 | """decorator for adding criterions to the registry""" 92 | 93 | def inner_wrapper(wrapped_class: Criterion) -> Callable: 94 | assert name not in cls.registry, ( 95 | f"{name} is already register as a criterion" 96 | ) 97 | cls.registry[name] = wrapped_class 98 | return wrapped_class 99 | 100 | return inner_wrapper 101 | -------------------------------------------------------------------------------- /lcm/evaluation/tasks/cnn_dailymail.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | # 6 | 7 | from functools import partial 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | from lcm.datasets.configs import JSONDatasetConfig 12 | from lcm.evaluation.metrics.common import ngram_score, rouge_score 13 | from lcm.evaluation.tasks import register_task 14 | from lcm.evaluation.tasks.base import GenerationTaskConfig 15 | from lcm.evaluation.utils.common import evaluate 16 | from lcm.evaluation.utils.data_utils import ( 17 | default_text_postprocess, 18 | default_text_prompt, 19 | ) 20 | 21 | SPLITS = ["test", "validation", "train"] 22 | FORMS = ["", "inverse_"] 23 | 24 | 25 | @register_task("cnn_dailymail_{form}llm.{split}", {"split": SPLITS, "form": FORMS}) 26 | def get_task_config_llm( 27 | dataset: JSONDatasetConfig, 28 | dataset_dir: str, 29 | split: str, 30 | form: str, 31 | min_gen_len: int = 10, 32 | max_gen_len: int = 512, 33 | max_gen_len_ratio: Optional[float] = None, 34 | max_prompt_len: int = 4096, 35 | ) -> GenerationTaskConfig: 36 | file_path = f"{dataset_dir}/cnn_daily_mail/{split}.jsonl" 37 | 38 | # In case the user specifies the directory that point directly to the task dir 39 | if not Path(file_path).exists(): 40 | file_path = f"{dataset_dir}/{split}.jsonl" 41 | 42 | assert Path(file_path).exists(), f"{file_path} not found." 43 | 44 | dataset.file_path = file_path 45 | 46 | if form != "inverse_": 47 | source_text_column = "article" 48 | target_text_column = "highlights" 49 | dataset.source_prefix_text = "[INST] Summarize the following article: " 50 | dataset.source_suffix_text = " [/INST]" 51 | else: 52 | source_text_column = "highlights" 53 | target_text_column = "article" 54 | dataset.source_prefix_text = ("[INST] Write an article from the following summary: ") # fmt: skip 55 | dataset.source_suffix_text = " [/INST]" 56 | 57 | dataset.source_text_column = source_text_column 58 | dataset.target_text_column = target_text_column 59 | 60 | # Add original columns for judge tasks 61 | dataset.columns = [dataset.source_text_column, dataset.target_text_column] 62 | postprocess_fn = partial(default_text_postprocess, source_text_column=dataset.source_text_column) # fmt: skip 63 | 64 | return GenerationTaskConfig( 65 | dataset=dataset, 66 | postprocess_fn=postprocess_fn, 67 | prompt_func=default_text_prompt, 68 | metric_fns=[ 69 | evaluate( 70 | rouge_score, 71 | outputs=("rouge2", "rougeL", "rougeLsum"), 72 | types=("rouge2", "rougeL", "rougeLsum"), 73 | ), 74 | evaluate( 75 | ngram_score, 76 | inputs=("prediction", "source"), 77 | outputs=("ngram_overlap", "repetition_4"), 78 | ), 79 | ], 80 | min_gen_len=min_gen_len, 81 | max_gen_len=max_gen_len, 82 | max_gen_len_ratio=max_gen_len_ratio, 83 | max_prompt_len=max_prompt_len, 84 | ) 85 | -------------------------------------------------------------------------------- /scripts/fit_embedding_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | import argparse 8 | from typing import List 9 | 10 | import numpy as np 11 | import pyarrow.compute as pc 12 | import torch 13 | from stopes.utils.arrow_utils import pyarrow_fixed_size_array_to_numpy 14 | from tqdm.auto import tqdm 15 | 16 | from lcm.datasets.configs import ( 17 | DataLoadingConfig, 18 | ParquetBatchFormat, 19 | get_parquet_config_from_name, 20 | ) 21 | from lcm.datasets.dataloading import ( 22 | build_weighted_pipeline_with_renaming, 23 | ) 24 | from lcm.models.sonar_normalizer import SonarNormalizer, SonarNormalizerConfig 25 | 26 | 27 | def sample_sentences_from_mixed_sources( 28 | name_with_weights: List[str], 29 | max_nb_samples: int = 10**6, 30 | down_sample: int = 5, 31 | column: str = "_source_column", 32 | ) -> np.ndarray: 33 | ds_list = list(map(get_parquet_config_from_name, name_with_weights)) 34 | 35 | dlc = DataLoadingConfig( 36 | max_tokens=10000, 37 | min_length_of_sequences=1, 38 | order_by_length=False, 39 | nb_prefetch=2, 40 | num_parallel_calls=2, 41 | nb_epochs=1, 42 | output_format=ParquetBatchFormat.pyarrow, 43 | ) 44 | 45 | basic_iterator = build_weighted_pipeline_with_renaming(ds_list, dlc, 0, 1) 46 | 47 | nb_sentences = 0 48 | sentences_batch = [] 49 | 50 | pbar = tqdm(total=None) 51 | for batch in tqdm(basic_iterator): 52 | vecs = pyarrow_fixed_size_array_to_numpy(pc.list_flatten(batch[column]))[ 53 | ::down_sample 54 | ].astype(np.float32) 55 | sentences_batch.append(vecs) 56 | nb_sentences += len(vecs) 57 | pbar.update(len(vecs)) 58 | if nb_sentences > max_nb_samples: 59 | break 60 | 61 | return np.vstack(sentences_batch) 62 | 63 | 64 | def main( 65 | ds_mixture: List[str], 66 | save_path: str, 67 | max_nb_samples: int = 10**6, 68 | ): 69 | """ 70 | Args example: 71 | 72 | ds_mixture = [ 73 | "dataset1:5", 74 | "dataset2:10", 75 | "dataset3=train:2", 76 | ] 77 | save_path = f"/path/to/new/normalizer.pt" 78 | """ 79 | embs = sample_sentences_from_mixed_sources( 80 | ds_mixture, max_nb_samples=max_nb_samples 81 | ) 82 | normalizer = SonarNormalizer(SonarNormalizerConfig()) 83 | normalizer.fit(torch.from_numpy(embs)) 84 | 85 | torch.save( 86 | { 87 | "model": normalizer.state_dict(), 88 | "dataset_mixture": ds_mixture, 89 | }, 90 | save_path, 91 | ) 92 | print(f"Normalizer saved to {save_path}") 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--ds", nargs="+", type=str, required=True) 98 | parser.add_argument("--save_path", type=str, required=True) 99 | parser.add_argument("--max_nb_samples", type=int, default=10**6) 100 | args = parser.parse_args() 101 | main(args.ds, args.save_path, args.max_nb_samples) 102 | -------------------------------------------------------------------------------- /tests/units/test_recipes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import warnings 7 | from pathlib import Path 8 | 9 | import pytest 10 | from fairseq2.assets.error import AssetError 11 | from hydra import compose, initialize 12 | from omegaconf import DictConfig 13 | from torch.cuda import OutOfMemoryError 14 | 15 | from lcm.train.common import Trainer, get_trainer 16 | 17 | KEY_TRAINING_RECIPES = [ 18 | "pretrain/two_tower", 19 | "pretrain/mse", 20 | "finetune/two_tower", 21 | "finetune/mse", 22 | ] 23 | 24 | 25 | @pytest.mark.skip("Need to create a real datacards") 26 | @pytest.mark.parametrize("conf_name", KEY_TRAINING_RECIPES) 27 | def test_train_recipes(monkeypatch, conf_name, tmp_path, group="train"): 28 | """ 29 | Make sure that the recipes are synced with changes from the trainers' 30 | signatures and the training configs 31 | """ 32 | from lcm.utils.common import setup_conf 33 | from tests.common import DEBUG 34 | 35 | setup_conf() 36 | 37 | # The new dynamic loss scaler does not work with non-Cuda env, disable it 38 | monkeypatch.setattr( 39 | "lcm.train.trainer.Trainer.setup_optimizer_and_lr_schedule", lambda self: None 40 | ) 41 | 42 | with initialize( 43 | version_base="1.2", 44 | config_path="../../recipes/train/", 45 | ): 46 | config = compose( 47 | config_name="defaults", 48 | overrides=[ 49 | f"+{group}={conf_name}", 50 | f"trainer.output_dir={tmp_path}", 51 | f"++trainer.debug={DEBUG}", 52 | "++trainer.use_fsdp=false", 53 | ], 54 | ) 55 | assert isinstance(config, DictConfig), ( 56 | f"+{group}={conf_name} expect dict-type config, get {type(config)}." 57 | ) 58 | 59 | try: 60 | trainer = get_trainer(config.trainer) 61 | except (ValueError, AssetError, OutOfMemoryError) as err: 62 | if isinstance(err, OutOfMemoryError): 63 | warnings.warn( 64 | f"Ignoring the error because the model from {conf_name} is too big in the test machine" 65 | ) 66 | return 67 | if isinstance(err, ValueError): 68 | main_errs = err.args[:1] 69 | main_err = " ".join(map(str, main_errs)) 70 | else: 71 | main_err = err.args[0] 72 | if "The checkpoint" in main_err: 73 | warnings.warn( 74 | f"Ignoring the error when initializing the trainer for recipe {conf_name}." 75 | "Probably, it is because the initial checkpoint is missing in the test environment." 76 | f"The error: {err}" 77 | ) 78 | return 79 | else: 80 | raise 81 | 82 | assert isinstance(trainer, Trainer), ( 83 | f"+{group}={conf_name} Error parsing recipe." 84 | ) 85 | 86 | 87 | def find_eval_recipes(): 88 | recipes_dir = Path(__file__).parent.parent.parent / "recipes/eval/lcm" 89 | config_files = recipes_dir.glob("*.yaml") 90 | for config_file in config_files: 91 | yield config_file.stem 92 | -------------------------------------------------------------------------------- /lcm/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import ctypes 7 | from abc import abstractmethod 8 | from pathlib import Path 9 | from typing import ( 10 | Any, 11 | Dict, 12 | Iterable, 13 | Optional, 14 | Protocol, 15 | Sized, 16 | Type, 17 | TypeVar, 18 | Union, 19 | runtime_checkable, 20 | ) 21 | 22 | import torch 23 | from omegaconf import DictConfig, OmegaConf 24 | 25 | root_working_dir = Path(__file__).parent.parent.parent 26 | 27 | 28 | def set_mkl_num_threads(): 29 | """Setting mkl num threads to 1, so that we don't get thread explosion.""" 30 | mkl_rt = ctypes.CDLL("libmkl_rt.so") 31 | mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1))) 32 | 33 | 34 | def working_dir_resolver(p: str): 35 | """The omegaconf resolver that translates a relative path to the absolute path""" 36 | return "file://" + str(root_working_dir.joinpath(p).resolve()) 37 | 38 | 39 | def setup_conf(): 40 | """Register the common Hydra config groups used in LCM (for now only the launcher)""" 41 | from stopes.pipelines import config_registry # noqa 42 | 43 | recipe_root = Path(__file__).parent.parent.parent / "recipes" 44 | config_registry["lcm-common"] = "file://" + str((recipe_root / "common").resolve()) 45 | config_registry["lcm-root"] = "file://" + str(recipe_root.resolve()) 46 | 47 | # Register omegaconf resovlers 48 | OmegaConf.register_new_resolver("realpath", working_dir_resolver, replace=True) 49 | 50 | 51 | def torch_type( 52 | dtype: Optional[Union[str, torch.dtype]] = None, 53 | ) -> Optional[torch.dtype]: 54 | # Convert dtyp string from the checkpoint to torch.dtype 55 | # https://github.com/pytorch/pytorch/issues/40471 56 | if dtype is None: 57 | return None 58 | 59 | if isinstance(dtype, torch.dtype): 60 | return dtype 61 | 62 | _dtype = eval(dtype) # type: ignore 63 | assert isinstance(_dtype, torch.dtype), f"Invalid dtype value: {dtype}" 64 | return _dtype 65 | 66 | 67 | @runtime_checkable 68 | class Batched(Sized, Protocol): 69 | """Abstract class for batched data""" 70 | 71 | @abstractmethod 72 | def __getitem__(self, i: int) -> Any: ... 73 | 74 | 75 | T = TypeVar("T") 76 | 77 | 78 | def promote_config(config: Union[T, DictConfig, Dict], config_cls: Type[T]) -> T: 79 | if isinstance(config, (Dict, DictConfig)): 80 | import dacite 81 | 82 | if isinstance(config, DictConfig): 83 | config = OmegaConf.to_container(config) # type: ignore 84 | 85 | return dacite.from_dict( 86 | data_class=config_cls, 87 | data=config, # type: ignore 88 | config=dacite.Config(cast=[Path]), # type: ignore 89 | ) 90 | else: 91 | assert isinstance(config, config_cls), f"Unknown config type: {type(config)}" 92 | return config 93 | 94 | 95 | def batched(inputs: Iterable, batch_size=10000) -> Iterable: 96 | batch = [] 97 | for line in inputs: 98 | batch.append(line) 99 | if len(batch) == batch_size: 100 | yield batch 101 | batch = [] 102 | if len(batch) > 0: 103 | yield batch 104 | -------------------------------------------------------------------------------- /tests/test_headers.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | 4 | COPYRIGHT = "Copyright (c) Meta Platforms, Inc. and affiliates" 5 | FB_COPYRIGHT = "Copyright (c) Facebook, Inc. and its affiliates" 6 | 7 | PY_HEADER = """# # Copyright (c) Meta Platforms, Inc. and affiliates. 8 | # All rights reserved. 9 | # 10 | # 11 | 12 | """ 13 | 14 | DOUBLE_SLASH_COMMENT_HEADER = """// Copyright (c) Meta Platforms, Inc. and affiliates 15 | // All rights reserved. 16 | // 17 | // This source code is licensed under the license found in the 18 | // LICENSE file in the root directory of this source tree. 19 | 20 | """ 21 | 22 | 23 | def check_file(file: Path, autofix: bool = False) -> bool: 24 | full_text = file.read_text() 25 | if COPYRIGHT in full_text: 26 | return True 27 | 28 | # Either returns immediatly or first tries to fix things. 29 | if not autofix: 30 | return False 31 | 32 | if FB_COPYRIGHT in full_text: 33 | file.write_text(full_text.replace(FB_COPYRIGHT, COPYRIGHT)) 34 | return True 35 | 36 | if file.suffix == ".py": 37 | file.write_text(PY_HEADER + full_text) 38 | return True 39 | 40 | double_slash_comment_header_suffixes = {".ts", ".tsx", ".js", ".jsx", ".css"} 41 | if file.suffix in double_slash_comment_header_suffixes: 42 | file.write_text(DOUBLE_SLASH_COMMENT_HEADER + full_text) 43 | return True 44 | 45 | return False 46 | 47 | 48 | def test_all_files_have_a_copyright_header(autofix: bool = False): 49 | root = Path(__file__).resolve().parents[1] 50 | assert (root / ".git").is_dir() 51 | ls_tree = subprocess.check_output( 52 | ["git", "ls-tree", "-r", "HEAD", "--name-only"], 53 | encoding="utf-8", 54 | ) 55 | files = ls_tree.strip().splitlines() 56 | failed = [] 57 | for f in files: 58 | file = root / f 59 | if any(part.startswith("fb_") for part in file.parts): 60 | continue 61 | if file.suffix in ( 62 | ".lock", 63 | ".png", 64 | ".ico", 65 | ".json", 66 | ".jsonl", 67 | ".yml", 68 | ".yaml", 69 | ".md", 70 | ".tsv", 71 | ".svg", 72 | ".txt", 73 | ".toml", 74 | ".ipynb", 75 | ".html", 76 | ".csv", 77 | ".env", 78 | ".pt", 79 | ): 80 | continue 81 | if file.name in ( 82 | ".gitignore", 83 | ".prettierignore", 84 | ".prettierrc", 85 | ".nojekyll", 86 | "moses-config.lowercase", 87 | "LICENSE", 88 | "parse_options.sh", 89 | ): 90 | continue 91 | if file.is_symlink() or not file.exists(): 92 | continue 93 | try: 94 | license = check_file(file, autofix=autofix) 95 | except: # noqa 96 | license = False 97 | if not license: 98 | print(file) 99 | failed.append(file) 100 | 101 | assert not failed, f"{failed} are missing the license header" 102 | 103 | 104 | if __name__ == "__main__": 105 | test_all_files_have_a_copyright_header(autofix=True) 106 | -------------------------------------------------------------------------------- /lcm/evaluation/cli/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import asyncio 7 | import logging 8 | import time 9 | 10 | import hydra 11 | from hydra import compose, initialize 12 | from omegaconf import DictConfig, OmegaConf 13 | from stopes.core import Requirements 14 | 15 | from lcm.evaluation.arun import EvalRunModule, RunModuleConfig, schedule_task 16 | from lcm.evaluation.cli.configs import CliConfig, LauncherOptions, parse_configs 17 | from lcm.utils.common import promote_config, setup_conf 18 | 19 | logger = logging.getLogger("lcm.evaluation") 20 | 21 | 22 | def main( 23 | cfg: CliConfig, launcher_opts: LauncherOptions, logger: logging.Logger = logger 24 | ) -> None: 25 | """ 26 | Pipeline main steps: 27 | - Create multiple EvalModuleConfig for each task 28 | - Schedule and run the (sharded) tags on SLURM 29 | - Aggregate the metrics in the scheduler node 30 | """ 31 | 32 | job_args = getattr(cfg, "job_args", None) or {} 33 | 34 | if isinstance(job_args, DictConfig): 35 | job_args = OmegaConf.to_container(job_args) 36 | 37 | assert isinstance(job_args, dict), f"Unexpected `job_args` type: {type(job_args)}" 38 | 39 | # 1. Set up launcher 40 | # If launcher_opts is a string (i.e. passed via non-Hydra CLI), we set up an embedded 41 | # Hydra session to construct the launcher 42 | if isinstance(launcher_opts, DictConfig): 43 | launcher = hydra.utils.instantiate(launcher_opts) 44 | else: 45 | launcher_args = [] 46 | for k, v in job_args.items(): # type: ignore 47 | if k.startswith("launcher."): 48 | # Escape list-style string for Hydra 49 | if v and "," in v: 50 | launcher_args.append(f"++{k}='{v}'") 51 | else: 52 | launcher_args.append(f"++{k}={v}") 53 | 54 | setup_conf() # Register stopes' and lcm' launchers 55 | with initialize(version_base="1.2", config_path="../../../recipes/common"): 56 | launcher_cfg = compose( 57 | config_name="requirements", # load any config to attach and detach launcher later 58 | overrides=[f"+launcher={launcher_opts}"] + launcher_args, 59 | ) 60 | 61 | launcher = hydra.utils.instantiate(launcher_cfg)["launcher"] 62 | 63 | # 2. Set up requirements 64 | requirements_args = job_args.get("requirements", None) 65 | if requirements_args: 66 | requirements = promote_config(requirements_args, Requirements) 67 | else: 68 | requirements = None 69 | nshards = job_args.get("nshards", None) # type: ignore 70 | 71 | # 3. Set up run 72 | run_configs = parse_configs(cfg) 73 | task_names = [r.task_name for r in run_configs] 74 | task_modules = [ 75 | EvalRunModule( 76 | RunModuleConfig( 77 | requirements=requirements, nshards=nshards, **run_config.__dict__ 78 | ) 79 | ) 80 | for run_config in run_configs 81 | ] 82 | start = time.monotonic() 83 | loop = asyncio.get_event_loop() 84 | all_runs = asyncio.gather( 85 | *[schedule_task(m, launcher, logger=logger) for m in task_modules] 86 | ) 87 | loop.run_until_complete(all_runs) 88 | logger.info( 89 | f"Tasks {task_names} took {time.monotonic() - start:.2f} seconds (including scheduling)." 90 | ) 91 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_judge_tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import logging 7 | from typing import Iterable, List, Optional, Sequence 8 | 9 | import torch 10 | 11 | from lcm.evaluation.api import ( 12 | PREDICTION_TEXT_COLUMN, 13 | Example, 14 | Prediction, 15 | Predictor, 16 | Prompts, 17 | ) 18 | from lcm.evaluation.tasks import register_task 19 | from lcm.evaluation.tasks.base import GenerationTaskConfig 20 | from lcm.evaluation.utils.common import evaluate 21 | from lcm.evaluation.utils.data_utils import ResultsDataLoader, ResultsDatasetConfig 22 | 23 | logger = logging.getLogger("lcm.evaluation.test_judge") 24 | 25 | 26 | class Judge(Predictor): 27 | """ 28 | Example judge. 29 | 30 | A judge is any model that predicts from the outcome of an 31 | evaluation task 32 | """ 33 | 34 | def __call__( 35 | self, 36 | prompts: Prompts, 37 | max_prompt_len: Optional[int] = None, 38 | max_gen_len: Optional[int] = None, 39 | temperature: float = 0.0, 40 | top_p: float = 0.0, 41 | top_k: int = 0, 42 | echo: bool = True, 43 | return_logprobs: bool = False, 44 | show_progress: bool = False, 45 | disable_cache: bool = False, 46 | **kwargs, 47 | ) -> Sequence[Prediction]: 48 | assert isinstance(prompts, Iterable) 49 | preds: List[Prediction] = [] 50 | for seq in prompts: 51 | # The outcome of the dummy task is {"metrics": list, "prediction_embed": list, "prompts": list} 52 | # The judge will calculate the L2 distance between the means of the prediction and inputs previous 53 | # task 54 | assert "prompts" in seq and "prediction_embed" in seq 55 | prompts_mean = torch.mean(torch.stack(seq["prompts"])) # type: ignore 56 | preds_mean = torch.mean(torch.stack(seq["prediction_embed"])) # type: ignore 57 | preds.append( 58 | Prediction(text="judge", embed=preds_mean, tokens=prompts_mean) 59 | ) # type: ignore 60 | return preds 61 | 62 | 63 | def dist(prediction: torch.Tensor, targets: torch.Tensor) -> float: 64 | return prediction.item() - targets.item() 65 | 66 | 67 | def postprocess(x: Example) -> Example: 68 | # Get the best hypothesis 69 | prediction = x["prediction_embed"][0] 70 | 71 | prompts_mean = torch.mean(torch.stack(x["prompts"])) 72 | 73 | # Bause L2 distance is a batch-wise measure, we must add a dimension to it 74 | return {"prediction": prediction.unsqueeze(0), "targets": prompts_mean.unsqueeze(0)} 75 | 76 | 77 | @register_task("l2_as_judge", data_loader_type=ResultsDataLoader) 78 | def get_judge_task_config(dataset: ResultsDatasetConfig) -> GenerationTaskConfig: 79 | """ 80 | Implement the L2 distance as a judge task, to illustrate the 81 | new API. This trivial task just reload the generation results from a 82 | next-sentence-prediction task, then recalculate the L2 from the averaged result 83 | 84 | Note that the name convention `dataset_dir` is mandatory for all tasks that 85 | require a dataset located locally. 86 | """ 87 | dataset.source_text_column = PREDICTION_TEXT_COLUMN 88 | return GenerationTaskConfig( 89 | dataset=dataset, 90 | postprocess_fn=postprocess, 91 | metric_fns=[ 92 | evaluate(dist, outputs=("dist")), 93 | ], 94 | ) 95 | -------------------------------------------------------------------------------- /tests/units/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import logging 7 | from pathlib import Path 8 | from typing import Any 9 | 10 | import numpy as np 11 | import pyarrow as pa 12 | import pyarrow.parquet as pq 13 | import pytest 14 | from fairseq2.gang import FakeGang 15 | 16 | from lcm.datasets.configs import ( 17 | DataLoadingConfig, 18 | ParquetBatchFormat, 19 | ParquetDatasetConfig, 20 | ) 21 | 22 | logger = logging.getLogger("lcm.test.units") 23 | 24 | 25 | def mock_init_process_group(dconfig: Any, logger: logging.Logger): 26 | from tests.common import device 27 | 28 | return FakeGang(device=device) 29 | 30 | 31 | def mock_get_gang(): 32 | from tests.common import device 33 | 34 | return FakeGang(device=device) 35 | 36 | 37 | def simple_table() -> pa.Table: 38 | d = { 39 | "0": "zero", 40 | "1": "one", 41 | "2": "two", 42 | "3": "three", 43 | "4": "four", # fmt: skip 44 | "5": "five", 45 | "6": "six", 46 | "7": "seven", 47 | "8": "eight", 48 | "9": "nine", 49 | } 50 | 51 | def num2text(x): 52 | return " ".join([d[i] for i in list(str(x))]) 53 | 54 | data = { 55 | "cat": np.arange(1000) // 100, 56 | "id": np.arange(1000), 57 | "seq": [np.arange(i % 10) for i in range(1000)], 58 | "text": [f"random text {num2text(i)}" for i in range(1000)], 59 | } 60 | return pa.Table.from_pydict(data) 61 | 62 | 63 | @pytest.fixture(autouse=True) 64 | def patches(monkeypatch): 65 | # Change behaviour of the training pipelines for testing. 66 | # Please put all the patch in this fixture 67 | 68 | # LCM patch 69 | monkeypatch.setattr( 70 | "lcm.utils.distributed.init_process_group", mock_init_process_group 71 | ) 72 | monkeypatch.setattr("lcm.train.trainer.init_process_group", mock_init_process_group) 73 | monkeypatch.setattr( 74 | "lcm.train.trainer.TrainerBuilder._setup_additional_logging", lambda _: None 75 | ) 76 | monkeypatch.setattr("lcm.utils.logging.log_env_variables", lambda _: None) 77 | monkeypatch.setattr("lcm.train.trainer.log_env_variables", lambda _: None) 78 | monkeypatch.setattr("lcm.datasets.base.set_mkl_num_threads", lambda: None) 79 | monkeypatch.setattr("lcm.datasets.dataloader.set_mkl_num_threads", lambda: None) 80 | monkeypatch.setattr("lcm.evaluation.utils.common.setup_env", lambda: None) 81 | monkeypatch.setattr("lcm.evaluation.run.setup_env", lambda: None) 82 | 83 | monkeypatch.setattr("lcm.evaluation.cli.local.get_gang", mock_get_gang) 84 | 85 | 86 | @pytest.fixture() 87 | def simple_dataset(tmp_path: Path): 88 | pq.write_to_dataset(simple_table(), tmp_path, partition_cols=["cat"]) 89 | yield tmp_path 90 | 91 | 92 | @pytest.fixture() 93 | def simple_data_config(simple_dataset): 94 | dlc = DataLoadingConfig( 95 | batch_size=10, 96 | seed=12, 97 | nb_epochs=1, 98 | shuffle=False, 99 | output_format=ParquetBatchFormat.pyarrow, 100 | min_length_of_sequences=None, 101 | ) 102 | pdc = ParquetDatasetConfig( 103 | parquet_path=simple_dataset, 104 | columns=["id", "cat"], 105 | source_column="seq", 106 | source_text_column="text", 107 | nb_parallel_fragments=1, 108 | split_to_row_groups=False, 109 | ) 110 | yield dlc, pdc 111 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | [project] 7 | name = "large-concept-model" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | version = "0.1.0" 11 | description = "LCM: Large Concept Model an architecture using a higher-level semantic representation called a concept, which is language- and modality-agnostic, representing ideas or actions." 12 | dependencies = [ 13 | "dacite>=1.8.1", 14 | "fire>=0.7.0", 15 | "hydra-core>=1.3.2", 16 | "importlib-resources~=6.4", 17 | "numpy>=1.21", 18 | "polars>=1.16.0", 19 | "pyarrow>=16.1.0", 20 | "retrying>=1.3.4", 21 | "sentence-splitter>=1.4", 22 | "sonar-space>=0.3.2", 23 | "stopes[mono]>=2.2.0", 24 | "tensorboard>=2.18.0", 25 | ] 26 | 27 | classifiers = [ 28 | "License :: OSI Approved :: MIT License", 29 | "Topic :: Scientific/Engineering", 30 | "Development Status :: 4 - Beta", 31 | ] 32 | 33 | [project.urls] 34 | Source = "https://github.com/facebookresearch/large_concept_model" 35 | Tracker = "https://github.com/facebookresearch/large_concept_model/issues" 36 | 37 | [build-system] 38 | requires = ["flit_core >=3.2,<4", "setuptools < 74"] 39 | build-backend = "flit_core.buildapi" 40 | 41 | [tool.flit.module] 42 | name = "lcm" # TODO change module name 43 | 44 | [project.optional-dependencies] 45 | cpu = [ 46 | "torch==2.5.1+cpu", 47 | "torchaudio==2.5.1+cpu", 48 | "fairseq2n==0.3.0rc1", 49 | "fairseq2[arrow]==0.3.0rc1", 50 | ] 51 | eval = [ 52 | "accelerate>=1.2.0", 53 | "bert-score>=0.3.13", 54 | "editdistance>=0.8.1", 55 | "jinja2>=3.1.3", 56 | "nltk>=3.9.1", 57 | "rouge-score>=0.1.2", 58 | "sacrebleu>=2.4.3", 59 | "scikit-learn>=1.5.2", 60 | "spacy>=3.7.5", 61 | "textdescriptives>=2.8.2", 62 | "tiktoken>=0.8.0", 63 | "transformers>=4.45.0", 64 | "fairscale>=0.4.13", 65 | ] 66 | data = [ 67 | "numpy>=1.21", 68 | "numba>=0.60.0", 69 | "spacy>=3.7.5", 70 | "en_core_web_sm@https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl", 71 | "sacremoses>=0.1.1", 72 | "nltk>=3.8.1", 73 | "scipy>=1.14", 74 | "editdistance>=0.8.1", 75 | "sacrebleu>=2.4.1", 76 | "datasets>=2.18.0", 77 | "wtpsplit>=2.1.0", 78 | "transformers>=4.45.0", 79 | ] 80 | 81 | 82 | [tool.ruff] 83 | target-version = "py310" 84 | 85 | [tool.mypy] 86 | python_version = "3.10" 87 | show_error_codes = true 88 | check_untyped_defs = true 89 | ignore_missing_imports = true 90 | implicit_optional = true 91 | implicit_reexport = true 92 | 93 | files = [ 94 | "lcm/", # TODO 95 | ] 96 | 97 | [tool.uv] 98 | prerelease = "explicit" # for fairseq2 0.3.0rc1 99 | 100 | # TODO Change versions 101 | [tool.uv.sources] 102 | fairseq2 = [ 103 | { index = "fairseq2-cpu", extra = 'cpu' } 104 | ] 105 | fairseq2n = [ 106 | { index = "fairseq2-cpu", extra = 'cpu' } 107 | ] 108 | torch={index="pytorch-cpu"} 109 | torchaudio={index="pytorch-cpu"} 110 | # sonar-space = { git = "https://github.com/facebookresearch/SONAR", branch = "update_fs2" } # TODO 111 | 112 | [[tool.uv.index]] 113 | name = "fairseq2-cpu" 114 | url = "https://fair.pkg.atmeta.com/fairseq2/whl/rc/pt2.5.1/cpu/" 115 | explicit = true 116 | 117 | [[tool.uv.index]] 118 | name = "pytorch-cpu" 119 | url = "https://download.pytorch.org/whl/cpu" 120 | explicit = true 121 | 122 | [dependency-groups] 123 | dev = ["pytest-asyncio>=0.23.2", "pytest>=8.0.0"] 124 | 125 | 126 | [project.entry-points."fairseq2"] 127 | "fairseq2" = "lcm:setup_fairseq2" # TODO update lcm 128 | -------------------------------------------------------------------------------- /lcm/evaluation/metrics/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from collections import Counter 4 | from typing import Any, Dict, List, Sequence, Tuple 5 | 6 | import torch 7 | 8 | 9 | def f1(prediction: str, targets: List[str]) -> float: 10 | def _f1(pred_tokens: List[str], gt_tokens: List[str]) -> float: 11 | common = Counter(pred_tokens) & Counter(gt_tokens) 12 | num_same = sum(common.values()) 13 | if num_same == 0: 14 | return 0 15 | precision = 1.0 * num_same / len(pred_tokens) 16 | recall = 1.0 * num_same / len(gt_tokens) 17 | return (2 * precision * recall) / (precision + recall) 18 | 19 | return max(_f1(prediction.split(), target.split()) for target in targets) 20 | 21 | 22 | def exact_match(prediction: str, targets: List[str]) -> float: 23 | return max(float(prediction == target) for target in targets) 24 | 25 | 26 | def one_character_exact_match(prediction: str, targets: List[str]) -> float: 27 | return max(float(prediction.strip()[0] == target.strip()[0]) for target in targets) 28 | 29 | 30 | def cosine_based_match( 31 | prediction: torch.Tensor, choices: torch.Tensor, tgt_idx: int 32 | ) -> float: 33 | cosine_sim = torch.nn.CosineSimilarity(dim=-1) 34 | pred_idx = cosine_sim(prediction, choices).argmax() 35 | return float(pred_idx.item() == tgt_idx) 36 | 37 | 38 | def exact_match_f1(prediction: str, targets: List[str]) -> Tuple[float, float]: 39 | return (exact_match(prediction, targets), f1(prediction, targets)) 40 | 41 | 42 | def bleu(prediction: str, targets: List[str], lang=None, **kwargs: Any) -> float: 43 | import sacrebleu # type: ignore 44 | 45 | if not lang: 46 | return sacrebleu.corpus_bleu([prediction], [targets], **kwargs).score 47 | 48 | tokenizer = "13a" 49 | if lang.startswith("zho"): 50 | tokenizer = "zh" 51 | elif lang.startswith("jpn"): 52 | tokenizer = "ja-mecab" 53 | elif lang.startswith("kor"): 54 | tokenizer = "ko-mecab" 55 | 56 | return sacrebleu.corpus_bleu( 57 | [prediction], [targets], tokenize=tokenizer, **kwargs 58 | ).score 59 | 60 | 61 | def sentence_bleu(prediction: str, targets: List[str], **kwargs: Any) -> float: 62 | import sacrebleu # type: ignore 63 | 64 | return sacrebleu.sentence_bleu(prediction, targets, **kwargs).score 65 | 66 | 67 | def rouge_score( 68 | prediction: str, 69 | targets: List[str], 70 | types: Sequence[str] = ("rouge3", "rougeL", "rougeLsum"), 71 | **kwargs: Any, 72 | ) -> Dict[str, float]: 73 | from rouge_score import rouge_scorer # type: ignore 74 | 75 | split_summaries: bool = kwargs.pop("split_summaries", True) 76 | scorer = rouge_scorer.RougeScorer(types, split_summaries=split_summaries, **kwargs) 77 | if hasattr(scorer, "score_multi"): 78 | scores = scorer.score_multi(targets, prediction) # type: ignore 79 | else: 80 | assert len(targets) == 1, len(targets) 81 | scores = scorer.score(targets[0], prediction) 82 | avg_fmeasures = {s: scores[s].fmeasure for s in types} 83 | return avg_fmeasures 84 | 85 | 86 | def ngram_score(prediction: str, source: str) -> Dict[str, float]: 87 | from lcm.evaluation.utils.segment_alignment import get_all_ngrams 88 | 89 | src_ngrams = get_all_ngrams(source, 1, 4) 90 | tgt_ngrams = get_all_ngrams(prediction, 1, 4) 91 | result = {} 92 | result["ngram_overlap"] = len(set(tgt_ngrams).intersection(src_ngrams)) / max(1, len(set(src_ngrams))) # fmt: skip 93 | result["repetition_4"] = len(tgt_ngrams) / max(1, len(set(tgt_ngrams))) - 1 # fmt: skip 94 | return result 95 | -------------------------------------------------------------------------------- /lcm/evaluation/metrics/seahorse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # 4 | 5 | 6 | from pathlib import Path 7 | from typing import List, Optional, Sequence, Tuple 8 | 9 | import numpy as np 10 | import torch 11 | from fairseq2.typing import CPU, DataType, Device 12 | from tqdm.auto import tqdm 13 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 14 | 15 | from lcm.evaluation.api import Scorer 16 | from lcm.evaluation.utils.hf import infer_offline_model_dir 17 | from lcm.utils.common import batched 18 | 19 | 20 | class SeahorseScorer(Scorer): 21 | def __init__( 22 | self, 23 | model_name: str = "google/seahorse-large-q5", 24 | inputs: Tuple[str, ...] = ("prediction", "inputs"), 25 | bad_token_id: int = 497, 26 | good_token_id: int = 333, 27 | device: Device = CPU, 28 | dtype: Optional[DataType] = None, 29 | batch_size: int = 1, 30 | **kwargs, 31 | ): 32 | super().__init__( 33 | model_name=model_name, 34 | inputs=inputs, 35 | device=device, 36 | dtype=dtype, 37 | **kwargs, 38 | ) 39 | self.batch_size = batch_size 40 | self.bad_token_id = bad_token_id 41 | self.good_token_id = good_token_id 42 | 43 | @staticmethod 44 | def prompt_seahorse(article: str, summary: str): 45 | return f"premise: {article} hypothesis: {summary}" 46 | 47 | @classmethod 48 | def default_outputs(cls, model_name: str) -> Tuple[str, ...]: 49 | question_id = model_name[-1] 50 | return (f"seahorse-q{question_id}",) 51 | 52 | def init_model(self): 53 | # For Seahorse, we do not use HF cache, as this does not save 54 | # all SPM model. 55 | model_root = infer_offline_model_dir(dtype=self.dtype) 56 | if model_root: 57 | model_name = str(Path(model_root).joinpath(self.model_name)) 58 | else: 59 | model_name = self.model_name 60 | offload_folder = self.kwargs.get("offload_folder", None) 61 | self.tokenizer = AutoTokenizer.from_pretrained( 62 | model_name, 63 | clean_up_tokenization_spaces=False, 64 | ) 65 | self.model = AutoModelForSeq2SeqLM.from_pretrained( 66 | model_name, 67 | offload_folder=offload_folder, 68 | ).to(self.device) 69 | 70 | def score_texts( 71 | self, 72 | texts: Sequence[str], 73 | references: Optional[Sequence[str]] = None, 74 | show_progress: bool = False, 75 | ) -> np.ndarray: 76 | assert references, "Seahorse require references text (source documents)" 77 | batch_size = self.batch_size or len(texts) 78 | 79 | results: List[np.ndarray] = [] 80 | 81 | pairs_iter = batched(zip(references, texts), batch_size) 82 | if show_progress: 83 | pairs_iter = tqdm(pairs_iter) 84 | for pairs_batch in pairs_iter: 85 | prompts = list(map(lambda x: self.prompt_seahorse(*x), pairs_batch)) 86 | 87 | inputs = self.tokenizer( 88 | prompts, padding=True, truncation=True, return_tensors="pt" 89 | ).to(self.device) 90 | prefix = torch.tensor([[self.tokenizer.pad_token_id]] * len(prompts)).to(self.device) # fmt: skip 91 | with torch.inference_mode(): 92 | outputs = self.model(**inputs, decoder_input_ids=prefix) 93 | logits = outputs.logits[:, 0] 94 | norm_logits = torch.nn.functional.softmax(logits, dim=1) 95 | results.append(norm_logits[:, 333].cpu().numpy()) 96 | return np.concatenate(results) 97 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /lcm/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | 4 | # flake8: noqa 5 | import inspect 6 | from typing import List, Optional 7 | 8 | from fairseq2.gang import Gang 9 | 10 | from lcm.evaluation.api import Scorer, ScorerConfig 11 | from lcm.evaluation.metrics.common import ( 12 | bleu, 13 | exact_match, 14 | exact_match_f1, 15 | f1, 16 | rouge_score, 17 | sentence_bleu, 18 | ) 19 | from lcm.evaluation.metrics.similarity import ( 20 | edit_distance, 21 | l2_distance, 22 | longest_common_substring, 23 | memorization_score, 24 | mse_constrative_accuracy, 25 | nltk_sentence_bleu, 26 | ) 27 | 28 | _SCORER_MAP = { 29 | "sentence_fluency": "lcm.evaluation.metrics.sentence_fluency.FluencyClassifierScorer", 30 | "sentence_perplexity": "lcm.evaluation.metrics.sentence_fluency.PerplexityScorer", 31 | "round_trip_translation": "lcm.evaluation.metrics.sentence_fluency.RoundTripTranslationScorer", 32 | "word_repetition": "lcm.evaluation.metrics.sentence_fluency.WordRepetitionScorer", 33 | "token_repetition": "lcm.evaluation.metrics.sentence_fluency.TokenRepetitionScorer", 34 | "seahorse": "lcm.evaluation.metrics.seahorse.SeahorseScorer", 35 | "momentum_coherence": "lcm.evaluation.metrics.coherence_metrics.MomentumCoherenceProcessor", 36 | "translated_rouge": "lcm.evaluation.metrics.multilingual_similarity.TranslatedRougeScorer", 37 | "bertscore": "lcm.evaluation.metrics.multilingual_similarity.BertScoreScorer", 38 | } 39 | 40 | 41 | def get_scorer( 42 | config: ScorerConfig, 43 | metrics_to_report: Optional[List] = None, 44 | gang: Optional[Gang] = None, 45 | ) -> Optional[Scorer]: 46 | scorer_type = config.scorer_type 47 | if scorer_type not in _SCORER_MAP: 48 | raise ValueError(f"No metrics found for {scorer_type}") 49 | 50 | module_path, config_cls_name = _SCORER_MAP[scorer_type].rsplit(".", 1) 51 | module = __import__(module_path, fromlist=[config_cls_name]) 52 | scorer_cls = getattr(module, config_cls_name) 53 | assert issubclass(scorer_cls, Scorer), f"Unsupported scorer type: {scorer_cls}" 54 | defaults = inspect.signature(scorer_cls.__init__).parameters 55 | 56 | # Mark the metric that we don't want to calculate 57 | if "outputs" in defaults: 58 | output_columns = defaults["outputs"].default 59 | else: 60 | assert config.model_name, ( 61 | f"Cannot resolve output name for the scorer type {scorer_cls}" 62 | ) 63 | output_columns = scorer_cls.default_outputs(config.model_name) 64 | 65 | if isinstance(output_columns, str): 66 | output_columns = [output_columns] 67 | elif isinstance(output_columns, tuple): 68 | output_columns = list(output_columns) 69 | if metrics_to_report: 70 | for i, metric_name in enumerate(output_columns): 71 | if metric_name not in metrics_to_report: 72 | output_columns[i] = None 73 | 74 | if all(c is None for c in output_columns): 75 | return None 76 | 77 | kwargs = {"model_name": "", "outputs": tuple(output_columns)} 78 | 79 | if config.model_name: 80 | kwargs["model_name"] = config.model_name 81 | elif "model_name" in defaults: 82 | kwargs["model_name"] = defaults["model_name"].default 83 | 84 | if config.inputs: 85 | kwargs["inputs"] = config.inputs 86 | elif "inputs" in defaults: 87 | kwargs["inputs"] = defaults["inputs"].default 88 | 89 | _params = config.params or {} 90 | has_kwargs = len({k: v for k, v in defaults.items() if v.kind == v.VAR_KEYWORD}) > 0 91 | for k, v in _params.items(): 92 | if k in defaults: 93 | v = v or defaults[k].default 94 | elif not has_kwargs: 95 | continue 96 | kwargs[k] = v 97 | return scorer_cls(gang=gang, **kwargs) 98 | -------------------------------------------------------------------------------- /lcm/train/step_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass 7 | from typing import Literal, Optional 8 | 9 | import torch 10 | import torch.distributions as D 11 | from fairseq2.logging import get_log_writer 12 | from torch import Tensor 13 | 14 | from lcm.nn.schedulers import DDIMScheduler 15 | 16 | SUPPORTED_SAMPLERS = Literal["uniform", "beta"] 17 | SUPPORTED_WEIGHTINGS = Literal["none", "clamp_snr"] 18 | 19 | logger = get_log_writer(__name__) 20 | 21 | 22 | def beta_function(a, b): 23 | result = torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b)) 24 | return result 25 | 26 | 27 | @dataclass 28 | class StepsSamplerConfig: 29 | sampling: SUPPORTED_SAMPLERS = "uniform" 30 | weighting: SUPPORTED_WEIGHTINGS = "none" 31 | beta_a: float = 0.8 32 | beta_b: float = 1 33 | max_gamma: float = 5.0 34 | min_gamma: float = 0 35 | 36 | 37 | class StepsSampler(object): 38 | def __init__( 39 | self, 40 | config: StepsSamplerConfig, 41 | noise_scheduler: DDIMScheduler, 42 | ): 43 | num_diffusion_train_steps = noise_scheduler.num_diffusion_train_steps 44 | weights: Optional[Tensor] = None 45 | 46 | if config.sampling == "uniform": 47 | weights = torch.ones( 48 | num_diffusion_train_steps, 49 | ) 50 | 51 | elif config.sampling == "beta": 52 | # As motivated in https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/00328.pdf 53 | a = torch.tensor([config.beta_a]) 54 | b = torch.tensor([config.beta_b]) 55 | # a=1, b=1 -> uniform 56 | # The paper empirically chooses b=1, a=0.8 < 1 57 | 58 | steps = ( 59 | torch.arange(1, num_diffusion_train_steps + 1) 60 | / num_diffusion_train_steps 61 | ) 62 | weights = ( 63 | 1 / beta_function(a, b) * (steps ** (a - 1)) * ((1 - steps) ** (b - 1)) 64 | ) 65 | 66 | assert weights is not None, "The sampling weights were not properly set!" 67 | logger.info(f"Training with sampling weights={weights}") 68 | 69 | self.distrib = D.Categorical( 70 | probs=weights / weights.sum(), 71 | ) 72 | 73 | # setup weights for scaling: 74 | if config.weighting == "none": 75 | self.gamma_per_step = None 76 | 77 | elif config.weighting == "clamp_snr": 78 | # Min-SNR scheme from 79 | # https://arxiv.org/abs/2303.09556 80 | snrs = noise_scheduler.get_snrs() 81 | # gamma(t) = min(max_gamma, snr(t)) 82 | self.gamma_per_step = torch.clamp( 83 | snrs, max=config.max_gamma, min=config.min_gamma 84 | ) 85 | 86 | logger.info(f"Training with Gamma={self.gamma_per_step}") 87 | 88 | @property 89 | def _training_weights(self) -> Tensor: 90 | return self.distrib.probs 91 | 92 | def sample(self, size: torch.Size, device: torch.device): 93 | samples = self.distrib.sample(size).to(device) 94 | # print('Samples', samples) 95 | # print('Counts:', torch.bincount(samples.flatten())) 96 | return samples 97 | 98 | def get_loss_scales(self, steps): 99 | if self.gamma_per_step is None: 100 | return None 101 | 102 | # If we're using constant Gamma=1 (returning None), then the sum of 103 | # the loss scales is steps.numel(), to match the total mass, 104 | # we normalize the scales to sum to steps.numel() 105 | gamma = self.gamma_per_step.to(steps.device)[steps] 106 | gamma = gamma / gamma.sum() * steps.numel() 107 | return gamma 108 | -------------------------------------------------------------------------------- /lcm/nn/timestep_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import math 7 | from typing import Optional 8 | 9 | import torch 10 | from fairseq2.nn.projection import Linear 11 | from fairseq2.typing import DataType, Device 12 | from torch import Tensor 13 | from torch.nn import Module 14 | 15 | from lcm.nn.initialization import parse_activation_fn 16 | 17 | 18 | class DiTTimestepEncoder(Module): 19 | """ 20 | Embeds scalar timesteps into vector representations. 21 | Based on DiT's `TimestepEmbedder` 22 | https://github.com/facebookresearch/DiT/blob/main/models.py 23 | """ 24 | 25 | def __init__( 26 | self, 27 | embedding_dim: int, 28 | frequency_embedding_size: int = 256, 29 | activation_fn_name: str = "silu", 30 | device: Optional[Device] = None, 31 | dtype: Optional[DataType] = None, 32 | ): 33 | super().__init__() 34 | 35 | self.dtype = dtype 36 | 37 | self.device = device 38 | 39 | self.embedding_dim = embedding_dim 40 | 41 | self.frequency_embedding_size = frequency_embedding_size 42 | 43 | self.fc1 = Linear( 44 | frequency_embedding_size, 45 | embedding_dim, 46 | bias=True, 47 | device=device, 48 | dtype=dtype, 49 | ) 50 | self.nonlin = parse_activation_fn(activation_fn_name) 51 | self.fc2 = Linear( 52 | embedding_dim, 53 | embedding_dim, 54 | bias=True, 55 | device=device, 56 | dtype=dtype, 57 | ) 58 | 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self) -> None: 62 | """Reset the parameters and buffers of the module.""" 63 | torch.nn.init.normal_(self.fc1.weight, std=0.02) 64 | torch.nn.init.normal_(self.fc2.weight, std=0.02) 65 | 66 | if self.fc1.bias is not None: 67 | torch.nn.init.zeros_(self.fc1.bias) 68 | 69 | if self.fc2.bias is not None: 70 | torch.nn.init.zeros_(self.fc2.bias) 71 | 72 | @staticmethod 73 | def sinusoidal_timestep_embedding( 74 | timestep, frequency_embedding_size, max_period=10000 75 | ): 76 | """ 77 | Create sinusoidal timestep embeddings. 78 | :param timestep: a 1-D Tensor of N indices, one per batch element. 79 | These may be fractional. 80 | :param frequency_embedding_size: the dimension of the output. 81 | :param max_period: controls the minimum frequency of the embeddings. 82 | :return: an (N, D) Tensor of positional embeddings. 83 | 84 | Based on DiT's `TimestepEmbedder` 85 | https://github.com/facebookresearch/DiT/blob/main/models.py 86 | """ 87 | half = frequency_embedding_size // 2 88 | 89 | freqs = torch.exp( 90 | -math.log(max_period) 91 | * torch.arange(start=0, end=half, dtype=torch.float32) 92 | / half 93 | ).to(device=timestep.device) 94 | 95 | args = timestep[:, None].float() * freqs[None] 96 | 97 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 98 | 99 | if frequency_embedding_size % 2: 100 | embedding = torch.cat( 101 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 102 | ) 103 | 104 | return embedding 105 | 106 | def forward(self, timesteps: Tensor) -> Tensor: 107 | initial_size = timesteps.size() 108 | 109 | flat_timesteps = timesteps.view(-1, 1) 110 | 111 | t_freq = self.sinusoidal_timestep_embedding( 112 | flat_timesteps, self.frequency_embedding_size 113 | ).to(self.dtype) 114 | 115 | t_emb = self.fc1(t_freq) 116 | 117 | if self.nonlin is not None: 118 | t_emb = self.nonlin(t_emb) 119 | 120 | t_emb = self.fc2(t_emb) 121 | 122 | return t_emb.view(*initial_size, self.embedding_dim) 123 | -------------------------------------------------------------------------------- /lcm/evaluation/predictors/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass 7 | from typing import Iterable, List, Optional, Sequence 8 | 9 | import torch 10 | 11 | from ..api import Prediction, Predictor, PredictorConfig, Prompts 12 | 13 | 14 | @dataclass 15 | class DummyPredictorConfig(PredictorConfig): 16 | @classmethod 17 | def predictor_class(cls): 18 | return DummyPredictor 19 | 20 | 21 | class DummyPredictor: 22 | def __init__(self, config: DummyPredictorConfig, **kwargs): 23 | self.config = config 24 | if "eos_config" in kwargs: 25 | self.eos = "" 26 | else: 27 | self.eos = " " 28 | 29 | @staticmethod 30 | def from_config(config: DummyPredictorConfig, **kwargs) -> "DummyPredictor": # type: ignore 31 | return DummyPredictor(config, **kwargs) 32 | 33 | def __call__( # type: ignore 34 | self, 35 | prompts: Prompts, 36 | max_prompt_len: Optional[int] = None, 37 | max_gen_len: Optional[int] = None, 38 | temperature: float = 0.0, 39 | disable_cache: bool = False, 40 | greedy: bool = True, 41 | top_p: float = 0.0, 42 | top_k: int = 0, 43 | echo: bool = True, 44 | return_logprobs: bool = False, 45 | show_progress: bool = False, 46 | num_generations: int = 1, 47 | **kwargs, 48 | ) -> Sequence[Prediction]: 49 | if return_logprobs: 50 | raise NotImplementedError("The Dummy predictor does not support logprobs") 51 | if greedy and num_generations > 1: 52 | raise ValueError("Greedy generation only works with beam size 1") 53 | # prompts should be a NestedTensor by now. Here we return a fake 54 | # prediction that shares the same embeddings and input, with empty text 55 | preds: List[Prediction] = [] 56 | assert isinstance(prompts, Iterable) 57 | for seq in prompts: # 58 | preds.extend( 59 | [Prediction(text=self.eos, tokens=[0], embed=seq)] * num_generations # type: ignore 60 | ) 61 | return preds 62 | 63 | 64 | @dataclass 65 | class DummyJudgeConfig(PredictorConfig): 66 | @classmethod 67 | def predictor_class(cls): 68 | return DummyJudge 69 | 70 | 71 | class DummyJudge(Predictor): 72 | """ 73 | Example judge. 74 | 75 | A judge is any model that predicts from the outcome of an evaluation task. In this example, 76 | we just average out the prediction from the previous (dummy) predictor 77 | """ 78 | 79 | def __init__(self, config: DummyPredictorConfig, **kwargs): 80 | self.config = config 81 | 82 | def __call__( 83 | self, 84 | prompts: Prompts, 85 | max_prompt_len: Optional[int] = None, 86 | max_gen_len: Optional[int] = None, 87 | temperature: float = 0.0, 88 | top_p: float = 0.0, 89 | top_k: int = 0, 90 | echo: bool = True, 91 | return_logprobs: bool = False, 92 | show_progress: bool = False, 93 | disable_cache: bool = False, 94 | num_generations: int = 1, 95 | **kwargs, 96 | ) -> Sequence[Prediction]: 97 | assert isinstance(prompts, Iterable) 98 | preds: List[Prediction] = [] 99 | for seq in prompts: 100 | # The outcome of the dummy task is {"metrics": list, "prediction_embed": list, "prompts": list} 101 | # The judge will calculate the L2 distance between the means of the prediction and inputs previous 102 | # task 103 | assert "prompts" in seq and "prediction_embed" in seq 104 | preds_mean = torch.mean(torch.stack(seq["prediction_embed"])) # type: ignore 105 | preds.append(Prediction(text="judge", embed=preds_mean, tokens=[1])) 106 | return preds 107 | 108 | @staticmethod 109 | def from_config(config: DummyJudgeConfig, **kwargs) -> "DummyJudge": # type: ignore 110 | return DummyJudge(config) # type: ignore 111 | -------------------------------------------------------------------------------- /lcm/datasets/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import logging 7 | from abc import ABC, abstractmethod 8 | from typing import Callable, Dict, Generic, Iterator, Optional, Sequence, TypeVar, Union 9 | 10 | import torch 11 | from fairseq2.data.data_pipeline import DataPipeline 12 | from fairseq2.gang import FakeGang, Gang 13 | from fairseq2.typing import DataType 14 | 15 | from lcm.datasets.configs import ( 16 | DataLoadingConfig, 17 | DatasetConfigT, 18 | create_dataset_config_from_cards, 19 | ) 20 | from lcm.datasets.dataloading import ( 21 | build_weighted_pipeline_with_renaming as default_build_fn, 22 | ) 23 | from lcm.utils.common import Batched, set_mkl_num_threads 24 | 25 | BatchT_co = TypeVar("BatchT_co", bound=Union[Dict, Batched], covariant=True) 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class DataLoader(ABC, Generic[BatchT_co, DatasetConfigT]): 30 | def __init__( 31 | self, 32 | data_config: DataLoadingConfig, 33 | datasets: Sequence[DatasetConfigT], 34 | gang: Gang, 35 | builder_func: Callable[..., DataPipeline] = default_build_fn, 36 | dtype: DataType = torch.float16, 37 | ): 38 | self.data_config = data_config 39 | self.datasets = list(map(create_dataset_config_from_cards, datasets)) 40 | self.dtype = dtype 41 | self.gang = gang 42 | self.builder_func = builder_func 43 | 44 | self._pipeline: Optional[DataPipeline] = None 45 | 46 | @property 47 | def pipeline(self) -> DataPipeline: 48 | if self._pipeline is None: 49 | logger.info(f"R{self.gang.rank} self._pipeline is None, building...") 50 | gang_rank = self.gang.rank if self.gang else 0 51 | world_size = self.gang.size if self.gang else 1 52 | 53 | self._pipeline = self.builder_func( 54 | self.datasets, self.data_config, gang_rank, world_size 55 | ) 56 | assert self._pipeline, ( 57 | f"Cannot build data pipeline from config {self.data_config}" 58 | ) 59 | return self._pipeline 60 | 61 | def destroy(self) -> None: 62 | """Destroy the pipeline to rebuild it with different shuffling""" 63 | self._pipeline = None 64 | # Build again and reset it 65 | logger.info(f"R{self.gang.rank} resetting the pipeline in DataLoader.destroy") 66 | self.reset() 67 | 68 | def reset(self) -> None: 69 | """ 70 | Applying reset will result in different shuffling for next iterations, 71 | since pipeline will use modified generator state from previous one. 72 | This's suitable side effect for `sharding_in_memory=False` (training) scenario. 73 | 74 | Illustrative example : 75 | >>> import torch 76 | >>> from fairseq2.data import read_sequence 77 | 78 | >>> def get_one_epoch_pipeline(): 79 | ... torch.manual_seed(13) 80 | ... return read_sequence(list(range(10))).shuffle(5) 81 | 82 | >>> bb = get_one_epoch_pipeline().and_return() 83 | >>> list(bb) 84 | [3, 1, 2, 4, 0, 8, 5, 6, 9, 7] 85 | >>> bb.reset() 86 | >>> list(bb) 87 | [4, 0, 3, 2, 1, 9, 7, 6, 8, 5] 88 | """ 89 | self.pipeline.reset() 90 | 91 | @abstractmethod 92 | def iterate_batches(self) -> Iterator[BatchT_co]: ... 93 | 94 | 95 | class BaseDataLoader(DataLoader[dict, DatasetConfigT]): 96 | def __init__( 97 | self, 98 | data_config: DataLoadingConfig, 99 | datasets: Sequence[DatasetConfigT], 100 | dtype: DataType = torch.float16, 101 | gang: Gang = None, 102 | ) -> None: 103 | gang = gang or FakeGang() 104 | super().__init__( 105 | data_config=data_config, 106 | datasets=datasets, 107 | builder_func=default_build_fn, 108 | dtype=dtype, 109 | gang=gang, 110 | ) 111 | set_mkl_num_threads() 112 | 113 | def iterate_batches(self) -> Iterator[dict]: 114 | yield from iter(self.pipeline) 115 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_similarity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import math 7 | import unittest 8 | from typing import Any, Callable, Dict, List, Tuple 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from lcm.evaluation.metrics import ( 14 | edit_distance, 15 | longest_common_substring, 16 | memorization_score, 17 | nltk_sentence_bleu, 18 | ) 19 | from lcm.evaluation.metrics.similarity import mse_constrative_accuracy 20 | 21 | 22 | class WhitespaceTokenizer: 23 | def __call__(self, s: str) -> List[str]: 24 | return s.split() 25 | 26 | 27 | class StubTokenizer: 28 | def __init__(self) -> None: 29 | self.tok_to_id: Dict[str, int] = {} 30 | 31 | def __call__(self, s: str) -> List[int]: 32 | toks = s.split() 33 | tokenized = [] 34 | for tok in toks: 35 | if tok not in self.tok_to_id: 36 | self.tok_to_id[tok] = len(self.tok_to_id) 37 | tokenized.append(self.tok_to_id[tok]) 38 | return tokenized 39 | 40 | 41 | def get_examples() -> List[Tuple[str, str]]: 42 | return [ 43 | ( 44 | "John Doe and he lives in the United Kingdom .", 45 | "Jane Doe and she lives in the United States .", 46 | ), 47 | ( 48 | "the ratio of the radius of a circle to its", 49 | "a famous decimal that never enters a repeating pattern .", 50 | ), 51 | ( 52 | "Billy Bob . They are on trial for tax fraud", 53 | "Billy Bob . Are they really on trial for tax", 54 | ), 55 | ( 56 | "Billy Bob . They are on trial for tax fraud", 57 | "Billy Bob . They are on trial for tax fraud", 58 | ), 59 | ] 60 | 61 | 62 | class TestSimilarity(unittest.TestCase): 63 | def run_scenarios( 64 | self, 65 | score_fn: Callable[..., Any], 66 | expected: List[float], 67 | ) -> None: 68 | for i, (q, g) in enumerate(get_examples()): 69 | # Ensure it works with str tokens and token ids (List[int]) 70 | for tok_cls in [WhitespaceTokenizer, StubTokenizer]: 71 | with self.subTest(tokenizer=tok_cls.__name__, example_id=i): 72 | tok = tok_cls() 73 | self.assertAlmostEqual( 74 | expected[i], 75 | score_fn(tok(q), tok(g)), # type: ignore 76 | places=4, # type: ignore 77 | ) 78 | 79 | def test_bleu(self) -> None: 80 | self.run_scenarios( 81 | score_fn=nltk_sentence_bleu, 82 | expected=[0.3247, 0.0211, 0.3799, 1.0], 83 | ) 84 | 85 | def test_edit_distance(self) -> None: 86 | self.run_scenarios( 87 | score_fn=edit_distance, 88 | expected=[3.0, 9.0, 4.0, 0.0], 89 | ) 90 | 91 | def test_memorization_score(self) -> None: 92 | self.run_scenarios( 93 | score_fn=memorization_score, 94 | expected=[0.7, 0.1, 0.3, 1.0], 95 | ) 96 | 97 | def test_longest_common_substring(self) -> None: 98 | self.run_scenarios( 99 | score_fn=longest_common_substring, 100 | expected=[4.0, 1.0, 4.0, 10.0], 101 | ) 102 | 103 | def test_mse_constrative_accuracy(self) -> None: 104 | q, g = get_examples()[-1] 105 | expected = 0.9 106 | tok = StubTokenizer() 107 | tokq = tok(q) 108 | tokg = tok(g) 109 | 110 | embed = torch.nn.Embedding(len(tok.tok_to_id), 5).requires_grad_(False) 111 | 112 | # Repeat the text to run the constrative against the first sentence as distractors 113 | embedq = embed(torch.tensor(tokq)).repeat(2, 1, 1) 114 | embedg = embed(torch.tensor(tokg)).repeat(2, 1, 1) 115 | 116 | acc = mse_constrative_accuracy(embedq, embedg) 117 | assert len(acc) == 20 # batch_size x seq_len 118 | 119 | mean_acc = np.asarray(acc).reshape(2, 10).mean(axis=1)[0] 120 | 121 | assert math.isclose(expected, mean_acc, abs_tol=0.01) 122 | 123 | 124 | if __name__ == "__main__": 125 | unittest.main() 126 | -------------------------------------------------------------------------------- /lcm/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import os 7 | import subprocess 8 | from pathlib import Path 9 | from typing import Dict 10 | 11 | import torch.distributed as dist 12 | from fairseq2.gang import get_rank 13 | from fairseq2.logging import get_log_writer 14 | from fairseq2.recipes.logging import _setup_aten_logging, _setup_nccl_logging 15 | from fairseq2.recipes.utils.log import log_environment_info 16 | from fairseq2.typing import Device 17 | 18 | logger = get_log_writer(__name__) 19 | 20 | LCM_REPOS = ["lcm", "fairseq2", "sonar", "stopes"] 21 | 22 | 23 | def setup_additional_logging(log_folder: Path): 24 | slurm_job_id: str = os.environ.get("SLURM_JOB_ID", "local") 25 | base_log_file = log_folder / f"{slurm_job_id}_{get_rank()}.log" 26 | _setup_aten_logging(base_log_file, force=False) 27 | _setup_nccl_logging(base_log_file, force=False) 28 | 29 | 30 | def log_git_status( 31 | repo: str = "lcm", 32 | tolerate_uncommitted: bool = False, 33 | ) -> str: 34 | assert repo in LCM_REPOS, ( 35 | f"Only the LCM core repos ({LCM_REPOS}) are supported in `log_git_status`" 36 | ) 37 | 38 | repo_path = os.path.dirname(globals()[repo].__file__) 39 | 40 | try: 41 | # check for modifications 42 | mod_output = subprocess.run( 43 | f"cd {repo_path}; git status --porcelain", capture_output=True, shell=True 44 | ) 45 | modifications = mod_output.stdout.decode("utf-8").split("\n") 46 | uncommitted = len( 47 | [ 48 | m 49 | for m in modifications 50 | if m.startswith(" M") or m.startswith(("M ", "A ", "D ", "R ")) 51 | ] 52 | ) 53 | if uncommitted > 0: 54 | if tolerate_uncommitted: 55 | logger.warning( 56 | ( 57 | "Changes to {} should be committed before running a job " 58 | "- found {} change(s)." 59 | " We will continue regardless, but the git commit hashes are unreliable!" 60 | ).format(repo, uncommitted) 61 | ) 62 | else: 63 | raise AssertionError( 64 | f"Changes to {repo} should be committed before running a job - found {uncommitted} change(s). If runing tests try adding `--debug-training`" 65 | ) 66 | 67 | # get commit hash 68 | output = subprocess.run( 69 | f"cd {repo_path}; git rev-parse HEAD", capture_output=True, shell=True 70 | ) 71 | commit_hash = output.stdout.decode("ascii").strip() 72 | logger.info(f"{repo} ({repo_path}) commit hash: {commit_hash}") 73 | 74 | return commit_hash 75 | 76 | except AssertionError: 77 | raise 78 | 79 | except BaseException: 80 | raise ValueError( 81 | f"Could not check the git revision hash, make sure you can run `git status` in {repo} ({repo_path})" 82 | ) 83 | 84 | 85 | def log_lcm_environment(tolerate_uncommitted: bool = False) -> Dict: 86 | """ 87 | For traceability and reproducibility, get the latest commit hash for the four key repos 88 | """ 89 | 90 | commit_hashes = { 91 | repo: log_git_status(repo, tolerate_uncommitted) for repo in LCM_REPOS 92 | } 93 | 94 | return commit_hashes 95 | 96 | 97 | def log_env_variables(device: Device) -> None: 98 | """Log environment variables useful for debugging, including 99 | fs2's `log_environment_info` to dump Fairseq2, torch, nccl and other relevant metadata 100 | """ 101 | for key in sorted(os.environ.keys()): 102 | if not ( 103 | key.startswith( 104 | ("SLURM_", "SUBMITIT_", "NCCL_", "FI_", "CUDA_", "FAIRSEQ2_", "TORCH_") 105 | ) 106 | or key 107 | in ( 108 | "MASTER_ADDR", 109 | "MASTER_PORT", 110 | "RANK", 111 | "WORLD_SIZE", 112 | "LOCAL_RANK", 113 | "LOCAL_WORLD_SIZE", 114 | ) 115 | ): 116 | continue 117 | value = os.environ[key] 118 | logger.info(f"R{dist.get_rank()} -- {key}={value}") 119 | 120 | # For Fairseq2, torch and devices 121 | log_environment_info(logger, device) 122 | -------------------------------------------------------------------------------- /lcm/train/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | 7 | import asyncio 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | from typing import Any, Optional 11 | 12 | import hydra 13 | import submitit 14 | from omegaconf import DictConfig, OmegaConf 15 | from omegaconf.omegaconf import open_dict, read_write 16 | from stopes.core import Requirements, StopesModule 17 | 18 | from lcm.train.common import get_trainer 19 | from lcm.utils.common import setup_conf 20 | 21 | setup_conf() 22 | 23 | 24 | class TrainModule(StopesModule): 25 | def requirements(self) -> Requirements: 26 | return self.config.requirements 27 | 28 | def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0): 29 | # Add module.name to the config's log_folder 30 | with read_write(self.config): 31 | self.config.log_folder = Path(self.config.log_folder) / self.name() 32 | 33 | trainer = get_trainer(self.config) 34 | 35 | # trainer should have a run() method 36 | trainer.run() 37 | 38 | def should_retry( 39 | self, 40 | ex: Exception, 41 | attempt: int, 42 | iteration_value: Optional[Any] = None, 43 | iteration_index: int = 0, 44 | ) -> bool: 45 | # Before retrying the failed train run, clean the environment to make sure 46 | # fs2 ProcessGroupGang can set up properly without raising error if the 47 | # gang is not set up reliably 48 | with submitit.helpers.clean_env(): 49 | return "ValueError" not in str(ex) 50 | 51 | def name(self): 52 | """ 53 | implement this if you want to give a fancy name to your job 54 | """ 55 | name = self.config.get( 56 | "experiment_name", f"{self.__class__.__name__}_{self.sha_key()[:10]}" 57 | ) 58 | return name 59 | 60 | 61 | @dataclass 62 | class TrainingConfig: 63 | trainer: DictConfig 64 | launcher: DictConfig 65 | dry_run: bool = False 66 | 67 | 68 | async def run(config: TrainingConfig): 69 | # dump the all config to the outputs config log 70 | dump_dir = Path(config.launcher.config_dump_dir) 71 | dump_dir.mkdir(parents=True, exist_ok=True) 72 | OmegaConf.resolve(config) # type: ignore 73 | # XXX: do we want to promote datasets configs from thier names to the final params 74 | OmegaConf.save( 75 | config=config, 76 | f=str(dump_dir / "all_config.yaml"), 77 | ) 78 | 79 | train_config = config.trainer 80 | 81 | # If launcher.cluster = debug set debug in the trainer to True 82 | with open_dict(train_config): 83 | if config.launcher.cluster == "debug": 84 | train_config.debug = True 85 | train_config.log_folder = config.launcher.log_folder 86 | 87 | if getattr(config, "dry_run", False): 88 | trainer = get_trainer(train_config) 89 | print(f"Trainer: {trainer}") 90 | print(f"Train config: {getattr(trainer, 'config')}") 91 | 92 | return 93 | 94 | launcher = hydra.utils.instantiate(config.launcher) 95 | 96 | train_module = TrainModule(train_config) 97 | wait_on = launcher.schedule(train_module) 98 | 99 | await wait_on 100 | 101 | 102 | @hydra.main( 103 | version_base="1.2", 104 | config_path="../../recipes/train", 105 | config_name="defaults.yaml", 106 | ) 107 | def main(config: TrainingConfig) -> None: 108 | """ 109 | Launch a train module from CLI. 110 | 111 | Example: 112 | 113 | ```sh 114 | python -m lcm.train +pretrain=mse 115 | ``` 116 | 117 | in this example, `pretrain` is a folder under the `recipes` directory and `mse` 118 | is a yaml file with the trainer configuration. 119 | This yaml file must be in the `trainer` package (i.e. start with the `# @package trainer` 120 | hydra directive). 121 | It must contain a `__trainer__` entry defining the constructor for the trainer. 122 | 123 | You can use `-c job` to see the configuration without running anything. You can use 124 | `dry_run=true` to initialize the trainer from the configuration and make sure it's correct 125 | without running the actual training. To debug the jobs, you can use `launcher.cluster=debug` 126 | """ 127 | asyncio.run(run(config)) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /lcm/models/two_tower_diffusion_lcm/archs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from lcm.models.two_tower_diffusion_lcm.builder import ( 7 | DenoiserConfig, 8 | EncoderFrontendConfig, 9 | TransformerConfig, 10 | TwoTowerDiffusionLCModelConfig, 11 | lcm_arch, 12 | ) 13 | from lcm.nn.projection import ProjectionConfig 14 | from lcm.nn.schedulers import DDIMSchedulerConfig 15 | 16 | 17 | @lcm_arch("toy_two_tower_diffusion_lcm") 18 | def toy_lcm() -> TwoTowerDiffusionLCModelConfig: 19 | return TwoTowerDiffusionLCModelConfig( 20 | context_encoder=TransformerConfig(num_layers=2), 21 | denoiser=DenoiserConfig(num_layers=2), 22 | # TODO change normalizer name to align with the normalizer instructions 23 | sonar_normalizer_name="dummy_sonar_normalizer", 24 | ) 25 | 26 | 27 | @lcm_arch("two_tower_diffusion_lcm_1_6B") 28 | def two_tower_diffusion_lcm_1_6B() -> TwoTowerDiffusionLCModelConfig: 29 | """5-layer encodder / 13-layer denoiser / model dim 2048 30 | Parameter Size: 1,635,101,696""" 31 | model_dim: int = 2048 32 | num_attn_heads: int = 16 33 | return TwoTowerDiffusionLCModelConfig( 34 | model_dim=model_dim, 35 | max_seq_len=4096, 36 | frontend=EncoderFrontendConfig(), 37 | context_encoder=TransformerConfig( 38 | num_layers=5, 39 | ffn_inner_dim=4 * model_dim, 40 | num_attn_heads=num_attn_heads, 41 | final_dropout_p=0.0, 42 | attention_dropout_p=0.0, 43 | dropout_p=0.1, 44 | mha_output_proj_bias=True, 45 | use_swiglu=True, 46 | layer_normalization_style="rms", 47 | pos_embedding_style="rope", 48 | ), 49 | denoiser=DenoiserConfig( 50 | num_layers=13, 51 | timestep_embed_dim=model_dim, 52 | ffn_inner_dim=4 * model_dim, 53 | pos_embedding_style="none", 54 | num_attn_heads=num_attn_heads, 55 | final_dropout_p=0.0, 56 | attention_dropout_p=0.0, 57 | dropout_p=0.1, 58 | mha_output_proj_bias=True, 59 | use_swiglu=True, 60 | layer_normalization_style="rms", 61 | pre_denoiser=ProjectionConfig(), 62 | post_denoiser=ProjectionConfig(), 63 | ), 64 | # TODO change normalizer name to align with the normalizer instructions 65 | sonar_normalizer_name="dummy_sonar_normalizer", 66 | trained_with_cf_guidance=True, 67 | noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100), 68 | ) 69 | 70 | 71 | @lcm_arch("two_tower_diffusion_lcm_7B") 72 | def two_tower_diffusion_lcm_7B() -> TwoTowerDiffusionLCModelConfig: 73 | # 5-layer encodder / 14-layer denoiser / model dim 4096 74 | # Parameter Size: 6,930,781,696 75 | model_dim: int = 4096 76 | num_attn_heads: int = 32 77 | return TwoTowerDiffusionLCModelConfig( 78 | model_dim=model_dim, 79 | max_seq_len=4096, 80 | frontend=EncoderFrontendConfig(), 81 | context_encoder=TransformerConfig( 82 | num_layers=5, 83 | ffn_inner_dim=4 * model_dim, 84 | num_attn_heads=num_attn_heads, 85 | final_dropout_p=0.0, 86 | attention_dropout_p=0.0, 87 | dropout_p=0.1, 88 | mha_output_proj_bias=True, 89 | use_swiglu=True, 90 | layer_normalization_style="rms", 91 | pos_embedding_style="rope", 92 | ), 93 | denoiser=DenoiserConfig( 94 | num_layers=14, 95 | timestep_embed_dim=model_dim, 96 | ffn_inner_dim=4 * model_dim, 97 | pos_embedding_style="none", 98 | num_attn_heads=num_attn_heads, 99 | final_dropout_p=0.0, 100 | attention_dropout_p=0.0, 101 | dropout_p=0.1, 102 | mha_output_proj_bias=True, 103 | use_swiglu=True, 104 | layer_normalization_style="rms", 105 | pre_denoiser=ProjectionConfig(), 106 | post_denoiser=ProjectionConfig(), 107 | ), 108 | # TODO change normalizer name to align with the normalizer instructions 109 | sonar_normalizer_name="dummy_sonar_normalizer", 110 | trained_with_cf_guidance=True, 111 | noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100), 112 | ) 113 | -------------------------------------------------------------------------------- /scripts/prepare_wikipedia.py: -------------------------------------------------------------------------------- 1 | # # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | 5 | 6 | import asyncio 7 | from pathlib import Path 8 | 9 | import fire 10 | from stopes.core.launcher import Launcher 11 | from stopes.core.stopes_module import Requirements 12 | from stopes.modules.partitioned_data_mapper import stopes_data_mapper 13 | from stopes.modules.preprocess.sonar_text_embedding import ( 14 | LangColumnConfig, 15 | SonarTextEmbedderConfig, 16 | ) 17 | from stopes.utils.sharding.abstract_shards import BatchFormat 18 | from stopes.utils.sharding.hf_shards import HFInputConfig 19 | from stopes.utils.sharding.parquet_shards import ( 20 | ParquetOutputConfig, 21 | ) 22 | 23 | from lcm.datasets.sentence_splitter_pipeline import ( 24 | FullPipeline, 25 | FullPipelineConfig, 26 | SentenceSplitterConfig, 27 | ) 28 | 29 | 30 | def run(output_dir: Path): 31 | """ 32 | launch a preprocessing pipeline, this will use SAT to split text in sentences and then use SONAR to 33 | embed each sentence. 34 | This example downloads data from huggingface and outputs it to a parquet dataset. 35 | 36 | `output_dir` is the directory where the processed data will be written. The output will be in a parquet file format. 37 | """ 38 | # setup the sentence splitter 39 | splitter_config = SentenceSplitterConfig( 40 | columns=[ 41 | "text" 42 | ], # this is the column in the input dataset where we expect to find text to split 43 | model_name="sat-3l", 44 | verbose=True, 45 | sentence_threshold=0.2, # sentence splitting threshold to tune based on the data (domain, language, etc.) 46 | max_sentence_len=256, 47 | ) 48 | # setup SONAR, we are only going to deal with english 49 | sonar_encoder_config = SonarTextEmbedderConfig( 50 | column_config=[ # we can process several columns at once which is useful for finetuning datasets 51 | LangColumnConfig("text_sentences", lang_value="eng_Latn") 52 | ], # splitter has output a new column `text_sentences` and this is what we will embed 53 | device="cuda", # we want to work on a GPU, if you want to try this on a cpu, change the device here 54 | ) 55 | # setup the full pipeline, that will use the splitter and the sonar embeddings, 56 | full_config = FullPipelineConfig( 57 | splitter_config=splitter_config, 58 | sonar_encoder_config=sonar_encoder_config, 59 | ) 60 | 61 | # setup the input to download from huggingface, adjust this to the dataset you care about 62 | # Checkout https://github.com/facebookresearch/stopes/tree/main/stopes/utils/sharding for other potential 63 | # input systems (jsonl, parquet) and how to configure them in this pipeline. 64 | input_config = HFInputConfig( 65 | input_file="wikimedia/wikipedia", 66 | data_dir="20231101.en", 67 | split="train[0:200]", # we are only taking a small sample for the toy example 68 | num_shards=1, # as we have a small sample, we don't need many shards, you should increase this for larger datasets 69 | batch_format=BatchFormat.ARROW, 70 | batch_size=5, # adjust to your system's size 71 | ) 72 | # setup the output to write to parquet 73 | output_config = ParquetOutputConfig( 74 | output_dir, 75 | keep_same_partitioning=False, 76 | row_group_size=200, 77 | batch_size=200, 78 | ) 79 | 80 | # requirements for our slurm jobs, if you are using a local cpu, you can ignore this 81 | # if you are using slurm but no gpus, remove the gpus_per_node config 82 | req = Requirements( 83 | mem_gb=120, gpus_per_node=1, cpus_per_task=10, timeout_min=3 * 24 * 60 84 | ) 85 | # launching config, here we use `local` to run locally, but you can switch it to `slurm` if you have a SLURM cluster. 86 | launcher = Launcher( 87 | cache=None, 88 | cluster="local", 89 | # for SLURM you can set some parameters of the launcher here 90 | # cluster="slurm", 91 | # update_parameters={ 92 | # "partition": "learn", 93 | # }, 94 | ) 95 | 96 | # launch the shards processing 97 | stopes_wrapped = stopes_data_mapper(req, {"name": "prep_wiki"})(FullPipeline) 98 | stopes_module = stopes_wrapped(input_config, output_config, full_config) 99 | 100 | asyncio.run(launcher.schedule(stopes_module)) 101 | 102 | 103 | if __name__ == "__main__": 104 | fire.Fire(run) 105 | -------------------------------------------------------------------------------- /lcm/evaluation/metrics/similarity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # 4 | # 5 | 6 | from typing import List, Union 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def l2_distance( 13 | prediction: torch.Tensor, targets: torch.Tensor, flatten: bool = False 14 | ) -> Union[List[float], List[List[float]]]: 15 | l2_dist = (prediction - targets).pow(2).sum(dim=-1).sqrt() 16 | l2_dist = l2_dist.squeeze() 17 | if flatten: 18 | l2_dist = torch.flatten(l2_dist) 19 | x = l2_dist.cpu().tolist() 20 | if isinstance(x, float): # l2_dist is a torch scalar 21 | return [x] 22 | return x 23 | 24 | 25 | def mse_constrative_accuracy( 26 | prediction: torch.Tensor, 27 | targets: torch.Tensor, 28 | ) -> List[float]: 29 | """ 30 | Calculate the mse_loss between the predictions and groundtruth, each 31 | has the shape batch_size x seq_len x model_dim. 32 | 33 | Returns: 34 | a list of length batch_size x seq_len and scores are between 0 and 35 | 1 (with 0.5 corresponding to a random model) 36 | """ 37 | assert prediction.size() == targets.size() 38 | batch_size, seq_len, model_dim = prediction.size() 39 | preds_flat = prediction.reshape(batch_size * seq_len, model_dim) 40 | gt_flat = targets.reshape(batch_size * seq_len, model_dim) 41 | pos_dist = torch.nn.functional.mse_loss(preds_flat, gt_flat, reduction="none").sum( 42 | dim=-1 43 | ) 44 | distractors_flat = torch.stack( 45 | [ 46 | targets[torch.arange(batch_size) != batch_id, :].reshape( 47 | (batch_size - 1) * (seq_len), model_dim 48 | ) 49 | for batch_id in range(batch_size) 50 | for j in range(seq_len) 51 | ], 52 | dim=0, 53 | ) 54 | n_distractors = distractors_flat.shape[1] 55 | 56 | neg_dist = torch.nn.functional.mse_loss( 57 | preds_flat.unsqueeze(1).repeat(1, n_distractors, 1), 58 | distractors_flat, 59 | reduction="none", 60 | ).sum(dim=-1) 61 | ptw_acc = (neg_dist > pos_dist.unsqueeze(-1)).to(torch.float).mean(-1) 62 | return ptw_acc.cpu().tolist() 63 | 64 | 65 | def nltk_sentence_bleu(prediction_tokens: List[int], target_tokens: List[int]) -> float: 66 | try: 67 | from nltk.translate.bleu_score import ( # type: ignore 68 | SmoothingFunction, 69 | sentence_bleu, 70 | ) 71 | except (ImportError, ModuleNotFoundError): 72 | return -1.0 73 | 74 | return float( 75 | sentence_bleu( 76 | [target_tokens], 77 | prediction_tokens, 78 | smoothing_function=SmoothingFunction().method1, 79 | ) 80 | ) 81 | 82 | 83 | def edit_distance(prediction_tokens: List[int], target_tokens: List[int]) -> float: 84 | # Get minimum edit distance between prediction and targets in the case of multiple targets 85 | try: 86 | import editdistance 87 | except (ImportError, ModuleNotFoundError): 88 | return -1.0 89 | 90 | return float(editdistance.eval(prediction_tokens, target_tokens)) 91 | 92 | 93 | def longest_common_substring( 94 | prediction_tokens: List[int], target_tokens: List[int] 95 | ) -> float: 96 | lengths = np.zeros((len(prediction_tokens), len(target_tokens)), dtype=int).tolist() 97 | longest = 0 98 | 99 | for i in range(len(prediction_tokens)): 100 | for j in range(len(target_tokens)): 101 | if prediction_tokens[i] != target_tokens[j]: 102 | continue 103 | elif i == 0 or j == 0: 104 | lengths[i][j] = 1 105 | else: 106 | lengths[i][j] = lengths[i - 1][j - 1] + 1 107 | 108 | longest = max(longest, lengths[i][j]) 109 | 110 | return float(longest) 111 | 112 | 113 | def memorization_score(prediction_tokens: List[int], target_tokens: List[int]) -> float: 114 | # See "Emergent and Predictable Memorization in Large Language Models" 115 | # https://arxiv.org/pdf/2304.11158.pdf 116 | correct = sum( 117 | pred == target for pred, target in zip(prediction_tokens, target_tokens) 118 | ) 119 | correct_avg = correct / len(target_tokens) 120 | 121 | return float(correct_avg) 122 | 123 | 124 | def cos_sim(prediction: np.ndarray, targets: np.ndarray): 125 | pred_norm = np.linalg.norm(prediction, axis=-1, keepdims=True) 126 | targets_norm = np.linalg.norm(targets, axis=-1, keepdims=True) 127 | 128 | normalized_preds = prediction / pred_norm 129 | normalized_targets = targets / targets_norm 130 | 131 | return np.einsum("ij,ij->i", normalized_preds, normalized_targets) 132 | -------------------------------------------------------------------------------- /lcm/nn/initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import math 7 | from functools import partial 8 | from typing import Literal, Optional 9 | 10 | import torch 11 | from fairseq2.nn.projection import Linear 12 | from fairseq2.nn.transformer import TransformerNormOrder 13 | from torch.nn import Module 14 | 15 | SUPPORTED_INIT_TYPES = Literal[ 16 | "xavier", 17 | "sonar", 18 | "zero", 19 | "trunc_normal", 20 | "kaiming_uniform", 21 | "none", 22 | ] 23 | 24 | 25 | SONAR_STD = 0.006 26 | # Most SONAR embeddings have a distribution with the mean close to 0 and std close to 0.006 27 | # Initializing embedding-like parameters (e.g. end-of-text vector) from a similar distribution is recommended, 28 | # to minimize their disruption of the model training 29 | 30 | 31 | def get_init_fn(style: str = "xavier", sonar_std: float = SONAR_STD): 32 | if style == "xavier": 33 | return init_linear_xavier 34 | 35 | if style == "kaiming_uniform": 36 | return init_linear_kaiming_uniform 37 | 38 | if style == "sonar": 39 | return partial(init_linear_to_sonar, sonar_std=sonar_std) 40 | 41 | if style == "zero": 42 | return init_linear_zero 43 | 44 | if style == "trunc_normal": 45 | return init_linear_trunc_normal 46 | 47 | if style == "none": 48 | return None 49 | 50 | else: 51 | raise ValueError(f"Could not recognize initialization function {style}") 52 | 53 | 54 | def init_linear_to_sonar(layer: Linear, sonar_std: float) -> None: 55 | """ 56 | Initialize the post-lcm in such a way, that if it is fed layer-normed 57 | lcm outputs (with zero mean and unit variance), its outputs have zero 58 | mean and the variance of SONAR embeddings. 59 | """ 60 | if layer.bias is not None: 61 | torch.nn.init.zeros_(layer.bias) 62 | 63 | std = sonar_std * (3 / layer.input_dim) ** 0.5 64 | 65 | torch.nn.init.uniform_(layer.weight, a=-std, b=std) 66 | 67 | 68 | def init_linear_xavier(layer: Linear) -> None: 69 | torch.nn.init.xavier_uniform_(layer.weight) 70 | if layer.bias is not None: 71 | torch.nn.init.zeros_(layer.bias) 72 | 73 | 74 | def init_linear_zero(layer: Linear) -> None: 75 | torch.nn.init.zeros_(layer.weight) 76 | if layer.bias is not None: 77 | torch.nn.init.zeros_(layer.bias) 78 | 79 | 80 | def init_linear_trunc_normal(layer: Linear) -> None: 81 | torch.nn.init.trunc_normal_(layer.weight, std=1e-3) 82 | if layer.bias is not None: 83 | torch.nn.init.zeros_(layer.bias) 84 | 85 | 86 | def init_linear_kaiming_uniform(layer: Linear) -> None: 87 | torch.nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5)) 88 | 89 | if layer.bias is not None: 90 | fan_in = layer.weight.size(1) 91 | 92 | m = 1 93 | if layer.weight.ndim > 2: 94 | for s in layer.weight.shape[2:]: 95 | m *= s 96 | 97 | fan_in *= m 98 | 99 | # We do not calculate the true standard deviation of the uniform 100 | # distribution (i.e. multiply with sqrt(3)). See 101 | # https://github.com/pytorch/pytorch/issues/57109#issuecomment-828847575. 102 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 103 | 104 | torch.nn.init.uniform_(layer.bias, -bound, bound) 105 | 106 | 107 | def parse_norm_order(var: str) -> TransformerNormOrder: 108 | norm_order: TransformerNormOrder 109 | if var == "pre": 110 | norm_order = TransformerNormOrder.PRE 111 | elif var == "post": 112 | norm_order = TransformerNormOrder.POST 113 | elif var == "normformer": 114 | norm_order = TransformerNormOrder.PRE_WITH_NORMFORMER 115 | else: 116 | raise ValueError(f"Unknown normalization order {var}") 117 | 118 | return norm_order 119 | 120 | 121 | def parse_activation_fn(var: str = None) -> Optional[Module]: 122 | if var is None: 123 | return None 124 | 125 | activ_fn: Module 126 | 127 | if var == "relu": 128 | activ_fn = torch.nn.ReLU() 129 | elif var == "tanh": 130 | activ_fn = torch.nn.Tanh() 131 | elif var == "elu": 132 | activ_fn = torch.nn.ELU() 133 | elif var == "leaky_relu": 134 | activ_fn = torch.nn.LeakyReLU() 135 | elif var == "prelu": 136 | activ_fn = torch.nn.PReLU() 137 | elif var == "selu": 138 | activ_fn = torch.nn.SELU() 139 | elif var == "gelu": 140 | activ_fn = torch.nn.GELU() 141 | elif var == "silu": 142 | activ_fn = torch.nn.SiLU() 143 | elif var == "softsign": 144 | activ_fn = torch.nn.Softsign() 145 | elif var == "sigmoid": 146 | activ_fn = torch.nn.Sigmoid() 147 | elif var == "hardsigmoid": 148 | activ_fn = torch.nn.Hardsigmoid() 149 | else: 150 | raise ValueError(f"Unknown activation function {var}") 151 | 152 | return activ_fn 153 | -------------------------------------------------------------------------------- /tests/units/training/test_toy_task_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | from fairseq2.models import get_model_family 11 | from tqdm.auto import tqdm 12 | 13 | from lcm.datasets.configs import ( 14 | DataLoadingConfig, 15 | ) 16 | from lcm.models.base_lcm import BaseLCModelConfig 17 | from lcm.nn.transformer import TransformerConfig 18 | from lcm.train.lcm.trainer import LCMTrainerBuilder, LCMTrainingConfig 19 | from lcm.train.mse_lcm.criterion import ReconstructionCriterionConfig 20 | from lcm.train.trainer import Trainer 21 | from lcm.utils.card_utils import load_model_from_card, load_model_with_overrides 22 | from lcm.utils.model_type_registry import lcm_model_type_registry 23 | from tests.common import DEBUG, device 24 | 25 | 26 | def get_eval_loss(trainer: Trainer) -> float: 27 | trainer.model.eval() 28 | trainer.validation_data_loader.pipeline.reset() # type: ignore 29 | for mb in trainer.valid_metric_bag.values(): 30 | mb.reset_metrics() 31 | for batch in tqdm(trainer.validation_data_loader.iterate_batches()): # type: ignore 32 | loss = trainer.criterion(batch) 33 | trainer.valid_metric_bag[batch.name].update( 34 | [loss], 35 | ) 36 | values = { 37 | name: mb.sync_and_compute_metrics() 38 | for name, mb in trainer.valid_metric_bag.items() 39 | } 40 | trainer.model.train() 41 | # taking average value from over all datasets 42 | return np.mean([x["loss"].item() for x in values.values()]) # type: ignore 43 | 44 | 45 | def compare_models(loaded_model, trained_model): 46 | loaded_state_dict = loaded_model.state_dict() 47 | for param_name, param in trained_model.named_parameters(): 48 | assert torch.allclose( 49 | param.data, loaded_state_dict[param_name].to(param.device) 50 | ), f"{param_name} differs after loading the model!" 51 | 52 | 53 | def test_toy_mse_training(tmp_path, simple_train_dataset, simple_validation_dataset): 54 | """ 55 | Test that the trainer can be built, that it can run, and that it saves the model well. 56 | """ 57 | model_config_or_name = BaseLCModelConfig(lcm=TransformerConfig(num_layers=1)) 58 | criterion_cfg = ReconstructionCriterionConfig( 59 | name="next_sentence_mse", reduction="mean" 60 | ) 61 | 62 | train_dirname = tmp_path / "tmp_lcm_trainer_output" 63 | n_steps = 10 64 | training_cfg = LCMTrainingConfig( 65 | debug=DEBUG, 66 | fake_gang_device=device, 67 | model_config_or_name=model_config_or_name, 68 | use_fsdp=False, 69 | use_submitit=False, 70 | data_loading_config=DataLoadingConfig(batch_size=10), 71 | training_data=[simple_train_dataset], 72 | validation_data=[simple_validation_dataset], 73 | output_dir=train_dirname, 74 | criterion=criterion_cfg, 75 | num_lr_warmup_steps=n_steps // 3 + 1, 76 | max_steps=n_steps, 77 | checkpoint_every_n_steps=1, 78 | save_model_every_n_steps=1, 79 | lr=1e-6, 80 | ) 81 | 82 | # Testing that the trainer is buildable 83 | builder = LCMTrainerBuilder(training_cfg) 84 | trainer = builder.build_trainer() 85 | 86 | # Testing that the training does happen and decreases the loss 87 | old_eval_loss = get_eval_loss(trainer) 88 | assert math.isfinite(old_eval_loss), "Old eval loss is not finite!" 89 | trainer.run() 90 | new_eval_loss = get_eval_loss(trainer) 91 | assert math.isfinite(new_eval_loss), "New eval loss is not finite!" 92 | assert new_eval_loss < old_eval_loss 93 | 94 | # testing that the checkpointing works 95 | step_id, state_dict = trainer.checkpoint_manager.load_last_checkpoint() 96 | assert step_id == n_steps, f"step_id={step_id} does not match n_steps={n_steps}" 97 | for param_name, param in trainer.model.named_parameters(): 98 | assert torch.allclose( 99 | param.data, state_dict["model"][param_name].to(param.device) 100 | ), f"{param_name} differs in checkpoint!" 101 | 102 | # Testing that the model card has been created 103 | assert (train_dirname / "model_card.yaml").exists(), ( 104 | f"The file {train_dirname}/model_card.yaml does not exist" 105 | ) 106 | 107 | # Testing that the model card can be used to load the model correctly 108 | card = trainer.create_model_card_for_last_checkpoint() 109 | model_type = get_model_family(card) 110 | model_loader = lcm_model_type_registry.get_model_loader(model_type) 111 | loaded_model = model_loader(card) 112 | compare_models(loaded_model, trainer.model) 113 | 114 | # Test that the model card API works out-of-the-box 115 | loaded_model_1 = load_model_from_card(str(train_dirname / "model_card.yaml")) 116 | compare_models(loaded_model_1, trainer.model) 117 | 118 | loaded_model_2 = load_model_with_overrides(train_dirname) 119 | compare_models(loaded_model_2, trainer.model) 120 | -------------------------------------------------------------------------------- /tests/units/inference/test_base_lcm_batched_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from random import randint 7 | from typing import List 8 | 9 | import pytest 10 | import torch 11 | from fairseq2.nn.padding import pad_seqs 12 | from stopes.core.utils import batch as into_batches 13 | from torch import Tensor 14 | 15 | from lcm.datasets.batch import EmbeddingsBatch 16 | from lcm.inference.lcm import LCMGenerator, LCMGeneratorOptions 17 | from lcm.models.base_lcm import BaseLCModelConfig, create_base_lcm_model 18 | from lcm.nn.transformer import TransformerConfig 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "disable_cache,eos_threshold", 23 | [(True, 0.3), (True, 1.1), (False, 0.3), (False, 1.1)], 24 | ) 25 | def test_caching_and_batching(disable_cache, eos_threshold): 26 | """ 27 | Test that batching works with and without batching 28 | and with a different stopping criterion to test the control of output length. 29 | """ 30 | # Sample input data 31 | batch_size = 5 32 | sonar_embed_dim = 4 33 | max_prompt_len = 7 34 | max_gen_len = 8 35 | 36 | # Create sample input list 37 | sample_inputs: List[Tensor] = [] 38 | 39 | for _ in range(batch_size): 40 | random_len = randint(1, max_prompt_len) 41 | sample_inputs.append(torch.randn(random_len, sonar_embed_dim)) 42 | 43 | # Create an LCM model 44 | model_cfg = BaseLCModelConfig( 45 | sonar_embed_dim=sonar_embed_dim, 46 | model_dim=sonar_embed_dim, 47 | lcm=TransformerConfig( 48 | ffn_inner_dim=4 * sonar_embed_dim, 49 | num_layers=2, 50 | num_attn_heads=1, 51 | ), 52 | ) 53 | 54 | model = create_base_lcm_model(model_cfg) 55 | eos_vec = torch.zeros(sonar_embed_dim) 56 | 57 | generator = LCMGenerator( 58 | model, 59 | eos_vec=eos_vec, 60 | options=LCMGeneratorOptions( 61 | eos_threshold=eos_threshold, 62 | sample_latent_variable=False, 63 | ), 64 | ) 65 | 66 | def generate(batch_size): 67 | print( 68 | f"Generating with a batch_size of {batch_size} - disable_cache={disable_cache} and eos_threshold={eos_threshold}" 69 | ) 70 | lcm_outputs = [] 71 | for batch in into_batches(sample_inputs, batch_size=3): 72 | padded_batch = EmbeddingsBatch(*pad_seqs(batch)) 73 | 74 | lcm_output = generator( 75 | padded_batch, 76 | max_gen_len=max_gen_len, 77 | disable_cache=disable_cache, 78 | ) 79 | lcm_outputs.extend([hyp[0].seq for hyp in lcm_output.hypotheses]) 80 | 81 | return lcm_outputs 82 | 83 | lcm_output_with_batching = generate(batch_size=3) 84 | lcm_output_without_batching = generate(batch_size=1) 85 | 86 | # Check if the outputs are equal (indicating successful batching/caching) 87 | assert all( 88 | [ 89 | torch.allclose( 90 | a, 91 | b, 92 | atol=1e-6, 93 | ) 94 | for a, b in zip(lcm_output_with_batching, lcm_output_without_batching) 95 | ] 96 | ), "Outputs with and without batching do not match" 97 | 98 | 99 | @pytest.mark.parametrize( 100 | "batch_size", 101 | [3, 1], 102 | ) 103 | def test_single_input_stopping(batch_size): 104 | """ 105 | Test that batching we don't stop prematurely with small batches 106 | """ 107 | # Sample input data 108 | sonar_embed_dim = 4 109 | max_prompt_len = 4 110 | max_gen_len = 3 111 | 112 | # Create sample input list 113 | sample_inputs: List[Tensor] = [] 114 | 115 | for _ in range(batch_size): 116 | random_len = randint(1, max_prompt_len) 117 | sample_inputs.append(torch.randn(random_len, sonar_embed_dim)) 118 | 119 | # Create an LCM model 120 | model_cfg = BaseLCModelConfig( 121 | sonar_embed_dim=sonar_embed_dim, 122 | model_dim=sonar_embed_dim, 123 | lcm=TransformerConfig( 124 | ffn_inner_dim=4 * sonar_embed_dim, 125 | num_layers=2, 126 | num_attn_heads=1, 127 | ), 128 | ) 129 | 130 | model = create_base_lcm_model(model_cfg) 131 | eos_vec = torch.zeros(sonar_embed_dim) 132 | 133 | generator = LCMGenerator( 134 | model, 135 | eos_vec=eos_vec, 136 | options=LCMGeneratorOptions( 137 | eos_threshold=1, 138 | sample_latent_variable=False, 139 | trim_hypotheses=False, 140 | ), 141 | ) 142 | 143 | lcm_outputs = [] 144 | for batch in into_batches(sample_inputs, batch_size=3): 145 | padded_batch = EmbeddingsBatch(*pad_seqs(batch)) 146 | 147 | lcm_output = generator( 148 | padded_batch, 149 | max_gen_len=max_gen_len, 150 | ) 151 | lcm_outputs.extend([hyp[0].seq for hyp in lcm_output.hypotheses]) 152 | 153 | # checking that we didn't stop prematurely 154 | for output_seq, prompt_seq in zip(lcm_outputs, sample_inputs): 155 | assert output_seq.size(0) - prompt_seq.size(0) == max_gen_len 156 | -------------------------------------------------------------------------------- /lcm/models/two_tower_diffusion_lcm/frontend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | from fairseq2.logging import get_log_writer 11 | from fairseq2.nn import Embedding, LearnedPositionEncoder, PositionEncoder 12 | from fairseq2.nn.incremental_state import IncrementalStateBag 13 | from fairseq2.nn.padding import PaddingMask 14 | from fairseq2.nn.projection import Linear 15 | from fairseq2.typing import DataType, Device 16 | from torch import Tensor 17 | from torch.nn import Dropout, Module 18 | 19 | from lcm.nn.initialization import SUPPORTED_INIT_TYPES, get_init_fn 20 | 21 | logger = get_log_writer(__name__) 22 | 23 | 24 | @dataclass 25 | class EncoderFrontendConfig: 26 | dropout_p: float = 0.0 27 | """ The dropout probability applied to the module' output""" 28 | 29 | pre_linear_bias: bool = True 30 | """ Whether or not the pre-linear layer has a bias term""" 31 | 32 | pre_linear_init_fn: SUPPORTED_INIT_TYPES = "kaiming_uniform" 33 | 34 | weight_normalization: bool = False 35 | 36 | embedding_std: float = 1.0 37 | 38 | 39 | class EncoderFrontend(Module): 40 | """ 41 | A fronted for the context encoder in encoder-decoder LCMs 42 | """ 43 | 44 | embed: Embedding 45 | pos_encoder: Optional[PositionEncoder] 46 | dropout: Optional[Dropout] 47 | 48 | def __init__( 49 | self, 50 | sonar_embed_dim: int, 51 | model_dim: int, 52 | config: EncoderFrontendConfig, 53 | pos_encoder: Optional[PositionEncoder], 54 | *, 55 | device: Optional[Device] = None, 56 | dtype: Optional[DataType] = None, 57 | ) -> None: 58 | """ 59 | :param sonar_embed_dim 60 | The embedding dimension of the sentence encoder, in this case SONAR 61 | :param model_dim 62 | The model embedding dimension 63 | :param config: 64 | A Frontend config. See `LCMFrontendConfig` 65 | :param pos_encoder: 66 | An optional position encoder. 67 | """ 68 | 69 | super().__init__() 70 | 71 | self.sonar_embed_dim = sonar_embed_dim 72 | 73 | self.model_dim = model_dim 74 | 75 | self.device = device 76 | 77 | # Pre-linear to map to model dimension 78 | init_fn = get_init_fn(config.pre_linear_init_fn) 79 | 80 | lin = Linear( 81 | sonar_embed_dim, 82 | model_dim, 83 | bias=config.pre_linear_bias, 84 | device=device, 85 | dtype=dtype, 86 | init_fn=init_fn, 87 | ) 88 | 89 | if config.weight_normalization: 90 | self.pre_linear = torch.nn.utils.parametrizations.weight_norm(lin) 91 | else: 92 | self.pre_linear = lin 93 | 94 | if pos_encoder is not None: 95 | if pos_encoder.encoding_dim != self.model_dim: 96 | raise ValueError( 97 | f"`encoding_dim` of `pos_encoder` and `embedding_dim` of \ 98 | `embed` must be equal, but are {pos_encoder.encoding_dim} \ 99 | and {self.model_dim} instead." 100 | ) 101 | 102 | self.pos_encoder = pos_encoder 103 | else: 104 | self.register_module("pos_encoder", None) 105 | 106 | if config.dropout_p > 0.0: 107 | self.dropout = Dropout(config.dropout_p) 108 | else: 109 | self.register_module("dropout", None) 110 | 111 | self.reset_parameters(embedding_std=config.embedding_std) 112 | 113 | def reset_parameters(self, embedding_std: float) -> None: 114 | """Initialize module parameters. 115 | The positional embeddings should be initialized with the 116 | same order of magnitude as the semantic embeddings, in order 117 | to make the early training as stable as possible. 118 | Otherwise, the positional and special token embeddings would 119 | flood out the semantic information. 120 | """ 121 | logger.info( 122 | f"Initializing frontend embeddings (special and positional) ~ N(0, {embedding_std})" 123 | ) 124 | if isinstance(self.pos_encoder, LearnedPositionEncoder): 125 | torch.nn.init.normal_(self.pos_encoder.weight, std=embedding_std) 126 | 127 | def forward( 128 | self, 129 | seqs: Tensor, 130 | padding_mask: Optional[PaddingMask], 131 | state_bag: Optional[IncrementalStateBag] = None, 132 | **kwargs, 133 | ) -> Tuple[Tensor, Optional[PaddingMask]]: 134 | """ 135 | Apply pre-linear (if relevant) and add positional embeddings 136 | """ 137 | 138 | # pre-linear if any: 139 | seqs = self.pre_linear(seqs) 140 | 141 | if self.pos_encoder is not None: 142 | seqs = self.pos_encoder( 143 | seqs, 144 | padding_mask, 145 | state_bag=state_bag, 146 | **kwargs, 147 | ) 148 | 149 | if self.dropout is not None: 150 | seqs = self.dropout(seqs) 151 | 152 | return seqs, padding_mask 153 | -------------------------------------------------------------------------------- /tests/units/evaluation/test_cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # 5 | 6 | import json 7 | import logging 8 | import sys 9 | 10 | from lcm.evaluation.__main__ import cfg_from_cli 11 | from lcm.evaluation.cli.local import main as local_main 12 | from lcm.evaluation.utils.data_utils import load_jsonl 13 | from lcm.utils.common import batched, setup_conf 14 | 15 | logger = logging.getLogger("lcm.evaluation.test_cli") 16 | 17 | 18 | def test_dynamic_prompt(tmp_path, simple_json_dataset, monkeypatch): 19 | """Test that the dynamic prompt can be set via source_prefix_text""" 20 | setup_conf() 21 | bsz = 5 22 | commands = [ 23 | "lcm.evaluation.__main__.py", 24 | "--tasks", 25 | "dummy_json_generation", 26 | "--predictor", 27 | "dummy", 28 | "--data_loading.batch_size", 29 | str(bsz), 30 | "--dataset.file_path", 31 | str(simple_json_dataset), 32 | ] 33 | 34 | raw_data = load_jsonl(simple_json_dataset) 35 | 36 | # Default prompt defined in the task 37 | with monkeypatch.context() as m1: 38 | m1.setattr( 39 | sys, 40 | "argv", 41 | commands 42 | + [ 43 | "--dump_dir", 44 | str(tmp_path / "test_1"), 45 | "--dataset.source_text_column", 46 | "input_text", 47 | ], 48 | ) 49 | eval_config, _ = cfg_from_cli() 50 | local_main(eval_config, logger=logger) 51 | 52 | with open( 53 | tmp_path.joinpath("test_1", "results", "dummy_json_generation.json") 54 | ) as fh: 55 | result = json.load(fh) 56 | assert "results" in result and result["results"]["m1"] == 0.0 57 | 58 | default_prompts = [f"[INST] Prompt: {x['input_text']}" for x in raw_data] 59 | default_prompts = batched(default_prompts, batch_size=bsz) # type: ignore 60 | results = load_jsonl( 61 | tmp_path.joinpath( 62 | "test_1", 63 | "raw_results", 64 | "dummy_json_generation", 65 | "dummy_json_generation.json", 66 | ) 67 | ) 68 | for prompt, result in zip(default_prompts, results): 69 | assert result["text_prompts"] == prompt, ( 70 | f"Not match: {result['text_prompts']} != {prompt}" 71 | ) 72 | # Custom prompt with prefix and suffix 73 | with monkeypatch.context() as m2: 74 | m2.setattr( 75 | sys, 76 | "argv", 77 | commands 78 | + [ 79 | "--dump_dir", 80 | str(tmp_path / "test_2"), 81 | "--dataset.source_text_column", 82 | "input_text", 83 | "--dataset.source_prefix_text", 84 | "[Myprompt] ", 85 | "--dataset.source_suffix_text", 86 | "[/Myprompt]", 87 | ], 88 | ) 89 | eval_config, _ = cfg_from_cli() 90 | local_main(eval_config, logger=logger) 91 | 92 | custom_prompts = [f"[Myprompt] {x['input_text']}[/Myprompt]" for x in raw_data] 93 | custom_prompts = batched(custom_prompts, batch_size=bsz) # type: ignore 94 | results = load_jsonl( 95 | tmp_path.joinpath( 96 | "test_2", 97 | "raw_results", 98 | "dummy_json_generation", 99 | "dummy_json_generation.json", 100 | ) 101 | ) 102 | for prompt, result in zip(custom_prompts, results): 103 | assert result["text_prompts"] == prompt, ( 104 | f"Not match: {result['text_prompts']} != {prompt}" 105 | ) 106 | 107 | # Custom prompt with complex sequences of text 108 | with monkeypatch.context() as m3: 109 | m3.setattr( 110 | sys, 111 | "argv", 112 | commands 113 | + [ 114 | "--dump_dir", 115 | str(tmp_path / "test_3"), 116 | "--dataset.source_sequences", 117 | '{"text_value": "[SEQ]"}', 118 | "--dataset.source_sequences", 119 | '{"text_column": "input_text"}', 120 | "--dataset.source_sequences", 121 | '{"text_value": "-"}', 122 | "--dataset.source_sequences", 123 | '{"text_column": "input_text"}', 124 | "--dataset.source_sequences", 125 | '{"text_value": "[/SEQ]"}', 126 | ], 127 | ) 128 | eval_config, _ = cfg_from_cli() 129 | local_main(eval_config, logger=logger) 130 | 131 | custom_prompts = [ 132 | f"[SEQ] {x['input_text']} - {x['input_text']} [/SEQ]" for x in raw_data 133 | ] 134 | custom_prompts = batched(custom_prompts, batch_size=bsz) # type: ignore 135 | results = load_jsonl( 136 | tmp_path.joinpath( 137 | "test_3", 138 | "raw_results", 139 | "dummy_json_generation", 140 | "dummy_json_generation.json", 141 | ) 142 | ) 143 | for prompt, result in zip(custom_prompts, results): 144 | assert result["text_prompts"] == prompt, ( 145 | f"Not match: {result['text_prompts']} != {prompt}" 146 | ) 147 | --------------------------------------------------------------------------------