├── rcm ├── configs │ ├── experiments │ │ └── rcm │ │ │ └── __init__.py │ ├── defaults │ │ ├── trainer.py │ │ ├── tokenizer.py │ │ ├── ema.py │ │ ├── model.py │ │ ├── conditioner.py │ │ ├── scheduler.py │ │ ├── ckpt_type.py │ │ ├── dataloader.py │ │ ├── checkpoint.py │ │ ├── optimizer.py │ │ ├── net.py │ │ └── callbacks.py │ └── registry_distill.py ├── samplers │ └── euler.py ├── utils │ ├── lognormal.py │ ├── misc.py │ ├── selective_activation_checkpoint.py │ ├── optim_instantiate_dtensor.py │ ├── dtensor_helper.py │ ├── checkpointer.py │ ├── fsdp_helper.py │ └── attention.py ├── modules │ └── denoiser_scaling.py ├── datasets │ ├── webdataset.py │ ├── utils.py │ ├── merge_tar_shards.py │ └── visualize_tar.py ├── tokenizers │ └── interface.py ├── callbacks │ ├── compile_tokenizer.py │ ├── dataloading_monitor.py │ ├── grad_clip.py │ ├── iter_speed.py │ └── heart_beat.py └── networks │ └── wan2pt1_jvp_test.py ├── examples ├── i2v_input_1.jpg ├── i2v_input_2.jpg ├── i2v_input_3.jpg ├── i2v_input_4.jpg ├── i2v_input_5.jpg ├── i2v_input_6.jpg └── i2v_input_7.jpg ├── imaginaire ├── __init__.py ├── utils │ ├── __init__.py │ ├── easy_io │ │ ├── __init__.py │ │ ├── backends │ │ │ ├── __init__.py │ │ │ ├── base_backend.py │ │ │ ├── http_backend.py │ │ │ └── registry_utils.py │ │ └── handlers │ │ │ ├── torch_handler.py │ │ │ ├── pandas_handler.py │ │ │ ├── torchjit_handler.py │ │ │ ├── __init__.py │ │ │ ├── txt_handler.py │ │ │ ├── gzip_handler.py │ │ │ ├── byte_handler.py │ │ │ ├── yaml_handler.py │ │ │ ├── tarfile_handler.py │ │ │ ├── csv_handler.py │ │ │ ├── pickle_handler.py │ │ │ ├── base.py │ │ │ ├── json_handler.py │ │ │ ├── jsonl_handler.py │ │ │ ├── np_handler.py │ │ │ ├── registry_utils.py │ │ │ ├── pil_handler.py │ │ │ └── imageio_video_handler.py │ ├── device.py │ ├── parallel_state_helper.py │ ├── wandb_util.py │ ├── log.py │ ├── profiling.py │ └── io.py ├── callbacks │ ├── __init__.py │ ├── manual_gc.py │ ├── low_precision.py │ └── every_n.py ├── lazy_config │ ├── file_io.py │ ├── registry.py │ ├── omegaconf_patch.py │ ├── __init__.py │ └── instantiate.py └── model.py ├── scripts ├── dcp_to_pth.py └── train.py └── .gitignore /rcm/configs/experiments/rcm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/i2v_input_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_1.jpg -------------------------------------------------------------------------------- /examples/i2v_input_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_2.jpg -------------------------------------------------------------------------------- /examples/i2v_input_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_3.jpg -------------------------------------------------------------------------------- /examples/i2v_input_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_4.jpg -------------------------------------------------------------------------------- /examples/i2v_input_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_5.jpg -------------------------------------------------------------------------------- /examples/i2v_input_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_6.jpg -------------------------------------------------------------------------------- /examples/i2v_input_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_7.jpg -------------------------------------------------------------------------------- /rcm/configs/defaults/trainer.py: -------------------------------------------------------------------------------- 1 | from hydra.core.config_store import ConfigStore 2 | from imaginaire.lazy_config import LazyCall as L 3 | from imaginaire.lazy_config import PLACEHOLDER, LazyDict 4 | from imaginaire.trainer import ImaginaireTrainer 5 | from rcm.trainers.trainer_distillation import ImaginaireTrainer_Distill 6 | 7 | TRAINER: LazyDict = L(ImaginaireTrainer)(config=PLACEHOLDER) 8 | TRAINER_DISTILL: LazyDict = L(ImaginaireTrainer_Distill)(config=PLACEHOLDER) 9 | 10 | 11 | def register_trainer(): 12 | cs = ConfigStore.instance() 13 | cs.store(group="trainer", package="trainer.type", name="standard", node=TRAINER) 14 | cs.store(group="trainer", package="trainer.type", name="distill", node=TRAINER_DISTILL) 15 | -------------------------------------------------------------------------------- /imaginaire/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /imaginaire/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /imaginaire/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /imaginaire/lazy_config/file_io.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler 17 | from iopath.common.file_io import PathManager as PathManagerBase 18 | 19 | __all__ = ["PathHandler", "PathManager"] 20 | 21 | 22 | PathManager = PathManagerBase() 23 | PathManager.register_handler(HTTPURLHandler()) 24 | PathManager.register_handler(OneDrivePathHandler()) 25 | -------------------------------------------------------------------------------- /rcm/configs/defaults/tokenizer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | from imaginaire.lazy_config import LazyCall as L, LazyDict, PLACEHOLDER 18 | from rcm.tokenizers.wan2pt1 import Wan2pt1VAEInterface 19 | 20 | Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)(vae_pth=PLACEHOLDER) 21 | 22 | 23 | def register_tokenizer(): 24 | cs = ConfigStore.instance() 25 | cs.store(group="tokenizer", package="model.config.tokenizer", name="wan2pt1_tokenizer", node=Wan2pt1VAEConfig) 26 | -------------------------------------------------------------------------------- /rcm/samplers/euler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FlowEulerSampler: 5 | 6 | def __init__( 7 | self, 8 | num_train_timesteps=1000, 9 | sigma_max=1.0, 10 | sigma_min=0.0, 11 | ): 12 | self.num_train_timesteps = num_train_timesteps 13 | self.sigma_max = sigma_max 14 | self.sigma_min = sigma_min 15 | 16 | def set_timesteps(self, num_inference_steps=100, shift=3.0, device="cuda"): 17 | self.sigmas = torch.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1)[:-1] 18 | self.sigmas = shift * self.sigmas / (1 + (shift - 1) * self.sigmas) 19 | self.timesteps = self.sigmas * self.num_train_timesteps 20 | self.sigmas = self.sigmas.to(device) 21 | self.timesteps = self.timesteps.to(device) 22 | 23 | def step(self, model_output, timestep, sample): 24 | timestep_id = torch.argmin((self.timesteps - timestep).abs(), dim=0) 25 | sigma = self.sigmas[timestep_id] 26 | if timestep_id + 1 >= len(self.timesteps): 27 | sigma_ = 0 28 | else: 29 | sigma_ = self.sigmas[timestep_id + 1] 30 | prev_sample = sample + model_output * (sigma_ - sigma) 31 | return prev_sample 32 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/backends/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend 17 | from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend 18 | from imaginaire.utils.easy_io.backends.local_backend import LocalBackend 19 | from imaginaire.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend 20 | 21 | __all__ = [ 22 | "BaseStorageBackend", 23 | "HTTPBackend", 24 | "LocalBackend", 25 | "backends", 26 | "prefix_to_backends", 27 | "register_backend", 28 | ] 29 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/torch_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | try: 17 | import torch 18 | except ImportError: 19 | torch = None 20 | 21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 22 | 23 | 24 | class TorchHandler(BaseFileHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file, **kwargs): 28 | return torch.load(file, **kwargs) 29 | 30 | def dump_to_fileobj(self, obj, file, **kwargs): 31 | torch.save(obj, file, **kwargs) 32 | 33 | def dump_to_str(self, obj, **kwargs): 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/pandas_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pandas as pd 17 | 18 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip 19 | 20 | 21 | class PandasHandler(BaseFileHandler): 22 | str_like = False 23 | 24 | def load_from_fileobj(self, file, **kwargs): 25 | return pd.read_csv(file, **kwargs) 26 | 27 | def dump_to_fileobj(self, obj, file, **kwargs): 28 | obj.to_csv(file, **kwargs) 29 | 30 | def dump_to_str(self, obj, **kwargs): 31 | raise NotImplementedError("PandasHandler does not support dumping to str") 32 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/torchjit_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | try: 17 | import torch 18 | except ImportError: 19 | torch = None 20 | 21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 22 | 23 | 24 | class TorchJitHandler(BaseFileHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file, **kwargs): 28 | return torch.jit.load(file, **kwargs) 29 | 30 | def dump_to_fileobj(self, obj, file, **kwargs): 31 | torch.jit.save(obj, file, **kwargs) 32 | 33 | def dump_to_str(self, obj, **kwargs): 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /rcm/configs/defaults/ema.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import attrs 17 | from hydra.core.config_store import ConfigStore 18 | 19 | 20 | @attrs.define(slots=False) 21 | class EMAConfig: 22 | """ 23 | Config for the EMA. 24 | """ 25 | 26 | enabled: bool = True 27 | rate: float = 0.1 28 | iteration_shift: int = 0 29 | 30 | 31 | PowerEMAConfig: EMAConfig = EMAConfig( 32 | enabled=True, 33 | rate=0.10, 34 | iteration_shift=0, 35 | ) 36 | 37 | 38 | def register_ema(): 39 | cs = ConfigStore.instance() 40 | cs.store(group="ema", package="model.config.ema", name="power", node=PowerEMAConfig) 41 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 17 | from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler 18 | from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler 19 | from imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler 20 | from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler 21 | 22 | __all__ = [ 23 | "BaseFileHandler", 24 | "JsonHandler", 25 | "PickleHandler", 26 | "YamlHandler", 27 | "file_handlers", 28 | "register_handler", 29 | ] 30 | -------------------------------------------------------------------------------- /rcm/configs/defaults/model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.lazy_config import LazyCall as L 19 | from rcm.models.t2v_model_distill_rcm import T2VDistillConfig_rCM, T2VDistillModel_rCM 20 | 21 | FSDP_CONFIG_T2V_DISTILL_RCM = dict( 22 | trainer=dict(distributed_parallelism="fsdp"), 23 | model=L(T2VDistillModel_rCM)(config=T2VDistillConfig_rCM(fsdp_shard_size=8), _recursive_=False), 24 | ) 25 | 26 | 27 | def register_model(): 28 | cs = ConfigStore.instance() 29 | cs.store(group="model", package="_global_", name="fsdp_t2v_distill_rcm", node=FSDP_CONFIG_T2V_DISTILL_RCM) 30 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/txt_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 17 | 18 | 19 | class TxtHandler(BaseFileHandler): 20 | def load_from_fileobj(self, file, **kwargs): 21 | del kwargs 22 | return file.read() 23 | 24 | def dump_to_fileobj(self, obj, file, **kwargs): 25 | del kwargs 26 | if not isinstance(obj, str): 27 | obj = str(obj) 28 | file.write(obj) 29 | 30 | def dump_to_str(self, obj, **kwargs): 31 | del kwargs 32 | if not isinstance(obj, str): 33 | obj = str(obj) 34 | return obj 35 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/gzip_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import gzip 17 | import pickle 18 | from io import BytesIO 19 | from typing import Any 20 | 21 | from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler 22 | 23 | 24 | class GzipHandler(PickleHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file: BytesIO, **kwargs): 28 | with gzip.GzipFile(fileobj=file, mode="rb") as f: 29 | return pickle.load(f) 30 | 31 | def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): 32 | with gzip.GzipFile(fileobj=file, mode="wb") as f: 33 | pickle.dump(obj, f) 34 | -------------------------------------------------------------------------------- /rcm/configs/defaults/conditioner.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.lazy_config import LazyCall as L 19 | from imaginaire.lazy_config import LazyDict 20 | from rcm.conditioner import TextAttr, TextConditioner 21 | 22 | TextConditionerNoDropConfig: LazyDict = L(TextConditioner)( 23 | text=L(TextAttr)( 24 | input_key=["t5_text_embeddings"], 25 | dropout_rate=0.0, 26 | ), 27 | ) 28 | 29 | 30 | def register_conditioner(): 31 | cs = ConfigStore.instance() 32 | cs.store(group="conditioner", package="model.config.conditioner", name="text_nodrop", node=TextConditionerNoDropConfig) 33 | -------------------------------------------------------------------------------- /rcm/configs/defaults/scheduler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.lazy_config import LazyCall as L 19 | from imaginaire.lazy_config import LazyDict 20 | from rcm.utils.lr_scheduler import LambdaLinearScheduler 21 | 22 | LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( 23 | warm_up_steps=[100], 24 | cycle_lengths=[10000000000000], 25 | f_start=[1.0e-6], 26 | f_max=[1.0], 27 | f_min=[1.0], 28 | ) 29 | 30 | 31 | def register_scheduler(): 32 | cs = ConfigStore.instance() 33 | cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig) 34 | -------------------------------------------------------------------------------- /rcm/utils/lognormal.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from statistics import NormalDist 17 | 18 | import numpy as np 19 | import torch 20 | 21 | 22 | class LogNormal: 23 | def __init__( 24 | self, 25 | p_mean: float = 0.0, 26 | p_std: float = 1.0, 27 | ): 28 | self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) 29 | 30 | def __call__(self, batch_size: int) -> torch.Tensor: 31 | cdf_vals = np.random.uniform(size=(batch_size)) 32 | samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] 33 | 34 | log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") 35 | return torch.exp(log_sigma) 36 | 37 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/byte_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import IO 17 | 18 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 19 | 20 | 21 | class ByteHandler(BaseFileHandler): 22 | str_like = False 23 | 24 | def load_from_fileobj(self, file: IO[bytes], **kwargs): 25 | file.seek(0) 26 | # extra all bytes and return 27 | return file.read() 28 | 29 | def dump_to_fileobj( 30 | self, 31 | obj: bytes, 32 | file: IO[bytes], 33 | **kwargs, 34 | ): 35 | # write all bytes to file 36 | file.write(obj) 37 | 38 | def dump_to_str(self, obj, **kwargs): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /rcm/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def count_params(model: torch.nn.Module, verbose=False) -> int: 5 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 6 | if verbose: 7 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 8 | return total_params 9 | 10 | 11 | def common_broadcast(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 12 | ndims1 = x.ndim 13 | ndims2 = y.ndim 14 | 15 | common_ndims = min(ndims1, ndims2) 16 | for axis in range(common_ndims): 17 | assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) 18 | 19 | if ndims1 < ndims2: 20 | x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) 21 | elif ndims2 < ndims1: 22 | y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) 23 | 24 | return x, y 25 | 26 | 27 | def batch_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 28 | x, y = common_broadcast(x, y) 29 | return x + y 30 | 31 | 32 | def batch_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 33 | x, y = common_broadcast(x, y) 34 | return x * y 35 | 36 | 37 | def batch_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 38 | x, y = common_broadcast(x, y) 39 | return x - y 40 | 41 | 42 | def batch_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 43 | x, y = common_broadcast(x, y) 44 | return x / y 45 | -------------------------------------------------------------------------------- /rcm/configs/defaults/ckpt_type.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Dict 17 | 18 | from hydra.core.config_store import ConfigStore 19 | 20 | from imaginaire.lazy_config import LazyCall as L 21 | from rcm.checkpointers.dcp import DistributedCheckpointer 22 | from rcm.checkpointers.dcp_distill import DistributedCheckpointer_Distill 23 | 24 | DISTRIBUTED_CHECKPOINTER: Dict[str, str] = L(DistributedCheckpointer)() 25 | DISTRIBUTED_CHECKPOINTER_DISTILL: Dict[str, str] = L(DistributedCheckpointer_Distill)() 26 | 27 | 28 | def register_ckpt_type(): 29 | cs = ConfigStore.instance() 30 | cs.store(group="ckpt_type", package="checkpoint.type", name="dcp", node=DISTRIBUTED_CHECKPOINTER) 31 | cs.store(group="ckpt_type", package="checkpoint.type", name="dcp_distill", node=DISTRIBUTED_CHECKPOINTER_DISTILL) 32 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/yaml_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import yaml 17 | 18 | try: 19 | from yaml import CDumper as Dumper # type: ignore 20 | from yaml import CLoader as Loader # type: ignore 21 | except ImportError: 22 | from yaml import Dumper, Loader # type: ignore 23 | 24 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip 25 | 26 | 27 | class YamlHandler(BaseFileHandler): 28 | def load_from_fileobj(self, file, **kwargs): 29 | kwargs.setdefault("Loader", Loader) 30 | return yaml.load(file, **kwargs) 31 | 32 | def dump_to_fileobj(self, obj, file, **kwargs): 33 | kwargs.setdefault("Dumper", Dumper) 34 | yaml.dump(obj, file, **kwargs) 35 | 36 | def dump_to_str(self, obj, **kwargs): 37 | kwargs.setdefault("Dumper", Dumper) 38 | return yaml.dump(obj, **kwargs) 39 | -------------------------------------------------------------------------------- /rcm/modules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | class RectifiedFlow_TrigFlowWrapper: 20 | def __init__(self, sigma_data: float = 1.0, t_scaling_factor: float = 1.0): 21 | assert abs(sigma_data - 1.0) < 1e-6, "sigma_data must be 1.0 for RectifiedFlowScaling" 22 | self.t_scaling_factor = t_scaling_factor 23 | 24 | def __call__(self, trigflow_t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 25 | trigflow_t = trigflow_t.to(torch.float64) 26 | c_skip = 1 / (torch.cos(trigflow_t) + torch.sin(trigflow_t)) 27 | c_out = -1 * torch.sin(trigflow_t) / (torch.cos(trigflow_t) + torch.sin(trigflow_t)) 28 | c_in = 1 / (torch.cos(trigflow_t) + torch.sin(trigflow_t)) 29 | c_noise = (torch.sin(trigflow_t) / (torch.cos(trigflow_t) + torch.sin(trigflow_t))) * self.t_scaling_factor 30 | return c_skip, c_out, c_in, c_noise 31 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/tarfile_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tarfile 17 | 18 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 19 | 20 | 21 | class TarHandler(BaseFileHandler): 22 | str_like = False 23 | 24 | def load_from_fileobj(self, file, mode="r|*", **kwargs): 25 | return tarfile.open(fileobj=file, mode=mode, **kwargs) 26 | 27 | def load_from_path(self, filepath, mode="r|*", **kwargs): 28 | return tarfile.open(filepath, mode=mode, **kwargs) 29 | 30 | def dump_to_fileobj(self, obj, file, mode="w", **kwargs): 31 | with tarfile.open(fileobj=file, mode=mode) as tar: 32 | tar.add(obj, **kwargs) 33 | 34 | def dump_to_path(self, obj, filepath, mode="w", **kwargs): 35 | with tarfile.open(filepath, mode=mode) as tar: 36 | tar.add(obj, **kwargs) 37 | 38 | def dump_to_str(self, obj, **kwargs): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /imaginaire/utils/device.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | import os 18 | 19 | import pynvml 20 | 21 | 22 | class Device: 23 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore 24 | 25 | def __init__(self, device_idx: int): 26 | super().__init__() 27 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) 28 | 29 | def get_name(self) -> str: 30 | return pynvml.nvmlDeviceGetName(self.handle) 31 | 32 | def get_cpu_affinity(self) -> list[int]: 33 | affinity_string = "" 34 | for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): 35 | # assume nvml returns list of 64 bit ints 36 | affinity_string = f"{j:064b}" + affinity_string 37 | affinity_list = [int(x) for x in affinity_string] 38 | affinity_list.reverse() # so core 0 is in 0th element of list 39 | return [i for i, e in enumerate(affinity_list) if e != 0] 40 | -------------------------------------------------------------------------------- /rcm/configs/defaults/dataloader.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from hydra.core.config_store import ConfigStore 18 | 19 | from imaginaire.lazy_config import LazyCall as L 20 | from rcm.datasets.webdataset import create_dataloader 21 | 22 | 23 | DUMMY_DATALOADER = L(torch.utils.data.DataLoader)(dataset=lambda: torch.utils.data.TensorDataset(torch.empty(0, 1), torch.empty(0))) 24 | 25 | WEBDATASET_LOADER = L(create_dataloader)( 26 | tar_path_pattern="/path/to/dataset/shard_*.tar", 27 | batch_size=1, 28 | num_workers=8, 29 | shuffle_buffer=1000, 30 | prefetch_factor=2, 31 | ) 32 | 33 | 34 | def register_dataloader(): 35 | cs = ConfigStore() 36 | cs.store(group="data_train", package="dataloader_train", name="dummy", node=DUMMY_DATALOADER) 37 | cs.store(group="data_train", package="dataloader_train", name="webdataset", node=WEBDATASET_LOADER) 38 | cs.store(group="data_val", package="dataloader_val", name="dummy", node=DUMMY_DATALOADER) 39 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/csv_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import csv 17 | from io import StringIO 18 | 19 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 20 | 21 | 22 | class CsvHandler(BaseFileHandler): 23 | def load_from_fileobj(self, file, **kwargs): 24 | del kwargs 25 | reader = csv.reader(file) 26 | return list(reader) 27 | 28 | def dump_to_fileobj(self, obj, file, **kwargs): 29 | del kwargs 30 | writer = csv.writer(file) 31 | if not all(isinstance(row, list) for row in obj): 32 | raise ValueError("Each row must be a list") 33 | writer.writerows(obj) 34 | 35 | def dump_to_str(self, obj, **kwargs): 36 | del kwargs 37 | output = StringIO() 38 | writer = csv.writer(output) 39 | if not all(isinstance(row, list) for row in obj): 40 | raise ValueError("Each row must be a list") 41 | writer.writerows(obj) 42 | return output.getvalue() 43 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/pickle_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pickle 17 | from io import BytesIO 18 | from typing import Any 19 | 20 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 21 | 22 | 23 | class PickleHandler(BaseFileHandler): 24 | str_like = False 25 | 26 | def load_from_fileobj(self, file: BytesIO, **kwargs): 27 | return pickle.load(file, **kwargs) 28 | 29 | def load_from_path(self, filepath, **kwargs): 30 | return super().load_from_path(filepath, mode="rb", **kwargs) 31 | 32 | def dump_to_str(self, obj, **kwargs): 33 | kwargs.setdefault("protocol", 2) 34 | return pickle.dumps(obj, **kwargs) 35 | 36 | def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): 37 | kwargs.setdefault("protocol", 2) 38 | pickle.dump(obj, file, **kwargs) 39 | 40 | def dump_to_path(self, obj, filepath, **kwargs): 41 | with open(filepath, "wb") as f: 42 | pickle.dump(obj, f, **kwargs) 43 | -------------------------------------------------------------------------------- /imaginaire/utils/parallel_state_helper.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | This module contains various helper functions designed to extend the functionality of parallel states within the MCore library. 18 | 19 | MCore is a third-party library that is infrequently updated and may introduce backward compatibility issues in our codebase, such as changes in function signatures or missing / new functions in new versions. 20 | 21 | To mitigate these issues, this module provides stable functions that ensure the imaginaire codebase remains compatible with different versions of MCore. 22 | """ 23 | 24 | try: 25 | from megatron.core import parallel_state 26 | except ImportError: 27 | print("Megatron is not installed, is_tp_cp_pp_rank0 functions will not work.") 28 | 29 | 30 | def is_tp_cp_pp_rank0(): 31 | return ( 32 | parallel_state.get_tensor_model_parallel_rank() == 0 33 | and parallel_state.get_pipeline_model_parallel_rank() == 0 34 | and parallel_state.get_context_parallel_rank() == 0 35 | ) 36 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/base.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABCMeta, abstractmethod 17 | 18 | 19 | class BaseFileHandler(metaclass=ABCMeta): 20 | # `str_like` is a flag to indicate whether the type of file object is 21 | # str-like object or bytes-like object. Pickle only processes bytes-like 22 | # objects but json only processes str-like object. If it is str-like 23 | # object, `StringIO` will be used to process the buffer. 24 | str_like = True 25 | 26 | @abstractmethod 27 | def load_from_fileobj(self, file, **kwargs): 28 | pass 29 | 30 | @abstractmethod 31 | def dump_to_fileobj(self, obj, file, **kwargs): 32 | pass 33 | 34 | @abstractmethod 35 | def dump_to_str(self, obj, **kwargs): 36 | pass 37 | 38 | def load_from_path(self, filepath, mode="r", **kwargs): 39 | with open(filepath, mode) as f: 40 | return self.load_from_fileobj(f, **kwargs) 41 | 42 | def dump_to_path(self, obj, filepath, mode="w", **kwargs): 43 | with open(filepath, mode) as f: 44 | self.dump_to_fileobj(obj, f, **kwargs) 45 | -------------------------------------------------------------------------------- /rcm/configs/defaults/checkpoint.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.config import CheckpointConfig, ObjectStoreConfig 19 | 20 | s3_object_store = ObjectStoreConfig( 21 | enabled=True, 22 | credentials="credentials/s3_checkpoint.secret", 23 | bucket="checkpoints-us-east-1", 24 | ) 25 | 26 | CHECKPOINT_LOCAL = CheckpointConfig( 27 | save_to_object_store=ObjectStoreConfig(), 28 | save_iter=1000, 29 | load_from_object_store=ObjectStoreConfig(), 30 | load_path="", 31 | load_training_state=False, 32 | strict_resume=True, 33 | ) 34 | 35 | CHECKPOINT_S3 = CheckpointConfig( 36 | save_to_object_store=s3_object_store, 37 | save_iter=1000, 38 | load_from_object_store=s3_object_store, 39 | load_path="", 40 | load_training_state=False, 41 | strict_resume=True, 42 | ) 43 | 44 | 45 | def register_checkpoint(): 46 | cs = ConfigStore.instance() 47 | cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) 48 | cs.store(group="checkpoint", package="checkpoint", name="s3", node=CHECKPOINT_S3) 49 | -------------------------------------------------------------------------------- /rcm/datasets/webdataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import webdataset as wds 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | def dict_collation_fn(samples): 8 | if not samples: 9 | return {} 10 | 11 | keys = samples[0].keys() 12 | batched_dict = {key: [] for key in keys} 13 | 14 | for sample in samples: 15 | for key in keys: 16 | batched_dict[key].append(sample[key]) 17 | 18 | for key in keys: 19 | if isinstance(batched_dict[key][0], torch.Tensor): 20 | batched_dict[key] = torch.stack(batched_dict[key]) 21 | 22 | return batched_dict 23 | 24 | 25 | def create_dataloader( 26 | tar_path_pattern, # e.g., "/path/to/dataset/shard_*.tar" 27 | batch_size, 28 | num_workers=8, 29 | shuffle_buffer=1000, 30 | prefetch_factor=2, 31 | ): 32 | shards = glob.glob(tar_path_pattern) 33 | if not shards: 34 | raise FileNotFoundError(f"No files found with pattern '{tar_path_pattern}'") 35 | 36 | dataset = wds.DataPipeline( 37 | wds.SimpleShardList(shards), 38 | # this shuffles the shards 39 | wds.shuffle(1000), 40 | wds.split_by_node, 41 | wds.split_by_worker, 42 | wds.tarfile_to_samples(), 43 | # this shuffles the samples in memory 44 | wds.shuffle(shuffle_buffer), 45 | wds.decode(wds.handle_extension("pt", wds.torch_loads)), 46 | wds.rename(latents="latent.pt", t5_text_embeddings="embed.pt", prompts="prompt.txt"), 47 | wds.batched(batch_size, partial=False, collation_fn=dict_collation_fn), 48 | ) 49 | 50 | dataloader = DataLoader( 51 | dataset, 52 | batch_size=None, 53 | shuffle=False, 54 | num_workers=num_workers, 55 | pin_memory=True, 56 | prefetch_factor=prefetch_factor, 57 | ) 58 | 59 | return dataloader 60 | -------------------------------------------------------------------------------- /rcm/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # "width:height" 17 | IMAGE_RES_SIZE_INFO: dict[str, tuple[int, int]] = { 18 | "1024": {"1:1": (1024, 1024), "4:3": (1168, 880), "3:4": (880, 1168), "16:9": (1360, 768), "9:16": (768, 1360)}, 19 | "720": {"1:1": (960, 960), "4:3": (960, 704), "3:4": (704, 960), "16:9": (1280, 704), "9:16": (704, 1280)}, 20 | "512": {"1:1": (512, 512), "4:3": (640, 512), "3:4": (512, 640), "16:9": (640, 384), "9:16": (384, 640)}, 21 | "480": {"1:1": (480, 480), "4:3": (640, 480), "3:4": (480, 640), "16:9": (768, 432), "9:16": (432, 768)}, 22 | } 23 | 24 | # "width:height" 25 | VIDEO_RES_SIZE_INFO: dict[str, tuple[int, int]] = { 26 | "720": {"1:1": (960, 960), "4:3": (960, 704), "3:4": (704, 960), "16:9": (1280, 704), "9:16": (704, 1280)}, 27 | "512": {"1:1": (512, 512), "4:3": (640, 512), "3:4": (512, 640), "16:9": (640, 384), "9:16": (384, 640)}, 28 | "480": {"1:1": (480, 480), "4:3": (640, 480), "3:4": (480, 640), "16:9": (768, 432), "9:16": (432, 768)}, 29 | "480p": {"1:1": (640, 640), "4:3": (640, 480), "3:4": (480, 640), "16:9": (832, 480), "9:16": (480, 832)}, 30 | "720p": {"1:1": (960, 960), "4:3": (960, 720), "3:4": (720, 960), "16:9": (1280, 720), "9:16": (720, 1280)}, 31 | } 32 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/json_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | 18 | import numpy as np 19 | 20 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 21 | 22 | 23 | def set_default(obj): 24 | """Set default json values for non-serializable values. 25 | 26 | It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. 27 | It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, 28 | etc.) into plain numbers of plain python built-in types. 29 | """ 30 | if isinstance(obj, (set, range)): 31 | return list(obj) 32 | elif isinstance(obj, np.ndarray): 33 | return obj.tolist() 34 | elif isinstance(obj, np.generic): 35 | return obj.item() 36 | raise TypeError(f"{type(obj)} is unsupported for json dump") 37 | 38 | 39 | class JsonHandler(BaseFileHandler): 40 | def load_from_fileobj(self, file): 41 | return json.load(file) 42 | 43 | def dump_to_fileobj(self, obj, file, **kwargs): 44 | kwargs.setdefault("default", set_default) 45 | json.dump(obj, file, **kwargs) 46 | 47 | def dump_to_str(self, obj, **kwargs): 48 | kwargs.setdefault("default", set_default) 49 | return json.dumps(obj, **kwargs) 50 | -------------------------------------------------------------------------------- /scripts/dcp_to_pth.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | 18 | import torch 19 | from torch.distributed.checkpoint import FileSystemReader 20 | from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner 21 | from torch.distributed.checkpoint.state_dict_loader import _load_state_dict 22 | 23 | 24 | def parse_arguments() -> argparse.Namespace: 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--dcp_checkpoint_dir", type=str, default="checkpoints/iter_000010000/model") 27 | parser.add_argument("--save_path", type=str, default="saved_model.pt") 28 | return parser.parse_args() 29 | 30 | 31 | if __name__ == "__main__": 32 | args = parse_arguments() 33 | storage_reader = FileSystemReader(args.dcp_checkpoint_dir) 34 | 35 | sd = {} 36 | _load_state_dict(sd, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True) 37 | new_sd = {} 38 | for k, v in sd.items(): 39 | if k.startswith("net_ema."): 40 | new_key = k.replace("net_ema.", "net.") 41 | # Save in bf16 precision 42 | if v.is_floating_point(): 43 | new_sd[new_key] = v.to(torch.bfloat16) 44 | else: 45 | new_sd[new_key] = v 46 | torch.save(new_sd, args.save_path) 47 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/backends/base_backend.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import os.path as osp 18 | from abc import ABCMeta, abstractmethod 19 | 20 | 21 | def mkdir_or_exist(dir_name, mode=0o777): 22 | if dir_name == "": 23 | return 24 | dir_name = osp.expanduser(dir_name) 25 | os.makedirs(dir_name, mode=mode, exist_ok=True) 26 | 27 | 28 | def has_method(obj, method): 29 | return hasattr(obj, method) and callable(getattr(obj, method)) 30 | 31 | 32 | class BaseStorageBackend(metaclass=ABCMeta): 33 | """Abstract class of storage backends. 34 | 35 | All backends need to implement two apis: :meth:`get()` and 36 | :meth:`get_text()`. 37 | 38 | - :meth:`get()` reads the file as a byte stream. 39 | - :meth:`get_text()` reads the file as texts. 40 | """ 41 | 42 | # a flag to indicate whether the backend can create a symlink for a file 43 | # This attribute will be deprecated in future. 44 | _allow_symlink = False 45 | 46 | @property 47 | def allow_symlink(self): 48 | return self._allow_symlink 49 | 50 | @property 51 | def name(self): 52 | return self.__class__.__name__ 53 | 54 | @abstractmethod 55 | def get(self, filepath): 56 | pass 57 | 58 | @abstractmethod 59 | def get_text(self, filepath): 60 | pass 61 | -------------------------------------------------------------------------------- /imaginaire/callbacks/manual_gc.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import gc 17 | 18 | from imaginaire.callbacks.every_n import EveryN 19 | from imaginaire.utils import log 20 | 21 | 22 | class ManualGarbageCollection(EveryN): 23 | """ 24 | Disable auto gc and manually trigger garbage collection every N iterations 25 | It is super useful for large scale training to reduce gpu sync time! 26 | Can reach 50% speedup. 27 | 28 | It is important to note that this callback only disables gc in main process and have auto gc enabled in subprocesses. 29 | 30 | We start disable gc after warm_up iterations to avoid disabling gc in subprocesses, such as dataloader, which can cause OOM 31 | """ 32 | 33 | def __init__(self, *args, warm_up: int = 5, **kwargs): 34 | kwargs["barrier_after_run"] = False 35 | super().__init__(*args, **kwargs) 36 | 37 | self.counter = 0 38 | self.warm = warm_up 39 | 40 | def every_n_impl(self, trainer, model, data_batch, output_batch, loss, iteration): 41 | del trainer, model, data_batch, output_batch, loss 42 | self.counter += 1 43 | if self.counter < self.warm: 44 | return 45 | if self.counter == self.warm: 46 | gc.disable() 47 | log.critical("Garbage collection disabled") 48 | 49 | gc.collect(1) 50 | -------------------------------------------------------------------------------- /rcm/configs/defaults/optimizer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.lazy_config import PLACEHOLDER, LazyDict 19 | from imaginaire.lazy_config import LazyCall as L 20 | from rcm.utils.optim_instantiate_dtensor import get_base_optimizer 21 | 22 | AdamWConfig = L(get_base_optimizer)( 23 | model=PLACEHOLDER, 24 | lr=1e-4, 25 | weight_decay=0.1, 26 | betas=[0.9, 0.99], 27 | optim_type="adamw", 28 | eps=1e-8, 29 | ) 30 | 31 | FusedAdamWConfig: LazyDict = L(get_base_optimizer)( 32 | model=PLACEHOLDER, 33 | lr=1e-4, 34 | weight_decay=0.1, 35 | betas=[0.9, 0.99], 36 | optim_type="fusedadam", 37 | eps=1e-8, 38 | master_weights=True, 39 | capturable=True, 40 | ) 41 | 42 | 43 | def register_optimizer(): 44 | cs = ConfigStore.instance() 45 | cs.store(group="optimizer", package="optimizer", name="fusedadamw", node=FusedAdamWConfig) 46 | cs.store(group="optimizer", package="optimizer", name="adamw", node=AdamWConfig) 47 | 48 | 49 | def register_optimizer_fake_score(): 50 | cs = ConfigStore.instance() 51 | cs.store(group="optimizer_fake_score", package="model.config.optimizer_fake_score", name="fusedadamw", node=FusedAdamWConfig) 52 | cs.store(group="optimizer_fake_score", package="model.config.optimizer_fake_score", name="adamw", node=AdamWConfig) 53 | -------------------------------------------------------------------------------- /rcm/tokenizers/interface.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | 20 | 21 | class VideoTokenizerInterface(ABC): 22 | def __init__(self): # noqa: B027 23 | pass 24 | 25 | @abstractmethod 26 | def reset_dtype(self): 27 | """ 28 | Reset the dtype of the model to the dtype its weights were trained with or quantized to. 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def encode(self, state: torch.Tensor) -> torch.Tensor: 34 | pass 35 | 36 | @abstractmethod 37 | def decode(self, latent: torch.Tensor) -> torch.Tensor: 38 | pass 39 | 40 | @abstractmethod 41 | def get_latent_num_frames(self, num_pixel_frames: int) -> int: 42 | pass 43 | 44 | @abstractmethod 45 | def get_pixel_num_frames(self, num_latent_frames: int) -> int: 46 | pass 47 | 48 | @property 49 | @abstractmethod 50 | def spatial_compression_factor(self): 51 | pass 52 | 53 | @property 54 | @abstractmethod 55 | def temporal_compression_factor(self): 56 | pass 57 | 58 | @property 59 | @abstractmethod 60 | def spatial_resolution(self): 61 | pass 62 | 63 | @property 64 | @abstractmethod 65 | def pixel_chunk_duration(self): 66 | pass 67 | 68 | @property 69 | @abstractmethod 70 | def latent_chunk_duration(self): 71 | pass 72 | 73 | @property 74 | @abstractmethod 75 | def latent_ch(self) -> int: 76 | pass 77 | 78 | @property 79 | def is_chunk_overlap(self): 80 | return False 81 | 82 | @property 83 | def is_causal(self): 84 | return True 85 | -------------------------------------------------------------------------------- /rcm/callbacks/compile_tokenizer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from imaginaire.utils import log 19 | from imaginaire.utils.callback import Callback 20 | 21 | 22 | class CompileTokenizer(Callback): 23 | def __init__(self, enabled: bool = False, compile_after_iterations: int = 4, dynamic: bool = False): 24 | super().__init__() 25 | self.enabled = enabled 26 | self.compiled = False 27 | self.compile_after_iterations = compile_after_iterations 28 | self.skip_counter = 0 29 | self.dynamic = dynamic # If there are issues with constant recompilations you may set this value to None or True 30 | 31 | def on_training_step_start(self, model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None: 32 | if not self.enabled or self.compiled: 33 | return 34 | 35 | if isinstance(model.tokenizer, torch.jit.ScriptModule): 36 | log.critical(f"The Tokenizer model {type(model.tokenizer)} is a JIT model, which is not compilable. The Tokenizer will not be compiled.") 37 | 38 | if self.skip_counter == self.compile_after_iterations: 39 | try: 40 | # PyTorch >= 2.7 41 | torch._dynamo.config.recompile_limit = 32 42 | except AttributeError: 43 | try: 44 | torch._dynamo.config.cache_size_limit = 32 45 | except AttributeError: 46 | log.warning("Tokenizer compilation requested, but Torch Dynamo is unavailable – skipping compilation.") 47 | self.enabled = False 48 | return 49 | 50 | model.tokenizer.encode = torch.compile(model.tokenizer.encode, dynamic=self.dynamic) 51 | model.tokenizer.decode = torch.compile(model.tokenizer.decode, dynamic=self.dynamic) 52 | self.compiled = True 53 | self.skip_counter += 1 54 | -------------------------------------------------------------------------------- /rcm/utils/selective_activation_checkpoint.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | from enum import Enum 18 | 19 | import torch 20 | 21 | try: 22 | from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts, noop_context_fn 23 | except ImportError: 24 | CheckpointPolicy = None 25 | 26 | mm_only_save_list = { 27 | torch.ops.aten.mm.default, 28 | torch.ops.aten._scaled_dot_product_efficient_attention.default, 29 | torch.ops.aten._scaled_dot_product_flash_attention.default, 30 | torch.ops.aten.addmm.default, 31 | } 32 | 33 | 34 | class CheckpointMode(str, Enum): 35 | """ 36 | Enum for the different checkpoint modes. 37 | """ 38 | 39 | NONE = "none" 40 | MM_ONLY = "mm_only" 41 | BLOCK_WISE = "block_wise" 42 | 43 | def __str__(self) -> str: 44 | # Optional: makes print() show just the value 45 | return self.value 46 | 47 | 48 | def mm_only_policy(ctx, func, *args, **kwargs): 49 | """ 50 | In newer flash-attn and TE versions, FA2 shows up in the list of ops with the name of 'flash_attn._flash_attn_forward'. 51 | However, FA2 is much slower (2-3x) than FA3 or cuDNN kernel. Registering cuDNN kernel would require heavy changes in TE code. 52 | That's why the best option is to use FA3 with small modifications to flash_attn_interface.py to register FA3 as PyTorch op. 53 | """ 54 | to_save = func in mm_only_save_list or "flash_attn" in str(func) 55 | return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE 56 | 57 | 58 | def mm_only_context_fn(): 59 | return create_selective_checkpoint_contexts(mm_only_policy) 60 | 61 | 62 | @dataclass 63 | class SACConfig: 64 | mode: str = "mm_only" 65 | every_n_blocks: int = 1 66 | 67 | def get_context_fn(self): 68 | if self.mode == CheckpointMode.MM_ONLY: 69 | return mm_only_context_fn 70 | elif self.mode == CheckpointMode.BLOCK_WISE: 71 | return noop_context_fn 72 | else: 73 | raise ValueError(f"Invalid mode: {self.mode}") 74 | -------------------------------------------------------------------------------- /imaginaire/lazy_config/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pydoc 17 | from typing import Any 18 | 19 | from fvcore.common.registry import Registry # for backward compatibility. 20 | 21 | """ 22 | ``Registry`` and `locate` provide ways to map a string (typically found 23 | in config files) to callable objects. 24 | """ 25 | 26 | __all__ = ["Registry", "locate"] 27 | 28 | 29 | def _convert_target_to_string(t: Any) -> str: 30 | """ 31 | Inverse of ``locate()``. 32 | 33 | Args: 34 | t: any object with ``__module__`` and ``__qualname__`` 35 | """ 36 | module, qualname = t.__module__, t.__qualname__ 37 | 38 | # Compress the path to this object, e.g. ``module.submodule._impl.class`` 39 | # may become ``module.submodule.class``, if the later also resolves to the same 40 | # object. This simplifies the string, and also is less affected by moving the 41 | # class implementation. 42 | module_parts = module.split(".") 43 | for k in range(1, len(module_parts)): 44 | prefix = ".".join(module_parts[:k]) 45 | candidate = f"{prefix}.{qualname}" 46 | try: 47 | if locate(candidate) is t: 48 | return candidate 49 | except ImportError: 50 | pass 51 | return f"{module}.{qualname}" 52 | 53 | 54 | def locate(name: str) -> Any: 55 | """ 56 | Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, 57 | such as "module.submodule.class_name". 58 | 59 | Raise Exception if it cannot be found. 60 | """ 61 | obj = pydoc.locate(name) 62 | 63 | # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly 64 | # by pydoc.locate. Try a private function from hydra. 65 | if obj is None: 66 | try: 67 | # from hydra.utils import get_method - will print many errors 68 | from hydra.utils import _locate 69 | except ImportError as e: 70 | raise ImportError(f"Cannot dynamically locate object {name}!") from e 71 | else: 72 | obj = _locate(name) # it raises if fails 73 | 74 | return obj 75 | -------------------------------------------------------------------------------- /imaginaire/lazy_config/omegaconf_patch.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any 17 | 18 | from omegaconf import OmegaConf 19 | from omegaconf.base import DictKeyType, SCMode 20 | from omegaconf.dictconfig import DictConfig # pragma: no cover 21 | 22 | 23 | def to_object(cfg: Any) -> dict[DictKeyType, Any] | list[Any] | None | str | Any: 24 | """ 25 | Converts an OmegaConf configuration object to a native Python container (dict or list), unless 26 | the configuration is specifically created by LazyCall, in which case the original configuration 27 | is returned directly. 28 | 29 | This function serves as a modification of the original `to_object` method from OmegaConf, 30 | preventing DictConfig objects created by LazyCall from being automatically converted to Python 31 | dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended 32 | structure and behavior. 33 | 34 | Differences from OmegaConf's original `to_object`: 35 | - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. 36 | 37 | Reference: 38 | - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 39 | 40 | Args: 41 | cfg (Any): The OmegaConf configuration object to convert. 42 | 43 | Returns: 44 | Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if 45 | `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. 46 | 47 | Examples: 48 | >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) 49 | >>> to_object(cfg) 50 | DictConfig({"key": "value", "_target_": "Model"}) 51 | 52 | >>> cfg = DictConfig({"list": [1, 2, 3]}) 53 | >>> to_object(cfg) 54 | {'list': [1, 2, 3]} 55 | """ 56 | if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): 57 | return cfg 58 | 59 | return OmegaConf.to_container( 60 | cfg=cfg, 61 | resolve=True, 62 | throw_on_missing=True, 63 | enum_to_str=False, 64 | structured_config_mode=SCMode.INSTANTIATE, 65 | ) 66 | -------------------------------------------------------------------------------- /imaginaire/lazy_config/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | 18 | from omegaconf import OmegaConf 19 | 20 | from imaginaire.lazy_config.instantiate import instantiate 21 | from imaginaire.lazy_config.lazy import LazyCall, LazyConfig, LazyDict 22 | from imaginaire.lazy_config.omegaconf_patch import to_object 23 | 24 | OmegaConf.to_object = to_object 25 | 26 | PLACEHOLDER = None 27 | 28 | __all__ = ["PLACEHOLDER", "LazyCall", "LazyConfig", "LazyDict", "instantiate"] 29 | 30 | 31 | DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py 32 | 33 | 34 | def fixup_module_metadata(module_name, namespace, keys=None): 35 | """ 36 | Fix the __qualname__ of module members to be their exported api name, so 37 | when they are referenced in docs, sphinx can find them. Reference: 38 | https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 39 | """ 40 | if not DOC_BUILDING: 41 | return 42 | seen_ids = set() 43 | 44 | def fix_one(qualname, name, obj): 45 | # avoid infinite recursion (relevant when using 46 | # typing.Generic, for example) 47 | if id(obj) in seen_ids: 48 | return 49 | seen_ids.add(id(obj)) 50 | 51 | mod = getattr(obj, "__module__", None) 52 | if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): 53 | obj.__module__ = module_name 54 | # Modules, unlike everything else in Python, put fully-qualitied 55 | # names into their __name__ attribute. We check for "." to avoid 56 | # rewriting these. 57 | if hasattr(obj, "__name__") and "." not in obj.__name__: 58 | obj.__name__ = name 59 | obj.__qualname__ = qualname 60 | if isinstance(obj, type): 61 | for attr_name, attr_value in obj.__dict__.items(): 62 | fix_one(objname + "." + attr_name, attr_name, attr_value) 63 | 64 | if keys is None: 65 | keys = namespace.keys() 66 | for objname in keys: 67 | if not objname.startswith("_"): 68 | obj = namespace[objname] 69 | fix_one(objname, objname, obj) 70 | 71 | 72 | fixup_module_metadata(__name__, globals(), __all__) 73 | del fixup_module_metadata 74 | -------------------------------------------------------------------------------- /rcm/configs/defaults/net.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.lazy_config import LazyCall as L 19 | from imaginaire.lazy_config import LazyDict 20 | 21 | from rcm.networks.wan2pt1 import WanModel 22 | from rcm.networks.wan2pt1_jvp import WanModel_JVP 23 | 24 | wan2pt1_1pt3B_net_args = dict( 25 | dim=1536, 26 | eps=1e-06, 27 | ffn_dim=8960, 28 | freq_dim=256, 29 | in_dim=16, 30 | num_heads=12, 31 | num_layers=30, 32 | out_dim=16, 33 | text_len=512, 34 | ) 35 | 36 | wan2pt1_14B_net_args = dict( 37 | dim=5120, 38 | eps=1e-06, 39 | ffn_dim=13824, 40 | freq_dim=256, 41 | in_dim=16, 42 | num_heads=40, 43 | num_layers=40, 44 | out_dim=16, 45 | text_len=512, 46 | ) 47 | 48 | WAN2PT1_1PT3B_T2V: LazyDict = L(WanModel)(**wan2pt1_1pt3B_net_args, model_type="t2v") 49 | 50 | WAN2PT1_14B_T2V: LazyDict = L(WanModel)(**wan2pt1_14B_net_args, model_type="t2v") 51 | 52 | WAN2PT1_1PT3B_T2V_JVP: LazyDict = L(WanModel_JVP)(**wan2pt1_1pt3B_net_args, model_type="t2v") 53 | 54 | WAN2PT1_14B_T2V_JVP: LazyDict = L(WanModel_JVP)(**wan2pt1_14B_net_args, model_type="t2v") 55 | 56 | 57 | def register_net(): 58 | cs = ConfigStore.instance() 59 | cs.store(group="net", package="model.config.net", name="wan2pt1_1pt3B_t2v", node=WAN2PT1_1PT3B_T2V) 60 | cs.store(group="net", package="model.config.net", name="wan2pt1_14B_t2v", node=WAN2PT1_14B_T2V) 61 | cs.store(group="net", package="model.config.net", name="wan2pt1_1pt3B_t2v_jvp", node=WAN2PT1_1PT3B_T2V_JVP) 62 | cs.store(group="net", package="model.config.net", name="wan2pt1_14B_t2v_jvp", node=WAN2PT1_14B_T2V_JVP) 63 | 64 | 65 | def register_net_fake_score(): 66 | cs = ConfigStore.instance() 67 | cs.store(group="net_fake_score", package="model.config.net_fake_score", name="wan2pt1_1pt3B_t2v", node=WAN2PT1_1PT3B_T2V) 68 | cs.store(group="net_fake_score", package="model.config.net_fake_score", name="wan2pt1_14B_t2v", node=WAN2PT1_14B_T2V) 69 | 70 | 71 | def register_net_teacher(): 72 | cs = ConfigStore.instance() 73 | cs.store(group="net_teacher", package="model.config.net_teacher", name="wan2pt1_1pt3B_t2v", node=WAN2PT1_1PT3B_T2V) 74 | cs.store(group="net_teacher", package="model.config.net_teacher", name="wan2pt1_14B_t2v", node=WAN2PT1_14B_T2V) 75 | -------------------------------------------------------------------------------- /rcm/utils/optim_instantiate_dtensor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import hydra 17 | import torch 18 | from omegaconf import ListConfig 19 | from torch import nn 20 | 21 | from rcm.utils.fused_adam_dtensor import FusedAdam 22 | from imaginaire.utils import log 23 | 24 | 25 | def get_regular_param_group(net: nn.Module): 26 | """ 27 | seperate the parameters of the network into two groups: decay and no_decay. 28 | based on nano_gpt codebase. 29 | """ 30 | param_dict = {pn: p for pn, p in net.named_parameters()} 31 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 32 | 33 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 34 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 35 | return decay_params, nodecay_params 36 | 37 | 38 | def get_base_optimizer( 39 | model: nn.Module, 40 | lr: float, 41 | weight_decay: float, 42 | optim_type: str = "adamw", 43 | **kwargs, 44 | ) -> torch.optim.Optimizer: 45 | net_decay_param, net_nodecay_param = get_regular_param_group(model) 46 | 47 | num_decay_params = sum(p.numel() for p in net_decay_param) 48 | num_nodecay_params = sum(p.numel() for p in net_nodecay_param) 49 | net_param_total = num_decay_params + num_nodecay_params 50 | log.critical(f"total num parameters : {net_param_total:,}") 51 | 52 | param_group = [ 53 | { 54 | "params": net_decay_param + net_nodecay_param, 55 | "lr": lr, 56 | "weight_decay": weight_decay, 57 | }, 58 | ] 59 | 60 | if optim_type == "adamw": 61 | opt_cls = torch.optim.AdamW 62 | elif optim_type == "fusedadam": 63 | opt_cls = FusedAdam 64 | else: 65 | raise ValueError(f"Unknown optimizer type: {optim_type}") 66 | 67 | for k, v in kwargs.items(): 68 | if isinstance(v, ListConfig): 69 | kwargs[k] = list(v) 70 | 71 | return opt_cls(param_group, **kwargs) 72 | 73 | 74 | def get_base_scheduler( 75 | optimizer: torch.optim.Optimizer, 76 | model: nn.Module, 77 | scheduler_config: dict, 78 | ): 79 | net_scheduler = hydra.utils.instantiate(scheduler_config) 80 | net_scheduler.model = model 81 | 82 | return torch.optim.lr_scheduler.LambdaLR( 83 | optimizer, 84 | lr_lambda=[ 85 | net_scheduler.schedule, 86 | ], 87 | ) 88 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/jsonl_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | from typing import IO 18 | 19 | import numpy as np 20 | 21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 22 | 23 | 24 | def set_default(obj): 25 | """Set default json values for non-serializable values. 26 | 27 | It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. 28 | It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, 29 | etc.) into plain numbers of plain python built-in types. 30 | """ 31 | if isinstance(obj, (set, range)): 32 | return list(obj) 33 | elif isinstance(obj, np.ndarray): 34 | return obj.tolist() 35 | elif isinstance(obj, np.generic): 36 | return obj.item() 37 | raise TypeError(f"{type(obj)} is unsupported for json dump") 38 | 39 | 40 | class JsonlHandler(BaseFileHandler): 41 | """Handler for JSON lines (JSONL) files.""" 42 | 43 | def load_from_fileobj(self, file: IO[bytes]): 44 | """Load JSON objects from a newline-delimited JSON (JSONL) file object. 45 | 46 | Returns: 47 | A list of Python objects loaded from each JSON line. 48 | """ 49 | data = [] 50 | for line in file: 51 | line = line.strip() 52 | if not line: 53 | continue # skip empty lines if any 54 | data.append(json.loads(line)) 55 | return data 56 | 57 | def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs): 58 | """Dump a list of objects to a newline-delimited JSON (JSONL) file object. 59 | 60 | Args: 61 | obj: A list (or iterable) of objects to dump line by line. 62 | """ 63 | kwargs.setdefault("default", set_default) 64 | for item in obj: 65 | file.write(json.dumps(item, **kwargs) + "\n") 66 | 67 | def dump_to_str(self, obj, **kwargs): 68 | """Dump a list of objects to a newline-delimited JSON (JSONL) string.""" 69 | kwargs.setdefault("default", set_default) 70 | lines = [json.dumps(item, **kwargs) for item in obj] 71 | return "\n".join(lines) 72 | 73 | 74 | if __name__ == "__main__": 75 | from imaginaire.utils.easy_io import easy_io 76 | 77 | easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl") 78 | print(easy_io.load("test.jsonl")) 79 | easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl") 80 | print(easy_io.load("test.jsonl")) 81 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/np_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from io import BytesIO 17 | from typing import IO, Any 18 | 19 | import numpy as np 20 | 21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 22 | 23 | 24 | class NumpyHandler(BaseFileHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any: 28 | """ 29 | Load a NumPy array from a file-like object. 30 | 31 | Parameters: 32 | file (IO[bytes]): The file-like object containing the NumPy array data. 33 | **kwargs: Additional keyword arguments passed to `np.load`. 34 | 35 | Returns: 36 | numpy.ndarray: The loaded NumPy array. 37 | """ 38 | return np.load(file, **kwargs) 39 | 40 | def load_from_path(self, filepath: str, **kwargs) -> Any: 41 | """ 42 | Load a NumPy array from a file path. 43 | 44 | Parameters: 45 | filepath (str): The path to the file to load. 46 | **kwargs: Additional keyword arguments passed to `np.load`. 47 | 48 | Returns: 49 | numpy.ndarray: The loaded NumPy array. 50 | """ 51 | return super().load_from_path(filepath, mode="rb", **kwargs) 52 | 53 | def dump_to_str(self, obj: np.ndarray, **kwargs) -> str: 54 | """ 55 | Serialize a NumPy array to a string in binary format. 56 | 57 | Parameters: 58 | obj (np.ndarray): The NumPy array to serialize. 59 | **kwargs: Additional keyword arguments passed to `np.save`. 60 | 61 | Returns: 62 | str: The serialized NumPy array as a string. 63 | """ 64 | with BytesIO() as f: 65 | np.save(f, obj, **kwargs) 66 | return f.getvalue() 67 | 68 | def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs): 69 | """ 70 | Dump a NumPy array to a file-like object. 71 | 72 | Parameters: 73 | obj (np.ndarray): The NumPy array to dump. 74 | file (IO[bytes]): The file-like object to which the array is dumped. 75 | **kwargs: Additional keyword arguments passed to `np.save`. 76 | """ 77 | np.save(file, obj, **kwargs) 78 | 79 | def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs): 80 | """ 81 | Dump a NumPy array to a file path. 82 | 83 | Parameters: 84 | obj (np.ndarray): The NumPy array to dump. 85 | filepath (str): The file path where the array should be saved. 86 | **kwargs: Additional keyword arguments passed to `np.save`. 87 | """ 88 | with open(filepath, "wb") as f: 89 | np.save(f, obj, **kwargs) 90 | -------------------------------------------------------------------------------- /rcm/datasets/merge_tar_shards.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Merge many small TAR shards into new larger shards. 4 | Input: a directory containing shards like shard-00000.tar, shard-00001.tar, ... 5 | Output: new shards with target_shard_size samples each. 6 | 7 | Usage: 8 | python merge_tar_shards.py \ 9 | --input_dir path/to/small_shards \ 10 | --output_dir path/to/large_shards \ 11 | --target_shard_size 5000 12 | 13 | Each sample is assumed to be a group of tar members with common prefix, e.g. 14 | 000000.jpg 15 | 000000.txt 16 | 000001.jpg 17 | 000001.txt 18 | This script groups files by prefix before writing. 19 | """ 20 | 21 | import os 22 | import tarfile 23 | import argparse 24 | from collections import defaultdict 25 | 26 | 27 | def read_samples_from_tar(tar_path): 28 | """Yield samples as {key: {filename: bytes}}. 29 | Group consecutive members by prefix before '.' 30 | """ 31 | samples = defaultdict(dict) 32 | with tarfile.open(tar_path, "r") as tar: 33 | for m in tar.getmembers(): 34 | if not m.isfile(): 35 | continue 36 | name = os.path.basename(m.name) 37 | prefix = name.split(".")[0] 38 | f = tar.extractfile(m) 39 | if f is None: 40 | continue 41 | samples[prefix][name] = f.read() 42 | for key, files in samples.items(): 43 | yield key, files 44 | 45 | 46 | def write_shard(samples, out_path): 47 | """Write a list of samples to one TAR shard. 48 | samples: list of (key, {filename: bytes}) 49 | """ 50 | with tarfile.open(out_path, "w") as tar: 51 | for key, file_dict in samples: 52 | for fname, data in file_dict.items(): 53 | info = tarfile.TarInfo(name=f"{key}/{fname}") 54 | info.size = len(data) 55 | tar.addfile(info, fileobj=BytesIO(data)) 56 | 57 | 58 | from io import BytesIO 59 | 60 | def merge_shards(input_dir, output_dir, target_shard_size): 61 | os.makedirs(output_dir, exist_ok=True) 62 | shard_paths = sorted( 63 | [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".tar")] 64 | ) 65 | 66 | current_samples = [] 67 | new_shard_index = 0 68 | 69 | for tar_path in shard_paths: 70 | for key, files in read_samples_from_tar(tar_path): 71 | current_samples.append((key, files)) 72 | if len(current_samples) >= target_shard_size: 73 | out_path = os.path.join(output_dir, f"shard-{new_shard_index:05d}.tar") 74 | write_shard(current_samples, out_path) 75 | print(f"Wrote {out_path}, {len(current_samples)} samples") 76 | new_shard_index += 1 77 | current_samples = [] 78 | 79 | # final remainder 80 | if current_samples: 81 | out_path = os.path.join(output_dir, f"shard-{new_shard_index:05d}.tar") 82 | write_shard(current_samples, out_path) 83 | print(f"Wrote {out_path}, {len(current_samples)} samples") 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--input_dir", required=True) 89 | parser.add_argument("--output_dir", required=True) 90 | parser.add_argument("--target_shard_size", type=int, required=True) 91 | args = parser.parse_args() 92 | 93 | merge_shards(args.input_dir, args.output_dir, args.target_shard_size) 94 | -------------------------------------------------------------------------------- /imaginaire/callbacks/low_precision.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from imaginaire.model import ImaginaireModel 19 | from imaginaire.trainer import ImaginaireTrainer 20 | from imaginaire.utils import distributed 21 | from imaginaire.utils.callback import Callback 22 | from imaginaire.config import Config 23 | from imaginaire.utils.misc import get_local_tensor_if_DTensor 24 | 25 | 26 | def update_master_weights(optimizer: torch.optim.Optimizer): 27 | if getattr(optimizer, "master_weights", False) and optimizer.param_groups_master is not None: 28 | params, master_params = [], [] 29 | for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): 30 | for p, p_master in zip(group["params"], group_master["params"]): 31 | params.append(get_local_tensor_if_DTensor(p.data)) 32 | master_params.append(p_master.data) 33 | torch._foreach_copy_(params, master_params) 34 | 35 | 36 | class LowPrecisionCallback(Callback): 37 | """The callback class handling low precision training 38 | 39 | Config with non-primitive type makes it difficult to override the option. 40 | The callback gets precision from model.precision instead. 41 | It also auto disabled when using fp32. 42 | """ 43 | 44 | def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int): 45 | self.update_iter = update_iter 46 | 47 | def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: 48 | assert model.precision in [ 49 | torch.bfloat16, 50 | torch.float16, 51 | torch.half, 52 | ], "LowPrecisionCallback must use a low precision dtype." 53 | self.precision_type = model.precision 54 | 55 | def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None: 56 | for k, v in data.items(): 57 | if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): 58 | data[k] = v.to(dtype=self.precision_type) 59 | 60 | def on_validation_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None: 61 | for k, v in data.items(): 62 | if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): 63 | data[k] = v.to(dtype=self.precision_type) 64 | 65 | def on_before_zero_grad( 66 | self, 67 | model_ddp: distributed.DistributedDataParallel, 68 | optimizer: torch.optim.Optimizer, 69 | scheduler: torch.optim.lr_scheduler.LRScheduler, 70 | iteration: int = 0, 71 | ) -> None: 72 | if iteration % self.update_iter == 0: 73 | update_master_weights(optimizer) 74 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/backends/http_backend.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import tempfile 18 | from collections.abc import Generator 19 | from contextlib import contextmanager 20 | from pathlib import Path 21 | from urllib.request import urlopen 22 | 23 | from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend 24 | 25 | 26 | class HTTPBackend(BaseStorageBackend): 27 | """HTTP and HTTPS storage bachend.""" 28 | 29 | def get(self, filepath: str) -> bytes: 30 | """Read bytes from a given ``filepath``. 31 | 32 | Args: 33 | filepath (str): Path to read data. 34 | 35 | Returns: 36 | bytes: Expected bytes object. 37 | 38 | Examples: 39 | >>> backend = HTTPBackend() 40 | >>> backend.get('http://path/of/file') 41 | b'hello world' 42 | """ 43 | return urlopen(filepath).read() 44 | 45 | def get_text(self, filepath, encoding="utf-8") -> str: 46 | """Read text from a given ``filepath``. 47 | 48 | Args: 49 | filepath (str): Path to read data. 50 | encoding (str): The encoding format used to open the ``filepath``. 51 | Defaults to 'utf-8'. 52 | 53 | Returns: 54 | str: Expected text reading from ``filepath``. 55 | 56 | Examples: 57 | >>> backend = HTTPBackend() 58 | >>> backend.get_text('http://path/of/file') 59 | 'hello world' 60 | """ 61 | return urlopen(filepath).read().decode(encoding) 62 | 63 | @contextmanager 64 | def get_local_path(self, filepath: str) -> Generator[str | Path, None, None]: 65 | """Download a file from ``filepath`` to a local temporary directory, 66 | and return the temporary path. 67 | 68 | ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It 69 | can be called with ``with`` statement, and when exists from the 70 | ``with`` statement, the temporary path will be released. 71 | 72 | Args: 73 | filepath (str): Download a file from ``filepath``. 74 | 75 | Yields: 76 | Iterable[str]: Only yield one temporary path. 77 | 78 | Examples: 79 | >>> backend = HTTPBackend() 80 | >>> # After existing from the ``with`` clause, 81 | >>> # the path will be removed 82 | >>> with backend.get_local_path('http://path/of/file') as path: 83 | ... # do something here 84 | """ 85 | try: 86 | f = tempfile.NamedTemporaryFile(delete=False) 87 | f.write(self.get(filepath)) 88 | f.close() 89 | yield f.name 90 | finally: 91 | os.remove(f.name) 92 | -------------------------------------------------------------------------------- /rcm/configs/defaults/callbacks.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from imaginaire.lazy_config import PLACEHOLDER 19 | from imaginaire.lazy_config import LazyCall as L 20 | from imaginaire.callbacks.manual_gc import ManualGarbageCollection 21 | from imaginaire.callbacks.low_precision import LowPrecisionCallback 22 | from rcm.callbacks.compile_tokenizer import CompileTokenizer 23 | from rcm.callbacks.dataloading_monitor import DetailedDataLoadingSpeedMonitor 24 | from rcm.callbacks.device_monitor import DeviceMonitor 25 | from rcm.callbacks.grad_clip import GradClip 26 | from rcm.callbacks.heart_beat import HeartBeat 27 | from rcm.callbacks.iter_speed import IterSpeed 28 | from rcm.callbacks.wandb_log import WandbCallback 29 | from rcm.callbacks.every_n_draw_distill import EveryNDrawSample_Distill 30 | 31 | BASIC_CALLBACKS = dict( 32 | grad_clip=L(GradClip)(), 33 | low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), 34 | iter_speed=L(IterSpeed)( 35 | every_n="${trainer.logging_iter}", 36 | save_s3_every_log_n=10, 37 | ), 38 | heart_beat=L(HeartBeat)( 39 | every_n=10, 40 | update_interval_in_minute=20, 41 | ), 42 | device_monitor=L(DeviceMonitor)( 43 | every_n="${trainer.logging_iter}", 44 | upload_every_n_mul=10, 45 | ), 46 | manual_gc=L(ManualGarbageCollection)(every_n=5), 47 | compile_tokenizer=L(CompileTokenizer)( 48 | enabled=True, 49 | compile_after_iterations=4, 50 | dynamic=False, # If there are issues with constant recompilations you may set this value to None or True 51 | ), 52 | ) 53 | 54 | SPEED_CALLBACKS = dict( 55 | dataloader_speed=L(DetailedDataLoadingSpeedMonitor)( 56 | every_n="${trainer.logging_iter}", 57 | ), 58 | ) 59 | 60 | WANDB_CALLBACK = dict( 61 | wandb=L(WandbCallback)( 62 | logging_iter_multipler=1, 63 | save_logging_iter_multipler=10, 64 | ), 65 | wandb_10x=L(WandbCallback)( 66 | logging_iter_multipler=10, 67 | save_logging_iter_multipler=1, 68 | ), 69 | ) 70 | 71 | VIZ_ONLINE_SAMPLING_DISTILL_CALLBACKS = dict( 72 | every_n_sample_reg=L(EveryNDrawSample_Distill)( 73 | every_n=5000, 74 | ), 75 | every_n_sample_ema=L(EveryNDrawSample_Distill)( 76 | every_n=5000, 77 | is_ema=True, 78 | ), 79 | ) 80 | 81 | 82 | def register_callbacks(): 83 | cs = ConfigStore.instance() 84 | cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) 85 | cs.store(group="callbacks", package="trainer.callbacks", name="dataloading_speed", node=SPEED_CALLBACKS) 86 | cs.store(group="callbacks", package="trainer.callbacks", name="wandb", node=WANDB_CALLBACK) 87 | cs.store(group="callbacks", package="trainer.callbacks", name="viz_online_sampling_distill", node=VIZ_ONLINE_SAMPLING_DISTILL_CALLBACKS) 88 | -------------------------------------------------------------------------------- /imaginaire/callbacks/every_n.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import abstractmethod 17 | 18 | import torch 19 | 20 | from imaginaire.model import ImaginaireModel 21 | from imaginaire.trainer import ImaginaireTrainer 22 | from imaginaire.utils import distributed, log 23 | from imaginaire.utils.callback import Callback 24 | 25 | 26 | class EveryN(Callback): 27 | def __init__( 28 | self, 29 | every_n: int | None = None, 30 | step_size: int = 1, 31 | barrier_after_run: bool = True, 32 | run_at_start: bool = False, 33 | ) -> None: 34 | """Constructor for `EveryN`. 35 | 36 | Args: 37 | every_n (int): Frequency with which callback is run during training. 38 | step_size (int): Size of iteration step count. Default 1. 39 | barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. 40 | run_at_start (bool): Whether to run at the beginning of training. Default False. 41 | """ 42 | self.every_n = every_n 43 | if self.every_n == 0: 44 | log.warning( 45 | f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." 46 | ) 47 | 48 | self.step_size = step_size 49 | self.barrier_after_run = barrier_after_run 50 | self.run_at_start = run_at_start 51 | 52 | def on_training_step_end( 53 | self, 54 | model: ImaginaireModel, 55 | data_batch: dict[str, torch.Tensor], 56 | output_batch: dict[str, torch.Tensor], 57 | loss: torch.Tensor, 58 | iteration: int = 0, 59 | ) -> None: 60 | # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training 61 | if self.every_n != 0: 62 | trainer = self.trainer 63 | global_step = iteration // self.step_size 64 | should_run = (iteration == 1 and self.run_at_start) or ( 65 | global_step % self.every_n == 0 66 | ) # (self.every_n - 1) 67 | if should_run: 68 | log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") 69 | self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) 70 | log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") 71 | # add necessary barrier to avoid timeout 72 | if self.barrier_after_run: 73 | distributed.barrier() 74 | 75 | @abstractmethod 76 | def every_n_impl( 77 | self, 78 | trainer: ImaginaireTrainer, 79 | model: ImaginaireModel, 80 | data_batch: dict[str, torch.Tensor], 81 | output_batch: dict[str, torch.Tensor], 82 | loss: torch.Tensor, 83 | iteration: int, 84 | ) -> None: ... 85 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import importlib 18 | import os 19 | 20 | from loguru import logger as logging 21 | 22 | from imaginaire.config import Config, pretty_print_overrides 23 | from imaginaire.lazy_config import instantiate 24 | from imaginaire.lazy_config.lazy import LazyConfig 25 | from imaginaire.utils import distributed 26 | from imaginaire.utils.config_helper import get_config_module, override 27 | 28 | 29 | @logging.catch(reraise=True) 30 | def launch(config: Config, args: argparse.Namespace) -> None: 31 | # Need to initialize the distributed environment before calling config.validate() because it tries to synchronize 32 | # a buffer across ranks. If you don't do this, then you end up allocating a bunch of buffers on rank 0, and also that 33 | # check doesn't actually do anything. 34 | distributed.init() 35 | 36 | # Check that the config is valid 37 | config.validate() 38 | # Freeze the config so developers don't change it during training. 39 | config.freeze() # type: ignore 40 | trainer = instantiate(config.trainer.type, config=config) 41 | # Create the model 42 | model = instantiate(config.model) 43 | # Create the dataloaders. 44 | dataloader_train = instantiate(config.dataloader_train) 45 | dataloader_val = instantiate(config.dataloader_val) 46 | # Start training 47 | trainer.train(model, dataloader_train, dataloader_val) 48 | 49 | 50 | if __name__ == "__main__": 51 | # Usage: torchrun --nproc_per_node=1 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiments=predict2_video2world_training_2b_cosmos_nemo_assets 52 | 53 | # Get the config file from the input arguments. 54 | parser = argparse.ArgumentParser(description="Training") 55 | parser.add_argument("--config", help="Path to the config file", required=True) 56 | parser.add_argument( 57 | "opts", 58 | help=""" 59 | Modify config options at the end of the command. For Yacs configs, use 60 | space-separated "PATH.KEY VALUE" pairs. 61 | For python-based LazyConfig, use "path.key=value". 62 | """.strip(), 63 | default=None, 64 | nargs=argparse.REMAINDER, 65 | ) 66 | parser.add_argument( 67 | "--dryrun", 68 | action="store_true", 69 | help="Do a dry run without training. Useful for debugging the config.", 70 | ) 71 | args = parser.parse_args() 72 | config_module = get_config_module(args.config) 73 | config = importlib.import_module(config_module).make_config() 74 | config = override(config, args.opts) 75 | if args.dryrun: 76 | logging.info("Config:\n" + config.pretty_print(use_color=True) + "\n" + pretty_print_overrides(args.opts, use_color=True)) 77 | os.makedirs(config.job.path_local, exist_ok=True) 78 | LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") 79 | print(f"{config.job.path_local}/config.yaml") 80 | else: 81 | # Launch the training job. 82 | launch(config, args) 83 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/registry_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 17 | from imaginaire.utils.easy_io.handlers.byte_handler import ByteHandler 18 | from imaginaire.utils.easy_io.handlers.csv_handler import CsvHandler 19 | from imaginaire.utils.easy_io.handlers.gzip_handler import GzipHandler 20 | from imaginaire.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler 21 | from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler 22 | from imaginaire.utils.easy_io.handlers.jsonl_handler import JsonlHandler 23 | from imaginaire.utils.easy_io.handlers.np_handler import NumpyHandler 24 | from imaginaire.utils.easy_io.handlers.pandas_handler import PandasHandler 25 | from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler 26 | from imaginaire.utils.easy_io.handlers.pil_handler import PILHandler 27 | from imaginaire.utils.easy_io.handlers.tarfile_handler import TarHandler 28 | from imaginaire.utils.easy_io.handlers.torch_handler import TorchHandler 29 | from imaginaire.utils.easy_io.handlers.torchjit_handler import TorchJitHandler 30 | from imaginaire.utils.easy_io.handlers.txt_handler import TxtHandler 31 | from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler 32 | 33 | file_handlers = { 34 | "json": JsonHandler(), 35 | "yaml": YamlHandler(), 36 | "yml": YamlHandler(), 37 | "pickle": PickleHandler(), 38 | "pkl": PickleHandler(), 39 | "tar": TarHandler(), 40 | "jit": TorchJitHandler(), 41 | "npy": NumpyHandler(), 42 | "txt": TxtHandler(), 43 | "csv": CsvHandler(), 44 | "pandas": PandasHandler(), 45 | "gz": GzipHandler(), 46 | "jsonl": JsonlHandler(), 47 | "byte": ByteHandler(), 48 | } 49 | 50 | for torch_type in ["pt", "pth", "ckpt"]: 51 | file_handlers[torch_type] = TorchHandler() 52 | for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]: 53 | file_handlers[img_type] = PILHandler() 54 | file_handlers[img_type].format = img_type 55 | for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]: 56 | file_handlers[video_type] = ImageioVideoHandler() 57 | 58 | 59 | def _register_handler(handler, file_formats): 60 | """Register a handler for some file extensions. 61 | 62 | Args: 63 | handler (:obj:`BaseFileHandler`): Handler to be registered. 64 | file_formats (str or list[str]): File formats to be handled by this 65 | handler. 66 | """ 67 | if not isinstance(handler, BaseFileHandler): 68 | raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") 69 | if isinstance(file_formats, str): 70 | file_formats = [file_formats] 71 | if not all([isinstance(item, str) for item in file_formats]): 72 | raise TypeError("file_formats must be a str or a list of str") 73 | for ext in file_formats: 74 | file_handlers[ext] = handler 75 | 76 | 77 | def register_handler(file_formats, **kwargs): 78 | def wrap(cls): 79 | _register_handler(cls(**kwargs), file_formats) 80 | return cls 81 | 82 | return wrap 83 | -------------------------------------------------------------------------------- /rcm/callbacks/dataloading_monitor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import time 17 | 18 | import numpy as np 19 | import torch 20 | import wandb 21 | 22 | from imaginaire.model import ImaginaireModel 23 | from imaginaire.utils import distributed 24 | from imaginaire.utils.callback import Callback 25 | from imaginaire.utils.easy_io import easy_io 26 | 27 | 28 | class DetailedDataLoadingSpeedMonitor(Callback): 29 | def __init__( 30 | self, 31 | every_n: int, 32 | step_size: int = 1, 33 | save_s3: bool = False, 34 | ): 35 | self.every_n = every_n 36 | self.step_size = step_size 37 | self.should_run = False 38 | self.start_dataloading_time = None 39 | self.dataloading_time = None 40 | self.name = self.__class__.__name__ 41 | self.save_s3 = save_s3 42 | self.time_delta_list = [] 43 | 44 | def on_before_dataloading(self, iteration: int = 0) -> None: 45 | # We want to run it one iteration before on_training_step_start should_run is set to True. 46 | global_step = iteration // self.step_size 47 | self.should_run = (global_step + 1) % self.every_n == 0 48 | self.start_dataloading_time = time.time() 49 | 50 | def on_after_dataloading(self, iteration: int = 0) -> None: 51 | self.time_delta_list.append(time.time() - self.start_dataloading_time) 52 | 53 | def on_training_step_end( 54 | self, 55 | model: ImaginaireModel, 56 | data_batch: dict[str, torch.Tensor], 57 | output_batch: dict[str, torch.Tensor], 58 | loss: torch.Tensor, 59 | iteration: int = 0, 60 | ) -> None: 61 | if self.should_run: 62 | self.should_run = False 63 | cur_rank_mean, cur_rank_max = np.mean(self.time_delta_list), np.max(self.time_delta_list) 64 | self.time_delta_list = [] # Reset the list 65 | 66 | dataloading_time_gather_list = distributed.all_gather_tensor(torch.tensor([cur_rank_mean, cur_rank_max]).cuda()) 67 | wandb_info = {f"{self.name}_mean/dataloading_{k:03d}": v[0].item() for k, v in enumerate(dataloading_time_gather_list)} 68 | wandb_info.update({f"{self.name}_max/dataloading_{k:03d}": v[1].item() for k, v in enumerate(dataloading_time_gather_list)}) 69 | mean_times = torch.stack(dataloading_time_gather_list)[:, 0] 70 | slowest_dataloading_rank_id = torch.argmax(mean_times) 71 | max_dataloading = torch.max(mean_times) 72 | wandb_info.update( 73 | { 74 | "slowest_rank/slowest_dataloading_rank": slowest_dataloading_rank_id.item(), 75 | "slowest_rank/slowest_dataloading_time": max_dataloading.item(), 76 | } 77 | ) 78 | 79 | if wandb.run: 80 | wandb.log(wandb_info, step=iteration) 81 | 82 | if self.save_s3 and distributed.is_rank0(): 83 | easy_io.dump( 84 | wandb_info, 85 | f"s3://rundir/{self.name}/iter_{iteration:09d}.yaml", 86 | ) 87 | -------------------------------------------------------------------------------- /rcm/utils/dtensor_helper.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import annotations 17 | 18 | import itertools 19 | from typing import Any 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch.distributed.device_mesh import DeviceMesh 24 | 25 | from imaginaire.utils.misc import get_local_tensor_if_DTensor 26 | 27 | 28 | class DTensorFastEmaModelUpdater: 29 | """ 30 | Similar as FastEmaModelUpdater 31 | """ 32 | 33 | def __init__(self): 34 | # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite 35 | self.is_cached = False 36 | 37 | def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: 38 | with torch.no_grad(): 39 | for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): 40 | tgt_params.to_local().data.copy_(src_params.to_local().data) 41 | 42 | @torch.no_grad() 43 | def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: 44 | target_list = [] 45 | source_list = [] 46 | for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): 47 | assert tgt_params.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." 48 | target_list.append(tgt_params.to_local()) 49 | source_list.append(src_params.to_local().data) 50 | torch._foreach_mul_(target_list, beta) 51 | torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) 52 | 53 | @torch.no_grad() 54 | def cache(self, parameters: Any, is_cpu: bool = False) -> None: 55 | assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" 56 | device = "cpu" if is_cpu else "cuda" 57 | self.collected_params = [param.to_local().clone().to(device) for param in parameters] 58 | self.is_cached = True 59 | 60 | @torch.no_grad() 61 | def restore(self, parameters: Any) -> None: 62 | assert self.is_cached, "EMA cache is not taken yet." 63 | for c_param, param in zip(self.collected_params, parameters, strict=False): 64 | param.to_local().copy_(c_param.data.type_as(param.data)) 65 | self.collected_params = [] 66 | # Release the cache after we call restore 67 | self.is_cached = False 68 | 69 | 70 | def broadcast_dtensor_model_states(model: torch.nn.Module, mesh: DeviceMesh): 71 | """Broadcast model states from replicate mesh's rank 0.""" 72 | replicate_group = mesh.get_group("replicate") 73 | all_ranks = dist.get_process_group_ranks(replicate_group) 74 | if len(all_ranks) == 1: 75 | return 76 | 77 | for _, tensor in itertools.chain(model.named_parameters(), model.named_buffers()): 78 | # Get src rank which is the first rank in each replication group 79 | src_rank = all_ranks[0] 80 | # Broadcast the local tensor 81 | local_tensor = get_local_tensor_if_DTensor(tensor) 82 | dist.broadcast( 83 | local_tensor, 84 | src=src_rank, 85 | group=replicate_group, 86 | ) 87 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/pil_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import IO 17 | 18 | import numpy as np 19 | 20 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 21 | 22 | try: 23 | from PIL import Image 24 | except ImportError: 25 | Image = None 26 | 27 | 28 | class PILHandler(BaseFileHandler): 29 | format: str 30 | str_like = False 31 | 32 | def load_from_fileobj( 33 | self, 34 | file: IO[bytes], 35 | fmt: str = "pil", 36 | size: int | tuple[int, int] | None = None, 37 | **kwargs, 38 | ): 39 | """ 40 | Load an image from a file-like object and return it in a specified format. 41 | 42 | Args: 43 | file (IO[bytes]): A file-like object containing the image data. 44 | fmt (str): The format to convert the image into. Options are \ 45 | 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \ 46 | 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor). 47 | size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \ 48 | or a tuple of (width, height). If specified, the image is resized accordingly. 49 | **kwargs: Additional keyword arguments that can be passed to conversion functions. 50 | 51 | Returns: 52 | Image data in the format specified by `fmt`. 53 | 54 | Raises: 55 | IOError: If the image cannot be loaded or processed. 56 | ValueError: If the specified format is unsupported. 57 | """ 58 | try: 59 | img = Image.open(file) 60 | img.load() # Explicitly load the image data 61 | if size is not None: 62 | if isinstance(size, int): 63 | size = ( 64 | size, 65 | size, 66 | ) # create a tuple if only one integer is provided 67 | img = img.resize(size, Image.ANTIALIAS) 68 | 69 | # Return the image in the requested format 70 | if fmt in ["numpy", "np", "npy"]: 71 | return np.array(img, **kwargs) 72 | if fmt == "pil": 73 | return img 74 | if fmt in ["th", "torch"]: 75 | import torch 76 | 77 | # Convert to tensor 78 | img_tensor = torch.from_numpy(np.array(img, **kwargs)) 79 | # Convert image from HxWxC to CxHxW 80 | if img_tensor.ndim == 3: 81 | img_tensor = img_tensor.permute(2, 0, 1) 82 | return img_tensor 83 | raise ValueError( 84 | "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'." 85 | ) 86 | except Exception as e: 87 | raise OSError(f"Unable to load image: {e}") from e 88 | 89 | def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): 90 | if "format" not in kwargs: 91 | kwargs["format"] = self.format 92 | kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper() 93 | obj.save(file, **kwargs) 94 | 95 | def dump_to_str(self, obj, **kwargs): 96 | raise NotImplementedError 97 | -------------------------------------------------------------------------------- /rcm/callbacks/grad_clip.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | from typing import List, Tuple 18 | 19 | import torch 20 | import wandb 21 | 22 | from imaginaire.utils import distributed 23 | from imaginaire.utils.callback import Callback 24 | 25 | 26 | @torch.jit.script 27 | def _fused_nan_to_num(params: List[torch.Tensor]): 28 | for param in params: 29 | torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) 30 | 31 | 32 | @dataclass 33 | class _MagnitudeRecord: 34 | state: float = 0 35 | iter_count: int = 0 36 | 37 | def reset(self) -> None: 38 | self.state = 0 39 | self.iter_count = 0 40 | 41 | def update(self, cur_state: torch.Tensor) -> None: 42 | self.state += cur_state 43 | self.iter_count += 1 44 | 45 | def get_stat(self) -> Tuple[float, float]: 46 | if self.iter_count > 0: 47 | avg_state = self.state / self.iter_count 48 | avg_state = avg_state.item() 49 | else: 50 | avg_state = 0 51 | self.reset() 52 | return avg_state 53 | 54 | 55 | class GradClip(Callback): 56 | """ 57 | This callback is used to clip the gradient norm of the model. 58 | It also logs the average gradient norm of the model to wandb. 59 | """ 60 | 61 | def __init__(self, clip_norm=1.0, force_finite: bool = True): 62 | self.clip_norm = clip_norm 63 | self.force_finite = force_finite 64 | 65 | self.img_mag_log = _MagnitudeRecord() 66 | self.video_mag_log = _MagnitudeRecord() 67 | self._cur_state = None 68 | 69 | def on_training_step_start( 70 | self, model, data_batch: dict[str, torch.Tensor], iteration: int = 0 71 | ) -> None: 72 | if model.is_image_batch(data_batch): 73 | self._cur_state = self.img_mag_log 74 | else: 75 | self._cur_state = self.video_mag_log 76 | 77 | def on_before_optimizer_step( 78 | self, 79 | model_ddp: distributed.DistributedDataParallel, 80 | optimizer: torch.optim.Optimizer, 81 | scheduler: torch.optim.lr_scheduler.LRScheduler, 82 | grad_scaler: torch.amp.GradScaler, 83 | iteration: int = 0, 84 | ) -> None: 85 | del optimizer, scheduler 86 | if isinstance(model_ddp, distributed.DistributedDataParallel): 87 | model = model_ddp.module 88 | else: 89 | model = model_ddp 90 | params = [] 91 | 92 | if self.force_finite: 93 | for param in model.parameters(): 94 | if param.grad is not None: 95 | params.append(param.grad) 96 | _fused_nan_to_num(params) 97 | 98 | total_norm = model.clip_grad_norm_(self.clip_norm) 99 | 100 | self._cur_state.update(total_norm) 101 | if iteration % self.config.trainer.logging_iter == 0: 102 | avg_img_mag, avg_video_mag = self.img_mag_log.get_stat(), self.video_mag_log.get_stat() 103 | if wandb.run: 104 | wandb.log( 105 | { 106 | "clip_grad_norm/image": avg_img_mag, 107 | "clip_grad_norm/video": avg_video_mag, 108 | "iteration": iteration, 109 | }, 110 | step=iteration, 111 | ) 112 | -------------------------------------------------------------------------------- /imaginaire/utils/wandb_util.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import annotations 17 | 18 | import os 19 | from typing import TYPE_CHECKING 20 | 21 | import attrs 22 | import wandb 23 | import wandb.util 24 | from omegaconf import DictConfig 25 | 26 | from imaginaire.lazy_config.lazy import LazyConfig 27 | from imaginaire.utils import distributed, log 28 | from imaginaire.utils.easy_io import easy_io 29 | 30 | if TYPE_CHECKING: 31 | from imaginaire.config import CheckpointConfig, Config, JobConfig 32 | from imaginaire.model import ImaginaireModel 33 | 34 | 35 | @distributed.rank0_only 36 | def init_wandb(config: Config, model: ImaginaireModel) -> None: 37 | """Initialize Weights & Biases (wandb) logger. 38 | 39 | Args: 40 | config (Config): The config object for the Imaginaire codebase. 41 | model (ImaginaireModel): The PyTorch model. 42 | """ 43 | if isinstance(config.job, DictConfig): 44 | from imaginaire.config import JobConfig 45 | 46 | config_job = JobConfig(**config.job) 47 | else: 48 | config_job = config.job 49 | config_checkpoint = config.checkpoint 50 | # Try to fetch the W&B job ID for resuming training. 51 | wandb_id = _read_wandb_id(config_job, config_checkpoint) 52 | if wandb_id is None: 53 | # Generate a new W&B job ID. 54 | wandb_id = wandb.util.generate_id() 55 | _write_wandb_id(config_job, config_checkpoint, wandb_id=wandb_id) 56 | log.info(f"Generating new wandb ID: {wandb_id}") 57 | else: 58 | log.info(f"Resuming with existing wandb ID: {wandb_id}") 59 | # refactor config so that wandb better understands it 60 | local_safe_yaml_fp = LazyConfig.save_yaml(config, os.path.join(config_job.path_local, "config.yaml")) 61 | if os.path.exists(local_safe_yaml_fp): 62 | config_resolved = easy_io.load(local_safe_yaml_fp) 63 | else: 64 | config_resolved = attrs.asdict(config) 65 | # Initialize the wandb library. 66 | wandb.init( 67 | force=True, 68 | id=wandb_id, 69 | project=config_job.project, 70 | group=config_job.group, 71 | name=config_job.name, 72 | config=config_resolved, 73 | dir=config_job.path_local, 74 | resume="allow", 75 | mode=config_job.wandb_mode, 76 | ) 77 | 78 | 79 | def _read_wandb_id(config_job: JobConfig, config_checkpoint: CheckpointConfig) -> str | None: 80 | """Read the W&B job ID. If it doesn't exist, return None. 81 | 82 | Args: 83 | config_wandb (JobConfig): The config object for the W&B logger. 84 | config_checkpoint (CheckpointConfig): The config object for the checkpointer. 85 | 86 | Returns: 87 | wandb_id (str | None): W&B job ID. 88 | """ 89 | wandb_id = None 90 | wandb_id_path = f"{config_job.path_local}/wandb_id.txt" 91 | if os.path.isfile(wandb_id_path): 92 | wandb_id = open(wandb_id_path).read().strip() 93 | return wandb_id 94 | 95 | 96 | def _write_wandb_id(config_job: JobConfig, config_checkpoint: CheckpointConfig, wandb_id: str) -> None: 97 | """Write the generated W&B job ID. 98 | 99 | Args: 100 | config_wandb (JobConfig): The config object for the W&B logger. 101 | config_checkpoint (CheckpointConfig): The config object for the checkpointer. 102 | wandb_id (str): The W&B job ID. 103 | """ 104 | content = f"{wandb_id}\n" 105 | wandb_id_path = f"{config_job.path_local}/wandb_id.txt" 106 | with open(wandb_id_path, "w") as file: 107 | file.write(content) 108 | -------------------------------------------------------------------------------- /rcm/configs/registry_distill.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, List 17 | 18 | import attrs 19 | 20 | from imaginaire import config 21 | from imaginaire.utils.config_helper import import_all_modules_from_package 22 | from rcm.configs.defaults.trainer import register_trainer 23 | from rcm.configs.defaults.checkpoint import register_checkpoint 24 | from rcm.configs.defaults.ema import register_ema 25 | from rcm.configs.defaults.optimizer import register_optimizer, register_optimizer_fake_score 26 | from rcm.configs.defaults.scheduler import register_scheduler 27 | from rcm.configs.defaults.conditioner import register_conditioner 28 | from rcm.configs.defaults.callbacks import register_callbacks 29 | from rcm.configs.defaults.ckpt_type import register_ckpt_type 30 | from rcm.configs.defaults.dataloader import register_dataloader 31 | from rcm.configs.defaults.tokenizer import register_tokenizer 32 | from rcm.configs.defaults.model import register_model 33 | from rcm.configs.defaults.net import register_net, register_net_fake_score, register_net_teacher 34 | 35 | 36 | @attrs.define(slots=False) 37 | class Config(config.Config): 38 | # default config groups that will be used unless overwritten 39 | # see config groups in registry.py 40 | defaults: List[Any] = attrs.field( 41 | factory=lambda: [ 42 | "_self_", 43 | {"trainer": "standard"}, 44 | {"data_train": "dummy"}, 45 | {"data_val": "dummy"}, 46 | {"optimizer": "fusedadamw"}, 47 | {"scheduler": "lambdalinear"}, 48 | {"callbacks": "basic"}, 49 | {"checkpoint": "local"}, 50 | {"ckpt_type": "dcp"}, 51 | {"model": "fsdp_t2v_distill_rcm"}, 52 | {"net": None}, 53 | {"net_teacher": None}, 54 | {"net_fake_score": None}, 55 | {"optimizer_fake_score": "fusedadamw"}, 56 | {"conditioner": "text_nodrop"}, 57 | {"ema": "power"}, 58 | {"tokenizer": "wan2pt1_tokenizer"}, 59 | # the list is with order, we need global experiment to be the last one 60 | {"experiment": None}, 61 | ] 62 | ) 63 | 64 | 65 | def make_config() -> Config: 66 | c = Config( 67 | model=None, 68 | optimizer=None, 69 | scheduler=None, 70 | dataloader_train=None, 71 | dataloader_val=None, 72 | ) 73 | 74 | # Specifying values through instances of attrs 75 | c.job.project = "rcm" # this decides the wandb project name 76 | c.job.group = "debug" 77 | c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" 78 | 79 | c.trainer.max_iter = 400_000 80 | c.trainer.logging_iter = 100 81 | c.trainer.validation_iter = 100 82 | c.trainer.run_validation = False 83 | c.trainer.callbacks = None 84 | 85 | # Call this function to register config groups for advanced overriding. the order follows the default config groups 86 | register_trainer() 87 | register_dataloader() 88 | register_optimizer() 89 | register_optimizer_fake_score() 90 | register_scheduler() 91 | register_callbacks() 92 | register_checkpoint() 93 | register_ckpt_type() 94 | register_model() 95 | register_net() 96 | register_net_teacher() 97 | register_net_fake_score() 98 | register_conditioner() 99 | register_ema() 100 | register_tokenizer() 101 | 102 | # experiment config are defined in the experiments folder 103 | import_all_modules_from_package("rcm.configs.experiments.rcm", reload=True) 104 | return c 105 | -------------------------------------------------------------------------------- /rcm/callbacks/iter_speed.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import time 17 | 18 | import torch 19 | import wandb 20 | from torch import Tensor 21 | 22 | from imaginaire.callbacks.every_n import EveryN 23 | from imaginaire.model import ImaginaireModel 24 | from imaginaire.trainer import ImaginaireTrainer 25 | from imaginaire.utils import log 26 | from imaginaire.utils.distributed import rank0_only 27 | from imaginaire.utils.easy_io import easy_io 28 | 29 | 30 | class IterSpeed(EveryN): 31 | """ 32 | Args: 33 | hit_thres (int): Number of iterations to wait before logging. 34 | save_s3 (bool): Whether to save to S3. 35 | save_s3_every_log_n (int): Save to S3 every n log iterations, which means save_s3_every_log_n n * every_n global iterations. 36 | """ 37 | 38 | def __init__(self, *args, hit_thres: int = 5, save_s3: bool = False, save_s3_every_log_n: int = 10, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | self.time = None 41 | self.hit_counter = 0 42 | self.hit_thres = hit_thres 43 | self.save_s3 = save_s3 44 | self.save_s3_every_log_n = save_s3_every_log_n 45 | self.name = self.__class__.__name__ 46 | self.last_hit_time = time.time() 47 | 48 | def on_training_step_end( 49 | self, 50 | model: ImaginaireModel, 51 | data_batch: dict[str, torch.Tensor], 52 | output_batch: dict[str, torch.Tensor], 53 | loss: torch.Tensor, 54 | iteration: int = 0, 55 | ) -> None: 56 | if self.hit_counter < self.hit_thres: 57 | log.info( 58 | f"Iteration {iteration}: " 59 | f"Hit counter: {self.hit_counter + 1}/{self.hit_thres} | " 60 | f"Loss: {loss.item():.4f} | " 61 | f"Time: {time.time() - self.last_hit_time:.2f}s" 62 | ) 63 | self.hit_counter += 1 64 | self.last_hit_time = time.time() 65 | #! useful for large scale training and avoid oom crash in the first two iterations!!! 66 | torch.cuda.synchronize() 67 | return 68 | super().on_training_step_end(model, data_batch, output_batch, loss, iteration) 69 | 70 | @rank0_only 71 | def every_n_impl( 72 | self, 73 | trainer: ImaginaireTrainer, 74 | model: ImaginaireModel, 75 | data_batch: dict[str, Tensor], 76 | output_batch: dict[str, Tensor], 77 | loss: Tensor, 78 | iteration: int, 79 | ) -> None: 80 | if self.time is None: 81 | self.time = time.time() 82 | return 83 | cur_time = time.time() 84 | iter_speed = (cur_time - self.time) / self.every_n / self.step_size 85 | 86 | log.info(f"{iteration} : iter_speed {iter_speed:.2f} seconds per iteration | Loss: {loss.item():.4f}") 87 | 88 | if wandb.run: 89 | sample_counter = getattr(trainer, "sample_counter", iteration) 90 | wandb.log( 91 | { 92 | "timer/iter_speed": iter_speed, 93 | "sample_counter": sample_counter, 94 | }, 95 | step=iteration, 96 | ) 97 | self.time = cur_time 98 | if self.save_s3: 99 | if iteration % (self.save_s3_every_log_n * self.every_n) == 0: 100 | easy_io.dump( 101 | { 102 | "iter_speed": iter_speed, 103 | "iteration": iteration, 104 | }, 105 | f"s3://rundir/{self.name}/iter_{iteration:09d}.yaml", 106 | ) 107 | -------------------------------------------------------------------------------- /rcm/datasets/visualize_tar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import argparse 4 | import tarfile 5 | import torch 6 | from einops import rearrange 7 | from tqdm import tqdm 8 | 9 | from rcm.tokenizers.wan2pt1 import Wan2pt1VAEInterface 10 | from imaginaire.utils.io import save_image_or_video 11 | 12 | 13 | def main(args): 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | print(f"Using device: {device}") 16 | 17 | print("Loading VAE tokenizer...") 18 | try: 19 | tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) 20 | except Exception as e: 21 | print(f"Error loading VAE from {args.vae_path}: {e}") 22 | print("Please ensure the VAE model path is correct.") 23 | return 24 | 25 | print(f"Opening tar file: {args.tar_path}") 26 | if not os.path.exists(args.tar_path): 27 | print(f"Error: Tar file not found at {args.tar_path}") 28 | return 29 | 30 | latent_files = [] 31 | with tarfile.open(args.tar_path, "r") as tar: 32 | all_files = tar.getnames() 33 | latent_files = sorted([f for f in all_files if f.endswith(".latent.pt")]) 34 | 35 | if not latent_files: 36 | print("Error: No '.latent.pt' files found in the tar archive.") 37 | return 38 | 39 | print(f"Found {len(latent_files)} latent files in the archive.") 40 | 41 | if args.num_samples > 0: 42 | latent_files = latent_files[: args.num_samples] 43 | print(f"Decoding the first {len(latent_files)} samples.") 44 | 45 | decoded_videos = [] 46 | print("Decoding samples...") 47 | with tarfile.open(args.tar_path, "r") as tar: 48 | for member_name in tqdm(latent_files, desc="Decoding"): 49 | member_file = tar.extractfile(member_name) 50 | if member_file is None: 51 | print(f"Warning: Could not extract {member_name}. Skipping.") 52 | continue 53 | 54 | latent_tensor = torch.load(io.BytesIO(member_file.read()), map_location=device) 55 | 56 | if latent_tensor.dim() == 4: 57 | samples = latent_tensor.unsqueeze(0) 58 | else: 59 | samples = latent_tensor 60 | 61 | video = tokenizer.decode(samples.to(device, dtype=torch.bfloat16)) 62 | 63 | decoded_videos.append(video.float().cpu()) 64 | 65 | if not decoded_videos: 66 | print("No videos were decoded. Exiting.") 67 | return 68 | 69 | print("Stacking videos into a grid...") 70 | to_show = torch.stack(decoded_videos, dim=0) 71 | 72 | to_show = (1.0 + to_show.clamp(-1, 1)) / 2.0 73 | 74 | num_videos = to_show.shape[0] 75 | grid_cols = args.grid_cols 76 | grid_rows = (num_videos + grid_cols - 1) // grid_cols # 向上取整 77 | 78 | if num_videos != grid_rows * grid_cols: 79 | print(f"Warning: The number of videos ({num_videos}) doesn't fit a perfect {grid_rows}x{grid_cols} grid. Rearranging may be imperfect.") 80 | 81 | to_show = rearrange(to_show, "(rows cols) b c t h w -> c t (rows h) (cols b w)", cols=grid_cols) 82 | 83 | save_dir = os.path.dirname(args.save_path) 84 | if save_dir: 85 | os.makedirs(save_dir, exist_ok=True) 86 | 87 | print(f"Saving video grid to: {args.save_path}") 88 | save_image_or_video(to_show, args.save_path, fps=args.fps) 89 | 90 | print("Done!") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser(description="Decode latent samples from a .tar file into a video grid.") 95 | 96 | parser.add_argument("--tar_path", type=str, required=True, help="Path to the input .tar file containing latent samples.") 97 | parser.add_argument("--save_path", type=str, default="preview.mp4", help="Path to save the output video file.") 98 | parser.add_argument("--vae_path", type=str, default="assets/checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE model weights.") 99 | parser.add_argument("--num_samples", type=int, default=16, help="Number of samples to decode from the tar file. Set to 0 to decode all.") 100 | parser.add_argument("--grid_cols", type=int, default=4, help="Number of columns in the output video grid.") 101 | parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") 102 | 103 | args = parser.parse_args() 104 | main(args) 105 | -------------------------------------------------------------------------------- /rcm/callbacks/heart_beat.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import time 17 | from datetime import datetime 18 | 19 | import pytz 20 | import torch 21 | 22 | from imaginaire.callbacks.every_n import EveryN 23 | from imaginaire.model import ImaginaireModel 24 | from imaginaire.trainer import ImaginaireTrainer 25 | from imaginaire.utils import distributed 26 | from imaginaire.utils.easy_io import easy_io 27 | 28 | 29 | class HeartBeat(EveryN): 30 | """ 31 | A callback that logs a heartbeat message at regular intervals to indicate that the training process is still running. 32 | 33 | Args: 34 | every_n (int): The frequency at which the callback is invoked. 35 | step_size (int, optional): The step size for the callback. Defaults to 1. 36 | update_interval_in_minute (int, optional): The interval in minutes for logging the heartbeat. Defaults to 20 minutes. 37 | save_s3 (bool, optional): Whether to save the heartbeat information to S3. Defaults to False. 38 | """ 39 | 40 | def __init__(self, every_n: int, step_size: int = 1, update_interval_in_minute: int = 20, save_s3: bool = False): 41 | super().__init__(every_n=every_n, step_size=step_size) 42 | self.name = self.__class__.__name__ 43 | self.update_interval_in_minute = update_interval_in_minute 44 | self.save_s3 = save_s3 45 | self.pst = pytz.timezone("America/Los_Angeles") 46 | self.is_hitted = False 47 | 48 | @distributed.rank0_only 49 | def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: 50 | self.time = time.time() 51 | if self.save_s3: 52 | current_time_pst = datetime.now(self.pst).strftime("%Y_%m_%d-%H_%M_%S") 53 | info = { 54 | "iteration": iteration, 55 | "time": current_time_pst, 56 | } 57 | easy_io.dump(info, f"s3://rundir/{self.name}_start.yaml") 58 | easy_io.dump(info, f"s3://timestamps_rundir/{self.name}_start.yaml") 59 | 60 | def on_training_step_end( 61 | self, 62 | model: ImaginaireModel, 63 | data_batch: dict[str, torch.Tensor], 64 | output_batch: dict[str, torch.Tensor], 65 | loss: torch.Tensor, 66 | iteration: int = 0, 67 | ) -> None: 68 | if not self.is_hitted: 69 | self.is_hitted = True 70 | if distributed.get_rank() == 0: 71 | self.report(iteration) 72 | super().on_training_step_end(model, data_batch, output_batch, loss, iteration) 73 | 74 | @distributed.rank0_only 75 | def every_n_impl( 76 | self, 77 | trainer: ImaginaireTrainer, 78 | model: ImaginaireModel, 79 | data_batch: dict[str, torch.Tensor], 80 | output_batch: dict[str, torch.Tensor], 81 | loss: torch.Tensor, 82 | iteration: int, 83 | ) -> None: 84 | if time.time() - self.time > 60 * self.update_interval_in_minute: 85 | self.report(iteration) 86 | 87 | def report(self, iteration: int = 0): 88 | self.time = time.time() 89 | if self.save_s3: 90 | current_time_pst = datetime.now(self.pst).strftime("%Y_%m_%d-%H_%M_%S") 91 | info = { 92 | "iteration": iteration, 93 | "time": current_time_pst, 94 | } 95 | easy_io.dump(info, f"s3://rundir/{self.name}.yaml") 96 | 97 | @distributed.rank0_only 98 | def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None: 99 | if self.save_s3: 100 | current_time_pst = datetime.now(self.pst).strftime("%Y_%m_%d-%H_%M_%S") 101 | info = { 102 | "iteration": iteration, 103 | "time": current_time_pst, 104 | } 105 | easy_io.dump(info, f"s3://rundir/{self.name}_end.yaml") 106 | easy_io.dump(info, f"s3://timestamps_rundir/{self.name}_end.yaml") 107 | -------------------------------------------------------------------------------- /imaginaire/lazy_config/instantiate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import collections.abc as abc 17 | import dataclasses 18 | from typing import Any 19 | 20 | import attrs 21 | 22 | from imaginaire.lazy_config.registry import _convert_target_to_string, locate 23 | from imaginaire.utils import log 24 | 25 | __all__ = ["dump_dataclass", "instantiate"] 26 | 27 | 28 | def is_dataclass_or_attrs(target): 29 | return dataclasses.is_dataclass(target) or attrs.has(target) 30 | 31 | 32 | def dump_dataclass(obj: Any): 33 | """ 34 | Dump a dataclass recursively into a dict that can be later instantiated. 35 | 36 | Args: 37 | obj: a dataclass object 38 | 39 | Returns: 40 | dict 41 | """ 42 | assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), ( 43 | "dump_dataclass() requires an instance of a dataclass." 44 | ) 45 | ret = {"_target_": _convert_target_to_string(type(obj))} 46 | for f in dataclasses.fields(obj): 47 | v = getattr(obj, f.name) 48 | if dataclasses.is_dataclass(v): 49 | v = dump_dataclass(v) 50 | if isinstance(v, (list, tuple)): 51 | v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] 52 | ret[f.name] = v 53 | return ret 54 | 55 | 56 | def instantiate(cfg, *args, **kwargs): 57 | """ 58 | Recursively instantiate objects defined in dictionaries by 59 | "_target_" and arguments. 60 | 61 | Args: 62 | cfg: a dict-like object with "_target_" that defines the caller, and 63 | other keys that define the arguments 64 | args: Optional positional parameters pass-through. 65 | kwargs: Optional named parameters pass-through. 66 | 67 | Returns: 68 | object instantiated by cfg 69 | """ 70 | from omegaconf import DictConfig, ListConfig, OmegaConf 71 | 72 | if isinstance(cfg, ListConfig): 73 | lst = [instantiate(x) for x in cfg] 74 | return ListConfig(lst, flags={"allow_objects": True}) 75 | if isinstance(cfg, list): 76 | # Specialize for list, because many classes take 77 | # list[objects] as arguments, such as ResNet, DatasetMapper 78 | return [instantiate(x) for x in cfg] 79 | 80 | # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), 81 | # instantiate it to the actual dataclass. 82 | if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type): 83 | return OmegaConf.to_object(cfg) 84 | 85 | if isinstance(cfg, abc.Mapping) and "_target_" in cfg: 86 | # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, 87 | # but faster: https://github.com/facebookresearch/hydra/issues/1200 88 | is_recursive = getattr(cfg, "_recursive_", True) 89 | if is_recursive: 90 | cfg = {k: instantiate(v) for k, v in cfg.items()} 91 | else: 92 | cfg = {k: v for k, v in cfg.items()} 93 | # pop the _recursive_ key to avoid passing it as a parameter 94 | if "_recursive_" in cfg: 95 | cfg.pop("_recursive_") 96 | cls = cfg.pop("_target_") 97 | cls = instantiate(cls) 98 | 99 | if isinstance(cls, str): 100 | cls_name = cls 101 | cls = locate(cls_name) 102 | assert cls is not None, cls_name 103 | else: 104 | try: 105 | cls_name = cls.__module__ + "." + cls.__qualname__ 106 | except Exception: 107 | # target could be anything, so the above could fail 108 | cls_name = str(cls) 109 | assert callable(cls), f"_target_ {cls} does not define a callable object" 110 | try: 111 | # override config with kwargs 112 | instantiate_kwargs = {} 113 | instantiate_kwargs.update(cfg) 114 | instantiate_kwargs.update(kwargs) 115 | return cls(*args, **instantiate_kwargs) 116 | except TypeError: 117 | log.error(f"Error when instantiating {cls_name}!") 118 | raise 119 | return cfg # return as-is if don't know what to do 120 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Misc 2 | outputs/ 3 | output/ 4 | logs/ 5 | checkpoints/ 6 | assets/ 7 | 8 | *.jit 9 | *.pt 10 | *.png 11 | *.hdr 12 | *.jpg 13 | *.jpeg 14 | *.webp 15 | *.pgm 16 | *.tiff 17 | *.tif 18 | *.gif 19 | *.mp4 20 | *.tar 21 | *.tar.gz 22 | *.gz 23 | *.pkl 24 | *.pt 25 | *.bin 26 | 27 | # Other uncheckable file types 28 | *.zip 29 | *.exe 30 | *.dll 31 | *.swp 32 | *.vscode/** 33 | *.ipynb 34 | *.DS_Store 35 | *.pyc 36 | *Thumbs.db 37 | *.patch 38 | __MACOSX 39 | 40 | # ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------ 41 | 42 | # Byte-compiled / optimized / DLL files 43 | **/__pycache__/ 44 | *.py[cod] 45 | *$py.class 46 | 47 | # C extensions 48 | *.so 49 | 50 | # Distribution / packaging 51 | .Python 52 | build/ 53 | develop-eggs/ 54 | dist/ 55 | downloads/ 56 | eggs/ 57 | .eggs/ 58 | lib/ 59 | lib64/ 60 | parts/ 61 | results/ 62 | sdist/ 63 | var/ 64 | wheels/ 65 | share/python-wheels/ 66 | *.egg-info/ 67 | .installed.config 68 | *.egg 69 | MANIFEST 70 | 71 | # PyInstaller 72 | # Usually these files are written by a python script from a template 73 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 74 | *.manifest 75 | *.spec 76 | 77 | # Installer logs 78 | pip-log.txt 79 | pip-delete-this-directory.txt 80 | 81 | # Unit test / coverage reports 82 | htmlcov/ 83 | .tox/ 84 | .nox/ 85 | .coverage 86 | .coverage.* 87 | .cache 88 | nosetests.xml 89 | coverage.xml 90 | *.cover 91 | *.py,cover 92 | .hypothesis/ 93 | .pytest_cache/ 94 | cover/ 95 | 96 | # Translations 97 | *.mo 98 | *.pot 99 | 100 | # Django stuff: 101 | *.log 102 | local_settings.py 103 | db.sqlite3 104 | db.sqlite3-journal 105 | 106 | # Flask stuff: 107 | instance/ 108 | .webassets-cache 109 | 110 | # Scrapy stuff: 111 | .scrapy 112 | 113 | # Sphinx documentation 114 | docs/_build/ 115 | 116 | # PyBuilder 117 | .pybuilder/ 118 | target/ 119 | 120 | # Third party 121 | inf_lib/SpargeAttn/csrc/qattn/instantiations_sm*/*.cu 122 | 123 | # Jupyter Notebook 124 | .ipynb_checkpoints 125 | 126 | # IPython 127 | profile_default/ 128 | ipython_config.py 129 | 130 | # pyenv 131 | # For a library or package, you might want to ignore these files since the code is 132 | # intended to run in multiple environments; otherwise, check them in: 133 | # .python-version 134 | 135 | # pipenv 136 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 137 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 138 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 139 | # install all needed dependencies. 140 | #Pipfile.lock 141 | 142 | # poetry 143 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 144 | # This is especially recommended for binary packages to ensure reproducibility, and is more 145 | # commonly ignored for libraries. 146 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 147 | #poetry.lock 148 | 149 | # pdm 150 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 151 | #pdm.lock 152 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 153 | # in version control. 154 | # https://pdm.fming.dev/#use-with-ide 155 | .pdm.toml 156 | 157 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 158 | __pypackages__/ 159 | 160 | # Celery stuff 161 | celerybeat-schedule 162 | celerybeat.pid 163 | 164 | # SageMath parsed files 165 | *.sage.py 166 | 167 | # Environments 168 | .env 169 | .venv 170 | env/ 171 | venv/ 172 | ENV/ 173 | env.bak/ 174 | venv.bak/ 175 | 176 | # Spyder project settings 177 | .spyderproject 178 | .spyproject 179 | 180 | # Rope project settings 181 | .ropeproject 182 | 183 | # mkdocs documentation 184 | /site 185 | 186 | # mypy 187 | .mypy_cache/ 188 | .dmypy.json 189 | dmypy.json 190 | 191 | # Pyre type checker 192 | .pyre/ 193 | 194 | # pytype static type analyzer 195 | .pytype/ 196 | 197 | # Cython debug symbols 198 | cython_debug/ 199 | 200 | # ruff 201 | .ruff_cache 202 | 203 | # PyCharm 204 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 205 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 206 | # and can be added to the global gitignore or merged into this file. For a more nuclear 207 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 208 | #.idea/ 209 | CLIP 210 | .devcontainer/devcontainer.json 211 | 212 | # Coverage 213 | .coverage 214 | coverage.xml 215 | 216 | # JUnit Reports 217 | report.xml 218 | 219 | # CI-CD 220 | ci-cd/edify/self_test/output_cicd.yaml 221 | temp/ 222 | envs.txt 223 | manifest.json 224 | 225 | 226 | # locks 227 | *.locks* 228 | *.no_exist* 229 | 230 | wandb/ 231 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/backends/registry_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import inspect 17 | 18 | from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend 19 | from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend 20 | from imaginaire.utils.easy_io.backends.local_backend import LocalBackend 21 | 22 | backends: dict = {} 23 | prefix_to_backends: dict = {} 24 | 25 | 26 | def _register_backend( 27 | name: str, 28 | backend: type[BaseStorageBackend], 29 | force: bool = False, 30 | prefixes: str | list | tuple | None = None, 31 | ): 32 | """Register a backend. 33 | 34 | Args: 35 | name (str): The name of the registered backend. 36 | backend (BaseStorageBackend): The backend class to be registered, 37 | which must be a subclass of :class:`BaseStorageBackend`. 38 | force (bool): Whether to override the backend if the name has already 39 | been registered. Defaults to False. 40 | prefixes (str or list[str] or tuple[str], optional): The prefix 41 | of the registered storage backend. Defaults to None. 42 | """ 43 | global backends, prefix_to_backends 44 | 45 | if not isinstance(name, str): 46 | raise TypeError(f"the backend name should be a string, but got {type(name)}") 47 | 48 | if not inspect.isclass(backend): 49 | raise TypeError(f"backend should be a class, but got {type(backend)}") 50 | if not issubclass(backend, BaseStorageBackend): 51 | raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") 52 | 53 | if name in backends and not force: 54 | raise ValueError( 55 | f'{name} is already registered as a storage backend, add "force=True" if you want to override it' 56 | ) 57 | backends[name] = backend 58 | 59 | if prefixes is not None: 60 | if isinstance(prefixes, str): 61 | prefixes = [prefixes] 62 | else: 63 | assert isinstance(prefixes, (list, tuple)) 64 | 65 | for prefix in prefixes: 66 | if prefix in prefix_to_backends and not force: 67 | raise ValueError( 68 | f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it' 69 | ) 70 | 71 | prefix_to_backends[prefix] = backend 72 | 73 | 74 | def register_backend( 75 | name: str, 76 | backend: type[BaseStorageBackend] | None = None, 77 | force: bool = False, 78 | prefixes: str | list | tuple | None = None, 79 | ): 80 | """Register a backend. 81 | 82 | Args: 83 | name (str): The name of the registered backend. 84 | backend (class, optional): The backend class to be registered, 85 | which must be a subclass of :class:`BaseStorageBackend`. 86 | When this method is used as a decorator, backend is None. 87 | Defaults to None. 88 | force (bool): Whether to override the backend if the name has already 89 | been registered. Defaults to False. 90 | prefixes (str or list[str] or tuple[str], optional): The prefix 91 | of the registered storage backend. Defaults to None. 92 | 93 | This method can be used as a normal method or a decorator. 94 | 95 | Examples: 96 | 97 | >>> class NewBackend(BaseStorageBackend): 98 | ... def get(self, filepath): 99 | ... return filepath 100 | ... 101 | ... def get_text(self, filepath): 102 | ... return filepath 103 | >>> register_backend('new', NewBackend) 104 | 105 | >>> @register_backend('new') 106 | ... class NewBackend(BaseStorageBackend): 107 | ... def get(self, filepath): 108 | ... return filepath 109 | ... 110 | ... def get_text(self, filepath): 111 | ... return filepath 112 | """ 113 | if backend is not None: 114 | _register_backend(name, backend, force=force, prefixes=prefixes) 115 | return 116 | 117 | def _register(backend_cls): 118 | _register_backend(name, backend_cls, force=force, prefixes=prefixes) 119 | return backend_cls 120 | 121 | return _register 122 | 123 | 124 | register_backend("local", LocalBackend, prefixes="") 125 | register_backend("http", HTTPBackend, prefixes=["http", "https"]) 126 | -------------------------------------------------------------------------------- /imaginaire/utils/log.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import atexit 17 | import os 18 | import sys 19 | from typing import Any 20 | 21 | import torch.distributed as dist 22 | from loguru._logger import Core, Logger 23 | 24 | RANK0_ONLY = True 25 | LEVEL = os.environ.get("LOGURU_LEVEL", "INFO") 26 | 27 | logger = Logger( 28 | core=Core(), 29 | exception=None, 30 | depth=1, 31 | record=False, 32 | lazy=False, 33 | colors=False, 34 | raw=False, 35 | capture=True, 36 | patchers=[], 37 | extra={}, 38 | ) 39 | 40 | atexit.register(logger.remove) 41 | 42 | 43 | def _add_relative_path(record: dict[str, Any]) -> None: 44 | start = os.getcwd() 45 | record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start) 46 | 47 | 48 | *options, _, extra = logger._options # type: ignore 49 | logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore 50 | 51 | 52 | def init_loguru_stdout() -> None: 53 | logger.remove() 54 | machine_format = get_machine_format() 55 | message_format = get_message_format() 56 | logger.add( 57 | sys.stdout, 58 | level=LEVEL, 59 | format=f"[{{time:MM-DD HH:mm:ss}}|{machine_format}{message_format}", 60 | filter=_rank0_only_filter, 61 | ) 62 | 63 | 64 | def init_loguru_file(path: str) -> None: 65 | machine_format = get_machine_format() 66 | message_format = get_message_format() 67 | logger.add( 68 | path, 69 | encoding="utf8", 70 | level=LEVEL, 71 | format=f"[{{time:MM-DD HH:mm:ss}}|{machine_format}{message_format}", 72 | rotation="100 MB", 73 | filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY, 74 | enqueue=True, 75 | ) 76 | 77 | 78 | def get_machine_format() -> str: 79 | node_id = "0" 80 | num_nodes = 1 81 | machine_format = "" 82 | rank = 0 83 | if dist.is_available(): 84 | if not RANK0_ONLY and dist.is_initialized(): 85 | rank = dist.get_rank() 86 | world_size = dist.get_world_size() 87 | machine_format = ( 88 | f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| " 89 | ) 90 | return machine_format 91 | 92 | 93 | def get_message_format() -> str: 94 | message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" 95 | return message_format 96 | 97 | 98 | def _rank0_only_filter(record: Any) -> bool: 99 | is_rank0 = record["extra"].get("rank0_only", True) 100 | if _get_rank() == 0 and is_rank0: 101 | return True 102 | if not is_rank0: 103 | record["message"] = f"[RANK{_get_rank()}] " + record["message"] 104 | return not is_rank0 105 | 106 | 107 | def trace(message: str, rank0_only: bool = True) -> None: 108 | logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) 109 | 110 | 111 | def debug(message: str, rank0_only: bool = True) -> None: 112 | logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) 113 | 114 | 115 | def info(message: str, rank0_only: bool = True) -> None: 116 | logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) 117 | 118 | 119 | def success(message: str, rank0_only: bool = True) -> None: 120 | logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) 121 | 122 | 123 | def warning(message: str, rank0_only: bool = True) -> None: 124 | logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) 125 | 126 | 127 | def error(message: str, rank0_only: bool = True) -> None: 128 | logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) 129 | 130 | 131 | def critical(message: str, rank0_only: bool = True) -> None: 132 | logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) 133 | 134 | 135 | def exception(message: str, rank0_only: bool = True) -> None: 136 | logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) 137 | 138 | 139 | def _get_rank(group: dist.ProcessGroup | None = None) -> int: 140 | """Get the rank (GPU device) of the worker. 141 | 142 | Returns: 143 | rank (int): The rank of the worker. 144 | """ 145 | rank = 0 146 | if dist.is_available() and dist.is_initialized(): 147 | rank = dist.get_rank(group) 148 | return rank 149 | 150 | 151 | # Execute at import time. 152 | init_loguru_stdout() 153 | -------------------------------------------------------------------------------- /rcm/utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import annotations 17 | 18 | from typing import List, NamedTuple, Tuple 19 | 20 | import torch 21 | 22 | from imaginaire.utils import log 23 | 24 | TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) 25 | if TORCH_VERSION >= (1, 11): 26 | from torch.ao import quantization 27 | from torch.ao.quantization import FakeQuantizeBase, ObserverBase 28 | elif TORCH_VERSION >= (1, 8) and hasattr(torch.quantization, "FakeQuantizeBase") and hasattr(torch.quantization, "ObserverBase"): 29 | from torch import quantization 30 | from torch.quantization import FakeQuantizeBase, ObserverBase 31 | 32 | 33 | class _IncompatibleKeys( 34 | NamedTuple( 35 | "IncompatibleKeys", 36 | [ 37 | ("missing_keys", List[str]), 38 | ("unexpected_keys", List[str]), 39 | ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), 40 | ], 41 | ) 42 | ): 43 | pass 44 | 45 | 46 | # https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py 47 | def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: 48 | # workaround https://github.com/pytorch/pytorch/issues/24139 49 | model_state_dict = model.state_dict() 50 | incorrect_shapes = [] 51 | for k in list(checkpoint_state_dict.keys()): 52 | if k in model_state_dict: 53 | if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 54 | log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") 55 | continue 56 | model_param = model_state_dict[k] 57 | # Allow mismatch for uninitialized parameters 58 | if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): 59 | continue 60 | if not isinstance(model_param, torch.Tensor): 61 | raise ValueError( 62 | f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." 63 | ) 64 | 65 | shape_model = tuple(model_param.shape) 66 | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) 67 | if shape_model != shape_checkpoint: 68 | has_observer_base_classes = ( 69 | TORCH_VERSION >= (1, 8) and hasattr(quantization, "ObserverBase") and hasattr(quantization, "FakeQuantizeBase") 70 | ) 71 | if has_observer_base_classes: 72 | # Handle the special case of quantization per channel observers, 73 | # where buffer shape mismatches are expected. 74 | def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: 75 | # foo.bar.param_or_buffer_name -> [foo, bar] 76 | key_parts = key.split(".")[:-1] 77 | cur_module = model 78 | for key_part in key_parts: 79 | cur_module = getattr(cur_module, key_part) 80 | return cur_module 81 | 82 | cls_to_skip = ( 83 | ObserverBase, 84 | FakeQuantizeBase, 85 | ) 86 | target_module = _get_module_for_key(model, k) 87 | if isinstance(target_module, cls_to_skip): 88 | # Do not remove modules with expected shape mismatches 89 | # them from the state_dict loading. They have special logic 90 | # in _load_from_state_dict to handle the mismatches. 91 | continue 92 | 93 | incorrect_shapes.append((k, shape_checkpoint, shape_model)) 94 | checkpoint_state_dict.pop(k) 95 | incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) 96 | # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling 97 | missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] 98 | unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] 99 | return _IncompatibleKeys( 100 | missing_keys=missing_keys, 101 | unexpected_keys=unexpected_keys, 102 | incorrect_shapes=incorrect_shapes, 103 | ) 104 | -------------------------------------------------------------------------------- /rcm/networks/wan2pt1_jvp_test.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. 3 | # All rights reserved. 4 | # 5 | # This codebase constitutes NVIDIA proprietary technology and is strictly 6 | # confidential. Any unauthorized reproduction, distribution, or disclosure 7 | # of this code, in whole or in part, outside NVIDIA is strictly prohibited 8 | # without prior written consent. 9 | # 10 | # For inquiries regarding the use of this code in other NVIDIA proprietary 11 | # projects, please contact the Deep Imagination Research Team at 12 | # dir@exchange.nvidia.com. 13 | # ----------------------------------------------------------------------------- 14 | 15 | import functools 16 | 17 | import pytest 18 | import torch 19 | from einops import repeat 20 | 21 | from imaginaire.lazy_config import LazyCall as L 22 | from imaginaire.lazy_config import instantiate 23 | from rcm.networks.wan2pt1 import WanModel 24 | from rcm.networks.wan2pt1_jvp import CheckpointMode, SACConfig, WanModel_JVP 25 | 26 | """ 27 | Usage: 28 | pytest -s rcm/networks/wan2pt1_jvp_test.py 29 | """ 30 | 31 | N_HEADS = 8 32 | HEAD_DIM = 32 33 | mini_net = L(WanModel)( 34 | model_type="t2v", 35 | patch_size=(1, 2, 2), 36 | text_len=512, 37 | in_dim=16, 38 | dim=2048, 39 | ffn_dim=8192, 40 | freq_dim=256, 41 | text_dim=1024, 42 | num_layers=4, 43 | ) 44 | mini_net_jvp = L(WanModel_JVP)( 45 | model_type="t2v", 46 | patch_size=(1, 2, 2), 47 | text_len=512, 48 | in_dim=16, 49 | dim=2048, 50 | ffn_dim=8192, 51 | freq_dim=256, 52 | text_dim=1024, 53 | num_layers=4, 54 | ) 55 | mini_net_jvp_naive = L(WanModel_JVP)( 56 | model_type="t2v", 57 | patch_size=(1, 2, 2), 58 | text_len=512, 59 | in_dim=16, 60 | dim=2048, 61 | ffn_dim=8192, 62 | freq_dim=256, 63 | text_dim=1024, 64 | num_layers=4, 65 | naive_attn=True, 66 | ) 67 | 68 | 69 | @pytest.mark.L1 70 | def test_equivalent_forward_raw_vs_jvp(): 71 | dtype = torch.float16 72 | net = instantiate(mini_net).cuda().to(dtype=dtype) 73 | net.eval() 74 | net_jvp = instantiate(mini_net_jvp).cuda().to(dtype=dtype) 75 | net_jvp.eval() 76 | 77 | net_jvp.load_state_dict(net.state_dict(), strict=False) 78 | 79 | batch_size = 2 80 | t = 8 81 | x_B_C_T_H_W = torch.randn(batch_size, 16, t, 40, 40).cuda().to(dtype=dtype) 82 | noise_labels_B = torch.randn(batch_size).cuda().to(dtype=dtype) 83 | noise_labels_BT = repeat(noise_labels_B, "b -> b 1") 84 | crossattn_emb_B_T_D = torch.randn(batch_size, 512, 1024).cuda().to(dtype=dtype) 85 | padding_mask_B_T_H_W = torch.zeros(batch_size, 1, 40, 40).cuda().to(dtype=dtype) 86 | 87 | output_BT = net(x_B_C_T_H_W, noise_labels_BT, crossattn_emb_B_T_D, padding_mask=padding_mask_B_T_H_W) 88 | output_BT_jvp = net_jvp(x_B_C_T_H_W, noise_labels_BT, crossattn_emb_B_T_D, padding_mask=padding_mask_B_T_H_W) 89 | torch.testing.assert_close(output_BT, output_BT_jvp, rtol=1e-3, atol=1e-3) 90 | 91 | 92 | """ 93 | Usage: 94 | pytest -s projects/cosmos/diffusion/v2/networks/wan2pt1_jvp_test.py --all -k test_equivalent_jvp_naive_vs_flash 95 | """ 96 | 97 | 98 | @pytest.mark.L1 99 | def test_equivalent_jvp_naive_vs_flash(): 100 | dtype = torch.float16 101 | net_jvp = instantiate(mini_net_jvp, sac_config=SACConfig(mode=CheckpointMode.NONE)).cuda().to(dtype=dtype) 102 | net_jvp.eval() 103 | net_jvp_naive = instantiate(mini_net_jvp_naive, sac_config=SACConfig(mode=CheckpointMode.NONE)).cuda().to(dtype=dtype) 104 | net_jvp_naive.eval() 105 | net_jvp_naive.load_state_dict(net_jvp.state_dict(), strict=False) 106 | 107 | batch_size = 2 108 | t = 8 109 | x_B_C_T_H_W = torch.randn(batch_size, 16, t, 40, 40).cuda().to(dtype=dtype) 110 | t_x_B_C_T_H_W = torch.randn_like(x_B_C_T_H_W) 111 | noise_labels_B = torch.randn(batch_size).cuda().to(dtype=dtype) 112 | t_noise_labels_B = torch.randn_like(noise_labels_B) 113 | noise_labels_BT = repeat(noise_labels_B, "b -> b 1") 114 | t_noise_labels_BT = repeat(t_noise_labels_B, "b -> b 1") 115 | crossattn_emb_B_T_D = torch.randn(batch_size, 512, 1024).cuda().to(dtype=dtype) 116 | fps_B = torch.randint(size=(1,), low=2, high=30).cuda().float().repeat(batch_size) 117 | padding_mask_B_T_H_W = torch.zeros(batch_size, 1, 40, 40).cuda().to(dtype=dtype) 118 | 119 | output_BT_withoutT = net_jvp(x_B_C_T_H_W, noise_labels_BT, crossattn_emb_B_T_D, fps=fps_B, padding_mask=padding_mask_B_T_H_W) 120 | output_BT, t_output_BT = net_jvp( 121 | (x_B_C_T_H_W, t_x_B_C_T_H_W), 122 | (noise_labels_BT, t_noise_labels_BT), 123 | crossattn_emb_B_T_D, 124 | fps=fps_B, 125 | padding_mask=padding_mask_B_T_H_W, 126 | withT=True, 127 | ) 128 | 129 | fn_naive = functools.partial(net_jvp_naive.forward, crossattn_emb=crossattn_emb_B_T_D, fps=fps_B, padding_mask=padding_mask_B_T_H_W) 130 | 131 | output_BT_naive, t_output_BT_naive = torch.func.jvp(fn_naive, (x_B_C_T_H_W, noise_labels_BT), (t_x_B_C_T_H_W, t_noise_labels_BT)) 132 | 133 | torch.testing.assert_close(output_BT_withoutT, output_BT_naive, rtol=1e-3, atol=1e-3) 134 | torch.testing.assert_close(output_BT_withoutT, output_BT, rtol=1e-3, atol=1e-3) 135 | torch.testing.assert_close(t_output_BT, t_output_BT_naive, rtol=1e-3, atol=1e-3) 136 | -------------------------------------------------------------------------------- /imaginaire/model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any 17 | 18 | import torch 19 | 20 | from imaginaire.lazy_config import LazyDict, instantiate 21 | 22 | 23 | class ImaginaireModel(torch.nn.Module): 24 | """The base model class of Imaginaire. It is inherited from torch.nn.Module. 25 | 26 | All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the 27 | computation graphs. All inheriting child classes should implement the following methods: 28 | - training_step(): The training step of the model, including the loss computation. 29 | - validation_step(): The validation step of the model, including the loss computation. 30 | - forward(): The computation graph for model inference. 31 | The following methods have default implementations in ImaginaireModel: 32 | - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. 33 | """ 34 | 35 | def __init__(self) -> None: 36 | super().__init__() 37 | 38 | def init_optimizer_scheduler( 39 | self, 40 | optimizer_config: LazyDict[torch.optim.Optimizer], 41 | scheduler_config: LazyDict[torch.optim.lr_scheduler.LRScheduler], 42 | ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: 43 | """Creates the optimizer and scheduler for the model. 44 | 45 | Args: 46 | config_model (ModelConfig): The config object for the model. 47 | 48 | Returns: 49 | optimizer (torch.optim.Optimizer): The model optimizer. 50 | scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. 51 | """ 52 | optimizer_config.params = self.parameters() 53 | optimizer = instantiate(optimizer_config) 54 | scheduler_config.optimizer = optimizer 55 | scheduler = instantiate(scheduler_config) 56 | return optimizer, scheduler 57 | 58 | def training_step( 59 | self, data_batch: dict[str, torch.Tensor], iteration: int 60 | ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: 61 | """The training step of the model, including the loss computation. 62 | 63 | Args: 64 | data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). 65 | iteration (int): Current iteration number. 66 | 67 | Returns: 68 | output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. 69 | loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). 70 | """ 71 | raise NotImplementedError 72 | 73 | @torch.no_grad() 74 | def validation_step( 75 | self, data_batch: dict[str, torch.Tensor], iteration: int 76 | ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: 77 | """The validation step of the model, including the loss computation. 78 | 79 | Args: 80 | data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). 81 | iteration (int): Current iteration number. 82 | 83 | Returns: 84 | output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. 85 | loss (torch.Tensor): The total loss (weighted sum of various losses). 86 | """ 87 | raise NotImplementedError 88 | 89 | @torch.inference_mode() 90 | def forward(self, *args: Any, **kwargs: Any) -> Any: 91 | """The computation graph for model inference. 92 | 93 | Args: 94 | *args: Whatever you decide to pass into the forward method. 95 | **kwargs: Keyword arguments are also possible. 96 | 97 | Return: 98 | Your model's output. 99 | """ 100 | raise NotImplementedError 101 | 102 | def on_model_init_start(self, set_barrier=False) -> None: 103 | return 104 | 105 | def on_model_init_end(self, set_barrier=False) -> None: 106 | return 107 | 108 | def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: 109 | """The model preparation before the training is launched 110 | 111 | Args: 112 | memory_format (torch.memory_format): Memory format of the model. 113 | """ 114 | pass 115 | 116 | def on_before_zero_grad( 117 | self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int 118 | ) -> None: 119 | """Hook before zero_grad() is called. 120 | 121 | Args: 122 | optimizer (torch.optim.Optimizer): The model optimizer. 123 | scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. 124 | iteration (int): Current iteration number. 125 | """ 126 | pass 127 | 128 | def on_after_backward(self, iteration: int = 0) -> None: 129 | """Hook after loss.backward() is called. 130 | 131 | This method is called immediately after the backward pass, allowing for custom operations 132 | or modifications to be performed on the gradients before the optimizer step. 133 | 134 | Args: 135 | iteration (int): Current iteration number. 136 | """ 137 | pass 138 | -------------------------------------------------------------------------------- /imaginaire/utils/profiling.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import contextlib 17 | import os 18 | import time 19 | 20 | import torch 21 | 22 | from imaginaire.utils import distributed, log 23 | from imaginaire.utils.easy_io import easy_io 24 | 25 | # the number of warmup steps before the active step in each profiling cycle 26 | TORCH_TRACE_WARMUP = 3 27 | 28 | # how much memory allocation/free ops to record in memory snapshots 29 | MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 30 | 31 | 32 | @contextlib.contextmanager 33 | def maybe_enable_profiling(config, *, global_step: int = 0): 34 | # get user defined profiler settings 35 | enable_profiling = config.trainer.profiling.enable_profiling 36 | profile_freq = config.trainer.profiling.profile_freq 37 | 38 | if enable_profiling: 39 | trace_dir = os.path.join(config.job.path_local, "torch_trace") 40 | if distributed.get_rank() == 0: 41 | os.makedirs(trace_dir, exist_ok=True) 42 | 43 | rank = distributed.get_rank() 44 | 45 | def trace_handler(prof): 46 | curr_trace_dir_name = "iteration_" + str(prof.step_num) 47 | curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) 48 | if not os.path.exists(curr_trace_dir): 49 | os.makedirs(curr_trace_dir, exist_ok=True) 50 | 51 | log.info(f"Dumping traces at step {prof.step_num}") 52 | begin = time.monotonic() 53 | if config.trainer.profiling.first_n_rank < 0 or rank < config.trainer.profiling.first_n_rank: 54 | prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json.gz") # saved as gz to save space 55 | log.info(f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds") 56 | 57 | log.info(f"Profiling active. Traces will be saved at {trace_dir}") 58 | 59 | if not os.path.exists(trace_dir): 60 | os.makedirs(trace_dir, exist_ok=True) 61 | 62 | warmup, active = TORCH_TRACE_WARMUP, 1 63 | wait = profile_freq - (active + warmup) 64 | assert wait >= 0, "profile_freq must be greater than or equal to warmup + active" 65 | 66 | with torch.profiler.profile( 67 | activities=[ 68 | torch.profiler.ProfilerActivity.CPU, 69 | torch.profiler.ProfilerActivity.CUDA, 70 | ], 71 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), 72 | on_trace_ready=trace_handler, 73 | record_shapes=config.trainer.profiling.record_shape, 74 | profile_memory=config.trainer.profiling.profile_memory, 75 | with_stack=config.trainer.profiling.with_stack, 76 | with_modules=config.trainer.profiling.with_modules, 77 | ) as torch_profiler: 78 | torch_profiler.step_num = global_step 79 | yield torch_profiler 80 | else: 81 | torch_profiler = contextlib.nullcontext() 82 | yield None 83 | 84 | 85 | @contextlib.contextmanager 86 | def maybe_enable_memory_snapshot(config, *, global_step: int = 0): 87 | enable_snapshot = config.trainer.profiling.enable_memory_snapshot 88 | if enable_snapshot: 89 | snapshot_dir = os.path.join(config.job.path_local, "memory_snapshot") 90 | if distributed.get_rank() == 0: 91 | os.makedirs(snapshot_dir, exist_ok=True) 92 | 93 | rank = torch.distributed.get_rank() 94 | 95 | class MemoryProfiler: 96 | def __init__(self, step_num: int, freq: int): 97 | torch.cuda.memory._record_memory_history(max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES) 98 | # when resume training, we start from the last step 99 | self.step_num = step_num 100 | self.freq = freq 101 | 102 | def step(self, exit_ctx: bool = False): 103 | self.step_num += 1 104 | if not exit_ctx and self.step_num % self.freq != 0: 105 | return 106 | if not exit_ctx: 107 | curr_step = self.step_num 108 | dir_name = f"iteration_{curr_step}" 109 | else: 110 | # dump as iteration_0_exit if OOM at iter 1 111 | curr_step = self.step_num - 1 112 | dir_name = f"iteration_{curr_step}_exit" 113 | curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) 114 | if not os.path.exists(curr_snapshot_dir): 115 | os.makedirs(curr_snapshot_dir, exist_ok=True) 116 | log.info(f"Dumping memory snapshot at step {curr_step}") 117 | begin = time.monotonic() 118 | 119 | if config.trainer.profiling.first_n_rank < 0 or rank < config.trainer.profiling.first_n_rank: 120 | easy_io.dump( 121 | torch.cuda.memory._snapshot(), 122 | f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", 123 | ) 124 | log.info(f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds") 125 | 126 | log.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") 127 | profiler = MemoryProfiler(global_step, config.trainer.profiling.profile_freq) 128 | try: 129 | yield profiler 130 | except torch.cuda.OutOfMemoryError as e: 131 | profiler.step(exit_ctx=True) 132 | else: 133 | yield None 134 | -------------------------------------------------------------------------------- /rcm/utils/fsdp_helper.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import annotations 17 | 18 | from contextlib import contextmanager 19 | from functools import partial 20 | 21 | import torch 22 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 23 | CheckpointImpl, 24 | apply_activation_checkpointing, 25 | checkpoint_wrapper, 26 | ) 27 | from torch.distributed.device_mesh import init_device_mesh 28 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 29 | from torch.distributed.fsdp._runtime_utils import ( 30 | _post_forward, 31 | _post_forward_reshard, 32 | _pre_forward, 33 | _pre_forward_unshard, 34 | _root_pre_forward, 35 | ) 36 | from torch.distributed.utils import _p_assert 37 | 38 | from imaginaire.utils import distributed, log 39 | 40 | 41 | def apply_fsdp_checkpointing(model, list_block_cls): 42 | """apply activation checkpointing to model 43 | returns None as model is updated directly 44 | """ 45 | log.critical("--> applying fdsp activation checkpointing...") 46 | non_reentrant_wrapper = partial( 47 | checkpoint_wrapper, 48 | # offload_to_cpu=False, 49 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 50 | ) 51 | 52 | def check_fn(submodule): 53 | result = False 54 | for block_cls in list_block_cls: 55 | if isinstance(submodule, block_cls): 56 | result = True 57 | break 58 | return result 59 | 60 | apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) 61 | 62 | 63 | @contextmanager 64 | def possible_fsdp_scope( 65 | model: torch.nn.Module, 66 | ): 67 | enabled = isinstance(model, FSDP) 68 | if enabled: 69 | assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" 70 | handle = model._handle 71 | args, kwargs = [0], dict(dummy=0) 72 | with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): 73 | args, kwargs = _root_pre_forward(model, model, args, kwargs) 74 | unused = None 75 | args, kwargs = _pre_forward( 76 | model, 77 | handle, 78 | _pre_forward_unshard, 79 | model._fsdp_wrapped_module, 80 | args, 81 | kwargs, 82 | ) 83 | if handle: 84 | _p_assert( 85 | handle.flat_param.device == model.compute_device, 86 | "Expected `FlatParameter` to be on the compute device " f"{model.compute_device} but got {handle.flat_param.device}", 87 | ) 88 | try: 89 | yield None 90 | finally: 91 | if enabled: 92 | output = {"output": 1} 93 | _post_forward(model, handle, _post_forward_reshard, model, unused, output) 94 | 95 | 96 | def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): 97 | """ 98 | Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. 99 | 100 | This function requires explicit sizes for replica and sharding groups to accommodate models 101 | whose GPU fit is unknown, providing flexibility in distributed training setups. 102 | 103 | Args: 104 | replica_group_size (int): The size of each replica group. Must be provided to ensure 105 | the model fits within the available resources. 106 | sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to 107 | ensure the correct distribution of model parameters. 108 | device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" 109 | with the local rank as the device index. 110 | 111 | Returns: 112 | A device mesh object compatible with FSDP. 113 | 114 | Raises: 115 | ValueError: If replica_group_size or sharding_group_size are not provided, or if the 116 | world size is not evenly divisible by the sharding group size. 117 | RuntimeError: If a valid device mesh cannot be created. 118 | 119 | Usage: 120 | If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: 121 | Sharding_Group_Size = 4 122 | Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups 123 | >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) 124 | >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) 125 | """ 126 | 127 | # world_size = int(os.getenv("WORLD_SIZE", "1")) 128 | world_size = distributed.get_world_size() 129 | if sharding_group_size is None: 130 | sharding_group_size = min(world_size, 8) 131 | sharding_group_size = min(sharding_group_size, world_size) 132 | if replica_group_size is None: 133 | replica_group_size = world_size // sharding_group_size 134 | 135 | device = device or "cuda" 136 | 137 | if world_size % sharding_group_size != 0: 138 | raise ValueError(f"World size {world_size} is not evenly divisible by sharding group size {sharding_group_size}.") 139 | 140 | if (world_size // sharding_group_size) % replica_group_size != 0: 141 | raise ValueError(f"The calculated number of replica groups is not evenly divisible by " f"replica_group_size {replica_group_size}.") 142 | 143 | device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard")) 144 | if device_mesh is None: 145 | raise RuntimeError("Failed to create a valid device mesh.") 146 | 147 | log.critical(f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}") 148 | 149 | return device_mesh 150 | -------------------------------------------------------------------------------- /rcm/utils/attention.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Single point of entry for all generic attention ops (self and cross attention), that tries to 18 | # deliver the best performance possible given any use case (GPU and environment). 19 | # 20 | # On Hopper GPUs (i.e. H100, H20, H200), Flash Attention 3 is the best-performing choice, but it 21 | # needs to be installed. When it is not available, the second best choice is cuDNN attention, which 22 | # we get using PyTorch's SDPA API. 23 | # 24 | # For all other use cases, we will just use PyTorch's SDPA, but we need to specify backends and 25 | # priorities. 26 | # Flash Attention 2, which is one of the backends, is the best choice for Ampere GPUs (both RTX and 27 | # datacenter-class). 28 | # 29 | # For anything pre-Ampere, the only choice is "memory-efficient" (xformers) FMHA. 30 | # 31 | # For Ada and Blackwell RTX, it is unclear at the moment, so we defer to Flash Attention 2, and 32 | # fallbacks are cuDNN and xformers. 33 | # 34 | # For Blackwell datacenter-class (B200, GB200), cuDNN is the best choice. 35 | # 36 | # 37 | # Dispatching to the desired backends/paths are done by checking the compute capability (really SM 38 | # number, which is just compute capability * 10) of the GPU device the input tensors are on. 39 | # 40 | # Here's a breakdown of relevant compute capabilities: 41 | # 42 | # | GPU / category | Arch | 43 | # |================|=======| 44 | # | A100 | SM80 | 45 | # | A40 | SM80 | 46 | # | Ampere RTX | SM86 | 47 | # |----------------|-------| 48 | # | Ada Lovelace | SM89 | 49 | # |----------------|-------| 50 | # | H20 | SM90 | 51 | # | H100 | SM90 | 52 | # | H200 | SM90 | 53 | # |----------------|-------| 54 | # | B200 | SM100 | 55 | # | Blackwell RTX | SM103 | 56 | # |----------------|-------| 57 | # 58 | 59 | from functools import partial 60 | 61 | import torch 62 | from torch.nn.attention import SDPBackend, sdpa_kernel 63 | 64 | try: 65 | from flash_attn_3.flash_attn_interface import flash_attn_func 66 | 67 | FLASH_ATTN_3_AVAILABLE = True 68 | except ModuleNotFoundError: 69 | FLASH_ATTN_3_AVAILABLE = False 70 | 71 | 72 | def get_device_cc(device) -> int: 73 | """ 74 | Returns the compute capability of a given torch device if it's a CUDA device, otherwise returns 0. 75 | 76 | Args: 77 | device: torch device. 78 | 79 | Returns: 80 | device_cc (int): compute capability in the SmXXX format (i.e. 90 for Hopper). 81 | """ 82 | if torch.cuda.is_available() and torch.version.cuda and device.type == "cuda": 83 | major, minor = torch.cuda.get_device_capability(device) 84 | return major * 10 + minor 85 | return 0 86 | 87 | 88 | def attention( 89 | q, 90 | k, 91 | v, 92 | dropout_p=0.0, 93 | softmax_scale=None, 94 | q_scale=None, 95 | causal=False, 96 | deterministic=False, 97 | ): 98 | assert q.dtype == k.dtype and k.dtype == v.dtype 99 | dtype = q.dtype 100 | supported_dtypes = [torch.bfloat16, torch.float16, torch.float32] 101 | is_half = dtype in [torch.bfloat16, torch.float16] 102 | compute_cap = get_device_cc(q.device) 103 | 104 | if dtype not in supported_dtypes: 105 | raise NotImplementedError(f"{dtype=} is not supported.") 106 | 107 | if q_scale is not None: 108 | q = q * q_scale 109 | 110 | # If Flash Attention 3 is installed, and the user's running on a Hopper GPU (compute capability 111 | # 9.0, or SM90), use Flash Attention 3. 112 | if compute_cap == 90 and FLASH_ATTN_3_AVAILABLE and is_half: 113 | return flash_attn_func( 114 | q=q, 115 | k=k, 116 | v=v, 117 | softmax_scale=softmax_scale, 118 | causal=causal, 119 | deterministic=deterministic, 120 | )[0] 121 | else: 122 | # If Blackwell or Hopper (SM100 or SM90), cuDNN has native FMHA kernels. The Hopper one is 123 | # not always as fast as Flash Attention 3, but when Flash Attention is unavailable, it's 124 | # still a far better choice than Flash Attention 2 (Ampere). 125 | if compute_cap in [90, 100] and is_half: 126 | SDPA_BACKENDS = [ 127 | SDPBackend.CUDNN_ATTENTION, 128 | SDPBackend.FLASH_ATTENTION, 129 | SDPBackend.EFFICIENT_ATTENTION, 130 | ] 131 | BEST_SDPA_BACKEND = SDPBackend.CUDNN_ATTENTION 132 | elif is_half: 133 | SDPA_BACKENDS = [ 134 | SDPBackend.FLASH_ATTENTION, 135 | SDPBackend.CUDNN_ATTENTION, 136 | SDPBackend.EFFICIENT_ATTENTION, 137 | ] 138 | BEST_SDPA_BACKEND = SDPBackend.FLASH_ATTENTION if compute_cap >= 80 else SDPBackend.EFFICIENT_ATTENTION 139 | else: 140 | assert dtype == torch.float32, f"Unrecognized {dtype=}." 141 | SDPA_BACKENDS = [SDPBackend.EFFICIENT_ATTENTION] 142 | BEST_SDPA_BACKEND = SDPBackend.EFFICIENT_ATTENTION 143 | 144 | if deterministic: 145 | raise NotImplementedError("Deterministic mode in attention is only supported when Flash Attention 3 is available.") 146 | 147 | # Torch 2.6 and later allows priorities for backends, but for older versions 148 | # we can only run with a specific backend. As long as we pick ones we're certain 149 | # will work on that device, it should be fine. 150 | try: 151 | sdpa_kernel(backends=SDPA_BACKENDS, set_priority_order=True) 152 | sdpa_kernel_ = partial(sdpa_kernel, set_priority_order=True) 153 | except TypeError: 154 | sdpa_kernel_ = sdpa_kernel 155 | SDPA_BACKENDS = [BEST_SDPA_BACKEND] 156 | 157 | q = q.transpose(1, 2) 158 | k = k.transpose(1, 2) 159 | v = v.transpose(1, 2) 160 | 161 | with sdpa_kernel_(backends=SDPA_BACKENDS): 162 | out = torch.nn.functional.scaled_dot_product_attention( 163 | q, 164 | k, 165 | v, 166 | is_causal=causal, 167 | dropout_p=dropout_p, 168 | scale=softmax_scale, 169 | ) 170 | 171 | out = out.transpose(1, 2).contiguous() 172 | return out 173 | -------------------------------------------------------------------------------- /imaginaire/utils/easy_io/handlers/imageio_video_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import IO, Any 17 | 18 | import imageio 19 | import imageio.v3 as iio_v3 20 | import numpy as np 21 | import torch 22 | 23 | from imaginaire.utils import log 24 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler 25 | 26 | 27 | class ImageioVideoHandler(BaseFileHandler): 28 | str_like = False 29 | 30 | def load_from_fileobj( 31 | self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs 32 | ) -> tuple[np.ndarray, dict[str, Any]]: 33 | """ 34 | Load video from a file-like object using imageio.v3 with specified format and color mode. 35 | 36 | Parameters: 37 | file (IO[bytes]): A file-like object containing video data. 38 | format (str): Format of the video file (default 'mp4'). 39 | mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). 40 | 41 | Returns: 42 | tuple: A tuple containing an array of video frames and metadata about the video. 43 | """ 44 | file.seek(0) 45 | 46 | # The plugin argument in v3 replaces the format argument in v2 47 | plugin = kwargs.pop("plugin", "pyav") 48 | 49 | # Load all frames at once using v3 API 50 | video_frames = iio_v3.imread(file, plugin=plugin, **kwargs) 51 | 52 | # Handle grayscale conversion if needed 53 | if mode == "gray": 54 | import cv2 55 | 56 | if len(video_frames.shape) == 4: # (frames, height, width, channels) 57 | gray_frames = [] 58 | for frame in video_frames: 59 | gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 60 | gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent 61 | gray_frames.append(gray_frame) 62 | video_frames = np.array(gray_frames) 63 | 64 | # Extract metadata 65 | # Note: iio_v3.imread doesn't return metadata directly like v2 did 66 | # We need to extract it separately 67 | file.seek(0) 68 | metadata = self._extract_metadata(file, plugin=plugin) 69 | 70 | return video_frames, metadata 71 | 72 | def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> dict[str, Any]: 73 | """ 74 | Extract metadata from a video file. 75 | 76 | Parameters: 77 | file (IO[bytes]): File-like object containing video data. 78 | plugin (str): Plugin to use for reading. 79 | 80 | Returns: 81 | dict: Video metadata. 82 | """ 83 | try: 84 | # Create a generator to read frames and metadata 85 | metadata = iio_v3.immeta(file, plugin=plugin) 86 | 87 | # Add some standard fields similar to v2 metadata format 88 | if "fps" not in metadata and "duration" in metadata: 89 | # Read the first frame to get shape information 90 | file.seek(0) 91 | first_frame = iio_v3.imread(file, plugin=plugin, index=0) 92 | metadata["size"] = first_frame.shape[1::-1] # (width, height) 93 | metadata["source_size"] = metadata["size"] 94 | 95 | # Create a consistent metadata structure with v2 96 | metadata["plugin"] = plugin 97 | if "codec" not in metadata: 98 | metadata["codec"] = "unknown" 99 | if "pix_fmt" not in metadata: 100 | metadata["pix_fmt"] = "unknown" 101 | 102 | # Calculate nframes if possible 103 | if "fps" in metadata and "duration" in metadata: 104 | metadata["nframes"] = int(metadata["fps"] * metadata["duration"]) 105 | else: 106 | metadata["nframes"] = float("inf") 107 | 108 | return metadata 109 | 110 | except Exception as e: 111 | # Fallback to basic metadata 112 | return { 113 | "plugin": plugin, 114 | "nframes": float("inf"), 115 | "codec": "unknown", 116 | "fps": 30.0, # Default values 117 | "duration": 0, 118 | "size": (0, 0), 119 | } 120 | 121 | def dump_to_fileobj( 122 | self, 123 | obj: np.ndarray | torch.Tensor, 124 | file: IO[bytes], 125 | format: str = "mp4", # pylint: disable=redefined-builtin 126 | fps: int = 17, 127 | quality: int = 7, 128 | ffmpeg_params=None, 129 | **kwargs, 130 | ): 131 | """ 132 | Save an array of video frames to a file-like object using imageio. 133 | 134 | Parameters: 135 | obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video. 136 | file (IO[bytes]): A file-like object to which the video data will be written. 137 | format (str): Format of the video file (default 'mp4'). 138 | fps (int): Frames per second of the output video (default 17). 139 | quality (int): Quality of the video (0-10, default 5). 140 | ffmpeg_params (list): Additional parameters to pass to ffmpeg. 141 | 142 | """ 143 | if isinstance(obj, torch.Tensor): 144 | assert obj.dtype == torch.uint8, "Tensor must be of type uint8" 145 | obj = obj.cpu().numpy() 146 | h, w = obj.shape[1:-1] 147 | 148 | # Default ffmpeg params that ensure width and height are set 149 | default_ffmpeg_params = ["-s", f"{w}x{h}"] 150 | 151 | # Use provided ffmpeg_params if any, otherwise use defaults 152 | final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params 153 | 154 | mimsave_kwargs = { 155 | "fps": fps, 156 | "quality": quality, 157 | "macro_block_size": 1, 158 | "ffmpeg_params": final_ffmpeg_params, 159 | "output_params": ["-f", "mp4"], 160 | } 161 | # Update with any other kwargs 162 | mimsave_kwargs.update(kwargs) 163 | log.debug(f"mimsave_kwargs: {mimsave_kwargs}") 164 | 165 | imageio.mimsave(file, obj, format, **mimsave_kwargs) 166 | 167 | def dump_to_str(self, obj, **kwargs): 168 | raise NotImplementedError 169 | -------------------------------------------------------------------------------- /imaginaire/utils/io.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import IO, Any 18 | 19 | import numpy as np 20 | import torch 21 | from einops import rearrange 22 | from PIL import Image as PILImage 23 | from torch import Tensor 24 | 25 | from imaginaire.utils import log 26 | from imaginaire.utils.easy_io import easy_io 27 | 28 | 29 | def save_image_or_video_multiview( 30 | tensor: Tensor, save_path: str | IO[Any], fps: int = 24, quality=None, ffmpeg_params=None, n_views: int = 1 31 | ) -> None: 32 | """ 33 | Split the tensor into n_views, stack them along the width dimension, and save as a video 34 | Args: 35 | tensor (Tensor): Input tensor with shape (B, C, T, H, W) or (C, T, H, W) in [-1, 1] or [0, 1] range. 36 | If in [-1, 1] range, it will be automatically converted to [0, 1] range. 37 | save_path (Union[str, IO[Any]]): File path (with or without extension) or file-like object. 38 | fps (int): Frames per second for video. Default is 24. 39 | quality: Optional quality parameter for images (passed to easy_io). 40 | ffmpeg_params: Optional ffmpeg parameters for videos (passed to easy_io). 41 | """ 42 | # Handle batch dimension if present 43 | if tensor.ndim == 5: 44 | tensor = tensor[0] # Take the first item from the batch 45 | 46 | assert tensor.ndim == 4, "Tensor must have shape (C, T, H, W) or (B, C, T, H, W)" 47 | assert isinstance(save_path, str) or hasattr(save_path, "write"), "save_path must be a string or file-like object" 48 | 49 | # Normalize to [0, 1] range 50 | if torch.is_floating_point(tensor): 51 | # Check if tensor is in [-1, 1] range (approximately) 52 | if tensor.min() < -0.5: 53 | # Convert from [-1, 1] to [0, 1] 54 | tensor = (tensor + 1.0) / 2.0 55 | tensor = tensor.clamp(0, 1) 56 | else: 57 | assert tensor.dtype == torch.uint8, "Only support uint8 tensor" 58 | tensor = tensor.float().div(255) 59 | 60 | kwargs = {} 61 | if quality is not None: 62 | kwargs["quality"] = quality 63 | if ffmpeg_params is not None: 64 | kwargs["ffmpeg_params"] = ffmpeg_params 65 | 66 | save_obj = (rearrange((tensor.cpu().float().numpy() * 255), "c (v t) h w -> t (v h) w c", v=n_views) + 0.5).astype( 67 | np.uint8 68 | ) 69 | if isinstance(save_path, str): 70 | # Check if path already has an extension 71 | base, ext = os.path.splitext(save_path) 72 | if not ext: 73 | save_path = f"{base}.mp4" 74 | log.info(f"Saving video to {save_path} with fps {fps} and result shape {save_obj.shape}") 75 | easy_io.dump(save_obj, save_path, file_format="mp4", format="mp4", fps=fps, **kwargs) 76 | 77 | 78 | def save_image_or_video( 79 | tensor: Tensor, save_path: str | IO[Any], fps: int = 24, quality=None, ffmpeg_params=None 80 | ) -> None: 81 | """ 82 | Save a tensor as an image or video file based on shape 83 | 84 | Args: 85 | tensor (Tensor): Input tensor with shape (B, C, T, H, W) or (C, T, H, W) in [-1, 1] or [0, 1] range. 86 | If in [-1, 1] range, it will be automatically converted to [0, 1] range. 87 | save_path (Union[str, IO[Any]]): File path (with or without extension) or file-like object. 88 | fps (int): Frames per second for video. Default is 24. 89 | quality: Optional quality parameter for images (passed to easy_io). 90 | ffmpeg_params: Optional ffmpeg parameters for videos (passed to easy_io). 91 | """ 92 | # Handle batch dimension if present 93 | if tensor.ndim == 5: 94 | tensor = tensor[0] # Take the first item from the batch 95 | 96 | assert tensor.ndim == 4, "Tensor must have shape (C, T, H, W) or (B, C, T, H, W)" 97 | assert isinstance(save_path, str) or hasattr(save_path, "write"), "save_path must be a string or file-like object" 98 | 99 | # Normalize to [0, 1] range 100 | if torch.is_floating_point(tensor): 101 | # Check if tensor is in [-1, 1] range (approximately) 102 | if tensor.min() < -0.5: 103 | # Convert from [-1, 1] to [0, 1] 104 | tensor = (tensor + 1.0) / 2.0 105 | tensor = tensor.clamp(0, 1) 106 | else: 107 | assert tensor.dtype == torch.uint8, "Only support uint8 tensor" 108 | tensor = tensor.float().div(255) 109 | 110 | kwargs = {} 111 | if quality is not None: 112 | kwargs["quality"] = quality 113 | if ffmpeg_params is not None: 114 | kwargs["ffmpeg_params"] = ffmpeg_params 115 | 116 | if tensor.shape[1] == 1: 117 | save_obj = PILImage.fromarray( 118 | (rearrange((tensor.cpu().float().numpy() * 255), "c 1 h w -> h w c") + 0.5).astype(np.uint8), 119 | mode="RGB", 120 | ) 121 | if isinstance(save_path, str): 122 | # Check if path already has an extension 123 | base, ext = os.path.splitext(save_path) 124 | if not ext: 125 | save_path = f"{base}.jpg" 126 | easy_io.dump(save_obj, save_path, file_format="jpg", format="JPEG", quality=85, **kwargs) 127 | else: 128 | save_obj = (rearrange((tensor.cpu().float().numpy() * 255), "c t h w -> t h w c") + 0.5).astype(np.uint8) 129 | if isinstance(save_path, str): 130 | # Check if path already has an extension 131 | base, ext = os.path.splitext(save_path) 132 | if not ext: 133 | save_path = f"{base}.mp4" 134 | easy_io.dump(save_obj, save_path, file_format="mp4", format="mp4", fps=fps, **kwargs) 135 | 136 | 137 | def save_text_prompts(prompts: dict[str | list], save_path: str | IO[Any]) -> None: 138 | """ 139 | Save text prompts to a file. 140 | 141 | Args: 142 | prompts (dict[str]): Dictionary of text prompts to save. Expected keys: "prompt", "negative_prompt", "refined_prompt". 143 | save_path (Union[str, IO[Any]]): File path (with or without extension) or file-like object. 144 | """ 145 | if isinstance(save_path, str): 146 | base, ext = os.path.splitext(save_path) 147 | if not ext: 148 | save_path = f"{base}.txt" 149 | with open(save_path, "w") as f: 150 | f.write(f"[Prompt]\n{prompts['prompt']}\n") 151 | if prompts.get("negative_prompt"): 152 | f.write(f"[Negative Prompt]\n{prompts['negative_prompt']}\n") 153 | 154 | if prompts.get("refined_prompt"): 155 | if isinstance(prompts["refined_prompt"], str): 156 | f.write(f"[Refined Prompt]\n{prompts['refined_prompt']}\n") 157 | elif isinstance(prompts["refined_prompt"], list): 158 | for chunk_id, refined_prompt in enumerate(prompts["refined_prompt"]): 159 | f.write(f"[Refined Prompt for chunk {chunk_id}]\n{refined_prompt}\n") 160 | else: 161 | raise ValueError("refined_prompt must be a string or a list of strings") 162 | --------------------------------------------------------------------------------