├── .github
└── workflows
│ └── static.yml
├── .gitignore
├── LICENSE
├── README.md
├── ar_gen.ipynb
├── configs
├── autoregressive_l.yaml
├── autoregressive_xl.yaml
├── onenode_config.yaml
├── tokenizer_l.yaml
└── tokenizer_xl.yaml
├── examples
├── city.jpg
├── food.jpg
└── highland.webp
├── fid_stats
└── adm_in256_stats.npz
├── gen_demo.py
├── imagenet_classes.py
├── pages
├── figs
│ ├── Token_PCA.png
│ ├── comp_table.jpg
│ ├── spectral_analysis.png
│ ├── spectral_titok_ours.jpg
│ ├── teaser.jpg
│ └── tokenizer.png
└── index.html
├── requirements.txt
├── semanticist
├── engine
│ ├── diffusion_trainer.py
│ ├── gpt_trainer.py
│ └── trainer_utils.py
├── stage1
│ ├── diffuse_slot.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── diffusion_utils.py
│ │ ├── gaussian_diffusion.py
│ │ ├── respace.py
│ │ └── timestep_sampler.py
│ ├── diffusion_transfomer.py
│ ├── fused_attention.py
│ ├── pos_embed.py
│ ├── transport
│ │ ├── __init__.py
│ │ ├── integrators.py
│ │ ├── path.py
│ │ ├── transport.py
│ │ └── utils.py
│ └── vision_transformer.py
├── stage2
│ ├── diffloss.py
│ ├── generate.py
│ └── gpt.py
└── utils
│ ├── datasets.py
│ ├── device_utils.py
│ ├── logger.py
│ └── lr_scheduler.py
├── submitit_test.py
├── submitit_train.py
├── test.sh
├── test_net.py
├── tok_demo.py
├── train.sh
└── train_net.py
/.github/workflows/static.yml:
--------------------------------------------------------------------------------
1 | # Simple workflow for deploying static content to GitHub Pages
2 | name: Deploy static content to Pages
3 |
4 | on:
5 | # Runs on pushes targeting the default branch
6 | push:
7 | branches: ["main"]
8 |
9 | # Allows you to run this workflow manually from the Actions tab
10 | workflow_dispatch:
11 |
12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
13 | permissions:
14 | contents: read
15 | pages: write
16 | id-token: write
17 |
18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
20 | concurrency:
21 | group: "pages"
22 | cancel-in-progress: false
23 |
24 | jobs:
25 | # Single deploy job since we're just deploying
26 | deploy:
27 | environment:
28 | name: github-pages
29 | url: ${{ steps.deployment.outputs.page_url }}
30 | runs-on: ubuntu-latest
31 | steps:
32 | - name: Checkout
33 | uses: actions/checkout@v4
34 | - name: Setup Pages
35 | uses: actions/configure-pages@v5
36 | - name: Upload artifact
37 | uses: actions/upload-pages-artifact@v3
38 | with:
39 | # Upload entire repository
40 | path: 'pages/'
41 | - name: Deploy to GitHub Pages
42 | id: deployment
43 | uses: actions/deploy-pages@v4
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | cache
3 | dataset
4 | output
5 | build
6 | *_results*
7 | .ipynb_checkpoints
8 | *.zip
9 | *.pth
10 | *.ttf
11 | *.so
12 | .nfs*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Semanticist Authors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # "Principal Components" Enable A New Language of Images
2 | ### A New Paradigm for Compact and Interpretable Image Representations
3 | [Read the Paper] |
4 | [Project Page] |
5 | [Huggingface Tokenizer Demo] |
6 | [Huggingface Generation Demo]
7 |
8 | [Xin Wen](https://wen-xin.info/)1*,
9 | [Bingchen Zhao](https://bzhao.me/)2*,
10 | [Ismail Elezi](https://therevanchist.github.io/)3,
11 | [Jiankang Deng](https://jiankangdeng.github.io/)4,
12 | [Xiaojuan Qi](https://xjqi.github.io/)1
13 |
14 | * Equal Contribution
15 |
16 | 1 University of Hong Kong |
17 | 2 University of Edinburgh |
18 | 3 Noah's Ark Lab |
19 | 4 Imperial College London
20 |
21 | 
22 |
23 | ## Introduction & Motivation
24 | Deep generative models have revolutionized image synthesis, but how we tokenize visual data remains an open question.
25 | While classical methods like **Principal Component Analysis (PCA)** introduced compact, structured representations, modern **visual tokenizers**—from **VQ-VAE** to **SD-VAE**—often prioritize **reconstruction fidelity** at the cost of interpretability and efficiency.
26 |
27 | ### The Problem
28 |
29 | - **Lack of Structure:** Tokens are arbitrarily learned, without an ordering that prioritizes important visual features first.
30 | - **Semantic-Spectrum Coupling:** Tokens entangle *high-level semantics* with *low-level spectral details*, leading to inefficiencies in downstream applications.
31 |
32 | Can we design a **compact, structured tokenizer** that retains the benefits of PCA while leveraging modern generative techniques?
33 |
34 | ### Key Contributions (What's New?)
35 | - **📌 PCA-Guided Tokenization:** Introduces a *causal ordering* where earlier tokens capture the most important visual details, reducing redundancy.
36 | - **⚡ Semantic-Spectrum Decoupling:** Resolves the issue of semantic-spectrum coupling to ensure tokens focus on high-level semantic information.
37 | - **🌀 Diffusion-Based Decoding:** Uses a *diffusion decoder* for the spectral auto-regressive property to naturally separate semantic and spectral content.
38 | - **🚀 Compact & Interpretability-Friendly:** Enables *flexible token selection*, where fewer tokens can still yield high-quality reconstructions.
39 |
40 | For more details, please refer to our [project page](https://visual-gen.github.io/semanticist/).
41 |
42 | ## Getting Started
43 |
44 | ### Preparation
45 |
46 | First please makesure pytorch is installed (we used 2.5.1 but we expect any version >= 2.0 to work).
47 |
48 | Then install the rest of the dependencies.
49 |
50 | ```
51 | pip install -r requirements.txt
52 | ```
53 |
54 | Please then download [ImageNet](https://www.image-net.org/) and soft-link it to `./dataset/imagenet`. For evaluating FID, it is recommended to pre-process the validation set of ImageNet with [this script](https://github.com/LTH14/rcg/blob/main/prepare_imgnet_val.py). The target folder is `./dataset/imagenet/val256` in our case.
55 |
56 | ### Training
57 |
58 | Our codebase supports DDP training with accelerate, torchrun, and submitit (for slurm users). To train a Semanticist tokenizer with DiT-L tokenizer on 8 GPUs, you can run
59 | ```bash
60 | accelerate launch --config_file=configs/onenode_config.yaml train_net.py --cfg configs/tokenizer_l.yaml
61 | ```
62 | or
63 | ```bash
64 | torchrun --nproc-per-node=8 train_net.py --cfg configs/tokenizer_l.yaml
65 | ```
66 | or
67 | ```bash
68 | python submitit_train.py --ngpus=8 --nodes=1 --partition=xxx --config configs/tokenizer_l.yaml
69 | ```
70 | We used a global batch size of 2048 and thus the effective batch size per GPU is 256 in this case. Your may modify the batch size and gradient accumulation steps in the config file accrrding to your training resources.
71 |
72 | To train a ϵLlamaGen autoregressive model with a tokenizer trained as above, you can run the following command. Remember to change the path to the tokenizer in the config file. The EMA model is `custom_checkpoint_1.pkl` under the output folder.
73 | ```bash
74 | accelerate launch --config_file=configs/onenode_config.yaml train_net.py --cfg configs/autoregressive_l.yaml
75 | ```
76 | Note that caching is enabled by default and it takes around 400GB memory (dumped to `/dev/shm`) for ten_crop augmentation on ImageNet. If you want to disable it, you can set `enable_cache_latents` to False in the config file and/or specify a different data augmentation method (e.g., centercrop_cached, centercrop, randcrop).
77 |
78 | ### Evaluation
79 |
80 | By default, when evaluating online we do not use the EMA model. Thus to obtain the final performance, you are suggested to perform a separate evaluation after training. Like above, our scripts are compatible with accelerate, torchrun, and submitit.
81 | ```bash
82 | accelerate launch --config_file=configs/onenode_config.yaml test_net.py --model ./output/tokenizer/models_l --step 250000 --cfg_value 3.0 --test_num_slots 32
83 | ```
84 | or
85 | ```bash
86 | torchrun --nproc-per-node=8 test_net.py --model ./output/tokenizer/models_l --step 250000 --cfg_value 3.0 --test_num_slots 32
87 | ```
88 | or
89 | ```bash
90 | python submitit_eval.py --ngpus=8 --nodes=1 --partition=xxx --model ./output/tokenizer/models_l --step 250000 --cfg_value 3.0 --test_num_slots 32
91 | ```
92 | And for the AR model:
93 | ```bash
94 | torchrun --nproc-per-node=8 test_net.py --model ./output/autoregressive/models_l --step 250000 --cfg_value 6.0 --ae_cfg 1.0 --test_num_slots 32
95 | ```
96 | If `enable_ema` is set to True, the EMA model will be loaded automatically. You can adjust the number of GPUs flexibly. You can also specify multiple arguments in the command line to perform a grid search.
97 |
98 | ### Demos
99 |
100 | Please refer to our demo pages on Huggingface for the tokenizer and the AR model.
101 | - [Tokenizer Demo](https://huggingface.co/spaces/tennant/semanticist_tokenizer)
102 | - [AR Demo](https://huggingface.co/spaces/tennant/Semanticist_AR)
103 |
104 | ## Note
105 |
106 | It's possible that this code may not accurately replicate the results outlined in the paper due to potential human errors during the preparation and cleaning of the code for release. If you encounter any difficulties in reproducing our findings, please don't hesitate to inform us. Additionally, we'll make an effort to refine the README and code, and carry out sanity-check experiments in the near future.
107 |
108 | ## Acknowledgements
109 |
110 | Our codebase builds upon several existing publicly available codes. Specifically, we have modified or taken inspiration from the following repos: [DiT](https://github.com/facebookresearch/DiT), [SiT](https://github.com/willisma/SiT), [DiffAE](https://github.com/phizaz/diffae), [LlamaGen](https://github.com/FoundationVision/LlamaGen), [RCG](https://github.com/LTH14/rcg), [MAR](https://github.com/LTH14/mar), [REPA](https://github.com/sihyun-yu/REPA), etc. We thank the authors for their contributions to the community.
111 |
112 | ## Citation
113 |
114 | If you find this work useful in your research, please consider citing us!
115 |
116 | ```bibtex
117 | @article{semanticist,
118 | title={``Principal Components'' Enable A New Language of Images},
119 | author={Wen, Xin and Zhao, Bingchen and Elezi, Ismail and Deng, Jiankang and Qi, Xiaojuan},
120 | journal={arXiv preprint arXiv:2503.08685},
121 | year={2025}
122 | }
123 | ```
124 |
--------------------------------------------------------------------------------
/configs/autoregressive_l.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | target: semanticist.engine.gpt_trainer.GPTTrainer
3 | params:
4 | num_epoch: 400
5 | blr: 1e-4
6 | cosine_lr: False
7 | warmup_epochs: 100
8 | batch_size: 256
9 | num_workers: 8
10 | pin_memory: True
11 | grad_accum_steps: 1
12 | precision: 'bf16'
13 | max_grad_norm: 1.0
14 | enable_ema: True
15 | save_every: 10000
16 | sample_every: 5000
17 | fid_every: 50000
18 | eval_fid: False
19 | result_folder: "./output/autoregressive"
20 | log_dir: "./output/autoregressive/logs"
21 | ae_cfg: 1.0
22 | cfg: 6.0
23 | cfg_schedule: "linear"
24 | train_num_slots: 32
25 | test_num_slots: 32
26 | compile: True
27 | enable_cache_latents: True
28 | ae_model:
29 | target: semanticist.stage1.diffuse_slot.DiffuseSlot
30 | params:
31 | encoder: 'vit_base_patch16'
32 | enc_img_size: 256
33 | enc_causal: True
34 | num_slots: 256
35 | slot_dim: 16
36 | norm_slots: True
37 | cond_method: 'token'
38 | dit_model: 'DiT-L-2'
39 | vae: 'xwen99/mar-vae-kl16'
40 | num_sampling_steps: '250'
41 | ckpt_path: ./output/tokenizer/models_l/step250000/custom_checkpoint_1.pkl
42 |
43 | gpt_model:
44 | target: GPT-L
45 | params:
46 | num_slots: 32
47 | slot_dim: 16
48 | num_classes: 1000
49 | cls_token_num: 1
50 | resid_dropout_p: 0.1
51 | ffn_dropout_p: 0.1
52 | diffloss_d: 12
53 | diffloss_w: 1536
54 | num_sampling_steps: '100'
55 | diffusion_batch_mul: 4
56 | use_si: True
57 | cond_method: 'concat'
58 | ckpt_path: None
59 |
60 | dataset:
61 | target: semanticist.utils.datasets.ImageNet
62 | params:
63 | root: ./dataset/imagenet/
64 | split: train
65 | aug: tencrop_cached # or centercrop_cached
66 | img_size: 256
--------------------------------------------------------------------------------
/configs/autoregressive_xl.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | target: semanticist.engine.gpt_trainer.GPTTrainer
3 | params:
4 | num_epoch: 400
5 | blr: 1e-4
6 | cosine_lr: False
7 | warmup_epochs: 100
8 | batch_size: 256
9 | num_workers: 8
10 | pin_memory: True
11 | grad_accum_steps: 1
12 | precision: 'bf16'
13 | max_grad_norm: 1.0
14 | enable_ema: True
15 | save_every: 10000
16 | sample_every: 5000
17 | fid_every: 50000
18 | eval_fid: False
19 | result_folder: "./output/autoregressive"
20 | log_dir: "./output/autoregressive/logs"
21 | ae_cfg: 1.0
22 | cfg: 5.0
23 | cfg_schedule: "linear"
24 | train_num_slots: 32
25 | test_num_slots: 32
26 | compile: True
27 | enable_cache_latents: True
28 | ae_model:
29 | target: semanticist.stage1.diffuse_slot.DiffuseSlot
30 | params:
31 | encoder: 'vit_base_patch16'
32 | enc_img_size: 256
33 | enc_causal: True
34 | num_slots: 256
35 | slot_dim: 16
36 | norm_slots: True
37 | cond_method: 'token'
38 | dit_model: 'DiT-XL-2'
39 | vae: 'xwen99/mar-vae-kl16'
40 | num_sampling_steps: '250'
41 | ckpt_path: ./output/tokenizer/models_xl/step250000/custom_checkpoint_1.pkl
42 |
43 | gpt_model:
44 | target: GPT-L
45 | params:
46 | num_slots: 32
47 | slot_dim: 16
48 | num_classes: 1000
49 | cls_token_num: 1
50 | resid_dropout_p: 0.1
51 | ffn_dropout_p: 0.1
52 | diffloss_d: 12
53 | diffloss_w: 1536
54 | num_sampling_steps: '100'
55 | diffusion_batch_mul: 4
56 | use_si: True
57 | cond_method: 'concat'
58 | ckpt_path: None
59 |
60 | dataset:
61 | target: semanticist.utils.datasets.ImageNet
62 | params:
63 | root: ./dataset/imagenet/
64 | split: train
65 | aug: tencrop_cached # or centercrop_cached
66 | img_size: 256
--------------------------------------------------------------------------------
/configs/onenode_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config: {}
3 | distributed_type: MULTI_GPU
4 | fsdp_config: {}
5 | machine_rank: 0
6 | main_process_ip: null
7 | main_process_port: null
8 | main_training_function: main
9 | num_machines: 1
10 | num_processes: 8
11 | use_cpu: false
--------------------------------------------------------------------------------
/configs/tokenizer_l.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | target: semanticist.engine.diffusion_trainer.DiffusionTrainer
3 | params:
4 | num_epoch: 400
5 | valid_size: 64
6 | blr: 2.5e-5
7 | cosine_lr: True
8 | warmup_epochs: 100
9 | batch_size: 256
10 | num_workers: 16
11 | pin_memory: True
12 | grad_accum_steps: 1
13 | precision: 'bf16'
14 | max_grad_norm: 3.0
15 | enable_ema: True
16 | save_every: 10000
17 | sample_every: 5000
18 | fid_every: 50000
19 | result_folder: "./output/tokenizer/models_l"
20 | log_dir: "./output/tokenizer/models_l/logs"
21 | cfg: 3.0
22 | compile: True
23 | model:
24 | target: semanticist.stage1.diffuse_slot.DiffuseSlot
25 | params:
26 | encoder: 'vit_base_patch16'
27 | enc_img_size: 256
28 | enc_causal: True
29 | enc_use_mlp: False
30 | num_slots: 256
31 | slot_dim: 16
32 | norm_slots: True
33 | dit_model: 'DiT-L-2'
34 | vae: 'xwen99/mar-vae-kl16'
35 | enable_nest: False
36 | enable_nest_after: 50
37 | use_repa: True
38 | eval_fid: True
39 | fid_stats: 'fid_stats/adm_in256_stats.npz'
40 | num_sampling_steps: '250'
41 | ckpt_path: None
42 |
43 | dataset:
44 | target: semanticist.utils.datasets.ImageNet
45 | params:
46 | root: ./dataset/imagenet/
47 | split: train
48 | img_size: 256
49 |
50 | test_dataset:
51 | target: semanticist.utils.datasets.ImageNet
52 | params:
53 | root: ./dataset/imagenet/
54 | split: val
55 | img_size: 256
56 |
--------------------------------------------------------------------------------
/configs/tokenizer_xl.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | target: semanticist.engine.diffusion_trainer.DiffusionTrainer
3 | params:
4 | num_epoch: 400
5 | valid_size: 64
6 | blr: 2.5e-5
7 | cosine_lr: True
8 | warmup_epochs: 100
9 | batch_size: 256
10 | num_workers: 16
11 | pin_memory: True
12 | grad_accum_steps: 1
13 | precision: 'bf16'
14 | max_grad_norm: 3.0
15 | enable_ema: True
16 | save_every: 10000
17 | sample_every: 5000
18 | fid_every: 50000
19 | result_folder: "./output/tokenizer/models_xl"
20 | log_dir: "./output/tokenizer/models_xl/logs"
21 | cfg: 3.0
22 | compile: True
23 | model:
24 | target: semanticist.stage1.diffuse_slot.DiffuseSlot
25 | params:
26 | encoder: 'vit_base_patch16'
27 | enc_img_size: 256
28 | enc_causal: True
29 | enc_use_mlp: False
30 | num_slots: 256
31 | slot_dim: 16
32 | norm_slots: True
33 | dit_model: 'DiT-XL-2'
34 | vae: 'xwen99/mar-vae-kl16'
35 | enable_nest: False
36 | enable_nest_after: 50
37 | use_repa: True
38 | eval_fid: True
39 | fid_stats: 'fid_stats/adm_in256_stats.npz'
40 | num_sampling_steps: '250'
41 | ckpt_path: None
42 |
43 | dataset:
44 | target: semanticist.utils.datasets.ImageNet
45 | params:
46 | root: ./dataset/imagenet/
47 | split: train
48 | img_size: 256
49 |
50 | test_dataset:
51 | target: semanticist.utils.datasets.ImageNet
52 | params:
53 | root: ./dataset/imagenet/
54 | split: val
55 | img_size: 256
56 |
--------------------------------------------------------------------------------
/examples/city.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/examples/city.jpg
--------------------------------------------------------------------------------
/examples/food.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/examples/food.jpg
--------------------------------------------------------------------------------
/examples/highland.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/examples/highland.webp
--------------------------------------------------------------------------------
/fid_stats/adm_in256_stats.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/fid_stats/adm_in256_stats.npz
--------------------------------------------------------------------------------
/gen_demo.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import numpy as np
3 | from PIL import Image
4 | import os.path as osp
5 | import torch
6 | import matplotlib.pyplot as plt
7 | from omegaconf import OmegaConf
8 | from tqdm import tqdm
9 | from huggingface_hub import hf_hub_download
10 | from semanticist.engine.trainer_utils import instantiate_from_config
11 | from semanticist.stage1.diffuse_slot import DiffuseSlot
12 | from semanticist.stage2.gpt import GPT_models
13 | from semanticist.stage2.generate import generate
14 | from safetensors import safe_open
15 | from semanticist.utils.datasets import vae_transforms
16 | from PIL import Image
17 | from imagenet_classes import imagenet_classes
18 |
19 | transform = vae_transforms('test')
20 |
21 |
22 | def norm_ip(img, low, high):
23 | img.clamp_(min=low, max=high)
24 | img.sub_(low).div_(max(high - low, 1e-5))
25 |
26 | def norm_range(t, value_range):
27 | if value_range is not None:
28 | norm_ip(t, value_range[0], value_range[1])
29 | else:
30 | norm_ip(t, float(t.min()), float(t.max()))
31 |
32 | from PIL import Image
33 | def convert_np(img):
34 | ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
35 | .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
36 | return ndarr
37 | def convert_PIL(img):
38 | ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
39 | .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
40 | img = Image.fromarray(ndarr)
41 | return img
42 |
43 | def norm_slots(slots):
44 | mean = torch.mean(slots, dim=-1, keepdim=True)
45 | std = torch.std(slots, dim=-1, keepdim=True)
46 | return (slots - mean) / std
47 |
48 | def load_state_dict(state_dict, model):
49 | """Helper to load a state dict with proper prefix handling."""
50 | if 'state_dict' in state_dict:
51 | state_dict = state_dict['state_dict']
52 | # Remove '_orig_mod' prefix if present
53 | state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
54 | missing, unexpected = model.load_state_dict(
55 | state_dict, strict=False
56 | )
57 | # print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
58 |
59 | def load_safetensors(path, model):
60 | """Helper to load a safetensors checkpoint."""
61 | from safetensors.torch import safe_open
62 | with safe_open(path, framework="pt", device="cpu") as f:
63 | state_dict = {k: f.get_tensor(k) for k in f.keys()}
64 | load_state_dict(state_dict, model)
65 |
66 | def load_checkpoint(ckpt_path, model):
67 | if ckpt_path is None or not osp.exists(ckpt_path):
68 | return
69 |
70 | if osp.isdir(ckpt_path):
71 | # ckpt_path is something like 'path/to/models/step10/'
72 | model_path = osp.join(ckpt_path, "model.safetensors")
73 | if osp.exists(model_path):
74 | load_safetensors(model_path, model)
75 | else:
76 | # ckpt_path is something like 'path/to/models/step10.pt'
77 | if ckpt_path.endswith(".safetensors"):
78 | load_safetensors(ckpt_path, model)
79 | else:
80 | state_dict = torch.load(ckpt_path, map_location="cpu")
81 | load_state_dict(state_dict, model)
82 |
83 | print(f"Loaded checkpoint from {ckpt_path}")
84 |
85 | device = "cuda" if torch.cuda.is_available() else "cpu"
86 | print(f"Is CUDA available: {torch.cuda.is_available()}")
87 | if device == 'cuda':
88 | print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
89 |
90 | ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_ar_gen_L.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
91 | config_path = 'configs/autoregressive_xl.yaml'
92 |
93 | cfg = OmegaConf.load(config_path)
94 | params = cfg.trainer.params
95 |
96 | ae_model = instantiate_from_config(params.ae_model).to(device)
97 | ae_model_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_tok_XL.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
98 | load_checkpoint(ae_model_path, ae_model)
99 | ae_model.eval()
100 |
101 | gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(device)
102 | load_checkpoint(ckpt_path, gpt_model)
103 | gpt_model.eval();
104 |
105 | def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
106 | n_slots_inf = []
107 | for num_slots_to_inference in nums:
108 | drop_mask = model.nested_sampler(slots.shape[0], device, num_slots_to_inference)
109 | recon_n = model.sample(slots, drop_mask=drop_mask, cfg=cfg)
110 | n_slots_inf.append(recon_n)
111 | return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
112 |
113 | num_slots = params.ae_model.params.num_slots
114 | slot_dim = params.ae_model.params.slot_dim
115 | dtype = torch.bfloat16
116 | # the model is trained with only 32 tokens.
117 | num_slots_to_gen = 32
118 |
119 | # Function to generate image from class
120 | def generate_from_class(class_id, cfg_scale):
121 | with torch.no_grad():
122 | dtype = torch.bfloat16
123 | num_slots_to_gen = 32
124 | with torch.autocast(device, dtype=dtype):
125 | slots_gen = generate(
126 | gpt_model,
127 | torch.tensor([class_id]).to(device),
128 | num_slots_to_gen,
129 | cfg_scale=cfg_scale,
130 | cfg_schedule="linear"
131 | )
132 | if num_slots_to_gen < num_slots:
133 | null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
134 | null_slots = null_slots[:, num_slots_to_gen:, :]
135 | slots_gen = torch.cat([slots_gen, null_slots], dim=1)
136 | return slots_gen
137 |
138 | with gr.Blocks() as demo:
139 | with gr.Row():
140 | # First column - Input and configs
141 | with gr.Column(scale=1):
142 | gr.Markdown("## Input")
143 |
144 | # Replace image input with ImageNet class selection
145 | imagenet_classes = {k: v for k, v in enumerate(imagenet_classes)}
146 | class_choices = [f"{id}: {name}" for id, name in imagenet_classes.items()]
147 |
148 | # Dropdown for class selection
149 | class_dropdown = gr.Dropdown(
150 | choices=class_choices[:20], # Limit for demonstration
151 | label="Select ImageNet Class",
152 | value=class_choices[0] if class_choices else None
153 | )
154 |
155 | # Option to enter class ID directly
156 | class_id_input = gr.Number(
157 | label="Or enter class ID directly (0-999)",
158 | value=0,
159 | minimum=0,
160 | maximum=999,
161 | step=1
162 | )
163 |
164 | with gr.Group():
165 | gr.Markdown("### Configuration")
166 | show_gallery = gr.Checkbox(label="Show Gallery", value=True)
167 | slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
168 | labels_input = gr.Textbox(
169 | label="Number of tokens to reconstruct (comma-separated)",
170 | value="1, 2, 4, 8, 16",
171 | placeholder="Enter comma-separated numbers for the number of slots to use"
172 | )
173 |
174 | # Second column - Output (conditionally rendered)
175 | with gr.Column(scale=1):
176 | gr.Markdown("## Output")
177 |
178 | # Container for conditional rendering
179 | with gr.Group(visible=True) as gallery_container:
180 | gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
181 |
182 | # Always visible output image
183 | output_image = gr.Image(label="Generated Image", type="numpy")
184 |
185 | # Handle form submission
186 | submit_btn = gr.Button("Generate")
187 |
188 | # Define the processing logic
189 | def update_outputs(class_selection, class_id, show_gallery_value, slider_value, labels_text):
190 | # Determine which class to use - either from dropdown or direct input
191 | if class_selection:
192 | # Extract class ID from the dropdown selection
193 | selected_class_id = int(class_selection.split(":")[0])
194 | else:
195 | selected_class_id = int(class_id)
196 |
197 | # Update the visibility of the gallery container
198 | gallery_container.visible = show_gallery_value
199 |
200 | try:
201 | # Parse the labels from the text input
202 | if labels_text and "," in labels_text:
203 | labels = [int(label.strip()) for label in labels_text.split(",")]
204 | else:
205 | # Default labels if none provided or in wrong format
206 | labels = [1, 4, 16, 64, 256]
207 | except:
208 | labels = [1, 4, 16, 64, 256]
209 |
210 | while len(labels) < 3:
211 | labels.append(256)
212 |
213 | # Generate the image based on the selected class
214 | slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
215 |
216 | recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
217 |
218 | # Always generate the model decomposition for potential gallery display
219 | model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
220 |
221 | if not show_gallery_value:
222 | # If only the image should be shown, return just the processed image
223 | return gallery_container, [], recon
224 | else:
225 | # Create image variations and pair them with labels
226 | gallery_images = [
227 | (recon, f'Generated from class {selected_class_id}'),
228 | ] + [(img, 'Gen. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
229 | return gallery_container, gallery_images, recon
230 |
231 | # Connect the inputs and outputs
232 | submit_btn.click(
233 | fn=update_outputs,
234 | inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
235 | outputs=[gallery_container, gallery, output_image]
236 | )
237 |
238 | # Also update when checkbox changes
239 | show_gallery.change(
240 | fn=lambda value: gr.update(visible=value),
241 | inputs=[show_gallery],
242 | outputs=[gallery_container]
243 | )
244 |
245 | # Add examples
246 | examples = [
247 | # ["0: tench, Tinca tinca", 0, True, 4.0, "1,2,4,8,16"],
248 | ["1: goldfish", 1, True, 4.0, "1,2,4,8,16"],
249 | # ["2: great white shark, white shark", 2, True, 4.0, "1,2,4,8,16"],
250 | ]
251 |
252 | gr.Examples(
253 | examples=examples,
254 | inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
255 | outputs=[gallery_container, gallery, output_image],
256 | fn=update_outputs,
257 | cache_examples=False
258 | )
259 |
260 | # Launch the demo
261 | if __name__ == "__main__":
262 | demo.launch()
263 |
--------------------------------------------------------------------------------
/imagenet_classes.py:
--------------------------------------------------------------------------------
1 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
--------------------------------------------------------------------------------
/pages/figs/Token_PCA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/Token_PCA.png
--------------------------------------------------------------------------------
/pages/figs/comp_table.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/comp_table.jpg
--------------------------------------------------------------------------------
/pages/figs/spectral_analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/spectral_analysis.png
--------------------------------------------------------------------------------
/pages/figs/spectral_titok_ours.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/spectral_titok_ours.jpg
--------------------------------------------------------------------------------
/pages/figs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/teaser.jpg
--------------------------------------------------------------------------------
/pages/figs/tokenizer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visual-gen/semanticist/4c44856c16e6a420bda719ce5d1cf81e0dcb3b78/pages/figs/tokenizer.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.26.4
2 | accelerate
3 | diffusers[torch]
4 | transformers
5 | safetensors
6 | omegaconf
7 | tensorboard
8 | huggingface-hub
9 | einops
10 | timm
11 | scipy
12 | scikit-learn
13 | scikit-image
14 | git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
15 | opencv-python-headless
16 | torchmetrics
17 | submitit
--------------------------------------------------------------------------------
/semanticist/engine/trainer_utils.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 | import cv2
3 | import numpy as np
4 | import torch_fidelity
5 | from collections import OrderedDict
6 | from concurrent.futures import ThreadPoolExecutor
7 | import importlib
8 | from torch.optim import AdamW
9 | from semanticist.utils.lr_scheduler import build_scheduler
10 |
11 |
12 | def get_obj_from_str(string, reload=False):
13 | """Get object from string path."""
14 | module, cls = string.rsplit(".", 1)
15 | if reload:
16 | module_imp = importlib.import_module(module)
17 | importlib.reload(module_imp)
18 | return getattr(importlib.import_module(module, package=None), cls)
19 |
20 |
21 | def instantiate_from_config(config):
22 | """Instantiate an object from a config dictionary."""
23 | if not "target" in config:
24 | raise KeyError("Expected key `target` to instantiate.")
25 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
26 |
27 |
28 | def is_dist_avail_and_initialized():
29 | """Check if distributed training is available and initialized."""
30 | if not torch.distributed.is_initialized():
31 | return False
32 | return True
33 |
34 |
35 | def is_main_process():
36 | """Check if the current process is the main process."""
37 | return not is_dist_avail_and_initialized() or torch.distributed.get_rank() == 0
38 |
39 |
40 | def concat_all_gather(tensor):
41 | """
42 | Performs all_gather operation on the provided tensors.
43 | *** Warning ***: torch.distributed.all_gather has no gradient.
44 | """
45 | tensors_gather = [torch.ones_like(tensor)
46 | for _ in range(torch.distributed.get_world_size())]
47 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
48 |
49 | output = torch.cat(tensors_gather, dim=0)
50 | return output
51 |
52 |
53 | def requires_grad(model, flag=True):
54 | """Set requires_grad flag for all model parameters."""
55 | for p in model.parameters():
56 | p.requires_grad = flag
57 |
58 |
59 | def save_img(img, save_path):
60 | """Save a single image to disk."""
61 | img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255)
62 | img = img.astype(np.uint8)[:, :, ::-1]
63 | cv2.imwrite(save_path, img)
64 |
65 |
66 | def save_img_batch(imgs, save_paths):
67 | """Process and save multiple images at once using a thread pool."""
68 | # Convert to numpy and prepare all images in one go
69 | imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
70 | imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
71 |
72 | with ThreadPoolExecutor(max_workers=32) as pool:
73 | # Submit all tasks at once
74 | futures = [pool.submit(cv2.imwrite, path, img)
75 | for path, img in zip(save_paths, imgs)]
76 | # Wait for all tasks to complete
77 | for future in futures:
78 | future.result() # This will raise any exceptions that occurred
79 |
80 |
81 | def get_fid_stats(real_dir, rec_dir, fid_stats):
82 | """Calculate FID statistics between real and reconstructed images."""
83 | stats = torch_fidelity.calculate_metrics(
84 | input1=rec_dir,
85 | input2=real_dir,
86 | fid_statistics_file=fid_stats,
87 | cuda=True,
88 | isc=True,
89 | fid=True,
90 | kid=False,
91 | prc=False,
92 | verbose=False,
93 | )
94 | return stats
95 |
96 |
97 | def create_scheduler(optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps,
98 | warmup_lr_init, decay_steps, cosine_lr):
99 | """Create a learning rate scheduler."""
100 | scheduler = build_scheduler(
101 | optimizer,
102 | num_epoch,
103 | steps_per_epoch,
104 | lr_min,
105 | warmup_steps,
106 | warmup_lr_init,
107 | decay_steps,
108 | cosine_lr,
109 | )
110 | return scheduler
111 |
112 |
113 | def load_state_dict(state_dict, model):
114 | """Helper to load a state dict with proper prefix handling."""
115 | if 'state_dict' in state_dict:
116 | state_dict = state_dict['state_dict']
117 | # Remove '_orig_mod' prefix if present
118 | state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
119 | missing, unexpected = model.load_state_dict(
120 | state_dict, strict=False
121 | )
122 | if is_main_process():
123 | print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
124 |
125 |
126 | def load_safetensors(path, model):
127 | """Helper to load a safetensors checkpoint."""
128 | from safetensors.torch import safe_open
129 | with safe_open(path, framework="pt", device="cpu") as f:
130 | state_dict = {k: f.get_tensor(k) for k in f.keys()}
131 | load_state_dict(state_dict, model)
132 |
133 |
134 | def setup_result_folders(result_folder):
135 | """Setup result folders for saving models and images."""
136 | model_saved_dir = os.path.join(result_folder, "models")
137 | os.makedirs(model_saved_dir, exist_ok=True)
138 |
139 | image_saved_dir = os.path.join(result_folder, "images")
140 | os.makedirs(image_saved_dir, exist_ok=True)
141 |
142 | return model_saved_dir, image_saved_dir
143 |
144 |
145 | def create_optimizer(model, weight_decay, learning_rate, betas=(0.9, 0.95)):
146 | """Create an AdamW optimizer with weight decay for 2D parameters only."""
147 | # start with all of the candidate parameters
148 | param_dict = {pn: p for pn, p in model.named_parameters()}
149 | # filter out those that do not require grad
150 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
151 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
152 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
153 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
154 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
155 | optim_groups = [
156 | {'params': decay_params, 'weight_decay': weight_decay},
157 | {'params': nodecay_params, 'weight_decay': 0.0}
158 | ]
159 | num_decay_params = sum(p.numel() for p in decay_params)
160 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
161 | if is_main_process():
162 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
163 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
164 | optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas)
165 | return optimizer
166 |
167 |
168 | class EMAModel:
169 | """Model Exponential Moving Average."""
170 | def __init__(self, model, device, decay=0.999):
171 | self.device = device
172 | self.decay = decay
173 | self.ema_params = OrderedDict(
174 | (name, param.clone().detach().to(device))
175 | for name, param in model.named_parameters()
176 | if param.requires_grad
177 | )
178 |
179 | @torch.no_grad()
180 | def update(self, model):
181 | for name, param in model.named_parameters():
182 | if param.requires_grad:
183 | if name in self.ema_params:
184 | self.ema_params[name].lerp_(param.data, 1 - self.decay)
185 | else:
186 | self.ema_params[name] = param.data.clone().detach()
187 |
188 | def state_dict(self):
189 | return self.ema_params
190 |
191 | def load_state_dict(self, params):
192 | self.ema_params = OrderedDict(
193 | (name, param.clone().detach().to(self.device))
194 | for name, param in params.items()
195 | )
196 |
197 |
198 | class PaddedDataset(torch.utils.data.Dataset):
199 | """Dataset wrapper that pads a dataset to ensure even distribution across processes."""
200 | def __init__(self, dataset, padding_size):
201 | self.dataset = dataset
202 | self.padding_size = padding_size
203 |
204 | def __len__(self):
205 | return len(self.dataset) + self.padding_size
206 |
207 | def __getitem__(self, idx):
208 | if idx < len(self.dataset):
209 | return self.dataset[idx]
210 | return self.dataset[0]
211 |
212 | class CacheDataLoader:
213 | """DataLoader-like interface for cached data with epoch-based shuffling."""
214 | def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None):
215 | self.slots = slots
216 | self.targets = targets
217 | self.batch_size = batch_size
218 | self.num_augs = num_augs
219 | self.seed = seed
220 | self.epoch = 0
221 | # Original dataset size (before augmentations)
222 | self.num_samples = len(slots) // num_augs
223 |
224 | def set_epoch(self, epoch):
225 | """Set epoch for deterministic shuffling."""
226 | self.epoch = epoch
227 |
228 | def __len__(self):
229 | """Return number of batches based on original dataset size."""
230 | return self.num_samples // self.batch_size
231 |
232 | def __iter__(self):
233 | """Return random indices for current epoch."""
234 | g = torch.Generator()
235 | g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch)
236 |
237 | # Randomly sample indices from the entire augmented dataset
238 | indices = torch.randint(
239 | 0, len(self.slots),
240 | (self.num_samples,),
241 | generator=g
242 | ).numpy()
243 |
244 | # Yield batches of indices
245 | for start in range(0, self.num_samples, self.batch_size):
246 | end = min(start + self.batch_size, self.num_samples)
247 | batch_indices = indices[start:end]
248 | yield (
249 | torch.from_numpy(self.slots[batch_indices]),
250 | torch.from_numpy(self.targets[batch_indices])
251 | )
--------------------------------------------------------------------------------
/semanticist/stage1/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/semanticist/stage1/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/semanticist/stage1/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/semanticist/stage1/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/semanticist/stage1/diffusion_transfomer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Mlp
17 | from semanticist.stage1.fused_attention import Attention
18 |
19 |
20 |
21 | def modulate(x, shift, scale):
22 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
23 |
24 |
25 | #################################################################################
26 | # Embedding Layers for Timesteps and Class Labels #
27 | #################################################################################
28 |
29 | class TimestepEmbedder(nn.Module):
30 | """
31 | Embeds scalar timesteps into vector representations.
32 | """
33 | def __init__(self, hidden_size, frequency_embedding_size=256):
34 | super().__init__()
35 | self.mlp = nn.Sequential(
36 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
37 | nn.SiLU(),
38 | nn.Linear(hidden_size, hidden_size, bias=True),
39 | )
40 | self.frequency_embedding_size = frequency_embedding_size
41 |
42 | @staticmethod
43 | def timestep_embedding(t, dim, max_period=10000):
44 | """
45 | Create sinusoidal timestep embeddings.
46 | :param t: a 1-D Tensor of N indices, one per batch element.
47 | These may be fractional.
48 | :param dim: the dimension of the output.
49 | :param max_period: controls the minimum frequency of the embeddings.
50 | :return: an (N, D) Tensor of positional embeddings.
51 | """
52 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
53 | half = dim // 2
54 | freqs = torch.exp(
55 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
56 | ).to(device=t.device)
57 | args = t[:, None].float() * freqs[None]
58 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
59 | if dim % 2:
60 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
61 | return embedding
62 |
63 | def forward(self, t):
64 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
65 | t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
66 | return t_emb
67 |
68 |
69 | class LabelEmbedder(nn.Module):
70 | """
71 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
72 | """
73 | def __init__(self, num_classes, hidden_size, dropout_prob):
74 | super().__init__()
75 | use_cfg_embedding = dropout_prob > 0
76 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
77 | self.num_classes = num_classes
78 | self.dropout_prob = dropout_prob
79 |
80 | def token_drop(self, labels, force_drop_ids=None):
81 | """
82 | Drops labels to enable classifier-free guidance.
83 | """
84 | if force_drop_ids is None:
85 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
86 | else:
87 | drop_ids = force_drop_ids == 1
88 | labels = torch.where(drop_ids, self.num_classes, labels)
89 | return labels
90 |
91 | def forward(self, labels, train, force_drop_ids=None):
92 | use_dropout = self.dropout_prob > 0
93 | if (train and use_dropout) or (force_drop_ids is not None):
94 | labels = self.token_drop(labels, force_drop_ids)
95 | embeddings = self.embedding_table(labels)
96 | return embeddings
97 |
98 |
99 | #################################################################################
100 | # Core DiT Model #
101 | #################################################################################
102 |
103 | class DiTBlock(nn.Module):
104 | """
105 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
106 | """
107 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
108 | super().__init__()
109 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
111 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
112 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
113 | approx_gelu = lambda: nn.GELU(approximate="tanh")
114 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
115 | self.adaLN_modulation = nn.Sequential(
116 | nn.SiLU(),
117 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
118 | )
119 |
120 | def forward(self, x, c, mask=None):
121 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
122 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)
123 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
124 | return x
125 |
126 |
127 | class FinalLayer(nn.Module):
128 | """
129 | The final layer of DiT.
130 | """
131 | def __init__(self, hidden_size, patch_size, out_channels):
132 | super().__init__()
133 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
134 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
135 | self.adaLN_modulation = nn.Sequential(
136 | nn.SiLU(),
137 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
138 | )
139 |
140 | def forward(self, x, c):
141 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
142 | x = modulate(self.norm_final(x), shift, scale)
143 | x = self.linear(x)
144 | return x
145 |
146 |
147 | class DiT(nn.Module):
148 | """
149 | Diffusion model with a Transformer backbone.
150 | """
151 | def __init__(
152 | self,
153 | input_size=32,
154 | patch_size=2,
155 | in_channels=4,
156 | hidden_size=1152,
157 | depth=28,
158 | num_heads=16,
159 | mlp_ratio=4.0,
160 | class_dropout_prob=0.1,
161 | num_classes=1000,
162 | learn_sigma=True,
163 | ):
164 | super().__init__()
165 | self.learn_sigma = learn_sigma
166 | self.in_channels = in_channels
167 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
168 | self.patch_size = patch_size
169 | self.num_heads = num_heads
170 |
171 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
172 | self.t_embedder = TimestepEmbedder(hidden_size)
173 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
174 | num_patches = self.x_embedder.num_patches
175 | # Will use fixed sin-cos embedding:
176 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
177 |
178 | self.blocks = nn.ModuleList([
179 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
180 | ])
181 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
182 | self.initialize_weights()
183 |
184 | def initialize_weights(self):
185 | # Initialize transformer layers:
186 | def _basic_init(module):
187 | if isinstance(module, nn.Linear):
188 | torch.nn.init.xavier_uniform_(module.weight)
189 | if module.bias is not None:
190 | nn.init.constant_(module.bias, 0)
191 | self.apply(_basic_init)
192 |
193 | # Initialize (and freeze) pos_embed by sin-cos embedding:
194 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
195 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
196 |
197 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
198 | w = self.x_embedder.proj.weight.data
199 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
200 | nn.init.constant_(self.x_embedder.proj.bias, 0)
201 |
202 | # Initialize label embedding table:
203 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
204 |
205 | # Initialize timestep embedding MLP:
206 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
207 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
208 |
209 | # Zero-out adaLN modulation layers in DiT blocks:
210 | for block in self.blocks:
211 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
212 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
213 |
214 | # Zero-out output layers:
215 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
216 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
217 | nn.init.constant_(self.final_layer.linear.weight, 0)
218 | nn.init.constant_(self.final_layer.linear.bias, 0)
219 |
220 | def unpatchify(self, x):
221 | """
222 | x: (N, T, patch_size**2 * C)
223 | imgs: (N, H, W, C)
224 | """
225 | c = self.out_channels
226 | p = self.x_embedder.patch_size[0]
227 | h = w = int(x.shape[1] ** 0.5)
228 | assert h * w == x.shape[1]
229 |
230 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
231 | x = torch.einsum('nhwpqc->nchpwq', x)
232 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
233 | return imgs
234 |
235 | def forward(self, x, t, y):
236 | """
237 | Forward pass of DiT.
238 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
239 | t: (N,) tensor of diffusion timesteps
240 | y: (N,) tensor of class labels
241 | """
242 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
243 | t = self.t_embedder(t) # (N, D)
244 | y = self.y_embedder(y, self.training) # (N, D)
245 | c = t + y # (N, D)
246 | for block in self.blocks:
247 | x = block(x, c) # (N, T, D)
248 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
249 | x = self.unpatchify(x) # (N, out_channels, H, W)
250 | return x
251 |
252 | def forward_with_cfg(self, x, t, y, cfg_scale):
253 | """
254 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
255 | """
256 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
257 | half = x[: len(x) // 2]
258 | combined = torch.cat([half, half], dim=0)
259 | model_out = self.forward(combined, t, y)
260 | # For exact reproducibility reasons, we apply classifier-free guidance on only
261 | # three channels by default. The standard approach to cfg applies it to all channels.
262 | # This can be done by uncommenting the following line and commenting-out the line following that.
263 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
264 | # eps, rest = model_out[:, :3], model_out[:, 3:]
265 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
266 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
267 | eps = torch.cat([half_eps, half_eps], dim=0)
268 | return torch.cat([eps, rest], dim=1)
269 |
270 |
271 | #################################################################################
272 | # Sine/Cosine Positional Embedding Functions #
273 | #################################################################################
274 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
275 |
276 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
277 | """
278 | grid_size: int of the grid height and width
279 | return:
280 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
281 | """
282 | grid_h = np.arange(grid_size, dtype=np.float32)
283 | grid_w = np.arange(grid_size, dtype=np.float32)
284 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
285 | grid = np.stack(grid, axis=0)
286 |
287 | grid = grid.reshape([2, 1, grid_size, grid_size])
288 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
289 | if cls_token and extra_tokens > 0:
290 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
291 | return pos_embed
292 |
293 |
294 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
295 | assert embed_dim % 2 == 0
296 |
297 | # use half of dimensions to encode grid_h
298 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
299 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
300 |
301 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
302 | return emb
303 |
304 |
305 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
306 | """
307 | embed_dim: output dimension for each position
308 | pos: a list of positions to be encoded: size (M,)
309 | out: (M, D)
310 | """
311 | assert embed_dim % 2 == 0
312 | omega = np.arange(embed_dim // 2, dtype=np.float64)
313 | omega /= embed_dim / 2.
314 | omega = 1. / 10000**omega # (D/2,)
315 |
316 | pos = pos.reshape(-1) # (M,)
317 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
318 |
319 | emb_sin = np.sin(out) # (M, D/2)
320 | emb_cos = np.cos(out) # (M, D/2)
321 |
322 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
323 | return emb
324 |
325 |
326 | #################################################################################
327 | # DiT Configs #
328 | #################################################################################
329 |
330 | def DiT_XL_2(**kwargs):
331 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
332 |
333 | def DiT_XL_4(**kwargs):
334 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
335 |
336 | def DiT_XL_8(**kwargs):
337 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
338 |
339 | def DiT_L_2(**kwargs):
340 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
341 |
342 | def DiT_L_4(**kwargs):
343 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
344 |
345 | def DiT_L_8(**kwargs):
346 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
347 |
348 | def DiT_B_2(**kwargs):
349 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
350 |
351 | def DiT_B_4(**kwargs):
352 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
353 |
354 | def DiT_B_8(**kwargs):
355 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
356 |
357 | def DiT_S_2(**kwargs):
358 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
359 |
360 | def DiT_S_4(**kwargs):
361 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
362 |
363 | def DiT_S_8(**kwargs):
364 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
365 |
366 |
367 | DiT_models = {
368 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
369 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
370 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
371 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
372 | }
--------------------------------------------------------------------------------
/semanticist/stage1/fused_attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from typing import Type
4 |
5 | class Attention(nn.Module):
6 | def __init__(
7 | self,
8 | dim: int,
9 | num_heads: int = 8,
10 | qkv_bias: bool = False,
11 | qk_norm: bool = False,
12 | proj_bias: bool = True,
13 | attn_drop: float = 0.,
14 | proj_drop: float = 0.,
15 | norm_layer: Type[nn.Module] = nn.LayerNorm,
16 | ) -> None:
17 | super().__init__()
18 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
19 | self.num_heads = num_heads
20 | self.head_dim = dim // num_heads
21 | self.scale = self.head_dim ** -0.5
22 |
23 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
24 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
25 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
26 | self.attn_drop = nn.Dropout(attn_drop)
27 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
28 | self.proj_drop = nn.Dropout(proj_drop)
29 |
30 | def forward(self, x, attn_mask=None):
31 | B, N, C = x.shape
32 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
33 | q, k, v = qkv.unbind(0)
34 | q, k = self.q_norm(q), self.k_norm(k)
35 |
36 | x = F.scaled_dot_product_attention(
37 | q, k, v,
38 | attn_mask=attn_mask,
39 | dropout_p=self.attn_drop.p if self.training else 0.,
40 | )
41 |
42 | x = x.transpose(1, 2).reshape(B, N, C)
43 | x = self.proj(x)
44 | x = self.proj_drop(x)
45 | return x
46 |
--------------------------------------------------------------------------------
/semanticist/stage1/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_1d_sincos_pos_embed(embed_dim, grid_size):
39 | grid = np.arange(grid_size, dtype=np.float32)
40 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
41 | return pos_embed
42 |
43 |
44 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45 | assert embed_dim % 2 == 0
46 |
47 | # use half of dimensions to encode grid_h
48 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
49 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
50 |
51 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52 | return emb
53 |
54 |
55 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
56 | """
57 | embed_dim: output dimension for each position
58 | pos: a list of positions to be encoded: size (M,)
59 | out: (M, D)
60 | """
61 | assert embed_dim % 2 == 0
62 | omega = np.arange(embed_dim // 2, dtype=float)
63 | omega /= embed_dim / 2.
64 | omega = 1. / 10000**omega # (D/2,)
65 |
66 | pos = pos.reshape(-1) # (M,)
67 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
68 |
69 | emb_sin = np.sin(out) # (M, D/2)
70 | emb_cos = np.cos(out) # (M, D/2)
71 |
72 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
73 | return emb
74 |
75 |
76 | # --------------------------------------------------------
77 | # Interpolate position embeddings for high-resolution
78 | # References:
79 | # DeiT: https://github.com/facebookresearch/deit
80 | # --------------------------------------------------------
81 | def interpolate_pos_embed(model, checkpoint_model):
82 | if 'pos_embed' in checkpoint_model:
83 | pos_embed_checkpoint = checkpoint_model['pos_embed']
84 | embedding_size = pos_embed_checkpoint.shape[-1]
85 | num_patches = model.patch_embed.num_patches
86 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
87 | # height (== width) for the checkpoint position embedding
88 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
89 | # height (== width) for the new position embedding
90 | new_size = int(num_patches ** 0.5)
91 | # class_token and dist_token are kept unchanged
92 | if orig_size != new_size:
93 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
94 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
95 | # only the position tokens are interpolated
96 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
97 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
98 | pos_tokens = torch.nn.functional.interpolate(
99 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
100 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
101 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
102 | checkpoint_model['pos_embed'] = new_pos_embed
--------------------------------------------------------------------------------
/semanticist/stage1/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler
2 |
3 | def create_transport(
4 | path_type='Linear',
5 | prediction="velocity",
6 | loss_weight=None,
7 | train_eps=None,
8 | sample_eps=None,
9 | ):
10 | """function for creating Transport object
11 | **Note**: model prediction defaults to velocity
12 | Args:
13 | - path_type: type of path to use; default to linear
14 | - learn_score: set model prediction to score
15 | - learn_noise: set model prediction to noise
16 | - velocity_weighted: weight loss by velocity weight
17 | - likelihood_weighted: weight loss by likelihood weight
18 | - train_eps: small epsilon for avoiding instability during training
19 | - sample_eps: small epsilon for avoiding instability during sampling
20 | """
21 |
22 | if prediction == "noise":
23 | model_type = ModelType.NOISE
24 | elif prediction == "score":
25 | model_type = ModelType.SCORE
26 | else:
27 | model_type = ModelType.VELOCITY
28 |
29 | if loss_weight == "velocity":
30 | loss_type = WeightType.VELOCITY
31 | elif loss_weight == "likelihood":
32 | loss_type = WeightType.LIKELIHOOD
33 | else:
34 | loss_type = WeightType.NONE
35 |
36 | path_choice = {
37 | "Linear": PathType.LINEAR,
38 | "GVP": PathType.GVP,
39 | "VP": PathType.VP,
40 | }
41 |
42 | path_type = path_choice[path_type]
43 |
44 | if (path_type in [PathType.VP]):
45 | train_eps = 1e-5 if train_eps is None else train_eps
46 | sample_eps = 1e-3 if train_eps is None else sample_eps
47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
48 | train_eps = 1e-3 if train_eps is None else train_eps
49 | sample_eps = 1e-3 if train_eps is None else sample_eps
50 | else: # velocity & [GVP, LINEAR] is stable everywhere
51 | train_eps = 0
52 | sample_eps = 0
53 |
54 | # create flow state
55 | state = Transport(
56 | model_type=model_type,
57 | path_type=path_type,
58 | loss_type=loss_type,
59 | train_eps=train_eps,
60 | sample_eps=sample_eps,
61 | )
62 |
63 | return state
--------------------------------------------------------------------------------
/semanticist/stage1/transport/integrators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | import torch.nn as nn
4 | from torchdiffeq import odeint
5 | from functools import partial
6 | from tqdm import tqdm
7 |
8 | class sde:
9 | """SDE solver class"""
10 | def __init__(
11 | self,
12 | drift,
13 | diffusion,
14 | *,
15 | t0,
16 | t1,
17 | num_steps,
18 | sampler_type,
19 | temperature=1.0,
20 | ):
21 | assert t0 < t1, "SDE sampler has to be in forward time"
22 |
23 | self.num_timesteps = num_steps
24 | self.t = th.linspace(t0, t1, num_steps)
25 | self.dt = self.t[1] - self.t[0]
26 | self.drift = drift
27 | self.diffusion = diffusion
28 | self.sampler_type = sampler_type
29 | self.temperature = temperature
30 |
31 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
32 | w_cur = th.randn(x.size()).to(x)
33 | t = th.ones(x.size(0)).to(x) * t
34 | dw = w_cur * th.sqrt(self.dt)
35 | drift = self.drift(x, t, model, **model_kwargs)
36 | diffusion = self.diffusion(x, t)
37 | mean_x = x + drift * self.dt
38 | x = mean_x + th.sqrt(2 * diffusion) * dw * self.temperature
39 | return x, mean_x
40 |
41 | def __Heun_step(self, x, _, t, model, **model_kwargs):
42 | w_cur = th.randn(x.size()).to(x)
43 | dw = w_cur * th.sqrt(self.dt) * self.temperature
44 | t_cur = th.ones(x.size(0)).to(x) * t
45 | diffusion = self.diffusion(x, t_cur)
46 | xhat = x + th.sqrt(2 * diffusion) * dw
47 | K1 = self.drift(xhat, t_cur, model, **model_kwargs)
48 | xp = xhat + self.dt * K1
49 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
50 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
51 |
52 | def __forward_fn(self):
53 | """TODO: generalize here by adding all private functions ending with steps to it"""
54 | sampler_dict = {
55 | "Euler": self.__Euler_Maruyama_step,
56 | "Heun": self.__Heun_step,
57 | }
58 |
59 | try:
60 | sampler = sampler_dict[self.sampler_type]
61 | except:
62 | raise NotImplementedError("Smapler type not implemented.")
63 |
64 | return sampler
65 |
66 | def sample(self, init, model, **model_kwargs):
67 | """forward loop of sde"""
68 | x = init
69 | mean_x = init
70 | samples = []
71 | sampler = self.__forward_fn()
72 | for ti in self.t[:-1]:
73 | with th.no_grad():
74 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
75 | samples.append(x)
76 |
77 | return samples
78 |
79 | class ode:
80 | """ODE solver class"""
81 | def __init__(
82 | self,
83 | drift,
84 | *,
85 | t0,
86 | t1,
87 | sampler_type,
88 | num_steps,
89 | atol,
90 | rtol,
91 | temperature=1.0,
92 | ):
93 | assert t0 < t1, "ODE sampler has to be in forward time"
94 |
95 | self.drift = drift
96 | self.t = th.linspace(t0, t1, num_steps)
97 | self.atol = atol
98 | self.rtol = rtol
99 | self.sampler_type = sampler_type
100 | self.temperature = temperature
101 |
102 | def sample(self, x, model, **model_kwargs):
103 |
104 | device = x[0].device if isinstance(x, tuple) else x.device
105 | def _fn(t, x):
106 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
107 | # For ODE, we scale the drift by the temperature
108 | # This is equivalent to scaling time by 1/temperature
109 | model_output = self.drift(x, t, model, **model_kwargs)
110 | if self.temperature != 1.0:
111 | # If it's a tuple (for likelihood calculation), only scale the first element
112 | if isinstance(model_output, tuple):
113 | scaled_output = (model_output[0] / self.temperature, model_output[1])
114 | return scaled_output
115 | else:
116 | return model_output / self.temperature
117 | return model_output
118 |
119 | t = self.t.to(device)
120 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
121 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
122 | samples = odeint(
123 | _fn,
124 | x,
125 | t,
126 | method=self.sampler_type,
127 | atol=atol,
128 | rtol=rtol
129 | )
130 | return samples
--------------------------------------------------------------------------------
/semanticist/stage1/transport/path.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | from functools import partial
4 |
5 | def expand_t_like_x(t, x):
6 | """Function to reshape time t to broadcastable dimension of x
7 | Args:
8 | t: [batch_dim,], time vector
9 | x: [batch_dim,...], data point
10 | """
11 | dims = [1] * (len(x.size()) - 1)
12 | t = t.view(t.size(0), *dims)
13 | return t
14 |
15 |
16 | #################### Coupling Plans ####################
17 |
18 | class ICPlan:
19 | """Linear Coupling Plan"""
20 | def __init__(self, sigma=0.0):
21 | self.sigma = sigma
22 |
23 | def compute_alpha_t(self, t):
24 | """Compute the data coefficient along the path"""
25 | return t, 1
26 |
27 | def compute_sigma_t(self, t):
28 | """Compute the noise coefficient along the path"""
29 | return 1 - t, -1
30 |
31 | def compute_d_alpha_alpha_ratio_t(self, t):
32 | """Compute the ratio between d_alpha and alpha"""
33 | return 1 / t
34 |
35 | def compute_drift(self, x, t):
36 | """We always output sde according to score parametrization; """
37 | t = expand_t_like_x(t, x)
38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
39 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
40 | drift = alpha_ratio * x
41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
42 |
43 | return -drift, diffusion
44 |
45 | def compute_diffusion(self, x, t, form="constant", norm=1.0):
46 | """Compute the diffusion term of the SDE
47 | Args:
48 | x: [batch_dim, ...], data point
49 | t: [batch_dim,], time vector
50 | form: str, form of the diffusion term
51 | norm: float, norm of the diffusion term
52 | """
53 | t = expand_t_like_x(t, x)
54 | choices = {
55 | "constant": norm,
56 | "SBDM": norm * self.compute_drift(x, t)[1],
57 | "sigma": norm * self.compute_sigma_t(t)[0],
58 | "linear": norm * (1 - t),
59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
60 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
61 | }
62 |
63 | try:
64 | diffusion = choices[form]
65 | except KeyError:
66 | raise NotImplementedError(f"Diffusion form {form} not implemented")
67 |
68 | return diffusion
69 |
70 | def get_score_from_velocity(self, velocity, x, t):
71 | """Wrapper function: transfrom velocity prediction model to score
72 | Args:
73 | velocity: [batch_dim, ...] shaped tensor; velocity model output
74 | x: [batch_dim, ...] shaped tensor; x_t data point
75 | t: [batch_dim,] time tensor
76 | """
77 | t = expand_t_like_x(t, x)
78 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
79 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
80 | mean = x
81 | reverse_alpha_ratio = alpha_t / d_alpha_t
82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
83 | score = (reverse_alpha_ratio * velocity - mean) / var
84 | return score
85 |
86 | def get_noise_from_velocity(self, velocity, x, t):
87 | """Wrapper function: transfrom velocity prediction model to denoiser
88 | Args:
89 | velocity: [batch_dim, ...] shaped tensor; velocity model output
90 | x: [batch_dim, ...] shaped tensor; x_t data point
91 | t: [batch_dim,] time tensor
92 | """
93 | t = expand_t_like_x(t, x)
94 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
95 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
96 | mean = x
97 | reverse_alpha_ratio = alpha_t / d_alpha_t
98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t
99 | noise = (reverse_alpha_ratio * velocity - mean) / var
100 | return noise
101 |
102 | def get_velocity_from_score(self, score, x, t):
103 | """Wrapper function: transfrom score prediction model to velocity
104 | Args:
105 | score: [batch_dim, ...] shaped tensor; score model output
106 | x: [batch_dim, ...] shaped tensor; x_t data point
107 | t: [batch_dim,] time tensor
108 | """
109 | t = expand_t_like_x(t, x)
110 | drift, var = self.compute_drift(x, t)
111 | velocity = var * score - drift
112 | return velocity
113 |
114 | def compute_mu_t(self, t, x0, x1):
115 | """Compute the mean of time-dependent density p_t"""
116 | t = expand_t_like_x(t, x1)
117 | alpha_t, _ = self.compute_alpha_t(t)
118 | sigma_t, _ = self.compute_sigma_t(t)
119 | return alpha_t * x1 + sigma_t * x0
120 |
121 | def compute_xt(self, t, x0, x1):
122 | """Sample xt from time-dependent density p_t; rng is required"""
123 | xt = self.compute_mu_t(t, x0, x1)
124 | return xt
125 |
126 | def compute_ut(self, t, x0, x1, xt):
127 | """Compute the vector field corresponding to p_t"""
128 | t = expand_t_like_x(t, x1)
129 | _, d_alpha_t = self.compute_alpha_t(t)
130 | _, d_sigma_t = self.compute_sigma_t(t)
131 | return d_alpha_t * x1 + d_sigma_t * x0
132 |
133 | def plan(self, t, x0, x1):
134 | xt = self.compute_xt(t, x0, x1)
135 | ut = self.compute_ut(t, x0, x1, xt)
136 | return t, xt, ut
137 |
138 |
139 | class VPCPlan(ICPlan):
140 | """class for VP path flow matching"""
141 |
142 | def __init__(self, sigma_min=0.1, sigma_max=20.0):
143 | self.sigma_min = sigma_min
144 | self.sigma_max = sigma_max
145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
147 |
148 |
149 | def compute_alpha_t(self, t):
150 | """Compute coefficient of x1"""
151 | alpha_t = self.log_mean_coeff(t)
152 | alpha_t = th.exp(alpha_t)
153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
154 | return alpha_t, d_alpha_t
155 |
156 | def compute_sigma_t(self, t):
157 | """Compute coefficient of x0"""
158 | p_sigma_t = 2 * self.log_mean_coeff(t)
159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
161 | return sigma_t, d_sigma_t
162 |
163 | def compute_d_alpha_alpha_ratio_t(self, t):
164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
165 | return self.d_log_mean_coeff(t)
166 |
167 | def compute_drift(self, x, t):
168 | """Compute the drift term of the SDE"""
169 | t = expand_t_like_x(t, x)
170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
171 | return -0.5 * beta_t * x, beta_t / 2
172 |
173 |
174 | class GVPCPlan(ICPlan):
175 | def __init__(self, sigma=0.0):
176 | super().__init__(sigma)
177 |
178 | def compute_alpha_t(self, t):
179 | """Compute coefficient of x1"""
180 | alpha_t = th.sin(t * np.pi / 2)
181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
182 | return alpha_t, d_alpha_t
183 |
184 | def compute_sigma_t(self, t):
185 | """Compute coefficient of x0"""
186 | sigma_t = th.cos(t * np.pi / 2)
187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
188 | return sigma_t, d_sigma_t
189 |
190 | def compute_d_alpha_alpha_ratio_t(self, t):
191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
192 | return np.pi / (2 * th.tan(t * np.pi / 2))
--------------------------------------------------------------------------------
/semanticist/stage1/transport/transport.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | import logging
4 |
5 | import enum
6 |
7 | from . import path
8 | from .utils import EasyDict, log_state, mean_flat
9 | from .integrators import ode, sde
10 |
11 | class ModelType(enum.Enum):
12 | """
13 | Which type of output the model predicts.
14 | """
15 |
16 | NOISE = enum.auto() # the model predicts epsilon
17 | SCORE = enum.auto() # the model predicts \nabla \log p(x)
18 | VELOCITY = enum.auto() # the model predicts v(x)
19 |
20 | class PathType(enum.Enum):
21 | """
22 | Which type of path to use.
23 | """
24 |
25 | LINEAR = enum.auto()
26 | GVP = enum.auto()
27 | VP = enum.auto()
28 |
29 | class WeightType(enum.Enum):
30 | """
31 | Which type of weighting to use.
32 | """
33 |
34 | NONE = enum.auto()
35 | VELOCITY = enum.auto()
36 | LIKELIHOOD = enum.auto()
37 |
38 |
39 | class Transport:
40 |
41 | def __init__(
42 | self,
43 | *,
44 | model_type,
45 | path_type,
46 | loss_type,
47 | train_eps,
48 | sample_eps,
49 | ):
50 | path_options = {
51 | PathType.LINEAR: path.ICPlan,
52 | PathType.GVP: path.GVPCPlan,
53 | PathType.VP: path.VPCPlan,
54 | }
55 |
56 | self.loss_type = loss_type
57 | self.model_type = model_type
58 | self.path_sampler = path_options[path_type]()
59 | self.train_eps = train_eps
60 | self.sample_eps = sample_eps
61 |
62 | def prior_logp(self, z):
63 | '''
64 | Standard multivariate normal prior
65 | Assume z is batched
66 | '''
67 | shape = th.tensor(z.size())
68 | N = th.prod(shape[1:])
69 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
70 | return th.vmap(_fn)(z)
71 |
72 |
73 | def check_interval(
74 | self,
75 | train_eps,
76 | sample_eps,
77 | *,
78 | diffusion_form="SBDM",
79 | sde=False,
80 | reverse=False,
81 | eval=False,
82 | last_step_size=0.0,
83 | ):
84 | t0 = 0
85 | t1 = 1
86 | eps = train_eps if not eval else sample_eps
87 | if (type(self.path_sampler) in [path.VPCPlan]):
88 |
89 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
90 |
91 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
92 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
93 |
94 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
95 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
96 |
97 | if reverse:
98 | t0, t1 = 1 - t0, 1 - t1
99 |
100 | return t0, t1
101 |
102 |
103 | def sample(self, x1):
104 | """Sampling x0 & t based on shape of x1 (if needed)
105 | Args:
106 | x1 - data point; [batch, *dim]
107 | """
108 |
109 | x0 = th.randn_like(x1)
110 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
111 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
112 | t = t.to(x1)
113 | return t, x0, x1
114 |
115 |
116 | def training_losses(
117 | self,
118 | model,
119 | x1,
120 | model_kwargs=None
121 | ):
122 | """Loss for training the score model
123 | Args:
124 | - model: backbone model; could be score, noise, or velocity
125 | - x1: datapoint
126 | - model_kwargs: additional arguments for the model
127 | """
128 | if model_kwargs == None:
129 | model_kwargs = {}
130 |
131 | t, x0, x1 = self.sample(x1)
132 | t, xt, ut = self.path_sampler.plan(t, x0, x1)
133 | model_output = model(xt, t, **model_kwargs)
134 | if len(model_output.shape) == len(xt.shape) + 1:
135 | x0 = x0.unsqueeze(-1).expand(*([-1] * (len(x0.shape))), model_output.shape[-1])
136 | xt = xt.unsqueeze(-1).expand(*([-1] * (len(xt.shape))), model_output.shape[-1])
137 | ut = ut.unsqueeze(-1).expand(*([-1] * (len(ut.shape))), model_output.shape[-1])
138 | B, C = xt.shape[:2]
139 | assert model_output.shape == (B, C, *xt.shape[2:])
140 |
141 | terms = {}
142 | terms['pred'] = model_output
143 | if self.model_type == ModelType.VELOCITY:
144 | terms['loss'] = mean_flat(((model_output - ut) ** 2))
145 | else:
146 | _, drift_var = self.path_sampler.compute_drift(xt, t)
147 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
148 | if self.loss_type in [WeightType.VELOCITY]:
149 | weight = (drift_var / sigma_t) ** 2
150 | elif self.loss_type in [WeightType.LIKELIHOOD]:
151 | weight = drift_var / (sigma_t ** 2)
152 | elif self.loss_type in [WeightType.NONE]:
153 | weight = 1
154 | else:
155 | raise NotImplementedError()
156 |
157 | if self.model_type == ModelType.NOISE:
158 | terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
159 | else:
160 | terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
161 |
162 | return terms
163 |
164 |
165 | def get_drift(
166 | self
167 | ):
168 | """member function for obtaining the drift of the probability flow ODE"""
169 | def score_ode(x, t, model, **model_kwargs):
170 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
171 | model_output = model(x, t, **model_kwargs)
172 | return (-drift_mean + drift_var * model_output) # by change of variable
173 |
174 | def noise_ode(x, t, model, **model_kwargs):
175 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
176 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
177 | model_output = model(x, t, **model_kwargs)
178 | score = model_output / -sigma_t
179 | return (-drift_mean + drift_var * score)
180 |
181 | def velocity_ode(x, t, model, **model_kwargs):
182 | model_output = model(x, t, **model_kwargs)
183 | return model_output
184 |
185 | if self.model_type == ModelType.NOISE:
186 | drift_fn = noise_ode
187 | elif self.model_type == ModelType.SCORE:
188 | drift_fn = score_ode
189 | else:
190 | drift_fn = velocity_ode
191 |
192 | def body_fn(x, t, model, **model_kwargs):
193 | model_output = drift_fn(x, t, model, **model_kwargs)
194 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
195 | return model_output
196 |
197 | return body_fn
198 |
199 |
200 | def get_score(
201 | self,
202 | ):
203 | """member function for obtaining score of
204 | x_t = alpha_t * x + sigma_t * eps"""
205 | if self.model_type == ModelType.NOISE:
206 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
207 | elif self.model_type == ModelType.SCORE:
208 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
209 | elif self.model_type == ModelType.VELOCITY:
210 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
211 | else:
212 | raise NotImplementedError()
213 |
214 | return score_fn
215 |
216 |
217 | class Sampler:
218 | """Sampler class for the transport model"""
219 | def __init__(
220 | self,
221 | transport,
222 | ):
223 | """Constructor for a general sampler; supporting different sampling methods
224 | Args:
225 | - transport: an tranport object specify model prediction & interpolant type
226 | """
227 |
228 | self.transport = transport
229 | self.drift = self.transport.get_drift()
230 | self.score = self.transport.get_score()
231 |
232 | def __get_sde_diffusion_and_drift(
233 | self,
234 | *,
235 | diffusion_form="SBDM",
236 | diffusion_norm=1.0,
237 | ):
238 |
239 | def diffusion_fn(x, t):
240 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
241 | return diffusion
242 |
243 | sde_drift = \
244 | lambda x, t, model, **kwargs: \
245 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
246 |
247 | sde_diffusion = diffusion_fn
248 |
249 | return sde_drift, sde_diffusion
250 |
251 | def __get_last_step(
252 | self,
253 | sde_drift,
254 | *,
255 | last_step,
256 | last_step_size,
257 | ):
258 | """Get the last step function of the SDE solver"""
259 |
260 | if last_step is None:
261 | last_step_fn = \
262 | lambda x, t, model, **model_kwargs: \
263 | x
264 | elif last_step == "Mean":
265 | last_step_fn = \
266 | lambda x, t, model, **model_kwargs: \
267 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size
268 | elif last_step == "Tweedie":
269 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
270 | sigma = self.transport.path_sampler.compute_sigma_t
271 | last_step_fn = \
272 | lambda x, t, model, **model_kwargs: \
273 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
274 | elif last_step == "Euler":
275 | last_step_fn = \
276 | lambda x, t, model, **model_kwargs: \
277 | x + self.drift(x, t, model, **model_kwargs) * last_step_size
278 | else:
279 | raise NotImplementedError()
280 |
281 | return last_step_fn
282 |
283 | def sample_sde(
284 | self,
285 | *,
286 | sampling_method="Euler",
287 | diffusion_form="SBDM",
288 | diffusion_norm=1.0,
289 | last_step="Mean",
290 | last_step_size=0.04,
291 | num_steps=250,
292 | temperature=1.0,
293 | ):
294 | """returns a sampling function with given SDE settings
295 | Args:
296 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
297 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
298 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1
299 | - last_step: type of the last step; default to identity
300 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
301 | - num_steps: total integration step of SDE
302 | - temperature: temperature scaling for the noise during sampling; default to 1.0
303 | """
304 |
305 | if last_step is None:
306 | last_step_size = 0.0
307 |
308 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
309 | diffusion_form=diffusion_form,
310 | diffusion_norm=diffusion_norm,
311 | )
312 |
313 | t0, t1 = self.transport.check_interval(
314 | self.transport.train_eps,
315 | self.transport.sample_eps,
316 | diffusion_form=diffusion_form,
317 | sde=True,
318 | eval=True,
319 | reverse=False,
320 | last_step_size=last_step_size,
321 | )
322 |
323 | _sde = sde(
324 | sde_drift,
325 | sde_diffusion,
326 | t0=t0,
327 | t1=t1,
328 | num_steps=num_steps,
329 | sampler_type=sampling_method,
330 | temperature=temperature
331 | )
332 |
333 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
334 |
335 |
336 | def _sample(init, model, **model_kwargs):
337 | xs = _sde.sample(init, model, **model_kwargs)
338 | ts = th.ones(init.size(0), device=init.device) * t1
339 | x = last_step_fn(xs[-1], ts, model, **model_kwargs)
340 | xs.append(x)
341 |
342 | assert len(xs) == num_steps, "Samples does not match the number of steps"
343 |
344 | return xs
345 |
346 | return _sample
347 |
348 | def sample_ode(
349 | self,
350 | *,
351 | sampling_method="dopri5",
352 | num_steps=50,
353 | atol=1e-6,
354 | rtol=1e-3,
355 | reverse=False,
356 | temperature=1.0,
357 | ):
358 | """returns a sampling function with given ODE settings
359 | Args:
360 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
361 | - num_steps:
362 | - fixed solver (Euler, Heun): the actual number of integration steps performed
363 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
364 | - atol: absolute error tolerance for the solver
365 | - rtol: relative error tolerance for the solver
366 | - reverse: whether solving the ODE in reverse (data to noise); default to False
367 | - temperature: temperature scaling for the drift during sampling; default to 1.0
368 | """
369 | if reverse:
370 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
371 | else:
372 | drift = self.drift
373 |
374 | t0, t1 = self.transport.check_interval(
375 | self.transport.train_eps,
376 | self.transport.sample_eps,
377 | sde=False,
378 | eval=True,
379 | reverse=reverse,
380 | last_step_size=0.0,
381 | )
382 |
383 | _ode = ode(
384 | drift=drift,
385 | t0=t0,
386 | t1=t1,
387 | sampler_type=sampling_method,
388 | num_steps=num_steps,
389 | atol=atol,
390 | rtol=rtol,
391 | temperature=temperature,
392 | )
393 |
394 | return _ode.sample
395 |
396 | def sample_ode_likelihood(
397 | self,
398 | *,
399 | sampling_method="dopri5",
400 | num_steps=50,
401 | atol=1e-6,
402 | rtol=1e-3,
403 | temperature=1.0,
404 | ):
405 |
406 | """returns a sampling function for calculating likelihood with given ODE settings
407 | Args:
408 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
409 | - num_steps:
410 | - fixed solver (Euler, Heun): the actual number of integration steps performed
411 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
412 | - atol: absolute error tolerance for the solver
413 | - rtol: relative error tolerance for the solver
414 | - temperature: temperature scaling for the drift during sampling; default to 1.0
415 | """
416 | def _likelihood_drift(x, t, model, **model_kwargs):
417 | x, _ = x
418 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
419 | t = th.ones_like(t) * (1 - t)
420 | with th.enable_grad():
421 | x.requires_grad = True
422 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
423 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
424 | drift = self.drift(x, t, model, **model_kwargs)
425 | return (-drift, logp_grad)
426 |
427 | t0, t1 = self.transport.check_interval(
428 | self.transport.train_eps,
429 | self.transport.sample_eps,
430 | sde=False,
431 | eval=True,
432 | reverse=False,
433 | last_step_size=0.0,
434 | )
435 |
436 | _ode = ode(
437 | drift=_likelihood_drift,
438 | t0=t0,
439 | t1=t1,
440 | sampler_type=sampling_method,
441 | num_steps=num_steps,
442 | atol=atol,
443 | rtol=rtol,
444 | temperature=temperature,
445 | )
446 |
447 | def _sample_fn(x, model, **model_kwargs):
448 | init_logp = th.zeros(x.size(0)).to(x)
449 | input = (x, init_logp)
450 | drift, delta_logp = _ode.sample(input, model, **model_kwargs)
451 | drift, delta_logp = drift[-1], delta_logp[-1]
452 | prior_logp = self.transport.prior_logp(drift)
453 | logp = prior_logp - delta_logp
454 | return logp, drift
455 |
456 | return _sample_fn
--------------------------------------------------------------------------------
/semanticist/stage1/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | class EasyDict:
4 |
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 | def mean_flat(x):
13 | """
14 | Take the mean over all non-batch dimensions.
15 | """
16 | return th.mean(x, dim=list(range(1, len(x.size()))))
17 |
18 | def log_state(state):
19 | result = []
20 |
21 | sorted_state = dict(sorted(state.items()))
22 | for key, value in sorted_state.items():
23 | # Check if the value is an instance of a class
24 | if "