├── rcm
├── configs
│ ├── experiments
│ │ └── rcm
│ │ │ └── __init__.py
│ ├── defaults
│ │ ├── trainer.py
│ │ ├── tokenizer.py
│ │ ├── ema.py
│ │ ├── model.py
│ │ ├── conditioner.py
│ │ ├── scheduler.py
│ │ ├── ckpt_type.py
│ │ ├── dataloader.py
│ │ ├── checkpoint.py
│ │ ├── optimizer.py
│ │ ├── net.py
│ │ └── callbacks.py
│ └── registry_distill.py
├── samplers
│ └── euler.py
├── utils
│ ├── lognormal.py
│ ├── misc.py
│ ├── selective_activation_checkpoint.py
│ ├── optim_instantiate_dtensor.py
│ ├── dtensor_helper.py
│ ├── checkpointer.py
│ ├── fsdp_helper.py
│ └── attention.py
├── modules
│ └── denoiser_scaling.py
├── datasets
│ ├── webdataset.py
│ ├── utils.py
│ ├── merge_tar_shards.py
│ └── visualize_tar.py
├── tokenizers
│ └── interface.py
├── callbacks
│ ├── compile_tokenizer.py
│ ├── dataloading_monitor.py
│ ├── grad_clip.py
│ ├── iter_speed.py
│ └── heart_beat.py
└── networks
│ └── wan2pt1_jvp_test.py
├── examples
├── i2v_input_1.jpg
├── i2v_input_2.jpg
├── i2v_input_3.jpg
├── i2v_input_4.jpg
├── i2v_input_5.jpg
├── i2v_input_6.jpg
└── i2v_input_7.jpg
├── imaginaire
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── easy_io
│ │ ├── __init__.py
│ │ ├── backends
│ │ │ ├── __init__.py
│ │ │ ├── base_backend.py
│ │ │ ├── http_backend.py
│ │ │ └── registry_utils.py
│ │ └── handlers
│ │ │ ├── torch_handler.py
│ │ │ ├── pandas_handler.py
│ │ │ ├── torchjit_handler.py
│ │ │ ├── __init__.py
│ │ │ ├── txt_handler.py
│ │ │ ├── gzip_handler.py
│ │ │ ├── byte_handler.py
│ │ │ ├── yaml_handler.py
│ │ │ ├── tarfile_handler.py
│ │ │ ├── csv_handler.py
│ │ │ ├── pickle_handler.py
│ │ │ ├── base.py
│ │ │ ├── json_handler.py
│ │ │ ├── jsonl_handler.py
│ │ │ ├── np_handler.py
│ │ │ ├── registry_utils.py
│ │ │ ├── pil_handler.py
│ │ │ └── imageio_video_handler.py
│ ├── device.py
│ ├── parallel_state_helper.py
│ ├── wandb_util.py
│ ├── log.py
│ ├── profiling.py
│ └── io.py
├── callbacks
│ ├── __init__.py
│ ├── manual_gc.py
│ ├── low_precision.py
│ └── every_n.py
├── lazy_config
│ ├── file_io.py
│ ├── registry.py
│ ├── omegaconf_patch.py
│ ├── __init__.py
│ └── instantiate.py
└── model.py
├── scripts
├── dcp_to_pth.py
└── train.py
└── .gitignore
/rcm/configs/experiments/rcm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/i2v_input_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_1.jpg
--------------------------------------------------------------------------------
/examples/i2v_input_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_2.jpg
--------------------------------------------------------------------------------
/examples/i2v_input_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_3.jpg
--------------------------------------------------------------------------------
/examples/i2v_input_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_4.jpg
--------------------------------------------------------------------------------
/examples/i2v_input_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_5.jpg
--------------------------------------------------------------------------------
/examples/i2v_input_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_6.jpg
--------------------------------------------------------------------------------
/examples/i2v_input_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/rcm/HEAD/examples/i2v_input_7.jpg
--------------------------------------------------------------------------------
/rcm/configs/defaults/trainer.py:
--------------------------------------------------------------------------------
1 | from hydra.core.config_store import ConfigStore
2 | from imaginaire.lazy_config import LazyCall as L
3 | from imaginaire.lazy_config import PLACEHOLDER, LazyDict
4 | from imaginaire.trainer import ImaginaireTrainer
5 | from rcm.trainers.trainer_distillation import ImaginaireTrainer_Distill
6 |
7 | TRAINER: LazyDict = L(ImaginaireTrainer)(config=PLACEHOLDER)
8 | TRAINER_DISTILL: LazyDict = L(ImaginaireTrainer_Distill)(config=PLACEHOLDER)
9 |
10 |
11 | def register_trainer():
12 | cs = ConfigStore.instance()
13 | cs.store(group="trainer", package="trainer.type", name="standard", node=TRAINER)
14 | cs.store(group="trainer", package="trainer.type", name="distill", node=TRAINER_DISTILL)
15 |
--------------------------------------------------------------------------------
/imaginaire/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/imaginaire/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/imaginaire/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/imaginaire/lazy_config/file_io.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
17 | from iopath.common.file_io import PathManager as PathManagerBase
18 |
19 | __all__ = ["PathHandler", "PathManager"]
20 |
21 |
22 | PathManager = PathManagerBase()
23 | PathManager.register_handler(HTTPURLHandler())
24 | PathManager.register_handler(OneDrivePathHandler())
25 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/tokenizer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 | from imaginaire.lazy_config import LazyCall as L, LazyDict, PLACEHOLDER
18 | from rcm.tokenizers.wan2pt1 import Wan2pt1VAEInterface
19 |
20 | Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)(vae_pth=PLACEHOLDER)
21 |
22 |
23 | def register_tokenizer():
24 | cs = ConfigStore.instance()
25 | cs.store(group="tokenizer", package="model.config.tokenizer", name="wan2pt1_tokenizer", node=Wan2pt1VAEConfig)
26 |
--------------------------------------------------------------------------------
/rcm/samplers/euler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class FlowEulerSampler:
5 |
6 | def __init__(
7 | self,
8 | num_train_timesteps=1000,
9 | sigma_max=1.0,
10 | sigma_min=0.0,
11 | ):
12 | self.num_train_timesteps = num_train_timesteps
13 | self.sigma_max = sigma_max
14 | self.sigma_min = sigma_min
15 |
16 | def set_timesteps(self, num_inference_steps=100, shift=3.0, device="cuda"):
17 | self.sigmas = torch.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1)[:-1]
18 | self.sigmas = shift * self.sigmas / (1 + (shift - 1) * self.sigmas)
19 | self.timesteps = self.sigmas * self.num_train_timesteps
20 | self.sigmas = self.sigmas.to(device)
21 | self.timesteps = self.timesteps.to(device)
22 |
23 | def step(self, model_output, timestep, sample):
24 | timestep_id = torch.argmin((self.timesteps - timestep).abs(), dim=0)
25 | sigma = self.sigmas[timestep_id]
26 | if timestep_id + 1 >= len(self.timesteps):
27 | sigma_ = 0
28 | else:
29 | sigma_ = self.sigmas[timestep_id + 1]
30 | prev_sample = sample + model_output * (sigma_ - sigma)
31 | return prev_sample
32 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/backends/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
17 | from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
18 | from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
19 | from imaginaire.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend
20 |
21 | __all__ = [
22 | "BaseStorageBackend",
23 | "HTTPBackend",
24 | "LocalBackend",
25 | "backends",
26 | "prefix_to_backends",
27 | "register_backend",
28 | ]
29 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/torch_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | try:
17 | import torch
18 | except ImportError:
19 | torch = None
20 |
21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22 |
23 |
24 | class TorchHandler(BaseFileHandler):
25 | str_like = False
26 |
27 | def load_from_fileobj(self, file, **kwargs):
28 | return torch.load(file, **kwargs)
29 |
30 | def dump_to_fileobj(self, obj, file, **kwargs):
31 | torch.save(obj, file, **kwargs)
32 |
33 | def dump_to_str(self, obj, **kwargs):
34 | raise NotImplementedError
35 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/pandas_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import pandas as pd
17 |
18 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
19 |
20 |
21 | class PandasHandler(BaseFileHandler):
22 | str_like = False
23 |
24 | def load_from_fileobj(self, file, **kwargs):
25 | return pd.read_csv(file, **kwargs)
26 |
27 | def dump_to_fileobj(self, obj, file, **kwargs):
28 | obj.to_csv(file, **kwargs)
29 |
30 | def dump_to_str(self, obj, **kwargs):
31 | raise NotImplementedError("PandasHandler does not support dumping to str")
32 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/torchjit_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | try:
17 | import torch
18 | except ImportError:
19 | torch = None
20 |
21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22 |
23 |
24 | class TorchJitHandler(BaseFileHandler):
25 | str_like = False
26 |
27 | def load_from_fileobj(self, file, **kwargs):
28 | return torch.jit.load(file, **kwargs)
29 |
30 | def dump_to_fileobj(self, obj, file, **kwargs):
31 | torch.jit.save(obj, file, **kwargs)
32 |
33 | def dump_to_str(self, obj, **kwargs):
34 | raise NotImplementedError
35 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/ema.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import attrs
17 | from hydra.core.config_store import ConfigStore
18 |
19 |
20 | @attrs.define(slots=False)
21 | class EMAConfig:
22 | """
23 | Config for the EMA.
24 | """
25 |
26 | enabled: bool = True
27 | rate: float = 0.1
28 | iteration_shift: int = 0
29 |
30 |
31 | PowerEMAConfig: EMAConfig = EMAConfig(
32 | enabled=True,
33 | rate=0.10,
34 | iteration_shift=0,
35 | )
36 |
37 |
38 | def register_ema():
39 | cs = ConfigStore.instance()
40 | cs.store(group="ema", package="model.config.ema", name="power", node=PowerEMAConfig)
41 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17 | from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
18 | from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
19 | from imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler
20 | from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
21 |
22 | __all__ = [
23 | "BaseFileHandler",
24 | "JsonHandler",
25 | "PickleHandler",
26 | "YamlHandler",
27 | "file_handlers",
28 | "register_handler",
29 | ]
30 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/model.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.lazy_config import LazyCall as L
19 | from rcm.models.t2v_model_distill_rcm import T2VDistillConfig_rCM, T2VDistillModel_rCM
20 |
21 | FSDP_CONFIG_T2V_DISTILL_RCM = dict(
22 | trainer=dict(distributed_parallelism="fsdp"),
23 | model=L(T2VDistillModel_rCM)(config=T2VDistillConfig_rCM(fsdp_shard_size=8), _recursive_=False),
24 | )
25 |
26 |
27 | def register_model():
28 | cs = ConfigStore.instance()
29 | cs.store(group="model", package="_global_", name="fsdp_t2v_distill_rcm", node=FSDP_CONFIG_T2V_DISTILL_RCM)
30 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/txt_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17 |
18 |
19 | class TxtHandler(BaseFileHandler):
20 | def load_from_fileobj(self, file, **kwargs):
21 | del kwargs
22 | return file.read()
23 |
24 | def dump_to_fileobj(self, obj, file, **kwargs):
25 | del kwargs
26 | if not isinstance(obj, str):
27 | obj = str(obj)
28 | file.write(obj)
29 |
30 | def dump_to_str(self, obj, **kwargs):
31 | del kwargs
32 | if not isinstance(obj, str):
33 | obj = str(obj)
34 | return obj
35 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/gzip_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import gzip
17 | import pickle
18 | from io import BytesIO
19 | from typing import Any
20 |
21 | from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
22 |
23 |
24 | class GzipHandler(PickleHandler):
25 | str_like = False
26 |
27 | def load_from_fileobj(self, file: BytesIO, **kwargs):
28 | with gzip.GzipFile(fileobj=file, mode="rb") as f:
29 | return pickle.load(f)
30 |
31 | def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
32 | with gzip.GzipFile(fileobj=file, mode="wb") as f:
33 | pickle.dump(obj, f)
34 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/conditioner.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.lazy_config import LazyCall as L
19 | from imaginaire.lazy_config import LazyDict
20 | from rcm.conditioner import TextAttr, TextConditioner
21 |
22 | TextConditionerNoDropConfig: LazyDict = L(TextConditioner)(
23 | text=L(TextAttr)(
24 | input_key=["t5_text_embeddings"],
25 | dropout_rate=0.0,
26 | ),
27 | )
28 |
29 |
30 | def register_conditioner():
31 | cs = ConfigStore.instance()
32 | cs.store(group="conditioner", package="model.config.conditioner", name="text_nodrop", node=TextConditionerNoDropConfig)
33 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/scheduler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.lazy_config import LazyCall as L
19 | from imaginaire.lazy_config import LazyDict
20 | from rcm.utils.lr_scheduler import LambdaLinearScheduler
21 |
22 | LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)(
23 | warm_up_steps=[100],
24 | cycle_lengths=[10000000000000],
25 | f_start=[1.0e-6],
26 | f_max=[1.0],
27 | f_min=[1.0],
28 | )
29 |
30 |
31 | def register_scheduler():
32 | cs = ConfigStore.instance()
33 | cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig)
34 |
--------------------------------------------------------------------------------
/rcm/utils/lognormal.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from statistics import NormalDist
17 |
18 | import numpy as np
19 | import torch
20 |
21 |
22 | class LogNormal:
23 | def __init__(
24 | self,
25 | p_mean: float = 0.0,
26 | p_std: float = 1.0,
27 | ):
28 | self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std)
29 |
30 | def __call__(self, batch_size: int) -> torch.Tensor:
31 | cdf_vals = np.random.uniform(size=(batch_size))
32 | samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals]
33 |
34 | log_sigma = torch.tensor(samples_interval_gaussian, device="cuda")
35 | return torch.exp(log_sigma)
36 |
37 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/byte_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import IO
17 |
18 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
19 |
20 |
21 | class ByteHandler(BaseFileHandler):
22 | str_like = False
23 |
24 | def load_from_fileobj(self, file: IO[bytes], **kwargs):
25 | file.seek(0)
26 | # extra all bytes and return
27 | return file.read()
28 |
29 | def dump_to_fileobj(
30 | self,
31 | obj: bytes,
32 | file: IO[bytes],
33 | **kwargs,
34 | ):
35 | # write all bytes to file
36 | file.write(obj)
37 |
38 | def dump_to_str(self, obj, **kwargs):
39 | raise NotImplementedError
40 |
--------------------------------------------------------------------------------
/rcm/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def count_params(model: torch.nn.Module, verbose=False) -> int:
5 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
6 | if verbose:
7 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
8 | return total_params
9 |
10 |
11 | def common_broadcast(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
12 | ndims1 = x.ndim
13 | ndims2 = y.ndim
14 |
15 | common_ndims = min(ndims1, ndims2)
16 | for axis in range(common_ndims):
17 | assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis)
18 |
19 | if ndims1 < ndims2:
20 | x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
21 | elif ndims2 < ndims1:
22 | y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
23 |
24 | return x, y
25 |
26 |
27 | def batch_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
28 | x, y = common_broadcast(x, y)
29 | return x + y
30 |
31 |
32 | def batch_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
33 | x, y = common_broadcast(x, y)
34 | return x * y
35 |
36 |
37 | def batch_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
38 | x, y = common_broadcast(x, y)
39 | return x - y
40 |
41 |
42 | def batch_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
43 | x, y = common_broadcast(x, y)
44 | return x / y
45 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/ckpt_type.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Dict
17 |
18 | from hydra.core.config_store import ConfigStore
19 |
20 | from imaginaire.lazy_config import LazyCall as L
21 | from rcm.checkpointers.dcp import DistributedCheckpointer
22 | from rcm.checkpointers.dcp_distill import DistributedCheckpointer_Distill
23 |
24 | DISTRIBUTED_CHECKPOINTER: Dict[str, str] = L(DistributedCheckpointer)()
25 | DISTRIBUTED_CHECKPOINTER_DISTILL: Dict[str, str] = L(DistributedCheckpointer_Distill)()
26 |
27 |
28 | def register_ckpt_type():
29 | cs = ConfigStore.instance()
30 | cs.store(group="ckpt_type", package="checkpoint.type", name="dcp", node=DISTRIBUTED_CHECKPOINTER)
31 | cs.store(group="ckpt_type", package="checkpoint.type", name="dcp_distill", node=DISTRIBUTED_CHECKPOINTER_DISTILL)
32 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/yaml_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import yaml
17 |
18 | try:
19 | from yaml import CDumper as Dumper # type: ignore
20 | from yaml import CLoader as Loader # type: ignore
21 | except ImportError:
22 | from yaml import Dumper, Loader # type: ignore
23 |
24 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
25 |
26 |
27 | class YamlHandler(BaseFileHandler):
28 | def load_from_fileobj(self, file, **kwargs):
29 | kwargs.setdefault("Loader", Loader)
30 | return yaml.load(file, **kwargs)
31 |
32 | def dump_to_fileobj(self, obj, file, **kwargs):
33 | kwargs.setdefault("Dumper", Dumper)
34 | yaml.dump(obj, file, **kwargs)
35 |
36 | def dump_to_str(self, obj, **kwargs):
37 | kwargs.setdefault("Dumper", Dumper)
38 | return yaml.dump(obj, **kwargs)
39 |
--------------------------------------------------------------------------------
/rcm/modules/denoiser_scaling.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 |
19 | class RectifiedFlow_TrigFlowWrapper:
20 | def __init__(self, sigma_data: float = 1.0, t_scaling_factor: float = 1.0):
21 | assert abs(sigma_data - 1.0) < 1e-6, "sigma_data must be 1.0 for RectifiedFlowScaling"
22 | self.t_scaling_factor = t_scaling_factor
23 |
24 | def __call__(self, trigflow_t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
25 | trigflow_t = trigflow_t.to(torch.float64)
26 | c_skip = 1 / (torch.cos(trigflow_t) + torch.sin(trigflow_t))
27 | c_out = -1 * torch.sin(trigflow_t) / (torch.cos(trigflow_t) + torch.sin(trigflow_t))
28 | c_in = 1 / (torch.cos(trigflow_t) + torch.sin(trigflow_t))
29 | c_noise = (torch.sin(trigflow_t) / (torch.cos(trigflow_t) + torch.sin(trigflow_t))) * self.t_scaling_factor
30 | return c_skip, c_out, c_in, c_noise
31 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/tarfile_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tarfile
17 |
18 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
19 |
20 |
21 | class TarHandler(BaseFileHandler):
22 | str_like = False
23 |
24 | def load_from_fileobj(self, file, mode="r|*", **kwargs):
25 | return tarfile.open(fileobj=file, mode=mode, **kwargs)
26 |
27 | def load_from_path(self, filepath, mode="r|*", **kwargs):
28 | return tarfile.open(filepath, mode=mode, **kwargs)
29 |
30 | def dump_to_fileobj(self, obj, file, mode="w", **kwargs):
31 | with tarfile.open(fileobj=file, mode=mode) as tar:
32 | tar.add(obj, **kwargs)
33 |
34 | def dump_to_path(self, obj, filepath, mode="w", **kwargs):
35 | with tarfile.open(filepath, mode=mode) as tar:
36 | tar.add(obj, **kwargs)
37 |
38 | def dump_to_str(self, obj, **kwargs):
39 | raise NotImplementedError
40 |
--------------------------------------------------------------------------------
/imaginaire/utils/device.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import math
17 | import os
18 |
19 | import pynvml
20 |
21 |
22 | class Device:
23 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
24 |
25 | def __init__(self, device_idx: int):
26 | super().__init__()
27 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
28 |
29 | def get_name(self) -> str:
30 | return pynvml.nvmlDeviceGetName(self.handle)
31 |
32 | def get_cpu_affinity(self) -> list[int]:
33 | affinity_string = ""
34 | for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
35 | # assume nvml returns list of 64 bit ints
36 | affinity_string = f"{j:064b}" + affinity_string
37 | affinity_list = [int(x) for x in affinity_string]
38 | affinity_list.reverse() # so core 0 is in 0th element of list
39 | return [i for i, e in enumerate(affinity_list) if e != 0]
40 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/dataloader.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | from hydra.core.config_store import ConfigStore
18 |
19 | from imaginaire.lazy_config import LazyCall as L
20 | from rcm.datasets.webdataset import create_dataloader
21 |
22 |
23 | DUMMY_DATALOADER = L(torch.utils.data.DataLoader)(dataset=lambda: torch.utils.data.TensorDataset(torch.empty(0, 1), torch.empty(0)))
24 |
25 | WEBDATASET_LOADER = L(create_dataloader)(
26 | tar_path_pattern="/path/to/dataset/shard_*.tar",
27 | batch_size=1,
28 | num_workers=8,
29 | shuffle_buffer=1000,
30 | prefetch_factor=2,
31 | )
32 |
33 |
34 | def register_dataloader():
35 | cs = ConfigStore()
36 | cs.store(group="data_train", package="dataloader_train", name="dummy", node=DUMMY_DATALOADER)
37 | cs.store(group="data_train", package="dataloader_train", name="webdataset", node=WEBDATASET_LOADER)
38 | cs.store(group="data_val", package="dataloader_val", name="dummy", node=DUMMY_DATALOADER)
39 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/csv_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import csv
17 | from io import StringIO
18 |
19 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
20 |
21 |
22 | class CsvHandler(BaseFileHandler):
23 | def load_from_fileobj(self, file, **kwargs):
24 | del kwargs
25 | reader = csv.reader(file)
26 | return list(reader)
27 |
28 | def dump_to_fileobj(self, obj, file, **kwargs):
29 | del kwargs
30 | writer = csv.writer(file)
31 | if not all(isinstance(row, list) for row in obj):
32 | raise ValueError("Each row must be a list")
33 | writer.writerows(obj)
34 |
35 | def dump_to_str(self, obj, **kwargs):
36 | del kwargs
37 | output = StringIO()
38 | writer = csv.writer(output)
39 | if not all(isinstance(row, list) for row in obj):
40 | raise ValueError("Each row must be a list")
41 | writer.writerows(obj)
42 | return output.getvalue()
43 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/pickle_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import pickle
17 | from io import BytesIO
18 | from typing import Any
19 |
20 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21 |
22 |
23 | class PickleHandler(BaseFileHandler):
24 | str_like = False
25 |
26 | def load_from_fileobj(self, file: BytesIO, **kwargs):
27 | return pickle.load(file, **kwargs)
28 |
29 | def load_from_path(self, filepath, **kwargs):
30 | return super().load_from_path(filepath, mode="rb", **kwargs)
31 |
32 | def dump_to_str(self, obj, **kwargs):
33 | kwargs.setdefault("protocol", 2)
34 | return pickle.dumps(obj, **kwargs)
35 |
36 | def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
37 | kwargs.setdefault("protocol", 2)
38 | pickle.dump(obj, file, **kwargs)
39 |
40 | def dump_to_path(self, obj, filepath, **kwargs):
41 | with open(filepath, "wb") as f:
42 | pickle.dump(obj, f, **kwargs)
43 |
--------------------------------------------------------------------------------
/imaginaire/utils/parallel_state_helper.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """
17 | This module contains various helper functions designed to extend the functionality of parallel states within the MCore library.
18 |
19 | MCore is a third-party library that is infrequently updated and may introduce backward compatibility issues in our codebase, such as changes in function signatures or missing / new functions in new versions.
20 |
21 | To mitigate these issues, this module provides stable functions that ensure the imaginaire codebase remains compatible with different versions of MCore.
22 | """
23 |
24 | try:
25 | from megatron.core import parallel_state
26 | except ImportError:
27 | print("Megatron is not installed, is_tp_cp_pp_rank0 functions will not work.")
28 |
29 |
30 | def is_tp_cp_pp_rank0():
31 | return (
32 | parallel_state.get_tensor_model_parallel_rank() == 0
33 | and parallel_state.get_pipeline_model_parallel_rank() == 0
34 | and parallel_state.get_context_parallel_rank() == 0
35 | )
36 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/base.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from abc import ABCMeta, abstractmethod
17 |
18 |
19 | class BaseFileHandler(metaclass=ABCMeta):
20 | # `str_like` is a flag to indicate whether the type of file object is
21 | # str-like object or bytes-like object. Pickle only processes bytes-like
22 | # objects but json only processes str-like object. If it is str-like
23 | # object, `StringIO` will be used to process the buffer.
24 | str_like = True
25 |
26 | @abstractmethod
27 | def load_from_fileobj(self, file, **kwargs):
28 | pass
29 |
30 | @abstractmethod
31 | def dump_to_fileobj(self, obj, file, **kwargs):
32 | pass
33 |
34 | @abstractmethod
35 | def dump_to_str(self, obj, **kwargs):
36 | pass
37 |
38 | def load_from_path(self, filepath, mode="r", **kwargs):
39 | with open(filepath, mode) as f:
40 | return self.load_from_fileobj(f, **kwargs)
41 |
42 | def dump_to_path(self, obj, filepath, mode="w", **kwargs):
43 | with open(filepath, mode) as f:
44 | self.dump_to_fileobj(obj, f, **kwargs)
45 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/checkpoint.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.config import CheckpointConfig, ObjectStoreConfig
19 |
20 | s3_object_store = ObjectStoreConfig(
21 | enabled=True,
22 | credentials="credentials/s3_checkpoint.secret",
23 | bucket="checkpoints-us-east-1",
24 | )
25 |
26 | CHECKPOINT_LOCAL = CheckpointConfig(
27 | save_to_object_store=ObjectStoreConfig(),
28 | save_iter=1000,
29 | load_from_object_store=ObjectStoreConfig(),
30 | load_path="",
31 | load_training_state=False,
32 | strict_resume=True,
33 | )
34 |
35 | CHECKPOINT_S3 = CheckpointConfig(
36 | save_to_object_store=s3_object_store,
37 | save_iter=1000,
38 | load_from_object_store=s3_object_store,
39 | load_path="",
40 | load_training_state=False,
41 | strict_resume=True,
42 | )
43 |
44 |
45 | def register_checkpoint():
46 | cs = ConfigStore.instance()
47 | cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL)
48 | cs.store(group="checkpoint", package="checkpoint", name="s3", node=CHECKPOINT_S3)
49 |
--------------------------------------------------------------------------------
/rcm/datasets/webdataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import webdataset as wds
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | def dict_collation_fn(samples):
8 | if not samples:
9 | return {}
10 |
11 | keys = samples[0].keys()
12 | batched_dict = {key: [] for key in keys}
13 |
14 | for sample in samples:
15 | for key in keys:
16 | batched_dict[key].append(sample[key])
17 |
18 | for key in keys:
19 | if isinstance(batched_dict[key][0], torch.Tensor):
20 | batched_dict[key] = torch.stack(batched_dict[key])
21 |
22 | return batched_dict
23 |
24 |
25 | def create_dataloader(
26 | tar_path_pattern, # e.g., "/path/to/dataset/shard_*.tar"
27 | batch_size,
28 | num_workers=8,
29 | shuffle_buffer=1000,
30 | prefetch_factor=2,
31 | ):
32 | shards = glob.glob(tar_path_pattern)
33 | if not shards:
34 | raise FileNotFoundError(f"No files found with pattern '{tar_path_pattern}'")
35 |
36 | dataset = wds.DataPipeline(
37 | wds.SimpleShardList(shards),
38 | # this shuffles the shards
39 | wds.shuffle(1000),
40 | wds.split_by_node,
41 | wds.split_by_worker,
42 | wds.tarfile_to_samples(),
43 | # this shuffles the samples in memory
44 | wds.shuffle(shuffle_buffer),
45 | wds.decode(wds.handle_extension("pt", wds.torch_loads)),
46 | wds.rename(latents="latent.pt", t5_text_embeddings="embed.pt", prompts="prompt.txt"),
47 | wds.batched(batch_size, partial=False, collation_fn=dict_collation_fn),
48 | )
49 |
50 | dataloader = DataLoader(
51 | dataset,
52 | batch_size=None,
53 | shuffle=False,
54 | num_workers=num_workers,
55 | pin_memory=True,
56 | prefetch_factor=prefetch_factor,
57 | )
58 |
59 | return dataloader
60 |
--------------------------------------------------------------------------------
/rcm/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # "width:height"
17 | IMAGE_RES_SIZE_INFO: dict[str, tuple[int, int]] = {
18 | "1024": {"1:1": (1024, 1024), "4:3": (1168, 880), "3:4": (880, 1168), "16:9": (1360, 768), "9:16": (768, 1360)},
19 | "720": {"1:1": (960, 960), "4:3": (960, 704), "3:4": (704, 960), "16:9": (1280, 704), "9:16": (704, 1280)},
20 | "512": {"1:1": (512, 512), "4:3": (640, 512), "3:4": (512, 640), "16:9": (640, 384), "9:16": (384, 640)},
21 | "480": {"1:1": (480, 480), "4:3": (640, 480), "3:4": (480, 640), "16:9": (768, 432), "9:16": (432, 768)},
22 | }
23 |
24 | # "width:height"
25 | VIDEO_RES_SIZE_INFO: dict[str, tuple[int, int]] = {
26 | "720": {"1:1": (960, 960), "4:3": (960, 704), "3:4": (704, 960), "16:9": (1280, 704), "9:16": (704, 1280)},
27 | "512": {"1:1": (512, 512), "4:3": (640, 512), "3:4": (512, 640), "16:9": (640, 384), "9:16": (384, 640)},
28 | "480": {"1:1": (480, 480), "4:3": (640, 480), "3:4": (480, 640), "16:9": (768, 432), "9:16": (432, 768)},
29 | "480p": {"1:1": (640, 640), "4:3": (640, 480), "3:4": (480, 640), "16:9": (832, 480), "9:16": (480, 832)},
30 | "720p": {"1:1": (960, 960), "4:3": (960, 720), "3:4": (720, 960), "16:9": (1280, 720), "9:16": (720, 1280)},
31 | }
32 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/json_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 |
18 | import numpy as np
19 |
20 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21 |
22 |
23 | def set_default(obj):
24 | """Set default json values for non-serializable values.
25 |
26 | It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
27 | It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
28 | etc.) into plain numbers of plain python built-in types.
29 | """
30 | if isinstance(obj, (set, range)):
31 | return list(obj)
32 | elif isinstance(obj, np.ndarray):
33 | return obj.tolist()
34 | elif isinstance(obj, np.generic):
35 | return obj.item()
36 | raise TypeError(f"{type(obj)} is unsupported for json dump")
37 |
38 |
39 | class JsonHandler(BaseFileHandler):
40 | def load_from_fileobj(self, file):
41 | return json.load(file)
42 |
43 | def dump_to_fileobj(self, obj, file, **kwargs):
44 | kwargs.setdefault("default", set_default)
45 | json.dump(obj, file, **kwargs)
46 |
47 | def dump_to_str(self, obj, **kwargs):
48 | kwargs.setdefault("default", set_default)
49 | return json.dumps(obj, **kwargs)
50 |
--------------------------------------------------------------------------------
/scripts/dcp_to_pth.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 |
18 | import torch
19 | from torch.distributed.checkpoint import FileSystemReader
20 | from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
21 | from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
22 |
23 |
24 | def parse_arguments() -> argparse.Namespace:
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--dcp_checkpoint_dir", type=str, default="checkpoints/iter_000010000/model")
27 | parser.add_argument("--save_path", type=str, default="saved_model.pt")
28 | return parser.parse_args()
29 |
30 |
31 | if __name__ == "__main__":
32 | args = parse_arguments()
33 | storage_reader = FileSystemReader(args.dcp_checkpoint_dir)
34 |
35 | sd = {}
36 | _load_state_dict(sd, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True)
37 | new_sd = {}
38 | for k, v in sd.items():
39 | if k.startswith("net_ema."):
40 | new_key = k.replace("net_ema.", "net.")
41 | # Save in bf16 precision
42 | if v.is_floating_point():
43 | new_sd[new_key] = v.to(torch.bfloat16)
44 | else:
45 | new_sd[new_key] = v
46 | torch.save(new_sd, args.save_path)
47 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/backends/base_backend.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import os.path as osp
18 | from abc import ABCMeta, abstractmethod
19 |
20 |
21 | def mkdir_or_exist(dir_name, mode=0o777):
22 | if dir_name == "":
23 | return
24 | dir_name = osp.expanduser(dir_name)
25 | os.makedirs(dir_name, mode=mode, exist_ok=True)
26 |
27 |
28 | def has_method(obj, method):
29 | return hasattr(obj, method) and callable(getattr(obj, method))
30 |
31 |
32 | class BaseStorageBackend(metaclass=ABCMeta):
33 | """Abstract class of storage backends.
34 |
35 | All backends need to implement two apis: :meth:`get()` and
36 | :meth:`get_text()`.
37 |
38 | - :meth:`get()` reads the file as a byte stream.
39 | - :meth:`get_text()` reads the file as texts.
40 | """
41 |
42 | # a flag to indicate whether the backend can create a symlink for a file
43 | # This attribute will be deprecated in future.
44 | _allow_symlink = False
45 |
46 | @property
47 | def allow_symlink(self):
48 | return self._allow_symlink
49 |
50 | @property
51 | def name(self):
52 | return self.__class__.__name__
53 |
54 | @abstractmethod
55 | def get(self, filepath):
56 | pass
57 |
58 | @abstractmethod
59 | def get_text(self, filepath):
60 | pass
61 |
--------------------------------------------------------------------------------
/imaginaire/callbacks/manual_gc.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import gc
17 |
18 | from imaginaire.callbacks.every_n import EveryN
19 | from imaginaire.utils import log
20 |
21 |
22 | class ManualGarbageCollection(EveryN):
23 | """
24 | Disable auto gc and manually trigger garbage collection every N iterations
25 | It is super useful for large scale training to reduce gpu sync time!
26 | Can reach 50% speedup.
27 |
28 | It is important to note that this callback only disables gc in main process and have auto gc enabled in subprocesses.
29 |
30 | We start disable gc after warm_up iterations to avoid disabling gc in subprocesses, such as dataloader, which can cause OOM
31 | """
32 |
33 | def __init__(self, *args, warm_up: int = 5, **kwargs):
34 | kwargs["barrier_after_run"] = False
35 | super().__init__(*args, **kwargs)
36 |
37 | self.counter = 0
38 | self.warm = warm_up
39 |
40 | def every_n_impl(self, trainer, model, data_batch, output_batch, loss, iteration):
41 | del trainer, model, data_batch, output_batch, loss
42 | self.counter += 1
43 | if self.counter < self.warm:
44 | return
45 | if self.counter == self.warm:
46 | gc.disable()
47 | log.critical("Garbage collection disabled")
48 |
49 | gc.collect(1)
50 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/optimizer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.lazy_config import PLACEHOLDER, LazyDict
19 | from imaginaire.lazy_config import LazyCall as L
20 | from rcm.utils.optim_instantiate_dtensor import get_base_optimizer
21 |
22 | AdamWConfig = L(get_base_optimizer)(
23 | model=PLACEHOLDER,
24 | lr=1e-4,
25 | weight_decay=0.1,
26 | betas=[0.9, 0.99],
27 | optim_type="adamw",
28 | eps=1e-8,
29 | )
30 |
31 | FusedAdamWConfig: LazyDict = L(get_base_optimizer)(
32 | model=PLACEHOLDER,
33 | lr=1e-4,
34 | weight_decay=0.1,
35 | betas=[0.9, 0.99],
36 | optim_type="fusedadam",
37 | eps=1e-8,
38 | master_weights=True,
39 | capturable=True,
40 | )
41 |
42 |
43 | def register_optimizer():
44 | cs = ConfigStore.instance()
45 | cs.store(group="optimizer", package="optimizer", name="fusedadamw", node=FusedAdamWConfig)
46 | cs.store(group="optimizer", package="optimizer", name="adamw", node=AdamWConfig)
47 |
48 |
49 | def register_optimizer_fake_score():
50 | cs = ConfigStore.instance()
51 | cs.store(group="optimizer_fake_score", package="model.config.optimizer_fake_score", name="fusedadamw", node=FusedAdamWConfig)
52 | cs.store(group="optimizer_fake_score", package="model.config.optimizer_fake_score", name="adamw", node=AdamWConfig)
53 |
--------------------------------------------------------------------------------
/rcm/tokenizers/interface.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from abc import ABC, abstractmethod
17 |
18 | import torch
19 |
20 |
21 | class VideoTokenizerInterface(ABC):
22 | def __init__(self): # noqa: B027
23 | pass
24 |
25 | @abstractmethod
26 | def reset_dtype(self):
27 | """
28 | Reset the dtype of the model to the dtype its weights were trained with or quantized to.
29 | """
30 | pass
31 |
32 | @abstractmethod
33 | def encode(self, state: torch.Tensor) -> torch.Tensor:
34 | pass
35 |
36 | @abstractmethod
37 | def decode(self, latent: torch.Tensor) -> torch.Tensor:
38 | pass
39 |
40 | @abstractmethod
41 | def get_latent_num_frames(self, num_pixel_frames: int) -> int:
42 | pass
43 |
44 | @abstractmethod
45 | def get_pixel_num_frames(self, num_latent_frames: int) -> int:
46 | pass
47 |
48 | @property
49 | @abstractmethod
50 | def spatial_compression_factor(self):
51 | pass
52 |
53 | @property
54 | @abstractmethod
55 | def temporal_compression_factor(self):
56 | pass
57 |
58 | @property
59 | @abstractmethod
60 | def spatial_resolution(self):
61 | pass
62 |
63 | @property
64 | @abstractmethod
65 | def pixel_chunk_duration(self):
66 | pass
67 |
68 | @property
69 | @abstractmethod
70 | def latent_chunk_duration(self):
71 | pass
72 |
73 | @property
74 | @abstractmethod
75 | def latent_ch(self) -> int:
76 | pass
77 |
78 | @property
79 | def is_chunk_overlap(self):
80 | return False
81 |
82 | @property
83 | def is_causal(self):
84 | return True
85 |
--------------------------------------------------------------------------------
/rcm/callbacks/compile_tokenizer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | from imaginaire.utils import log
19 | from imaginaire.utils.callback import Callback
20 |
21 |
22 | class CompileTokenizer(Callback):
23 | def __init__(self, enabled: bool = False, compile_after_iterations: int = 4, dynamic: bool = False):
24 | super().__init__()
25 | self.enabled = enabled
26 | self.compiled = False
27 | self.compile_after_iterations = compile_after_iterations
28 | self.skip_counter = 0
29 | self.dynamic = dynamic # If there are issues with constant recompilations you may set this value to None or True
30 |
31 | def on_training_step_start(self, model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None:
32 | if not self.enabled or self.compiled:
33 | return
34 |
35 | if isinstance(model.tokenizer, torch.jit.ScriptModule):
36 | log.critical(f"The Tokenizer model {type(model.tokenizer)} is a JIT model, which is not compilable. The Tokenizer will not be compiled.")
37 |
38 | if self.skip_counter == self.compile_after_iterations:
39 | try:
40 | # PyTorch >= 2.7
41 | torch._dynamo.config.recompile_limit = 32
42 | except AttributeError:
43 | try:
44 | torch._dynamo.config.cache_size_limit = 32
45 | except AttributeError:
46 | log.warning("Tokenizer compilation requested, but Torch Dynamo is unavailable – skipping compilation.")
47 | self.enabled = False
48 | return
49 |
50 | model.tokenizer.encode = torch.compile(model.tokenizer.encode, dynamic=self.dynamic)
51 | model.tokenizer.decode = torch.compile(model.tokenizer.decode, dynamic=self.dynamic)
52 | self.compiled = True
53 | self.skip_counter += 1
54 |
--------------------------------------------------------------------------------
/rcm/utils/selective_activation_checkpoint.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass
17 | from enum import Enum
18 |
19 | import torch
20 |
21 | try:
22 | from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts, noop_context_fn
23 | except ImportError:
24 | CheckpointPolicy = None
25 |
26 | mm_only_save_list = {
27 | torch.ops.aten.mm.default,
28 | torch.ops.aten._scaled_dot_product_efficient_attention.default,
29 | torch.ops.aten._scaled_dot_product_flash_attention.default,
30 | torch.ops.aten.addmm.default,
31 | }
32 |
33 |
34 | class CheckpointMode(str, Enum):
35 | """
36 | Enum for the different checkpoint modes.
37 | """
38 |
39 | NONE = "none"
40 | MM_ONLY = "mm_only"
41 | BLOCK_WISE = "block_wise"
42 |
43 | def __str__(self) -> str:
44 | # Optional: makes print() show just the value
45 | return self.value
46 |
47 |
48 | def mm_only_policy(ctx, func, *args, **kwargs):
49 | """
50 | In newer flash-attn and TE versions, FA2 shows up in the list of ops with the name of 'flash_attn._flash_attn_forward'.
51 | However, FA2 is much slower (2-3x) than FA3 or cuDNN kernel. Registering cuDNN kernel would require heavy changes in TE code.
52 | That's why the best option is to use FA3 with small modifications to flash_attn_interface.py to register FA3 as PyTorch op.
53 | """
54 | to_save = func in mm_only_save_list or "flash_attn" in str(func)
55 | return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE
56 |
57 |
58 | def mm_only_context_fn():
59 | return create_selective_checkpoint_contexts(mm_only_policy)
60 |
61 |
62 | @dataclass
63 | class SACConfig:
64 | mode: str = "mm_only"
65 | every_n_blocks: int = 1
66 |
67 | def get_context_fn(self):
68 | if self.mode == CheckpointMode.MM_ONLY:
69 | return mm_only_context_fn
70 | elif self.mode == CheckpointMode.BLOCK_WISE:
71 | return noop_context_fn
72 | else:
73 | raise ValueError(f"Invalid mode: {self.mode}")
74 |
--------------------------------------------------------------------------------
/imaginaire/lazy_config/registry.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import pydoc
17 | from typing import Any
18 |
19 | from fvcore.common.registry import Registry # for backward compatibility.
20 |
21 | """
22 | ``Registry`` and `locate` provide ways to map a string (typically found
23 | in config files) to callable objects.
24 | """
25 |
26 | __all__ = ["Registry", "locate"]
27 |
28 |
29 | def _convert_target_to_string(t: Any) -> str:
30 | """
31 | Inverse of ``locate()``.
32 |
33 | Args:
34 | t: any object with ``__module__`` and ``__qualname__``
35 | """
36 | module, qualname = t.__module__, t.__qualname__
37 |
38 | # Compress the path to this object, e.g. ``module.submodule._impl.class``
39 | # may become ``module.submodule.class``, if the later also resolves to the same
40 | # object. This simplifies the string, and also is less affected by moving the
41 | # class implementation.
42 | module_parts = module.split(".")
43 | for k in range(1, len(module_parts)):
44 | prefix = ".".join(module_parts[:k])
45 | candidate = f"{prefix}.{qualname}"
46 | try:
47 | if locate(candidate) is t:
48 | return candidate
49 | except ImportError:
50 | pass
51 | return f"{module}.{qualname}"
52 |
53 |
54 | def locate(name: str) -> Any:
55 | """
56 | Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
57 | such as "module.submodule.class_name".
58 |
59 | Raise Exception if it cannot be found.
60 | """
61 | obj = pydoc.locate(name)
62 |
63 | # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
64 | # by pydoc.locate. Try a private function from hydra.
65 | if obj is None:
66 | try:
67 | # from hydra.utils import get_method - will print many errors
68 | from hydra.utils import _locate
69 | except ImportError as e:
70 | raise ImportError(f"Cannot dynamically locate object {name}!") from e
71 | else:
72 | obj = _locate(name) # it raises if fails
73 |
74 | return obj
75 |
--------------------------------------------------------------------------------
/imaginaire/lazy_config/omegaconf_patch.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Any
17 |
18 | from omegaconf import OmegaConf
19 | from omegaconf.base import DictKeyType, SCMode
20 | from omegaconf.dictconfig import DictConfig # pragma: no cover
21 |
22 |
23 | def to_object(cfg: Any) -> dict[DictKeyType, Any] | list[Any] | None | str | Any:
24 | """
25 | Converts an OmegaConf configuration object to a native Python container (dict or list), unless
26 | the configuration is specifically created by LazyCall, in which case the original configuration
27 | is returned directly.
28 |
29 | This function serves as a modification of the original `to_object` method from OmegaConf,
30 | preventing DictConfig objects created by LazyCall from being automatically converted to Python
31 | dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
32 | structure and behavior.
33 |
34 | Differences from OmegaConf's original `to_object`:
35 | - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
36 |
37 | Reference:
38 | - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
39 |
40 | Args:
41 | cfg (Any): The OmegaConf configuration object to convert.
42 |
43 | Returns:
44 | Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
45 | `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
46 |
47 | Examples:
48 | >>> cfg = DictConfig({"key": "value", "_target_": "Model"})
49 | >>> to_object(cfg)
50 | DictConfig({"key": "value", "_target_": "Model"})
51 |
52 | >>> cfg = DictConfig({"list": [1, 2, 3]})
53 | >>> to_object(cfg)
54 | {'list': [1, 2, 3]}
55 | """
56 | if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
57 | return cfg
58 |
59 | return OmegaConf.to_container(
60 | cfg=cfg,
61 | resolve=True,
62 | throw_on_missing=True,
63 | enum_to_str=False,
64 | structured_config_mode=SCMode.INSTANTIATE,
65 | )
66 |
--------------------------------------------------------------------------------
/imaginaire/lazy_config/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 |
18 | from omegaconf import OmegaConf
19 |
20 | from imaginaire.lazy_config.instantiate import instantiate
21 | from imaginaire.lazy_config.lazy import LazyCall, LazyConfig, LazyDict
22 | from imaginaire.lazy_config.omegaconf_patch import to_object
23 |
24 | OmegaConf.to_object = to_object
25 |
26 | PLACEHOLDER = None
27 |
28 | __all__ = ["PLACEHOLDER", "LazyCall", "LazyConfig", "LazyDict", "instantiate"]
29 |
30 |
31 | DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
32 |
33 |
34 | def fixup_module_metadata(module_name, namespace, keys=None):
35 | """
36 | Fix the __qualname__ of module members to be their exported api name, so
37 | when they are referenced in docs, sphinx can find them. Reference:
38 | https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
39 | """
40 | if not DOC_BUILDING:
41 | return
42 | seen_ids = set()
43 |
44 | def fix_one(qualname, name, obj):
45 | # avoid infinite recursion (relevant when using
46 | # typing.Generic, for example)
47 | if id(obj) in seen_ids:
48 | return
49 | seen_ids.add(id(obj))
50 |
51 | mod = getattr(obj, "__module__", None)
52 | if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
53 | obj.__module__ = module_name
54 | # Modules, unlike everything else in Python, put fully-qualitied
55 | # names into their __name__ attribute. We check for "." to avoid
56 | # rewriting these.
57 | if hasattr(obj, "__name__") and "." not in obj.__name__:
58 | obj.__name__ = name
59 | obj.__qualname__ = qualname
60 | if isinstance(obj, type):
61 | for attr_name, attr_value in obj.__dict__.items():
62 | fix_one(objname + "." + attr_name, attr_name, attr_value)
63 |
64 | if keys is None:
65 | keys = namespace.keys()
66 | for objname in keys:
67 | if not objname.startswith("_"):
68 | obj = namespace[objname]
69 | fix_one(objname, objname, obj)
70 |
71 |
72 | fixup_module_metadata(__name__, globals(), __all__)
73 | del fixup_module_metadata
74 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/net.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.lazy_config import LazyCall as L
19 | from imaginaire.lazy_config import LazyDict
20 |
21 | from rcm.networks.wan2pt1 import WanModel
22 | from rcm.networks.wan2pt1_jvp import WanModel_JVP
23 |
24 | wan2pt1_1pt3B_net_args = dict(
25 | dim=1536,
26 | eps=1e-06,
27 | ffn_dim=8960,
28 | freq_dim=256,
29 | in_dim=16,
30 | num_heads=12,
31 | num_layers=30,
32 | out_dim=16,
33 | text_len=512,
34 | )
35 |
36 | wan2pt1_14B_net_args = dict(
37 | dim=5120,
38 | eps=1e-06,
39 | ffn_dim=13824,
40 | freq_dim=256,
41 | in_dim=16,
42 | num_heads=40,
43 | num_layers=40,
44 | out_dim=16,
45 | text_len=512,
46 | )
47 |
48 | WAN2PT1_1PT3B_T2V: LazyDict = L(WanModel)(**wan2pt1_1pt3B_net_args, model_type="t2v")
49 |
50 | WAN2PT1_14B_T2V: LazyDict = L(WanModel)(**wan2pt1_14B_net_args, model_type="t2v")
51 |
52 | WAN2PT1_1PT3B_T2V_JVP: LazyDict = L(WanModel_JVP)(**wan2pt1_1pt3B_net_args, model_type="t2v")
53 |
54 | WAN2PT1_14B_T2V_JVP: LazyDict = L(WanModel_JVP)(**wan2pt1_14B_net_args, model_type="t2v")
55 |
56 |
57 | def register_net():
58 | cs = ConfigStore.instance()
59 | cs.store(group="net", package="model.config.net", name="wan2pt1_1pt3B_t2v", node=WAN2PT1_1PT3B_T2V)
60 | cs.store(group="net", package="model.config.net", name="wan2pt1_14B_t2v", node=WAN2PT1_14B_T2V)
61 | cs.store(group="net", package="model.config.net", name="wan2pt1_1pt3B_t2v_jvp", node=WAN2PT1_1PT3B_T2V_JVP)
62 | cs.store(group="net", package="model.config.net", name="wan2pt1_14B_t2v_jvp", node=WAN2PT1_14B_T2V_JVP)
63 |
64 |
65 | def register_net_fake_score():
66 | cs = ConfigStore.instance()
67 | cs.store(group="net_fake_score", package="model.config.net_fake_score", name="wan2pt1_1pt3B_t2v", node=WAN2PT1_1PT3B_T2V)
68 | cs.store(group="net_fake_score", package="model.config.net_fake_score", name="wan2pt1_14B_t2v", node=WAN2PT1_14B_T2V)
69 |
70 |
71 | def register_net_teacher():
72 | cs = ConfigStore.instance()
73 | cs.store(group="net_teacher", package="model.config.net_teacher", name="wan2pt1_1pt3B_t2v", node=WAN2PT1_1PT3B_T2V)
74 | cs.store(group="net_teacher", package="model.config.net_teacher", name="wan2pt1_14B_t2v", node=WAN2PT1_14B_T2V)
75 |
--------------------------------------------------------------------------------
/rcm/utils/optim_instantiate_dtensor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import hydra
17 | import torch
18 | from omegaconf import ListConfig
19 | from torch import nn
20 |
21 | from rcm.utils.fused_adam_dtensor import FusedAdam
22 | from imaginaire.utils import log
23 |
24 |
25 | def get_regular_param_group(net: nn.Module):
26 | """
27 | seperate the parameters of the network into two groups: decay and no_decay.
28 | based on nano_gpt codebase.
29 | """
30 | param_dict = {pn: p for pn, p in net.named_parameters()}
31 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
32 |
33 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
34 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
35 | return decay_params, nodecay_params
36 |
37 |
38 | def get_base_optimizer(
39 | model: nn.Module,
40 | lr: float,
41 | weight_decay: float,
42 | optim_type: str = "adamw",
43 | **kwargs,
44 | ) -> torch.optim.Optimizer:
45 | net_decay_param, net_nodecay_param = get_regular_param_group(model)
46 |
47 | num_decay_params = sum(p.numel() for p in net_decay_param)
48 | num_nodecay_params = sum(p.numel() for p in net_nodecay_param)
49 | net_param_total = num_decay_params + num_nodecay_params
50 | log.critical(f"total num parameters : {net_param_total:,}")
51 |
52 | param_group = [
53 | {
54 | "params": net_decay_param + net_nodecay_param,
55 | "lr": lr,
56 | "weight_decay": weight_decay,
57 | },
58 | ]
59 |
60 | if optim_type == "adamw":
61 | opt_cls = torch.optim.AdamW
62 | elif optim_type == "fusedadam":
63 | opt_cls = FusedAdam
64 | else:
65 | raise ValueError(f"Unknown optimizer type: {optim_type}")
66 |
67 | for k, v in kwargs.items():
68 | if isinstance(v, ListConfig):
69 | kwargs[k] = list(v)
70 |
71 | return opt_cls(param_group, **kwargs)
72 |
73 |
74 | def get_base_scheduler(
75 | optimizer: torch.optim.Optimizer,
76 | model: nn.Module,
77 | scheduler_config: dict,
78 | ):
79 | net_scheduler = hydra.utils.instantiate(scheduler_config)
80 | net_scheduler.model = model
81 |
82 | return torch.optim.lr_scheduler.LambdaLR(
83 | optimizer,
84 | lr_lambda=[
85 | net_scheduler.schedule,
86 | ],
87 | )
88 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/jsonl_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 | from typing import IO
18 |
19 | import numpy as np
20 |
21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22 |
23 |
24 | def set_default(obj):
25 | """Set default json values for non-serializable values.
26 |
27 | It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
28 | It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
29 | etc.) into plain numbers of plain python built-in types.
30 | """
31 | if isinstance(obj, (set, range)):
32 | return list(obj)
33 | elif isinstance(obj, np.ndarray):
34 | return obj.tolist()
35 | elif isinstance(obj, np.generic):
36 | return obj.item()
37 | raise TypeError(f"{type(obj)} is unsupported for json dump")
38 |
39 |
40 | class JsonlHandler(BaseFileHandler):
41 | """Handler for JSON lines (JSONL) files."""
42 |
43 | def load_from_fileobj(self, file: IO[bytes]):
44 | """Load JSON objects from a newline-delimited JSON (JSONL) file object.
45 |
46 | Returns:
47 | A list of Python objects loaded from each JSON line.
48 | """
49 | data = []
50 | for line in file:
51 | line = line.strip()
52 | if not line:
53 | continue # skip empty lines if any
54 | data.append(json.loads(line))
55 | return data
56 |
57 | def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs):
58 | """Dump a list of objects to a newline-delimited JSON (JSONL) file object.
59 |
60 | Args:
61 | obj: A list (or iterable) of objects to dump line by line.
62 | """
63 | kwargs.setdefault("default", set_default)
64 | for item in obj:
65 | file.write(json.dumps(item, **kwargs) + "\n")
66 |
67 | def dump_to_str(self, obj, **kwargs):
68 | """Dump a list of objects to a newline-delimited JSON (JSONL) string."""
69 | kwargs.setdefault("default", set_default)
70 | lines = [json.dumps(item, **kwargs) for item in obj]
71 | return "\n".join(lines)
72 |
73 |
74 | if __name__ == "__main__":
75 | from imaginaire.utils.easy_io import easy_io
76 |
77 | easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl")
78 | print(easy_io.load("test.jsonl"))
79 | easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl")
80 | print(easy_io.load("test.jsonl"))
81 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/np_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from io import BytesIO
17 | from typing import IO, Any
18 |
19 | import numpy as np
20 |
21 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22 |
23 |
24 | class NumpyHandler(BaseFileHandler):
25 | str_like = False
26 |
27 | def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any:
28 | """
29 | Load a NumPy array from a file-like object.
30 |
31 | Parameters:
32 | file (IO[bytes]): The file-like object containing the NumPy array data.
33 | **kwargs: Additional keyword arguments passed to `np.load`.
34 |
35 | Returns:
36 | numpy.ndarray: The loaded NumPy array.
37 | """
38 | return np.load(file, **kwargs)
39 |
40 | def load_from_path(self, filepath: str, **kwargs) -> Any:
41 | """
42 | Load a NumPy array from a file path.
43 |
44 | Parameters:
45 | filepath (str): The path to the file to load.
46 | **kwargs: Additional keyword arguments passed to `np.load`.
47 |
48 | Returns:
49 | numpy.ndarray: The loaded NumPy array.
50 | """
51 | return super().load_from_path(filepath, mode="rb", **kwargs)
52 |
53 | def dump_to_str(self, obj: np.ndarray, **kwargs) -> str:
54 | """
55 | Serialize a NumPy array to a string in binary format.
56 |
57 | Parameters:
58 | obj (np.ndarray): The NumPy array to serialize.
59 | **kwargs: Additional keyword arguments passed to `np.save`.
60 |
61 | Returns:
62 | str: The serialized NumPy array as a string.
63 | """
64 | with BytesIO() as f:
65 | np.save(f, obj, **kwargs)
66 | return f.getvalue()
67 |
68 | def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs):
69 | """
70 | Dump a NumPy array to a file-like object.
71 |
72 | Parameters:
73 | obj (np.ndarray): The NumPy array to dump.
74 | file (IO[bytes]): The file-like object to which the array is dumped.
75 | **kwargs: Additional keyword arguments passed to `np.save`.
76 | """
77 | np.save(file, obj, **kwargs)
78 |
79 | def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs):
80 | """
81 | Dump a NumPy array to a file path.
82 |
83 | Parameters:
84 | obj (np.ndarray): The NumPy array to dump.
85 | filepath (str): The file path where the array should be saved.
86 | **kwargs: Additional keyword arguments passed to `np.save`.
87 | """
88 | with open(filepath, "wb") as f:
89 | np.save(f, obj, **kwargs)
90 |
--------------------------------------------------------------------------------
/rcm/datasets/merge_tar_shards.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Merge many small TAR shards into new larger shards.
4 | Input: a directory containing shards like shard-00000.tar, shard-00001.tar, ...
5 | Output: new shards with target_shard_size samples each.
6 |
7 | Usage:
8 | python merge_tar_shards.py \
9 | --input_dir path/to/small_shards \
10 | --output_dir path/to/large_shards \
11 | --target_shard_size 5000
12 |
13 | Each sample is assumed to be a group of tar members with common prefix, e.g.
14 | 000000.jpg
15 | 000000.txt
16 | 000001.jpg
17 | 000001.txt
18 | This script groups files by prefix before writing.
19 | """
20 |
21 | import os
22 | import tarfile
23 | import argparse
24 | from collections import defaultdict
25 |
26 |
27 | def read_samples_from_tar(tar_path):
28 | """Yield samples as {key: {filename: bytes}}.
29 | Group consecutive members by prefix before '.'
30 | """
31 | samples = defaultdict(dict)
32 | with tarfile.open(tar_path, "r") as tar:
33 | for m in tar.getmembers():
34 | if not m.isfile():
35 | continue
36 | name = os.path.basename(m.name)
37 | prefix = name.split(".")[0]
38 | f = tar.extractfile(m)
39 | if f is None:
40 | continue
41 | samples[prefix][name] = f.read()
42 | for key, files in samples.items():
43 | yield key, files
44 |
45 |
46 | def write_shard(samples, out_path):
47 | """Write a list of samples to one TAR shard.
48 | samples: list of (key, {filename: bytes})
49 | """
50 | with tarfile.open(out_path, "w") as tar:
51 | for key, file_dict in samples:
52 | for fname, data in file_dict.items():
53 | info = tarfile.TarInfo(name=f"{key}/{fname}")
54 | info.size = len(data)
55 | tar.addfile(info, fileobj=BytesIO(data))
56 |
57 |
58 | from io import BytesIO
59 |
60 | def merge_shards(input_dir, output_dir, target_shard_size):
61 | os.makedirs(output_dir, exist_ok=True)
62 | shard_paths = sorted(
63 | [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".tar")]
64 | )
65 |
66 | current_samples = []
67 | new_shard_index = 0
68 |
69 | for tar_path in shard_paths:
70 | for key, files in read_samples_from_tar(tar_path):
71 | current_samples.append((key, files))
72 | if len(current_samples) >= target_shard_size:
73 | out_path = os.path.join(output_dir, f"shard-{new_shard_index:05d}.tar")
74 | write_shard(current_samples, out_path)
75 | print(f"Wrote {out_path}, {len(current_samples)} samples")
76 | new_shard_index += 1
77 | current_samples = []
78 |
79 | # final remainder
80 | if current_samples:
81 | out_path = os.path.join(output_dir, f"shard-{new_shard_index:05d}.tar")
82 | write_shard(current_samples, out_path)
83 | print(f"Wrote {out_path}, {len(current_samples)} samples")
84 |
85 |
86 | if __name__ == "__main__":
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument("--input_dir", required=True)
89 | parser.add_argument("--output_dir", required=True)
90 | parser.add_argument("--target_shard_size", type=int, required=True)
91 | args = parser.parse_args()
92 |
93 | merge_shards(args.input_dir, args.output_dir, args.target_shard_size)
94 |
--------------------------------------------------------------------------------
/imaginaire/callbacks/low_precision.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | from imaginaire.model import ImaginaireModel
19 | from imaginaire.trainer import ImaginaireTrainer
20 | from imaginaire.utils import distributed
21 | from imaginaire.utils.callback import Callback
22 | from imaginaire.config import Config
23 | from imaginaire.utils.misc import get_local_tensor_if_DTensor
24 |
25 |
26 | def update_master_weights(optimizer: torch.optim.Optimizer):
27 | if getattr(optimizer, "master_weights", False) and optimizer.param_groups_master is not None:
28 | params, master_params = [], []
29 | for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master):
30 | for p, p_master in zip(group["params"], group_master["params"]):
31 | params.append(get_local_tensor_if_DTensor(p.data))
32 | master_params.append(p_master.data)
33 | torch._foreach_copy_(params, master_params)
34 |
35 |
36 | class LowPrecisionCallback(Callback):
37 | """The callback class handling low precision training
38 |
39 | Config with non-primitive type makes it difficult to override the option.
40 | The callback gets precision from model.precision instead.
41 | It also auto disabled when using fp32.
42 | """
43 |
44 | def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
45 | self.update_iter = update_iter
46 |
47 | def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
48 | assert model.precision in [
49 | torch.bfloat16,
50 | torch.float16,
51 | torch.half,
52 | ], "LowPrecisionCallback must use a low precision dtype."
53 | self.precision_type = model.precision
54 |
55 | def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
56 | for k, v in data.items():
57 | if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
58 | data[k] = v.to(dtype=self.precision_type)
59 |
60 | def on_validation_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
61 | for k, v in data.items():
62 | if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
63 | data[k] = v.to(dtype=self.precision_type)
64 |
65 | def on_before_zero_grad(
66 | self,
67 | model_ddp: distributed.DistributedDataParallel,
68 | optimizer: torch.optim.Optimizer,
69 | scheduler: torch.optim.lr_scheduler.LRScheduler,
70 | iteration: int = 0,
71 | ) -> None:
72 | if iteration % self.update_iter == 0:
73 | update_master_weights(optimizer)
74 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/backends/http_backend.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import tempfile
18 | from collections.abc import Generator
19 | from contextlib import contextmanager
20 | from pathlib import Path
21 | from urllib.request import urlopen
22 |
23 | from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
24 |
25 |
26 | class HTTPBackend(BaseStorageBackend):
27 | """HTTP and HTTPS storage bachend."""
28 |
29 | def get(self, filepath: str) -> bytes:
30 | """Read bytes from a given ``filepath``.
31 |
32 | Args:
33 | filepath (str): Path to read data.
34 |
35 | Returns:
36 | bytes: Expected bytes object.
37 |
38 | Examples:
39 | >>> backend = HTTPBackend()
40 | >>> backend.get('http://path/of/file')
41 | b'hello world'
42 | """
43 | return urlopen(filepath).read()
44 |
45 | def get_text(self, filepath, encoding="utf-8") -> str:
46 | """Read text from a given ``filepath``.
47 |
48 | Args:
49 | filepath (str): Path to read data.
50 | encoding (str): The encoding format used to open the ``filepath``.
51 | Defaults to 'utf-8'.
52 |
53 | Returns:
54 | str: Expected text reading from ``filepath``.
55 |
56 | Examples:
57 | >>> backend = HTTPBackend()
58 | >>> backend.get_text('http://path/of/file')
59 | 'hello world'
60 | """
61 | return urlopen(filepath).read().decode(encoding)
62 |
63 | @contextmanager
64 | def get_local_path(self, filepath: str) -> Generator[str | Path, None, None]:
65 | """Download a file from ``filepath`` to a local temporary directory,
66 | and return the temporary path.
67 |
68 | ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
69 | can be called with ``with`` statement, and when exists from the
70 | ``with`` statement, the temporary path will be released.
71 |
72 | Args:
73 | filepath (str): Download a file from ``filepath``.
74 |
75 | Yields:
76 | Iterable[str]: Only yield one temporary path.
77 |
78 | Examples:
79 | >>> backend = HTTPBackend()
80 | >>> # After existing from the ``with`` clause,
81 | >>> # the path will be removed
82 | >>> with backend.get_local_path('http://path/of/file') as path:
83 | ... # do something here
84 | """
85 | try:
86 | f = tempfile.NamedTemporaryFile(delete=False)
87 | f.write(self.get(filepath))
88 | f.close()
89 | yield f.name
90 | finally:
91 | os.remove(f.name)
92 |
--------------------------------------------------------------------------------
/rcm/configs/defaults/callbacks.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from hydra.core.config_store import ConfigStore
17 |
18 | from imaginaire.lazy_config import PLACEHOLDER
19 | from imaginaire.lazy_config import LazyCall as L
20 | from imaginaire.callbacks.manual_gc import ManualGarbageCollection
21 | from imaginaire.callbacks.low_precision import LowPrecisionCallback
22 | from rcm.callbacks.compile_tokenizer import CompileTokenizer
23 | from rcm.callbacks.dataloading_monitor import DetailedDataLoadingSpeedMonitor
24 | from rcm.callbacks.device_monitor import DeviceMonitor
25 | from rcm.callbacks.grad_clip import GradClip
26 | from rcm.callbacks.heart_beat import HeartBeat
27 | from rcm.callbacks.iter_speed import IterSpeed
28 | from rcm.callbacks.wandb_log import WandbCallback
29 | from rcm.callbacks.every_n_draw_distill import EveryNDrawSample_Distill
30 |
31 | BASIC_CALLBACKS = dict(
32 | grad_clip=L(GradClip)(),
33 | low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1),
34 | iter_speed=L(IterSpeed)(
35 | every_n="${trainer.logging_iter}",
36 | save_s3_every_log_n=10,
37 | ),
38 | heart_beat=L(HeartBeat)(
39 | every_n=10,
40 | update_interval_in_minute=20,
41 | ),
42 | device_monitor=L(DeviceMonitor)(
43 | every_n="${trainer.logging_iter}",
44 | upload_every_n_mul=10,
45 | ),
46 | manual_gc=L(ManualGarbageCollection)(every_n=5),
47 | compile_tokenizer=L(CompileTokenizer)(
48 | enabled=True,
49 | compile_after_iterations=4,
50 | dynamic=False, # If there are issues with constant recompilations you may set this value to None or True
51 | ),
52 | )
53 |
54 | SPEED_CALLBACKS = dict(
55 | dataloader_speed=L(DetailedDataLoadingSpeedMonitor)(
56 | every_n="${trainer.logging_iter}",
57 | ),
58 | )
59 |
60 | WANDB_CALLBACK = dict(
61 | wandb=L(WandbCallback)(
62 | logging_iter_multipler=1,
63 | save_logging_iter_multipler=10,
64 | ),
65 | wandb_10x=L(WandbCallback)(
66 | logging_iter_multipler=10,
67 | save_logging_iter_multipler=1,
68 | ),
69 | )
70 |
71 | VIZ_ONLINE_SAMPLING_DISTILL_CALLBACKS = dict(
72 | every_n_sample_reg=L(EveryNDrawSample_Distill)(
73 | every_n=5000,
74 | ),
75 | every_n_sample_ema=L(EveryNDrawSample_Distill)(
76 | every_n=5000,
77 | is_ema=True,
78 | ),
79 | )
80 |
81 |
82 | def register_callbacks():
83 | cs = ConfigStore.instance()
84 | cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
85 | cs.store(group="callbacks", package="trainer.callbacks", name="dataloading_speed", node=SPEED_CALLBACKS)
86 | cs.store(group="callbacks", package="trainer.callbacks", name="wandb", node=WANDB_CALLBACK)
87 | cs.store(group="callbacks", package="trainer.callbacks", name="viz_online_sampling_distill", node=VIZ_ONLINE_SAMPLING_DISTILL_CALLBACKS)
88 |
--------------------------------------------------------------------------------
/imaginaire/callbacks/every_n.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from abc import abstractmethod
17 |
18 | import torch
19 |
20 | from imaginaire.model import ImaginaireModel
21 | from imaginaire.trainer import ImaginaireTrainer
22 | from imaginaire.utils import distributed, log
23 | from imaginaire.utils.callback import Callback
24 |
25 |
26 | class EveryN(Callback):
27 | def __init__(
28 | self,
29 | every_n: int | None = None,
30 | step_size: int = 1,
31 | barrier_after_run: bool = True,
32 | run_at_start: bool = False,
33 | ) -> None:
34 | """Constructor for `EveryN`.
35 |
36 | Args:
37 | every_n (int): Frequency with which callback is run during training.
38 | step_size (int): Size of iteration step count. Default 1.
39 | barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts.
40 | run_at_start (bool): Whether to run at the beginning of training. Default False.
41 | """
42 | self.every_n = every_n
43 | if self.every_n == 0:
44 | log.warning(
45 | f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped."
46 | )
47 |
48 | self.step_size = step_size
49 | self.barrier_after_run = barrier_after_run
50 | self.run_at_start = run_at_start
51 |
52 | def on_training_step_end(
53 | self,
54 | model: ImaginaireModel,
55 | data_batch: dict[str, torch.Tensor],
56 | output_batch: dict[str, torch.Tensor],
57 | loss: torch.Tensor,
58 | iteration: int = 0,
59 | ) -> None:
60 | # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training
61 | if self.every_n != 0:
62 | trainer = self.trainer
63 | global_step = iteration // self.step_size
64 | should_run = (iteration == 1 and self.run_at_start) or (
65 | global_step % self.every_n == 0
66 | ) # (self.every_n - 1)
67 | if should_run:
68 | log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}")
69 | self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration)
70 | log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}")
71 | # add necessary barrier to avoid timeout
72 | if self.barrier_after_run:
73 | distributed.barrier()
74 |
75 | @abstractmethod
76 | def every_n_impl(
77 | self,
78 | trainer: ImaginaireTrainer,
79 | model: ImaginaireModel,
80 | data_batch: dict[str, torch.Tensor],
81 | output_batch: dict[str, torch.Tensor],
82 | loss: torch.Tensor,
83 | iteration: int,
84 | ) -> None: ...
85 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import importlib
18 | import os
19 |
20 | from loguru import logger as logging
21 |
22 | from imaginaire.config import Config, pretty_print_overrides
23 | from imaginaire.lazy_config import instantiate
24 | from imaginaire.lazy_config.lazy import LazyConfig
25 | from imaginaire.utils import distributed
26 | from imaginaire.utils.config_helper import get_config_module, override
27 |
28 |
29 | @logging.catch(reraise=True)
30 | def launch(config: Config, args: argparse.Namespace) -> None:
31 | # Need to initialize the distributed environment before calling config.validate() because it tries to synchronize
32 | # a buffer across ranks. If you don't do this, then you end up allocating a bunch of buffers on rank 0, and also that
33 | # check doesn't actually do anything.
34 | distributed.init()
35 |
36 | # Check that the config is valid
37 | config.validate()
38 | # Freeze the config so developers don't change it during training.
39 | config.freeze() # type: ignore
40 | trainer = instantiate(config.trainer.type, config=config)
41 | # Create the model
42 | model = instantiate(config.model)
43 | # Create the dataloaders.
44 | dataloader_train = instantiate(config.dataloader_train)
45 | dataloader_val = instantiate(config.dataloader_val)
46 | # Start training
47 | trainer.train(model, dataloader_train, dataloader_val)
48 |
49 |
50 | if __name__ == "__main__":
51 | # Usage: torchrun --nproc_per_node=1 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiments=predict2_video2world_training_2b_cosmos_nemo_assets
52 |
53 | # Get the config file from the input arguments.
54 | parser = argparse.ArgumentParser(description="Training")
55 | parser.add_argument("--config", help="Path to the config file", required=True)
56 | parser.add_argument(
57 | "opts",
58 | help="""
59 | Modify config options at the end of the command. For Yacs configs, use
60 | space-separated "PATH.KEY VALUE" pairs.
61 | For python-based LazyConfig, use "path.key=value".
62 | """.strip(),
63 | default=None,
64 | nargs=argparse.REMAINDER,
65 | )
66 | parser.add_argument(
67 | "--dryrun",
68 | action="store_true",
69 | help="Do a dry run without training. Useful for debugging the config.",
70 | )
71 | args = parser.parse_args()
72 | config_module = get_config_module(args.config)
73 | config = importlib.import_module(config_module).make_config()
74 | config = override(config, args.opts)
75 | if args.dryrun:
76 | logging.info("Config:\n" + config.pretty_print(use_color=True) + "\n" + pretty_print_overrides(args.opts, use_color=True))
77 | os.makedirs(config.job.path_local, exist_ok=True)
78 | LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
79 | print(f"{config.job.path_local}/config.yaml")
80 | else:
81 | # Launch the training job.
82 | launch(config, args)
83 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/registry_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17 | from imaginaire.utils.easy_io.handlers.byte_handler import ByteHandler
18 | from imaginaire.utils.easy_io.handlers.csv_handler import CsvHandler
19 | from imaginaire.utils.easy_io.handlers.gzip_handler import GzipHandler
20 | from imaginaire.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler
21 | from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
22 | from imaginaire.utils.easy_io.handlers.jsonl_handler import JsonlHandler
23 | from imaginaire.utils.easy_io.handlers.np_handler import NumpyHandler
24 | from imaginaire.utils.easy_io.handlers.pandas_handler import PandasHandler
25 | from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
26 | from imaginaire.utils.easy_io.handlers.pil_handler import PILHandler
27 | from imaginaire.utils.easy_io.handlers.tarfile_handler import TarHandler
28 | from imaginaire.utils.easy_io.handlers.torch_handler import TorchHandler
29 | from imaginaire.utils.easy_io.handlers.torchjit_handler import TorchJitHandler
30 | from imaginaire.utils.easy_io.handlers.txt_handler import TxtHandler
31 | from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
32 |
33 | file_handlers = {
34 | "json": JsonHandler(),
35 | "yaml": YamlHandler(),
36 | "yml": YamlHandler(),
37 | "pickle": PickleHandler(),
38 | "pkl": PickleHandler(),
39 | "tar": TarHandler(),
40 | "jit": TorchJitHandler(),
41 | "npy": NumpyHandler(),
42 | "txt": TxtHandler(),
43 | "csv": CsvHandler(),
44 | "pandas": PandasHandler(),
45 | "gz": GzipHandler(),
46 | "jsonl": JsonlHandler(),
47 | "byte": ByteHandler(),
48 | }
49 |
50 | for torch_type in ["pt", "pth", "ckpt"]:
51 | file_handlers[torch_type] = TorchHandler()
52 | for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]:
53 | file_handlers[img_type] = PILHandler()
54 | file_handlers[img_type].format = img_type
55 | for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]:
56 | file_handlers[video_type] = ImageioVideoHandler()
57 |
58 |
59 | def _register_handler(handler, file_formats):
60 | """Register a handler for some file extensions.
61 |
62 | Args:
63 | handler (:obj:`BaseFileHandler`): Handler to be registered.
64 | file_formats (str or list[str]): File formats to be handled by this
65 | handler.
66 | """
67 | if not isinstance(handler, BaseFileHandler):
68 | raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}")
69 | if isinstance(file_formats, str):
70 | file_formats = [file_formats]
71 | if not all([isinstance(item, str) for item in file_formats]):
72 | raise TypeError("file_formats must be a str or a list of str")
73 | for ext in file_formats:
74 | file_handlers[ext] = handler
75 |
76 |
77 | def register_handler(file_formats, **kwargs):
78 | def wrap(cls):
79 | _register_handler(cls(**kwargs), file_formats)
80 | return cls
81 |
82 | return wrap
83 |
--------------------------------------------------------------------------------
/rcm/callbacks/dataloading_monitor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import time
17 |
18 | import numpy as np
19 | import torch
20 | import wandb
21 |
22 | from imaginaire.model import ImaginaireModel
23 | from imaginaire.utils import distributed
24 | from imaginaire.utils.callback import Callback
25 | from imaginaire.utils.easy_io import easy_io
26 |
27 |
28 | class DetailedDataLoadingSpeedMonitor(Callback):
29 | def __init__(
30 | self,
31 | every_n: int,
32 | step_size: int = 1,
33 | save_s3: bool = False,
34 | ):
35 | self.every_n = every_n
36 | self.step_size = step_size
37 | self.should_run = False
38 | self.start_dataloading_time = None
39 | self.dataloading_time = None
40 | self.name = self.__class__.__name__
41 | self.save_s3 = save_s3
42 | self.time_delta_list = []
43 |
44 | def on_before_dataloading(self, iteration: int = 0) -> None:
45 | # We want to run it one iteration before on_training_step_start should_run is set to True.
46 | global_step = iteration // self.step_size
47 | self.should_run = (global_step + 1) % self.every_n == 0
48 | self.start_dataloading_time = time.time()
49 |
50 | def on_after_dataloading(self, iteration: int = 0) -> None:
51 | self.time_delta_list.append(time.time() - self.start_dataloading_time)
52 |
53 | def on_training_step_end(
54 | self,
55 | model: ImaginaireModel,
56 | data_batch: dict[str, torch.Tensor],
57 | output_batch: dict[str, torch.Tensor],
58 | loss: torch.Tensor,
59 | iteration: int = 0,
60 | ) -> None:
61 | if self.should_run:
62 | self.should_run = False
63 | cur_rank_mean, cur_rank_max = np.mean(self.time_delta_list), np.max(self.time_delta_list)
64 | self.time_delta_list = [] # Reset the list
65 |
66 | dataloading_time_gather_list = distributed.all_gather_tensor(torch.tensor([cur_rank_mean, cur_rank_max]).cuda())
67 | wandb_info = {f"{self.name}_mean/dataloading_{k:03d}": v[0].item() for k, v in enumerate(dataloading_time_gather_list)}
68 | wandb_info.update({f"{self.name}_max/dataloading_{k:03d}": v[1].item() for k, v in enumerate(dataloading_time_gather_list)})
69 | mean_times = torch.stack(dataloading_time_gather_list)[:, 0]
70 | slowest_dataloading_rank_id = torch.argmax(mean_times)
71 | max_dataloading = torch.max(mean_times)
72 | wandb_info.update(
73 | {
74 | "slowest_rank/slowest_dataloading_rank": slowest_dataloading_rank_id.item(),
75 | "slowest_rank/slowest_dataloading_time": max_dataloading.item(),
76 | }
77 | )
78 |
79 | if wandb.run:
80 | wandb.log(wandb_info, step=iteration)
81 |
82 | if self.save_s3 and distributed.is_rank0():
83 | easy_io.dump(
84 | wandb_info,
85 | f"s3://rundir/{self.name}/iter_{iteration:09d}.yaml",
86 | )
87 |
--------------------------------------------------------------------------------
/rcm/utils/dtensor_helper.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import annotations
17 |
18 | import itertools
19 | from typing import Any
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch.distributed.device_mesh import DeviceMesh
24 |
25 | from imaginaire.utils.misc import get_local_tensor_if_DTensor
26 |
27 |
28 | class DTensorFastEmaModelUpdater:
29 | """
30 | Similar as FastEmaModelUpdater
31 | """
32 |
33 | def __init__(self):
34 | # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite
35 | self.is_cached = False
36 |
37 | def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None:
38 | with torch.no_grad():
39 | for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
40 | tgt_params.to_local().data.copy_(src_params.to_local().data)
41 |
42 | @torch.no_grad()
43 | def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None:
44 | target_list = []
45 | source_list = []
46 | for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
47 | assert tgt_params.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead."
48 | target_list.append(tgt_params.to_local())
49 | source_list.append(src_params.to_local().data)
50 | torch._foreach_mul_(target_list, beta)
51 | torch._foreach_add_(target_list, source_list, alpha=1.0 - beta)
52 |
53 | @torch.no_grad()
54 | def cache(self, parameters: Any, is_cpu: bool = False) -> None:
55 | assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?"
56 | device = "cpu" if is_cpu else "cuda"
57 | self.collected_params = [param.to_local().clone().to(device) for param in parameters]
58 | self.is_cached = True
59 |
60 | @torch.no_grad()
61 | def restore(self, parameters: Any) -> None:
62 | assert self.is_cached, "EMA cache is not taken yet."
63 | for c_param, param in zip(self.collected_params, parameters, strict=False):
64 | param.to_local().copy_(c_param.data.type_as(param.data))
65 | self.collected_params = []
66 | # Release the cache after we call restore
67 | self.is_cached = False
68 |
69 |
70 | def broadcast_dtensor_model_states(model: torch.nn.Module, mesh: DeviceMesh):
71 | """Broadcast model states from replicate mesh's rank 0."""
72 | replicate_group = mesh.get_group("replicate")
73 | all_ranks = dist.get_process_group_ranks(replicate_group)
74 | if len(all_ranks) == 1:
75 | return
76 |
77 | for _, tensor in itertools.chain(model.named_parameters(), model.named_buffers()):
78 | # Get src rank which is the first rank in each replication group
79 | src_rank = all_ranks[0]
80 | # Broadcast the local tensor
81 | local_tensor = get_local_tensor_if_DTensor(tensor)
82 | dist.broadcast(
83 | local_tensor,
84 | src=src_rank,
85 | group=replicate_group,
86 | )
87 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/pil_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import IO
17 |
18 | import numpy as np
19 |
20 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21 |
22 | try:
23 | from PIL import Image
24 | except ImportError:
25 | Image = None
26 |
27 |
28 | class PILHandler(BaseFileHandler):
29 | format: str
30 | str_like = False
31 |
32 | def load_from_fileobj(
33 | self,
34 | file: IO[bytes],
35 | fmt: str = "pil",
36 | size: int | tuple[int, int] | None = None,
37 | **kwargs,
38 | ):
39 | """
40 | Load an image from a file-like object and return it in a specified format.
41 |
42 | Args:
43 | file (IO[bytes]): A file-like object containing the image data.
44 | fmt (str): The format to convert the image into. Options are \
45 | 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \
46 | 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor).
47 | size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \
48 | or a tuple of (width, height). If specified, the image is resized accordingly.
49 | **kwargs: Additional keyword arguments that can be passed to conversion functions.
50 |
51 | Returns:
52 | Image data in the format specified by `fmt`.
53 |
54 | Raises:
55 | IOError: If the image cannot be loaded or processed.
56 | ValueError: If the specified format is unsupported.
57 | """
58 | try:
59 | img = Image.open(file)
60 | img.load() # Explicitly load the image data
61 | if size is not None:
62 | if isinstance(size, int):
63 | size = (
64 | size,
65 | size,
66 | ) # create a tuple if only one integer is provided
67 | img = img.resize(size, Image.ANTIALIAS)
68 |
69 | # Return the image in the requested format
70 | if fmt in ["numpy", "np", "npy"]:
71 | return np.array(img, **kwargs)
72 | if fmt == "pil":
73 | return img
74 | if fmt in ["th", "torch"]:
75 | import torch
76 |
77 | # Convert to tensor
78 | img_tensor = torch.from_numpy(np.array(img, **kwargs))
79 | # Convert image from HxWxC to CxHxW
80 | if img_tensor.ndim == 3:
81 | img_tensor = img_tensor.permute(2, 0, 1)
82 | return img_tensor
83 | raise ValueError(
84 | "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'."
85 | )
86 | except Exception as e:
87 | raise OSError(f"Unable to load image: {e}") from e
88 |
89 | def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs):
90 | if "format" not in kwargs:
91 | kwargs["format"] = self.format
92 | kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper()
93 | obj.save(file, **kwargs)
94 |
95 | def dump_to_str(self, obj, **kwargs):
96 | raise NotImplementedError
97 |
--------------------------------------------------------------------------------
/rcm/callbacks/grad_clip.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass
17 | from typing import List, Tuple
18 |
19 | import torch
20 | import wandb
21 |
22 | from imaginaire.utils import distributed
23 | from imaginaire.utils.callback import Callback
24 |
25 |
26 | @torch.jit.script
27 | def _fused_nan_to_num(params: List[torch.Tensor]):
28 | for param in params:
29 | torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param)
30 |
31 |
32 | @dataclass
33 | class _MagnitudeRecord:
34 | state: float = 0
35 | iter_count: int = 0
36 |
37 | def reset(self) -> None:
38 | self.state = 0
39 | self.iter_count = 0
40 |
41 | def update(self, cur_state: torch.Tensor) -> None:
42 | self.state += cur_state
43 | self.iter_count += 1
44 |
45 | def get_stat(self) -> Tuple[float, float]:
46 | if self.iter_count > 0:
47 | avg_state = self.state / self.iter_count
48 | avg_state = avg_state.item()
49 | else:
50 | avg_state = 0
51 | self.reset()
52 | return avg_state
53 |
54 |
55 | class GradClip(Callback):
56 | """
57 | This callback is used to clip the gradient norm of the model.
58 | It also logs the average gradient norm of the model to wandb.
59 | """
60 |
61 | def __init__(self, clip_norm=1.0, force_finite: bool = True):
62 | self.clip_norm = clip_norm
63 | self.force_finite = force_finite
64 |
65 | self.img_mag_log = _MagnitudeRecord()
66 | self.video_mag_log = _MagnitudeRecord()
67 | self._cur_state = None
68 |
69 | def on_training_step_start(
70 | self, model, data_batch: dict[str, torch.Tensor], iteration: int = 0
71 | ) -> None:
72 | if model.is_image_batch(data_batch):
73 | self._cur_state = self.img_mag_log
74 | else:
75 | self._cur_state = self.video_mag_log
76 |
77 | def on_before_optimizer_step(
78 | self,
79 | model_ddp: distributed.DistributedDataParallel,
80 | optimizer: torch.optim.Optimizer,
81 | scheduler: torch.optim.lr_scheduler.LRScheduler,
82 | grad_scaler: torch.amp.GradScaler,
83 | iteration: int = 0,
84 | ) -> None:
85 | del optimizer, scheduler
86 | if isinstance(model_ddp, distributed.DistributedDataParallel):
87 | model = model_ddp.module
88 | else:
89 | model = model_ddp
90 | params = []
91 |
92 | if self.force_finite:
93 | for param in model.parameters():
94 | if param.grad is not None:
95 | params.append(param.grad)
96 | _fused_nan_to_num(params)
97 |
98 | total_norm = model.clip_grad_norm_(self.clip_norm)
99 |
100 | self._cur_state.update(total_norm)
101 | if iteration % self.config.trainer.logging_iter == 0:
102 | avg_img_mag, avg_video_mag = self.img_mag_log.get_stat(), self.video_mag_log.get_stat()
103 | if wandb.run:
104 | wandb.log(
105 | {
106 | "clip_grad_norm/image": avg_img_mag,
107 | "clip_grad_norm/video": avg_video_mag,
108 | "iteration": iteration,
109 | },
110 | step=iteration,
111 | )
112 |
--------------------------------------------------------------------------------
/imaginaire/utils/wandb_util.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import annotations
17 |
18 | import os
19 | from typing import TYPE_CHECKING
20 |
21 | import attrs
22 | import wandb
23 | import wandb.util
24 | from omegaconf import DictConfig
25 |
26 | from imaginaire.lazy_config.lazy import LazyConfig
27 | from imaginaire.utils import distributed, log
28 | from imaginaire.utils.easy_io import easy_io
29 |
30 | if TYPE_CHECKING:
31 | from imaginaire.config import CheckpointConfig, Config, JobConfig
32 | from imaginaire.model import ImaginaireModel
33 |
34 |
35 | @distributed.rank0_only
36 | def init_wandb(config: Config, model: ImaginaireModel) -> None:
37 | """Initialize Weights & Biases (wandb) logger.
38 |
39 | Args:
40 | config (Config): The config object for the Imaginaire codebase.
41 | model (ImaginaireModel): The PyTorch model.
42 | """
43 | if isinstance(config.job, DictConfig):
44 | from imaginaire.config import JobConfig
45 |
46 | config_job = JobConfig(**config.job)
47 | else:
48 | config_job = config.job
49 | config_checkpoint = config.checkpoint
50 | # Try to fetch the W&B job ID for resuming training.
51 | wandb_id = _read_wandb_id(config_job, config_checkpoint)
52 | if wandb_id is None:
53 | # Generate a new W&B job ID.
54 | wandb_id = wandb.util.generate_id()
55 | _write_wandb_id(config_job, config_checkpoint, wandb_id=wandb_id)
56 | log.info(f"Generating new wandb ID: {wandb_id}")
57 | else:
58 | log.info(f"Resuming with existing wandb ID: {wandb_id}")
59 | # refactor config so that wandb better understands it
60 | local_safe_yaml_fp = LazyConfig.save_yaml(config, os.path.join(config_job.path_local, "config.yaml"))
61 | if os.path.exists(local_safe_yaml_fp):
62 | config_resolved = easy_io.load(local_safe_yaml_fp)
63 | else:
64 | config_resolved = attrs.asdict(config)
65 | # Initialize the wandb library.
66 | wandb.init(
67 | force=True,
68 | id=wandb_id,
69 | project=config_job.project,
70 | group=config_job.group,
71 | name=config_job.name,
72 | config=config_resolved,
73 | dir=config_job.path_local,
74 | resume="allow",
75 | mode=config_job.wandb_mode,
76 | )
77 |
78 |
79 | def _read_wandb_id(config_job: JobConfig, config_checkpoint: CheckpointConfig) -> str | None:
80 | """Read the W&B job ID. If it doesn't exist, return None.
81 |
82 | Args:
83 | config_wandb (JobConfig): The config object for the W&B logger.
84 | config_checkpoint (CheckpointConfig): The config object for the checkpointer.
85 |
86 | Returns:
87 | wandb_id (str | None): W&B job ID.
88 | """
89 | wandb_id = None
90 | wandb_id_path = f"{config_job.path_local}/wandb_id.txt"
91 | if os.path.isfile(wandb_id_path):
92 | wandb_id = open(wandb_id_path).read().strip()
93 | return wandb_id
94 |
95 |
96 | def _write_wandb_id(config_job: JobConfig, config_checkpoint: CheckpointConfig, wandb_id: str) -> None:
97 | """Write the generated W&B job ID.
98 |
99 | Args:
100 | config_wandb (JobConfig): The config object for the W&B logger.
101 | config_checkpoint (CheckpointConfig): The config object for the checkpointer.
102 | wandb_id (str): The W&B job ID.
103 | """
104 | content = f"{wandb_id}\n"
105 | wandb_id_path = f"{config_job.path_local}/wandb_id.txt"
106 | with open(wandb_id_path, "w") as file:
107 | file.write(content)
108 |
--------------------------------------------------------------------------------
/rcm/configs/registry_distill.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Any, List
17 |
18 | import attrs
19 |
20 | from imaginaire import config
21 | from imaginaire.utils.config_helper import import_all_modules_from_package
22 | from rcm.configs.defaults.trainer import register_trainer
23 | from rcm.configs.defaults.checkpoint import register_checkpoint
24 | from rcm.configs.defaults.ema import register_ema
25 | from rcm.configs.defaults.optimizer import register_optimizer, register_optimizer_fake_score
26 | from rcm.configs.defaults.scheduler import register_scheduler
27 | from rcm.configs.defaults.conditioner import register_conditioner
28 | from rcm.configs.defaults.callbacks import register_callbacks
29 | from rcm.configs.defaults.ckpt_type import register_ckpt_type
30 | from rcm.configs.defaults.dataloader import register_dataloader
31 | from rcm.configs.defaults.tokenizer import register_tokenizer
32 | from rcm.configs.defaults.model import register_model
33 | from rcm.configs.defaults.net import register_net, register_net_fake_score, register_net_teacher
34 |
35 |
36 | @attrs.define(slots=False)
37 | class Config(config.Config):
38 | # default config groups that will be used unless overwritten
39 | # see config groups in registry.py
40 | defaults: List[Any] = attrs.field(
41 | factory=lambda: [
42 | "_self_",
43 | {"trainer": "standard"},
44 | {"data_train": "dummy"},
45 | {"data_val": "dummy"},
46 | {"optimizer": "fusedadamw"},
47 | {"scheduler": "lambdalinear"},
48 | {"callbacks": "basic"},
49 | {"checkpoint": "local"},
50 | {"ckpt_type": "dcp"},
51 | {"model": "fsdp_t2v_distill_rcm"},
52 | {"net": None},
53 | {"net_teacher": None},
54 | {"net_fake_score": None},
55 | {"optimizer_fake_score": "fusedadamw"},
56 | {"conditioner": "text_nodrop"},
57 | {"ema": "power"},
58 | {"tokenizer": "wan2pt1_tokenizer"},
59 | # the list is with order, we need global experiment to be the last one
60 | {"experiment": None},
61 | ]
62 | )
63 |
64 |
65 | def make_config() -> Config:
66 | c = Config(
67 | model=None,
68 | optimizer=None,
69 | scheduler=None,
70 | dataloader_train=None,
71 | dataloader_val=None,
72 | )
73 |
74 | # Specifying values through instances of attrs
75 | c.job.project = "rcm" # this decides the wandb project name
76 | c.job.group = "debug"
77 | c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}"
78 |
79 | c.trainer.max_iter = 400_000
80 | c.trainer.logging_iter = 100
81 | c.trainer.validation_iter = 100
82 | c.trainer.run_validation = False
83 | c.trainer.callbacks = None
84 |
85 | # Call this function to register config groups for advanced overriding. the order follows the default config groups
86 | register_trainer()
87 | register_dataloader()
88 | register_optimizer()
89 | register_optimizer_fake_score()
90 | register_scheduler()
91 | register_callbacks()
92 | register_checkpoint()
93 | register_ckpt_type()
94 | register_model()
95 | register_net()
96 | register_net_teacher()
97 | register_net_fake_score()
98 | register_conditioner()
99 | register_ema()
100 | register_tokenizer()
101 |
102 | # experiment config are defined in the experiments folder
103 | import_all_modules_from_package("rcm.configs.experiments.rcm", reload=True)
104 | return c
105 |
--------------------------------------------------------------------------------
/rcm/callbacks/iter_speed.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import time
17 |
18 | import torch
19 | import wandb
20 | from torch import Tensor
21 |
22 | from imaginaire.callbacks.every_n import EveryN
23 | from imaginaire.model import ImaginaireModel
24 | from imaginaire.trainer import ImaginaireTrainer
25 | from imaginaire.utils import log
26 | from imaginaire.utils.distributed import rank0_only
27 | from imaginaire.utils.easy_io import easy_io
28 |
29 |
30 | class IterSpeed(EveryN):
31 | """
32 | Args:
33 | hit_thres (int): Number of iterations to wait before logging.
34 | save_s3 (bool): Whether to save to S3.
35 | save_s3_every_log_n (int): Save to S3 every n log iterations, which means save_s3_every_log_n n * every_n global iterations.
36 | """
37 |
38 | def __init__(self, *args, hit_thres: int = 5, save_s3: bool = False, save_s3_every_log_n: int = 10, **kwargs):
39 | super().__init__(*args, **kwargs)
40 | self.time = None
41 | self.hit_counter = 0
42 | self.hit_thres = hit_thres
43 | self.save_s3 = save_s3
44 | self.save_s3_every_log_n = save_s3_every_log_n
45 | self.name = self.__class__.__name__
46 | self.last_hit_time = time.time()
47 |
48 | def on_training_step_end(
49 | self,
50 | model: ImaginaireModel,
51 | data_batch: dict[str, torch.Tensor],
52 | output_batch: dict[str, torch.Tensor],
53 | loss: torch.Tensor,
54 | iteration: int = 0,
55 | ) -> None:
56 | if self.hit_counter < self.hit_thres:
57 | log.info(
58 | f"Iteration {iteration}: "
59 | f"Hit counter: {self.hit_counter + 1}/{self.hit_thres} | "
60 | f"Loss: {loss.item():.4f} | "
61 | f"Time: {time.time() - self.last_hit_time:.2f}s"
62 | )
63 | self.hit_counter += 1
64 | self.last_hit_time = time.time()
65 | #! useful for large scale training and avoid oom crash in the first two iterations!!!
66 | torch.cuda.synchronize()
67 | return
68 | super().on_training_step_end(model, data_batch, output_batch, loss, iteration)
69 |
70 | @rank0_only
71 | def every_n_impl(
72 | self,
73 | trainer: ImaginaireTrainer,
74 | model: ImaginaireModel,
75 | data_batch: dict[str, Tensor],
76 | output_batch: dict[str, Tensor],
77 | loss: Tensor,
78 | iteration: int,
79 | ) -> None:
80 | if self.time is None:
81 | self.time = time.time()
82 | return
83 | cur_time = time.time()
84 | iter_speed = (cur_time - self.time) / self.every_n / self.step_size
85 |
86 | log.info(f"{iteration} : iter_speed {iter_speed:.2f} seconds per iteration | Loss: {loss.item():.4f}")
87 |
88 | if wandb.run:
89 | sample_counter = getattr(trainer, "sample_counter", iteration)
90 | wandb.log(
91 | {
92 | "timer/iter_speed": iter_speed,
93 | "sample_counter": sample_counter,
94 | },
95 | step=iteration,
96 | )
97 | self.time = cur_time
98 | if self.save_s3:
99 | if iteration % (self.save_s3_every_log_n * self.every_n) == 0:
100 | easy_io.dump(
101 | {
102 | "iter_speed": iter_speed,
103 | "iteration": iteration,
104 | },
105 | f"s3://rundir/{self.name}/iter_{iteration:09d}.yaml",
106 | )
107 |
--------------------------------------------------------------------------------
/rcm/datasets/visualize_tar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import argparse
4 | import tarfile
5 | import torch
6 | from einops import rearrange
7 | from tqdm import tqdm
8 |
9 | from rcm.tokenizers.wan2pt1 import Wan2pt1VAEInterface
10 | from imaginaire.utils.io import save_image_or_video
11 |
12 |
13 | def main(args):
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 | print(f"Using device: {device}")
16 |
17 | print("Loading VAE tokenizer...")
18 | try:
19 | tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path)
20 | except Exception as e:
21 | print(f"Error loading VAE from {args.vae_path}: {e}")
22 | print("Please ensure the VAE model path is correct.")
23 | return
24 |
25 | print(f"Opening tar file: {args.tar_path}")
26 | if not os.path.exists(args.tar_path):
27 | print(f"Error: Tar file not found at {args.tar_path}")
28 | return
29 |
30 | latent_files = []
31 | with tarfile.open(args.tar_path, "r") as tar:
32 | all_files = tar.getnames()
33 | latent_files = sorted([f for f in all_files if f.endswith(".latent.pt")])
34 |
35 | if not latent_files:
36 | print("Error: No '.latent.pt' files found in the tar archive.")
37 | return
38 |
39 | print(f"Found {len(latent_files)} latent files in the archive.")
40 |
41 | if args.num_samples > 0:
42 | latent_files = latent_files[: args.num_samples]
43 | print(f"Decoding the first {len(latent_files)} samples.")
44 |
45 | decoded_videos = []
46 | print("Decoding samples...")
47 | with tarfile.open(args.tar_path, "r") as tar:
48 | for member_name in tqdm(latent_files, desc="Decoding"):
49 | member_file = tar.extractfile(member_name)
50 | if member_file is None:
51 | print(f"Warning: Could not extract {member_name}. Skipping.")
52 | continue
53 |
54 | latent_tensor = torch.load(io.BytesIO(member_file.read()), map_location=device)
55 |
56 | if latent_tensor.dim() == 4:
57 | samples = latent_tensor.unsqueeze(0)
58 | else:
59 | samples = latent_tensor
60 |
61 | video = tokenizer.decode(samples.to(device, dtype=torch.bfloat16))
62 |
63 | decoded_videos.append(video.float().cpu())
64 |
65 | if not decoded_videos:
66 | print("No videos were decoded. Exiting.")
67 | return
68 |
69 | print("Stacking videos into a grid...")
70 | to_show = torch.stack(decoded_videos, dim=0)
71 |
72 | to_show = (1.0 + to_show.clamp(-1, 1)) / 2.0
73 |
74 | num_videos = to_show.shape[0]
75 | grid_cols = args.grid_cols
76 | grid_rows = (num_videos + grid_cols - 1) // grid_cols # 向上取整
77 |
78 | if num_videos != grid_rows * grid_cols:
79 | print(f"Warning: The number of videos ({num_videos}) doesn't fit a perfect {grid_rows}x{grid_cols} grid. Rearranging may be imperfect.")
80 |
81 | to_show = rearrange(to_show, "(rows cols) b c t h w -> c t (rows h) (cols b w)", cols=grid_cols)
82 |
83 | save_dir = os.path.dirname(args.save_path)
84 | if save_dir:
85 | os.makedirs(save_dir, exist_ok=True)
86 |
87 | print(f"Saving video grid to: {args.save_path}")
88 | save_image_or_video(to_show, args.save_path, fps=args.fps)
89 |
90 | print("Done!")
91 |
92 |
93 | if __name__ == "__main__":
94 | parser = argparse.ArgumentParser(description="Decode latent samples from a .tar file into a video grid.")
95 |
96 | parser.add_argument("--tar_path", type=str, required=True, help="Path to the input .tar file containing latent samples.")
97 | parser.add_argument("--save_path", type=str, default="preview.mp4", help="Path to save the output video file.")
98 | parser.add_argument("--vae_path", type=str, default="assets/checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE model weights.")
99 | parser.add_argument("--num_samples", type=int, default=16, help="Number of samples to decode from the tar file. Set to 0 to decode all.")
100 | parser.add_argument("--grid_cols", type=int, default=4, help="Number of columns in the output video grid.")
101 | parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.")
102 |
103 | args = parser.parse_args()
104 | main(args)
105 |
--------------------------------------------------------------------------------
/rcm/callbacks/heart_beat.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import time
17 | from datetime import datetime
18 |
19 | import pytz
20 | import torch
21 |
22 | from imaginaire.callbacks.every_n import EveryN
23 | from imaginaire.model import ImaginaireModel
24 | from imaginaire.trainer import ImaginaireTrainer
25 | from imaginaire.utils import distributed
26 | from imaginaire.utils.easy_io import easy_io
27 |
28 |
29 | class HeartBeat(EveryN):
30 | """
31 | A callback that logs a heartbeat message at regular intervals to indicate that the training process is still running.
32 |
33 | Args:
34 | every_n (int): The frequency at which the callback is invoked.
35 | step_size (int, optional): The step size for the callback. Defaults to 1.
36 | update_interval_in_minute (int, optional): The interval in minutes for logging the heartbeat. Defaults to 20 minutes.
37 | save_s3 (bool, optional): Whether to save the heartbeat information to S3. Defaults to False.
38 | """
39 |
40 | def __init__(self, every_n: int, step_size: int = 1, update_interval_in_minute: int = 20, save_s3: bool = False):
41 | super().__init__(every_n=every_n, step_size=step_size)
42 | self.name = self.__class__.__name__
43 | self.update_interval_in_minute = update_interval_in_minute
44 | self.save_s3 = save_s3
45 | self.pst = pytz.timezone("America/Los_Angeles")
46 | self.is_hitted = False
47 |
48 | @distributed.rank0_only
49 | def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
50 | self.time = time.time()
51 | if self.save_s3:
52 | current_time_pst = datetime.now(self.pst).strftime("%Y_%m_%d-%H_%M_%S")
53 | info = {
54 | "iteration": iteration,
55 | "time": current_time_pst,
56 | }
57 | easy_io.dump(info, f"s3://rundir/{self.name}_start.yaml")
58 | easy_io.dump(info, f"s3://timestamps_rundir/{self.name}_start.yaml")
59 |
60 | def on_training_step_end(
61 | self,
62 | model: ImaginaireModel,
63 | data_batch: dict[str, torch.Tensor],
64 | output_batch: dict[str, torch.Tensor],
65 | loss: torch.Tensor,
66 | iteration: int = 0,
67 | ) -> None:
68 | if not self.is_hitted:
69 | self.is_hitted = True
70 | if distributed.get_rank() == 0:
71 | self.report(iteration)
72 | super().on_training_step_end(model, data_batch, output_batch, loss, iteration)
73 |
74 | @distributed.rank0_only
75 | def every_n_impl(
76 | self,
77 | trainer: ImaginaireTrainer,
78 | model: ImaginaireModel,
79 | data_batch: dict[str, torch.Tensor],
80 | output_batch: dict[str, torch.Tensor],
81 | loss: torch.Tensor,
82 | iteration: int,
83 | ) -> None:
84 | if time.time() - self.time > 60 * self.update_interval_in_minute:
85 | self.report(iteration)
86 |
87 | def report(self, iteration: int = 0):
88 | self.time = time.time()
89 | if self.save_s3:
90 | current_time_pst = datetime.now(self.pst).strftime("%Y_%m_%d-%H_%M_%S")
91 | info = {
92 | "iteration": iteration,
93 | "time": current_time_pst,
94 | }
95 | easy_io.dump(info, f"s3://rundir/{self.name}.yaml")
96 |
97 | @distributed.rank0_only
98 | def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
99 | if self.save_s3:
100 | current_time_pst = datetime.now(self.pst).strftime("%Y_%m_%d-%H_%M_%S")
101 | info = {
102 | "iteration": iteration,
103 | "time": current_time_pst,
104 | }
105 | easy_io.dump(info, f"s3://rundir/{self.name}_end.yaml")
106 | easy_io.dump(info, f"s3://timestamps_rundir/{self.name}_end.yaml")
107 |
--------------------------------------------------------------------------------
/imaginaire/lazy_config/instantiate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import collections.abc as abc
17 | import dataclasses
18 | from typing import Any
19 |
20 | import attrs
21 |
22 | from imaginaire.lazy_config.registry import _convert_target_to_string, locate
23 | from imaginaire.utils import log
24 |
25 | __all__ = ["dump_dataclass", "instantiate"]
26 |
27 |
28 | def is_dataclass_or_attrs(target):
29 | return dataclasses.is_dataclass(target) or attrs.has(target)
30 |
31 |
32 | def dump_dataclass(obj: Any):
33 | """
34 | Dump a dataclass recursively into a dict that can be later instantiated.
35 |
36 | Args:
37 | obj: a dataclass object
38 |
39 | Returns:
40 | dict
41 | """
42 | assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
43 | "dump_dataclass() requires an instance of a dataclass."
44 | )
45 | ret = {"_target_": _convert_target_to_string(type(obj))}
46 | for f in dataclasses.fields(obj):
47 | v = getattr(obj, f.name)
48 | if dataclasses.is_dataclass(v):
49 | v = dump_dataclass(v)
50 | if isinstance(v, (list, tuple)):
51 | v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
52 | ret[f.name] = v
53 | return ret
54 |
55 |
56 | def instantiate(cfg, *args, **kwargs):
57 | """
58 | Recursively instantiate objects defined in dictionaries by
59 | "_target_" and arguments.
60 |
61 | Args:
62 | cfg: a dict-like object with "_target_" that defines the caller, and
63 | other keys that define the arguments
64 | args: Optional positional parameters pass-through.
65 | kwargs: Optional named parameters pass-through.
66 |
67 | Returns:
68 | object instantiated by cfg
69 | """
70 | from omegaconf import DictConfig, ListConfig, OmegaConf
71 |
72 | if isinstance(cfg, ListConfig):
73 | lst = [instantiate(x) for x in cfg]
74 | return ListConfig(lst, flags={"allow_objects": True})
75 | if isinstance(cfg, list):
76 | # Specialize for list, because many classes take
77 | # list[objects] as arguments, such as ResNet, DatasetMapper
78 | return [instantiate(x) for x in cfg]
79 |
80 | # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
81 | # instantiate it to the actual dataclass.
82 | if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
83 | return OmegaConf.to_object(cfg)
84 |
85 | if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
86 | # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
87 | # but faster: https://github.com/facebookresearch/hydra/issues/1200
88 | is_recursive = getattr(cfg, "_recursive_", True)
89 | if is_recursive:
90 | cfg = {k: instantiate(v) for k, v in cfg.items()}
91 | else:
92 | cfg = {k: v for k, v in cfg.items()}
93 | # pop the _recursive_ key to avoid passing it as a parameter
94 | if "_recursive_" in cfg:
95 | cfg.pop("_recursive_")
96 | cls = cfg.pop("_target_")
97 | cls = instantiate(cls)
98 |
99 | if isinstance(cls, str):
100 | cls_name = cls
101 | cls = locate(cls_name)
102 | assert cls is not None, cls_name
103 | else:
104 | try:
105 | cls_name = cls.__module__ + "." + cls.__qualname__
106 | except Exception:
107 | # target could be anything, so the above could fail
108 | cls_name = str(cls)
109 | assert callable(cls), f"_target_ {cls} does not define a callable object"
110 | try:
111 | # override config with kwargs
112 | instantiate_kwargs = {}
113 | instantiate_kwargs.update(cfg)
114 | instantiate_kwargs.update(kwargs)
115 | return cls(*args, **instantiate_kwargs)
116 | except TypeError:
117 | log.error(f"Error when instantiating {cls_name}!")
118 | raise
119 | return cfg # return as-is if don't know what to do
120 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Misc
2 | outputs/
3 | output/
4 | logs/
5 | checkpoints/
6 | assets/
7 |
8 | *.jit
9 | *.pt
10 | *.png
11 | *.hdr
12 | *.jpg
13 | *.jpeg
14 | *.webp
15 | *.pgm
16 | *.tiff
17 | *.tif
18 | *.gif
19 | *.mp4
20 | *.tar
21 | *.tar.gz
22 | *.gz
23 | *.pkl
24 | *.pt
25 | *.bin
26 |
27 | # Other uncheckable file types
28 | *.zip
29 | *.exe
30 | *.dll
31 | *.swp
32 | *.vscode/**
33 | *.ipynb
34 | *.DS_Store
35 | *.pyc
36 | *Thumbs.db
37 | *.patch
38 | __MACOSX
39 |
40 | # ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------
41 |
42 | # Byte-compiled / optimized / DLL files
43 | **/__pycache__/
44 | *.py[cod]
45 | *$py.class
46 |
47 | # C extensions
48 | *.so
49 |
50 | # Distribution / packaging
51 | .Python
52 | build/
53 | develop-eggs/
54 | dist/
55 | downloads/
56 | eggs/
57 | .eggs/
58 | lib/
59 | lib64/
60 | parts/
61 | results/
62 | sdist/
63 | var/
64 | wheels/
65 | share/python-wheels/
66 | *.egg-info/
67 | .installed.config
68 | *.egg
69 | MANIFEST
70 |
71 | # PyInstaller
72 | # Usually these files are written by a python script from a template
73 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
74 | *.manifest
75 | *.spec
76 |
77 | # Installer logs
78 | pip-log.txt
79 | pip-delete-this-directory.txt
80 |
81 | # Unit test / coverage reports
82 | htmlcov/
83 | .tox/
84 | .nox/
85 | .coverage
86 | .coverage.*
87 | .cache
88 | nosetests.xml
89 | coverage.xml
90 | *.cover
91 | *.py,cover
92 | .hypothesis/
93 | .pytest_cache/
94 | cover/
95 |
96 | # Translations
97 | *.mo
98 | *.pot
99 |
100 | # Django stuff:
101 | *.log
102 | local_settings.py
103 | db.sqlite3
104 | db.sqlite3-journal
105 |
106 | # Flask stuff:
107 | instance/
108 | .webassets-cache
109 |
110 | # Scrapy stuff:
111 | .scrapy
112 |
113 | # Sphinx documentation
114 | docs/_build/
115 |
116 | # PyBuilder
117 | .pybuilder/
118 | target/
119 |
120 | # Third party
121 | inf_lib/SpargeAttn/csrc/qattn/instantiations_sm*/*.cu
122 |
123 | # Jupyter Notebook
124 | .ipynb_checkpoints
125 |
126 | # IPython
127 | profile_default/
128 | ipython_config.py
129 |
130 | # pyenv
131 | # For a library or package, you might want to ignore these files since the code is
132 | # intended to run in multiple environments; otherwise, check them in:
133 | # .python-version
134 |
135 | # pipenv
136 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
137 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
138 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
139 | # install all needed dependencies.
140 | #Pipfile.lock
141 |
142 | # poetry
143 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
144 | # This is especially recommended for binary packages to ensure reproducibility, and is more
145 | # commonly ignored for libraries.
146 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
147 | #poetry.lock
148 |
149 | # pdm
150 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
151 | #pdm.lock
152 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
153 | # in version control.
154 | # https://pdm.fming.dev/#use-with-ide
155 | .pdm.toml
156 |
157 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
158 | __pypackages__/
159 |
160 | # Celery stuff
161 | celerybeat-schedule
162 | celerybeat.pid
163 |
164 | # SageMath parsed files
165 | *.sage.py
166 |
167 | # Environments
168 | .env
169 | .venv
170 | env/
171 | venv/
172 | ENV/
173 | env.bak/
174 | venv.bak/
175 |
176 | # Spyder project settings
177 | .spyderproject
178 | .spyproject
179 |
180 | # Rope project settings
181 | .ropeproject
182 |
183 | # mkdocs documentation
184 | /site
185 |
186 | # mypy
187 | .mypy_cache/
188 | .dmypy.json
189 | dmypy.json
190 |
191 | # Pyre type checker
192 | .pyre/
193 |
194 | # pytype static type analyzer
195 | .pytype/
196 |
197 | # Cython debug symbols
198 | cython_debug/
199 |
200 | # ruff
201 | .ruff_cache
202 |
203 | # PyCharm
204 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
205 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
206 | # and can be added to the global gitignore or merged into this file. For a more nuclear
207 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
208 | #.idea/
209 | CLIP
210 | .devcontainer/devcontainer.json
211 |
212 | # Coverage
213 | .coverage
214 | coverage.xml
215 |
216 | # JUnit Reports
217 | report.xml
218 |
219 | # CI-CD
220 | ci-cd/edify/self_test/output_cicd.yaml
221 | temp/
222 | envs.txt
223 | manifest.json
224 |
225 |
226 | # locks
227 | *.locks*
228 | *.no_exist*
229 |
230 | wandb/
231 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/backends/registry_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import inspect
17 |
18 | from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
19 | from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
20 | from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
21 |
22 | backends: dict = {}
23 | prefix_to_backends: dict = {}
24 |
25 |
26 | def _register_backend(
27 | name: str,
28 | backend: type[BaseStorageBackend],
29 | force: bool = False,
30 | prefixes: str | list | tuple | None = None,
31 | ):
32 | """Register a backend.
33 |
34 | Args:
35 | name (str): The name of the registered backend.
36 | backend (BaseStorageBackend): The backend class to be registered,
37 | which must be a subclass of :class:`BaseStorageBackend`.
38 | force (bool): Whether to override the backend if the name has already
39 | been registered. Defaults to False.
40 | prefixes (str or list[str] or tuple[str], optional): The prefix
41 | of the registered storage backend. Defaults to None.
42 | """
43 | global backends, prefix_to_backends
44 |
45 | if not isinstance(name, str):
46 | raise TypeError(f"the backend name should be a string, but got {type(name)}")
47 |
48 | if not inspect.isclass(backend):
49 | raise TypeError(f"backend should be a class, but got {type(backend)}")
50 | if not issubclass(backend, BaseStorageBackend):
51 | raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
52 |
53 | if name in backends and not force:
54 | raise ValueError(
55 | f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
56 | )
57 | backends[name] = backend
58 |
59 | if prefixes is not None:
60 | if isinstance(prefixes, str):
61 | prefixes = [prefixes]
62 | else:
63 | assert isinstance(prefixes, (list, tuple))
64 |
65 | for prefix in prefixes:
66 | if prefix in prefix_to_backends and not force:
67 | raise ValueError(
68 | f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it'
69 | )
70 |
71 | prefix_to_backends[prefix] = backend
72 |
73 |
74 | def register_backend(
75 | name: str,
76 | backend: type[BaseStorageBackend] | None = None,
77 | force: bool = False,
78 | prefixes: str | list | tuple | None = None,
79 | ):
80 | """Register a backend.
81 |
82 | Args:
83 | name (str): The name of the registered backend.
84 | backend (class, optional): The backend class to be registered,
85 | which must be a subclass of :class:`BaseStorageBackend`.
86 | When this method is used as a decorator, backend is None.
87 | Defaults to None.
88 | force (bool): Whether to override the backend if the name has already
89 | been registered. Defaults to False.
90 | prefixes (str or list[str] or tuple[str], optional): The prefix
91 | of the registered storage backend. Defaults to None.
92 |
93 | This method can be used as a normal method or a decorator.
94 |
95 | Examples:
96 |
97 | >>> class NewBackend(BaseStorageBackend):
98 | ... def get(self, filepath):
99 | ... return filepath
100 | ...
101 | ... def get_text(self, filepath):
102 | ... return filepath
103 | >>> register_backend('new', NewBackend)
104 |
105 | >>> @register_backend('new')
106 | ... class NewBackend(BaseStorageBackend):
107 | ... def get(self, filepath):
108 | ... return filepath
109 | ...
110 | ... def get_text(self, filepath):
111 | ... return filepath
112 | """
113 | if backend is not None:
114 | _register_backend(name, backend, force=force, prefixes=prefixes)
115 | return
116 |
117 | def _register(backend_cls):
118 | _register_backend(name, backend_cls, force=force, prefixes=prefixes)
119 | return backend_cls
120 |
121 | return _register
122 |
123 |
124 | register_backend("local", LocalBackend, prefixes="")
125 | register_backend("http", HTTPBackend, prefixes=["http", "https"])
126 |
--------------------------------------------------------------------------------
/imaginaire/utils/log.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import atexit
17 | import os
18 | import sys
19 | from typing import Any
20 |
21 | import torch.distributed as dist
22 | from loguru._logger import Core, Logger
23 |
24 | RANK0_ONLY = True
25 | LEVEL = os.environ.get("LOGURU_LEVEL", "INFO")
26 |
27 | logger = Logger(
28 | core=Core(),
29 | exception=None,
30 | depth=1,
31 | record=False,
32 | lazy=False,
33 | colors=False,
34 | raw=False,
35 | capture=True,
36 | patchers=[],
37 | extra={},
38 | )
39 |
40 | atexit.register(logger.remove)
41 |
42 |
43 | def _add_relative_path(record: dict[str, Any]) -> None:
44 | start = os.getcwd()
45 | record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start)
46 |
47 |
48 | *options, _, extra = logger._options # type: ignore
49 | logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore
50 |
51 |
52 | def init_loguru_stdout() -> None:
53 | logger.remove()
54 | machine_format = get_machine_format()
55 | message_format = get_message_format()
56 | logger.add(
57 | sys.stdout,
58 | level=LEVEL,
59 | format=f"[{{time:MM-DD HH:mm:ss}}|{machine_format}{message_format}",
60 | filter=_rank0_only_filter,
61 | )
62 |
63 |
64 | def init_loguru_file(path: str) -> None:
65 | machine_format = get_machine_format()
66 | message_format = get_message_format()
67 | logger.add(
68 | path,
69 | encoding="utf8",
70 | level=LEVEL,
71 | format=f"[{{time:MM-DD HH:mm:ss}}|{machine_format}{message_format}",
72 | rotation="100 MB",
73 | filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY,
74 | enqueue=True,
75 | )
76 |
77 |
78 | def get_machine_format() -> str:
79 | node_id = "0"
80 | num_nodes = 1
81 | machine_format = ""
82 | rank = 0
83 | if dist.is_available():
84 | if not RANK0_ONLY and dist.is_initialized():
85 | rank = dist.get_rank()
86 | world_size = dist.get_world_size()
87 | machine_format = (
88 | f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| "
89 | )
90 | return machine_format
91 |
92 |
93 | def get_message_format() -> str:
94 | message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}"
95 | return message_format
96 |
97 |
98 | def _rank0_only_filter(record: Any) -> bool:
99 | is_rank0 = record["extra"].get("rank0_only", True)
100 | if _get_rank() == 0 and is_rank0:
101 | return True
102 | if not is_rank0:
103 | record["message"] = f"[RANK{_get_rank()}] " + record["message"]
104 | return not is_rank0
105 |
106 |
107 | def trace(message: str, rank0_only: bool = True) -> None:
108 | logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message)
109 |
110 |
111 | def debug(message: str, rank0_only: bool = True) -> None:
112 | logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message)
113 |
114 |
115 | def info(message: str, rank0_only: bool = True) -> None:
116 | logger.opt(depth=1).bind(rank0_only=rank0_only).info(message)
117 |
118 |
119 | def success(message: str, rank0_only: bool = True) -> None:
120 | logger.opt(depth=1).bind(rank0_only=rank0_only).success(message)
121 |
122 |
123 | def warning(message: str, rank0_only: bool = True) -> None:
124 | logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message)
125 |
126 |
127 | def error(message: str, rank0_only: bool = True) -> None:
128 | logger.opt(depth=1).bind(rank0_only=rank0_only).error(message)
129 |
130 |
131 | def critical(message: str, rank0_only: bool = True) -> None:
132 | logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message)
133 |
134 |
135 | def exception(message: str, rank0_only: bool = True) -> None:
136 | logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message)
137 |
138 |
139 | def _get_rank(group: dist.ProcessGroup | None = None) -> int:
140 | """Get the rank (GPU device) of the worker.
141 |
142 | Returns:
143 | rank (int): The rank of the worker.
144 | """
145 | rank = 0
146 | if dist.is_available() and dist.is_initialized():
147 | rank = dist.get_rank(group)
148 | return rank
149 |
150 |
151 | # Execute at import time.
152 | init_loguru_stdout()
153 |
--------------------------------------------------------------------------------
/rcm/utils/checkpointer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import annotations
17 |
18 | from typing import List, NamedTuple, Tuple
19 |
20 | import torch
21 |
22 | from imaginaire.utils import log
23 |
24 | TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
25 | if TORCH_VERSION >= (1, 11):
26 | from torch.ao import quantization
27 | from torch.ao.quantization import FakeQuantizeBase, ObserverBase
28 | elif TORCH_VERSION >= (1, 8) and hasattr(torch.quantization, "FakeQuantizeBase") and hasattr(torch.quantization, "ObserverBase"):
29 | from torch import quantization
30 | from torch.quantization import FakeQuantizeBase, ObserverBase
31 |
32 |
33 | class _IncompatibleKeys(
34 | NamedTuple(
35 | "IncompatibleKeys",
36 | [
37 | ("missing_keys", List[str]),
38 | ("unexpected_keys", List[str]),
39 | ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
40 | ],
41 | )
42 | ):
43 | pass
44 |
45 |
46 | # https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py
47 | def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
48 | # workaround https://github.com/pytorch/pytorch/issues/24139
49 | model_state_dict = model.state_dict()
50 | incorrect_shapes = []
51 | for k in list(checkpoint_state_dict.keys()):
52 | if k in model_state_dict:
53 | if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
54 | log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
55 | continue
56 | model_param = model_state_dict[k]
57 | # Allow mismatch for uninitialized parameters
58 | if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
59 | continue
60 | if not isinstance(model_param, torch.Tensor):
61 | raise ValueError(
62 | f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
63 | )
64 |
65 | shape_model = tuple(model_param.shape)
66 | shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
67 | if shape_model != shape_checkpoint:
68 | has_observer_base_classes = (
69 | TORCH_VERSION >= (1, 8) and hasattr(quantization, "ObserverBase") and hasattr(quantization, "FakeQuantizeBase")
70 | )
71 | if has_observer_base_classes:
72 | # Handle the special case of quantization per channel observers,
73 | # where buffer shape mismatches are expected.
74 | def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
75 | # foo.bar.param_or_buffer_name -> [foo, bar]
76 | key_parts = key.split(".")[:-1]
77 | cur_module = model
78 | for key_part in key_parts:
79 | cur_module = getattr(cur_module, key_part)
80 | return cur_module
81 |
82 | cls_to_skip = (
83 | ObserverBase,
84 | FakeQuantizeBase,
85 | )
86 | target_module = _get_module_for_key(model, k)
87 | if isinstance(target_module, cls_to_skip):
88 | # Do not remove modules with expected shape mismatches
89 | # them from the state_dict loading. They have special logic
90 | # in _load_from_state_dict to handle the mismatches.
91 | continue
92 |
93 | incorrect_shapes.append((k, shape_checkpoint, shape_model))
94 | checkpoint_state_dict.pop(k)
95 | incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
96 | # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
97 | missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
98 | unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
99 | return _IncompatibleKeys(
100 | missing_keys=missing_keys,
101 | unexpected_keys=unexpected_keys,
102 | incorrect_shapes=incorrect_shapes,
103 | )
104 |
--------------------------------------------------------------------------------
/rcm/networks/wan2pt1_jvp_test.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
3 | # All rights reserved.
4 | #
5 | # This codebase constitutes NVIDIA proprietary technology and is strictly
6 | # confidential. Any unauthorized reproduction, distribution, or disclosure
7 | # of this code, in whole or in part, outside NVIDIA is strictly prohibited
8 | # without prior written consent.
9 | #
10 | # For inquiries regarding the use of this code in other NVIDIA proprietary
11 | # projects, please contact the Deep Imagination Research Team at
12 | # dir@exchange.nvidia.com.
13 | # -----------------------------------------------------------------------------
14 |
15 | import functools
16 |
17 | import pytest
18 | import torch
19 | from einops import repeat
20 |
21 | from imaginaire.lazy_config import LazyCall as L
22 | from imaginaire.lazy_config import instantiate
23 | from rcm.networks.wan2pt1 import WanModel
24 | from rcm.networks.wan2pt1_jvp import CheckpointMode, SACConfig, WanModel_JVP
25 |
26 | """
27 | Usage:
28 | pytest -s rcm/networks/wan2pt1_jvp_test.py
29 | """
30 |
31 | N_HEADS = 8
32 | HEAD_DIM = 32
33 | mini_net = L(WanModel)(
34 | model_type="t2v",
35 | patch_size=(1, 2, 2),
36 | text_len=512,
37 | in_dim=16,
38 | dim=2048,
39 | ffn_dim=8192,
40 | freq_dim=256,
41 | text_dim=1024,
42 | num_layers=4,
43 | )
44 | mini_net_jvp = L(WanModel_JVP)(
45 | model_type="t2v",
46 | patch_size=(1, 2, 2),
47 | text_len=512,
48 | in_dim=16,
49 | dim=2048,
50 | ffn_dim=8192,
51 | freq_dim=256,
52 | text_dim=1024,
53 | num_layers=4,
54 | )
55 | mini_net_jvp_naive = L(WanModel_JVP)(
56 | model_type="t2v",
57 | patch_size=(1, 2, 2),
58 | text_len=512,
59 | in_dim=16,
60 | dim=2048,
61 | ffn_dim=8192,
62 | freq_dim=256,
63 | text_dim=1024,
64 | num_layers=4,
65 | naive_attn=True,
66 | )
67 |
68 |
69 | @pytest.mark.L1
70 | def test_equivalent_forward_raw_vs_jvp():
71 | dtype = torch.float16
72 | net = instantiate(mini_net).cuda().to(dtype=dtype)
73 | net.eval()
74 | net_jvp = instantiate(mini_net_jvp).cuda().to(dtype=dtype)
75 | net_jvp.eval()
76 |
77 | net_jvp.load_state_dict(net.state_dict(), strict=False)
78 |
79 | batch_size = 2
80 | t = 8
81 | x_B_C_T_H_W = torch.randn(batch_size, 16, t, 40, 40).cuda().to(dtype=dtype)
82 | noise_labels_B = torch.randn(batch_size).cuda().to(dtype=dtype)
83 | noise_labels_BT = repeat(noise_labels_B, "b -> b 1")
84 | crossattn_emb_B_T_D = torch.randn(batch_size, 512, 1024).cuda().to(dtype=dtype)
85 | padding_mask_B_T_H_W = torch.zeros(batch_size, 1, 40, 40).cuda().to(dtype=dtype)
86 |
87 | output_BT = net(x_B_C_T_H_W, noise_labels_BT, crossattn_emb_B_T_D, padding_mask=padding_mask_B_T_H_W)
88 | output_BT_jvp = net_jvp(x_B_C_T_H_W, noise_labels_BT, crossattn_emb_B_T_D, padding_mask=padding_mask_B_T_H_W)
89 | torch.testing.assert_close(output_BT, output_BT_jvp, rtol=1e-3, atol=1e-3)
90 |
91 |
92 | """
93 | Usage:
94 | pytest -s projects/cosmos/diffusion/v2/networks/wan2pt1_jvp_test.py --all -k test_equivalent_jvp_naive_vs_flash
95 | """
96 |
97 |
98 | @pytest.mark.L1
99 | def test_equivalent_jvp_naive_vs_flash():
100 | dtype = torch.float16
101 | net_jvp = instantiate(mini_net_jvp, sac_config=SACConfig(mode=CheckpointMode.NONE)).cuda().to(dtype=dtype)
102 | net_jvp.eval()
103 | net_jvp_naive = instantiate(mini_net_jvp_naive, sac_config=SACConfig(mode=CheckpointMode.NONE)).cuda().to(dtype=dtype)
104 | net_jvp_naive.eval()
105 | net_jvp_naive.load_state_dict(net_jvp.state_dict(), strict=False)
106 |
107 | batch_size = 2
108 | t = 8
109 | x_B_C_T_H_W = torch.randn(batch_size, 16, t, 40, 40).cuda().to(dtype=dtype)
110 | t_x_B_C_T_H_W = torch.randn_like(x_B_C_T_H_W)
111 | noise_labels_B = torch.randn(batch_size).cuda().to(dtype=dtype)
112 | t_noise_labels_B = torch.randn_like(noise_labels_B)
113 | noise_labels_BT = repeat(noise_labels_B, "b -> b 1")
114 | t_noise_labels_BT = repeat(t_noise_labels_B, "b -> b 1")
115 | crossattn_emb_B_T_D = torch.randn(batch_size, 512, 1024).cuda().to(dtype=dtype)
116 | fps_B = torch.randint(size=(1,), low=2, high=30).cuda().float().repeat(batch_size)
117 | padding_mask_B_T_H_W = torch.zeros(batch_size, 1, 40, 40).cuda().to(dtype=dtype)
118 |
119 | output_BT_withoutT = net_jvp(x_B_C_T_H_W, noise_labels_BT, crossattn_emb_B_T_D, fps=fps_B, padding_mask=padding_mask_B_T_H_W)
120 | output_BT, t_output_BT = net_jvp(
121 | (x_B_C_T_H_W, t_x_B_C_T_H_W),
122 | (noise_labels_BT, t_noise_labels_BT),
123 | crossattn_emb_B_T_D,
124 | fps=fps_B,
125 | padding_mask=padding_mask_B_T_H_W,
126 | withT=True,
127 | )
128 |
129 | fn_naive = functools.partial(net_jvp_naive.forward, crossattn_emb=crossattn_emb_B_T_D, fps=fps_B, padding_mask=padding_mask_B_T_H_W)
130 |
131 | output_BT_naive, t_output_BT_naive = torch.func.jvp(fn_naive, (x_B_C_T_H_W, noise_labels_BT), (t_x_B_C_T_H_W, t_noise_labels_BT))
132 |
133 | torch.testing.assert_close(output_BT_withoutT, output_BT_naive, rtol=1e-3, atol=1e-3)
134 | torch.testing.assert_close(output_BT_withoutT, output_BT, rtol=1e-3, atol=1e-3)
135 | torch.testing.assert_close(t_output_BT, t_output_BT_naive, rtol=1e-3, atol=1e-3)
136 |
--------------------------------------------------------------------------------
/imaginaire/model.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Any
17 |
18 | import torch
19 |
20 | from imaginaire.lazy_config import LazyDict, instantiate
21 |
22 |
23 | class ImaginaireModel(torch.nn.Module):
24 | """The base model class of Imaginaire. It is inherited from torch.nn.Module.
25 |
26 | All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
27 | computation graphs. All inheriting child classes should implement the following methods:
28 | - training_step(): The training step of the model, including the loss computation.
29 | - validation_step(): The validation step of the model, including the loss computation.
30 | - forward(): The computation graph for model inference.
31 | The following methods have default implementations in ImaginaireModel:
32 | - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
33 | """
34 |
35 | def __init__(self) -> None:
36 | super().__init__()
37 |
38 | def init_optimizer_scheduler(
39 | self,
40 | optimizer_config: LazyDict[torch.optim.Optimizer],
41 | scheduler_config: LazyDict[torch.optim.lr_scheduler.LRScheduler],
42 | ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
43 | """Creates the optimizer and scheduler for the model.
44 |
45 | Args:
46 | config_model (ModelConfig): The config object for the model.
47 |
48 | Returns:
49 | optimizer (torch.optim.Optimizer): The model optimizer.
50 | scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
51 | """
52 | optimizer_config.params = self.parameters()
53 | optimizer = instantiate(optimizer_config)
54 | scheduler_config.optimizer = optimizer
55 | scheduler = instantiate(scheduler_config)
56 | return optimizer, scheduler
57 |
58 | def training_step(
59 | self, data_batch: dict[str, torch.Tensor], iteration: int
60 | ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
61 | """The training step of the model, including the loss computation.
62 |
63 | Args:
64 | data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
65 | iteration (int): Current iteration number.
66 |
67 | Returns:
68 | output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
69 | loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
70 | """
71 | raise NotImplementedError
72 |
73 | @torch.no_grad()
74 | def validation_step(
75 | self, data_batch: dict[str, torch.Tensor], iteration: int
76 | ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
77 | """The validation step of the model, including the loss computation.
78 |
79 | Args:
80 | data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
81 | iteration (int): Current iteration number.
82 |
83 | Returns:
84 | output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
85 | loss (torch.Tensor): The total loss (weighted sum of various losses).
86 | """
87 | raise NotImplementedError
88 |
89 | @torch.inference_mode()
90 | def forward(self, *args: Any, **kwargs: Any) -> Any:
91 | """The computation graph for model inference.
92 |
93 | Args:
94 | *args: Whatever you decide to pass into the forward method.
95 | **kwargs: Keyword arguments are also possible.
96 |
97 | Return:
98 | Your model's output.
99 | """
100 | raise NotImplementedError
101 |
102 | def on_model_init_start(self, set_barrier=False) -> None:
103 | return
104 |
105 | def on_model_init_end(self, set_barrier=False) -> None:
106 | return
107 |
108 | def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
109 | """The model preparation before the training is launched
110 |
111 | Args:
112 | memory_format (torch.memory_format): Memory format of the model.
113 | """
114 | pass
115 |
116 | def on_before_zero_grad(
117 | self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
118 | ) -> None:
119 | """Hook before zero_grad() is called.
120 |
121 | Args:
122 | optimizer (torch.optim.Optimizer): The model optimizer.
123 | scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
124 | iteration (int): Current iteration number.
125 | """
126 | pass
127 |
128 | def on_after_backward(self, iteration: int = 0) -> None:
129 | """Hook after loss.backward() is called.
130 |
131 | This method is called immediately after the backward pass, allowing for custom operations
132 | or modifications to be performed on the gradients before the optimizer step.
133 |
134 | Args:
135 | iteration (int): Current iteration number.
136 | """
137 | pass
138 |
--------------------------------------------------------------------------------
/imaginaire/utils/profiling.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import contextlib
17 | import os
18 | import time
19 |
20 | import torch
21 |
22 | from imaginaire.utils import distributed, log
23 | from imaginaire.utils.easy_io import easy_io
24 |
25 | # the number of warmup steps before the active step in each profiling cycle
26 | TORCH_TRACE_WARMUP = 3
27 |
28 | # how much memory allocation/free ops to record in memory snapshots
29 | MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
30 |
31 |
32 | @contextlib.contextmanager
33 | def maybe_enable_profiling(config, *, global_step: int = 0):
34 | # get user defined profiler settings
35 | enable_profiling = config.trainer.profiling.enable_profiling
36 | profile_freq = config.trainer.profiling.profile_freq
37 |
38 | if enable_profiling:
39 | trace_dir = os.path.join(config.job.path_local, "torch_trace")
40 | if distributed.get_rank() == 0:
41 | os.makedirs(trace_dir, exist_ok=True)
42 |
43 | rank = distributed.get_rank()
44 |
45 | def trace_handler(prof):
46 | curr_trace_dir_name = "iteration_" + str(prof.step_num)
47 | curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
48 | if not os.path.exists(curr_trace_dir):
49 | os.makedirs(curr_trace_dir, exist_ok=True)
50 |
51 | log.info(f"Dumping traces at step {prof.step_num}")
52 | begin = time.monotonic()
53 | if config.trainer.profiling.first_n_rank < 0 or rank < config.trainer.profiling.first_n_rank:
54 | prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json.gz") # saved as gz to save space
55 | log.info(f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds")
56 |
57 | log.info(f"Profiling active. Traces will be saved at {trace_dir}")
58 |
59 | if not os.path.exists(trace_dir):
60 | os.makedirs(trace_dir, exist_ok=True)
61 |
62 | warmup, active = TORCH_TRACE_WARMUP, 1
63 | wait = profile_freq - (active + warmup)
64 | assert wait >= 0, "profile_freq must be greater than or equal to warmup + active"
65 |
66 | with torch.profiler.profile(
67 | activities=[
68 | torch.profiler.ProfilerActivity.CPU,
69 | torch.profiler.ProfilerActivity.CUDA,
70 | ],
71 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
72 | on_trace_ready=trace_handler,
73 | record_shapes=config.trainer.profiling.record_shape,
74 | profile_memory=config.trainer.profiling.profile_memory,
75 | with_stack=config.trainer.profiling.with_stack,
76 | with_modules=config.trainer.profiling.with_modules,
77 | ) as torch_profiler:
78 | torch_profiler.step_num = global_step
79 | yield torch_profiler
80 | else:
81 | torch_profiler = contextlib.nullcontext()
82 | yield None
83 |
84 |
85 | @contextlib.contextmanager
86 | def maybe_enable_memory_snapshot(config, *, global_step: int = 0):
87 | enable_snapshot = config.trainer.profiling.enable_memory_snapshot
88 | if enable_snapshot:
89 | snapshot_dir = os.path.join(config.job.path_local, "memory_snapshot")
90 | if distributed.get_rank() == 0:
91 | os.makedirs(snapshot_dir, exist_ok=True)
92 |
93 | rank = torch.distributed.get_rank()
94 |
95 | class MemoryProfiler:
96 | def __init__(self, step_num: int, freq: int):
97 | torch.cuda.memory._record_memory_history(max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES)
98 | # when resume training, we start from the last step
99 | self.step_num = step_num
100 | self.freq = freq
101 |
102 | def step(self, exit_ctx: bool = False):
103 | self.step_num += 1
104 | if not exit_ctx and self.step_num % self.freq != 0:
105 | return
106 | if not exit_ctx:
107 | curr_step = self.step_num
108 | dir_name = f"iteration_{curr_step}"
109 | else:
110 | # dump as iteration_0_exit if OOM at iter 1
111 | curr_step = self.step_num - 1
112 | dir_name = f"iteration_{curr_step}_exit"
113 | curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
114 | if not os.path.exists(curr_snapshot_dir):
115 | os.makedirs(curr_snapshot_dir, exist_ok=True)
116 | log.info(f"Dumping memory snapshot at step {curr_step}")
117 | begin = time.monotonic()
118 |
119 | if config.trainer.profiling.first_n_rank < 0 or rank < config.trainer.profiling.first_n_rank:
120 | easy_io.dump(
121 | torch.cuda.memory._snapshot(),
122 | f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle",
123 | )
124 | log.info(f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds")
125 |
126 | log.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
127 | profiler = MemoryProfiler(global_step, config.trainer.profiling.profile_freq)
128 | try:
129 | yield profiler
130 | except torch.cuda.OutOfMemoryError as e:
131 | profiler.step(exit_ctx=True)
132 | else:
133 | yield None
134 |
--------------------------------------------------------------------------------
/rcm/utils/fsdp_helper.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import annotations
17 |
18 | from contextlib import contextmanager
19 | from functools import partial
20 |
21 | import torch
22 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
23 | CheckpointImpl,
24 | apply_activation_checkpointing,
25 | checkpoint_wrapper,
26 | )
27 | from torch.distributed.device_mesh import init_device_mesh
28 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
29 | from torch.distributed.fsdp._runtime_utils import (
30 | _post_forward,
31 | _post_forward_reshard,
32 | _pre_forward,
33 | _pre_forward_unshard,
34 | _root_pre_forward,
35 | )
36 | from torch.distributed.utils import _p_assert
37 |
38 | from imaginaire.utils import distributed, log
39 |
40 |
41 | def apply_fsdp_checkpointing(model, list_block_cls):
42 | """apply activation checkpointing to model
43 | returns None as model is updated directly
44 | """
45 | log.critical("--> applying fdsp activation checkpointing...")
46 | non_reentrant_wrapper = partial(
47 | checkpoint_wrapper,
48 | # offload_to_cpu=False,
49 | checkpoint_impl=CheckpointImpl.NO_REENTRANT,
50 | )
51 |
52 | def check_fn(submodule):
53 | result = False
54 | for block_cls in list_block_cls:
55 | if isinstance(submodule, block_cls):
56 | result = True
57 | break
58 | return result
59 |
60 | apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
61 |
62 |
63 | @contextmanager
64 | def possible_fsdp_scope(
65 | model: torch.nn.Module,
66 | ):
67 | enabled = isinstance(model, FSDP)
68 | if enabled:
69 | assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled"
70 | handle = model._handle
71 | args, kwargs = [0], dict(dummy=0)
72 | with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"):
73 | args, kwargs = _root_pre_forward(model, model, args, kwargs)
74 | unused = None
75 | args, kwargs = _pre_forward(
76 | model,
77 | handle,
78 | _pre_forward_unshard,
79 | model._fsdp_wrapped_module,
80 | args,
81 | kwargs,
82 | )
83 | if handle:
84 | _p_assert(
85 | handle.flat_param.device == model.compute_device,
86 | "Expected `FlatParameter` to be on the compute device " f"{model.compute_device} but got {handle.flat_param.device}",
87 | )
88 | try:
89 | yield None
90 | finally:
91 | if enabled:
92 | output = {"output": 1}
93 | _post_forward(model, handle, _post_forward_reshard, model, unused, output)
94 |
95 |
96 | def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None):
97 | """
98 | Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training.
99 |
100 | This function requires explicit sizes for replica and sharding groups to accommodate models
101 | whose GPU fit is unknown, providing flexibility in distributed training setups.
102 |
103 | Args:
104 | replica_group_size (int): The size of each replica group. Must be provided to ensure
105 | the model fits within the available resources.
106 | sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to
107 | ensure the correct distribution of model parameters.
108 | device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
109 | with the local rank as the device index.
110 |
111 | Returns:
112 | A device mesh object compatible with FSDP.
113 |
114 | Raises:
115 | ValueError: If replica_group_size or sharding_group_size are not provided, or if the
116 | world size is not evenly divisible by the sharding group size.
117 | RuntimeError: If a valid device mesh cannot be created.
118 |
119 | Usage:
120 | If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then:
121 | Sharding_Group_Size = 4
122 | Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups
123 | >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size)
124 | >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...)
125 | """
126 |
127 | # world_size = int(os.getenv("WORLD_SIZE", "1"))
128 | world_size = distributed.get_world_size()
129 | if sharding_group_size is None:
130 | sharding_group_size = min(world_size, 8)
131 | sharding_group_size = min(sharding_group_size, world_size)
132 | if replica_group_size is None:
133 | replica_group_size = world_size // sharding_group_size
134 |
135 | device = device or "cuda"
136 |
137 | if world_size % sharding_group_size != 0:
138 | raise ValueError(f"World size {world_size} is not evenly divisible by sharding group size {sharding_group_size}.")
139 |
140 | if (world_size // sharding_group_size) % replica_group_size != 0:
141 | raise ValueError(f"The calculated number of replica groups is not evenly divisible by " f"replica_group_size {replica_group_size}.")
142 |
143 | device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard"))
144 | if device_mesh is None:
145 | raise RuntimeError("Failed to create a valid device mesh.")
146 |
147 | log.critical(f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}")
148 |
149 | return device_mesh
150 |
--------------------------------------------------------------------------------
/rcm/utils/attention.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Single point of entry for all generic attention ops (self and cross attention), that tries to
18 | # deliver the best performance possible given any use case (GPU and environment).
19 | #
20 | # On Hopper GPUs (i.e. H100, H20, H200), Flash Attention 3 is the best-performing choice, but it
21 | # needs to be installed. When it is not available, the second best choice is cuDNN attention, which
22 | # we get using PyTorch's SDPA API.
23 | #
24 | # For all other use cases, we will just use PyTorch's SDPA, but we need to specify backends and
25 | # priorities.
26 | # Flash Attention 2, which is one of the backends, is the best choice for Ampere GPUs (both RTX and
27 | # datacenter-class).
28 | #
29 | # For anything pre-Ampere, the only choice is "memory-efficient" (xformers) FMHA.
30 | #
31 | # For Ada and Blackwell RTX, it is unclear at the moment, so we defer to Flash Attention 2, and
32 | # fallbacks are cuDNN and xformers.
33 | #
34 | # For Blackwell datacenter-class (B200, GB200), cuDNN is the best choice.
35 | #
36 | #
37 | # Dispatching to the desired backends/paths are done by checking the compute capability (really SM
38 | # number, which is just compute capability * 10) of the GPU device the input tensors are on.
39 | #
40 | # Here's a breakdown of relevant compute capabilities:
41 | #
42 | # | GPU / category | Arch |
43 | # |================|=======|
44 | # | A100 | SM80 |
45 | # | A40 | SM80 |
46 | # | Ampere RTX | SM86 |
47 | # |----------------|-------|
48 | # | Ada Lovelace | SM89 |
49 | # |----------------|-------|
50 | # | H20 | SM90 |
51 | # | H100 | SM90 |
52 | # | H200 | SM90 |
53 | # |----------------|-------|
54 | # | B200 | SM100 |
55 | # | Blackwell RTX | SM103 |
56 | # |----------------|-------|
57 | #
58 |
59 | from functools import partial
60 |
61 | import torch
62 | from torch.nn.attention import SDPBackend, sdpa_kernel
63 |
64 | try:
65 | from flash_attn_3.flash_attn_interface import flash_attn_func
66 |
67 | FLASH_ATTN_3_AVAILABLE = True
68 | except ModuleNotFoundError:
69 | FLASH_ATTN_3_AVAILABLE = False
70 |
71 |
72 | def get_device_cc(device) -> int:
73 | """
74 | Returns the compute capability of a given torch device if it's a CUDA device, otherwise returns 0.
75 |
76 | Args:
77 | device: torch device.
78 |
79 | Returns:
80 | device_cc (int): compute capability in the SmXXX format (i.e. 90 for Hopper).
81 | """
82 | if torch.cuda.is_available() and torch.version.cuda and device.type == "cuda":
83 | major, minor = torch.cuda.get_device_capability(device)
84 | return major * 10 + minor
85 | return 0
86 |
87 |
88 | def attention(
89 | q,
90 | k,
91 | v,
92 | dropout_p=0.0,
93 | softmax_scale=None,
94 | q_scale=None,
95 | causal=False,
96 | deterministic=False,
97 | ):
98 | assert q.dtype == k.dtype and k.dtype == v.dtype
99 | dtype = q.dtype
100 | supported_dtypes = [torch.bfloat16, torch.float16, torch.float32]
101 | is_half = dtype in [torch.bfloat16, torch.float16]
102 | compute_cap = get_device_cc(q.device)
103 |
104 | if dtype not in supported_dtypes:
105 | raise NotImplementedError(f"{dtype=} is not supported.")
106 |
107 | if q_scale is not None:
108 | q = q * q_scale
109 |
110 | # If Flash Attention 3 is installed, and the user's running on a Hopper GPU (compute capability
111 | # 9.0, or SM90), use Flash Attention 3.
112 | if compute_cap == 90 and FLASH_ATTN_3_AVAILABLE and is_half:
113 | return flash_attn_func(
114 | q=q,
115 | k=k,
116 | v=v,
117 | softmax_scale=softmax_scale,
118 | causal=causal,
119 | deterministic=deterministic,
120 | )[0]
121 | else:
122 | # If Blackwell or Hopper (SM100 or SM90), cuDNN has native FMHA kernels. The Hopper one is
123 | # not always as fast as Flash Attention 3, but when Flash Attention is unavailable, it's
124 | # still a far better choice than Flash Attention 2 (Ampere).
125 | if compute_cap in [90, 100] and is_half:
126 | SDPA_BACKENDS = [
127 | SDPBackend.CUDNN_ATTENTION,
128 | SDPBackend.FLASH_ATTENTION,
129 | SDPBackend.EFFICIENT_ATTENTION,
130 | ]
131 | BEST_SDPA_BACKEND = SDPBackend.CUDNN_ATTENTION
132 | elif is_half:
133 | SDPA_BACKENDS = [
134 | SDPBackend.FLASH_ATTENTION,
135 | SDPBackend.CUDNN_ATTENTION,
136 | SDPBackend.EFFICIENT_ATTENTION,
137 | ]
138 | BEST_SDPA_BACKEND = SDPBackend.FLASH_ATTENTION if compute_cap >= 80 else SDPBackend.EFFICIENT_ATTENTION
139 | else:
140 | assert dtype == torch.float32, f"Unrecognized {dtype=}."
141 | SDPA_BACKENDS = [SDPBackend.EFFICIENT_ATTENTION]
142 | BEST_SDPA_BACKEND = SDPBackend.EFFICIENT_ATTENTION
143 |
144 | if deterministic:
145 | raise NotImplementedError("Deterministic mode in attention is only supported when Flash Attention 3 is available.")
146 |
147 | # Torch 2.6 and later allows priorities for backends, but for older versions
148 | # we can only run with a specific backend. As long as we pick ones we're certain
149 | # will work on that device, it should be fine.
150 | try:
151 | sdpa_kernel(backends=SDPA_BACKENDS, set_priority_order=True)
152 | sdpa_kernel_ = partial(sdpa_kernel, set_priority_order=True)
153 | except TypeError:
154 | sdpa_kernel_ = sdpa_kernel
155 | SDPA_BACKENDS = [BEST_SDPA_BACKEND]
156 |
157 | q = q.transpose(1, 2)
158 | k = k.transpose(1, 2)
159 | v = v.transpose(1, 2)
160 |
161 | with sdpa_kernel_(backends=SDPA_BACKENDS):
162 | out = torch.nn.functional.scaled_dot_product_attention(
163 | q,
164 | k,
165 | v,
166 | is_causal=causal,
167 | dropout_p=dropout_p,
168 | scale=softmax_scale,
169 | )
170 |
171 | out = out.transpose(1, 2).contiguous()
172 | return out
173 |
--------------------------------------------------------------------------------
/imaginaire/utils/easy_io/handlers/imageio_video_handler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import IO, Any
17 |
18 | import imageio
19 | import imageio.v3 as iio_v3
20 | import numpy as np
21 | import torch
22 |
23 | from imaginaire.utils import log
24 | from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
25 |
26 |
27 | class ImageioVideoHandler(BaseFileHandler):
28 | str_like = False
29 |
30 | def load_from_fileobj(
31 | self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs
32 | ) -> tuple[np.ndarray, dict[str, Any]]:
33 | """
34 | Load video from a file-like object using imageio.v3 with specified format and color mode.
35 |
36 | Parameters:
37 | file (IO[bytes]): A file-like object containing video data.
38 | format (str): Format of the video file (default 'mp4').
39 | mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb').
40 |
41 | Returns:
42 | tuple: A tuple containing an array of video frames and metadata about the video.
43 | """
44 | file.seek(0)
45 |
46 | # The plugin argument in v3 replaces the format argument in v2
47 | plugin = kwargs.pop("plugin", "pyav")
48 |
49 | # Load all frames at once using v3 API
50 | video_frames = iio_v3.imread(file, plugin=plugin, **kwargs)
51 |
52 | # Handle grayscale conversion if needed
53 | if mode == "gray":
54 | import cv2
55 |
56 | if len(video_frames.shape) == 4: # (frames, height, width, channels)
57 | gray_frames = []
58 | for frame in video_frames:
59 | gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
60 | gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent
61 | gray_frames.append(gray_frame)
62 | video_frames = np.array(gray_frames)
63 |
64 | # Extract metadata
65 | # Note: iio_v3.imread doesn't return metadata directly like v2 did
66 | # We need to extract it separately
67 | file.seek(0)
68 | metadata = self._extract_metadata(file, plugin=plugin)
69 |
70 | return video_frames, metadata
71 |
72 | def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> dict[str, Any]:
73 | """
74 | Extract metadata from a video file.
75 |
76 | Parameters:
77 | file (IO[bytes]): File-like object containing video data.
78 | plugin (str): Plugin to use for reading.
79 |
80 | Returns:
81 | dict: Video metadata.
82 | """
83 | try:
84 | # Create a generator to read frames and metadata
85 | metadata = iio_v3.immeta(file, plugin=plugin)
86 |
87 | # Add some standard fields similar to v2 metadata format
88 | if "fps" not in metadata and "duration" in metadata:
89 | # Read the first frame to get shape information
90 | file.seek(0)
91 | first_frame = iio_v3.imread(file, plugin=plugin, index=0)
92 | metadata["size"] = first_frame.shape[1::-1] # (width, height)
93 | metadata["source_size"] = metadata["size"]
94 |
95 | # Create a consistent metadata structure with v2
96 | metadata["plugin"] = plugin
97 | if "codec" not in metadata:
98 | metadata["codec"] = "unknown"
99 | if "pix_fmt" not in metadata:
100 | metadata["pix_fmt"] = "unknown"
101 |
102 | # Calculate nframes if possible
103 | if "fps" in metadata and "duration" in metadata:
104 | metadata["nframes"] = int(metadata["fps"] * metadata["duration"])
105 | else:
106 | metadata["nframes"] = float("inf")
107 |
108 | return metadata
109 |
110 | except Exception as e:
111 | # Fallback to basic metadata
112 | return {
113 | "plugin": plugin,
114 | "nframes": float("inf"),
115 | "codec": "unknown",
116 | "fps": 30.0, # Default values
117 | "duration": 0,
118 | "size": (0, 0),
119 | }
120 |
121 | def dump_to_fileobj(
122 | self,
123 | obj: np.ndarray | torch.Tensor,
124 | file: IO[bytes],
125 | format: str = "mp4", # pylint: disable=redefined-builtin
126 | fps: int = 17,
127 | quality: int = 7,
128 | ffmpeg_params=None,
129 | **kwargs,
130 | ):
131 | """
132 | Save an array of video frames to a file-like object using imageio.
133 |
134 | Parameters:
135 | obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video.
136 | file (IO[bytes]): A file-like object to which the video data will be written.
137 | format (str): Format of the video file (default 'mp4').
138 | fps (int): Frames per second of the output video (default 17).
139 | quality (int): Quality of the video (0-10, default 5).
140 | ffmpeg_params (list): Additional parameters to pass to ffmpeg.
141 |
142 | """
143 | if isinstance(obj, torch.Tensor):
144 | assert obj.dtype == torch.uint8, "Tensor must be of type uint8"
145 | obj = obj.cpu().numpy()
146 | h, w = obj.shape[1:-1]
147 |
148 | # Default ffmpeg params that ensure width and height are set
149 | default_ffmpeg_params = ["-s", f"{w}x{h}"]
150 |
151 | # Use provided ffmpeg_params if any, otherwise use defaults
152 | final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params
153 |
154 | mimsave_kwargs = {
155 | "fps": fps,
156 | "quality": quality,
157 | "macro_block_size": 1,
158 | "ffmpeg_params": final_ffmpeg_params,
159 | "output_params": ["-f", "mp4"],
160 | }
161 | # Update with any other kwargs
162 | mimsave_kwargs.update(kwargs)
163 | log.debug(f"mimsave_kwargs: {mimsave_kwargs}")
164 |
165 | imageio.mimsave(file, obj, format, **mimsave_kwargs)
166 |
167 | def dump_to_str(self, obj, **kwargs):
168 | raise NotImplementedError
169 |
--------------------------------------------------------------------------------
/imaginaire/utils/io.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from typing import IO, Any
18 |
19 | import numpy as np
20 | import torch
21 | from einops import rearrange
22 | from PIL import Image as PILImage
23 | from torch import Tensor
24 |
25 | from imaginaire.utils import log
26 | from imaginaire.utils.easy_io import easy_io
27 |
28 |
29 | def save_image_or_video_multiview(
30 | tensor: Tensor, save_path: str | IO[Any], fps: int = 24, quality=None, ffmpeg_params=None, n_views: int = 1
31 | ) -> None:
32 | """
33 | Split the tensor into n_views, stack them along the width dimension, and save as a video
34 | Args:
35 | tensor (Tensor): Input tensor with shape (B, C, T, H, W) or (C, T, H, W) in [-1, 1] or [0, 1] range.
36 | If in [-1, 1] range, it will be automatically converted to [0, 1] range.
37 | save_path (Union[str, IO[Any]]): File path (with or without extension) or file-like object.
38 | fps (int): Frames per second for video. Default is 24.
39 | quality: Optional quality parameter for images (passed to easy_io).
40 | ffmpeg_params: Optional ffmpeg parameters for videos (passed to easy_io).
41 | """
42 | # Handle batch dimension if present
43 | if tensor.ndim == 5:
44 | tensor = tensor[0] # Take the first item from the batch
45 |
46 | assert tensor.ndim == 4, "Tensor must have shape (C, T, H, W) or (B, C, T, H, W)"
47 | assert isinstance(save_path, str) or hasattr(save_path, "write"), "save_path must be a string or file-like object"
48 |
49 | # Normalize to [0, 1] range
50 | if torch.is_floating_point(tensor):
51 | # Check if tensor is in [-1, 1] range (approximately)
52 | if tensor.min() < -0.5:
53 | # Convert from [-1, 1] to [0, 1]
54 | tensor = (tensor + 1.0) / 2.0
55 | tensor = tensor.clamp(0, 1)
56 | else:
57 | assert tensor.dtype == torch.uint8, "Only support uint8 tensor"
58 | tensor = tensor.float().div(255)
59 |
60 | kwargs = {}
61 | if quality is not None:
62 | kwargs["quality"] = quality
63 | if ffmpeg_params is not None:
64 | kwargs["ffmpeg_params"] = ffmpeg_params
65 |
66 | save_obj = (rearrange((tensor.cpu().float().numpy() * 255), "c (v t) h w -> t (v h) w c", v=n_views) + 0.5).astype(
67 | np.uint8
68 | )
69 | if isinstance(save_path, str):
70 | # Check if path already has an extension
71 | base, ext = os.path.splitext(save_path)
72 | if not ext:
73 | save_path = f"{base}.mp4"
74 | log.info(f"Saving video to {save_path} with fps {fps} and result shape {save_obj.shape}")
75 | easy_io.dump(save_obj, save_path, file_format="mp4", format="mp4", fps=fps, **kwargs)
76 |
77 |
78 | def save_image_or_video(
79 | tensor: Tensor, save_path: str | IO[Any], fps: int = 24, quality=None, ffmpeg_params=None
80 | ) -> None:
81 | """
82 | Save a tensor as an image or video file based on shape
83 |
84 | Args:
85 | tensor (Tensor): Input tensor with shape (B, C, T, H, W) or (C, T, H, W) in [-1, 1] or [0, 1] range.
86 | If in [-1, 1] range, it will be automatically converted to [0, 1] range.
87 | save_path (Union[str, IO[Any]]): File path (with or without extension) or file-like object.
88 | fps (int): Frames per second for video. Default is 24.
89 | quality: Optional quality parameter for images (passed to easy_io).
90 | ffmpeg_params: Optional ffmpeg parameters for videos (passed to easy_io).
91 | """
92 | # Handle batch dimension if present
93 | if tensor.ndim == 5:
94 | tensor = tensor[0] # Take the first item from the batch
95 |
96 | assert tensor.ndim == 4, "Tensor must have shape (C, T, H, W) or (B, C, T, H, W)"
97 | assert isinstance(save_path, str) or hasattr(save_path, "write"), "save_path must be a string or file-like object"
98 |
99 | # Normalize to [0, 1] range
100 | if torch.is_floating_point(tensor):
101 | # Check if tensor is in [-1, 1] range (approximately)
102 | if tensor.min() < -0.5:
103 | # Convert from [-1, 1] to [0, 1]
104 | tensor = (tensor + 1.0) / 2.0
105 | tensor = tensor.clamp(0, 1)
106 | else:
107 | assert tensor.dtype == torch.uint8, "Only support uint8 tensor"
108 | tensor = tensor.float().div(255)
109 |
110 | kwargs = {}
111 | if quality is not None:
112 | kwargs["quality"] = quality
113 | if ffmpeg_params is not None:
114 | kwargs["ffmpeg_params"] = ffmpeg_params
115 |
116 | if tensor.shape[1] == 1:
117 | save_obj = PILImage.fromarray(
118 | (rearrange((tensor.cpu().float().numpy() * 255), "c 1 h w -> h w c") + 0.5).astype(np.uint8),
119 | mode="RGB",
120 | )
121 | if isinstance(save_path, str):
122 | # Check if path already has an extension
123 | base, ext = os.path.splitext(save_path)
124 | if not ext:
125 | save_path = f"{base}.jpg"
126 | easy_io.dump(save_obj, save_path, file_format="jpg", format="JPEG", quality=85, **kwargs)
127 | else:
128 | save_obj = (rearrange((tensor.cpu().float().numpy() * 255), "c t h w -> t h w c") + 0.5).astype(np.uint8)
129 | if isinstance(save_path, str):
130 | # Check if path already has an extension
131 | base, ext = os.path.splitext(save_path)
132 | if not ext:
133 | save_path = f"{base}.mp4"
134 | easy_io.dump(save_obj, save_path, file_format="mp4", format="mp4", fps=fps, **kwargs)
135 |
136 |
137 | def save_text_prompts(prompts: dict[str | list], save_path: str | IO[Any]) -> None:
138 | """
139 | Save text prompts to a file.
140 |
141 | Args:
142 | prompts (dict[str]): Dictionary of text prompts to save. Expected keys: "prompt", "negative_prompt", "refined_prompt".
143 | save_path (Union[str, IO[Any]]): File path (with or without extension) or file-like object.
144 | """
145 | if isinstance(save_path, str):
146 | base, ext = os.path.splitext(save_path)
147 | if not ext:
148 | save_path = f"{base}.txt"
149 | with open(save_path, "w") as f:
150 | f.write(f"[Prompt]\n{prompts['prompt']}\n")
151 | if prompts.get("negative_prompt"):
152 | f.write(f"[Negative Prompt]\n{prompts['negative_prompt']}\n")
153 |
154 | if prompts.get("refined_prompt"):
155 | if isinstance(prompts["refined_prompt"], str):
156 | f.write(f"[Refined Prompt]\n{prompts['refined_prompt']}\n")
157 | elif isinstance(prompts["refined_prompt"], list):
158 | for chunk_id, refined_prompt in enumerate(prompts["refined_prompt"]):
159 | f.write(f"[Refined Prompt for chunk {chunk_id}]\n{refined_prompt}\n")
160 | else:
161 | raise ValueError("refined_prompt must be a string or a list of strings")
162 |
--------------------------------------------------------------------------------