├── cosmos_predict1 ├── tokenizer │ ├── __init__.py │ ├── training │ │ ├── __init__.py │ │ ├── datasets │ │ │ └── __init__.py │ │ ├── configs │ │ │ ├── base │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ ├── metric.py │ │ │ │ ├── callback.py │ │ │ │ ├── optim.py │ │ │ │ ├── checkpoint.py │ │ │ │ └── data.py │ │ │ ├── experiments │ │ │ │ ├── __init__.py │ │ │ │ └── utils.py │ │ │ ├── __init__.py │ │ │ └── config.py │ │ └── losses │ │ │ └── __init__.py │ ├── test_data │ │ └── video.mp4 │ ├── inference │ │ └── __init__.py │ └── modules │ │ ├── distributions.py │ │ └── __init__.py ├── diffusion │ ├── config │ │ ├── __init__.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ ├── tokenizer.py │ │ │ ├── net.py │ │ │ └── model.py │ │ ├── inference │ │ │ ├── __init__.py │ │ │ ├── cosmos-1-diffusion-text2world-multiview.py │ │ │ ├── cosmos-1-diffusion-video2world-multiview.py │ │ │ └── cosmos-1-diffusion-gen3c.py │ │ └── config.py │ ├── module │ │ └── __init__.py │ ├── networks │ │ └── __init__.py │ ├── training │ │ ├── datasets │ │ │ └── data_sources │ │ │ │ └── item_dataset.py │ │ ├── config │ │ │ ├── video2world │ │ │ │ └── registry.py │ │ │ ├── world_interpolator │ │ │ │ └── registry.py │ │ │ ├── text2world_multiview │ │ │ │ └── registry.py │ │ │ ├── video2world_multiview │ │ │ │ └── registry.py │ │ │ ├── video2world_instruction │ │ │ │ └── registry.py │ │ │ └── base │ │ │ │ ├── ema.py │ │ │ │ ├── optim.py │ │ │ │ ├── vae.py │ │ │ │ └── model.py │ │ ├── trainer.py │ │ ├── callbacks │ │ │ └── low_precision.py │ │ ├── utils │ │ │ ├── peft │ │ │ │ └── lora_config.py │ │ │ └── optim_instantiate.py │ │ └── modules │ │ │ └── edm_sde.py │ ├── modules │ │ └── denoiser_scaling.py │ ├── types.py │ ├── utils │ │ └── customization │ │ │ └── customization_manager.py │ ├── functional │ │ ├── batch_ops.py │ │ └── multi_step.py │ └── checkpointers │ │ └── ema_fsdp_checkpointer.py ├── autoregressive │ ├── tokenizer │ │ ├── __init__.py │ │ └── networks.py │ ├── configs │ │ ├── experiment │ │ │ └── video2video │ │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── base │ │ │ ├── __init__.py │ │ │ ├── model_parallel.py │ │ │ ├── callbacks.py │ │ │ ├── dataset.py │ │ │ └── dataloader.py │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── misc.py │ ├── inference │ │ └── __init__.py │ ├── modules │ │ └── __init__.py │ └── diffusion_decoder │ │ ├── __init__.py │ │ └── config │ │ ├── base │ │ └── conditioner.py │ │ └── config_latent_diffusion_decoder.py ├── auxiliary │ └── guardrail │ │ ├── common │ │ ├── __init__.py │ │ └── io_utils.py │ │ ├── __init__.py │ │ ├── aegis │ │ └── __init__.py │ │ ├── blocklist │ │ ├── __init__.py │ │ └── utils.py │ │ ├── llamaGuard3 │ │ ├── __init__.py │ │ └── categories.py │ │ ├── face_blur_filter │ │ ├── __init__.py │ │ └── blur_utils.py │ │ └── video_content_safety_filter │ │ ├── __init__.py │ │ ├── vision_encoder.py │ │ └── model.py ├── utils │ ├── easy_io │ │ ├── backends │ │ │ ├── __init__.py │ │ │ └── base_backend.py │ │ ├── __init__.py │ │ └── handlers │ │ │ ├── torch_handler.py │ │ │ ├── pandas_handler.py │ │ │ ├── torchjit_handler.py │ │ │ ├── txt_handler.py │ │ │ ├── gzip_handler.py │ │ │ ├── __init__.py │ │ │ ├── yaml_handler.py │ │ │ ├── tarfile_handler.py │ │ │ ├── csv_handler.py │ │ │ ├── pickle_handler.py │ │ │ ├── base.py │ │ │ └── json_handler.py │ ├── __init__.py │ ├── parallel_state_helper.py │ ├── lazy_config │ │ ├── file_io.py │ │ ├── __init__.py │ │ ├── registry.py │ │ └── omegaconf_patch.py │ ├── device.py │ ├── scheduler.py │ ├── callbacks │ │ └── grad_clip.py │ └── env_parsers │ │ └── cred_env_parser.py ├── __init__.py ├── checkpointer │ ├── __init__.py │ └── tp.py └── callbacks │ └── grad_clip.py ├── configs ├── inference │ ├── 3dgs_res_176_320_views_17.yaml │ ├── 3dgs_res_176_320_views_49.yaml │ ├── 3dgs_res_352_640_views_49.yaml │ ├── 3dgs_res_704_1280_views_49.yaml │ ├── 3dgs_res_704_1280_views_121.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6_prune.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6_dynamic.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml │ └── default.yaml ├── training │ ├── 3dgs_res_176_320_views_49.yaml │ ├── 3dgs_res_176_320_views_17.yaml │ ├── 3dgs_res_352_640_views_49.yaml │ ├── 3dgs_res_704_1280_views_49.yaml │ ├── 3dgs_res_704_1280_views_121.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6_prune.yaml │ ├── 3dgs_res_704_1280_views_121_multi_6_dynamic.yaml │ └── 3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml ├── accelerate │ ├── accelerate_config.yaml │ └── accelerate_config_single.yaml └── demo │ ├── lyra_static.yaml │ └── lyra_dynamic.yaml ├── lyra.yaml ├── requirements_lyra.txt ├── scripts ├── bash │ ├── static_sdg.sh │ └── dynamic_sdg.sh ├── test_environment.py └── download_guardrail_checkpoints.py ├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── recon │ │ └── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── train.py │ │ └── misc.py │ └── data │ │ ├── datafield.py │ │ ├── __init__.py │ │ └── radym_wrapper.py ├── rendering │ └── __init__.py └── utils │ └── random_state_utils.py ├── requirements_gen3c.txt ├── inference.sh ├── CONTRIBUTING.md └── train.sh /cosmos_predict1/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/test_data/video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/lyra/HEAD/cosmos_predict1/tokenizer/test_data/video.mp4 -------------------------------------------------------------------------------- /configs/inference/3dgs_res_176_320_views_17.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_176_320_views_17/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_176_320_views_17 3 | 4 | static_view_indices_fixed: ['0'] -------------------------------------------------------------------------------- /configs/inference/3dgs_res_176_320_views_49.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_176_320_views_49/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_176_320_views_49 3 | 4 | static_view_indices_fixed: ['0'] -------------------------------------------------------------------------------- /configs/inference/3dgs_res_352_640_views_49.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_352_640_views_49/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_352_640_views_49 3 | 4 | static_view_indices_fixed: ['0'] -------------------------------------------------------------------------------- /configs/inference/3dgs_res_704_1280_views_49.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_704_1280_views_49/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_704_1280_views_49 3 | 4 | static_view_indices_fixed: ['0'] -------------------------------------------------------------------------------- /configs/inference/3dgs_res_704_1280_views_121.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_704_1280_views_121/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121 3 | 4 | static_view_indices_fixed: ['0'] -------------------------------------------------------------------------------- /lyra.yaml: -------------------------------------------------------------------------------- 1 | name: lyra 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.10 6 | - pip=25.0 7 | - cmake 8 | - ninja 9 | - gcc=12.4.0 10 | - gxx=12.4.0 11 | - cuda=12.4 12 | - cuda-nvcc=12.4 13 | - cuda-toolkit=12.4 14 | -------------------------------------------------------------------------------- /configs/training/3dgs_res_176_320_views_49.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_176_320_views_49 2 | img_size: [176, 320] 3 | num_views: 98 4 | num_input_views: 49 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 1 7 | batch_size: 4 8 | load_latents: False 9 | max_train_steps: 12500 -------------------------------------------------------------------------------- /configs/training/3dgs_res_176_320_views_17.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_176_320_views_17 2 | img_size: [176, 320] 3 | num_views: 34 4 | num_input_views: 17 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 1 7 | batch_size: 4 8 | load_latents: False 9 | max_train_steps: 10000 10 | -------------------------------------------------------------------------------- /configs/inference/3dgs_res_704_1280_views_121_multi_6.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6 3 | 4 | static_view_indices_fixed: ['5', '0', '1', '2', '3', '4'] 5 | target_index_subsample: 4 -------------------------------------------------------------------------------- /configs/inference/3dgs_res_704_1280_views_121_multi_6_prune.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6_prune/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6_prune 3 | 4 | static_view_indices_fixed: ['5', '0', '1', '2', '3', '4'] 5 | target_index_subsample: 4 -------------------------------------------------------------------------------- /configs/training/3dgs_res_352_640_views_49.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_352_640_views_49 2 | img_size: [352, 640] 3 | num_views: 98 4 | num_input_views: 49 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 1 7 | batch_size: 2 8 | lpips_chunk_size: 8 9 | load_latents: False 10 | max_train_steps: 15000 -------------------------------------------------------------------------------- /configs/training/3dgs_res_704_1280_views_49.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_704_1280_views_49 2 | img_size: [704, 1280] 3 | num_views: 98 4 | num_input_views: 49 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 1 7 | batch_size: 1 8 | lpips_chunk_size: 1 9 | load_latents: False 10 | max_train_steps: 17500 -------------------------------------------------------------------------------- /configs/training/3dgs_res_704_1280_views_121.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_704_1280_views_121 2 | img_size: [704, 1280] 3 | num_views: 130 4 | num_input_views: 121 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 1 7 | batch_size: 1 8 | lpips_chunk_size: 1 9 | checkpointing_steps: 200 10 | max_train_steps: 75000 -------------------------------------------------------------------------------- /configs/training/3dgs_res_704_1280_views_121_multi_6.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6 2 | img_size: [704, 1280] 3 | num_views: 130 4 | num_input_views: 121 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 6 7 | batch_size: 1 8 | static_view_indices_sampling: random_bucket 9 | static_frame_sampling: exponential 10 | lpips_chunk_size: 1 11 | checkpointing_steps: 50 12 | max_train_steps: 82000 -------------------------------------------------------------------------------- /requirements_lyra.txt: -------------------------------------------------------------------------------- 1 | flash_attn==2.7.4.post1 2 | timm==1.0.19 3 | kiui==0.2.17 4 | lru-dict==1.3.0 5 | git+https://github.com/Dao-AILab/causal-conv1d@v1.4.0 6 | git+https://github.com/nerfstudio-project/gsplat.git@73fad53c31ec4d6b088470715a63f432990493de 7 | git+https://github.com/rahul-goel/fused-ssim/@8bdb59feb7b9a41b1fab625907cb21f5417deaac 8 | mpi4py==4.1.0 9 | plyfile==1.1.2 10 | deepspeed==0.17.5 11 | accelerate==1.10.0 12 | openexr==3.2.3 -------------------------------------------------------------------------------- /configs/accelerate/accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | enable_cpu_affinity: false 5 | downcast_bf16: 'no' 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: 'no' 10 | num_machines: 1 11 | num_processes: 8 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false -------------------------------------------------------------------------------- /configs/accelerate/accelerate_config_single.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: NO 4 | enable_cpu_affinity: false 5 | downcast_bf16: 'no' 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: 'no' 10 | num_machines: 1 11 | num_processes: 1 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false -------------------------------------------------------------------------------- /scripts/bash/static_sdg.sh: -------------------------------------------------------------------------------- 1 | CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=1 cosmos_predict1/diffusion/inference/gen3c_single_image_sdg.py \ 2 | --checkpoint_dir checkpoints \ 3 | --num_gpus 1 \ 4 | --input_image_path assets/demo/static/diffusion_input/images/00172.png \ 5 | --video_save_folder assets/demo/static/diffusion_output_generated \ 6 | --foreground_masking \ 7 | --multi_trajectory \ 8 | --total_movement_distance_factor 1.0 -------------------------------------------------------------------------------- /scripts/bash/dynamic_sdg.sh: -------------------------------------------------------------------------------- 1 | CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=1 cosmos_predict1/diffusion/inference/gen3c_dynamic_sdg.py \ 2 | --checkpoint_dir checkpoints \ 3 | --vipe_path assets/demo/dynamic/diffusion_input/rgb/6a71ee0422ff4222884f1b2a3cba6820.mp4 \ 4 | --video_save_folder assets/demo/dynamic/diffusion_output_generated \ 5 | --disable_prompt_upsampler \ 6 | --num_gpus 1 \ 7 | --foreground_masking \ 8 | --multi_trajectory -------------------------------------------------------------------------------- /configs/training/3dgs_res_704_1280_views_121_multi_6_prune.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6_prune 2 | img_size: [704, 1280] 3 | num_views: 130 4 | num_input_views: 121 5 | gs_view_chunk_size: 1 6 | num_input_multi_views: 6 7 | batch_size: 1 8 | static_view_indices_sampling: random_bucket 9 | static_frame_sampling: exponential 10 | lpips_chunk_size: 1 11 | lambda_opacity: 0.1 12 | gaussians_prune_ratio: 0.8 13 | gaussians_random_ratio: 0.0 14 | checkpointing_steps: 50 15 | max_train_steps: 83000 -------------------------------------------------------------------------------- /configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic 2 | data_mode: [['lyra_dynamic', 1]] 3 | img_size: [704, 1280] 4 | num_views: 133 5 | num_input_views: 121 6 | gs_view_chunk_size: 1 7 | num_input_multi_views: 6 8 | batch_size: 1 9 | static_view_indices_sampling: random_bucket 10 | static_frame_sampling: exponential 11 | lpips_chunk_size: 1 12 | use_time_embedding: true 13 | # use flipped supervision 14 | select_target_views_input_dynamic: false 15 | checkpointing_steps: 50 16 | max_train_steps: 90000 -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend 2 | from cosmos_predict1.utils.easy_io.backends.http_backend import HTTPBackend 3 | from cosmos_predict1.utils.easy_io.backends.local_backend import LocalBackend 4 | from cosmos_predict1.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend 5 | 6 | __all__ = [ 7 | "BaseStorageBackend", 8 | "LocalBackend", 9 | "HTTPBackend", 10 | "register_backend", 11 | "backends", 12 | "prefix_to_backends", 13 | ] 14 | -------------------------------------------------------------------------------- /configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml: -------------------------------------------------------------------------------- 1 | output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune 2 | data_mode: [['lyra_dynamic', 1]] 3 | img_size: [704, 1280] 4 | num_views: 133 5 | num_input_views: 121 6 | gs_view_chunk_size: 1 7 | num_input_multi_views: 6 8 | batch_size: 1 9 | static_view_indices_sampling: random_bucket 10 | static_frame_sampling: exponential 11 | lpips_chunk_size: 1 12 | lambda_opacity: 0.1 13 | gaussians_prune_ratio: 0.8 14 | gaussians_random_ratio: 0.0 15 | # use flipped supervision 16 | select_target_views_input_dynamic: false 17 | use_time_embedding: true 18 | checkpointing_steps: 50 19 | max_train_steps: 91000 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /src/models/recon/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /src/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /src/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /cosmos_predict1/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/checkpointer/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /configs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic 3 | 4 | static_view_indices_fixed: ['5', '0', '1', '2', '3', '4'] 5 | target_index_subsample: 4 6 | 7 | # For dynamic scenes, set target time 8 | set_manual_time_idx: true 9 | 10 | # Only create outputs for specified target times 11 | target_index_manual: [0, 60, 120] 12 | 13 | dataset_name: lyra_dynamic_demo 14 | 15 | # Alternative: Loop over start and number of time indices with given stride 16 | # target_index_manual: null 17 | # target_index_manual_stride: 1 18 | # target_index_manual_start_idx: 0 19 | # target_index_manual_num_idx: 121 -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/aegis/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/base/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/diffusion_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /configs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune/config.yaml 2 | out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic_prune 3 | 4 | static_view_indices_fixed: ['5', '0', '1', '2', '3', '4'] 5 | target_index_subsample: 4 6 | 7 | # For dynamic scenes, set target time 8 | set_manual_time_idx: true 9 | 10 | # Only create outputs for specified target times 11 | target_index_manual: [0, 60, 120] 12 | 13 | dataset_name: lyra_dynamic_demo 14 | 15 | # Alternative: Loop over start and number of time indices with given stride 16 | # target_index_manual: null 17 | # target_index_manual_stride: 1 18 | # target_index_manual_start_idx: 0 19 | # target_index_manual_num_idx: 121 -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /requirements_gen3c.txt: -------------------------------------------------------------------------------- 1 | attrs==25.1.0 2 | better-profanity==0.7.0 3 | boto3==1.35.99 4 | decord==0.6.0 5 | diffusers==0.32.2 6 | einops==0.8.1 7 | huggingface-hub==0.29.2 8 | hydra-core==1.3.2 9 | imageio[pyav,ffmpeg]==2.37.0 10 | iopath==0.1.10 11 | ipdb==0.13.13 12 | loguru==0.7.2 13 | mediapy==1.2.2 14 | megatron-core==0.10.0 15 | nltk==3.9.1 16 | numpy==1.26.4 17 | nvidia-ml-py==12.535.133 18 | omegaconf==2.3.0 19 | opencv-python==4.10.0.84 20 | pandas==2.2.3 21 | peft==0.14.0 22 | pillow==11.1.0 23 | protobuf==4.25.3 24 | pynvml==12.0.0 25 | pyyaml==6.0.2 26 | retinaface-py==0.0.2 27 | safetensors==0.5.3 28 | scikit-image==0.24.0 29 | sentencepiece==0.2.0 30 | setuptools==76.0.0 31 | termcolor==2.5.0 32 | torch==2.6.0 33 | torchvision==0.21.0 34 | tqdm==4.66.5 35 | transformers==4.49.0 36 | warp-lang==1.7.2 -------------------------------------------------------------------------------- /configs/demo/lyra_static.yaml: -------------------------------------------------------------------------------- 1 | # Save all renderings etc. in this folder 2 | out_dir_inference: outputs/demo/lyra_static 3 | 4 | # Define dataset name defined in src/models/data/registry.py 5 | dataset_name: lyra_static_demo # Use pre-generated latents 6 | # dataset_name: lyra_static_demo_generated # Generate own latents 7 | 8 | # Order of camera trajectory indices 9 | static_view_indices_fixed: ['5', '0', '1', '2', '3', '4'] 10 | 11 | # Only render each 4. frame 12 | target_index_subsample: 4 13 | 14 | # For static scenes, do not set target time 15 | set_manual_time_idx: true 16 | 17 | # Inherit from these configs the model part etc. 18 | config_path: [configs/training/default.yaml, configs/training/3dgs_res_704_1280_views_121_multi_6_prune.yaml] 19 | 20 | # Update path to where the static Lyra checkpoint is downloaded 21 | ckpt_path: checkpoints/Lyra/lyra_static.pt -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import dataclasses 17 | 18 | 19 | @dataclasses.dataclass 20 | class ItemDatasetConfig: 21 | path: str 22 | length: int 23 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/video2world/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.training.config.video2world.experiment import register_experiments 19 | 20 | 21 | def register_configs(): 22 | cs = ConfigStore.instance() 23 | 24 | register_experiments(cs) 25 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/world_interpolator/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.training.config.world_interpolator.experiment import register_experiments 19 | 20 | 21 | def register_configs(): 22 | cs = ConfigStore.instance() 23 | 24 | register_experiments(cs) 25 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.training.config.text2world_multiview.experiment import register_experiments 19 | 20 | 21 | def register_configs(): 22 | cs = ConfigStore.instance() 23 | 24 | register_experiments(cs) 25 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.training.config.video2world_multiview.experiment import register_experiments 19 | 20 | 21 | def register_configs(): 22 | cs = ConfigStore.instance() 23 | 24 | register_experiments(cs) 25 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/parallel_state_helper.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from megatron.core import parallel_state 17 | 18 | 19 | def is_tp_cp_pp_rank0(): 20 | return ( 21 | parallel_state.get_tensor_model_parallel_rank() == 0 22 | and parallel_state.get_pipeline_model_parallel_rank() == 0 23 | and parallel_state.get_context_parallel_rank() == 0 24 | ) 25 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.training.config.video2world_instruction.experiment import register_experiments 19 | 20 | 21 | def register_configs(): 22 | cs = ConfigStore.instance() 23 | 24 | register_experiments(cs) 25 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/lazy_config/file_io.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler 17 | from iopath.common.file_io import PathManager as PathManagerBase 18 | 19 | __all__ = ["PathManager", "PathHandler"] 20 | 21 | 22 | PathManager = PathManagerBase() 23 | PathManager.register_handler(HTTPURLHandler()) 24 | PathManager.register_handler(OneDrivePathHandler()) 25 | -------------------------------------------------------------------------------- /configs/demo/lyra_dynamic.yaml: -------------------------------------------------------------------------------- 1 | # Save all renderings etc. in this folder 2 | out_dir_inference: outputs/demo/lyra_dynamic 3 | 4 | # Define dataset name defined in src/models/data/registry.py 5 | dataset_name: lyra_dynamic_demo # Use pre-generated latents 6 | # dataset_name: lyra_dynamic_demo_generated # Generate own latents 7 | 8 | # Order of camera trajectory indices 9 | static_view_indices_fixed: ['5', '0', '1', '2', '3', '4'] 10 | 11 | # Only render each 4. frame 12 | target_index_subsample: 4 13 | 14 | # Inherit from these configs the model part etc. 15 | config_path: [configs/training/default.yaml, configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml] 16 | 17 | # For dynamic scenes, set target time 18 | set_manual_time_idx: true 19 | 20 | # Only create outputs for specified target times (between 0=min and 120=max) 21 | target_index_manual: [0, 60, 120] 22 | 23 | # Alternative: Loop over start and number of time indices with given stride 24 | # target_index_manual: null 25 | # target_index_manual_stride: 1 26 | # target_index_manual_start_idx: 0 27 | # target_index_manual_num_idx: 121 28 | 29 | # Update path to where the static Lyra checkpoint is downloaded 30 | ckpt_path: checkpoints/Lyra/lyra_dynamic.pt -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | UNSAFE_CATEGORIES = { 17 | "S1": "Violent Crimes.", 18 | "S2": "Non-Violent Crimes.", 19 | "S3": "Sex Crimes.", 20 | "S4": "Child Exploitation.", 21 | "S5": "Defamation.", 22 | "S6": "Specialized Advice.", 23 | "S7": "Privacy.", 24 | "S8": "Intellectual Property.", 25 | "S9": "Indiscriminate Weapons.", 26 | "S10": "Hate.", 27 | "S11": "Self-Harm.", 28 | "S12": "Sexual Content.", 29 | "S13": "Elections.", 30 | "s14": "Code Interpreter Abuse.", 31 | } 32 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/torch_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | try: 17 | import torch 18 | except ImportError: 19 | torch = None 20 | 21 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 22 | 23 | 24 | class TorchHandler(BaseFileHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file, **kwargs): 28 | return torch.load(file, **kwargs) 29 | 30 | def dump_to_fileobj(self, obj, file, **kwargs): 31 | torch.save(obj, file, **kwargs) 32 | 33 | def dump_to_str(self, obj, **kwargs): 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/pandas_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pandas as pd 17 | 18 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip 19 | 20 | 21 | class PandasHandler(BaseFileHandler): 22 | str_like = False 23 | 24 | def load_from_fileobj(self, file, **kwargs): 25 | return pd.read_csv(file, **kwargs) 26 | 27 | def dump_to_fileobj(self, obj, file, **kwargs): 28 | obj.to_csv(file, **kwargs) 29 | 30 | def dump_to_str(self, obj, **kwargs): 31 | raise NotImplementedError("PandasHandler does not support dumping to str") 32 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/base/ema.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.utils.ema import EMAModelTracker, PowerEMATracker 17 | from cosmos_predict1.utils.lazy_config import PLACEHOLDER 18 | from cosmos_predict1.utils.lazy_config import LazyCall as L 19 | from cosmos_predict1.utils.lazy_config import LazyDict 20 | 21 | PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( 22 | model=PLACEHOLDER, enabled=True, rate=0.10, num=3 23 | ) 24 | 25 | RegEMAConfig: LazyDict = L(EMAModelTracker.initialize_multi_rank_ema)( 26 | model=PLACEHOLDER, enabled=True, rate=0.999, num=1 27 | ) 28 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | try: 17 | import torch 18 | except ImportError: 19 | torch = None 20 | 21 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 22 | 23 | 24 | class TorchJitHandler(BaseFileHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file, **kwargs): 28 | return torch.jit.load(file, **kwargs) 29 | 30 | def dump_to_fileobj(self, obj, file, **kwargs): 31 | torch.jit.save(obj, file, **kwargs) 32 | 33 | def dump_to_str(self, obj, **kwargs): 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /src/models/utils/train.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | 18 | def get_most_recent_checkpoint(output_dir: str) -> str | None: 19 | """ 20 | Returns the most recent checkpoint directory from the given output directory. 21 | 22 | Args: 23 | output_dir (str): Path to the directory containing checkpoints. 24 | 25 | Returns: 26 | str | None: The name of the most recent checkpoint directory, or None if none exist. 27 | """ 28 | dirs = os.listdir(output_dir) 29 | dirs = [d for d in dirs if d.startswith("checkpoint")] 30 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 31 | return dirs[-1] if dirs else None -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/txt_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 17 | 18 | 19 | class TxtHandler(BaseFileHandler): 20 | def load_from_fileobj(self, file, **kwargs): 21 | del kwargs 22 | return file.read() 23 | 24 | def dump_to_fileobj(self, obj, file, **kwargs): 25 | del kwargs 26 | if not isinstance(obj, str): 27 | obj = str(obj) 28 | file.write(obj) 29 | 30 | def dump_to_str(self, obj, **kwargs): 31 | del kwargs 32 | if not isinstance(obj, str): 33 | obj = str(obj) 34 | return obj 35 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/gzip_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import gzip 17 | import pickle 18 | from io import BytesIO 19 | from typing import Any 20 | 21 | from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler 22 | 23 | 24 | class GzipHandler(PickleHandler): 25 | str_like = False 26 | 27 | def load_from_fileobj(self, file: BytesIO, **kwargs): 28 | with gzip.GzipFile(fileobj=file, mode="rb") as f: 29 | return pickle.load(f) 30 | 31 | def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): 32 | with gzip.GzipFile(fileobj=file, mode="wb") as f: 33 | pickle.dump(obj, f) 34 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 17 | from cosmos_predict1.utils.easy_io.handlers.json_handler import JsonHandler 18 | from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler 19 | from cosmos_predict1.utils.easy_io.handlers.registry_utils import file_handlers, register_handler 20 | from cosmos_predict1.utils.easy_io.handlers.yaml_handler import YamlHandler 21 | 22 | __all__ = [ 23 | "BaseFileHandler", 24 | "JsonHandler", 25 | "PickleHandler", 26 | "YamlHandler", 27 | "register_handler", 28 | "file_handlers", 29 | ] 30 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/modules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Tuple 17 | 18 | import torch 19 | 20 | 21 | class EDMScaling: 22 | def __init__(self, sigma_data: float = 0.5): 23 | self.sigma_data = sigma_data 24 | 25 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 26 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 27 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 28 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 29 | c_noise = 0.25 * sigma.log() 30 | return c_skip, c_out, c_in, c_noise 31 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cfg_file_name=3dgs_res_176_320_views_17.yaml 17 | # cfg_file_name=3dgs_res_176_320_views_49.yaml 18 | # cfg_file_name=3dgs_res_352_640_views_49.yaml 19 | # cfg_file_name=3dgs_res_704_1280_views_49.yaml 20 | # cfg_file_name=3dgs_res_704_1280_views_121.yaml 21 | # cfg_file_name=3dgs_res_704_1280_views_121_multi_6.yaml 22 | # cfg_file_name=3dgs_res_704_1280_views_121_multi_6_prune.yaml 23 | # cfg_file_name=3dgs_res_704_1280_views_121_multi_6_dynamic.yaml 24 | # cfg_file_name=3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml 25 | 26 | out_dir_main=outputs/lyra 27 | 28 | cfg=configs/inference/$cfg_file_name 29 | accelerate launch sample.py --config $cfg -------------------------------------------------------------------------------- /configs/inference/default.yaml: -------------------------------------------------------------------------------- 1 | config_path: outputs/training/cosmos3dgs/config.yaml 2 | out_dir_inference: outputs/inference 3 | dataset_name: lyra_static_demo 4 | 5 | # If not checkpoint name given, take the latest 6 | ckpt_path: null 7 | ckpt_name: null 8 | 9 | # Set view indices manually 10 | static_view_indices_fixed: null 11 | 12 | # Only render a stride of the cameras from the 3DGS 13 | target_index_subsample: 1 14 | 15 | # Do evaluation 16 | do_eval: false 17 | 18 | # Don't read and write the depth 19 | use_depth: false 20 | 21 | # Overwrite number of test images 22 | num_test_images: null 23 | 24 | # Video output fps 25 | out_fps: 24 26 | 27 | # Assume static scenes in default 28 | target_index_manual: null 29 | target_index_manual_start_idx: null 30 | 31 | ## Export file config 32 | # Output a grid of results, if yes, how many scenes to visualize in one grid 33 | save_grid: false 34 | num_grid_samples: 4 35 | # Output RGB decoder output next to the 3DGS rendering 36 | save_gt_input: true 37 | # Output a separate file of the RGB decoder output 38 | save_video_input: false 39 | # Output annotated gt depth 40 | save_gt_depth: true 41 | # Save a RGB-decoded version of the latents 42 | save_rgb_decoding: false 43 | # Output 3D gaussians as simple ply file 44 | save_gaussians: false 45 | # Output 3D gaussians using the original 3DGS export script 46 | save_gaussians_orig: 47 | # Skip existing generated file 48 | skip_existing: false -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/types.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import annotations 17 | 18 | from dataclasses import dataclass 19 | from typing import Optional 20 | 21 | import torch 22 | 23 | 24 | @dataclass 25 | class LabelImageCondition: 26 | label: torch.Tensor 27 | 28 | def get_classifier_free_guidance_condition(self) -> LabelImageCondition: 29 | return LabelImageCondition(torch.zeros_like(self.label)) 30 | 31 | 32 | @dataclass 33 | class DenoisePrediction: 34 | x0: torch.Tensor # clean data prediction 35 | eps: Optional[torch.Tensor] = None # noise prediction 36 | logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty 37 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import cv2 17 | import numpy as np 18 | 19 | 20 | def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray: 21 | """ 22 | Pixelate a face region by reducing resolution and then upscaling. 23 | 24 | Args: 25 | face_img: Face region to pixelate 26 | blocks: Number of blocks to divide the face into (in each dimension) 27 | 28 | Returns: 29 | Pixelated face region 30 | """ 31 | h, w = face_img.shape[:2] 32 | # Shrink the image and scale back up to create pixelation effect 33 | temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR) 34 | pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST) 35 | return pixelated 36 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/utils/customization/customization_manager.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from enum import Enum 18 | 19 | 20 | class CustomizationType(Enum): 21 | LORA = 1 22 | REPLACE = 2 23 | 24 | @classmethod 25 | def from_value(cls, value): 26 | """Convert both int and str to the corresponding enum.""" 27 | if isinstance(value, str): 28 | value = value.lower() 29 | if value == "lora": 30 | return cls.LORA 31 | elif value == "replace": 32 | return cls.REPLACE 33 | elif value == "": 34 | return None 35 | else: 36 | raise ValueError("Customization type must be lora or replace") 37 | raise TypeError("CustomizationType must be specified as a string.") 38 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import attrs 17 | 18 | from cosmos_predict1.tokenizer.training.model import TokenizerModel 19 | from cosmos_predict1.utils.config import EMAConfig 20 | from cosmos_predict1.utils.lazy_config import LazyCall as L 21 | from cosmos_predict1.utils.lazy_config import LazyDict 22 | 23 | 24 | @attrs.define(slots=False) 25 | class ModelConfig: 26 | network: LazyDict = None 27 | loss: LazyDict = None 28 | metric: LazyDict = None 29 | ema: EMAConfig = EMAConfig(enabled=True, beta=0.9999) 30 | precision: str = "bfloat16" 31 | torch_compile: bool = False 32 | disc: LazyDict = None 33 | disc_optimizer: LazyDict = None 34 | disc_scheduler: LazyDict = None 35 | 36 | 37 | DefaultModelConfig: LazyDict = L(TokenizerModel)(config=ModelConfig()) 38 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/base/model_parallel.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from megatron.core import ModelParallelConfig 18 | 19 | from cosmos_predict1.utils.lazy_config import LazyDict 20 | 21 | 22 | def create_model_parallel_config(): 23 | model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16")) 24 | model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}" 25 | model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}" 26 | model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}" 27 | MODEL_PARALLELS = LazyDict( 28 | dict( 29 | model_parallel_bf16=model_parallel, 30 | ), 31 | flags={"allow_objects": True}, 32 | ) 33 | return MODEL_PARALLELS["model_parallel_bf16"] 34 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/yaml_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import yaml 17 | 18 | try: 19 | from yaml import CDumper as Dumper # type: ignore 20 | from yaml import CLoader as Loader # type: ignore 21 | except ImportError: 22 | from yaml import Loader, Dumper # type: ignore 23 | 24 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip 25 | 26 | 27 | class YamlHandler(BaseFileHandler): 28 | def load_from_fileobj(self, file, **kwargs): 29 | kwargs.setdefault("Loader", Loader) 30 | return yaml.load(file, **kwargs) 31 | 32 | def dump_to_fileobj(self, obj, file, **kwargs): 33 | kwargs.setdefault("Dumper", Dumper) 34 | yaml.dump(obj, file, **kwargs) 35 | 36 | def dump_to_str(self, obj, **kwargs): 37 | kwargs.setdefault("Dumper", Dumper) 38 | return yaml.dump(obj, **kwargs) 39 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/base/callbacks.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing 17 | from cosmos_predict1.callbacks.grad_clip import GradClip 18 | from cosmos_predict1.utils.callback import ProgressBarCallback 19 | from cosmos_predict1.utils.lazy_config import LazyCall as L 20 | 21 | BASIC_CALLBACKS = dict( 22 | progress_bar=L(ProgressBarCallback)(), 23 | grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"), 24 | ) 25 | 26 | VIDEO_TEACHER_FORCING_CALLBACK = dict( 27 | vid_sampling_tf=L(VideoSamplingTeacherForcing)( 28 | every_n=500, 29 | video_latent_shape="${model.model_config.video_latent_shape}", 30 | num_frames_to_display=4, 31 | save_folder="video_sampling_teacher_forcing", 32 | ) 33 | ) 34 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/trainer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.diffusion.training.utils.checkpointer import MultiRankCheckpointer 17 | from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer 18 | from cosmos_predict1.utils.trainer import Trainer as BaseTrainer 19 | 20 | 21 | class Trainer(BaseTrainer): 22 | def __init__(self, config): 23 | super(Trainer, self).__init__(config) 24 | if config.trainer.distributed_parallelism == "ddp": 25 | self.checkpointer = MultiRankCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) 26 | elif config.trainer.distributed_parallelism == "fsdp": 27 | self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) 28 | else: 29 | raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") 30 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/base/dataset.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Dataset config class.""" 17 | 18 | import attrs 19 | 20 | from cosmos_predict1.utils.config import make_freezable 21 | 22 | 23 | @make_freezable 24 | @attrs.define(slots=False) 25 | class VideoDatasetConfig: 26 | """ 27 | Args: 28 | dataset_dir (str): Base path to the dataset directory 29 | sequence_interval (int): Interval between sampled frames in a sequence 30 | num_frames (int): Number of frames to load per sequence 31 | video_size (list): Target size [H,W] for video frames 32 | start_frame_interval (int): Interval between starting frames of sequences 33 | """ 34 | 35 | dataset_dir: str = "datasets/cosmos_nemo_assets/videos/" 36 | sequence_interval: int = 1 37 | num_frames: int = 33 38 | video_size: list[int, int] = [640, 848] 39 | start_frame_interval: int = 1 40 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tarfile 17 | 18 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 19 | 20 | 21 | class TarHandler(BaseFileHandler): 22 | str_like = False 23 | 24 | def load_from_fileobj(self, file, mode="r|*", **kwargs): 25 | return tarfile.open(fileobj=file, mode=mode, **kwargs) 26 | 27 | def load_from_path(self, filepath, mode="r|*", **kwargs): 28 | return tarfile.open(filepath, mode=mode, **kwargs) 29 | 30 | def dump_to_fileobj(self, obj, file, mode="w", **kwargs): 31 | with tarfile.open(fileobj=file, mode=mode) as tar: 32 | tar.add(obj, **kwargs) 33 | 34 | def dump_to_path(self, obj, filepath, mode="w", **kwargs): 35 | with tarfile.open(filepath, mode=mode) as tar: 36 | tar.add(obj, **kwargs) 37 | 38 | def dump_to_str(self, obj, **kwargs): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/base/optim.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.diffusion.training.functional.lr_scheduler import LambdaLinearScheduler 17 | from cosmos_predict1.diffusion.training.utils.optim_instantiate import get_base_optimizer 18 | from cosmos_predict1.utils.lazy_config import PLACEHOLDER 19 | from cosmos_predict1.utils.lazy_config import LazyCall as L 20 | from cosmos_predict1.utils.lazy_config import LazyDict 21 | 22 | FusedAdamWConfig: LazyDict = L(get_base_optimizer)( 23 | model=PLACEHOLDER, 24 | lr=1e-4, 25 | weight_decay=0.3, 26 | betas=[0.9, 0.999], 27 | optim_type="fusedadam", 28 | eps=1e-8, 29 | sharding=False, 30 | master_weights=True, 31 | capturable=True, 32 | ) 33 | 34 | LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( 35 | warm_up_steps=[1000], 36 | cycle_lengths=[10000000000000], 37 | f_start=[1.0e-6], 38 | f_max=[1.0], 39 | f_min=[1.0], 40 | ) 41 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/csv_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import csv 17 | from io import StringIO 18 | 19 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 20 | 21 | 22 | class CsvHandler(BaseFileHandler): 23 | def load_from_fileobj(self, file, **kwargs): 24 | del kwargs 25 | reader = csv.reader(file) 26 | return list(reader) 27 | 28 | def dump_to_fileobj(self, obj, file, **kwargs): 29 | del kwargs 30 | writer = csv.writer(file) 31 | if not all(isinstance(row, list) for row in obj): 32 | raise ValueError("Each row must be a list") 33 | writer.writerows(obj) 34 | 35 | def dump_to_str(self, obj, **kwargs): 36 | del kwargs 37 | output = StringIO() 38 | writer = csv.writer(output) 39 | if not all(isinstance(row, list) for row in obj): 40 | raise ValueError("Each row must be a list") 41 | writer.writerows(obj) 42 | return output.getvalue() 43 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/pickle_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pickle 17 | from io import BytesIO 18 | from typing import Any 19 | 20 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 21 | 22 | 23 | class PickleHandler(BaseFileHandler): 24 | str_like = False 25 | 26 | def load_from_fileobj(self, file: BytesIO, **kwargs): 27 | return pickle.load(file, **kwargs) 28 | 29 | def load_from_path(self, filepath, **kwargs): 30 | return super().load_from_path(filepath, mode="rb", **kwargs) 31 | 32 | def dump_to_str(self, obj, **kwargs): 33 | kwargs.setdefault("protocol", 2) 34 | return pickle.dumps(obj, **kwargs) 35 | 36 | def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): 37 | kwargs.setdefault("protocol", 2) 38 | pickle.dump(obj, file, **kwargs) 39 | 40 | def dump_to_path(self, obj, filepath, **kwargs): 41 | with open(filepath, "wb") as f: 42 | pickle.dump(obj, f, **kwargs) 43 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/experiments/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """registry for commandline override options for config.""" 17 | from cosmos_predict1.utils.lazy_config import LazyDict 18 | 19 | 20 | def create_debug_job_with_mock_data(full_experiment_name): 21 | job_dict = dict( 22 | defaults=[ 23 | f"/experiment/{full_experiment_name.replace('-', '_')}", 24 | {"override /data_train": "mock_video360"}, 25 | {"override /data_val": "mock_video360"}, 26 | "_self_", 27 | ], 28 | job=dict(group="debug", name=f"mock_{full_experiment_name}" + "_${now:%Y-%m-%d}_${now:%H-%M-%S}"), 29 | trainer=dict( 30 | max_iter=2, 31 | logging_iter=1, 32 | max_val_iter=1, 33 | validation_iter=2, 34 | ), 35 | checkpoint=dict( 36 | strict_resume=False, 37 | load_training_state=False, 38 | save_iter=2, 39 | ), 40 | ) 41 | return LazyDict(job_dict) 42 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/modules/distributions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The distribution modes to use for continuous image tokenizers.""" 17 | 18 | import torch 19 | 20 | 21 | class IdentityDistribution(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, parameters): 26 | return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) 27 | 28 | 29 | class GaussianDistribution(torch.nn.Module): 30 | def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): 31 | super().__init__() 32 | self.min_logvar = min_logvar 33 | self.max_logvar = max_logvar 34 | 35 | def sample(self, mean, logvar): 36 | std = torch.exp(0.5 * logvar) 37 | return mean + std * torch.randn_like(mean) 38 | 39 | def forward(self, parameters): 40 | mean, logvar = torch.chunk(parameters, 2, dim=1) 41 | logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) 42 | return self.sample(mean, logvar), (mean, logvar) 43 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The loss reduction modes.""" 17 | 18 | from enum import Enum 19 | 20 | import torch 21 | 22 | 23 | def _mean(recon: torch.Tensor) -> torch.Tensor: 24 | return torch.mean(recon) 25 | 26 | 27 | def _sum_per_frame(recon: torch.Tensor) -> torch.Tensor: 28 | batch_size = recon.shape[0] * recon.shape[2] if recon.ndim == 5 else recon.shape[0] 29 | return torch.sum(recon) / batch_size 30 | 31 | 32 | def _sum(recon: torch.Tensor) -> torch.Tensor: 33 | return torch.sum(recon) / recon.shape[0] 34 | 35 | 36 | class ReduceMode(Enum): 37 | MEAN = "MEAN" 38 | SUM_PER_FRAME = "SUM_PER_FRAME" 39 | SUM = "SUM" 40 | 41 | @property 42 | def function(self): 43 | if self == ReduceMode.MEAN: 44 | return _mean 45 | elif self == ReduceMode.SUM_PER_FRAME: 46 | return _sum_per_frame 47 | elif self == ReduceMode.SUM: 48 | return _sum 49 | else: 50 | raise ValueError("Invalid ReduceMode") 51 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/blocklist/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import re 18 | 19 | from cosmos_predict1.utils import log 20 | 21 | 22 | def read_keyword_list_from_dir(folder_path: str) -> list[str]: 23 | """Read keyword list from all files in a folder.""" 24 | output_list = [] 25 | file_list = [] 26 | # Get list of files in the folder 27 | for file in os.listdir(folder_path): 28 | if os.path.isfile(os.path.join(folder_path, file)): 29 | file_list.append(file) 30 | 31 | # Process each file 32 | for file in file_list: 33 | file_path = os.path.join(folder_path, file) 34 | try: 35 | with open(file_path, "r") as f: 36 | output_list.extend([line.strip() for line in f.readlines()]) 37 | except Exception as e: 38 | log.error(f"Error reading file {file}: {str(e)}") 39 | 40 | return output_list 41 | 42 | 43 | def to_ascii(prompt: str) -> str: 44 | """Convert prompt to ASCII.""" 45 | return re.sub(r"[^\x00-\x7F]+", " ", prompt) 46 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/callbacks/low_precision.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from cosmos_predict1.diffusion.training.trainer import Trainer 19 | from cosmos_predict1.utils.callback import LowPrecisionCallback as BaseCallback 20 | from cosmos_predict1.utils.config import Config 21 | from cosmos_predict1.utils.model import Model 22 | 23 | 24 | class LowPrecisionCallback(BaseCallback): 25 | """ 26 | Config with non-primitive type makes it difficult to override the option. 27 | The callback gets precision from model.precision instead. 28 | """ 29 | 30 | def __init__(self, config: Config, trainer: Trainer, update_iter: int): 31 | self.config = config 32 | self.trainer = trainer 33 | self.update_iter = update_iter 34 | 35 | def on_train_start(self, model: Model, iteration: int = 0) -> None: 36 | assert model.precision in [ 37 | torch.bfloat16, 38 | torch.float16, 39 | torch.half, 40 | ], "LowPrecisionCallback must use a low precision dtype." 41 | self.precision_type = model.precision 42 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/utils/peft/lora_config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | def get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1): 18 | """ 19 | Get a LoRA configuration for the Self-Attention (FA) and Cross-Attention (CA) blocks in the model. 20 | This LoRA configuration is used to inject LoRA parameters into the model. 21 | 22 | Args: 23 | first_nblocks (int): The number of blocks to apply LoRA to. 24 | rank (int): The rank of the LoRA matrices. 25 | """ 26 | blocks_regex = r"\b(" + "|".join([str(i) for i in range(first_nblocks)]) + r")\b" 27 | return dict( 28 | enabled=True, 29 | customization_type="LoRA", 30 | rank=rank, 31 | scale=scale, 32 | edits=[ 33 | dict( 34 | blocks=blocks_regex, 35 | customization_type="LoRA", 36 | rank=rank, 37 | scale=scale, 38 | block_edit=[ 39 | "FA[to_q, to_v]", 40 | "CA[to_q, to_v]", 41 | ], 42 | ) 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/base.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABCMeta, abstractmethod 17 | 18 | 19 | class BaseFileHandler(metaclass=ABCMeta): 20 | # `str_like` is a flag to indicate whether the type of file object is 21 | # str-like object or bytes-like object. Pickle only processes bytes-like 22 | # objects but json only processes str-like object. If it is str-like 23 | # object, `StringIO` will be used to process the buffer. 24 | str_like = True 25 | 26 | @abstractmethod 27 | def load_from_fileobj(self, file, **kwargs): 28 | pass 29 | 30 | @abstractmethod 31 | def dump_to_fileobj(self, obj, file, **kwargs): 32 | pass 33 | 34 | @abstractmethod 35 | def dump_to_str(self, obj, **kwargs): 36 | pass 37 | 38 | def load_from_path(self, filepath, mode="r", **kwargs): 39 | with open(filepath, mode) as f: 40 | return self.load_from_fileobj(f, **kwargs) 41 | 42 | def dump_to_path(self, obj, filepath, mode="w", **kwargs): 43 | with open(filepath, mode) as f: 44 | self.dump_to_fileobj(obj, f, **kwargs) 45 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/modules/edm_sde.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from statistics import NormalDist 17 | 18 | import numpy as np 19 | import torch 20 | 21 | 22 | class EDMSDE: 23 | def __init__( 24 | self, 25 | p_mean: float = -1.2, 26 | p_std: float = 1.2, 27 | sigma_max: float = 80.0, 28 | sigma_min: float = 0.002, 29 | ): 30 | self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) 31 | self.sigma_max = sigma_max 32 | self.sigma_min = sigma_min 33 | 34 | def sample_t(self, batch_size: int) -> torch.Tensor: 35 | cdf_vals = np.random.uniform(size=(batch_size)) 36 | samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] 37 | 38 | log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") 39 | return torch.exp(log_sigma) 40 | 41 | def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 42 | """This is trivial in the base class, but may be used by derived classes in a more interesting way""" 43 | return x0, sigma 44 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/metric.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Metric configurations for the tokenizer model. 17 | 18 | Support for PSNR or SSIM, there are validation only metrics. 19 | """ 20 | import attrs 21 | 22 | from cosmos_predict1.tokenizer.training.metrics import CodeUsageMetric, PSNRMetric, SSIMMetric, TokenizerMetric 23 | from cosmos_predict1.utils.lazy_config import LazyCall as L 24 | from cosmos_predict1.utils.lazy_config import LazyDict 25 | 26 | 27 | @attrs.define(slots=False) 28 | class Metric: 29 | # The combined loss function, and its reduction mode. 30 | PSNR: LazyDict = L(PSNRMetric)() 31 | SSIM: LazyDict = L(SSIMMetric)() 32 | 33 | 34 | @attrs.define(slots=False) 35 | class DiscreteTokenizerMetric: 36 | # with code usage (perplexity PPL), for discrete tokenizers only 37 | PSNR: LazyDict = L(PSNRMetric)() 38 | SSIM: LazyDict = L(SSIMMetric)() 39 | CodeUsage: LazyDict = L(CodeUsageMetric)(codebook_size=64000) 40 | 41 | 42 | MetricConfig: LazyDict = L(TokenizerMetric)(config=Metric()) 43 | 44 | DiscreteTokenizerMetricConfig: LazyDict = L(TokenizerMetric)(config=DiscreteTokenizerMetric()) 45 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/handlers/json_handler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | 18 | import numpy as np 19 | 20 | from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler 21 | 22 | 23 | def set_default(obj): 24 | """Set default json values for non-serializable values. 25 | 26 | It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. 27 | It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, 28 | etc.) into plain numbers of plain python built-in types. 29 | """ 30 | if isinstance(obj, (set, range)): 31 | return list(obj) 32 | elif isinstance(obj, np.ndarray): 33 | return obj.tolist() 34 | elif isinstance(obj, np.generic): 35 | return obj.item() 36 | raise TypeError(f"{type(obj)} is unsupported for json dump") 37 | 38 | 39 | class JsonHandler(BaseFileHandler): 40 | def load_from_fileobj(self, file): 41 | return json.load(file) 42 | 43 | def dump_to_fileobj(self, obj, file, **kwargs): 44 | kwargs.setdefault("default", set_default) 45 | json.dump(obj, file, **kwargs) 46 | 47 | def dump_to_str(self, obj, **kwargs): 48 | kwargs.setdefault("default", set_default) 49 | return json.dumps(obj, **kwargs) 50 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from enum import Enum 17 | 18 | from cosmos_predict1.tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution 19 | from cosmos_predict1.tokenizer.modules.layers2d import Decoder, Encoder 20 | from cosmos_predict1.tokenizer.modules.layers3d import DecoderBase, DecoderFactorized, EncoderBase, EncoderFactorized, DecoderFactorizedVanilla, DecoderBaseVanilla 21 | from cosmos_predict1.tokenizer.modules.quantizers import FSQuantizer, LFQuantizer, ResidualFSQuantizer, VectorQuantizer 22 | 23 | 24 | class EncoderType(Enum): 25 | Default = Encoder 26 | 27 | 28 | class DecoderType(Enum): 29 | Default = Decoder 30 | 31 | 32 | class Encoder3DType(Enum): 33 | BASE = EncoderBase 34 | FACTORIZED = EncoderFactorized 35 | 36 | 37 | class Decoder3DType(Enum): 38 | BASE = DecoderBase 39 | FACTORIZED = DecoderFactorized 40 | FACTORIZEDVanilla = DecoderFactorizedVanilla 41 | BASEVanilla = DecoderBaseVanilla 42 | 43 | 44 | class ContinuousFormulation(Enum): 45 | VAE = GaussianDistribution 46 | AE = IdentityDistribution 47 | 48 | 49 | class DiscreteQuantizer(Enum): 50 | VQ = VectorQuantizer 51 | LFQ = LFQuantizer 52 | FSQ = FSQuantizer 53 | RESFSQ = ResidualFSQuantizer 54 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/easy_io/backends/base_backend.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import os.path as osp 18 | from abc import ABCMeta, abstractmethod 19 | 20 | 21 | def mkdir_or_exist(dir_name, mode=0o777): 22 | if dir_name == "": 23 | return 24 | dir_name = osp.expanduser(dir_name) 25 | os.makedirs(dir_name, mode=mode, exist_ok=True) 26 | 27 | 28 | def has_method(obj, method): 29 | return hasattr(obj, method) and callable(getattr(obj, method)) 30 | 31 | 32 | class BaseStorageBackend(metaclass=ABCMeta): 33 | """Abstract class of storage backends. 34 | 35 | All backends need to implement two apis: :meth:`get()` and 36 | :meth:`get_text()`. 37 | 38 | - :meth:`get()` reads the file as a byte stream. 39 | - :meth:`get_text()` reads the file as texts. 40 | """ 41 | 42 | # a flag to indicate whether the backend can create a symlink for a file 43 | # This attribute will be deprecated in future. 44 | _allow_symlink = False 45 | 46 | @property 47 | def allow_symlink(self): 48 | return self._allow_symlink 49 | 50 | @property 51 | def name(self): 52 | return self.__class__.__name__ 53 | 54 | @abstractmethod 55 | def get(self, filepath): 56 | pass 57 | 58 | @abstractmethod 59 | def get_text(self, filepath): 60 | pass 61 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.utils.lazy_config import LazyCall as L 19 | from cosmos_predict1.utils.lazy_config import LazyDict 20 | 21 | Cosmos_Predict1_Text2World_7B_Multiview: LazyDict = LazyDict( 22 | dict( 23 | defaults=[ 24 | "/experiment/Cosmos_Predict1_Text2World_7B", 25 | {"override /net": "faditv2_multiview_7b"}, 26 | {"override /conditioner": "add_fps_image_size_padding_mask_frame_repeat"}, 27 | "_self_", 28 | ], 29 | job=dict( 30 | group="Text2World", 31 | name="Cosmos_Predict1_Text2World_7B_Multiview", 32 | ), 33 | model=dict( 34 | latent_shape=[ 35 | 16, 36 | 16, 37 | 88, 38 | 160, 39 | ], 40 | tokenizer=dict( 41 | video_vae=dict( 42 | pixel_chunk_duration=57, 43 | ) 44 | ), 45 | ), 46 | ) 47 | ) 48 | 49 | 50 | cs = ConfigStore.instance() 51 | cs.store( 52 | group="experiment", 53 | package="_global_", 54 | name=Cosmos_Predict1_Text2World_7B_Multiview["job"]["name"], 55 | node=Cosmos_Predict1_Text2World_7B_Multiview, 56 | ) 57 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/callback.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """callbacks config options: 17 | 18 | BASIC_CALLBACKS: always recommended to use 19 | """ 20 | 21 | from cosmos_predict1.tokenizer.training.callbacks import ( 22 | AdaptCkptStateDict, 23 | ExpandLossMask, 24 | GradClipCallback, 25 | TorchCompile, 26 | ) 27 | from cosmos_predict1.utils.callback import EMAModelCallback, LowPrecisionCallback, ProgressBarCallback 28 | from cosmos_predict1.utils.lazy_config import PLACEHOLDER 29 | from cosmos_predict1.utils.lazy_config import LazyCall as L 30 | 31 | BASIC_CALLBACKS = dict( 32 | low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), 33 | grad_clip=L(GradClipCallback)(grad_clip_norm=1, verbose=False, config=PLACEHOLDER, trainer=PLACEHOLDER), 34 | ema=L(EMAModelCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), 35 | progress_bar=L(ProgressBarCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), 36 | expand_loss_mask=L(ExpandLossMask)(kernel_size=51, config=PLACEHOLDER, trainer=PLACEHOLDER), 37 | adapt_ckpt_state_dict=L(AdaptCkptStateDict)(config=PLACEHOLDER, trainer=PLACEHOLDER), 38 | torch_compile=L(TorchCompile)( 39 | compile_after_iterations=8, 40 | compile_network=False, 41 | compile_loss=False, 42 | compile_loss_keys=["flow", "perceptual"], 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from PIL import Image 18 | from transformers import SiglipModel, SiglipProcessor 19 | 20 | 21 | class SigLIPEncoder(torch.nn.Module): 22 | def __init__( 23 | self, 24 | checkpoint_dir: str, 25 | model_name: str = "google/siglip-so400m-patch14-384", 26 | device="cuda" if torch.cuda.is_available() else "cpu", 27 | dtype=torch.float32, 28 | ) -> None: 29 | super().__init__() 30 | self.checkpoint_dir = checkpoint_dir 31 | self.device = device 32 | self.dtype = dtype 33 | self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir) 34 | self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir) 35 | self.model.to(self.device, dtype=self.dtype).eval() 36 | 37 | @torch.inference_mode() 38 | def encode_image(self, input_img: Image.Image) -> torch.Tensor: 39 | """Encode an image into a feature vector.""" 40 | with torch.no_grad(): 41 | inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype) 42 | image_features = self.model.get_image_features(**inputs) 43 | image_features /= image_features.norm(dim=-1, keepdim=True) 44 | return image_features 45 | -------------------------------------------------------------------------------- /cosmos_predict1/checkpointer/tp.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer 17 | from cosmos_predict1.utils.model import Model 18 | 19 | 20 | class Checkpointer(DDPCheckpointer): 21 | """ 22 | Checkpointer class for Tensor Parallelism (TP) in distributed training. 23 | 24 | This implementation supports the combination of Tensor Parallelism (TP) and Data Parallel Processing (DDP), with optional Context Parallelism (CP). 25 | 26 | Note: 27 | - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. 28 | - In principle, this implementation is also compatible with Pipeline Parallelism (PP) and Expert Parallelism (EP), which are other forms of model parallelism. However, PP and EP have not been tested yet. 29 | """ 30 | 31 | def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: 32 | """ 33 | Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) 34 | to append the TP-rank postfix to the checkpoint path. 35 | """ 36 | checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) 37 | if key == "trainer": 38 | return checkpoint_path 39 | else: 40 | checkpoint_path = checkpoint_path.replace(".pt", f"_mp_{self.mp_rank}.pt") 41 | 42 | return checkpoint_path 43 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/functional/batch_ops.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Functions for performing operations with broadcasting to the right axis 17 | # 18 | # Example 19 | # input1: tensor of size (N1, N2) 20 | # input2: tensor of size (N1, N2, N3, N4) 21 | # batch_mul(input1, input2) = input1[:, :, None, None] * input2 22 | # 23 | # If the common dimensions don't match, we raise an assertion error. 24 | 25 | from torch import Tensor 26 | 27 | 28 | def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: 29 | ndims1 = x.ndim 30 | ndims2 = y.ndim 31 | 32 | common_ndims = min(ndims1, ndims2) 33 | for axis in range(common_ndims): 34 | assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) 35 | 36 | if ndims1 < ndims2: 37 | x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) 38 | elif ndims2 < ndims1: 39 | y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) 40 | 41 | return x, y 42 | 43 | 44 | def batch_add(x: Tensor, y: Tensor) -> Tensor: 45 | x, y = common_broadcast(x, y) 46 | return x + y 47 | 48 | 49 | def batch_mul(x: Tensor, y: Tensor) -> Tensor: 50 | x, y = common_broadcast(x, y) 51 | return x * y 52 | 53 | 54 | def batch_sub(x: Tensor, y: Tensor) -> Tensor: 55 | x, y = common_broadcast(x, y) 56 | return x - y 57 | 58 | 59 | def batch_div(x: Tensor, y: Tensor) -> Tensor: 60 | x, y = common_broadcast(x, y) 61 | return x / y 62 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/functional/multi_step.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | Impl of multistep methods to solve the ODE in the diffusion model. 18 | """ 19 | 20 | from typing import Callable, List, Tuple 21 | 22 | import torch 23 | 24 | from cosmos_predict1.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step 25 | 26 | 27 | def order2_fn( 28 | x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor 29 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 30 | """ 31 | impl the second order multistep method in https://arxiv.org/pdf/2308.02157 32 | Adams Bashforth approach! 33 | """ 34 | if x0_preds: 35 | x0_s1, s1 = x0_preds[0] 36 | x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) 37 | else: 38 | x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] 39 | return x_t, [(x0_s, s)] 40 | 41 | 42 | # key: method name, value: method function 43 | # key: order + algorithm name 44 | MULTISTEP_FNs = { 45 | "2ab": order2_fn, 46 | } 47 | 48 | 49 | def get_multi_step_fn(name: str) -> Callable: 50 | if name in MULTISTEP_FNs: 51 | return MULTISTEP_FNs[name] 52 | methods = "\n\t".join(MULTISTEP_FNs.keys()) 53 | raise RuntimeError("Only support multistep method\n" + methods) 54 | 55 | 56 | def is_multi_step_fn_supported(name: str) -> bool: 57 | """ 58 | Check if the multistep method is supported. 59 | """ 60 | return name in MULTISTEP_FNs 61 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/base/vae.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import omegaconf 17 | 18 | from cosmos_predict1.diffusion.training.module.pretrained_vae import VideoJITTokenizer 19 | from cosmos_predict1.utils.lazy_config import LazyCall as L 20 | 21 | TOKENIZER_OPTIONS = {} 22 | 23 | 24 | def tokenizer_register(key): 25 | def decorator(func): 26 | TOKENIZER_OPTIONS[key] = func 27 | return func 28 | 29 | return decorator 30 | 31 | 32 | @tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") 33 | def get_cosmos_tokenizer_comp8x8x8( 34 | resolution: str, 35 | chunk_duration: int, 36 | ) -> omegaconf.dictconfig.DictConfig: 37 | assert resolution in ["512", "720"] 38 | 39 | pixel_chunk_duration = chunk_duration 40 | temporal_compression_factor = 8 41 | spatial_compression_factor = 8 42 | 43 | return L(VideoJITTokenizer)( 44 | name="cosmos_diffusion_tokenizer_comp8x8x8", 45 | enc_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", 46 | dec_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", 47 | mean_std_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", 48 | latent_ch=16, 49 | is_bf16=True, 50 | pixel_chunk_duration=pixel_chunk_duration, 51 | temporal_compression_factor=temporal_compression_factor, 52 | spatial_compression_factor=spatial_compression_factor, 53 | spatial_resolution=resolution, 54 | ) 55 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, List 17 | 18 | import attrs 19 | 20 | from cosmos_predict1.diffusion.config.base.model import DefaultModelConfig 21 | from cosmos_predict1.diffusion.config.registry import register_configs 22 | from cosmos_predict1.utils import config 23 | from cosmos_predict1.utils.config_helper import import_all_modules_from_package 24 | 25 | 26 | @attrs.define(slots=False) 27 | class Config(config.Config): 28 | # default config groups that will be used unless overwritten 29 | # see config groups in registry.py 30 | defaults: List[Any] = attrs.field( 31 | factory=lambda: [ 32 | "_self_", 33 | {"net": None}, 34 | {"conditioner": "add_fps_image_size_padding_mask"}, 35 | {"tokenizer": "tokenizer"}, 36 | {"experiment": None}, 37 | ] 38 | ) 39 | 40 | 41 | def make_config(): 42 | c = Config( 43 | model=DefaultModelConfig(), 44 | ) 45 | 46 | # Specifying values through instances of attrs 47 | c.job.project = "cosmos_diffusion" 48 | c.job.group = "inference" 49 | 50 | # Call this function to register config groups for advanced overriding. 51 | register_configs() 52 | 53 | # experiment config are defined in the experiment folder 54 | # call import_all_modules_from_package to register them 55 | import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True) 56 | return c 57 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.networks.general_dit_video_conditioned_multiview import MultiviewVideoExtendGeneralDIT 19 | from cosmos_predict1.utils.lazy_config import LazyCall as L 20 | from cosmos_predict1.utils.lazy_config import LazyDict 21 | 22 | Cosmos_Predict1_Video2World_7B_Multiview: LazyDict = LazyDict( 23 | dict( 24 | defaults=[ 25 | "/experiment/Cosmos_Predict1_Text2World_7B_Multiview", 26 | {"override /conditioner": "video_cond_frame_repeat"}, 27 | "_self_", 28 | ], 29 | job=dict( 30 | group="Text2World", 31 | name="Cosmos_Predict1_Video2World_7B_Multiview", 32 | ), 33 | model=dict( 34 | latent_shape=[ 35 | 16, 36 | 16, 37 | 88, 38 | 160, 39 | ], 40 | net=L(MultiviewVideoExtendGeneralDIT)( 41 | n_views=6, 42 | view_condition_dim=6, 43 | add_repeat_frame_embedding=True, 44 | ), 45 | conditioner=dict(video_cond_bool=dict()), 46 | ), 47 | ) 48 | ) 49 | 50 | 51 | cs = ConfigStore.instance() 52 | cs.store( 53 | group="experiment", 54 | package="_global_", 55 | name=Cosmos_Predict1_Video2World_7B_Multiview["job"]["name"], 56 | node=Cosmos_Predict1_Video2World_7B_Multiview, 57 | ) 58 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/optim.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """optimizer config options: 17 | 18 | fused_adam - FusedAdamConfig 19 | adamw - AdamWConfig 20 | """ 21 | 22 | import torch 23 | 24 | from cosmos_predict1.utils import fused_adam 25 | from cosmos_predict1.utils.lazy_config import PLACEHOLDER 26 | from cosmos_predict1.utils.lazy_config import LazyCall as L 27 | from cosmos_predict1.utils.lazy_config import LazyDict 28 | from cosmos_predict1.utils.scheduler import WarmupCosineLR, WarmupLambdaLR 29 | 30 | FusedAdamConfig: LazyDict = L(fused_adam.FusedAdam)( 31 | capturable=True, 32 | master_weights=True, 33 | adam_w_mode=True, 34 | params=PLACEHOLDER, 35 | lr=1e-4, 36 | betas=(0.5, 0.999), 37 | eps=1e-8, 38 | weight_decay=0.01, 39 | ) 40 | 41 | AdamWConfig: LazyDict = L(torch.optim.AdamW)( 42 | params=PLACEHOLDER, 43 | lr=1e-4, 44 | betas=(0.5, 0.999), 45 | eps=1e-8, 46 | weight_decay=0.01, 47 | ) 48 | 49 | WarmupLRConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) 50 | 51 | FusedAdamDiscConfig: LazyDict = L(fused_adam.FusedAdam)( 52 | capturable=True, 53 | master_weights=True, 54 | adam_w_mode=True, 55 | params=PLACEHOLDER, 56 | lr=4e-4, 57 | betas=(0.5, 0.999), 58 | eps=1e-8, 59 | weight_decay=0.01, 60 | ) 61 | 62 | WarmupLRDiscConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) 63 | 64 | WarmupCosineLRConfig: LazyDict = L(WarmupCosineLR)( 65 | optimizer=PLACEHOLDER, warmup_iters=5000, lr_decay_iters=1000000, min_lr=1e-8 66 | ) 67 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/base/tokenizer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import omegaconf 17 | 18 | from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer 19 | from cosmos_predict1.utils.lazy_config import LazyCall as L 20 | 21 | TOKENIZER_OPTIONS = {} 22 | 23 | 24 | def tokenizer_register(key): 25 | def decorator(func): 26 | TOKENIZER_OPTIONS[key] = func 27 | return func 28 | 29 | return decorator 30 | 31 | 32 | @tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") 33 | def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: 34 | assert resolution in ["720"] 35 | 36 | pixel_chunk_duration = chunk_duration 37 | temporal_compression_factor = 8 38 | spatial_compression_factor = 8 39 | 40 | return L(JointImageVideoSharedJITTokenizer)( 41 | video_vae=L(VideoJITTokenizer)( 42 | name="cosmos_predict1_tokenizer", 43 | latent_ch=16, 44 | is_bf16=True, 45 | pixel_chunk_duration=pixel_chunk_duration, 46 | temporal_compression_factor=temporal_compression_factor, 47 | spatial_compression_factor=spatial_compression_factor, 48 | spatial_resolution=resolution, 49 | ), 50 | image_vae=L(JITVAE)( 51 | name="cosmos_predict1_tokenizer", 52 | latent_ch=16, 53 | is_image=False, 54 | is_bf16=True, 55 | ), 56 | name="cosmos_predict1_tokenizer", 57 | latent_ch=16, 58 | ) 59 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from hydra.core.config_store import ConfigStore 17 | 18 | from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT 19 | from cosmos_predict1.utils.lazy_config import LazyCall as L 20 | from cosmos_predict1.utils.lazy_config import LazyDict 21 | 22 | GEN3C_Cosmos_7B: LazyDict = LazyDict( 23 | dict( 24 | defaults=[ 25 | {"override /net": "faditv2_7b"}, 26 | {"override /conditioner": "video_cond"}, 27 | {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, 28 | "_self_", 29 | ], 30 | model=dict( 31 | latent_shape=[ 32 | 16, 33 | 16, 34 | 88, 35 | 160, 36 | ], 37 | conditioner=dict(video_cond_bool=dict()), 38 | net=L(VideoExtendGeneralDIT)( 39 | rope_h_extrapolation_ratio=1.0, 40 | rope_w_extrapolation_ratio=1.0, 41 | rope_t_extrapolation_ratio=2.0, 42 | in_channels=16 + 16 * 4 + 1 # 16: video_latent, 16 * 4: (warped_frames + warped_frames_mask) * buffer 2, 1: mask 43 | ), 44 | frame_buffer_max=2, 45 | ), 46 | job=dict(group="Gen3c", name="GEN3C_Cosmos_7B"), 47 | ) 48 | ) 49 | 50 | cs = ConfigStore.instance() 51 | for _item in [ 52 | GEN3C_Cosmos_7B, 53 | ]: 54 | cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) 55 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import attrs 17 | import torch 18 | import torch.nn as nn 19 | 20 | from cosmos_predict1.utils.config import make_freezable 21 | 22 | 23 | @make_freezable 24 | @attrs.define(slots=False) 25 | class ModelConfig: 26 | input_size: int = 1152 27 | num_classes: int = 7 28 | 29 | 30 | class SafetyClassifier(nn.Module): 31 | def __init__(self, input_size: int = 1024, num_classes: int = 2): 32 | super().__init__() 33 | self.input_size = input_size 34 | self.num_classes = num_classes 35 | self.layers = nn.Sequential( 36 | nn.Linear(self.input_size, 512), 37 | nn.BatchNorm1d(512), 38 | nn.ReLU(), 39 | nn.Linear(512, 256), 40 | nn.BatchNorm1d(256), 41 | nn.ReLU(), 42 | nn.Linear(256, self.num_classes), 43 | # Note: No activation function here; CrossEntropyLoss expects raw logits 44 | ) 45 | 46 | def forward(self, x): 47 | return self.layers(x) 48 | 49 | 50 | class VideoSafetyModel(nn.Module): 51 | def __init__(self, config: ModelConfig) -> None: 52 | super().__init__() 53 | self.config = config 54 | self.num_classes = config.num_classes 55 | self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes) 56 | 57 | @torch.inference_mode() 58 | def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 59 | logits = self.network(data_batch["data"].cuda()) 60 | return {"logits": logits} 61 | -------------------------------------------------------------------------------- /src/models/data/datafield.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from enum import Enum 17 | 18 | # class DataField(Enum): 19 | class DataField(str, Enum): 20 | # [B, C, H, W], float32, RGB image ranges from 0 to 1. 21 | IMAGE_RGB = "image_rgb" 22 | # [B, 4, 4], float32, camera-to-world transformation matrix. 23 | CAMERA_C2W_TRANSFORM = "camera_c2w_transform" 24 | # [B, 4], float32, OpenCV pinhole intrinsics represented as [fx, fy, cx, cy]. 25 | CAMERA_INTRINSICS = "camera_intrinsics" 26 | # list of captions of size B. 27 | CAPTION = "caption" 28 | # [B, H, W], float32, depth map in metric scale. 29 | METRIC_DEPTH = "metric_depth" 30 | # [B, H, W], uint8, instance mask (0 is background). 31 | DYNAMIC_INSTANCE_MASK = "dynamic_instance_mask" 32 | # [B, H, W], float32, backward flow from this frame to previous frame. 33 | BACKWARD_FLOW = "backward_flow" 34 | # [B, H, W, 3], float32, ray direction (assume no motion/RS). 35 | RAY_DIRECTION = "ray_direction" 36 | # TODO [Add description] 37 | OBJECT_BBOX = "object_bbox" 38 | # TODO [Add description] a list of float32 point cloud. 39 | POINT_CLOUD = "point_cloud" 40 | # [B, N, (3 + 3x3)], N future positions. For the last dim, 41 | # the first 3 are xyz locations, and tha last 9 are rots 42 | # B corresponds to the number of timestamps for the base camera type 43 | TRAJECTORY = "trajectory" 44 | # [V,] dictionary of meta data 45 | META_DATA = "meta_data" 46 | # [V, N, C] N is variable for different V float32 47 | LANGUAGE_EMBEDDING = "language_embedding" 48 | # [B, C, T, H, W], float32, latent image 49 | LATENT_RGB = "latent_rgb" 50 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/lazy_config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | from cosmos_predict1.utils.lazy_config.instantiate import instantiate 7 | from cosmos_predict1.utils.lazy_config.lazy import LazyCall, LazyConfig 8 | from cosmos_predict1.utils.lazy_config.omegaconf_patch import to_object 9 | 10 | OmegaConf.to_object = to_object 11 | 12 | PLACEHOLDER = None 13 | LazyDict = DictConfig 14 | 15 | __all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"] 16 | 17 | 18 | DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py 19 | 20 | 21 | def fixup_module_metadata(module_name, namespace, keys=None): 22 | """ 23 | Fix the __qualname__ of module members to be their exported api name, so 24 | when they are referenced in docs, sphinx can find them. Reference: 25 | https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 26 | """ 27 | if not DOC_BUILDING: 28 | return 29 | seen_ids = set() 30 | 31 | def fix_one(qualname, name, obj): 32 | # avoid infinite recursion (relevant when using 33 | # typing.Generic, for example) 34 | if id(obj) in seen_ids: 35 | return 36 | seen_ids.add(id(obj)) 37 | 38 | mod = getattr(obj, "__module__", None) 39 | if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): 40 | obj.__module__ = module_name 41 | # Modules, unlike everything else in Python, put fully-qualitied 42 | # names into their __name__ attribute. We check for "." to avoid 43 | # rewriting these. 44 | if hasattr(obj, "__name__") and "." not in obj.__name__: 45 | obj.__name__ = name 46 | obj.__qualname__ = qualname 47 | if isinstance(obj, type): 48 | for attr_name, attr_value in obj.__dict__.items(): 49 | fix_one(objname + "." + attr_name, attr_name, attr_value) 50 | 51 | if keys is None: 52 | keys = namespace.keys() 53 | for objname in keys: 54 | if not objname.startswith("_"): 55 | obj = namespace[objname] 56 | fix_one(objname, objname, obj) 57 | 58 | 59 | fixup_module_metadata(__name__, globals(), __all__) 60 | del fixup_module_metadata 61 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import warnings 17 | 18 | import attrs 19 | 20 | from cosmos_predict1.utils import log 21 | from cosmos_predict1.utils.config import CheckpointConfig as BaseCheckpointConfig 22 | from cosmos_predict1.utils.config import make_freezable 23 | from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer as BaseFSDPCheckpointer 24 | 25 | 26 | @make_freezable 27 | @attrs.define(slots=False) 28 | class CheckpointConfig(BaseCheckpointConfig): 29 | load_ema_to_reg: bool = False 30 | 31 | 32 | class FSDPCheckpointer(BaseFSDPCheckpointer): 33 | def __init__(self, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | if not isinstance(self.config_checkpoint, CheckpointConfig): 36 | warnings.warn( 37 | "The 'config_checkpoint' is not an instance of 'CheckpointConfig'. " 38 | "This behavior is deprecated and will not be supported in future versions. " 39 | "Please update 'config_checkpoint' to be of type 'CheckpointConfig'.", 40 | DeprecationWarning, 41 | ) 42 | 43 | self.load_ema_to_reg = False 44 | else: 45 | self.load_ema_to_reg = self.config_checkpoint.load_ema_to_reg 46 | 47 | log.critical(f"load_ema_to_reg: {self.load_ema_to_reg}", rank0_only=False) 48 | 49 | def load_model_during_init(self, model, is_ema: bool = False, ema_id: int = 0): 50 | if self.load_ema_to_reg and is_ema is False: 51 | is_ema = True 52 | ema_id = 0 53 | log.critical("Loading EMA model to regular model during initialization.", rank0_only=False) 54 | super().load_model_during_init(model, is_ema, ema_id) 55 | -------------------------------------------------------------------------------- /src/utils/random_state_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import random 18 | import numpy as np 19 | import torch 20 | from accelerate.utils import ( 21 | is_xpu_available, 22 | is_torch_xla_available, 23 | ) 24 | 25 | if is_torch_xla_available(): 26 | import torch_xla.core.xla_model as xm 27 | 28 | RNG_STATE_NAME = "random_states" 29 | 30 | def save_random_state(output_dir, process_index): 31 | states = {} 32 | states_name = f"{RNG_STATE_NAME}_{process_index}.pkl" 33 | states["random_state"] = random.getstate() 34 | states["numpy_random_seed"] = np.random.get_state() 35 | states["torch_manual_seed"] = torch.get_rng_state() 36 | if is_xpu_available(): 37 | states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() 38 | else: 39 | states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() 40 | if is_torch_xla_available(): 41 | states["xm_seed"] = xm.get_rng_state() 42 | output_states_file = os.path.join(output_dir, states_name) 43 | torch.save(states, output_states_file) 44 | 45 | def load_random_state(input_dir, process_index): 46 | try: 47 | states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl")) 48 | random.setstate(states["random_state"]) 49 | np.random.set_state(states["numpy_random_seed"]) 50 | torch.set_rng_state(states["torch_manual_seed"]) 51 | if is_xpu_available(): 52 | torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) 53 | else: 54 | torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) 55 | if is_torch_xla_available(): 56 | xm.set_rng_state(states["xm_seed"]) 57 | except Exception: 58 | print(f"Failed to load random states from {input_dir}") -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | from typing import Dict, Optional 18 | 19 | import torch 20 | 21 | from cosmos_predict1.diffusion.conditioner import BaseVideoCondition, GeneralConditioner 22 | from cosmos_predict1.diffusion.config.base.conditioner import ( 23 | FPSConfig, 24 | ImageSizeConfig, 25 | LatentConditionConfig, 26 | LatentConditionSigmaConfig, 27 | NumFramesConfig, 28 | PaddingMaskConfig, 29 | TextConfig, 30 | ) 31 | from cosmos_predict1.utils.lazy_config import LazyCall as L 32 | from cosmos_predict1.utils.lazy_config import LazyDict 33 | 34 | 35 | @dataclass 36 | class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): 37 | # latent_condition will concat to the input of network, along channel dim; 38 | # cfg will make latent_condition all zero padding. 39 | latent_condition: Optional[torch.Tensor] = None 40 | latent_condition_sigma: Optional[torch.Tensor] = None 41 | 42 | 43 | class VideoDiffusionDecoderConditioner(GeneralConditioner): 44 | def forward( 45 | self, 46 | batch: Dict, 47 | override_dropout_rate: Optional[Dict[str, float]] = None, 48 | ) -> VideoLatentDiffusionDecoderCondition: 49 | output = super()._forward(batch, override_dropout_rate) 50 | return VideoLatentDiffusionDecoderCondition(**output) 51 | 52 | 53 | VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)( 54 | text=TextConfig(), 55 | fps=FPSConfig(), 56 | num_frames=NumFramesConfig(), 57 | image_size=ImageSizeConfig(), 58 | padding_mask=PaddingMaskConfig(), 59 | latent_condition=LatentConditionConfig(), 60 | latent_condition_sigma=LatentConditionSigmaConfig(), 61 | ) 62 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/checkpoint.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """checkpoints config options: 17 | 18 | CHECKPOINT_LOCAL: store at local file system 19 | 20 | """ 21 | import attrs 22 | 23 | from cosmos_predict1.utils import config 24 | from cosmos_predict1.utils.config import make_freezable 25 | from cosmos_predict1.utils.lazy_config import LazyDict 26 | 27 | 28 | @make_freezable 29 | @attrs.define(slots=False) 30 | class ExperimentConfig: 31 | # Enables enforcing experiment naming. 32 | enabled: bool = True 33 | # The project, e.g. edify_video4. 34 | project: str = None 35 | # The valid groups, e.g ["video"]. 36 | groups: list[str] = None 37 | # The approved name prefixes, e.g. ["DV1024", "DI256"]. 38 | name_prefixes: list[str] = None 39 | 40 | 41 | @make_freezable 42 | @attrs.define(slots=False) 43 | class TokenizerCheckpointConfig(config.CheckpointConfig): 44 | # Experiment naming configs. 45 | experiment: ExperimentConfig = attrs.field(factory=ExperimentConfig) 46 | 47 | 48 | jit_config = config.JITConfig( 49 | enabled=True, 50 | input_shape=[1, 3, 1024, 1024], 51 | ) 52 | 53 | experiment_config = ExperimentConfig( 54 | enabled=True, 55 | project="cosmos_tokenizer", 56 | groups=["debug", "video"], 57 | name_prefixes=[ 58 | f"{base}{size}" if base in ["CI", "DI"] else f"{base}{size}_Causal" 59 | for base in ["CI", "DI", "CV", "DV"] 60 | for size in [256, 320, 480, 512, 720, 1024, 1080] 61 | ] 62 | + [f"{base}{size}" for base in ["CV", "DV"] for size in [256, 320, 512, 720]] 63 | + ["mock"], 64 | ) 65 | 66 | CHECKPOINT_LOCAL: LazyDict = attrs.asdict( 67 | TokenizerCheckpointConfig( 68 | save_iter=5000, 69 | jit=jit_config, 70 | experiment=experiment_config, 71 | ) 72 | ) 73 | -------------------------------------------------------------------------------- /scripts/test_environment.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import importlib 18 | import os 19 | import sys 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "--training", 26 | action="store_true", 27 | help="Whether to check training-specific dependencies", 28 | ) 29 | return parser.parse_args() 30 | 31 | 32 | def check_packages(package_list): 33 | global all_success 34 | for package in package_list: 35 | try: 36 | _ = importlib.import_module(package) 37 | except Exception as e: 38 | print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m") 39 | all_success = False 40 | else: 41 | print(f"\033[92m[SUCCESS]\033[0m {package} found") 42 | 43 | 44 | args = parse_args() 45 | 46 | if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): 47 | detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 48 | print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m") 49 | sys.exit(1) 50 | 51 | if "CONDA_PREFIX" not in os.environ: 52 | print("\033[93m[WARNING]\033[0m Cosmos should be run under a conda environment.") 53 | 54 | print("Attempting to import critical packages...") 55 | 56 | packages = [ 57 | "torch", 58 | "torchvision", 59 | "diffusers", 60 | "transformers", 61 | "megatron.core", 62 | "transformer_engine", 63 | ] 64 | packages_training = [ 65 | "apex.multi_tensor_apply", 66 | ] 67 | all_success = True 68 | 69 | check_packages(packages) 70 | if args.training: 71 | check_packages(packages_training) 72 | 73 | if all_success: 74 | print("-----------------------------------------------------------") 75 | print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!") 76 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/configs/base/dataloader.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from megatron.core import parallel_state 17 | from torch.utils.data import DataLoader, DistributedSampler 18 | 19 | from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig 20 | from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset 21 | from cosmos_predict1.utils import log 22 | from cosmos_predict1.utils.lazy_config import LazyCall as L 23 | 24 | DATALOADER_OPTIONS = {} 25 | 26 | 27 | def get_sampler(dataset): 28 | return DistributedSampler( 29 | dataset, 30 | num_replicas=parallel_state.get_data_parallel_world_size(), 31 | rank=parallel_state.get_data_parallel_rank(), 32 | shuffle=True, 33 | seed=0, 34 | ) 35 | 36 | 37 | def dataloader_register(key): 38 | log.info(f"registering dataloader {key}...") 39 | 40 | def decorator(func): 41 | DATALOADER_OPTIONS[key] = func 42 | return func 43 | 44 | return decorator 45 | 46 | 47 | @dataloader_register("tealrobot_video") 48 | def get_tealrobot_video( 49 | batch_size: int = 1, 50 | dataset_dir: str = "datasets/cosmos_nemo_assets/videos/", 51 | sequence_interval: int = 1, 52 | num_frames: int = 33, 53 | video_size: list[int, int] = [640, 848], 54 | start_frame_interval: int = 1, 55 | ): 56 | dataset = L(VideoDataset)( 57 | config=VideoDatasetConfig( 58 | dataset_dir=dataset_dir, 59 | sequence_interval=sequence_interval, 60 | num_frames=num_frames, 61 | video_size=video_size, 62 | start_frame_interval=start_frame_interval, 63 | ) 64 | ) 65 | return L(DataLoader)( 66 | dataset=dataset, 67 | sampler=L(get_sampler)(dataset=dataset), 68 | batch_size=batch_size, 69 | drop_last=True, 70 | pin_memory=True, 71 | num_workers=8, 72 | ) 73 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/base/data.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """dataloader config options 17 | 18 | Available dataloader options: 19 | image_loader_basic 20 | video_loader_basic 21 | joint_image_video_loader_basic 22 | """ 23 | 24 | from torch.utils.data import DataLoader 25 | 26 | from cosmos_predict1.tokenizer.training.configs.base.mock_data import get_mock_video_dataloader 27 | from cosmos_predict1.tokenizer.training.datasets.dataset_provider import dataset_entry 28 | from cosmos_predict1.utils.lazy_config import LazyCall 29 | 30 | DATALOADER_OPTIONS = {} 31 | 32 | 33 | def dataloader_register(key): 34 | def decorator(func): 35 | DATALOADER_OPTIONS[key] = func 36 | return func 37 | 38 | return decorator 39 | 40 | 41 | @dataloader_register("video_loader_basic") 42 | def get_video_dataloader( 43 | dataset_name, 44 | is_train, 45 | batch_size=1, 46 | num_video_frames=25, 47 | resolution="720", 48 | crop_height=128, 49 | num_workers=8, 50 | ): 51 | if dataset_name.startswith("mock"): 52 | return get_mock_video_dataloader( 53 | batch_size=batch_size, 54 | is_train=is_train, 55 | num_video_frames=num_video_frames, 56 | resolution=resolution, 57 | crop_height=crop_height, 58 | ) 59 | return LazyCall(DataLoader)( 60 | dataset=LazyCall(dataset_entry)( 61 | dataset_name=dataset_name, 62 | dataset_type="video", 63 | is_train=is_train, 64 | resolution=resolution, 65 | crop_height=crop_height, 66 | num_video_frames=num_video_frames, 67 | ), 68 | batch_size=batch_size, # 2 69 | num_workers=num_workers, # 8 70 | prefetch_factor=2, 71 | shuffle=None, # do we need this? 72 | sampler=None, 73 | persistent_workers=False, 74 | pin_memory=True, 75 | ) 76 | -------------------------------------------------------------------------------- /src/models/data/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import time 17 | import torch 18 | from src.models.data.provider import Provider 19 | 20 | def get_multi_dataloader(opt, accelerator=None): 21 | train_datasets, test_datasets = get_datasets(opt, accelerator) 22 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 23 | 24 | train_dataloader = torch.utils.data.DataLoader( 25 | train_dataset, 26 | batch_size=opt.batch_size, 27 | shuffle=True, 28 | num_workers=opt.num_workers, 29 | pin_memory=True, 30 | drop_last=True, 31 | ) 32 | 33 | test_dataset = torch.utils.data.ConcatDataset(test_datasets) 34 | test_dataloader = torch.utils.data.DataLoader( 35 | test_dataset, 36 | batch_size=opt.batch_size, 37 | shuffle=False, 38 | num_workers=opt.num_workers, 39 | pin_memory=True, 40 | drop_last=False, 41 | ) 42 | 43 | 44 | return train_dataloader, test_dataloader 45 | 46 | def get_datasets(opt, accelerator=None): 47 | train_datasets = [] 48 | test_datasets = [] 49 | 50 | for idx in range(len(opt.data_mode)): 51 | begin_time = time.time() 52 | if isinstance(opt.data_mode[idx], str): 53 | dataset_name, num_repeat = opt.data_mode[idx], 1 54 | else: 55 | dataset_name, num_repeat = opt.data_mode[idx] 56 | 57 | train_dataset = Provider(dataset_name, opt, training=True, num_repeat=num_repeat) 58 | train_datasets.append(train_dataset) 59 | 60 | test_dataset = Provider(dataset_name, opt, training=False, num_repeat=num_repeat) 61 | test_datasets.append(test_dataset) 62 | if accelerator is None or accelerator.is_main_process: 63 | print(f"Loaded {dataset_name}, train size: {len(train_dataset)}, test size: {len(test_dataset)}, loading took {time.time() - begin_time} seconds") 64 | 65 | return train_datasets, test_datasets 66 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to receive your patches and contributions. Please keep your PRs as draft until such time that you would like us to review them. 4 | 5 | ## Code Reviews 6 | 7 | All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult 8 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. 9 | 10 | ## Signing Your Work 11 | 12 | * We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. 13 | 14 | * Any contribution which contains commits that are not Signed-Off will not be accepted. 15 | 16 | * To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: 17 | ```bash 18 | $ git commit -s -m "Add cool feature." 19 | ``` 20 | This will append the following to your commit message: 21 | ``` 22 | Signed-off-by: Your Name 23 | ``` 24 | 25 | * Full text of the DCO: 26 | 27 | ``` 28 | Developer Certificate of Origin 29 | Version 1.1 30 | 31 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 32 | 1 Letterman Drive 33 | Suite D4700 34 | San Francisco, CA, 94129 35 | 36 | Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. 37 | ``` 38 | 39 | ``` 40 | Developer's Certificate of Origin 1.1 41 | 42 | By making a contribution to this project, I certify that: 43 | 44 | (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or 45 | 46 | (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or 47 | 48 | (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. 49 | 50 | (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. 51 | ``` -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, List 17 | 18 | import attrs 19 | 20 | from cosmos_predict1.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs 21 | from cosmos_predict1.diffusion.config.base.model import LatentDiffusionDecoderModelConfig 22 | from cosmos_predict1.diffusion.config.registry import register_configs 23 | from cosmos_predict1.utils import config 24 | from cosmos_predict1.utils.config_helper import import_all_modules_from_package 25 | 26 | 27 | @attrs.define(slots=False) 28 | class Config(config.Config): 29 | # default config groups that will be used unless overwritten 30 | # see config groups in registry.py 31 | defaults: List[Any] = attrs.field( 32 | factory=lambda: [ 33 | "_self_", 34 | {"net": None}, 35 | {"conditioner": "basic"}, 36 | {"tokenizer": "tokenizer"}, 37 | {"tokenizer_corruptor": None}, 38 | {"latent_corruptor": None}, 39 | {"pixel_corruptor": None}, 40 | {"experiment": None}, 41 | ] 42 | ) 43 | 44 | 45 | def make_config(): 46 | c = Config(model=LatentDiffusionDecoderModelConfig()) 47 | 48 | # Specifying values through instances of attrs 49 | c.job.project = "cosmos_video4" 50 | c.job.group = "debug" 51 | c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" 52 | 53 | # Call this function to register config groups for advanced overriding. 54 | register_configs() 55 | register_dd_configs() 56 | 57 | # experiment config are defined in the experiment folder 58 | # call import_all_modules_from_package to register them 59 | import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True) 60 | import_all_modules_from_package("cosmos_predict1.autoregressive.diffusion_decoder.config.inference", reload=True) 61 | return c 62 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/base/net.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | 18 | from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT 19 | from cosmos_predict1.diffusion.networks.general_dit_multiview import MultiviewGeneralDIT 20 | from cosmos_predict1.utils.lazy_config import LazyCall as L 21 | from cosmos_predict1.utils.lazy_config import LazyDict 22 | 23 | FADITV2Config: LazyDict = L(GeneralDIT)( 24 | max_img_h=240, 25 | max_img_w=240, 26 | max_frames=128, 27 | in_channels=16, 28 | out_channels=16, 29 | patch_spatial=2, 30 | patch_temporal=1, 31 | model_channels=4096, 32 | block_config="FA-CA-MLP", 33 | num_blocks=28, 34 | num_heads=32, 35 | concat_padding_mask=True, 36 | pos_emb_cls="rope3d", 37 | pos_emb_learnable=False, 38 | pos_emb_interpolation="crop", 39 | block_x_format="THWBD", 40 | affline_emb_norm=True, 41 | use_adaln_lora=True, 42 | adaln_lora_dim=256, 43 | ) 44 | 45 | 46 | FADITV2_14B_Config = copy.deepcopy(FADITV2Config) 47 | FADITV2_14B_Config.model_channels = 5120 48 | FADITV2_14B_Config.num_heads = 40 49 | FADITV2_14B_Config.num_blocks = 36 50 | 51 | 52 | FADITV2_Multiview_Config: LazyDict = L(MultiviewGeneralDIT)( 53 | max_img_h=240, 54 | max_img_w=240, 55 | max_frames=128, 56 | in_channels=16, 57 | out_channels=16, 58 | patch_spatial=2, 59 | patch_temporal=1, 60 | model_channels=4096, 61 | block_config="FA-CA-MLP", 62 | num_blocks=28, 63 | num_heads=32, 64 | concat_padding_mask=True, 65 | pos_emb_cls="rope3d", 66 | pos_emb_learnable=False, 67 | pos_emb_interpolation="crop", 68 | block_x_format="THWBD", 69 | affline_emb_norm=True, 70 | use_adaln_lora=True, 71 | adaln_lora_dim=256, 72 | n_views=6, 73 | view_condition_dim=6, 74 | add_repeat_frame_embedding=True, 75 | rope_h_extrapolation_ratio=1.0, 76 | rope_w_extrapolation_ratio=1.0, 77 | rope_t_extrapolation_ratio=1.0, 78 | ) 79 | -------------------------------------------------------------------------------- /src/models/data/radym_wrapper.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import os 18 | from typing import Any, List, Optional 19 | from src.models.data.radym import Radym 20 | 21 | class RadymWrapper(Radym): 22 | def __init__(self, is_static: bool = True, is_multi_view: bool = False, **kwargs): 23 | super().__init__(**kwargs) 24 | 25 | # For recon code base 26 | self.is_static = is_static 27 | self.sample_list = self.mp4_file_paths 28 | self.num_cameras = len([camera_name for camera_name in os.listdir(self.root_path) if camera_name != 'flag']) if is_multi_view else 1 29 | if is_multi_view: 30 | self.n_views = self.num_cameras 31 | 32 | def __len__(self): 33 | return len(self.sample_list) 34 | 35 | def count_frames(self, video_idx: int): 36 | return self.num_frames(video_idx) 37 | 38 | def count_cameras(self, video_idx: int): 39 | return self.num_cameras 40 | 41 | def get_data( 42 | self, 43 | idx, 44 | data_fields: List[str], 45 | frame_indices: Optional[List[int]] = None, 46 | view_indices: Optional[List[int]] = None, 47 | camera_convention: str = "opencv", 48 | ): 49 | assert camera_convention == 'opencv', f"No support for camera convention {camera_convention}" 50 | if view_indices is None or len(view_indices) == 0: 51 | view_indices = list(range(self.count_cameras(idx))) 52 | final_dict = None 53 | for view_idx in view_indices: 54 | output_dict = self._read_data( 55 | idx, frame_indices, [view_idx], data_fields, 56 | ) 57 | if final_dict is None: 58 | final_dict = output_dict 59 | else: 60 | for k in final_dict: 61 | if k == "__key__": 62 | continue 63 | final_dict[k] = torch.concatenate([final_dict[k], output_dict[k]]) 64 | return final_dict -------------------------------------------------------------------------------- /cosmos_predict1/utils/lazy_config/registry.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pydoc 17 | from typing import Any 18 | 19 | """ 20 | `locate` provide ways to map a string (typically found 21 | in config files) to callable objects. 22 | """ 23 | 24 | __all__ = ["locate"] 25 | 26 | 27 | def _convert_target_to_string(t: Any) -> str: 28 | """ 29 | Inverse of ``locate()``. 30 | 31 | Args: 32 | t: any object with ``__module__`` and ``__qualname__`` 33 | """ 34 | module, qualname = t.__module__, t.__qualname__ 35 | 36 | # Compress the path to this object, e.g. ``module.submodule._impl.class`` 37 | # may become ``module.submodule.class``, if the later also resolves to the same 38 | # object. This simplifies the string, and also is less affected by moving the 39 | # class implementation. 40 | module_parts = module.split(".") 41 | for k in range(1, len(module_parts)): 42 | prefix = ".".join(module_parts[:k]) 43 | candidate = f"{prefix}.{qualname}" 44 | try: 45 | if locate(candidate) is t: 46 | return candidate 47 | except ImportError: 48 | pass 49 | return f"{module}.{qualname}" 50 | 51 | 52 | def locate(name: str) -> Any: 53 | """ 54 | Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, 55 | such as "module.submodule.class_name". 56 | 57 | Raise Exception if it cannot be found. 58 | """ 59 | obj = pydoc.locate(name) 60 | 61 | # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly 62 | # by pydoc.locate. Try a private function from hydra. 63 | if obj is None: 64 | try: 65 | # from hydra.utils import get_method - will print many errors 66 | from hydra.utils import _locate 67 | except ImportError as e: 68 | raise ImportError(f"Cannot dynamically locate object {name}!") from e 69 | else: 70 | obj = _locate(name) # it raises if fails 71 | 72 | return obj 73 | -------------------------------------------------------------------------------- /scripts/download_guardrail_checkpoints.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import List 18 | 19 | from huggingface_hub import snapshot_download 20 | 21 | 22 | def download_models(models: List[str], destination_root: str): 23 | """ 24 | Download models from Hugging Face Hub and save them in org/project structure. 25 | 26 | Args: 27 | models: List of model IDs in format 'org/project' 28 | destination_root: Root directory where models will be saved 29 | """ 30 | for model_id in models: 31 | model_id, revision = model_id.split(":") if ":" in model_id else (model_id, None) 32 | print(f"Downloading {model_id}...") 33 | 34 | # Create the full path for the model 35 | model_path = os.path.join(destination_root, model_id) 36 | 37 | try: 38 | # Download the model 39 | snapshot_download( 40 | repo_id=model_id, 41 | local_dir=model_path, 42 | revision=revision, 43 | ) 44 | print(f"Successfully downloaded {model_id} to {model_path}") 45 | 46 | except Exception as e: 47 | raise RuntimeError(f"Error downloading {model_id}: {str(e)}. Please delete the directory and try again.") 48 | 49 | 50 | def download_guardrail_checkpoints(destination_root: str): 51 | """ 52 | Download guardrail checkpoints from Hugging Face Hub and save them in org/project structure. 53 | 54 | Args: 55 | destination_root: Root directory where checkpoints will be saved 56 | """ 57 | # List of models to download 58 | models_to_download = [ 59 | "meta-llama/Llama-Guard-3-8B", 60 | "nvidia/Cosmos-Guardrail1", 61 | ] 62 | 63 | # Create the destination directory if it doesn't exist 64 | os.makedirs(destination_root, exist_ok=True) 65 | 66 | # Download the models 67 | download_models(models_to_download, destination_root) 68 | 69 | 70 | if __name__ == "__main__": 71 | download_guardrail_checkpoints("checkpoints") 72 | -------------------------------------------------------------------------------- /src/models/utils/misc.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from omegaconf import OmegaConf 18 | from typing import List 19 | 20 | def load_and_merge_configs(config_paths: List[str]): 21 | """ 22 | Load and merge multiple OmegaConf configs in order. 23 | Later configs override earlier ones. 24 | Any missing keys in later configs are added to the schema as None. 25 | 26 | Args: 27 | config_paths (List[str]): List of paths to config files. 28 | The first config acts as the base schema. 29 | 30 | Returns: 31 | OmegaConf.DictConfig: The merged configuration. 32 | """ 33 | if not config_paths: 34 | raise ValueError("No config paths provided.") 35 | 36 | # Start with the first config as schema 37 | schema = OmegaConf.load(config_paths[0]) 38 | 39 | # Iteratively merge the rest 40 | for path in config_paths[1:]: 41 | cfg = OmegaConf.load(path) 42 | 43 | # Add missing keys into schema 44 | missing_keys = set(cfg.keys()) - set(schema.keys()) 45 | for key in missing_keys: 46 | OmegaConf.update(schema, key, None, force_add=True) 47 | 48 | # Merge current config into schema 49 | schema = OmegaConf.merge(schema, cfg) 50 | 51 | return schema 52 | 53 | 54 | def seed_everything(seed: int): 55 | import random, os 56 | import numpy as np 57 | import torch 58 | 59 | random.seed(seed) 60 | os.environ['PYTHONHASHSEED'] = str(seed) 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed(seed) 64 | # torch.backends.cudnn.deterministic = True 65 | # torch.backends.cudnn.benchmark = True 66 | 67 | dtype_map = { 68 | 'float32': torch.float32, 69 | 'float': torch.float32, 70 | 'float64': torch.float64, 71 | 'double': torch.float64, 72 | 'float16': torch.float16, 73 | 'half': torch.float16, 74 | 'bfloat16': torch.bfloat16, 75 | 'int32': torch.int32, 76 | 'int64': torch.int64, 77 | 'long': torch.int64, 78 | } 79 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | config_file=configs/accelerate/accelerate_config.yaml 17 | train_file=train.py 18 | 19 | ### Static 3D Generation ### 20 | 21 | # Stage 1 22 | config1=configs/training/3dgs_res_176_320_views_17.yaml 23 | # accelerate launch --config_file $config_file $train_file --config $config1 24 | 25 | # Stage 2 26 | config2=configs/training/3dgs_res_176_320_views_49.yaml 27 | python src/utils/copy_to_resume.py $config1 $config2 28 | # accelerate launch --config_file $config_file $train_file --config $config2 29 | 30 | # Stage 3 31 | config3=configs/training/3dgs_res_352_640_views_49.yaml 32 | python src/utils/copy_to_resume.py $config2 $config3 33 | # accelerate launch --config_file $config_file $train_file --config $config3 34 | 35 | # Stage 4 36 | config4=configs/training/3dgs_res_704_1280_views_49.yaml 37 | python src/utils/copy_to_resume.py $config3 $config4 38 | # accelerate launch --config_file $config_file $train_file --config $config4 39 | 40 | # Stage 5 41 | config5=configs/training/3dgs_res_704_1280_views_121.yaml 42 | python src/utils/copy_to_resume.py $config4 $config5 43 | # accelerate launch --config_file $config_file $train_file --config $config5 44 | 45 | # Stage 6 46 | config6=configs/training/3dgs_res_704_1280_views_121_multi_6.yaml 47 | python src/utils/copy_to_resume.py $config5 $config6 48 | # accelerate launch --config_file $config_file $train_file --config $config6 49 | 50 | # Stage 7 51 | config7=configs/training/3dgs_res_704_1280_views_121_multi_6_prune.yaml 52 | python src/utils/copy_to_resume.py $config6 $config7 53 | # accelerate launch --config_file $config_file $train_file --config $config7 54 | 55 | ### Dynamic 3D Generation ### 56 | # Stage 8 57 | config8=configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml 58 | python src/utils/copy_to_resume.py $config6 $config8 59 | accelerate launch --config_file $config_file $train_file --config $config8 60 | 61 | config9=configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml 62 | python src/utils/copy_to_resume.py $config8 $config9 63 | accelerate launch --config_file $config_file $train_file --config $config9 64 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/device.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | import os 18 | 19 | import pynvml 20 | 21 | 22 | class Device: 23 | """A class to handle NVIDIA GPU device operations using NVML. 24 | 25 | This class provides an interface to access and manage NVIDIA GPU devices, 26 | including retrieving device information and CPU affinity settings. 27 | 28 | Attributes: 29 | _nvml_affinity_elements (int): Number of 64-bit elements needed to represent CPU affinity 30 | """ 31 | 32 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore 33 | 34 | def __init__(self, device_idx: int): 35 | """Initialize a Device instance for a specific GPU. 36 | 37 | Args: 38 | device_idx (int): Index of the GPU device to manage 39 | 40 | Raises: 41 | NVMLError: If the device cannot be found or initialized 42 | """ 43 | super().__init__() 44 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) 45 | 46 | def get_cpu_affinity(self) -> list[int]: 47 | """Get the CPU affinity mask for this GPU device. 48 | 49 | Retrieves the CPU affinity mask indicating which CPU cores are assigned 50 | to this GPU device. The affinity is returned as a list of CPU core indices. 51 | 52 | Returns: 53 | list[int]: List of CPU core indices that have affinity with this GPU 54 | 55 | Raises: 56 | NVMLError: If the CPU affinity information cannot be retrieved 57 | 58 | Example: 59 | >>> device = Device(0) 60 | >>> device.get_cpu_affinity() 61 | [0, 1, 2, 3] # Shows this GPU has affinity with CPU cores 0-3 62 | """ 63 | affinity_string = "" 64 | for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): 65 | # assume nvml returns list of 64 bit ints 66 | affinity_string = "{:064b}".format(j) + affinity_string 67 | affinity_list = [int(x) for x in affinity_string] 68 | affinity_list.reverse() # so core 0 is in 0th element of list 69 | return [i for i, e in enumerate(affinity_list) if e != 0] 70 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/config/base/model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Dict, List, Optional 17 | 18 | import attrs 19 | 20 | from cosmos_predict1.utils.lazy_config import LazyDict 21 | 22 | 23 | @attrs.define(slots=False) 24 | class DefaultModelConfig: 25 | tokenizer: LazyDict = None 26 | conditioner: LazyDict = None 27 | net: LazyDict = None 28 | sigma_data: float = 0.5 29 | precision: str = "bfloat16" 30 | input_data_key: str = "video" # key to fetch input data from data_batch 31 | latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames 32 | input_image_key: str = "images_1024" 33 | adjust_video_noise: bool = False # Added field with default value 34 | context_parallel_size: int = 1 # Added field with default value 35 | # `num_latents_to_drop` is a flag that helps satisfy (1I,N*P,1I) latents setup. 36 | # Since the tokenizer is causal and has the `T+1` input frames setup, it's 37 | # challenging to encode arbitrary number of frames. To circumvent this, 38 | # we sample as many frames, run the tokenizer twice, and discard the last 39 | # chunk's P-latents, ensuring the requirement: I-latents for the input frames 40 | # and P-latent for the-to-be-predicted in-between frames. 41 | # By default, this flag does not have any effect. 42 | num_latents_to_drop: int = 0 # number of P-latents to discard after encoding 43 | 44 | sde: Optional[Dict] = None 45 | vae: Optional[Dict] = None # Add this line to include the vae field 46 | peft_control: LazyDict | None = None 47 | frame_buffer_max: Optional[int] = 1 48 | 49 | 50 | @attrs.define(slots=False) 51 | class LatentDiffusionDecoderModelConfig(DefaultModelConfig): 52 | tokenizer_corruptor: LazyDict = None 53 | latent_corruptor: LazyDict = None 54 | pixel_corruptor: LazyDict = None 55 | diffusion_decoder_cond_sigma_low: float = None 56 | diffusion_decoder_cond_sigma_high: float = None 57 | diffusion_decoder_corrupt_prob: float = None 58 | condition_on_tokenizer_corruptor_token: bool = False 59 | 60 | 61 | @attrs.define(slots=False) 62 | class MultiviewModelConfig(DefaultModelConfig): 63 | n_views: int = 4 64 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/lazy_config/omegaconf_patch.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Dict, List, Union 17 | 18 | from omegaconf import OmegaConf 19 | from omegaconf.base import DictKeyType, SCMode 20 | from omegaconf.dictconfig import DictConfig # pragma: no cover 21 | 22 | 23 | def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: 24 | """ 25 | Converts an OmegaConf configuration object to a native Python container (dict or list), unless 26 | the configuration is specifically created by LazyCall, in which case the original configuration 27 | is returned directly. 28 | 29 | This function serves as a modification of the original `to_object` method from OmegaConf, 30 | preventing DictConfig objects created by LazyCall from being automatically converted to Python 31 | dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended 32 | structure and behavior. 33 | 34 | Differences from OmegaConf's original `to_object`: 35 | - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. 36 | 37 | Reference: 38 | - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 39 | 40 | Args: 41 | cfg (Any): The OmegaConf configuration object to convert. 42 | 43 | Returns: 44 | Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if 45 | `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. 46 | 47 | Examples: 48 | >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) 49 | >>> to_object(cfg) 50 | DictConfig({"key": "value", "_target_": "Model"}) 51 | 52 | >>> cfg = DictConfig({"list": [1, 2, 3]}) 53 | >>> to_object(cfg) 54 | {'list': [1, 2, 3]} 55 | """ 56 | if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): 57 | return cfg 58 | 59 | return OmegaConf.to_container( 60 | cfg=cfg, 61 | resolve=True, 62 | throw_on_missing=True, 63 | enum_to_str=False, 64 | structured_config_mode=SCMode.INSTANTIATE, 65 | ) 66 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from typing import List 18 | 19 | import torch 20 | 21 | 22 | class WarmupLambdaLR(torch.optim.lr_scheduler.LambdaLR): 23 | def __init__(self, optimizer, warmup, last_epoch=-1, verbose=False): 24 | # Define the lambda function based on the warmup period 25 | self.warmup = warmup 26 | 27 | def lr_lambda(epoch): 28 | # Increase lr linearly for the first 'warmup' epochs 29 | if epoch < warmup: 30 | return float(epoch + 1) / warmup 31 | # After 'warmup' epochs, keep lr constant 32 | return 1.0 33 | 34 | # Initialize the parent class with the generated lr_lambda 35 | super(WarmupLambdaLR, self).__init__(optimizer, lr_lambda, last_epoch, verbose) 36 | 37 | 38 | # cosine lr decay scheduler with warmup from https://github.com/karpathy/nanoGPT/blob/master/train.py#L228 39 | class WarmupCosineLR(torch.optim.lr_scheduler.LRScheduler): 40 | def __init__( 41 | self, 42 | optimizer: torch.optim.Optimizer, 43 | warmup_iters: int, 44 | lr_decay_iters: int, 45 | min_lr: float, 46 | last_epoch: int = -1, 47 | ): 48 | self.warmup_iters = warmup_iters 49 | self.lr_decay_iters = lr_decay_iters 50 | self.min_lr = min_lr 51 | super().__init__(optimizer, last_epoch) 52 | 53 | def get_lr(self) -> List[float]: 54 | # 1) linear warmup for warmup_iters steps 55 | if self.last_epoch < self.warmup_iters: 56 | return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs] 57 | # 2) if it > lr_decay_iters, return min learning rate 58 | if self.last_epoch > self.lr_decay_iters: 59 | return [self.min_lr for _ in self.base_lrs] 60 | # 3) in between, use cosine decay down to min learning rate 61 | decay_ratio = (self.last_epoch - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters) 62 | assert 0 <= decay_ratio <= 1 63 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 64 | return [self.min_lr + coeff * (base_lr - self.min_lr) for base_lr in self.base_lrs] 65 | -------------------------------------------------------------------------------- /cosmos_predict1/callbacks/grad_clip.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List, Optional 17 | 18 | import torch 19 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 20 | 21 | from cosmos_predict1.utils import distributed 22 | from cosmos_predict1.utils.callback import Callback 23 | 24 | 25 | @torch.jit.script 26 | def _fused_nan_to_num(params: List[torch.Tensor]): 27 | for param in params: 28 | torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) 29 | 30 | 31 | class GradClip(Callback): 32 | def __init__( 33 | self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False 34 | ): 35 | self.clip_norm = clip_norm 36 | self.force_finite = force_finite 37 | self.model_key = model_key 38 | self.fsdp_enabled = fsdp_enabled 39 | 40 | def on_before_optimizer_step( 41 | self, 42 | model_ddp: distributed.DistributedDataParallel, 43 | optimizer: torch.optim.Optimizer, 44 | scheduler: torch.optim.lr_scheduler.LRScheduler, 45 | grad_scaler: torch.amp.GradScaler, 46 | iteration: int = 0, 47 | ) -> None: 48 | del optimizer, scheduler 49 | if isinstance(model_ddp, distributed.DistributedDataParallel): 50 | model = model_ddp.module 51 | else: 52 | model = model_ddp 53 | 54 | # select sub-network if specified 55 | if self.model_key is not None: 56 | items = self.model_key.split(".") 57 | for item in items: 58 | model = getattr(model, item) 59 | 60 | if self.force_finite: 61 | params = [] 62 | for param in model.parameters(): 63 | if param.grad is not None: 64 | params.append(param.grad) 65 | # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) 66 | _fused_nan_to_num(params) 67 | 68 | # check if FSDP is used 69 | # total_norm 70 | if isinstance(model, FSDP) and self.fsdp_enabled: 71 | model.clip_grad_norm_(self.clip_norm) 72 | else: 73 | torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) 74 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/callbacks/grad_clip.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List, Optional 17 | 18 | import torch 19 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 20 | 21 | from cosmos_predict1.utils import distributed 22 | from cosmos_predict1.utils.callback import Callback 23 | 24 | 25 | @torch.jit.script 26 | def _fused_nan_to_num(params: List[torch.Tensor]): 27 | for param in params: 28 | torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) 29 | 30 | 31 | class GradClip(Callback): 32 | def __init__( 33 | self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False 34 | ): 35 | self.clip_norm = clip_norm 36 | self.force_finite = force_finite 37 | self.model_key = model_key 38 | self.fsdp_enabled = fsdp_enabled 39 | 40 | def on_before_optimizer_step( 41 | self, 42 | model_ddp: distributed.DistributedDataParallel, 43 | optimizer: torch.optim.Optimizer, 44 | scheduler: torch.optim.lr_scheduler.LRScheduler, 45 | grad_scaler: torch.amp.GradScaler, 46 | iteration: int = 0, 47 | ) -> None: 48 | del optimizer, scheduler 49 | if isinstance(model_ddp, distributed.DistributedDataParallel): 50 | model = model_ddp.module 51 | else: 52 | model = model_ddp 53 | 54 | # select sub-network if specified 55 | if self.model_key is not None: 56 | items = self.model_key.split(".") 57 | for item in items: 58 | model = getattr(model, item) 59 | 60 | if self.force_finite: 61 | params = [] 62 | for param in model.parameters(): 63 | if param.grad is not None: 64 | params.append(param.grad) 65 | # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) 66 | _fused_nan_to_num(params) 67 | 68 | # check if FSDP is used 69 | # total_norm 70 | if isinstance(model, FSDP) and self.fsdp_enabled: 71 | model.clip_grad_norm_(self.clip_norm) 72 | else: 73 | torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) 74 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/config/base/model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List 17 | 18 | import attrs 19 | 20 | from cosmos_predict1.diffusion.training.config.base.ema import PowerEMAConfig 21 | from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE 22 | from cosmos_predict1.utils.lazy_config import LazyCall as L 23 | from cosmos_predict1.utils.lazy_config import LazyDict 24 | 25 | 26 | @attrs.define(slots=False) 27 | class FSDPConfig: 28 | policy: str = "block" 29 | checkpoint: bool = False 30 | min_num_params: int = 1024 31 | sharding_group_size: int = 8 32 | sharding_strategy: str = "full" 33 | 34 | 35 | @attrs.define(slots=False) 36 | class DefaultModelConfig: 37 | vae: LazyDict = None 38 | conditioner: LazyDict = None 39 | net: LazyDict = None 40 | ema: LazyDict = PowerEMAConfig 41 | sde: LazyDict = L(EDMSDE)( 42 | p_mean=0.0, 43 | p_std=1.0, 44 | sigma_max=80, 45 | sigma_min=0.0002, 46 | ) 47 | sigma_data: float = 0.5 48 | camera_sample_weight: LazyDict = LazyDict( 49 | dict( 50 | enabled=False, 51 | weight=5.0, 52 | ) 53 | ) 54 | aesthetic_finetuning: LazyDict = LazyDict( 55 | dict( 56 | enabled=False, 57 | ) 58 | ) 59 | loss_mask_enabled: bool = False 60 | loss_masking: LazyDict = None 61 | loss_add_logvar: bool = True 62 | precision: str = "bfloat16" 63 | input_data_key: str = "video" # key to fetch input data from data_batch 64 | input_image_key: str = "images_1024" # key to fetch input image from data_batch 65 | loss_reduce: str = "sum" 66 | loss_scale: float = 1.0 67 | latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames 68 | fsdp_enabled: bool = False 69 | use_torch_compile: bool = False 70 | fsdp: FSDPConfig = attrs.field(factory=FSDPConfig) 71 | use_dummy_temporal_dim: bool = False # Whether to use dummy temporal dimension in data 72 | adjust_video_noise: bool = False # whether or not adjust video noise accroding to the video length 73 | peft_control: LazyDict | None = None 74 | 75 | 76 | @attrs.define(slots=False) 77 | class MultiviewModelConfig(DefaultModelConfig): 78 | n_views: int = 6 79 | -------------------------------------------------------------------------------- /cosmos_predict1/diffusion/training/utils/optim_instantiate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import hydra 17 | import torch 18 | from torch import nn 19 | 20 | from cosmos_predict1.utils import log 21 | from cosmos_predict1.utils.fused_adam import FusedAdam 22 | 23 | 24 | def get_regular_param_group(net: nn.Module): 25 | """ 26 | seperate the parameters of the network into two groups: decay and no_decay. 27 | based on nano_gpt codebase. 28 | """ 29 | param_dict = {pn: p for pn, p in net.named_parameters()} 30 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 31 | 32 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 33 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 34 | return decay_params, nodecay_params 35 | 36 | 37 | def get_base_optimizer( 38 | model: nn.Module, 39 | lr: float, 40 | weight_decay: float, 41 | optim_type: str = "adamw", 42 | sharding: bool = False, 43 | **kwargs, 44 | ) -> torch.optim.Optimizer: 45 | net_decay_param, net_nodecay_param = get_regular_param_group(model) 46 | 47 | num_decay_params = sum(p.numel() for p in net_decay_param) 48 | num_nodecay_params = sum(p.numel() for p in net_nodecay_param) 49 | net_param_total = num_decay_params + num_nodecay_params 50 | log.critical(f"total num parameters : {net_param_total:,}") 51 | 52 | param_group = [ 53 | { 54 | "params": net_decay_param + net_nodecay_param, 55 | "lr": lr, 56 | "weight_decay": weight_decay, 57 | }, 58 | ] 59 | 60 | if optim_type == "adamw": 61 | opt_cls = torch.optim.AdamW 62 | elif optim_type == "fusedadam": 63 | opt_cls = FusedAdam 64 | else: 65 | raise ValueError(f"Unknown optimizer type: {optim_type}") 66 | 67 | return opt_cls(param_group, **kwargs) 68 | 69 | 70 | def get_base_scheduler( 71 | optimizer: torch.optim.Optimizer, 72 | model: nn.Module, 73 | scheduler_config: dict, 74 | ): 75 | net_scheduler = hydra.utils.instantiate(scheduler_config) 76 | net_scheduler.model = model 77 | 78 | return torch.optim.lr_scheduler.LambdaLR( 79 | optimizer, 80 | lr_lambda=[ 81 | net_scheduler.schedule, 82 | ], 83 | ) 84 | -------------------------------------------------------------------------------- /cosmos_predict1/tokenizer/training/configs/config.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Default config for cosmos/tokenizer project.""" 17 | 18 | from typing import Any, List 19 | 20 | import attrs 21 | 22 | from cosmos_predict1.tokenizer.training.configs.base.model import DefaultModelConfig 23 | from cosmos_predict1.tokenizer.training.configs.registry import register_configs 24 | from cosmos_predict1.tokenizer.training.trainer import TokenizerTrainer 25 | from cosmos_predict1.utils import config 26 | from cosmos_predict1.utils.config_helper import import_all_modules_from_package 27 | 28 | 29 | @attrs.define(slots=False) 30 | class Config(config.Config): 31 | defaults: List[Any] = attrs.field( 32 | factory=lambda: [ 33 | "_self_", 34 | {"data_train": "mock_video720"}, 35 | {"data_val": "mock_video720"}, 36 | {"optimizer": "fused_adam"}, 37 | {"scheduler": "warmup"}, 38 | {"network": "continuous_factorized_video"}, 39 | {"loss": "video"}, 40 | {"metric": "reconstruction"}, 41 | {"checkpoint": "local"}, 42 | {"callbacks": "basic"}, 43 | {"experiment": None}, 44 | ] 45 | ) 46 | 47 | 48 | def make_config(): 49 | c = Config( 50 | model=DefaultModelConfig, 51 | optimizer=None, 52 | scheduler=None, 53 | dataloader_train=None, 54 | dataloader_val=None, 55 | checkpoint=None, 56 | ) 57 | c.job.project = "posttraining" 58 | c.job.group = "debug" 59 | c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}" 60 | 61 | c.trainer.type = TokenizerTrainer 62 | c.trainer.run_validation = True 63 | 64 | c.trainer.seed = 1234 65 | c.trainer.max_iter = 10_000_000 66 | c.trainer.validation_iter = 5000 67 | c.trainer.max_val_iter = 1 68 | c.trainer.logging_iter = 100 69 | 70 | c.trainer.callbacks = None 71 | c.trainer.ddp.static_graph = True 72 | c.trainer.ddp.find_unused_parameters = False 73 | register_configs() 74 | 75 | # experiment config are defined in the experiment folder 76 | # call import_all_modules_from_package to register them 77 | import_all_modules_from_package("cosmos_predict1.tokenizer.training.configs.experiments") 78 | 79 | return c 80 | -------------------------------------------------------------------------------- /cosmos_predict1/auxiliary/guardrail/common/io_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import glob 17 | from dataclasses import dataclass 18 | 19 | import imageio 20 | import numpy as np 21 | 22 | from cosmos_predict1.utils import log 23 | 24 | 25 | @dataclass 26 | class VideoData: 27 | frames: np.ndarray # Shape: [B, H, W, C] 28 | fps: int 29 | duration: int # in seconds 30 | 31 | 32 | def get_video_filepaths(input_dir: str) -> list[str]: 33 | """Get a list of filepaths for all videos in the input directory.""" 34 | paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True) 35 | paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True) 36 | paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True) 37 | paths = sorted(paths) 38 | log.debug(f"Found {len(paths)} videos") 39 | return paths 40 | 41 | 42 | def read_video(filepath: str) -> VideoData: 43 | """Read a video file and extract its frames and metadata.""" 44 | try: 45 | reader = imageio.get_reader(filepath, "ffmpeg") 46 | except Exception as e: 47 | raise ValueError(f"Failed to read video file: {filepath}") from e 48 | 49 | # Extract metadata from the video file 50 | try: 51 | metadata = reader.get_meta_data() 52 | fps = metadata.get("fps") 53 | duration = metadata.get("duration") 54 | except Exception as e: 55 | reader.close() 56 | raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e 57 | 58 | # Extract frames from the video file 59 | try: 60 | frames = np.array([frame for frame in reader]) 61 | except Exception as e: 62 | raise ValueError(f"Failed to extract frames from video file: {filepath}") from e 63 | finally: 64 | reader.close() 65 | 66 | return VideoData(frames=frames, fps=fps, duration=duration) 67 | 68 | 69 | def save_video(filepath: str, frames: np.ndarray, fps: int) -> None: 70 | """Save a video file from a sequence of frames.""" 71 | try: 72 | writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1) 73 | for frame in frames: 74 | writer.append_data(frame) 75 | except Exception as e: 76 | raise ValueError(f"Failed to save video file to {filepath}") from e 77 | finally: 78 | writer.close() 79 | -------------------------------------------------------------------------------- /cosmos_predict1/utils/env_parsers/cred_env_parser.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from cosmos_predict1.utils.env_parsers.env_parser import EnvParser 17 | from cosmos_predict1.utils.validator import String 18 | 19 | 20 | class CredentialEnvParser(EnvParser): 21 | APP_ENV = String(default="") 22 | PROD_FT_AWS_CREDS_ACCESS_KEY_ID = String(default="") 23 | PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY = String(default="") 24 | PROD_FT_AWS_CREDS_ENDPOINT_URL = String(default="https://s3.us-west-2.amazonaws.com") 25 | PROD_FT_AWS_CREDS_REGION_NAME = String(default="us-west-2") 26 | 27 | PROD_S3_CHECKPOINT_ACCESS_KEY_ID = String(default="") 28 | PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY = String(default="") 29 | PROD_S3_CHECKPOINT_ENDPOINT_URL = String(default="") 30 | PROD_S3_CHECKPOINT_REGION_NAME = String(default="") 31 | 32 | PROD_TEAM_DIR_ACCESS_KEY_ID = String(default="") 33 | PROD_TEAM_DIR_SECRET_ACCESS_KEY = String(default="") 34 | PROD_TEAM_DIR_ENDPOINT_URL = String(default="") 35 | PROD_TEAM_DIR_REGION_NAME = String(default="") 36 | 37 | PICASSO_AUTH_MODEL_REGISTRY_API_KEY = String(default="") 38 | PICASSO_API_ENDPOINT_URL = String(default="https://meeocvslt2.execute-api.us-west-2.amazonaws.com") 39 | 40 | 41 | CRED_ENVS = CredentialEnvParser() 42 | CRED_ENVS_DICT = { 43 | "PROD_FT_AWS_CREDS": { 44 | "aws_access_key_id": CRED_ENVS.PROD_FT_AWS_CREDS_ACCESS_KEY_ID, 45 | "aws_secret_access_key": CRED_ENVS.PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY, 46 | "endpoint_url": CRED_ENVS.PROD_FT_AWS_CREDS_ENDPOINT_URL, 47 | "region_name": CRED_ENVS.PROD_FT_AWS_CREDS_REGION_NAME, 48 | }, 49 | "PROD_S3_CHECKPOINT": { 50 | "aws_access_key_id": CRED_ENVS.PROD_S3_CHECKPOINT_ACCESS_KEY_ID, 51 | "aws_secret_access_key": CRED_ENVS.PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY, 52 | "endpoint_url": CRED_ENVS.PROD_S3_CHECKPOINT_ENDPOINT_URL, 53 | "region_name": CRED_ENVS.PROD_S3_CHECKPOINT_REGION_NAME, 54 | }, 55 | "PROD_TEAM_DIR": { 56 | "aws_access_key_id": CRED_ENVS.PROD_TEAM_DIR_ACCESS_KEY_ID, 57 | "aws_secret_access_key": CRED_ENVS.PROD_TEAM_DIR_SECRET_ACCESS_KEY, 58 | "endpoint_url": CRED_ENVS.PROD_TEAM_DIR_ENDPOINT_URL, 59 | "region_name": CRED_ENVS.PROD_TEAM_DIR_REGION_NAME, 60 | }, 61 | } 62 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/tokenizer/networks.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections import namedtuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | from cosmos_predict1.autoregressive.tokenizer.modules import CausalConv3d, DecoderFactorized, EncoderFactorized 22 | from cosmos_predict1.autoregressive.tokenizer.quantizers import FSQuantizer 23 | from cosmos_predict1.utils import log 24 | 25 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) 26 | 27 | 28 | class CausalDiscreteVideoTokenizer(nn.Module): 29 | def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: 30 | super().__init__() 31 | self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") 32 | self.embedding_dim = embedding_dim 33 | self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs) 34 | self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs) 35 | 36 | self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) 37 | self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) 38 | 39 | self.quantizer = FSQuantizer(**kwargs) 40 | 41 | num_parameters = sum(param.numel() for param in self.parameters()) 42 | log.debug(f"model={self.name}, num_parameters={num_parameters:,}") 43 | log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") 44 | 45 | def to(self, *args, **kwargs): 46 | setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) 47 | return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) 48 | 49 | def encode(self, x): 50 | h = self.encoder(x) 51 | h = self.quant_conv(h) 52 | return self.quantizer(h) 53 | 54 | def decode(self, quant): 55 | quant = self.post_quant_conv(quant) 56 | return self.decoder(quant) 57 | 58 | def forward(self, input): 59 | quant_info, quant_codes, quant_loss = self.encode(input) 60 | reconstructions = self.decode(quant_codes) 61 | if self.training: 62 | return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) 63 | return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) 64 | -------------------------------------------------------------------------------- /cosmos_predict1/autoregressive/utils/misc.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from omegaconf import DictConfig, OmegaConf 18 | 19 | 20 | class CustomSimpleNamespace: 21 | """ 22 | A simple namespace class that supports both attribute-style and dictionary-style access. 23 | """ 24 | 25 | def __init__(self, d): 26 | self._d = d 27 | 28 | def __getattr__(self, attr): 29 | # Attribute-style access: config.key 30 | try: 31 | return self._d[attr] 32 | except KeyError: 33 | raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'") 34 | 35 | def __getitem__(self, key): 36 | # Dictionary-style access: config['key'] 37 | return self._d[key] 38 | 39 | 40 | def maybe_convert_to_namespace(config): 41 | """ 42 | This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both 43 | attribute-style and dictionary-style access. 44 | Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile. 45 | """ 46 | # If input is OmegaConf's DictConfig, convert to a standard dict 47 | if isinstance(config, DictConfig): 48 | config = OmegaConf.to_container(config, resolve=True) 49 | 50 | if isinstance(config, dict): 51 | return CustomSimpleNamespace(config) 52 | else: 53 | return config 54 | 55 | 56 | def random_dropout(embeddings, drop_rate): 57 | r""" 58 | Function to perform random dropout for embeddings. 59 | When we drop embeddings, we zero them out. 60 | Args: 61 | embeddings (tensor): Input embeddings 62 | drop_rate (float): Rate of dropping the embedding. 63 | """ 64 | num_samples = embeddings.shape[0] 65 | # Create a shape (num_samples, 1, 1, 1, 1, ...) depending on embeddings dim. 66 | # This is done to ensure we can broadcast the zero_flag to the embeddings. 67 | # embeddings.ndim is 3 for images, and 4 for videos, and the corresponding 68 | # shapes are (num_samples, 1, 1) and (num_samples, 1, 1, 1) respectively. 69 | tensor_shape = (num_samples,) + tuple([1] * (embeddings.ndim - 1)) 70 | zero_flag = torch.ones(tensor_shape).to(embeddings.dtype) * (1 - drop_rate) 71 | zero_flag = torch.bernoulli(zero_flag).to(embeddings.device) 72 | embeddings = embeddings * zero_flag 73 | return embeddings 74 | --------------------------------------------------------------------------------