├── .github └── workflows │ └── static.yml ├── .gitignore ├── LICENSE ├── README.md ├── ar_gen.ipynb ├── configs ├── autoregressive_l.yaml ├── autoregressive_xl.yaml ├── onenode_config.yaml ├── tokenizer_l.yaml └── tokenizer_xl.yaml ├── examples ├── city.jpg ├── food.jpg └── highland.webp ├── fid_stats └── adm_in256_stats.npz ├── gen_demo.py ├── imagenet_classes.py ├── pages ├── figs │ ├── Token_PCA.png │ ├── comp_table.jpg │ ├── spectral_analysis.png │ ├── spectral_titok_ours.jpg │ ├── teaser.jpg │ └── tokenizer.png └── index.html ├── requirements.txt ├── semanticist ├── engine │ ├── diffusion_trainer.py │ ├── gpt_trainer.py │ └── trainer_utils.py ├── stage1 │ ├── diffuse_slot.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── diffusion_utils.py │ │ ├── gaussian_diffusion.py │ │ ├── respace.py │ │ └── timestep_sampler.py │ ├── diffusion_transfomer.py │ ├── fused_attention.py │ ├── pos_embed.py │ ├── transport │ │ ├── __init__.py │ │ ├── integrators.py │ │ ├── path.py │ │ ├── transport.py │ │ └── utils.py │ └── vision_transformer.py ├── stage2 │ ├── diffloss.py │ ├── generate.py │ └── gpt.py └── utils │ ├── datasets.py │ ├── device_utils.py │ ├── logger.py │ └── lr_scheduler.py ├── submitit_test.py ├── submitit_train.py ├── test.sh ├── test_net.py ├── tok_demo.py ├── train.sh └── train_net.py /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | # Simple workflow for deploying static content to GitHub Pages 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | # Runs on pushes targeting the default branch 6 | push: 7 | branches: ["main"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | # Single deploy job since we're just deploying 26 | deploy: 27 | environment: 28 | name: github-pages 29 | url: ${{ steps.deployment.outputs.page_url }} 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | - name: Setup Pages 35 | uses: actions/configure-pages@v5 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | with: 39 | # Upload entire repository 40 | path: 'pages/' 41 | - name: Deploy to GitHub Pages 42 | id: deployment 43 | uses: actions/deploy-pages@v4 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | cache 3 | dataset 4 | output 5 | build 6 | *_results* 7 | .ipynb_checkpoints 8 | *.zip 9 | *.pth 10 | *.ttf 11 | *.so 12 | .nfs* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Semanticist Authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "Principal Components" Enable A New Language of Images 2 | ### A New Paradigm for Compact and Interpretable Image Representations 3 | [Read the Paper]   |   4 | [Project Page]   |   5 | [Huggingface Tokenizer Demo]   |   6 | [Huggingface Generation Demo] 7 | 8 | [Xin Wen](https://wen-xin.info/)1*, 9 | [Bingchen Zhao](https://bzhao.me/)2*, 10 | [Ismail Elezi](https://therevanchist.github.io/)3, 11 | [Jiankang Deng](https://jiankangdeng.github.io/)4, 12 | [Xiaojuan Qi](https://xjqi.github.io/)1 13 |
14 | * Equal Contribution   15 |
16 | 1 University of Hong Kong   |   17 | 2 University of Edinburgh   |   18 | 3 Noah's Ark Lab   |   19 | 4 Imperial College London 20 | 21 | ![Semanticist Teaser](pages/figs/teaser.jpg) 22 | 23 | ## Introduction & Motivation 24 | Deep generative models have revolutionized image synthesis, but how we tokenize visual data remains an open question. 25 | While classical methods like **Principal Component Analysis (PCA)** introduced compact, structured representations, modern **visual tokenizers**—from **VQ-VAE** to **SD-VAE**—often prioritize **reconstruction fidelity** at the cost of interpretability and efficiency. 26 | 27 | ### The Problem 28 | 29 | - **Lack of Structure:** Tokens are arbitrarily learned, without an ordering that prioritizes important visual features first. 30 | - **Semantic-Spectrum Coupling:** Tokens entangle *high-level semantics* with *low-level spectral details*, leading to inefficiencies in downstream applications. 31 | 32 | Can we design a **compact, structured tokenizer** that retains the benefits of PCA while leveraging modern generative techniques? 33 | 34 | ### Key Contributions (What's New?) 35 | - **📌 PCA-Guided Tokenization:** Introduces a *causal ordering* where earlier tokens capture the most important visual details, reducing redundancy. 36 | - **⚡ Semantic-Spectrum Decoupling:** Resolves the issue of semantic-spectrum coupling to ensure tokens focus on high-level semantic information. 37 | - **🌀 Diffusion-Based Decoding:** Uses a *diffusion decoder* for the spectral auto-regressive property to naturally separate semantic and spectral content. 38 | - **🚀 Compact & Interpretability-Friendly:** Enables *flexible token selection*, where fewer tokens can still yield high-quality reconstructions. 39 | 40 | For more details, please refer to our [project page](https://visual-gen.github.io/semanticist/). 41 | 42 | ## Getting Started 43 | 44 | ### Preparation 45 | 46 | First please makesure pytorch is installed (we used 2.5.1 but we expect any version >= 2.0 to work). 47 | 48 | Then install the rest of the dependencies. 49 | 50 | ``` 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | Please then download [ImageNet](https://www.image-net.org/) and soft-link it to `./dataset/imagenet`. For evaluating FID, it is recommended to pre-process the validation set of ImageNet with [this script](https://github.com/LTH14/rcg/blob/main/prepare_imgnet_val.py). The target folder is `./dataset/imagenet/val256` in our case. 55 | 56 | ### Training 57 | 58 | Our codebase supports DDP training with accelerate, torchrun, and submitit (for slurm users). To train a Semanticist tokenizer with DiT-L tokenizer on 8 GPUs, you can run 59 | ```bash 60 | accelerate launch --config_file=configs/onenode_config.yaml train_net.py --cfg configs/tokenizer_l.yaml 61 | ``` 62 | or 63 | ```bash 64 | torchrun --nproc-per-node=8 train_net.py --cfg configs/tokenizer_l.yaml 65 | ``` 66 | or 67 | ```bash 68 | python submitit_train.py --ngpus=8 --nodes=1 --partition=xxx --config configs/tokenizer_l.yaml 69 | ``` 70 | We used a global batch size of 2048 and thus the effective batch size per GPU is 256 in this case. Your may modify the batch size and gradient accumulation steps in the config file accrrding to your training resources. 71 | 72 | To train a ϵLlamaGen autoregressive model with a tokenizer trained as above, you can run the following command. Remember to change the path to the tokenizer in the config file. The EMA model is `custom_checkpoint_1.pkl` under the output folder. 73 | ```bash 74 | accelerate launch --config_file=configs/onenode_config.yaml train_net.py --cfg configs/autoregressive_l.yaml 75 | ``` 76 | Note that caching is enabled by default and it takes around 400GB memory (dumped to `/dev/shm`) for ten_crop augmentation on ImageNet. If you want to disable it, you can set `enable_cache_latents` to False in the config file and/or specify a different data augmentation method (e.g., centercrop_cached, centercrop, randcrop). 77 | 78 | ### Evaluation 79 | 80 | By default, when evaluating online we do not use the EMA model. Thus to obtain the final performance, you are suggested to perform a separate evaluation after training. Like above, our scripts are compatible with accelerate, torchrun, and submitit. 81 | ```bash 82 | accelerate launch --config_file=configs/onenode_config.yaml test_net.py --model ./output/tokenizer/models_l --step 250000 --cfg_value 3.0 --test_num_slots 32 83 | ``` 84 | or 85 | ```bash 86 | torchrun --nproc-per-node=8 test_net.py --model ./output/tokenizer/models_l --step 250000 --cfg_value 3.0 --test_num_slots 32 87 | ``` 88 | or 89 | ```bash 90 | python submitit_eval.py --ngpus=8 --nodes=1 --partition=xxx --model ./output/tokenizer/models_l --step 250000 --cfg_value 3.0 --test_num_slots 32 91 | ``` 92 | And for the AR model: 93 | ```bash 94 | torchrun --nproc-per-node=8 test_net.py --model ./output/autoregressive/models_l --step 250000 --cfg_value 6.0 --ae_cfg 1.0 --test_num_slots 32 95 | ``` 96 | If `enable_ema` is set to True, the EMA model will be loaded automatically. You can adjust the number of GPUs flexibly. You can also specify multiple arguments in the command line to perform a grid search. 97 | 98 | ### Demos 99 | 100 | Please refer to our demo pages on Huggingface for the tokenizer and the AR model. 101 | - [Tokenizer Demo](https://huggingface.co/spaces/tennant/semanticist_tokenizer) 102 | - [AR Demo](https://huggingface.co/spaces/tennant/Semanticist_AR) 103 | 104 | ## Note 105 | 106 | It's possible that this code may not accurately replicate the results outlined in the paper due to potential human errors during the preparation and cleaning of the code for release. If you encounter any difficulties in reproducing our findings, please don't hesitate to inform us. Additionally, we'll make an effort to refine the README and code, and carry out sanity-check experiments in the near future. 107 | 108 | ## Acknowledgements 109 | 110 | Our codebase builds upon several existing publicly available codes. Specifically, we have modified or taken inspiration from the following repos: [DiT](https://github.com/facebookresearch/DiT), [SiT](https://github.com/willisma/SiT), [DiffAE](https://github.com/phizaz/diffae), [LlamaGen](https://github.com/FoundationVision/LlamaGen), [RCG](https://github.com/LTH14/rcg), [MAR](https://github.com/LTH14/mar), [REPA](https://github.com/sihyun-yu/REPA), etc. We thank the authors for their contributions to the community. 111 | 112 | ## Citation 113 | 114 | If you find this work useful in your research, please consider citing us! 115 | 116 | ```bibtex 117 | @article{semanticist, 118 | title={``Principal Components'' Enable A New Language of Images}, 119 | author={Wen, Xin and Zhao, Bingchen and Elezi, Ismail and Deng, Jiankang and Qi, Xiaojuan}, 120 | journal={arXiv preprint arXiv:2503.08685}, 121 | year={2025} 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /configs/autoregressive_l.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | target: semanticist.engine.gpt_trainer.GPTTrainer 3 | params: 4 | num_epoch: 400 5 | blr: 1e-4 6 | cosine_lr: False 7 | warmup_epochs: 100 8 | batch_size: 256 9 | num_workers: 8 10 | pin_memory: True 11 | grad_accum_steps: 1 12 | precision: 'bf16' 13 | max_grad_norm: 1.0 14 | enable_ema: True 15 | save_every: 10000 16 | sample_every: 5000 17 | fid_every: 50000 18 | eval_fid: False 19 | result_folder: "./output/autoregressive" 20 | log_dir: "./output/autoregressive/logs" 21 | ae_cfg: 1.0 22 | cfg: 6.0 23 | cfg_schedule: "linear" 24 | train_num_slots: 32 25 | test_num_slots: 32 26 | compile: True 27 | enable_cache_latents: True 28 | ae_model: 29 | target: semanticist.stage1.diffuse_slot.DiffuseSlot 30 | params: 31 | encoder: 'vit_base_patch16' 32 | enc_img_size: 256 33 | enc_causal: True 34 | num_slots: 256 35 | slot_dim: 16 36 | norm_slots: True 37 | cond_method: 'token' 38 | dit_model: 'DiT-L-2' 39 | vae: 'xwen99/mar-vae-kl16' 40 | num_sampling_steps: '250' 41 | ckpt_path: ./output/tokenizer/models_l/step250000/custom_checkpoint_1.pkl 42 | 43 | gpt_model: 44 | target: GPT-L 45 | params: 46 | num_slots: 32 47 | slot_dim: 16 48 | num_classes: 1000 49 | cls_token_num: 1 50 | resid_dropout_p: 0.1 51 | ffn_dropout_p: 0.1 52 | diffloss_d: 12 53 | diffloss_w: 1536 54 | num_sampling_steps: '100' 55 | diffusion_batch_mul: 4 56 | use_si: True 57 | cond_method: 'concat' 58 | ckpt_path: None 59 | 60 | dataset: 61 | target: semanticist.utils.datasets.ImageNet 62 | params: 63 | root: ./dataset/imagenet/ 64 | split: train 65 | aug: tencrop_cached # or centercrop_cached 66 | img_size: 256 -------------------------------------------------------------------------------- /configs/autoregressive_xl.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | target: semanticist.engine.gpt_trainer.GPTTrainer 3 | params: 4 | num_epoch: 400 5 | blr: 1e-4 6 | cosine_lr: False 7 | warmup_epochs: 100 8 | batch_size: 256 9 | num_workers: 8 10 | pin_memory: True 11 | grad_accum_steps: 1 12 | precision: 'bf16' 13 | max_grad_norm: 1.0 14 | enable_ema: True 15 | save_every: 10000 16 | sample_every: 5000 17 | fid_every: 50000 18 | eval_fid: False 19 | result_folder: "./output/autoregressive" 20 | log_dir: "./output/autoregressive/logs" 21 | ae_cfg: 1.0 22 | cfg: 5.0 23 | cfg_schedule: "linear" 24 | train_num_slots: 32 25 | test_num_slots: 32 26 | compile: True 27 | enable_cache_latents: True 28 | ae_model: 29 | target: semanticist.stage1.diffuse_slot.DiffuseSlot 30 | params: 31 | encoder: 'vit_base_patch16' 32 | enc_img_size: 256 33 | enc_causal: True 34 | num_slots: 256 35 | slot_dim: 16 36 | norm_slots: True 37 | cond_method: 'token' 38 | dit_model: 'DiT-XL-2' 39 | vae: 'xwen99/mar-vae-kl16' 40 | num_sampling_steps: '250' 41 | ckpt_path: ./output/tokenizer/models_xl/step250000/custom_checkpoint_1.pkl 42 | 43 | gpt_model: 44 | target: GPT-L 45 | params: 46 | num_slots: 32 47 | slot_dim: 16 48 | num_classes: 1000 49 | cls_token_num: 1 50 | resid_dropout_p: 0.1 51 | ffn_dropout_p: 0.1 52 | diffloss_d: 12 53 | diffloss_w: 1536 54 | num_sampling_steps: '100' 55 | diffusion_batch_mul: 4 56 | use_si: True 57 | cond_method: 'concat' 58 | ckpt_path: None 59 | 60 | dataset: 61 | target: semanticist.utils.datasets.ImageNet 62 | params: 63 | root: ./dataset/imagenet/ 64 | split: train 65 | aug: tencrop_cached # or centercrop_cached 66 | img_size: 256 -------------------------------------------------------------------------------- /configs/onenode_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | fsdp_config: {} 5 | machine_rank: 0 6 | main_process_ip: null 7 | main_process_port: null 8 | main_training_function: main 9 | num_machines: 1 10 | num_processes: 8 11 | use_cpu: false -------------------------------------------------------------------------------- /configs/tokenizer_l.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | target: semanticist.engine.diffusion_trainer.DiffusionTrainer 3 | params: 4 | num_epoch: 400 5 | valid_size: 64 6 | blr: 2.5e-5 7 | cosine_lr: True 8 | warmup_epochs: 100 9 | batch_size: 256 10 | num_workers: 16 11 | pin_memory: True 12 | grad_accum_steps: 1 13 | precision: 'bf16' 14 | max_grad_norm: 3.0 15 | enable_ema: True 16 | save_every: 10000 17 | sample_every: 5000 18 | fid_every: 50000 19 | result_folder: "./output/tokenizer/models_l" 20 | log_dir: "./output/tokenizer/models_l/logs" 21 | cfg: 3.0 22 | compile: True 23 | model: 24 | target: semanticist.stage1.diffuse_slot.DiffuseSlot 25 | params: 26 | encoder: 'vit_base_patch16' 27 | enc_img_size: 256 28 | enc_causal: True 29 | enc_use_mlp: False 30 | num_slots: 256 31 | slot_dim: 16 32 | norm_slots: True 33 | dit_model: 'DiT-L-2' 34 | vae: 'xwen99/mar-vae-kl16' 35 | enable_nest: False 36 | enable_nest_after: 50 37 | use_repa: True 38 | eval_fid: True 39 | fid_stats: 'fid_stats/adm_in256_stats.npz' 40 | num_sampling_steps: '250' 41 | ckpt_path: None 42 | 43 | dataset: 44 | target: semanticist.utils.datasets.ImageNet 45 | params: 46 | root: ./dataset/imagenet/ 47 | split: train 48 | img_size: 256 49 | 50 | test_dataset: 51 | target: semanticist.utils.datasets.ImageNet 52 | params: 53 | root: ./dataset/imagenet/ 54 | split: val 55 | img_size: 256 56 | -------------------------------------------------------------------------------- /configs/tokenizer_xl.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | target: semanticist.engine.diffusion_trainer.DiffusionTrainer 3 | params: 4 | num_epoch: 400 5 | valid_size: 64 6 | blr: 2.5e-5 7 | cosine_lr: True 8 | warmup_epochs: 100 9 | batch_size: 256 10 | num_workers: 16 11 | pin_memory: True 12 | grad_accum_steps: 1 13 | precision: 'bf16' 14 | max_grad_norm: 3.0 15 | enable_ema: True 16 | save_every: 10000 17 | sample_every: 5000 18 | fid_every: 50000 19 | result_folder: "./output/tokenizer/models_xl" 20 | log_dir: "./output/tokenizer/models_xl/logs" 21 | cfg: 3.0 22 | compile: True 23 | model: 24 | target: semanticist.stage1.diffuse_slot.DiffuseSlot 25 | params: 26 | encoder: 'vit_base_patch16' 27 | enc_img_size: 256 28 | enc_causal: True 29 | enc_use_mlp: False 30 | num_slots: 256 31 | slot_dim: 16 32 | norm_slots: True 33 | dit_model: 'DiT-XL-2' 34 | vae: 'xwen99/mar-vae-kl16' 35 | enable_nest: False 36 | enable_nest_after: 50 37 | use_repa: True 38 | eval_fid: True 39 | fid_stats: 'fid_stats/adm_in256_stats.npz' 40 | num_sampling_steps: '250' 41 | ckpt_path: None 42 | 43 | dataset: 44 | target: semanticist.utils.datasets.ImageNet 45 | params: 46 | root: ./dataset/imagenet/ 47 | split: train 48 | img_size: 256 49 | 50 | test_dataset: 51 | target: semanticist.utils.datasets.ImageNet 52 | params: 53 | root: ./dataset/imagenet/ 54 | split: val 55 | img_size: 256 56 | -------------------------------------------------------------------------------- /examples/city.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/examples/city.jpg -------------------------------------------------------------------------------- /examples/food.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/examples/food.jpg -------------------------------------------------------------------------------- /examples/highland.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/examples/highland.webp -------------------------------------------------------------------------------- /fid_stats/adm_in256_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/fid_stats/adm_in256_stats.npz -------------------------------------------------------------------------------- /gen_demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | from PIL import Image 4 | import os.path as osp 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from omegaconf import OmegaConf 8 | from tqdm import tqdm 9 | from huggingface_hub import hf_hub_download 10 | from semanticist.engine.trainer_utils import instantiate_from_config 11 | from semanticist.stage1.diffuse_slot import DiffuseSlot 12 | from semanticist.stage2.gpt import GPT_models 13 | from semanticist.stage2.generate import generate 14 | from safetensors import safe_open 15 | from semanticist.utils.datasets import vae_transforms 16 | from PIL import Image 17 | from imagenet_classes import imagenet_classes 18 | 19 | transform = vae_transforms('test') 20 | 21 | 22 | def norm_ip(img, low, high): 23 | img.clamp_(min=low, max=high) 24 | img.sub_(low).div_(max(high - low, 1e-5)) 25 | 26 | def norm_range(t, value_range): 27 | if value_range is not None: 28 | norm_ip(t, value_range[0], value_range[1]) 29 | else: 30 | norm_ip(t, float(t.min()), float(t.max())) 31 | 32 | from PIL import Image 33 | def convert_np(img): 34 | ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\ 35 | .permute(1, 2, 0).to("cpu", torch.uint8).numpy() 36 | return ndarr 37 | def convert_PIL(img): 38 | ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\ 39 | .permute(1, 2, 0).to("cpu", torch.uint8).numpy() 40 | img = Image.fromarray(ndarr) 41 | return img 42 | 43 | def norm_slots(slots): 44 | mean = torch.mean(slots, dim=-1, keepdim=True) 45 | std = torch.std(slots, dim=-1, keepdim=True) 46 | return (slots - mean) / std 47 | 48 | def load_state_dict(state_dict, model): 49 | """Helper to load a state dict with proper prefix handling.""" 50 | if 'state_dict' in state_dict: 51 | state_dict = state_dict['state_dict'] 52 | # Remove '_orig_mod' prefix if present 53 | state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} 54 | missing, unexpected = model.load_state_dict( 55 | state_dict, strict=False 56 | ) 57 | # print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") 58 | 59 | def load_safetensors(path, model): 60 | """Helper to load a safetensors checkpoint.""" 61 | from safetensors.torch import safe_open 62 | with safe_open(path, framework="pt", device="cpu") as f: 63 | state_dict = {k: f.get_tensor(k) for k in f.keys()} 64 | load_state_dict(state_dict, model) 65 | 66 | def load_checkpoint(ckpt_path, model): 67 | if ckpt_path is None or not osp.exists(ckpt_path): 68 | return 69 | 70 | if osp.isdir(ckpt_path): 71 | # ckpt_path is something like 'path/to/models/step10/' 72 | model_path = osp.join(ckpt_path, "model.safetensors") 73 | if osp.exists(model_path): 74 | load_safetensors(model_path, model) 75 | else: 76 | # ckpt_path is something like 'path/to/models/step10.pt' 77 | if ckpt_path.endswith(".safetensors"): 78 | load_safetensors(ckpt_path, model) 79 | else: 80 | state_dict = torch.load(ckpt_path, map_location="cpu") 81 | load_state_dict(state_dict, model) 82 | 83 | print(f"Loaded checkpoint from {ckpt_path}") 84 | 85 | device = "cuda" if torch.cuda.is_available() else "cpu" 86 | print(f"Is CUDA available: {torch.cuda.is_available()}") 87 | if device == 'cuda': 88 | print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") 89 | 90 | ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_ar_gen_L.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/') 91 | config_path = 'configs/autoregressive_xl.yaml' 92 | 93 | cfg = OmegaConf.load(config_path) 94 | params = cfg.trainer.params 95 | 96 | ae_model = instantiate_from_config(params.ae_model).to(device) 97 | ae_model_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_tok_XL.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/') 98 | load_checkpoint(ae_model_path, ae_model) 99 | ae_model.eval() 100 | 101 | gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(device) 102 | load_checkpoint(ckpt_path, gpt_model) 103 | gpt_model.eval(); 104 | 105 | def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False): 106 | n_slots_inf = [] 107 | for num_slots_to_inference in nums: 108 | drop_mask = model.nested_sampler(slots.shape[0], device, num_slots_to_inference) 109 | recon_n = model.sample(slots, drop_mask=drop_mask, cfg=cfg) 110 | n_slots_inf.append(recon_n) 111 | return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))] 112 | 113 | num_slots = params.ae_model.params.num_slots 114 | slot_dim = params.ae_model.params.slot_dim 115 | dtype = torch.bfloat16 116 | # the model is trained with only 32 tokens. 117 | num_slots_to_gen = 32 118 | 119 | # Function to generate image from class 120 | def generate_from_class(class_id, cfg_scale): 121 | with torch.no_grad(): 122 | dtype = torch.bfloat16 123 | num_slots_to_gen = 32 124 | with torch.autocast(device, dtype=dtype): 125 | slots_gen = generate( 126 | gpt_model, 127 | torch.tensor([class_id]).to(device), 128 | num_slots_to_gen, 129 | cfg_scale=cfg_scale, 130 | cfg_schedule="linear" 131 | ) 132 | if num_slots_to_gen < num_slots: 133 | null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1) 134 | null_slots = null_slots[:, num_slots_to_gen:, :] 135 | slots_gen = torch.cat([slots_gen, null_slots], dim=1) 136 | return slots_gen 137 | 138 | with gr.Blocks() as demo: 139 | with gr.Row(): 140 | # First column - Input and configs 141 | with gr.Column(scale=1): 142 | gr.Markdown("## Input") 143 | 144 | # Replace image input with ImageNet class selection 145 | imagenet_classes = {k: v for k, v in enumerate(imagenet_classes)} 146 | class_choices = [f"{id}: {name}" for id, name in imagenet_classes.items()] 147 | 148 | # Dropdown for class selection 149 | class_dropdown = gr.Dropdown( 150 | choices=class_choices[:20], # Limit for demonstration 151 | label="Select ImageNet Class", 152 | value=class_choices[0] if class_choices else None 153 | ) 154 | 155 | # Option to enter class ID directly 156 | class_id_input = gr.Number( 157 | label="Or enter class ID directly (0-999)", 158 | value=0, 159 | minimum=0, 160 | maximum=999, 161 | step=1 162 | ) 163 | 164 | with gr.Group(): 165 | gr.Markdown("### Configuration") 166 | show_gallery = gr.Checkbox(label="Show Gallery", value=True) 167 | slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value") 168 | labels_input = gr.Textbox( 169 | label="Number of tokens to reconstruct (comma-separated)", 170 | value="1, 2, 4, 8, 16", 171 | placeholder="Enter comma-separated numbers for the number of slots to use" 172 | ) 173 | 174 | # Second column - Output (conditionally rendered) 175 | with gr.Column(scale=1): 176 | gr.Markdown("## Output") 177 | 178 | # Container for conditional rendering 179 | with gr.Group(visible=True) as gallery_container: 180 | gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True) 181 | 182 | # Always visible output image 183 | output_image = gr.Image(label="Generated Image", type="numpy") 184 | 185 | # Handle form submission 186 | submit_btn = gr.Button("Generate") 187 | 188 | # Define the processing logic 189 | def update_outputs(class_selection, class_id, show_gallery_value, slider_value, labels_text): 190 | # Determine which class to use - either from dropdown or direct input 191 | if class_selection: 192 | # Extract class ID from the dropdown selection 193 | selected_class_id = int(class_selection.split(":")[0]) 194 | else: 195 | selected_class_id = int(class_id) 196 | 197 | # Update the visibility of the gallery container 198 | gallery_container.visible = show_gallery_value 199 | 200 | try: 201 | # Parse the labels from the text input 202 | if labels_text and "," in labels_text: 203 | labels = [int(label.strip()) for label in labels_text.split(",")] 204 | else: 205 | # Default labels if none provided or in wrong format 206 | labels = [1, 4, 16, 64, 256] 207 | except: 208 | labels = [1, 4, 16, 64, 256] 209 | 210 | while len(labels) < 3: 211 | labels.append(256) 212 | 213 | # Generate the image based on the selected class 214 | slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value) 215 | 216 | recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0] 217 | 218 | # Always generate the model decomposition for potential gallery display 219 | model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value) 220 | 221 | if not show_gallery_value: 222 | # If only the image should be shown, return just the processed image 223 | return gallery_container, [], recon 224 | else: 225 | # Create image variations and pair them with labels 226 | gallery_images = [ 227 | (recon, f'Generated from class {selected_class_id}'), 228 | ] + [(img, 'Gen. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)] 229 | return gallery_container, gallery_images, recon 230 | 231 | # Connect the inputs and outputs 232 | submit_btn.click( 233 | fn=update_outputs, 234 | inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input], 235 | outputs=[gallery_container, gallery, output_image] 236 | ) 237 | 238 | # Also update when checkbox changes 239 | show_gallery.change( 240 | fn=lambda value: gr.update(visible=value), 241 | inputs=[show_gallery], 242 | outputs=[gallery_container] 243 | ) 244 | 245 | # Add examples 246 | examples = [ 247 | # ["0: tench, Tinca tinca", 0, True, 4.0, "1,2,4,8,16"], 248 | ["1: goldfish", 1, True, 4.0, "1,2,4,8,16"], 249 | # ["2: great white shark, white shark", 2, True, 4.0, "1,2,4,8,16"], 250 | ] 251 | 252 | gr.Examples( 253 | examples=examples, 254 | inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input], 255 | outputs=[gallery_container, gallery, output_image], 256 | fn=update_outputs, 257 | cache_examples=False 258 | ) 259 | 260 | # Launch the demo 261 | if __name__ == "__main__": 262 | demo.launch() 263 | -------------------------------------------------------------------------------- /imagenet_classes.py: -------------------------------------------------------------------------------- 1 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] -------------------------------------------------------------------------------- /pages/figs/Token_PCA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/Token_PCA.png -------------------------------------------------------------------------------- /pages/figs/comp_table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/comp_table.jpg -------------------------------------------------------------------------------- /pages/figs/spectral_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/spectral_analysis.png -------------------------------------------------------------------------------- /pages/figs/spectral_titok_ours.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/spectral_titok_ours.jpg -------------------------------------------------------------------------------- /pages/figs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/teaser.jpg -------------------------------------------------------------------------------- /pages/figs/tokenizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/tokenizer.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | accelerate 3 | diffusers[torch] 4 | transformers 5 | safetensors 6 | omegaconf 7 | tensorboard 8 | huggingface-hub 9 | einops 10 | timm 11 | scipy 12 | scikit-learn 13 | scikit-image 14 | git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity 15 | opencv-python-headless 16 | torchmetrics 17 | submitit -------------------------------------------------------------------------------- /semanticist/engine/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | import cv2 3 | import numpy as np 4 | import torch_fidelity 5 | from collections import OrderedDict 6 | from concurrent.futures import ThreadPoolExecutor 7 | import importlib 8 | from torch.optim import AdamW 9 | from semanticist.utils.lr_scheduler import build_scheduler 10 | 11 | 12 | def get_obj_from_str(string, reload=False): 13 | """Get object from string path.""" 14 | module, cls = string.rsplit(".", 1) 15 | if reload: 16 | module_imp = importlib.import_module(module) 17 | importlib.reload(module_imp) 18 | return getattr(importlib.import_module(module, package=None), cls) 19 | 20 | 21 | def instantiate_from_config(config): 22 | """Instantiate an object from a config dictionary.""" 23 | if not "target" in config: 24 | raise KeyError("Expected key `target` to instantiate.") 25 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | """Check if distributed training is available and initialized.""" 30 | if not torch.distributed.is_initialized(): 31 | return False 32 | return True 33 | 34 | 35 | def is_main_process(): 36 | """Check if the current process is the main process.""" 37 | return not is_dist_avail_and_initialized() or torch.distributed.get_rank() == 0 38 | 39 | 40 | def concat_all_gather(tensor): 41 | """ 42 | Performs all_gather operation on the provided tensors. 43 | *** Warning ***: torch.distributed.all_gather has no gradient. 44 | """ 45 | tensors_gather = [torch.ones_like(tensor) 46 | for _ in range(torch.distributed.get_world_size())] 47 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 48 | 49 | output = torch.cat(tensors_gather, dim=0) 50 | return output 51 | 52 | 53 | def requires_grad(model, flag=True): 54 | """Set requires_grad flag for all model parameters.""" 55 | for p in model.parameters(): 56 | p.requires_grad = flag 57 | 58 | 59 | def save_img(img, save_path): 60 | """Save a single image to disk.""" 61 | img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255) 62 | img = img.astype(np.uint8)[:, :, ::-1] 63 | cv2.imwrite(save_path, img) 64 | 65 | 66 | def save_img_batch(imgs, save_paths): 67 | """Process and save multiple images at once using a thread pool.""" 68 | # Convert to numpy and prepare all images in one go 69 | imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8) 70 | imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once 71 | 72 | with ThreadPoolExecutor(max_workers=32) as pool: 73 | # Submit all tasks at once 74 | futures = [pool.submit(cv2.imwrite, path, img) 75 | for path, img in zip(save_paths, imgs)] 76 | # Wait for all tasks to complete 77 | for future in futures: 78 | future.result() # This will raise any exceptions that occurred 79 | 80 | 81 | def get_fid_stats(real_dir, rec_dir, fid_stats): 82 | """Calculate FID statistics between real and reconstructed images.""" 83 | stats = torch_fidelity.calculate_metrics( 84 | input1=rec_dir, 85 | input2=real_dir, 86 | fid_statistics_file=fid_stats, 87 | cuda=True, 88 | isc=True, 89 | fid=True, 90 | kid=False, 91 | prc=False, 92 | verbose=False, 93 | ) 94 | return stats 95 | 96 | 97 | def create_scheduler(optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps, 98 | warmup_lr_init, decay_steps, cosine_lr): 99 | """Create a learning rate scheduler.""" 100 | scheduler = build_scheduler( 101 | optimizer, 102 | num_epoch, 103 | steps_per_epoch, 104 | lr_min, 105 | warmup_steps, 106 | warmup_lr_init, 107 | decay_steps, 108 | cosine_lr, 109 | ) 110 | return scheduler 111 | 112 | 113 | def load_state_dict(state_dict, model): 114 | """Helper to load a state dict with proper prefix handling.""" 115 | if 'state_dict' in state_dict: 116 | state_dict = state_dict['state_dict'] 117 | # Remove '_orig_mod' prefix if present 118 | state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} 119 | missing, unexpected = model.load_state_dict( 120 | state_dict, strict=False 121 | ) 122 | if is_main_process(): 123 | print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") 124 | 125 | 126 | def load_safetensors(path, model): 127 | """Helper to load a safetensors checkpoint.""" 128 | from safetensors.torch import safe_open 129 | with safe_open(path, framework="pt", device="cpu") as f: 130 | state_dict = {k: f.get_tensor(k) for k in f.keys()} 131 | load_state_dict(state_dict, model) 132 | 133 | 134 | def setup_result_folders(result_folder): 135 | """Setup result folders for saving models and images.""" 136 | model_saved_dir = os.path.join(result_folder, "models") 137 | os.makedirs(model_saved_dir, exist_ok=True) 138 | 139 | image_saved_dir = os.path.join(result_folder, "images") 140 | os.makedirs(image_saved_dir, exist_ok=True) 141 | 142 | return model_saved_dir, image_saved_dir 143 | 144 | 145 | def create_optimizer(model, weight_decay, learning_rate, betas=(0.9, 0.95)): 146 | """Create an AdamW optimizer with weight decay for 2D parameters only.""" 147 | # start with all of the candidate parameters 148 | param_dict = {pn: p for pn, p in model.named_parameters()} 149 | # filter out those that do not require grad 150 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 151 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 152 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 153 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 154 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 155 | optim_groups = [ 156 | {'params': decay_params, 'weight_decay': weight_decay}, 157 | {'params': nodecay_params, 'weight_decay': 0.0} 158 | ] 159 | num_decay_params = sum(p.numel() for p in decay_params) 160 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 161 | if is_main_process(): 162 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 163 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 164 | optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas) 165 | return optimizer 166 | 167 | 168 | class EMAModel: 169 | """Model Exponential Moving Average.""" 170 | def __init__(self, model, device, decay=0.999): 171 | self.device = device 172 | self.decay = decay 173 | self.ema_params = OrderedDict( 174 | (name, param.clone().detach().to(device)) 175 | for name, param in model.named_parameters() 176 | if param.requires_grad 177 | ) 178 | 179 | @torch.no_grad() 180 | def update(self, model): 181 | for name, param in model.named_parameters(): 182 | if param.requires_grad: 183 | if name in self.ema_params: 184 | self.ema_params[name].lerp_(param.data, 1 - self.decay) 185 | else: 186 | self.ema_params[name] = param.data.clone().detach() 187 | 188 | def state_dict(self): 189 | return self.ema_params 190 | 191 | def load_state_dict(self, params): 192 | self.ema_params = OrderedDict( 193 | (name, param.clone().detach().to(self.device)) 194 | for name, param in params.items() 195 | ) 196 | 197 | 198 | class PaddedDataset(torch.utils.data.Dataset): 199 | """Dataset wrapper that pads a dataset to ensure even distribution across processes.""" 200 | def __init__(self, dataset, padding_size): 201 | self.dataset = dataset 202 | self.padding_size = padding_size 203 | 204 | def __len__(self): 205 | return len(self.dataset) + self.padding_size 206 | 207 | def __getitem__(self, idx): 208 | if idx < len(self.dataset): 209 | return self.dataset[idx] 210 | return self.dataset[0] 211 | 212 | class CacheDataLoader: 213 | """DataLoader-like interface for cached data with epoch-based shuffling.""" 214 | def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None): 215 | self.slots = slots 216 | self.targets = targets 217 | self.batch_size = batch_size 218 | self.num_augs = num_augs 219 | self.seed = seed 220 | self.epoch = 0 221 | # Original dataset size (before augmentations) 222 | self.num_samples = len(slots) // num_augs 223 | 224 | def set_epoch(self, epoch): 225 | """Set epoch for deterministic shuffling.""" 226 | self.epoch = epoch 227 | 228 | def __len__(self): 229 | """Return number of batches based on original dataset size.""" 230 | return self.num_samples // self.batch_size 231 | 232 | def __iter__(self): 233 | """Return random indices for current epoch.""" 234 | g = torch.Generator() 235 | g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch) 236 | 237 | # Randomly sample indices from the entire augmented dataset 238 | indices = torch.randint( 239 | 0, len(self.slots), 240 | (self.num_samples,), 241 | generator=g 242 | ).numpy() 243 | 244 | # Yield batches of indices 245 | for start in range(0, self.num_samples, self.batch_size): 246 | end = min(start + self.batch_size, self.num_samples) 247 | batch_indices = indices[start:end] 248 | yield ( 249 | torch.from_numpy(self.slots[batch_indices]), 250 | torch.from_numpy(self.targets[batch_indices]) 251 | ) -------------------------------------------------------------------------------- /semanticist/stage1/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /semanticist/stage1/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /semanticist/stage1/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /semanticist/stage1/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /semanticist/stage1/diffusion_transfomer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Mlp 17 | from semanticist.stage1.fused_attention import Attention 18 | 19 | 20 | 21 | def modulate(x, shift, scale): 22 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 23 | 24 | 25 | ################################################################################# 26 | # Embedding Layers for Timesteps and Class Labels # 27 | ################################################################################# 28 | 29 | class TimestepEmbedder(nn.Module): 30 | """ 31 | Embeds scalar timesteps into vector representations. 32 | """ 33 | def __init__(self, hidden_size, frequency_embedding_size=256): 34 | super().__init__() 35 | self.mlp = nn.Sequential( 36 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 37 | nn.SiLU(), 38 | nn.Linear(hidden_size, hidden_size, bias=True), 39 | ) 40 | self.frequency_embedding_size = frequency_embedding_size 41 | 42 | @staticmethod 43 | def timestep_embedding(t, dim, max_period=10000): 44 | """ 45 | Create sinusoidal timestep embeddings. 46 | :param t: a 1-D Tensor of N indices, one per batch element. 47 | These may be fractional. 48 | :param dim: the dimension of the output. 49 | :param max_period: controls the minimum frequency of the embeddings. 50 | :return: an (N, D) Tensor of positional embeddings. 51 | """ 52 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 53 | half = dim // 2 54 | freqs = torch.exp( 55 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 56 | ).to(device=t.device) 57 | args = t[:, None].float() * freqs[None] 58 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 59 | if dim % 2: 60 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 61 | return embedding 62 | 63 | def forward(self, t): 64 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 65 | t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) 66 | return t_emb 67 | 68 | 69 | class LabelEmbedder(nn.Module): 70 | """ 71 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 72 | """ 73 | def __init__(self, num_classes, hidden_size, dropout_prob): 74 | super().__init__() 75 | use_cfg_embedding = dropout_prob > 0 76 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 77 | self.num_classes = num_classes 78 | self.dropout_prob = dropout_prob 79 | 80 | def token_drop(self, labels, force_drop_ids=None): 81 | """ 82 | Drops labels to enable classifier-free guidance. 83 | """ 84 | if force_drop_ids is None: 85 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 86 | else: 87 | drop_ids = force_drop_ids == 1 88 | labels = torch.where(drop_ids, self.num_classes, labels) 89 | return labels 90 | 91 | def forward(self, labels, train, force_drop_ids=None): 92 | use_dropout = self.dropout_prob > 0 93 | if (train and use_dropout) or (force_drop_ids is not None): 94 | labels = self.token_drop(labels, force_drop_ids) 95 | embeddings = self.embedding_table(labels) 96 | return embeddings 97 | 98 | 99 | ################################################################################# 100 | # Core DiT Model # 101 | ################################################################################# 102 | 103 | class DiTBlock(nn.Module): 104 | """ 105 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 106 | """ 107 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 108 | super().__init__() 109 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 111 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 112 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 113 | approx_gelu = lambda: nn.GELU(approximate="tanh") 114 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 115 | self.adaLN_modulation = nn.Sequential( 116 | nn.SiLU(), 117 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 118 | ) 119 | 120 | def forward(self, x, c, mask=None): 121 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 122 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask) 123 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 124 | return x 125 | 126 | 127 | class FinalLayer(nn.Module): 128 | """ 129 | The final layer of DiT. 130 | """ 131 | def __init__(self, hidden_size, patch_size, out_channels): 132 | super().__init__() 133 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 134 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 135 | self.adaLN_modulation = nn.Sequential( 136 | nn.SiLU(), 137 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 138 | ) 139 | 140 | def forward(self, x, c): 141 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 142 | x = modulate(self.norm_final(x), shift, scale) 143 | x = self.linear(x) 144 | return x 145 | 146 | 147 | class DiT(nn.Module): 148 | """ 149 | Diffusion model with a Transformer backbone. 150 | """ 151 | def __init__( 152 | self, 153 | input_size=32, 154 | patch_size=2, 155 | in_channels=4, 156 | hidden_size=1152, 157 | depth=28, 158 | num_heads=16, 159 | mlp_ratio=4.0, 160 | class_dropout_prob=0.1, 161 | num_classes=1000, 162 | learn_sigma=True, 163 | ): 164 | super().__init__() 165 | self.learn_sigma = learn_sigma 166 | self.in_channels = in_channels 167 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 168 | self.patch_size = patch_size 169 | self.num_heads = num_heads 170 | 171 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 172 | self.t_embedder = TimestepEmbedder(hidden_size) 173 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 174 | num_patches = self.x_embedder.num_patches 175 | # Will use fixed sin-cos embedding: 176 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 177 | 178 | self.blocks = nn.ModuleList([ 179 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 180 | ]) 181 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 182 | self.initialize_weights() 183 | 184 | def initialize_weights(self): 185 | # Initialize transformer layers: 186 | def _basic_init(module): 187 | if isinstance(module, nn.Linear): 188 | torch.nn.init.xavier_uniform_(module.weight) 189 | if module.bias is not None: 190 | nn.init.constant_(module.bias, 0) 191 | self.apply(_basic_init) 192 | 193 | # Initialize (and freeze) pos_embed by sin-cos embedding: 194 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 195 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 196 | 197 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 198 | w = self.x_embedder.proj.weight.data 199 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 200 | nn.init.constant_(self.x_embedder.proj.bias, 0) 201 | 202 | # Initialize label embedding table: 203 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 204 | 205 | # Initialize timestep embedding MLP: 206 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 207 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 208 | 209 | # Zero-out adaLN modulation layers in DiT blocks: 210 | for block in self.blocks: 211 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 212 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 213 | 214 | # Zero-out output layers: 215 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 216 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 217 | nn.init.constant_(self.final_layer.linear.weight, 0) 218 | nn.init.constant_(self.final_layer.linear.bias, 0) 219 | 220 | def unpatchify(self, x): 221 | """ 222 | x: (N, T, patch_size**2 * C) 223 | imgs: (N, H, W, C) 224 | """ 225 | c = self.out_channels 226 | p = self.x_embedder.patch_size[0] 227 | h = w = int(x.shape[1] ** 0.5) 228 | assert h * w == x.shape[1] 229 | 230 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 231 | x = torch.einsum('nhwpqc->nchpwq', x) 232 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 233 | return imgs 234 | 235 | def forward(self, x, t, y): 236 | """ 237 | Forward pass of DiT. 238 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 239 | t: (N,) tensor of diffusion timesteps 240 | y: (N,) tensor of class labels 241 | """ 242 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 243 | t = self.t_embedder(t) # (N, D) 244 | y = self.y_embedder(y, self.training) # (N, D) 245 | c = t + y # (N, D) 246 | for block in self.blocks: 247 | x = block(x, c) # (N, T, D) 248 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 249 | x = self.unpatchify(x) # (N, out_channels, H, W) 250 | return x 251 | 252 | def forward_with_cfg(self, x, t, y, cfg_scale): 253 | """ 254 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 255 | """ 256 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 257 | half = x[: len(x) // 2] 258 | combined = torch.cat([half, half], dim=0) 259 | model_out = self.forward(combined, t, y) 260 | # For exact reproducibility reasons, we apply classifier-free guidance on only 261 | # three channels by default. The standard approach to cfg applies it to all channels. 262 | # This can be done by uncommenting the following line and commenting-out the line following that. 263 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 264 | # eps, rest = model_out[:, :3], model_out[:, 3:] 265 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 266 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 267 | eps = torch.cat([half_eps, half_eps], dim=0) 268 | return torch.cat([eps, rest], dim=1) 269 | 270 | 271 | ################################################################################# 272 | # Sine/Cosine Positional Embedding Functions # 273 | ################################################################################# 274 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 275 | 276 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 277 | """ 278 | grid_size: int of the grid height and width 279 | return: 280 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 281 | """ 282 | grid_h = np.arange(grid_size, dtype=np.float32) 283 | grid_w = np.arange(grid_size, dtype=np.float32) 284 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 285 | grid = np.stack(grid, axis=0) 286 | 287 | grid = grid.reshape([2, 1, grid_size, grid_size]) 288 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 289 | if cls_token and extra_tokens > 0: 290 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 291 | return pos_embed 292 | 293 | 294 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 295 | assert embed_dim % 2 == 0 296 | 297 | # use half of dimensions to encode grid_h 298 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 299 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 300 | 301 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 302 | return emb 303 | 304 | 305 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 306 | """ 307 | embed_dim: output dimension for each position 308 | pos: a list of positions to be encoded: size (M,) 309 | out: (M, D) 310 | """ 311 | assert embed_dim % 2 == 0 312 | omega = np.arange(embed_dim // 2, dtype=np.float64) 313 | omega /= embed_dim / 2. 314 | omega = 1. / 10000**omega # (D/2,) 315 | 316 | pos = pos.reshape(-1) # (M,) 317 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 318 | 319 | emb_sin = np.sin(out) # (M, D/2) 320 | emb_cos = np.cos(out) # (M, D/2) 321 | 322 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 323 | return emb 324 | 325 | 326 | ################################################################################# 327 | # DiT Configs # 328 | ################################################################################# 329 | 330 | def DiT_XL_2(**kwargs): 331 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 332 | 333 | def DiT_XL_4(**kwargs): 334 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 335 | 336 | def DiT_XL_8(**kwargs): 337 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 338 | 339 | def DiT_L_2(**kwargs): 340 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 341 | 342 | def DiT_L_4(**kwargs): 343 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 344 | 345 | def DiT_L_8(**kwargs): 346 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 347 | 348 | def DiT_B_2(**kwargs): 349 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 350 | 351 | def DiT_B_4(**kwargs): 352 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 353 | 354 | def DiT_B_8(**kwargs): 355 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 356 | 357 | def DiT_S_2(**kwargs): 358 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 359 | 360 | def DiT_S_4(**kwargs): 361 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 362 | 363 | def DiT_S_8(**kwargs): 364 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 365 | 366 | 367 | DiT_models = { 368 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 369 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 370 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 371 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 372 | } -------------------------------------------------------------------------------- /semanticist/stage1/fused_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from typing import Type 4 | 5 | class Attention(nn.Module): 6 | def __init__( 7 | self, 8 | dim: int, 9 | num_heads: int = 8, 10 | qkv_bias: bool = False, 11 | qk_norm: bool = False, 12 | proj_bias: bool = True, 13 | attn_drop: float = 0., 14 | proj_drop: float = 0., 15 | norm_layer: Type[nn.Module] = nn.LayerNorm, 16 | ) -> None: 17 | super().__init__() 18 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 19 | self.num_heads = num_heads 20 | self.head_dim = dim // num_heads 21 | self.scale = self.head_dim ** -0.5 22 | 23 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 24 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 25 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 26 | self.attn_drop = nn.Dropout(attn_drop) 27 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 28 | self.proj_drop = nn.Dropout(proj_drop) 29 | 30 | def forward(self, x, attn_mask=None): 31 | B, N, C = x.shape 32 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 33 | q, k, v = qkv.unbind(0) 34 | q, k = self.q_norm(q), self.k_norm(k) 35 | 36 | x = F.scaled_dot_product_attention( 37 | q, k, v, 38 | attn_mask=attn_mask, 39 | dropout_p=self.attn_drop.p if self.training else 0., 40 | ) 41 | 42 | x = x.transpose(1, 2).reshape(B, N, C) 43 | x = self.proj(x) 44 | x = self.proj_drop(x) 45 | return x 46 | -------------------------------------------------------------------------------- /semanticist/stage1/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_1d_sincos_pos_embed(embed_dim, grid_size): 39 | grid = np.arange(grid_size, dtype=np.float32) 40 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 41 | return pos_embed 42 | 43 | 44 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 45 | assert embed_dim % 2 == 0 46 | 47 | # use half of dimensions to encode grid_h 48 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 49 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 50 | 51 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 52 | return emb 53 | 54 | 55 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 56 | """ 57 | embed_dim: output dimension for each position 58 | pos: a list of positions to be encoded: size (M,) 59 | out: (M, D) 60 | """ 61 | assert embed_dim % 2 == 0 62 | omega = np.arange(embed_dim // 2, dtype=float) 63 | omega /= embed_dim / 2. 64 | omega = 1. / 10000**omega # (D/2,) 65 | 66 | pos = pos.reshape(-1) # (M,) 67 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 68 | 69 | emb_sin = np.sin(out) # (M, D/2) 70 | emb_cos = np.cos(out) # (M, D/2) 71 | 72 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 73 | return emb 74 | 75 | 76 | # -------------------------------------------------------- 77 | # Interpolate position embeddings for high-resolution 78 | # References: 79 | # DeiT: https://github.com/facebookresearch/deit 80 | # -------------------------------------------------------- 81 | def interpolate_pos_embed(model, checkpoint_model): 82 | if 'pos_embed' in checkpoint_model: 83 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 84 | embedding_size = pos_embed_checkpoint.shape[-1] 85 | num_patches = model.patch_embed.num_patches 86 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 87 | # height (== width) for the checkpoint position embedding 88 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 89 | # height (== width) for the new position embedding 90 | new_size = int(num_patches ** 0.5) 91 | # class_token and dist_token are kept unchanged 92 | if orig_size != new_size: 93 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 94 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 95 | # only the position tokens are interpolated 96 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 97 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 98 | pos_tokens = torch.nn.functional.interpolate( 99 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 100 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 101 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 102 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /semanticist/stage1/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler 2 | 3 | def create_transport( 4 | path_type='Linear', 5 | prediction="velocity", 6 | loss_weight=None, 7 | train_eps=None, 8 | sample_eps=None, 9 | ): 10 | """function for creating Transport object 11 | **Note**: model prediction defaults to velocity 12 | Args: 13 | - path_type: type of path to use; default to linear 14 | - learn_score: set model prediction to score 15 | - learn_noise: set model prediction to noise 16 | - velocity_weighted: weight loss by velocity weight 17 | - likelihood_weighted: weight loss by likelihood weight 18 | - train_eps: small epsilon for avoiding instability during training 19 | - sample_eps: small epsilon for avoiding instability during sampling 20 | """ 21 | 22 | if prediction == "noise": 23 | model_type = ModelType.NOISE 24 | elif prediction == "score": 25 | model_type = ModelType.SCORE 26 | else: 27 | model_type = ModelType.VELOCITY 28 | 29 | if loss_weight == "velocity": 30 | loss_type = WeightType.VELOCITY 31 | elif loss_weight == "likelihood": 32 | loss_type = WeightType.LIKELIHOOD 33 | else: 34 | loss_type = WeightType.NONE 35 | 36 | path_choice = { 37 | "Linear": PathType.LINEAR, 38 | "GVP": PathType.GVP, 39 | "VP": PathType.VP, 40 | } 41 | 42 | path_type = path_choice[path_type] 43 | 44 | if (path_type in [PathType.VP]): 45 | train_eps = 1e-5 if train_eps is None else train_eps 46 | sample_eps = 1e-3 if train_eps is None else sample_eps 47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): 48 | train_eps = 1e-3 if train_eps is None else train_eps 49 | sample_eps = 1e-3 if train_eps is None else sample_eps 50 | else: # velocity & [GVP, LINEAR] is stable everywhere 51 | train_eps = 0 52 | sample_eps = 0 53 | 54 | # create flow state 55 | state = Transport( 56 | model_type=model_type, 57 | path_type=path_type, 58 | loss_type=loss_type, 59 | train_eps=train_eps, 60 | sample_eps=sample_eps, 61 | ) 62 | 63 | return state -------------------------------------------------------------------------------- /semanticist/stage1/transport/integrators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | from torchdiffeq import odeint 5 | from functools import partial 6 | from tqdm import tqdm 7 | 8 | class sde: 9 | """SDE solver class""" 10 | def __init__( 11 | self, 12 | drift, 13 | diffusion, 14 | *, 15 | t0, 16 | t1, 17 | num_steps, 18 | sampler_type, 19 | temperature=1.0, 20 | ): 21 | assert t0 < t1, "SDE sampler has to be in forward time" 22 | 23 | self.num_timesteps = num_steps 24 | self.t = th.linspace(t0, t1, num_steps) 25 | self.dt = self.t[1] - self.t[0] 26 | self.drift = drift 27 | self.diffusion = diffusion 28 | self.sampler_type = sampler_type 29 | self.temperature = temperature 30 | 31 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): 32 | w_cur = th.randn(x.size()).to(x) 33 | t = th.ones(x.size(0)).to(x) * t 34 | dw = w_cur * th.sqrt(self.dt) 35 | drift = self.drift(x, t, model, **model_kwargs) 36 | diffusion = self.diffusion(x, t) 37 | mean_x = x + drift * self.dt 38 | x = mean_x + th.sqrt(2 * diffusion) * dw * self.temperature 39 | return x, mean_x 40 | 41 | def __Heun_step(self, x, _, t, model, **model_kwargs): 42 | w_cur = th.randn(x.size()).to(x) 43 | dw = w_cur * th.sqrt(self.dt) * self.temperature 44 | t_cur = th.ones(x.size(0)).to(x) * t 45 | diffusion = self.diffusion(x, t_cur) 46 | xhat = x + th.sqrt(2 * diffusion) * dw 47 | K1 = self.drift(xhat, t_cur, model, **model_kwargs) 48 | xp = xhat + self.dt * K1 49 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) 50 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step 51 | 52 | def __forward_fn(self): 53 | """TODO: generalize here by adding all private functions ending with steps to it""" 54 | sampler_dict = { 55 | "Euler": self.__Euler_Maruyama_step, 56 | "Heun": self.__Heun_step, 57 | } 58 | 59 | try: 60 | sampler = sampler_dict[self.sampler_type] 61 | except: 62 | raise NotImplementedError("Smapler type not implemented.") 63 | 64 | return sampler 65 | 66 | def sample(self, init, model, **model_kwargs): 67 | """forward loop of sde""" 68 | x = init 69 | mean_x = init 70 | samples = [] 71 | sampler = self.__forward_fn() 72 | for ti in self.t[:-1]: 73 | with th.no_grad(): 74 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) 75 | samples.append(x) 76 | 77 | return samples 78 | 79 | class ode: 80 | """ODE solver class""" 81 | def __init__( 82 | self, 83 | drift, 84 | *, 85 | t0, 86 | t1, 87 | sampler_type, 88 | num_steps, 89 | atol, 90 | rtol, 91 | temperature=1.0, 92 | ): 93 | assert t0 < t1, "ODE sampler has to be in forward time" 94 | 95 | self.drift = drift 96 | self.t = th.linspace(t0, t1, num_steps) 97 | self.atol = atol 98 | self.rtol = rtol 99 | self.sampler_type = sampler_type 100 | self.temperature = temperature 101 | 102 | def sample(self, x, model, **model_kwargs): 103 | 104 | device = x[0].device if isinstance(x, tuple) else x.device 105 | def _fn(t, x): 106 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t 107 | # For ODE, we scale the drift by the temperature 108 | # This is equivalent to scaling time by 1/temperature 109 | model_output = self.drift(x, t, model, **model_kwargs) 110 | if self.temperature != 1.0: 111 | # If it's a tuple (for likelihood calculation), only scale the first element 112 | if isinstance(model_output, tuple): 113 | scaled_output = (model_output[0] / self.temperature, model_output[1]) 114 | return scaled_output 115 | else: 116 | return model_output / self.temperature 117 | return model_output 118 | 119 | t = self.t.to(device) 120 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 121 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 122 | samples = odeint( 123 | _fn, 124 | x, 125 | t, 126 | method=self.sampler_type, 127 | atol=atol, 128 | rtol=rtol 129 | ) 130 | return samples -------------------------------------------------------------------------------- /semanticist/stage1/transport/path.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from functools import partial 4 | 5 | def expand_t_like_x(t, x): 6 | """Function to reshape time t to broadcastable dimension of x 7 | Args: 8 | t: [batch_dim,], time vector 9 | x: [batch_dim,...], data point 10 | """ 11 | dims = [1] * (len(x.size()) - 1) 12 | t = t.view(t.size(0), *dims) 13 | return t 14 | 15 | 16 | #################### Coupling Plans #################### 17 | 18 | class ICPlan: 19 | """Linear Coupling Plan""" 20 | def __init__(self, sigma=0.0): 21 | self.sigma = sigma 22 | 23 | def compute_alpha_t(self, t): 24 | """Compute the data coefficient along the path""" 25 | return t, 1 26 | 27 | def compute_sigma_t(self, t): 28 | """Compute the noise coefficient along the path""" 29 | return 1 - t, -1 30 | 31 | def compute_d_alpha_alpha_ratio_t(self, t): 32 | """Compute the ratio between d_alpha and alpha""" 33 | return 1 / t 34 | 35 | def compute_drift(self, x, t): 36 | """We always output sde according to score parametrization; """ 37 | t = expand_t_like_x(t, x) 38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) 39 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 40 | drift = alpha_ratio * x 41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t 42 | 43 | return -drift, diffusion 44 | 45 | def compute_diffusion(self, x, t, form="constant", norm=1.0): 46 | """Compute the diffusion term of the SDE 47 | Args: 48 | x: [batch_dim, ...], data point 49 | t: [batch_dim,], time vector 50 | form: str, form of the diffusion term 51 | norm: float, norm of the diffusion term 52 | """ 53 | t = expand_t_like_x(t, x) 54 | choices = { 55 | "constant": norm, 56 | "SBDM": norm * self.compute_drift(x, t)[1], 57 | "sigma": norm * self.compute_sigma_t(t)[0], 58 | "linear": norm * (1 - t), 59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, 60 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, 61 | } 62 | 63 | try: 64 | diffusion = choices[form] 65 | except KeyError: 66 | raise NotImplementedError(f"Diffusion form {form} not implemented") 67 | 68 | return diffusion 69 | 70 | def get_score_from_velocity(self, velocity, x, t): 71 | """Wrapper function: transfrom velocity prediction model to score 72 | Args: 73 | velocity: [batch_dim, ...] shaped tensor; velocity model output 74 | x: [batch_dim, ...] shaped tensor; x_t data point 75 | t: [batch_dim,] time tensor 76 | """ 77 | t = expand_t_like_x(t, x) 78 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 79 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 80 | mean = x 81 | reverse_alpha_ratio = alpha_t / d_alpha_t 82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 83 | score = (reverse_alpha_ratio * velocity - mean) / var 84 | return score 85 | 86 | def get_noise_from_velocity(self, velocity, x, t): 87 | """Wrapper function: transfrom velocity prediction model to denoiser 88 | Args: 89 | velocity: [batch_dim, ...] shaped tensor; velocity model output 90 | x: [batch_dim, ...] shaped tensor; x_t data point 91 | t: [batch_dim,] time tensor 92 | """ 93 | t = expand_t_like_x(t, x) 94 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 95 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 96 | mean = x 97 | reverse_alpha_ratio = alpha_t / d_alpha_t 98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t 99 | noise = (reverse_alpha_ratio * velocity - mean) / var 100 | return noise 101 | 102 | def get_velocity_from_score(self, score, x, t): 103 | """Wrapper function: transfrom score prediction model to velocity 104 | Args: 105 | score: [batch_dim, ...] shaped tensor; score model output 106 | x: [batch_dim, ...] shaped tensor; x_t data point 107 | t: [batch_dim,] time tensor 108 | """ 109 | t = expand_t_like_x(t, x) 110 | drift, var = self.compute_drift(x, t) 111 | velocity = var * score - drift 112 | return velocity 113 | 114 | def compute_mu_t(self, t, x0, x1): 115 | """Compute the mean of time-dependent density p_t""" 116 | t = expand_t_like_x(t, x1) 117 | alpha_t, _ = self.compute_alpha_t(t) 118 | sigma_t, _ = self.compute_sigma_t(t) 119 | return alpha_t * x1 + sigma_t * x0 120 | 121 | def compute_xt(self, t, x0, x1): 122 | """Sample xt from time-dependent density p_t; rng is required""" 123 | xt = self.compute_mu_t(t, x0, x1) 124 | return xt 125 | 126 | def compute_ut(self, t, x0, x1, xt): 127 | """Compute the vector field corresponding to p_t""" 128 | t = expand_t_like_x(t, x1) 129 | _, d_alpha_t = self.compute_alpha_t(t) 130 | _, d_sigma_t = self.compute_sigma_t(t) 131 | return d_alpha_t * x1 + d_sigma_t * x0 132 | 133 | def plan(self, t, x0, x1): 134 | xt = self.compute_xt(t, x0, x1) 135 | ut = self.compute_ut(t, x0, x1, xt) 136 | return t, xt, ut 137 | 138 | 139 | class VPCPlan(ICPlan): 140 | """class for VP path flow matching""" 141 | 142 | def __init__(self, sigma_min=0.1, sigma_max=20.0): 143 | self.sigma_min = sigma_min 144 | self.sigma_max = sigma_max 145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min 146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min 147 | 148 | 149 | def compute_alpha_t(self, t): 150 | """Compute coefficient of x1""" 151 | alpha_t = self.log_mean_coeff(t) 152 | alpha_t = th.exp(alpha_t) 153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t) 154 | return alpha_t, d_alpha_t 155 | 156 | def compute_sigma_t(self, t): 157 | """Compute coefficient of x0""" 158 | p_sigma_t = 2 * self.log_mean_coeff(t) 159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) 160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) 161 | return sigma_t, d_sigma_t 162 | 163 | def compute_d_alpha_alpha_ratio_t(self, t): 164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 165 | return self.d_log_mean_coeff(t) 166 | 167 | def compute_drift(self, x, t): 168 | """Compute the drift term of the SDE""" 169 | t = expand_t_like_x(t, x) 170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) 171 | return -0.5 * beta_t * x, beta_t / 2 172 | 173 | 174 | class GVPCPlan(ICPlan): 175 | def __init__(self, sigma=0.0): 176 | super().__init__(sigma) 177 | 178 | def compute_alpha_t(self, t): 179 | """Compute coefficient of x1""" 180 | alpha_t = th.sin(t * np.pi / 2) 181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) 182 | return alpha_t, d_alpha_t 183 | 184 | def compute_sigma_t(self, t): 185 | """Compute coefficient of x0""" 186 | sigma_t = th.cos(t * np.pi / 2) 187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) 188 | return sigma_t, d_sigma_t 189 | 190 | def compute_d_alpha_alpha_ratio_t(self, t): 191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 192 | return np.pi / (2 * th.tan(t * np.pi / 2)) -------------------------------------------------------------------------------- /semanticist/stage1/transport/transport.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | import logging 4 | 5 | import enum 6 | 7 | from . import path 8 | from .utils import EasyDict, log_state, mean_flat 9 | from .integrators import ode, sde 10 | 11 | class ModelType(enum.Enum): 12 | """ 13 | Which type of output the model predicts. 14 | """ 15 | 16 | NOISE = enum.auto() # the model predicts epsilon 17 | SCORE = enum.auto() # the model predicts \nabla \log p(x) 18 | VELOCITY = enum.auto() # the model predicts v(x) 19 | 20 | class PathType(enum.Enum): 21 | """ 22 | Which type of path to use. 23 | """ 24 | 25 | LINEAR = enum.auto() 26 | GVP = enum.auto() 27 | VP = enum.auto() 28 | 29 | class WeightType(enum.Enum): 30 | """ 31 | Which type of weighting to use. 32 | """ 33 | 34 | NONE = enum.auto() 35 | VELOCITY = enum.auto() 36 | LIKELIHOOD = enum.auto() 37 | 38 | 39 | class Transport: 40 | 41 | def __init__( 42 | self, 43 | *, 44 | model_type, 45 | path_type, 46 | loss_type, 47 | train_eps, 48 | sample_eps, 49 | ): 50 | path_options = { 51 | PathType.LINEAR: path.ICPlan, 52 | PathType.GVP: path.GVPCPlan, 53 | PathType.VP: path.VPCPlan, 54 | } 55 | 56 | self.loss_type = loss_type 57 | self.model_type = model_type 58 | self.path_sampler = path_options[path_type]() 59 | self.train_eps = train_eps 60 | self.sample_eps = sample_eps 61 | 62 | def prior_logp(self, z): 63 | ''' 64 | Standard multivariate normal prior 65 | Assume z is batched 66 | ''' 67 | shape = th.tensor(z.size()) 68 | N = th.prod(shape[1:]) 69 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. 70 | return th.vmap(_fn)(z) 71 | 72 | 73 | def check_interval( 74 | self, 75 | train_eps, 76 | sample_eps, 77 | *, 78 | diffusion_form="SBDM", 79 | sde=False, 80 | reverse=False, 81 | eval=False, 82 | last_step_size=0.0, 83 | ): 84 | t0 = 0 85 | t1 = 1 86 | eps = train_eps if not eval else sample_eps 87 | if (type(self.path_sampler) in [path.VPCPlan]): 88 | 89 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 90 | 91 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ 92 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step 93 | 94 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 95 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 96 | 97 | if reverse: 98 | t0, t1 = 1 - t0, 1 - t1 99 | 100 | return t0, t1 101 | 102 | 103 | def sample(self, x1): 104 | """Sampling x0 & t based on shape of x1 (if needed) 105 | Args: 106 | x1 - data point; [batch, *dim] 107 | """ 108 | 109 | x0 = th.randn_like(x1) 110 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps) 111 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 112 | t = t.to(x1) 113 | return t, x0, x1 114 | 115 | 116 | def training_losses( 117 | self, 118 | model, 119 | x1, 120 | model_kwargs=None 121 | ): 122 | """Loss for training the score model 123 | Args: 124 | - model: backbone model; could be score, noise, or velocity 125 | - x1: datapoint 126 | - model_kwargs: additional arguments for the model 127 | """ 128 | if model_kwargs == None: 129 | model_kwargs = {} 130 | 131 | t, x0, x1 = self.sample(x1) 132 | t, xt, ut = self.path_sampler.plan(t, x0, x1) 133 | model_output = model(xt, t, **model_kwargs) 134 | if len(model_output.shape) == len(xt.shape) + 1: 135 | x0 = x0.unsqueeze(-1).expand(*([-1] * (len(x0.shape))), model_output.shape[-1]) 136 | xt = xt.unsqueeze(-1).expand(*([-1] * (len(xt.shape))), model_output.shape[-1]) 137 | ut = ut.unsqueeze(-1).expand(*([-1] * (len(ut.shape))), model_output.shape[-1]) 138 | B, C = xt.shape[:2] 139 | assert model_output.shape == (B, C, *xt.shape[2:]) 140 | 141 | terms = {} 142 | terms['pred'] = model_output 143 | if self.model_type == ModelType.VELOCITY: 144 | terms['loss'] = mean_flat(((model_output - ut) ** 2)) 145 | else: 146 | _, drift_var = self.path_sampler.compute_drift(xt, t) 147 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) 148 | if self.loss_type in [WeightType.VELOCITY]: 149 | weight = (drift_var / sigma_t) ** 2 150 | elif self.loss_type in [WeightType.LIKELIHOOD]: 151 | weight = drift_var / (sigma_t ** 2) 152 | elif self.loss_type in [WeightType.NONE]: 153 | weight = 1 154 | else: 155 | raise NotImplementedError() 156 | 157 | if self.model_type == ModelType.NOISE: 158 | terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) 159 | else: 160 | terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) 161 | 162 | return terms 163 | 164 | 165 | def get_drift( 166 | self 167 | ): 168 | """member function for obtaining the drift of the probability flow ODE""" 169 | def score_ode(x, t, model, **model_kwargs): 170 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 171 | model_output = model(x, t, **model_kwargs) 172 | return (-drift_mean + drift_var * model_output) # by change of variable 173 | 174 | def noise_ode(x, t, model, **model_kwargs): 175 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 176 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) 177 | model_output = model(x, t, **model_kwargs) 178 | score = model_output / -sigma_t 179 | return (-drift_mean + drift_var * score) 180 | 181 | def velocity_ode(x, t, model, **model_kwargs): 182 | model_output = model(x, t, **model_kwargs) 183 | return model_output 184 | 185 | if self.model_type == ModelType.NOISE: 186 | drift_fn = noise_ode 187 | elif self.model_type == ModelType.SCORE: 188 | drift_fn = score_ode 189 | else: 190 | drift_fn = velocity_ode 191 | 192 | def body_fn(x, t, model, **model_kwargs): 193 | model_output = drift_fn(x, t, model, **model_kwargs) 194 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" 195 | return model_output 196 | 197 | return body_fn 198 | 199 | 200 | def get_score( 201 | self, 202 | ): 203 | """member function for obtaining score of 204 | x_t = alpha_t * x + sigma_t * eps""" 205 | if self.model_type == ModelType.NOISE: 206 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] 207 | elif self.model_type == ModelType.SCORE: 208 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) 209 | elif self.model_type == ModelType.VELOCITY: 210 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) 211 | else: 212 | raise NotImplementedError() 213 | 214 | return score_fn 215 | 216 | 217 | class Sampler: 218 | """Sampler class for the transport model""" 219 | def __init__( 220 | self, 221 | transport, 222 | ): 223 | """Constructor for a general sampler; supporting different sampling methods 224 | Args: 225 | - transport: an tranport object specify model prediction & interpolant type 226 | """ 227 | 228 | self.transport = transport 229 | self.drift = self.transport.get_drift() 230 | self.score = self.transport.get_score() 231 | 232 | def __get_sde_diffusion_and_drift( 233 | self, 234 | *, 235 | diffusion_form="SBDM", 236 | diffusion_norm=1.0, 237 | ): 238 | 239 | def diffusion_fn(x, t): 240 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) 241 | return diffusion 242 | 243 | sde_drift = \ 244 | lambda x, t, model, **kwargs: \ 245 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) 246 | 247 | sde_diffusion = diffusion_fn 248 | 249 | return sde_drift, sde_diffusion 250 | 251 | def __get_last_step( 252 | self, 253 | sde_drift, 254 | *, 255 | last_step, 256 | last_step_size, 257 | ): 258 | """Get the last step function of the SDE solver""" 259 | 260 | if last_step is None: 261 | last_step_fn = \ 262 | lambda x, t, model, **model_kwargs: \ 263 | x 264 | elif last_step == "Mean": 265 | last_step_fn = \ 266 | lambda x, t, model, **model_kwargs: \ 267 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size 268 | elif last_step == "Tweedie": 269 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long 270 | sigma = self.transport.path_sampler.compute_sigma_t 271 | last_step_fn = \ 272 | lambda x, t, model, **model_kwargs: \ 273 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) 274 | elif last_step == "Euler": 275 | last_step_fn = \ 276 | lambda x, t, model, **model_kwargs: \ 277 | x + self.drift(x, t, model, **model_kwargs) * last_step_size 278 | else: 279 | raise NotImplementedError() 280 | 281 | return last_step_fn 282 | 283 | def sample_sde( 284 | self, 285 | *, 286 | sampling_method="Euler", 287 | diffusion_form="SBDM", 288 | diffusion_norm=1.0, 289 | last_step="Mean", 290 | last_step_size=0.04, 291 | num_steps=250, 292 | temperature=1.0, 293 | ): 294 | """returns a sampling function with given SDE settings 295 | Args: 296 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama 297 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM 298 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1 299 | - last_step: type of the last step; default to identity 300 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] 301 | - num_steps: total integration step of SDE 302 | - temperature: temperature scaling for the noise during sampling; default to 1.0 303 | """ 304 | 305 | if last_step is None: 306 | last_step_size = 0.0 307 | 308 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( 309 | diffusion_form=diffusion_form, 310 | diffusion_norm=diffusion_norm, 311 | ) 312 | 313 | t0, t1 = self.transport.check_interval( 314 | self.transport.train_eps, 315 | self.transport.sample_eps, 316 | diffusion_form=diffusion_form, 317 | sde=True, 318 | eval=True, 319 | reverse=False, 320 | last_step_size=last_step_size, 321 | ) 322 | 323 | _sde = sde( 324 | sde_drift, 325 | sde_diffusion, 326 | t0=t0, 327 | t1=t1, 328 | num_steps=num_steps, 329 | sampler_type=sampling_method, 330 | temperature=temperature 331 | ) 332 | 333 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) 334 | 335 | 336 | def _sample(init, model, **model_kwargs): 337 | xs = _sde.sample(init, model, **model_kwargs) 338 | ts = th.ones(init.size(0), device=init.device) * t1 339 | x = last_step_fn(xs[-1], ts, model, **model_kwargs) 340 | xs.append(x) 341 | 342 | assert len(xs) == num_steps, "Samples does not match the number of steps" 343 | 344 | return xs 345 | 346 | return _sample 347 | 348 | def sample_ode( 349 | self, 350 | *, 351 | sampling_method="dopri5", 352 | num_steps=50, 353 | atol=1e-6, 354 | rtol=1e-3, 355 | reverse=False, 356 | temperature=1.0, 357 | ): 358 | """returns a sampling function with given ODE settings 359 | Args: 360 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 361 | - num_steps: 362 | - fixed solver (Euler, Heun): the actual number of integration steps performed 363 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 364 | - atol: absolute error tolerance for the solver 365 | - rtol: relative error tolerance for the solver 366 | - reverse: whether solving the ODE in reverse (data to noise); default to False 367 | - temperature: temperature scaling for the drift during sampling; default to 1.0 368 | """ 369 | if reverse: 370 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) 371 | else: 372 | drift = self.drift 373 | 374 | t0, t1 = self.transport.check_interval( 375 | self.transport.train_eps, 376 | self.transport.sample_eps, 377 | sde=False, 378 | eval=True, 379 | reverse=reverse, 380 | last_step_size=0.0, 381 | ) 382 | 383 | _ode = ode( 384 | drift=drift, 385 | t0=t0, 386 | t1=t1, 387 | sampler_type=sampling_method, 388 | num_steps=num_steps, 389 | atol=atol, 390 | rtol=rtol, 391 | temperature=temperature, 392 | ) 393 | 394 | return _ode.sample 395 | 396 | def sample_ode_likelihood( 397 | self, 398 | *, 399 | sampling_method="dopri5", 400 | num_steps=50, 401 | atol=1e-6, 402 | rtol=1e-3, 403 | temperature=1.0, 404 | ): 405 | 406 | """returns a sampling function for calculating likelihood with given ODE settings 407 | Args: 408 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 409 | - num_steps: 410 | - fixed solver (Euler, Heun): the actual number of integration steps performed 411 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 412 | - atol: absolute error tolerance for the solver 413 | - rtol: relative error tolerance for the solver 414 | - temperature: temperature scaling for the drift during sampling; default to 1.0 415 | """ 416 | def _likelihood_drift(x, t, model, **model_kwargs): 417 | x, _ = x 418 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 419 | t = th.ones_like(t) * (1 - t) 420 | with th.enable_grad(): 421 | x.requires_grad = True 422 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] 423 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) 424 | drift = self.drift(x, t, model, **model_kwargs) 425 | return (-drift, logp_grad) 426 | 427 | t0, t1 = self.transport.check_interval( 428 | self.transport.train_eps, 429 | self.transport.sample_eps, 430 | sde=False, 431 | eval=True, 432 | reverse=False, 433 | last_step_size=0.0, 434 | ) 435 | 436 | _ode = ode( 437 | drift=_likelihood_drift, 438 | t0=t0, 439 | t1=t1, 440 | sampler_type=sampling_method, 441 | num_steps=num_steps, 442 | atol=atol, 443 | rtol=rtol, 444 | temperature=temperature, 445 | ) 446 | 447 | def _sample_fn(x, model, **model_kwargs): 448 | init_logp = th.zeros(x.size(0)).to(x) 449 | input = (x, init_logp) 450 | drift, delta_logp = _ode.sample(input, model, **model_kwargs) 451 | drift, delta_logp = drift[-1], delta_logp[-1] 452 | prior_logp = self.transport.prior_logp(drift) 453 | logp = prior_logp - delta_logp 454 | return logp, drift 455 | 456 | return _sample_fn -------------------------------------------------------------------------------- /semanticist/stage1/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if " 0.0: 36 | random_tensor.div_(keep_prob) 37 | return x * random_tensor 38 | 39 | 40 | class DropPath(nn.Module): 41 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 42 | """ 43 | 44 | def __init__(self, drop_prob=None): 45 | super(DropPath, self).__init__() 46 | self.drop_prob = drop_prob 47 | 48 | def forward(self, x): 49 | return drop_path(x, self.drop_prob, self.training) 50 | 51 | 52 | class Mlp(nn.Module): 53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 54 | super().__init__() 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | self.fc1 = nn.Linear(in_features, hidden_features) 58 | self.act = act_layer() 59 | self.fc2 = nn.Linear(hidden_features, out_features) 60 | self.drop = nn.Dropout(drop) 61 | 62 | def forward(self, x): 63 | x = self.fc1(x) 64 | x = self.act(x) 65 | x = self.drop(x) 66 | x = self.fc2(x) 67 | x = self.drop(x) 68 | return x 69 | 70 | 71 | class Block(nn.Module): 72 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., 73 | attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0): 74 | super().__init__() 75 | self.norm1 = norm_layer(dim) 76 | self.attn = Attention( 77 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 78 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 79 | self.norm2 = norm_layer(dim) 80 | mlp_hidden_dim = int(dim * mlp_ratio) 81 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 82 | 83 | if init_values > 0: 84 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 85 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 86 | else: 87 | self.gamma_1, self.gamma_2 = None, None 88 | 89 | def forward(self, x, attn_mask=None): 90 | y = self.attn(self.norm1(x), attn_mask=attn_mask) 91 | if self.gamma_1 is None: 92 | x = x + self.drop_path(y) 93 | x = x + self.drop_path(self.mlp(self.norm2(x))) 94 | else: 95 | x = x + self.drop_path(self.gamma_1 * y) 96 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 97 | return x 98 | 99 | 100 | class PatchEmbed(nn.Module): 101 | """ Image to Patch Embedding 102 | """ 103 | 104 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 105 | super().__init__() 106 | num_patches = (img_size // patch_size) * (img_size // patch_size) 107 | self.img_size = img_size 108 | self.patch_size = patch_size 109 | self.num_patches = num_patches 110 | 111 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 112 | 113 | def forward(self, x): 114 | B, C, H, W = x.shape 115 | return self.proj(x) 116 | 117 | 118 | class VisionTransformer(nn.Module): 119 | """ Vision Transformer """ 120 | 121 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12, 122 | num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 123 | drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), 124 | init_values=0, num_slots=16): 125 | super().__init__() 126 | self.num_features = self.embed_dim = embed_dim 127 | 128 | self.patch_embed = PatchEmbed( 129 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 130 | num_patches = self.patch_embed.num_patches 131 | 132 | self.num_slots = num_slots 133 | 134 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 135 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + self.num_slots, embed_dim)) 136 | self.slot_embed = nn.Parameter(torch.zeros(1, num_slots, embed_dim)) 137 | 138 | self.pos_drop = nn.Dropout(p=drop_rate) 139 | 140 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 141 | self.blocks = nn.ModuleList([ 142 | Block( 143 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 144 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 145 | init_values=init_values) 146 | for i in range(depth)]) 147 | 148 | self.norm = norm_layer(embed_dim) 149 | 150 | nn.init.trunc_normal_(self.pos_embed, std=.02) 151 | nn.init.trunc_normal_(self.cls_token, std=.02) 152 | nn.init.trunc_normal_(self.slot_embed, std=.02) 153 | self.apply(self._init_weights) 154 | 155 | def _init_weights(self, m): 156 | if isinstance(m, nn.Linear): 157 | nn.init.trunc_normal_(m.weight, std=.02) 158 | if isinstance(m, nn.Linear) and m.bias is not None: 159 | nn.init.constant_(m.bias, 0) 160 | elif isinstance(m, nn.LayerNorm): 161 | nn.init.constant_(m.bias, 0) 162 | nn.init.constant_(m.weight, 1.0) 163 | 164 | def interpolate_pos_encoding(self, x, w, h): 165 | npatch = x.shape[1] - 1 - self.num_slots 166 | N = self.pos_embed.shape[1] - 1 - self.num_slots 167 | if npatch == N and w == h: 168 | return self.pos_embed 169 | class_pos_embed = self.pos_embed[:, 0] 170 | patch_pos_embed = self.pos_embed[:, 1:1+npatch] 171 | dim = x.shape[-1] 172 | w0 = w // self.patch_embed.patch_size[0] 173 | h0 = h // self.patch_embed.patch_size[1] 174 | # we add a small number to avoid floating point error in the interpolation 175 | # see discussion at https://github.com/facebookresearch/dino/issues/8 176 | w0, h0 = w0 + 0.1, h0 + 0.1 177 | patch_pos_embed = nn.functional.interpolate( 178 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 179 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 180 | mode='bicubic', 181 | ) 182 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 183 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 184 | 185 | slots_pos_embed = self.pos_embed[:, 1+npatch:] 186 | slots_pos_embed = slots_pos_embed.view(1, -1, dim) # (1, num_slots, dim) 187 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, slots_pos_embed), dim=1) 188 | 189 | def prepare_tokens(self, x): 190 | B, nc, w, h = x.shape 191 | x = self.patch_embed(x) 192 | x = x.flatten(2).transpose(1, 2) 193 | x = torch.cat((self.cls_token.expand(B, -1, -1), x, self.slot_embed.expand(B, -1, -1)), dim=1) 194 | x = x + self.interpolate_pos_encoding(x, w, h) 195 | return self.pos_drop(x) 196 | 197 | def forward(self, x, is_causal=True): 198 | x = self.prepare_tokens(x) 199 | if is_causal: 200 | attn_mask = torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool) 201 | # slots are causal to each other 202 | causal_mask = torch.ones(self.num_slots, self.num_slots, device=x.device, dtype=torch.bool).tril(diagonal=0) 203 | attn_mask[-self.num_slots:, -self.num_slots:] = causal_mask 204 | # cls token and patches should not see slots 205 | attn_mask[:-self.num_slots, -self.num_slots:] = False 206 | else: 207 | attn_mask = None 208 | 209 | for blk in self.blocks: 210 | x = blk(x, attn_mask=attn_mask) 211 | 212 | x = self.norm(x) 213 | outcome = x[:, -self.num_slots:] # return the slots 214 | return outcome 215 | 216 | def get_intermediate_layers(self, x, n=1): 217 | x = self.prepare_tokens(x) 218 | # we return the output tokens from the `n` last blocks 219 | output = [] 220 | for i, blk in enumerate(self.blocks): 221 | x = blk(x) 222 | if len(self.blocks) - i <= n: 223 | output.append(self.norm(x)) 224 | return output 225 | 226 | 227 | def vit_tiny_patch16(**kwargs): 228 | model = VisionTransformer( 229 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 230 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 231 | return model 232 | 233 | 234 | def vit_small_patch16(**kwargs): 235 | model = VisionTransformer( 236 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 237 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 238 | return model 239 | 240 | 241 | def vit_base_patch16(**kwargs): 242 | model = VisionTransformer( 243 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 244 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 245 | return model 246 | 247 | 248 | def vit_large_patch16(**kwargs): 249 | model = VisionTransformer( 250 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 251 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 252 | return model 253 | 254 | 255 | def vit_huge_patch14(**kwargs): 256 | model = VisionTransformer( 257 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 258 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 259 | return model -------------------------------------------------------------------------------- /semanticist/stage2/diffloss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from semanticist.stage1.diffusion import create_diffusion 6 | from semanticist.stage1.transport import create_transport, Sampler 7 | 8 | 9 | class DiffLoss(nn.Module): 10 | """Diffusion Loss""" 11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, predict_xstart=False, use_si=False, cond_method="adaln"): 12 | super(DiffLoss, self).__init__() 13 | self.in_channels = target_channels 14 | self.net = SimpleMLPAdaLN( 15 | in_channels=target_channels, 16 | model_channels=width, 17 | out_channels=target_channels * 2 if not use_si else target_channels, # for vlb loss 18 | z_channels=z_channels, 19 | num_res_blocks=depth, 20 | cond_method=cond_method, 21 | ) 22 | self.use_si = use_si 23 | if not use_si: 24 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine", predict_xstart=predict_xstart) 25 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine", predict_xstart=predict_xstart) 26 | else: 27 | self.transport = create_transport() 28 | self.sampler = Sampler(self.transport) 29 | 30 | def forward(self, target, z, mask=None): 31 | model_kwargs = dict(c=z) 32 | if not self.use_si: 33 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) 34 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) 35 | else: 36 | loss_dict = self.transport.training_losses(self.net, target, model_kwargs) 37 | loss = loss_dict["loss"] 38 | if mask is not None: 39 | loss = (loss * mask).sum() / mask.sum() 40 | return loss.mean() 41 | 42 | def sample(self, z, temperature=1.0, cfg=1.0): 43 | # diffusion loss sampling 44 | device = z.device 45 | if not cfg == 1.0: 46 | noise = torch.randn(z.shape[0] // 2, self.in_channels, device=device) 47 | noise = torch.cat([noise, noise], dim=0) 48 | model_kwargs = dict(c=z, cfg_scale=cfg) 49 | sample_fn = self.net.forward_with_cfg 50 | else: 51 | noise = torch.randn(z.shape[0], self.in_channels, device=device) 52 | model_kwargs = dict(c=z) 53 | sample_fn = self.net.forward 54 | 55 | if not self.use_si: 56 | sampled_token_latent = self.gen_diffusion.p_sample_loop( 57 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False, 58 | temperature=temperature, device=device 59 | ) 60 | else: 61 | sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma", temperature=temperature) 62 | sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1] 63 | if cfg != 1.0: 64 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) 65 | return sampled_token_latent 66 | 67 | 68 | def modulate(x, shift, scale): 69 | return x * (1 + scale) + shift 70 | 71 | 72 | class TimestepEmbedder(nn.Module): 73 | """ 74 | Embeds scalar timesteps into vector representations. 75 | """ 76 | def __init__(self, hidden_size, frequency_embedding_size=256): 77 | super().__init__() 78 | self.mlp = nn.Sequential( 79 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 80 | nn.SiLU(), 81 | nn.Linear(hidden_size, hidden_size, bias=True), 82 | ) 83 | self.frequency_embedding_size = frequency_embedding_size 84 | 85 | @staticmethod 86 | def timestep_embedding(t, dim, max_period=10000): 87 | """ 88 | Create sinusoidal timestep embeddings. 89 | :param t: a 1-D Tensor of N indices, one per batch element. 90 | These may be fractional. 91 | :param dim: the dimension of the output. 92 | :param max_period: controls the minimum frequency of the embeddings. 93 | :return: an (N, D) Tensor of positional embeddings. 94 | """ 95 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 96 | half = dim // 2 97 | freqs = torch.exp( 98 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 99 | ).to(device=t.device) 100 | args = t[:, None].float() * freqs[None] 101 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 102 | if dim % 2: 103 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 104 | return embedding 105 | 106 | def forward(self, t): 107 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 108 | t_emb = self.mlp(t_freq) 109 | return t_emb 110 | 111 | 112 | class ResBlock(nn.Module): 113 | """ 114 | A residual block with AdaLN for timestep and optional concatenation for condition. 115 | """ 116 | def __init__( 117 | self, 118 | channels, 119 | cond_method="adaln", 120 | ): 121 | super().__init__() 122 | self.channels = channels 123 | self.cond_method = cond_method 124 | 125 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 126 | self.adaLN_modulation = nn.Sequential( 127 | nn.SiLU(), 128 | nn.Linear(channels, 3 * channels, bias=True) 129 | ) 130 | 131 | # Input dimension depends on conditioning method 132 | mlp_in_dim = channels * 2 if cond_method == "concat" else channels 133 | self.mlp = nn.Sequential( 134 | nn.Linear(mlp_in_dim, channels, bias=True), 135 | nn.SiLU(), 136 | nn.Linear(channels, channels, bias=True), 137 | ) 138 | 139 | def forward(self, x, t, c=None): 140 | # Apply timestep embedding via AdaLN 141 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(3, dim=-1) 142 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 143 | 144 | # Concatenate condition if using concat method 145 | if self.cond_method == "concat" and c is not None: 146 | h = torch.cat([h, c], dim=-1) 147 | 148 | h = self.mlp(h) 149 | x = x + gate_mlp * h 150 | return x 151 | 152 | 153 | class FinalLayer(nn.Module): 154 | """ 155 | Final layer with AdaLN for timestep and optional concatenation for condition. 156 | """ 157 | def __init__(self, model_channels, out_channels, cond_method="adaln"): 158 | super().__init__() 159 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) 160 | self.cond_method = cond_method 161 | 162 | self.adaLN_modulation = nn.Sequential( 163 | nn.SiLU(), 164 | nn.Linear(model_channels, 2 * model_channels, bias=True) 165 | ) 166 | 167 | # Output dimension depends on conditioning method 168 | linear_in_dim = model_channels * 2 if cond_method == "concat" else model_channels 169 | self.linear = nn.Linear(linear_in_dim, out_channels, bias=True) 170 | 171 | def forward(self, x, t, c=None): 172 | # Apply timestep embedding via AdaLN 173 | shift, scale = self.adaLN_modulation(t).chunk(2, dim=-1) 174 | x = modulate(self.norm_final(x), shift, scale) 175 | 176 | # Concatenate condition if using concat method 177 | if self.cond_method == "concat" and c is not None: 178 | x = torch.cat([x, c], dim=-1) 179 | 180 | return self.linear(x) 181 | 182 | 183 | class SimpleMLPAdaLN(nn.Module): 184 | """ 185 | MLP for Diffusion Loss with AdaLN for timestep and optional concatenation for condition. 186 | """ 187 | def __init__( 188 | self, 189 | in_channels, 190 | model_channels, 191 | out_channels, 192 | z_channels, 193 | num_res_blocks, 194 | cond_method="adaln" 195 | ): 196 | super().__init__() 197 | self.in_channels = in_channels 198 | self.model_channels = model_channels 199 | self.out_channels = out_channels 200 | self.cond_method = cond_method 201 | 202 | self.time_embed = TimestepEmbedder(model_channels) 203 | self.cond_embed = nn.Linear(z_channels, model_channels) 204 | self.input_proj = nn.Linear(in_channels, model_channels) 205 | 206 | # Create residual blocks 207 | res_blocks = [ResBlock(model_channels, cond_method) for _ in range(num_res_blocks)] 208 | self.res_blocks = nn.ModuleList(res_blocks) 209 | 210 | self.final_layer = FinalLayer(model_channels, out_channels, cond_method=cond_method) 211 | self.initialize_weights() 212 | 213 | def initialize_weights(self): 214 | # Basic initialization for all linear layers 215 | def _basic_init(module): 216 | if isinstance(module, nn.Linear): 217 | torch.nn.init.xavier_uniform_(module.weight) 218 | if module.bias is not None: 219 | nn.init.constant_(module.bias, 0) 220 | self.apply(_basic_init) 221 | 222 | # Initialize timestep embedding MLP 223 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 224 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 225 | 226 | # Zero-out adaLN modulation layers (always used for timestep) 227 | for i, block in enumerate(self.res_blocks): 228 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 229 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 230 | 231 | # Zero-out output layers 232 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 233 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 234 | nn.init.constant_(self.final_layer.linear.weight, 0) 235 | nn.init.constant_(self.final_layer.linear.bias, 0) 236 | 237 | def forward(self, x, t, c): 238 | """ 239 | Apply the model to an input batch. 240 | :param x: an [N x C] Tensor of inputs. 241 | :param t: a 1-D batch of timesteps. 242 | :param c: conditioning from AR transformer. 243 | :return: an [N x C] Tensor of outputs. 244 | """ 245 | x = self.input_proj(x) 246 | t_emb = self.time_embed(t) 247 | c_emb = self.cond_embed(c) 248 | 249 | # Prepare conditioning based on method 250 | if self.cond_method == "adaln": 251 | t_combined, c_for_concat = t_emb + c_emb, None 252 | else: # concat 253 | t_combined, c_for_concat = t_emb, c_emb 254 | 255 | for block in self.res_blocks: 256 | x = block(x, t_combined, c_for_concat) 257 | return self.final_layer(x, t_combined, c_for_concat) 258 | 259 | def forward_with_cfg(self, x, t, c, cfg_scale): 260 | half = x[: len(x) // 2] 261 | combined = torch.cat([half, half], dim=0) 262 | model_out = self.forward(combined, t, c) 263 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 264 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 265 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 266 | eps = torch.cat([half_eps, half_eps], dim=0) 267 | return torch.cat([eps, rest], dim=1) -------------------------------------------------------------------------------- /semanticist/stage2/generate.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py 3 | # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py 4 | import torch 5 | 6 | def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, temperature: float = 1.0): 7 | tokens = model(None, cond_idx, input_pos, cfg=cfg_scale, temperature=temperature) 8 | return tokens.unsqueeze(1) 9 | 10 | 11 | def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, temperature: float = 1.0): 12 | assert input_pos.shape[-1] == 1 13 | if cfg_scale > 1.0: 14 | x = torch.cat([x, x]) 15 | tokens = model(x, cond_idx=None, input_pos=input_pos, cfg=cfg_scale, temperature=temperature) 16 | return tokens 17 | 18 | 19 | def decode_n_tokens( 20 | model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, 21 | cfg_scale: float, cfg_schedule = "constant", temperature: float = 1.0): 22 | new_tokens = [] 23 | for i in range(num_new_tokens): 24 | cfg_iter = get_cfg(cfg_scale, i + 1, num_new_tokens + 1, cfg_schedule) 25 | next_token = decode_one_token(model, cur_token, input_pos, cfg_iter, temperature=temperature).unsqueeze(1) 26 | input_pos += 1 27 | new_tokens.append(next_token.clone()) 28 | cur_token = next_token 29 | 30 | return new_tokens 31 | 32 | 33 | @torch.no_grad() 34 | def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_schedule = "constant", temperature: float = 1.0): 35 | if cfg_scale > 1.0: 36 | cond_null = torch.ones_like(cond) * model.num_classes 37 | cond_combined = torch.cat([cond, cond_null]) 38 | else: 39 | cond_combined = cond 40 | T = model.cls_token_num 41 | 42 | T_new = T + max_new_tokens 43 | max_seq_length = T_new 44 | max_batch_size = cond.shape[0] 45 | 46 | device = cond.device 47 | dtype = model.z_proj.weight.dtype 48 | if torch.is_autocast_enabled(): 49 | dtype = torch.get_autocast_dtype(device_type=device.type) 50 | with torch.device(device): 51 | max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size 52 | model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=dtype) 53 | 54 | if emb_masks is not None: 55 | assert emb_masks.shape[0] == max_batch_size 56 | assert emb_masks.shape[-1] == T 57 | if cfg_scale > 1.0: 58 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1) 59 | else: 60 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1) 61 | 62 | eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device) 63 | model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix 64 | 65 | # create an empty tensor of the expected final shape and fill in the current tokens 66 | seq = torch.empty((max_batch_size, T_new, model.slot_dim), dtype=dtype, device=device) 67 | 68 | input_pos = torch.arange(0, T, device=device) 69 | cfg_iter = get_cfg(cfg_scale, 0, max_new_tokens, cfg_schedule) 70 | next_token = prefill(model, cond_combined, input_pos, cfg_iter, temperature=temperature) 71 | seq[:, T:T+1] = next_token 72 | 73 | if max_new_tokens > 1: 74 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 75 | generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, cfg_scale, cfg_schedule=cfg_schedule, temperature=temperature) 76 | seq[:, T+1:] = torch.cat(generated_tokens, dim=1) 77 | 78 | model.reset_caches() 79 | return seq[:, T:] 80 | 81 | 82 | def get_cfg(cfg, cur_step, total_step, cfg_schedule="constant"): 83 | if cfg_schedule == "linear": 84 | return 1 + (cfg - 1) * (cur_step + 1) / total_step 85 | elif cfg_schedule == "constant": 86 | return cfg 87 | else: 88 | raise NotImplementedError 89 | -------------------------------------------------------------------------------- /semanticist/stage2/gpt.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py 3 | # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py 4 | # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py 5 | # llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py 6 | # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 7 | # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 8 | from typing import Optional, List, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | from semanticist.stage1.vision_transformer import DropPath 15 | from semanticist.stage2.diffloss import DiffLoss 16 | 17 | def find_multiple(n: int, k: int): 18 | if n % k == 0: 19 | return n 20 | return n + k - (n % k) 21 | 22 | 23 | 24 | ################################################################################# 25 | # Embedding Layers for Class Labels # 26 | ################################################################################# 27 | class LabelEmbedder(nn.Module): 28 | """ 29 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 30 | """ 31 | def __init__(self, num_classes, hidden_size, dropout_prob): 32 | super().__init__() 33 | use_cfg_embedding = dropout_prob > 0 34 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 35 | self.num_classes = num_classes 36 | self.dropout_prob = dropout_prob 37 | 38 | def token_drop(self, labels, force_drop_ids=None): 39 | """ 40 | Drops labels to enable classifier-free guidance. 41 | """ 42 | if force_drop_ids is None: 43 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 44 | else: 45 | drop_ids = force_drop_ids == 1 46 | labels = torch.where(drop_ids, self.num_classes, labels) 47 | return labels 48 | 49 | def forward(self, labels, train, force_drop_ids=None): 50 | use_dropout = self.dropout_prob > 0 51 | if (train and use_dropout) or (force_drop_ids is not None): 52 | labels = self.token_drop(labels, force_drop_ids) 53 | embeddings = self.embedding_table(labels).unsqueeze(1) 54 | return embeddings 55 | 56 | 57 | class MLP(nn.Module): 58 | def __init__(self, in_features, hidden_features, out_features): 59 | super().__init__() 60 | out_features = out_features or in_features 61 | hidden_features = hidden_features or in_features 62 | self.fc1 = nn.Linear(in_features, hidden_features, bias=False) 63 | self.act = nn.GELU(approximate='tanh') 64 | self.fc2 = nn.Linear(hidden_features, out_features, bias=False) 65 | 66 | def forward(self, x): 67 | x = self.fc1(x) 68 | x = self.act(x) 69 | x = self.fc2(x) 70 | return x 71 | 72 | 73 | ################################################################################# 74 | # GPT Model # 75 | ################################################################################# 76 | class RMSNorm(torch.nn.Module): 77 | def __init__(self, dim: int, eps: float = 1e-5): 78 | super().__init__() 79 | self.eps = eps 80 | self.weight = nn.Parameter(torch.ones(dim)) 81 | 82 | def _norm(self, x): 83 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 84 | 85 | def forward(self, x): 86 | output = self._norm(x.float()).type_as(x) 87 | return output * self.weight 88 | 89 | 90 | class FeedForward(nn.Module): 91 | def __init__( 92 | self, 93 | dim: int, 94 | multiple_of: int = 256, 95 | ffn_dropout_p: float = 0.0, 96 | ): 97 | super().__init__() 98 | hidden_dim = 4 * dim 99 | hidden_dim = int(2 * hidden_dim / 3) 100 | hidden_dim = find_multiple(hidden_dim, multiple_of) 101 | 102 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 103 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 104 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 105 | self.ffn_dropout = nn.Dropout(ffn_dropout_p) 106 | 107 | def forward(self, x): 108 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 109 | 110 | 111 | class KVCache(nn.Module): 112 | def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype): 113 | super().__init__() 114 | cache_shape = (max_batch_size, n_head, max_seq_length, head_dim) 115 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) 116 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) 117 | 118 | def update(self, input_pos, k_val, v_val): 119 | # input_pos: [S], k_val: [B, H, S, D] 120 | assert input_pos.shape[0] == k_val.shape[2] 121 | k_out = self.k_cache 122 | v_out = self.v_cache 123 | k_out[:, :, input_pos] = k_val 124 | v_out[:, :, input_pos] = v_val 125 | 126 | return k_out, v_out 127 | 128 | 129 | class Attention(nn.Module): 130 | def __init__( 131 | self, 132 | dim: int, 133 | n_head: int, 134 | attn_dropout_p: float = 0.0, 135 | resid_dropout_p: float = 0.1, 136 | ): 137 | super().__init__() 138 | assert dim % n_head == 0 139 | self.dim = dim 140 | self.head_dim = dim // n_head 141 | self.n_head = n_head 142 | 143 | # key, query, value projections for all heads, but in a batch 144 | self.wqkv = nn.Linear(dim, dim * 3, bias=False) 145 | self.wo = nn.Linear(dim, dim, bias=False) 146 | self.kv_cache = None 147 | 148 | # regularization 149 | self.attn_dropout_p = attn_dropout_p 150 | self.resid_dropout = nn.Dropout(resid_dropout_p) 151 | 152 | def forward( 153 | self, x: torch.Tensor, 154 | input_pos: Optional[torch.Tensor] = None, 155 | mask: Optional[torch.Tensor] = None 156 | ): 157 | bsz, seqlen, _ = x.shape 158 | xq, xk, xv = self.wqkv(x).split([self.dim, self.dim, self.dim], dim=-1) 159 | 160 | xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) 161 | xk = xk.view(bsz, seqlen, self.n_head, self.head_dim) 162 | xv = xv.view(bsz, seqlen, self.n_head, self.head_dim) 163 | 164 | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) 165 | 166 | if self.kv_cache is not None: 167 | keys, values = self.kv_cache.update(input_pos, xk, xv) 168 | else: 169 | keys, values = xk, xv 170 | 171 | output = F.scaled_dot_product_attention( 172 | xq, keys, values, 173 | attn_mask=mask, 174 | is_causal=True if mask is None else False, # is_causal=False is for KV cache 175 | dropout_p=self.attn_dropout_p if self.training else 0) 176 | 177 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 178 | 179 | output = self.resid_dropout(self.wo(output)) 180 | return output 181 | 182 | 183 | class TransformerBlock(nn.Module): 184 | def __init__( 185 | self, 186 | dim: int, 187 | n_head: int, 188 | multiple_of: int = 256, 189 | norm_eps: float = 1e-5, 190 | attn_dropout_p: float = 0.0, 191 | ffn_dropout_p: float = 0.1, 192 | resid_dropout_p: float = 0.1, 193 | drop_path: float = 0.0, 194 | ): 195 | super().__init__() 196 | self.attention = Attention( 197 | dim=dim, 198 | n_head=n_head, 199 | attn_dropout_p=attn_dropout_p, 200 | resid_dropout_p=resid_dropout_p, 201 | ) 202 | self.feed_forward = FeedForward( 203 | dim=dim, 204 | multiple_of=multiple_of, 205 | ffn_dropout_p=ffn_dropout_p, 206 | ) 207 | self.attention_norm = RMSNorm(dim, eps=norm_eps) 208 | self.ffn_norm = RMSNorm(dim, eps=norm_eps) 209 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 210 | 211 | def forward(self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None): 212 | h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask)) 213 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) 214 | return out 215 | 216 | 217 | class Transformer(nn.Module): 218 | def __init__( 219 | self, 220 | dim: int = 4096, 221 | n_layer: int = 32, 222 | n_head: int = 32, 223 | attn_dropout_p: float = 0.0, 224 | resid_dropout_p: float = 0.1, 225 | ffn_dropout_p: float = 0.1, 226 | drop_path_rate: float = 0.0, 227 | num_classes: Union[int, List[int]] = 1000, 228 | class_dropout_prob: float = 0.1, 229 | 230 | cls_token_num: int = 1, 231 | num_slots: int = 16, 232 | slot_dim: int = 256, 233 | 234 | diffloss_d: int = 3, 235 | diffloss_w: int = 1024, 236 | num_sampling_steps: str = '100', 237 | diffusion_batch_mul: int = 4, 238 | predict_xstart: bool = False, 239 | use_si: bool = False, 240 | cond_method: str = "adaln", 241 | **kwargs, 242 | ): 243 | super().__init__() 244 | 245 | # Store configuration 246 | self.dim = dim 247 | self.n_layer = n_layer 248 | self.n_head = n_head 249 | self.num_slots = num_slots 250 | self.slot_dim = slot_dim 251 | self.num_classes = num_classes 252 | self.cls_token_num = cls_token_num 253 | 254 | # Initialize embeddings 255 | self.cls_embedding = LabelEmbedder(num_classes, dim, class_dropout_prob) 256 | self.z_proj = nn.Linear(slot_dim, dim, bias=True) 257 | self.z_proj_ln = RMSNorm(dim) 258 | self.pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots + cls_token_num, dim)) 259 | 260 | # transformer blocks 261 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)] 262 | self.layers = torch.nn.ModuleList() 263 | for layer_id in range(n_layer): 264 | self.layers.append(TransformerBlock( 265 | dim=dim, 266 | n_head=n_head, 267 | ffn_dropout_p=ffn_dropout_p, 268 | attn_dropout_p=attn_dropout_p, 269 | resid_dropout_p=resid_dropout_p, 270 | drop_path=dpr[layer_id], 271 | )) 272 | 273 | # output layer 274 | self.norm = RMSNorm(dim) 275 | 276 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots, dim)) 277 | 278 | # KVCache 279 | self.max_batch_size = -1 280 | self.max_seq_length = -1 281 | 282 | self.initialize_weights() 283 | 284 | # Diffusion Loss 285 | self.diffloss = DiffLoss( 286 | target_channels=slot_dim, 287 | z_channels=dim, 288 | width=diffloss_w, 289 | depth=diffloss_d, 290 | num_sampling_steps=num_sampling_steps, 291 | predict_xstart=predict_xstart, 292 | use_si=use_si, 293 | cond_method=cond_method, 294 | ) 295 | self.diffusion_batch_mul = diffusion_batch_mul 296 | 297 | def initialize_weights(self): 298 | nn.init.normal_(self.pos_embed_learned, std=0.02) 299 | nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02) 300 | # Initialize nn.Linear and nn.Embedding 301 | self.apply(self._init_weights) 302 | 303 | def _init_weights(self, module): 304 | if isinstance(module, nn.Linear): 305 | module.weight.data.normal_(std=0.02) 306 | if module.bias is not None: 307 | module.bias.data.zero_() 308 | elif isinstance(module, nn.Embedding): 309 | module.weight.data.normal_(std=0.02) 310 | 311 | def setup_caches(self, max_batch_size, max_seq_length, dtype): 312 | # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 313 | # return 314 | head_dim = self.dim // self.n_head 315 | max_seq_length = find_multiple(max_seq_length, 8) 316 | self.max_seq_length = max_seq_length 317 | self.max_batch_size = max_batch_size 318 | for b in self.layers: 319 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.n_head, head_dim, dtype) 320 | 321 | causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 322 | self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1) 323 | 324 | def reset_caches(self): 325 | self.max_seq_length = -1 326 | self.max_batch_size = -1 327 | for b in self.layers: 328 | b.attention.kv_cache = None 329 | 330 | def forward_loss(self, z, target): 331 | bsz, seq_len, _ = target.shape 332 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 333 | z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) 334 | loss = self.diffloss(z=z, target=target) 335 | return loss 336 | 337 | def forward_cfg(self, h, cfg): 338 | if cfg > 1.0: 339 | h_cond, h_uncond = h.chunk(2, dim=0) 340 | h = h_uncond + cfg * (h_cond - h_uncond) 341 | return h 342 | 343 | def forward( 344 | self, 345 | slots: torch.Tensor, 346 | cond_idx: torch.Tensor, 347 | input_pos: Optional[torch.Tensor] = None, 348 | mask: Optional[torch.Tensor] = None, 349 | cfg: float = 1.0, 350 | temperature: float = 1.0 351 | ): 352 | if slots is not None and cond_idx is not None: # training or naive inference 353 | cond_embeddings = self.cls_embedding(cond_idx, train=self.training) 354 | cond_embeddings = cond_embeddings.expand(-1, self.cls_token_num, -1) 355 | token_embeddings = self.z_proj(slots) 356 | token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) 357 | else: 358 | if cond_idx is not None: # prefill in inference 359 | token_embeddings = self.cls_embedding(cond_idx, train=self.training) 360 | token_embeddings = token_embeddings.expand(-1, self.cls_token_num, -1) 361 | else: # decode_n_tokens(kv cache) in inference 362 | token_embeddings = self.z_proj(slots) 363 | 364 | bs = token_embeddings.shape[0] 365 | mask = self.causal_mask[:bs, None, input_pos] 366 | 367 | h = token_embeddings 368 | if self.training: 369 | h = h + self.pos_embed_learned 370 | else: 371 | h = h + self.pos_embed_learned[:, input_pos].view(1, -1, self.dim) 372 | 373 | h = self.z_proj_ln(h) # not sure if this is needed 374 | 375 | # transformer blocks 376 | for layer in self.layers: 377 | h = layer(h, input_pos, mask) 378 | 379 | h = self.norm(h) 380 | 381 | if self.training: 382 | h = h[:, self.cls_token_num - 1 : -1].contiguous() 383 | h = h + self.diffusion_pos_embed_learned 384 | loss = self.forward_loss(h, slots.detach()) 385 | return loss 386 | else: 387 | h = h[:, -1] 388 | h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1] 389 | next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=cfg) 390 | return next_tokens 391 | 392 | 393 | def get_fsdp_wrap_module_list(self) -> List[nn.Module]: 394 | return list(self.layers) 395 | 396 | 397 | 398 | ################################################################################# 399 | # GPT Configs # 400 | ################################################################################# 401 | ### text-conditional 402 | def GPT_7B(**kwargs): 403 | return Transformer(n_layer=32, n_head=32, dim=4096, **kwargs) # 6.6B 404 | 405 | def GPT_3B(**kwargs): 406 | return Transformer(n_layer=24, n_head=32, dim=3200, **kwargs) # 3.1B 407 | 408 | def GPT_1B(**kwargs): 409 | return Transformer(n_layer=22, n_head=32, dim=2048, **kwargs) # 1.2B 410 | 411 | ### class-conditional 412 | def GPT_XXXL(**kwargs): 413 | return Transformer(n_layer=48, n_head=40, dim=2560, **kwargs) # 3.9B 414 | 415 | def GPT_XXL(**kwargs): 416 | return Transformer(n_layer=48, n_head=24, dim=1536, **kwargs) # 1.4B 417 | 418 | def GPT_XL(**kwargs): 419 | return Transformer(n_layer=36, n_head=20, dim=1280, **kwargs) # 775M 420 | 421 | def GPT_L(**kwargs): 422 | return Transformer(n_layer=24, n_head=16, dim=1024, **kwargs) # 343M 423 | 424 | def GPT_B(**kwargs): 425 | return Transformer(n_layer=12, n_head=12, dim=768, **kwargs) # 111M 426 | 427 | 428 | GPT_models = { 429 | 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL, 430 | 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, 431 | } -------------------------------------------------------------------------------- /semanticist/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import os.path as osp 5 | from PIL import Image 6 | import torchvision 7 | import torchvision.transforms as TF 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | def center_crop_arr(pil_image, image_size): 13 | """ 14 | Center cropping implementation from ADM. 15 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 16 | """ 17 | while min(*pil_image.size) >= 2 * image_size: 18 | pil_image = pil_image.resize( 19 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 20 | ) 21 | 22 | scale = image_size / min(*pil_image.size) 23 | pil_image = pil_image.resize( 24 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 25 | ) 26 | 27 | arr = np.array(pil_image) 28 | crop_y = (arr.shape[0] - image_size) // 2 29 | crop_x = (arr.shape[1] - image_size) // 2 30 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 31 | 32 | def vae_transforms(split, aug='randcrop', img_size=256): 33 | t = [] 34 | if split == 'train': 35 | if aug == 'randcrop': 36 | t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True)) 37 | t.append(TF.RandomCrop(img_size)) 38 | elif aug == 'centercrop': 39 | t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size))) 40 | else: 41 | raise ValueError(f"Invalid augmentation: {aug}") 42 | t.append(TF.RandomHorizontalFlip(p=0.5)) 43 | else: 44 | t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size))) 45 | 46 | t.append(TF.ToTensor()) 47 | 48 | return TF.Compose(t) 49 | 50 | 51 | def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]): 52 | t = [] 53 | if 'centercrop' in aug: 54 | t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size))) 55 | t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))]))) 56 | elif 'tencrop' in aug: 57 | crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges] 58 | t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes])) 59 | t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple])) 60 | t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops]))) 61 | else: 62 | raise ValueError(f"Invalid augmentation: {aug}") 63 | 64 | return TF.Compose(t) 65 | 66 | class ImageNet(torchvision.datasets.ImageFolder): 67 | def __init__(self, root, split='train', aug='randcrop', img_size=256): 68 | super().__init__(osp.join(root, split)) 69 | if not 'cache' in aug: 70 | self.transform = vae_transforms(split, aug=aug, img_size=img_size) 71 | else: 72 | self.transform = cached_transforms(aug=aug, img_size=img_size) -------------------------------------------------------------------------------- /semanticist/utils/device_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def configure_compute_backend(): 4 | """Configure PyTorch compute backend settings for CUDA.""" 5 | if torch.cuda.is_available(): 6 | torch.backends.cuda.matmul.allow_tf32 = True 7 | torch.backends.cudnn.allow_tf32 = True 8 | torch.backends.cudnn.benchmark = True 9 | torch.backends.cudnn.deterministic = False 10 | else: 11 | raise ValueError("No CUDA available") 12 | 13 | def get_device(): 14 | """Get the device to use for training.""" 15 | if torch.cuda.is_available(): 16 | return torch.device("cuda") 17 | else: 18 | raise ValueError("No CUDA available") -------------------------------------------------------------------------------- /semanticist/utils/logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | import datetime 3 | import time 4 | import torch 5 | import torch.distributed as dist 6 | from semanticist.engine.trainer_utils import is_dist_avail_and_initialized, is_main_process 7 | from semanticist.utils.device_utils import get_device 8 | 9 | def synchronize_processes(): 10 | if torch.cuda.is_available(): 11 | torch.cuda.synchronize() 12 | else: # do nothing 13 | pass 14 | 15 | def empty_cache(): 16 | if torch.cuda.is_available(): 17 | torch.cuda.empty_cache() 18 | else: # do nothing 19 | pass 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window_size=20, fmt=None): 27 | if fmt is None: 28 | fmt = "{median:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window_size) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | if not is_dist_avail_and_initialized(): 44 | return 45 | t = torch.tensor([self.count, self.total], dtype=torch.float32, device=get_device()) 46 | dist.barrier() 47 | dist.all_reduce(t) 48 | t = t.tolist() 49 | self.count = int(t[0]) 50 | self.total = t[1] 51 | 52 | @property 53 | def median(self): 54 | d = torch.tensor(list(self.deque)) 55 | return d.median().item() 56 | 57 | @property 58 | def avg(self): 59 | d = torch.tensor(list(self.deque), dtype=torch.float32) 60 | return d.mean().item() 61 | 62 | @property 63 | def global_avg(self): 64 | return self.total / self.count 65 | 66 | @property 67 | def max(self): 68 | return max(self.deque) 69 | 70 | @property 71 | def value(self): 72 | return self.deque[-1] 73 | 74 | def __str__(self): 75 | return self.fmt.format( 76 | median=self.median, 77 | avg=self.avg, 78 | global_avg=self.global_avg, 79 | max=self.max, 80 | value=self.value) 81 | 82 | 83 | class MetricLogger(object): 84 | def __init__(self, delimiter="\t"): 85 | self.meters = defaultdict(SmoothedValue) 86 | self.delimiter = delimiter 87 | 88 | def update(self, **kwargs): 89 | for k, v in kwargs.items(): 90 | if v is None: 91 | continue 92 | if isinstance(v, torch.Tensor): 93 | v = v.item() 94 | assert isinstance(v, (float, int)) 95 | self.meters[k].update(v) 96 | 97 | def __getattr__(self, attr): 98 | if attr in self.meters: 99 | return self.meters[attr] 100 | if attr in self.__dict__: 101 | return self.__dict__[attr] 102 | raise AttributeError("'{}' object has no attribute '{}'".format( 103 | type(self).__name__, attr)) 104 | 105 | def __str__(self): 106 | loss_str = [] 107 | for name, meter in self.meters.items(): 108 | loss_str.append( 109 | "{}: {}".format(name, str(meter)) 110 | ) 111 | return self.delimiter.join(loss_str) 112 | 113 | def synchronize_between_processes(self): 114 | for meter in self.meters.values(): 115 | meter.synchronize_between_processes() 116 | 117 | def add_meter(self, name, meter): 118 | self.meters[name] = meter 119 | 120 | def log_every(self, iterable, print_freq, header=None): 121 | i = 0 122 | if not header: 123 | header = '' 124 | start_time = time.time() 125 | end = time.time() 126 | iter_time = SmoothedValue(fmt='{avg:.4f}') 127 | data_time = SmoothedValue(fmt='{avg:.4f}') 128 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 129 | log_msg = [ 130 | header, 131 | '[{0' + space_fmt + '}/{1}]', 132 | 'eta: {eta}', 133 | '{meters}', 134 | 'time: {time}', 135 | 'data: {data}' 136 | ] 137 | if torch.cuda.is_available(): 138 | log_msg.append('mem: {memory:.0f}') 139 | log_msg.append("util: {util:.1f}%") 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | if is_main_process(): 151 | memory = torch.cuda.max_memory_allocated() 152 | util = torch.cuda.utilization() 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=memory / MB, util=util)) 158 | else: 159 | if is_main_process(): 160 | print(log_msg.format( 161 | i, len(iterable), eta=eta_string, 162 | meters=str(self), 163 | time=str(iter_time), data=str(data_time))) 164 | i += 1 165 | end = time.time() 166 | total_time = time.time() - start_time 167 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 168 | if is_main_process(): 169 | print('{} Total time: {} ({:.4f} s / it)'.format( 170 | header, total_time_str, total_time / len(iterable))) -------------------------------------------------------------------------------- /semanticist/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from timm.scheduler.cosine_lr import CosineLRScheduler 2 | from timm.scheduler.step_lr import StepLRScheduler 3 | 4 | def build_scheduler(optimizer, n_epoch, n_iter_per_epoch, lr_min=0, warmup_steps=0, warmup_lr_init=0, decay_steps=None, cosine_lr=True): 5 | if decay_steps is None: 6 | decay_steps = n_epoch * n_iter_per_epoch 7 | 8 | if cosine_lr: 9 | scheduler = CosineLRScheduler(optimizer, t_initial=decay_steps, lr_min=lr_min, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init, 10 | cycle_limit=1, t_in_epochs=False, warmup_prefix=True) 11 | else: 12 | scheduler = StepLRScheduler(optimizer, decay_t=decay_steps, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init, 13 | t_in_epochs=False, warmup_prefix=True) 14 | 15 | return scheduler 16 | -------------------------------------------------------------------------------- /submitit_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os.path as osp 12 | import submitit 13 | import itertools 14 | 15 | from omegaconf import OmegaConf 16 | from semanticist.engine.trainer_utils import instantiate_from_config 17 | from semanticist.utils.device_utils import configure_compute_backend 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser("Submitit for accelerator training") 22 | # Slurm configuration 23 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 24 | parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request") 25 | parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days") 26 | parser.add_argument("--qos", default="normal", type=str, help="QOS to request") 27 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 28 | parser.add_argument("--partition", default="your-partition", type=str, help="Partition where to submit") 29 | parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition") 30 | parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request") 31 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 32 | 33 | # Model and testing configuration 34 | parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model(s)") 35 | parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number(s)") 36 | parser.add_argument('--cfg', type=str, default=None, help="Path to config file") 37 | parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use") 38 | 39 | # Legacy parameter (preserved for backward compatibility) 40 | parser.add_argument('--cfg_value', type=float, nargs='+', default=[None], 41 | help='Legacy parameter for GPT classifier-free guidance scale') 42 | 43 | # CFG-related parameters - all with nargs='+' to support multiple values 44 | parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None], 45 | help="Autoencoder classifier-free guidance scale") 46 | parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None], 47 | help="CFG schedule type (e.g., constant, linear)") 48 | parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None], 49 | help="Number of slots to use for inference") 50 | parser.add_argument('--temperature', type=float, nargs='+', default=[None], 51 | help="Temperature for sampling") 52 | 53 | return parser.parse_args() 54 | 55 | 56 | def load_config(model_path, cfg_path=None): 57 | """Load configuration from file or model directory.""" 58 | if cfg_path is not None and osp.exists(cfg_path): 59 | config_path = cfg_path 60 | elif model_path and osp.exists(osp.join(model_path, 'config.yaml')): 61 | config_path = osp.join(model_path, 'config.yaml') 62 | else: 63 | raise ValueError(f"No config file found at {model_path} or {cfg_path}") 64 | 65 | return OmegaConf.load(config_path) 66 | 67 | 68 | def setup_checkpoint_path(model_path, step, config): 69 | """Set up the checkpoint path based on model and step.""" 70 | if model_path: 71 | ckpt_path = osp.join(model_path, 'models', f'step{step}') 72 | if not osp.exists(ckpt_path): 73 | print(f"Skipping non-existent checkpoint: {ckpt_path}") 74 | return None 75 | if hasattr(config.trainer.params, 'model'): 76 | config.trainer.params.model.params.ckpt_path = ckpt_path 77 | else: 78 | config.trainer.params.gpt_model.params.ckpt_path = ckpt_path 79 | else: 80 | result_folder = config.trainer.params.result_folder 81 | ckpt_path = osp.join(result_folder, 'models', f'step{step}') 82 | if hasattr(config.trainer.params, 'model'): 83 | config.trainer.params.model.params.ckpt_path = ckpt_path 84 | else: 85 | config.trainer.params.gpt_model.params.ckpt_path = ckpt_path 86 | 87 | return ckpt_path 88 | 89 | 90 | def setup_test_config(config): 91 | """Set up common test configuration parameters.""" 92 | config.trainer.params.test_dataset = config.trainer.params.dataset 93 | config.trainer.params.test_dataset.params.split = 'val' 94 | config.trainer.params.test_only = True 95 | config.trainer.params.compile = False 96 | config.trainer.params.eval_fid = True 97 | config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz' 98 | if hasattr(config.trainer.params, 'model'): 99 | config.trainer.params.model.params.num_sampling_steps = '250' 100 | else: 101 | config.trainer.params.ae_model.params.num_sampling_steps = '250' 102 | 103 | def apply_cfg_params(config, param_dict): 104 | """Apply CFG-related parameters to the config.""" 105 | # Apply each parameter if it's not None 106 | if param_dict.get('cfg_value') is not None: 107 | config.trainer.params.cfg = param_dict['cfg_value'] 108 | print(f"Setting cfg to {param_dict['cfg_value']}") 109 | 110 | if param_dict.get('ae_cfg') is not None: 111 | config.trainer.params.ae_cfg = param_dict['ae_cfg'] 112 | print(f"Setting ae_cfg to {param_dict['ae_cfg']}") 113 | 114 | if param_dict.get('cfg_schedule') is not None: 115 | config.trainer.params.cfg_schedule = param_dict['cfg_schedule'] 116 | print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}") 117 | 118 | if param_dict.get('test_num_slots') is not None: 119 | config.trainer.params.test_num_slots = param_dict['test_num_slots'] 120 | print(f"Setting test_num_slots to {param_dict['test_num_slots']}") 121 | 122 | if param_dict.get('temperature') is not None: 123 | config.trainer.params.temperature = param_dict['temperature'] 124 | print(f"Setting temperature to {param_dict['temperature']}") 125 | 126 | 127 | def run_test(config): 128 | """Instantiate trainer and run test.""" 129 | trainer = instantiate_from_config(config.trainer) 130 | trainer.train() 131 | 132 | 133 | def generate_param_combinations(args): 134 | """Generate all combinations of parameters from the provided arguments.""" 135 | # Create parameter grid for all combinations 136 | param_grid = { 137 | 'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value, 138 | 'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg, 139 | 'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule, 140 | 'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots, 141 | 'temperature': [None] if args.temperature == [None] else args.temperature 142 | } 143 | 144 | # Get all parameter names that have non-None values 145 | active_params = [k for k, v in param_grid.items() if v != [None]] 146 | 147 | if not active_params: 148 | # If no parameters are specified, yield a dict with all None values 149 | yield {k: None for k in param_grid.keys()} 150 | return 151 | 152 | # Generate all combinations of active parameters 153 | active_values = [param_grid[k] for k in active_params] 154 | for combination in itertools.product(*active_values): 155 | param_dict = {k: None for k in param_grid.keys()} # Start with all None 156 | for i, param_name in enumerate(active_params): 157 | param_dict[param_name] = combination[i] 158 | yield param_dict 159 | 160 | 161 | class Trainer(object): 162 | def __init__(self, args): 163 | self.args = args 164 | 165 | def __call__(self): 166 | """Main entry point for the submitit job.""" 167 | self._setup_gpu_args() 168 | configure_compute_backend() 169 | self._run_tests() 170 | 171 | def _run_tests(self): 172 | """Run tests for all specified models and steps.""" 173 | for step in self.args.step: 174 | for model in self.args.model: 175 | print(f"Testing model: {model} at step: {step}") 176 | 177 | # Load configuration 178 | config = load_config(model, self.args.cfg) 179 | 180 | # Setup checkpoint path 181 | ckpt_path = setup_checkpoint_path(model, step, config) 182 | if ckpt_path is None: 183 | continue 184 | 185 | # Setup test configuration 186 | setup_test_config(config) 187 | 188 | # Generate and apply all parameter combinations 189 | for param_dict in generate_param_combinations(self.args): 190 | # Create a copy of the config for each parameter combination 191 | current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True)) 192 | 193 | # Print parameter combination 194 | param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None]) 195 | print(f"Testing with parameters: {param_str}") 196 | 197 | # Apply parameters and run test 198 | apply_cfg_params(current_config, param_dict) 199 | run_test(current_config) 200 | 201 | def _setup_gpu_args(self): 202 | """Set up GPU and distributed environment variables.""" 203 | import submitit 204 | 205 | print("Exporting PyTorch distributed environment variables") 206 | dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False) 207 | print(f"Master: {dist_env.master_addr}:{dist_env.master_port}") 208 | print(f"Rank: {dist_env.rank}") 209 | print(f"World size: {dist_env.world_size}") 210 | print(f"Local rank: {dist_env.local_rank}") 211 | print(f"Local world size: {dist_env.local_world_size}") 212 | 213 | job_env = submitit.JobEnvironment() 214 | self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id)) 215 | self.args.log_dir = self.args.output_dir 216 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 217 | 218 | 219 | def main(): 220 | """Main function to set up and submit the job.""" 221 | args = parse_args() 222 | 223 | # Determine job directory 224 | if args.cfg is not None and osp.exists(args.cfg): 225 | config = OmegaConf.load(args.cfg) 226 | elif osp.exists(osp.join(args.model[0], 'config.yaml')): 227 | config = OmegaConf.load(osp.join(args.model[0], 'config.yaml')) 228 | else: 229 | raise ValueError(f"No config file found at {args.model[0]} or {args.cfg}") 230 | 231 | args.job_dir = config.trainer.params.result_folder 232 | 233 | # Set up the executor 234 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 235 | 236 | # Configure slurm parameters 237 | slurm_kwargs = { 238 | 'slurm_signal_delay_s': 120, 239 | 'slurm_qos': args.qos 240 | } 241 | 242 | if args.comment: 243 | slurm_kwargs['slurm_comment'] = args.comment 244 | if args.exclude: 245 | slurm_kwargs['slurm_exclude'] = args.exclude 246 | if args.nodelist: 247 | slurm_kwargs['slurm_nodelist'] = args.nodelist 248 | 249 | # Update executor parameters 250 | executor.update_parameters( 251 | gpus_per_node=args.ngpus, 252 | tasks_per_node=args.ngpus, # one task per GPU 253 | nodes=args.nodes, 254 | timeout_min=args.timeout, 255 | slurm_partition=args.partition, 256 | name="semanticist", 257 | **slurm_kwargs 258 | ) 259 | 260 | args.output_dir = args.job_dir 261 | 262 | # Submit the job 263 | trainer = Trainer(args) 264 | job = executor.submit(trainer) 265 | 266 | print("Submitted job_id:", job.job_id) 267 | 268 | 269 | if __name__ == "__main__": 270 | main() 271 | -------------------------------------------------------------------------------- /submitit_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import submitit 13 | 14 | from omegaconf import OmegaConf 15 | from semanticist.engine.trainer_utils import instantiate_from_config 16 | from semanticist.utils.device_utils import configure_compute_backend 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser("Submitit for accelerator training") 20 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 21 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 22 | parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days") 23 | parser.add_argument("--qos", default="normal", type=str, help="QOS to request") 24 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 25 | 26 | parser.add_argument("--partition", default="your-partition", type=str, help="Partition where to submit") 27 | parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition") 28 | parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | parser.add_argument('--cfg', type=str, default='configs/your_config.yaml', help='accelerator configs') 31 | return parser.parse_args() 32 | 33 | 34 | class Trainer(object): 35 | def __init__(self, args, config): 36 | self.args = args 37 | self.config = config 38 | 39 | def __call__(self): 40 | self._setup_gpu_args() 41 | configure_compute_backend() 42 | trainer = instantiate_from_config(self.config.trainer) 43 | trainer.train(self.config) 44 | 45 | def checkpoint(self): 46 | import os 47 | import submitit 48 | 49 | model_dir = os.path.join(self.args.output_dir, "models") 50 | if os.path.exists(model_dir): 51 | # Get all step folders 52 | step_folders = [d for d in os.listdir(model_dir) if d.startswith("step")] 53 | if step_folders: 54 | # Extract step numbers and find max 55 | steps = [int(f.replace("step", "")) for f in step_folders] 56 | max_step = max(steps) 57 | # Set ckpt path to the latest step folder 58 | self.config.trainer.params.model.params.ckpt_path = os.path.join(model_dir, f"step{max_step}") 59 | print("Requeuing ", self.args, self.config) 60 | empty_trainer = type(self)(self.args, self.config) 61 | return submitit.helpers.DelayedSubmission(empty_trainer) 62 | 63 | def _setup_gpu_args(self): 64 | import submitit 65 | 66 | print("exporting PyTorch distributed environment variables") 67 | dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False) 68 | print(f"master: {dist_env.master_addr}:{dist_env.master_port}") 69 | print(f"rank: {dist_env.rank}") 70 | print(f"world size: {dist_env.world_size}") 71 | print(f"local rank: {dist_env.local_rank}") 72 | print(f"local world size: {dist_env.local_world_size}") 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id)) 76 | self.args.log_dir = self.args.output_dir 77 | self.config.trainer.params.result_folder = self.args.output_dir 78 | self.config.trainer.params.log_dir = os.path.join(self.args.output_dir, "logs") 79 | # self.args.gpu = job_env.local_rank 80 | # self.args.rank = job_env.global_rank 81 | # self.args.world_size = job_env.num_tasks 82 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 83 | 84 | 85 | def main(): 86 | args = parse_args() 87 | cfg_file = args.cfg 88 | assert os.path.exists(cfg_file) 89 | config = OmegaConf.load(cfg_file) 90 | 91 | if config.trainer.params.result_folder is None: 92 | if args.job_dir == "": 93 | args.job_dir = "./output/%j" 94 | 95 | config.trainer.params.result_folder = args.job_dir 96 | config.trainer.params.log_dir = os.path.join(args.job_dir, "logs") 97 | else: 98 | args.job_dir = config.trainer.params.result_folder 99 | 100 | # Note that the folder will depend on the job_id, to easily track experiments 101 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 102 | 103 | num_gpus_per_node = args.ngpus 104 | nodes = args.nodes 105 | timeout_min = args.timeout 106 | qos = args.qos 107 | 108 | partition = args.partition 109 | kwargs = {} 110 | if args.comment: 111 | kwargs['slurm_comment'] = args.comment 112 | if args.exclude: 113 | kwargs["slurm_exclude"] = args.exclude 114 | if args.nodelist: 115 | kwargs["slurm_nodelist"] = args.nodelist 116 | 117 | executor.update_parameters( 118 | mem_gb=40 * num_gpus_per_node, 119 | gpus_per_node=num_gpus_per_node, 120 | tasks_per_node=num_gpus_per_node, # one task per GPU 121 | # cpus_per_task=16, 122 | nodes=nodes, 123 | timeout_min=timeout_min, # max is 60 * 72 124 | # Below are cluster dependent parameters 125 | slurm_partition=partition, 126 | slurm_signal_delay_s=120, 127 | slurm_qos=qos, 128 | **kwargs 129 | ) 130 | 131 | executor.update_parameters(name="semanticist") 132 | 133 | args.output_dir = args.job_dir 134 | 135 | trainer = Trainer(args, config) 136 | job = executor.submit(trainer) 137 | 138 | print("Submitted job_id:", job.job_id) 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Semanticist testing 2 | accelerate launch --config_file=configs/onenode_config.yaml test_net.py --model ./output/tokenizer/models_xl --step 250000 --cfg_value 3.0 --test_num_slots 256 3 | # or use torchrun 4 | torchrun --nproc-per-node=8 test_net.py --model ./output/tokenizer/models_xl --step 250000 --cfg_value 3.0 --test_num_slots 256 5 | # or use submitit 6 | python submitit_test.py --ngpus=8 --nodes=1 --partition=xxx --model ./output/tokenizer/models_xl --step 250000 --cfg_value 3.0 --test_num_slots 256 7 | 8 | # ϵLlamaGen testing 9 | accelerate launch --config_file=configs/onenode_config.yaml test_net.py --model ./output/autoregressive/models_xl --step 250000 --cfg_value 5.0 --ae_cfg 1.0 --test_num_slots 32 10 | # or use torchrun 11 | torchrun --nproc-per-node=8 test_net.py --model ./output/autoregressive/models_xl --step 250000 --step 250000 --cfg_value 5.0 --ae_cfg 1.0 --test_num_slots 32 12 | # or use submitit 13 | python submitit_test.py --ngpus=8 --nodes=1 --partition=xxx --step 250000 --cfg_value 5.0 --ae_cfg 1.0 --test_num_slots 32 -------------------------------------------------------------------------------- /test_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import itertools 4 | from omegaconf import OmegaConf 5 | from semanticist.engine.trainer_utils import instantiate_from_config 6 | from semanticist.utils.device_utils import configure_compute_backend 7 | 8 | def parse_args(): 9 | """Parse command line arguments.""" 10 | parser = argparse.ArgumentParser("Test a model") 11 | 12 | # Model and testing configuration 13 | parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model directory") 14 | parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number to test") 15 | parser.add_argument('--cfg', type=str, default=None, help="Path to config file") 16 | parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use") 17 | 18 | # Legacy parameter (preserved for backward compatibility) 19 | parser.add_argument('--cfg_value', type=float, nargs='+', default=[None], 20 | help='Legacy parameter for GPT classifier-free guidance scale') 21 | parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None], 22 | help="Autoencoder classifier-free guidance scale") 23 | parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None], 24 | help="CFG schedule type (e.g., constant, linear)") 25 | parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None], 26 | help="Number of slots to use for inference") 27 | parser.add_argument('--temperature', type=float, nargs='+', default=[None], 28 | help="Temperature for sampling") 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def load_config(model_path, cfg_path=None): 34 | """Load configuration from file or model directory.""" 35 | if cfg_path is not None and osp.exists(cfg_path): 36 | config_path = cfg_path 37 | elif model_path and osp.exists(osp.join(model_path, 'config.yaml')): 38 | config_path = osp.join(model_path, 'config.yaml') 39 | else: 40 | raise ValueError(f"No config file found at {model_path} or {cfg_path}") 41 | 42 | return OmegaConf.load(config_path) 43 | 44 | 45 | def setup_checkpoint_path(model_path, step, config): 46 | """Set up the checkpoint path based on model and step.""" 47 | if model_path: 48 | ckpt_path = osp.join(model_path, 'models', f'step{step}') 49 | if not osp.exists(ckpt_path): 50 | print(f"Skipping non-existent checkpoint: {ckpt_path}") 51 | return None 52 | if hasattr(config.trainer.params, 'model'): 53 | config.trainer.params.model.params.ckpt_path = ckpt_path 54 | else: 55 | config.trainer.params.gpt_model.params.ckpt_path = ckpt_path 56 | else: 57 | result_folder = config.trainer.params.result_folder 58 | ckpt_path = osp.join(result_folder, 'models', f'step{step}') 59 | if hasattr(config.trainer.params, 'model'): 60 | config.trainer.params.model.params.ckpt_path = ckpt_path 61 | else: 62 | config.trainer.params.gpt_model.params.ckpt_path = ckpt_path 63 | 64 | return ckpt_path 65 | 66 | 67 | def setup_test_config(config): 68 | """Set up common test configuration parameters.""" 69 | config.trainer.params.test_dataset = config.trainer.params.dataset 70 | config.trainer.params.test_dataset.params.split = 'val' 71 | config.trainer.params.test_only = True 72 | config.trainer.params.compile = False 73 | config.trainer.params.eval_fid = True 74 | config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz' 75 | if hasattr(config.trainer.params, 'model'): 76 | config.trainer.params.model.params.num_sampling_steps = '250' 77 | else: 78 | config.trainer.params.ae_model.params.num_sampling_steps = '250' 79 | 80 | 81 | def apply_cfg_params(config, param_dict): 82 | """Apply CFG-related parameters to the config.""" 83 | # Apply each parameter if it's not None 84 | if param_dict.get('cfg_value') is not None: 85 | config.trainer.params.cfg = param_dict['cfg_value'] 86 | print(f"Setting cfg to {param_dict['cfg_value']}") 87 | 88 | if param_dict.get('ae_cfg') is not None: 89 | config.trainer.params.ae_cfg = param_dict['ae_cfg'] 90 | print(f"Setting ae_cfg to {param_dict['ae_cfg']}") 91 | 92 | if param_dict.get('cfg_schedule') is not None: 93 | config.trainer.params.cfg_schedule = param_dict['cfg_schedule'] 94 | print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}") 95 | 96 | if param_dict.get('test_num_slots') is not None: 97 | config.trainer.params.test_num_slots = param_dict['test_num_slots'] 98 | print(f"Setting test_num_slots to {param_dict['test_num_slots']}") 99 | 100 | if param_dict.get('temperature') is not None: 101 | config.trainer.params.temperature = param_dict['temperature'] 102 | print(f"Setting temperature to {param_dict['temperature']}") 103 | 104 | 105 | def run_test(config): 106 | """Instantiate trainer and run test.""" 107 | trainer = instantiate_from_config(config.trainer) 108 | trainer.train() 109 | 110 | 111 | def generate_param_combinations(args): 112 | """Generate all combinations of parameters from the provided arguments.""" 113 | # Create parameter grid for all combinations 114 | param_grid = { 115 | 'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value, 116 | 'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg, 117 | 'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule, 118 | 'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots, 119 | 'temperature': [None] if args.temperature == [None] else args.temperature 120 | } 121 | 122 | # Get all parameter names that have non-None values 123 | active_params = [k for k, v in param_grid.items() if v != [None]] 124 | 125 | if not active_params: 126 | # If no parameters are specified, yield a dict with all None values 127 | yield {k: None for k in param_grid.keys()} 128 | return 129 | 130 | # Generate all combinations of active parameters 131 | active_values = [param_grid[k] for k in active_params] 132 | for combination in itertools.product(*active_values): 133 | param_dict = {k: None for k in param_grid.keys()} # Start with all None 134 | for i, param_name in enumerate(active_params): 135 | param_dict[param_name] = combination[i] 136 | yield param_dict 137 | 138 | 139 | def test(args): 140 | """Main test function that processes arguments and runs tests.""" 141 | # Iterate through all model and step combinations 142 | for model in args.model: 143 | for step in args.step: 144 | print(f"Testing model: {model} at step: {step}") 145 | 146 | # Load configuration 147 | config = load_config(model, args.cfg) 148 | 149 | # Setup checkpoint path 150 | ckpt_path = setup_checkpoint_path(model, step, config) 151 | if ckpt_path is None: 152 | continue 153 | 154 | # Setup test configuration 155 | setup_test_config(config) 156 | 157 | # Generate and apply all parameter combinations 158 | for param_dict in generate_param_combinations(args): 159 | # Create a copy of the config for each parameter combination 160 | current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True)) 161 | 162 | # Print parameter combination 163 | param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None]) 164 | print(f"Testing with parameters: {param_str}") 165 | 166 | # Apply parameters and run test 167 | apply_cfg_params(current_config, param_dict) 168 | run_test(current_config) 169 | 170 | 171 | def main(): 172 | """Main entry point for the script.""" 173 | args = parse_args() 174 | configure_compute_backend() 175 | test(args) 176 | 177 | 178 | if __name__ == "__main__": 179 | main() 180 | -------------------------------------------------------------------------------- /tok_demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from omegaconf import OmegaConf 8 | from huggingface_hub import hf_hub_download 9 | 10 | from semanticist.engine.trainer_utils import instantiate_from_config 11 | from semanticist.stage1.diffuse_slot import DiffuseSlot 12 | 13 | device = "cuda" if torch.cuda.is_available() else "cpu" 14 | ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename='semanticist_tok_XL.pkl', cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/') 15 | config_path = 'configs/tokenizer_xl.yaml' 16 | cfg = OmegaConf.load(config_path) 17 | ckpt = torch.load(ckpt_path, map_location='cpu') 18 | from semanticist.utils.datasets import vae_transforms 19 | from PIL import Image 20 | 21 | transform = vae_transforms('test') 22 | print(f"Is CUDA available: {torch.cuda.is_available()}") 23 | if device == 'cuda': 24 | print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") 25 | 26 | 27 | def norm_ip(img, low, high): 28 | img.clamp_(min=low, max=high) 29 | img.sub_(low).div_(max(high - low, 1e-5)) 30 | 31 | def norm_range(t, value_range): 32 | if value_range is not None: 33 | norm_ip(t, value_range[0], value_range[1]) 34 | else: 35 | norm_ip(t, float(t.min()), float(t.max())) 36 | 37 | from PIL import Image 38 | def convert_np(img): 39 | ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\ 40 | .permute(1, 2, 0).to("cpu", torch.uint8).numpy() 41 | return ndarr 42 | def convert_PIL(img): 43 | ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\ 44 | .permute(1, 2, 0).to("cpu", torch.uint8).numpy() 45 | img = Image.fromarray(ndarr) 46 | return img 47 | 48 | ckpt = {k.replace('._orig_mod', ''): v for k, v in ckpt.items()} 49 | 50 | model = DiffuseSlot(**cfg['trainer']['params']['model']['params']) 51 | msg = model.load_state_dict(ckpt, strict=False) 52 | model = model.to(device) 53 | model = model.eval() 54 | model.enable_nest = True 55 | 56 | def viz_diff_slots(model, img, nums, cfg=1.0, return_img=False): 57 | n_slots_inf = [] 58 | for num_slots_to_inference in nums: 59 | recon_n = model( 60 | img, sample=True, cfg=cfg, 61 | inference_with_n_slots=num_slots_to_inference, 62 | ) 63 | n_slots_inf.append(recon_n) 64 | return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))] 65 | 66 | # Removed process_image function as its functionality is now in the update_outputs function 67 | 68 | with gr.Blocks() as demo: 69 | with gr.Row(): 70 | # First column - Input and configs 71 | with gr.Column(scale=1): 72 | gr.Markdown("## Input") 73 | input_image = gr.Image(label="Upload an image", type="numpy") 74 | 75 | with gr.Group(): 76 | gr.Markdown("### Configuration") 77 | show_gallery = gr.Checkbox(label="Show Gallery", value=True) 78 | # You can add more config options here 79 | # slider = gr.Slider(minimum=0, maximum=10, value=5, label="Processing Intensity") 80 | slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value") 81 | labels_input = gr.Textbox( 82 | label="Number of tokens to reconstruct (comma-separated)", 83 | value="1, 4, 16, 64, 256", 84 | placeholder="Enter comma-separated numbers for the number of slots to use" 85 | ) 86 | 87 | # Second column - Output (conditionally rendered) 88 | with gr.Column(scale=1): 89 | gr.Markdown("## Output") 90 | 91 | # Container for conditional rendering 92 | with gr.Group(visible=True) as gallery_container: 93 | gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True) 94 | 95 | # Always visible output image 96 | output_image = gr.Image(label="Processed Image", type="numpy") 97 | 98 | # Handle form submission 99 | submit_btn = gr.Button("Process") 100 | 101 | # Define the processing logic 102 | def update_outputs(image, show_gallery_value, slider_value, labels_text): 103 | # Update the visibility of the gallery container 104 | gallery_container.visible = show_gallery_value 105 | 106 | try: 107 | # Parse the labels from the text input 108 | if labels_text and "," in labels_text: 109 | labels = [int(label.strip()) for label in labels_text.split(",")] 110 | else: 111 | # Default labels if none provided or in wrong format 112 | labels = [1, 4, 16, 64, 256] 113 | except: 114 | labels = [1, 4, 16, 64, 256] 115 | while len(labels) < 3: 116 | labels.append(256) 117 | 118 | # Process the image based on configurations 119 | if image is None: 120 | # Return placeholder if no image is uploaded 121 | placeholder = np.zeros((300, 300, 3), dtype=np.uint8) 122 | return gallery_container, [], placeholder 123 | image = Image.fromarray(image) 124 | img = transform(image) 125 | img = img.unsqueeze(0).to(device) 126 | recon = viz_diff_slots(model, img, [256], cfg=slider_value)[0] 127 | 128 | 129 | if not show_gallery_value: 130 | # If only the image should be shown, return just the processed image 131 | return gallery_container, [], recon 132 | else: 133 | model_decompose = viz_diff_slots(model, img, labels, cfg=slider_value) 134 | # Create image variations and pair them with labels 135 | gallery_images = [ 136 | (image, 'GT'), 137 | # (np.array(Image.fromarray(image).convert("L").convert("RGB")), labels[1]), 138 | # (np.array(Image.fromarray(image).rotate(180)), labels[2]) 139 | ] + [(img, 'Recon. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)] 140 | return gallery_container, gallery_images, image 141 | 142 | # Connect the inputs and outputs 143 | submit_btn.click( 144 | fn=update_outputs, 145 | inputs=[input_image, show_gallery, slider, labels_input], 146 | outputs=[gallery_container, gallery, output_image] 147 | ) 148 | 149 | # Also update when checkbox changes 150 | show_gallery.change( 151 | fn=lambda value: gr.update(visible=value), 152 | inputs=[show_gallery], 153 | outputs=[gallery_container] 154 | ) 155 | 156 | # Add examples 157 | examples = [ 158 | ["examples/city.jpg", True, 4.0, "1,4,16,64,256"], 159 | ["examples/food.jpg", True, 4.0, "1,4,16,64,256"], 160 | ["examples/highland.webp", True, 4.0, "1,4,16,64,256"], 161 | ] 162 | 163 | gr.Examples( 164 | examples=examples, 165 | inputs=[input_image, show_gallery, slider, labels_input], 166 | outputs=[gallery_container, gallery, output_image], 167 | fn=update_outputs, 168 | cache_examples=True 169 | ) 170 | 171 | # Launch the demo 172 | if __name__ == "__main__": 173 | demo.launch() 174 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Semanticist training 2 | accelerate launch --config_file=configs/onenode_config.yaml train_net.py --cfg configs/tokenizer_xl.yaml 3 | # or use torchrun 4 | torchrun --nproc-per-node=8 train_net.py --cfg configs/tokenizer_xl.yaml 5 | # or use submitit 6 | python submitit_train.py --ngpus=8 --nodes=1 --partition=xxx --config configs/tokenizer_xl.yaml 7 | 8 | # ϵLlamaGen training 9 | accelerate launch --config_file=configs/onenode_config.yaml train_net.py --cfg configs/autoregressive_xl.yaml 10 | # or use torchrun 11 | torchrun --nproc-per-node=8 train_net.py --cfg configs/autoregressive_xl.yaml 12 | # or use submitit 13 | python submitit_train.py --ngpus=8 --nodes=1 --partition=xxx --config configs/autoregressive_xl.yaml -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | from omegaconf import OmegaConf 4 | from semanticist.engine.trainer_utils import instantiate_from_config 5 | from semanticist.utils.device_utils import configure_compute_backend 6 | 7 | def train(): 8 | configure_compute_backend() 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--cfg', type=str, default='configs/vit_vqgan.yaml') 11 | args = parser.parse_args() 12 | 13 | cfg_file = args.cfg 14 | assert osp.exists(cfg_file) 15 | config = OmegaConf.load(cfg_file) 16 | trainer = instantiate_from_config(config.trainer) 17 | trainer.train(args.cfg) 18 | 19 | if __name__ == '__main__': 20 | 21 | train() 22 | --------------------------------------------------------------------------------