├── vq ├── py.typed ├── tasks │ ├── utils │ │ └── __init__.py │ ├── __init__.py │ ├── sequence_modeling │ │ ├── runners │ │ │ ├── __init__.py │ │ │ ├── registries.py │ │ │ ├── metrics.py │ │ │ └── base.py │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── registries.py │ │ │ ├── c2i.py │ │ │ └── transformers.py │ │ └── registries.py │ ├── image_tokenization │ │ ├── __init__.py │ │ ├── models │ │ │ ├── quantizers │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── ste.py │ │ │ │ │ └── quantizer_holder.py │ │ │ │ ├── callbacks │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── lazy_init_weights.py │ │ │ │ │ └── base.py │ │ │ │ ├── __init__.py │ │ │ │ ├── registries.py │ │ │ │ └── losses.py │ │ │ ├── connectors │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── conv.py │ │ │ │ └── composed.py │ │ │ ├── __init__.py │ │ │ └── registries.py │ │ ├── runners │ │ │ ├── __init__.py │ │ │ ├── registries.py │ │ │ ├── callbacks.py │ │ │ ├── tokenizer.py │ │ │ └── metrics.py │ │ ├── registries.py │ │ ├── tokenize.py │ │ └── demo.py │ ├── image_reconstruction │ │ ├── __init__.py │ │ └── registries.py │ ├── image_classification │ │ ├── __init__.py │ │ ├── registries.py │ │ └── optimizers.py │ └── registries.py ├── algorithms │ ├── exp │ │ ├── __init__.py │ │ └── vqgan_vqkd │ │ │ └── __init__.py │ ├── nar │ │ ├── __init__.py │ │ └── transformers │ │ │ └── __init__.py │ ├── __init__.py │ ├── sq │ │ ├── __init__.py │ │ └── quantizers.py │ ├── fsq │ │ └── __init__.py │ ├── cluster │ │ ├── __init__.py │ │ ├── autoencoders.py │ │ └── base.py │ ├── utils │ │ ├── __init__.py │ │ └── losses.py │ ├── vq │ │ ├── callbacks │ │ │ ├── __init__.py │ │ │ ├── normalize.py │ │ │ └── update.py │ │ ├── __init__.py │ │ ├── distances.py │ │ └── utils.py │ ├── vqkd │ │ ├── quantizers │ │ │ ├── __init__.py │ │ │ └── base.py │ │ ├── __init__.py │ │ ├── registries.py │ │ ├── teachers │ │ │ ├── __init__.py │ │ │ ├── convnext.py │ │ │ ├── torchvision.py │ │ │ ├── dino.py │ │ │ ├── clip.py │ │ │ └── base.py │ │ └── connector.py │ ├── cvqvae │ │ ├── __init__.py │ │ └── registries.py │ ├── vqgan │ │ ├── discriminators │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── patchgan.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── registries.py │ │ │ ├── generator.py │ │ │ └── discriminator.py │ │ ├── __init__.py │ │ ├── registries.py │ │ └── quantizer.py │ └── ar │ │ ├── transformers │ │ ├── __init__.py │ │ ├── gpt.py │ │ ├── llama.py │ │ ├── base.py │ │ └── hf.py │ │ ├── __init__.py │ │ ├── image.py │ │ ├── c2i.py │ │ └── base.py ├── runners │ ├── callbacks │ │ └── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ ├── loss.py │ │ └── fid.py │ ├── __init__.py │ └── registries.py ├── models │ ├── __init__.py │ ├── registries.py │ └── autoencoders.py ├── utils │ ├── __init__.py │ ├── stores.py │ ├── builders.py │ ├── fid.py │ └── misc.py ├── __init__.py ├── datasets │ ├── __init__.py │ ├── imagenet.py │ ├── coco.py │ ├── vanilla.py │ ├── sa_med2d.py │ ├── laion_aesthetics.py │ ├── satin.py │ ├── split.py │ ├── base.py │ └── concat.py ├── registries.py ├── fid.py ├── test.py └── train.py ├── configs ├── sq │ └── interface.py ├── llamagen │ ├── README.md │ ├── vqgan_imagenet_ddp.py │ ├── vqgan_imagenet_384_ddp.py │ ├── vqgan_128_imagenet_384_ddp.py │ ├── vqgan_stylegan2_imagenet_ddp.py │ ├── c2i_medium_imagenet_ddp.py │ ├── vqgan_256_f8_imagenet_ddp.py │ ├── vqgan_128x16_f8_imagenet_ddp.py │ ├── c2i_medium_vqgan_imagenet_ddp.py │ ├── ar.py │ └── vqgan.py ├── ar │ ├── custom_imports.py │ ├── c2i_gpt2_medium_cfg_imagenet_ddp.py │ ├── c2i_llama_medium_cfg_imagenet_ddp.py │ ├── c2i.py │ ├── model.py │ ├── transformers │ │ ├── gpt2.py │ │ ├── interface.py │ │ └── llama.py │ ├── cfg.py │ ├── c2i_gpt2_medium_imagenet_ddp.py │ ├── c2i_llama_medium_imagenet_ddp.py │ ├── interface.py │ ├── x2i.py │ └── README.md ├── ic │ ├── README.md │ ├── custom_imports.py │ ├── imagenet_ddp.py │ ├── model.py │ ├── interface.py │ └── runner.py ├── vqgan │ ├── custom_imports.py │ ├── 8192_dd2_aglwg075_imagenet_ddp.py │ ├── 16384_dd2_aglwg075_imagenet_ddp.py │ ├── dd2_aglwg075.py │ ├── f8.py │ ├── 8192_satin_ddp.py │ ├── 1024_imagenet_ddp.py │ ├── 16384_imagenet_ddp.py │ ├── 8192_imagenet_ddp.py │ ├── 8192_sa_med2d_20m_ddp.py │ ├── 8192_laion_aesthetics_ddp.py │ ├── 8192_stylegan2_imagenet_ddp.py │ ├── interface.py │ ├── README.md │ └── model.py ├── vqkd │ ├── custom_imports.py │ ├── teachers │ │ ├── openclip.py │ │ ├── clip.py │ │ ├── dino.py │ │ ├── vit.py │ │ ├── evaclip.py │ │ ├── interface.py │ │ ├── openclip_L_14.py │ │ ├── openclip_H_14.py │ │ ├── openclip_bigG_14.py │ │ ├── mae.py │ │ └── convnext.py │ ├── clip_8192_satin_ddp.py │ ├── mae_8192_imagenet_ddp.py │ ├── vit_8192_imagenet_ddp.py │ ├── clip_8192_imagenet_ddp.py │ ├── dino_8192_imagenet_ddp.py │ ├── vit_16384_imagenet_ddp.py │ ├── clip_8192_sa_med2d_20m_ddp.py │ ├── convnext_8192_imagenet_ddp.py │ ├── evaclip_8192_imagenet_ddp.py │ ├── clip_8192_laion_aesthetics_ddp.py │ ├── README.md │ ├── openclip_L_14_8192_imagenet_ddp.py │ ├── openclip_bigG_14_8192_imagenet_ddp.py │ ├── interface.py │ └── model.py ├── strategies │ ├── find_unused_parameters.py │ ├── base.py │ ├── ddp.py │ ├── fsdp.py │ ├── cuda.py │ └── interface.py ├── fsq │ ├── custom_imports.py │ ├── 64000_imagenet_ddp.py │ ├── 8000_imagenet_ddp.py │ ├── README.md │ ├── interface.py │ └── model.py ├── cvqvae │ ├── custom_imports.py │ ├── quantizer.py │ ├── README.md │ └── 8192_dd2_aglwg075_imagenet_ddp.py ├── decoder │ ├── vqkd.py │ ├── vqkd_large.py │ ├── llamagen.py │ ├── vqgan.py │ ├── interface.py │ └── README.md ├── cluster │ ├── custom_imports.py │ ├── README.md │ ├── clip_8192_imagenet_ddp.py │ ├── encoders │ │ ├── interface.py │ │ ├── dino.py │ │ ├── clip.py │ │ ├── vit.py │ │ └── mae.py │ ├── interface.py │ ├── model.py │ └── runner.py ├── datasets │ ├── batch_size_in_total.py │ ├── transforms │ │ ├── none.py │ │ ├── default.py │ │ ├── weak.py │ │ ├── strong.py │ │ └── interface.py │ ├── interface.py │ ├── imagenet.py │ ├── vanilla.py │ ├── ffhq.py │ ├── celeba_hq.py │ ├── coco_2014.py │ ├── sa_med2d_20m.py │ ├── batch_size.py │ ├── laion_aesthetics.py │ ├── hq_faces.py │ └── satin.py ├── vq │ ├── distance.py │ ├── embedding_dim.py │ ├── interface.py │ └── num_embeddings.py ├── exps │ └── llamagen_vqgan_imagenet_ddp-no_refine_layer.py └── fid │ └── interface.py ├── docs ├── model_card.md ├── assets │ └── fig1.jpg ├── installation.md ├── data.md ├── inference.md ├── validation.md ├── training.md └── pretrained_models.md ├── .todd_version ├── .flake8 ├── tools ├── debugpy.sh ├── model_ema.py └── fid.py ├── setup.sh ├── .vscode ├── tasks.json ├── settings.json └── launch.json ├── setup.py ├── makefile └── README.md /vq/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/sq/interface.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vq/tasks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vq/algorithms/exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vq/algorithms/nar/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vq/algorithms/nar/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/model_card.md: -------------------------------------------------------------------------------- 1 | # Model Card 2 | 3 | TBD. 4 | -------------------------------------------------------------------------------- /vq/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /.todd_version: -------------------------------------------------------------------------------- 1 | ed2a3ae75a6698e6dafda5e5dba67ac88e1daa22 2 | -------------------------------------------------------------------------------- /vq/algorithms/sq/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizers import * 2 | -------------------------------------------------------------------------------- /vq/runners/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .visual import * 2 | -------------------------------------------------------------------------------- /vq/algorithms/fsq/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizers import * 2 | -------------------------------------------------------------------------------- /configs/llamagen/README.md: -------------------------------------------------------------------------------- 1 | # LlamaGen 2 | 3 | TBD. 4 | 5 | 6 | -------------------------------------------------------------------------------- /vq/runners/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .fid import * 2 | from .loss import * 3 | -------------------------------------------------------------------------------- /vq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from .registries import * 3 | -------------------------------------------------------------------------------- /vq/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoders import * 2 | from .registries import * 3 | -------------------------------------------------------------------------------- /configs/ar/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.algorithms.ar', 3 | ] 4 | -------------------------------------------------------------------------------- /configs/ic/README.md: -------------------------------------------------------------------------------- 1 | # Image Classification 2 | 3 | TBD. 4 | 5 | 6 | -------------------------------------------------------------------------------- /configs/vqgan/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.algorithms.vqgan', 3 | ] 4 | -------------------------------------------------------------------------------- /configs/vqkd/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.algorithms.vqkd', 3 | ] 4 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/openclip.py: -------------------------------------------------------------------------------- 1 | distiller = dict(teacher=dict(type='OpenCLIPTeacher')) 2 | -------------------------------------------------------------------------------- /vq/algorithms/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoders import * 2 | from .base import * 3 | -------------------------------------------------------------------------------- /vq/algorithms/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoders import * 2 | from .losses import * 3 | -------------------------------------------------------------------------------- /vq/algorithms/vq/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize import * 2 | from .update import * 3 | -------------------------------------------------------------------------------- /configs/ic/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.tasks.image_classification', 3 | ] 4 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/quantizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .callbacks import * 3 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .metrics import * 3 | -------------------------------------------------------------------------------- /vq/algorithms/exp/vqgan_vqkd/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .quantizer_callback import * 3 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models, runners 2 | from .registries import * 3 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models, runners 2 | from .registries import * 3 | -------------------------------------------------------------------------------- /configs/strategies/find_unused_parameters.py: -------------------------------------------------------------------------------- 1 | trainer = dict(wrap_model=dict(find_unused_parameters=True)) 2 | -------------------------------------------------------------------------------- /docs/assets/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/vector_quantization/HEAD/docs/assets/fig1.jpg -------------------------------------------------------------------------------- /vq/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks, metrics 2 | from .base import * 3 | from .registries import * 4 | -------------------------------------------------------------------------------- /configs/fsq/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.algorithms.vqgan', 3 | 'vq.algorithms.fsq', 4 | ] 5 | -------------------------------------------------------------------------------- /vq/tasks/image_reconstruction/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | from .models import * 3 | from .registries import * 4 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizer_holder import * 2 | from .ste import * 3 | -------------------------------------------------------------------------------- /vq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .builders import * 2 | from .fid import * 3 | from .misc import * 4 | from .stores import * 5 | -------------------------------------------------------------------------------- /configs/ar/c2i_gpt2_medium_cfg_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'c2i_gpt2_medium_imagenet_ddp.py', 3 | 'cfg.py', 4 | ] 5 | -------------------------------------------------------------------------------- /configs/cvqvae/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.algorithms.vqgan', 3 | 'vq.algorithms.cvqvae', 4 | ] 5 | -------------------------------------------------------------------------------- /configs/decoder/vqkd.py: -------------------------------------------------------------------------------- 1 | _export_ = dict( 2 | type='VQKDDecoder', 3 | out_chans=3, 4 | out_patch_size=16, 5 | ) 6 | -------------------------------------------------------------------------------- /vq/algorithms/cvqvae/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchors import * 2 | from .quantizer_callback import * 3 | from .registries import * 4 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .patchgan import * 3 | from .stylegan2 import * 4 | -------------------------------------------------------------------------------- /vq/tasks/image_classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .optimizers import * 3 | from .registries import * 4 | -------------------------------------------------------------------------------- /configs/ar/c2i_llama_medium_cfg_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'c2i_llama_medium_imagenet_ddp.py', 3 | 'cfg.py', 4 | ] 5 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/clip.py: -------------------------------------------------------------------------------- 1 | distiller = dict(teacher=dict( 2 | type='CLIPTeacher', 3 | downsample_factor=16, 4 | )) 5 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator import * 2 | from .generator import * 3 | from .registries import * 4 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/connectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .composed import * 3 | from .conv import * 4 | -------------------------------------------------------------------------------- /configs/decoder/vqkd_large.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'vqkd.py', 3 | ] 4 | 5 | _export_ = dict(embed_dim=1280, depth=32, num_heads=16) 6 | -------------------------------------------------------------------------------- /configs/strategies/base.py: -------------------------------------------------------------------------------- 1 | trainer = dict(strategy=dict(type='BaseStrategy')) 2 | validator = dict(strategy=dict(type='BaseStrategy')) 3 | -------------------------------------------------------------------------------- /vq/algorithms/ar/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .gpt import * 3 | from .hf import * 4 | from .llama import * 5 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import connectors, quantizers 2 | from .base import * 3 | from .registries import * 4 | -------------------------------------------------------------------------------- /vq/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | 3 | from . import algorithms, datasets, models, runners, tasks, utils 4 | from .registries import * 5 | -------------------------------------------------------------------------------- /configs/ar/c2i.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'x2i.py', 3 | ] 4 | 5 | trainer = dict(model=dict(type='ARC2I')) 6 | validator = dict(model=dict(type='ARC2I')) 7 | -------------------------------------------------------------------------------- /configs/cluster/custom_imports.py: -------------------------------------------------------------------------------- 1 | custom_imports = [ 2 | 'vq.algorithms.vqgan', 3 | 'vq.algorithms.cvqvae', 4 | 'vq.algorithms.cluster', 5 | ] 6 | -------------------------------------------------------------------------------- /vq/algorithms/ar/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transformers 2 | from .base import * 3 | from .c2i import * 4 | from .image import * 5 | from .x2i import * 6 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .composed import * 3 | from .lazy_init_weights import * 4 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/dino.py: -------------------------------------------------------------------------------- 1 | decoder = dict(out_chans=768) 2 | distiller = dict(teacher=dict( 3 | type='DINOTeacher', 4 | downsample_factor=16, 5 | )) 6 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .callbacks import * 2 | from .metrics import * 3 | from .registries import * 4 | from .tokenizer import * 5 | -------------------------------------------------------------------------------- /configs/datasets/batch_size_in_total.py: -------------------------------------------------------------------------------- 1 | trainer = dict(dataloader=dict(batch_size_in_total=True)) 2 | validator = dict(dataloader=dict(batch_size_in_total=True)) 3 | -------------------------------------------------------------------------------- /configs/vqgan/8192_dd2_aglwg075_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | '8192_imagenet_ddp.py', 5 | 'dd2_aglwg075.py', 6 | ] 7 | -------------------------------------------------------------------------------- /vq/algorithms/vq/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks 2 | from .distances import * 3 | from .losses import * 4 | from .quantizers import * 5 | from .utils import * 6 | -------------------------------------------------------------------------------- /configs/vqgan/16384_dd2_aglwg075_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | '16384_imagenet_ddp.py', 5 | 'dd2_aglwg075.py', 6 | ] 7 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks, utils 2 | from .base import * 3 | from .losses import * 4 | from .registries import * 5 | -------------------------------------------------------------------------------- /vq/algorithms/cvqvae/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'AnchorRegistry', 3 | ] 4 | 5 | import todd 6 | 7 | 8 | class AnchorRegistry(todd.Registry): 9 | pass 10 | -------------------------------------------------------------------------------- /configs/strategies/ddp.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'cuda.py', 3 | ] 4 | 5 | trainer = dict(strategy=dict(type='DDPStrategy')) 6 | validator = dict(strategy=dict(type='DDPStrategy')) 7 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/__init__.py: -------------------------------------------------------------------------------- 1 | from . import quantizers, teachers 2 | from .autoencoder import * 3 | from .base import * 4 | from .connector import * 5 | from .registries import * 6 | -------------------------------------------------------------------------------- /configs/cluster/README.md: -------------------------------------------------------------------------------- 1 | # Cluster 2 | 3 | Training: 4 | 5 | ```bash 6 | auto_torchrun -m vq.train cluster/clip_8192_imagenet_ddp configs/cluster/clip_8192_imagenet_ddp.py 7 | ``` 8 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQTeacherRegistry', 3 | ] 4 | 5 | from vq import VQModelRegistry 6 | 7 | 8 | class VQTeacherRegistry(VQModelRegistry): 9 | pass 10 | -------------------------------------------------------------------------------- /vq/utils/stores.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Store', 3 | ] 4 | 5 | from todd.utils import StoreMeta 6 | 7 | 8 | class Store(metaclass=StoreMeta): 9 | DEBUG: bool 10 | PRETRAINED: str 11 | -------------------------------------------------------------------------------- /configs/strategies/fsdp.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'ddp.py', 3 | ] 4 | 5 | trainer = dict( 6 | strategy=dict(type='FSDPStrategy'), 7 | wrap_model=dict(sync_module_states=True, use_orig_params=True), 8 | ) 9 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import discriminators, losses 2 | from .autoencoder import * 3 | from .model import * 4 | from .quantizer import * 5 | from .registries import * 6 | from .trainer import * 7 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQDiscriminatorRegistry', 3 | ] 4 | 5 | from vq import VQModelRegistry 6 | 7 | 8 | class VQDiscriminatorRegistry(VQModelRegistry): 9 | pass 10 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/teachers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .clip import * 3 | from .convnext import * 4 | from .dino import * 5 | from .mae import * 6 | from .torchvision import * 7 | from .vit import * 8 | -------------------------------------------------------------------------------- /configs/cvqvae/quantizer.py: -------------------------------------------------------------------------------- 1 | callback = dict( 2 | type='CVQVAECallback', 3 | ema=dict(), 4 | anchor=dict(type='NearestAnchor'), 5 | ) 6 | _export_ = dict(trainer=dict(model=dict(quantizer=dict(callbacks=[callback])))) 7 | -------------------------------------------------------------------------------- /configs/vqgan/dd2_aglwg075.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | discriminator=dict(depth=2), 3 | adaptive_generator_loss_weight_gain=0.75, 4 | ) 5 | runner = dict(model=model) 6 | 7 | _export_ = dict(trainer=runner, validator=runner) 8 | -------------------------------------------------------------------------------- /configs/vqgan/f8.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | encoder=dict(width_mults=(1, 1, 2, 4)), 3 | decoder=dict(width_mults=(4, 2, 1, 1)), 4 | ) 5 | runner = dict(model=model) 6 | 7 | _export_ = dict(trainer=runner, validator=runner) 8 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .c2i import * 3 | from .image import * 4 | from .registries import * 5 | from .samplers import * 6 | from .transformers import * 7 | from .x2i import * 8 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/llamagen/vqgan.py', 5 | dataset='imagenet', 6 | strategy='ddp', 7 | find_unused_parameters=True, 8 | ) 9 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = D100,D101,D102,D103,D104,D105,D107,E126,E201,E241,N805,N812,N817,W503,W504 3 | per-file-ignores = 4 | vq/__init__.py:F401,F403 5 | vq/**/__init__.py:F401,F403 6 | max-complexity = 12 7 | builtins = _kwargs_ 8 | -------------------------------------------------------------------------------- /configs/ar/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | model = _kwargs_['model'] 6 | 7 | _base_ = [ 8 | f'{model}.py', 9 | ] 10 | 11 | _export_: dict[str, Any] = dict() 12 | -------------------------------------------------------------------------------- /configs/vqkd/clip_8192_satin_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='clip', 6 | num_embeddings=8192, 7 | dataset='satin', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/mae_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='mae', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/vit_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='vit', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/utils/ste.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ste', 3 | ] 4 | 5 | import torch 6 | import torch.distributed 7 | 8 | 9 | def ste(z: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 10 | return x + (z - x).detach() 11 | -------------------------------------------------------------------------------- /configs/vqkd/clip_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='clip', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/dino_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='dino', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/vit.py: -------------------------------------------------------------------------------- 1 | decoder = dict(out_chans=768) 2 | distiller = dict( 3 | teacher=dict( 4 | type='ViTTeacher', 5 | model=dict(type='vit_b_16', weights='.ViT_B_16_Weights.IMAGENET1K_V1'), 6 | downsample_factor=16, 7 | ), 8 | ) 9 | -------------------------------------------------------------------------------- /configs/vqkd/vit_16384_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='vit', 6 | num_embeddings=16384, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/cluster/clip_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/cluster/interface.py', 5 | encoder='clip', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/clip_8192_sa_med2d_20m_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='clip', 6 | num_embeddings=8192, 7 | dataset='sa_med2d_20m', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/convnext_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='convnext', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/evaclip_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='evaclip', 6 | num_embeddings=8192, 7 | dataset='imagenet', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /vq/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .coco import * 3 | from .concat import * 4 | from .imagenet import * 5 | from .laion_aesthetics import * 6 | from .sa_med2d import * 7 | from .satin import * 8 | from .split import * 9 | from .vanilla import * 10 | -------------------------------------------------------------------------------- /vq/tasks/image_classification/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQICModelRegistry', 3 | ] 4 | 5 | from vq import VQModelRegistry 6 | 7 | from ..registries import VQICRegistry 8 | 9 | 10 | class VQICModelRegistry(VQICRegistry, VQModelRegistry): 11 | pass 12 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan_imagenet_384_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/llamagen/vqgan.py', 5 | dataset='imagenet', 6 | strategy='ddp', 7 | find_unused_parameters=True, 8 | image_size=384, 9 | ) 10 | -------------------------------------------------------------------------------- /configs/vqkd/clip_8192_laion_aesthetics_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/vqkd/interface.py', 5 | teacher='clip', 6 | num_embeddings=8192, 7 | dataset='laion_aesthetics', 8 | strategy='ddp', 9 | ) 10 | -------------------------------------------------------------------------------- /configs/cluster/encoders/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | encoder = _kwargs_['encoder'] 8 | 9 | _export_ = PyConfig.load(f'configs/cluster/encoders/{encoder}.py') 10 | -------------------------------------------------------------------------------- /configs/vqkd/README.md: -------------------------------------------------------------------------------- 1 | # VQ-KD 2 | 3 | [[arXiv](https://arxiv.org/abs/2208.06366)] [[GitHub](https://github.com/microsoft/unilm/tree/master/beit2)] 4 | 5 | Training: 6 | 7 | ```bash 8 | auto_torchrun -m vq.train vqkd/clip_8192_imagenet_ddp configs/vqkd/clip_8192_imagenet_ddp.py 9 | ``` 10 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/runners/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQSMMetricRegistry', 3 | ] 4 | 5 | from vq.runners import VQMetricRegistry 6 | 7 | from ..registries import VQSMRunnerRegistry 8 | 9 | 10 | class VQSMMetricRegistry(VQSMRunnerRegistry, VQMetricRegistry): 11 | pass 12 | -------------------------------------------------------------------------------- /configs/cvqvae/README.md: -------------------------------------------------------------------------------- 1 | # CVQ-VAE 2 | 3 | [[arXiv](https://arxiv.org/abs/2307.15139)] [[GitHub](https://github.com/lyndonzheng/CVQ-VAE)] 4 | 5 | Training: 6 | 7 | ```bash 8 | auto_torchrun -m vq.train cvqvae/8192_dd2_aglwg075_imagenet_ddp configs/cvqvae/8192_dd2_aglwg075_imagenet_ddp.py 9 | ``` 10 | -------------------------------------------------------------------------------- /configs/fsq/64000_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/fsq/interface.py', 7 | codebook_size=64000, 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/fsq/8000_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/fsq/interface.py', 7 | codebook_size=8000, 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan_128_imagenet_384_ddp.py: -------------------------------------------------------------------------------- 1 | from todd.configs import PyConfig 2 | 3 | _export_ = PyConfig.load( 4 | 'configs/llamagen/vqgan.py', 5 | dataset='imagenet', 6 | strategy='ddp', 7 | find_unused_parameters=True, 8 | image_size=384, 9 | num_embeddings=128, 10 | ) 11 | -------------------------------------------------------------------------------- /configs/vqgan/8192_satin_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqgan/interface.py', 7 | num_embeddings=8192, 8 | dataset='satin', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/evaclip.py: -------------------------------------------------------------------------------- 1 | decoder = dict(out_chans=768) 2 | distiller = dict( 3 | teacher=dict( 4 | type='EVACLIPTeacher', 5 | model=dict(model_name='EVA02-CLIP-B-16', pretrained='eva02_clip'), 6 | downsample_factor=16, 7 | image_wh=(224, 224), 8 | ), 9 | ) 10 | -------------------------------------------------------------------------------- /configs/ar/transformers/gpt2.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | transformer_size = _kwargs_['transformer_size'] 6 | 7 | _export_ = dict( 8 | type='GPT2Transformer', 9 | transformer=f'pretrained/huggingface/gpt2-{transformer_size}', 10 | ) 11 | -------------------------------------------------------------------------------- /configs/ic/imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _export_ = PyConfig.load( 9 | 'configs/ic/interface.py', 10 | dataset='imagenet', 11 | strategy='ddp', 12 | **_kwargs_, 13 | ) 14 | -------------------------------------------------------------------------------- /configs/vqgan/1024_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqgan/interface.py', 7 | num_embeddings=1024, 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vqgan/16384_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqgan/interface.py', 7 | num_embeddings=16384, 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vqgan/8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqgan/interface.py', 7 | num_embeddings=8192, 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/strategies/cuda.py: -------------------------------------------------------------------------------- 1 | trainer = dict( 2 | strategy=dict(type='CUDAStrategy'), 3 | dataloader=dict(sampler=dict(type='DistributedSampler', shuffle=True)), 4 | ) 5 | validator = dict( 6 | strategy=dict(type='CUDAStrategy'), 7 | dataloader=dict(sampler=dict(type='DistributedSampler', shuffle=False)), 8 | ) 9 | -------------------------------------------------------------------------------- /configs/vq/distance.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | strategy = _kwargs_['distance'] 6 | 7 | type_ = f'{strategy}Distance' 8 | runner = dict(model=dict(quantizer=dict(distance=dict(type=type_)))) 9 | 10 | _export_ = dict(trainer=runner, validator=runner) 11 | -------------------------------------------------------------------------------- /configs/vqgan/8192_sa_med2d_20m_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqgan/interface.py', 7 | num_embeddings=8192, 8 | dataset='sa_med2d_20m', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vqkd/openclip_L_14_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqkd/interface.py', 7 | teacher='openclip_L_14', 8 | num_embeddings=8192, 9 | dataset='imagenet', 10 | strategy='ddp', 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vqgan/8192_laion_aesthetics_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqgan/interface.py', 7 | num_embeddings=8192, 8 | dataset='laion_aesthetics', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vqkd/openclip_bigG_14_8192_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from todd.configs import PyConfig 4 | 5 | _export_ = PyConfig.load( 6 | 'configs/vqkd/interface.py', 7 | teacher='openclip_bigG_14', 8 | num_embeddings=8192, 9 | dataset='imagenet', 10 | strategy='ddp', 11 | ) 12 | -------------------------------------------------------------------------------- /configs/vq/embedding_dim.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | embedding_dim = _kwargs_['embedding_dim'] 6 | 7 | runner = dict( 8 | model=dict(quantizer=dict(embedding=dict(embedding_dim=embedding_dim))), 9 | ) 10 | 11 | _export_ = dict(trainer=runner, validator=runner) 12 | -------------------------------------------------------------------------------- /configs/vq/interface.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'num_embeddings.py', 3 | 'embedding_dim.py', 4 | 'distance.py', 5 | ] 6 | 7 | type_ = 'torch_nn_modules_sparse_Embedding' # pylint: disable=invalid-name 8 | runner = dict(model=dict(quantizer=dict(embedding=dict(type=type_)))) 9 | 10 | _export_ = dict(trainer=runner, validator=runner) 11 | -------------------------------------------------------------------------------- /configs/vq/num_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | num_embeddings = _kwargs_['num_embeddings'] 6 | 7 | runner = dict( 8 | model=dict(quantizer=dict(embedding=dict(num_embeddings=num_embeddings))), 9 | ) 10 | 11 | _export_ = dict(trainer=runner, validator=runner) 12 | -------------------------------------------------------------------------------- /vq/algorithms/sq/quantizers.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ScalarQuantizer', 3 | ] 4 | 5 | from vq.tasks.image_tokenization.models import VQITQuantizerRegistry 6 | from vq.tasks.image_tokenization.models.quantizers import BaseQuantizer 7 | 8 | 9 | @VQITQuantizerRegistry.register_() 10 | class ScalarQuantizer(BaseQuantizer): 11 | pass 12 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/losses/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQDiscriminatorLossRegistry', 3 | 'VQGeneratorLossRegistry', 4 | ] 5 | 6 | from vq.models import VQLossRegistry 7 | 8 | 9 | class VQDiscriminatorLossRegistry(VQLossRegistry): 10 | pass 11 | 12 | 13 | class VQGeneratorLossRegistry(VQLossRegistry): 14 | pass 15 | -------------------------------------------------------------------------------- /configs/exps/llamagen_vqgan_imagenet_ddp-no_refine_layer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | '../llamagen/vqgan_imagenet_ddp.py', 5 | ] 6 | 7 | coder = dict(attention_layer=None, refine_layer=None) 8 | runner = dict(model=dict(encoder=coder, decoder=coder)) 9 | 10 | _export_ = dict(trainer=runner, validator=runner) 11 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan_stylegan2_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'vqgan_imagenet_ddp.py', 3 | ] 4 | 5 | discriminator = dict( 6 | _delete_=True, 7 | type='StyleGAN2Discriminator', 8 | image_size=256, 9 | ) 10 | runner = dict(model=dict(discriminator=discriminator)) 11 | 12 | _export_ = dict(trainer=runner, validator=runner) 13 | -------------------------------------------------------------------------------- /vq/algorithms/ar/image.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ImageMixin', 3 | ] 4 | 5 | import enum 6 | from typing import TypeVar 7 | 8 | from vq.tasks.sequence_modeling.models import ImageModel 9 | 10 | from .base import BaseMixin 11 | 12 | T = TypeVar('T', bound=enum.Enum) 13 | 14 | 15 | class ImageMixin(BaseMixin[T], ImageModel[T]): 16 | pass 17 | -------------------------------------------------------------------------------- /vq/algorithms/ar/c2i.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ARC2I', 3 | ] 4 | 5 | import todd.tasks.large_multimodal_model as lmm 6 | 7 | from vq import VQModelRegistry 8 | from vq.tasks.sequence_modeling.models import C2I 9 | 10 | from .x2i import X2IMixin 11 | 12 | 13 | @VQModelRegistry.register_() 14 | class ARC2I(X2IMixin[lmm.C2IEnum], C2I): 15 | pass 16 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/discriminators/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseDiscriminator', 3 | ] 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class BaseDiscriminator(nn.Module, ABC): 12 | 13 | @abstractmethod 14 | def forward(self, image: torch.Tensor) -> torch.Tensor: 15 | pass 16 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | teacher = _kwargs_['teacher'] 8 | 9 | model = PyConfig.load(f'configs/vqkd/teachers/{teacher}.py') 10 | runner = dict(model=model) 11 | 12 | _export_ = dict(trainer=runner, validator=runner) 13 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/openclip_L_14.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | 'openclip.py', 5 | ] 6 | 7 | decoder = dict(out_chans=768) 8 | distiller = dict( 9 | teacher=dict( 10 | model=dict(model_name='ViT-L-14', pretrained='laion2B_s32B_b82k'), 11 | downsample_factor=14, 12 | image_wh=(224, 224), 13 | ), 14 | ) 15 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/openclip_H_14.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | 'openclip.py', 5 | ] 6 | 7 | decoder = dict(out_chans=1024) 8 | distiller = dict( 9 | teacher=dict( 10 | model=dict(model_name='ViT-H-14', pretrained='laion2B_s32B_b79k'), 11 | downsample_factor=14, 12 | image_wh=(224, 224), 13 | ), 14 | ) 15 | -------------------------------------------------------------------------------- /configs/fsq/README.md: -------------------------------------------------------------------------------- 1 | # FSQ 2 | 3 | [[arXiv](https://arxiv.org/abs/2309.15505)] [[GitHub](https://github.com/google-research/google-research/tree/master/fsq)] 4 | 5 | Training: 6 | 7 | ```bash 8 | auto_torchrun -m vq.train fsq/8000_imagenet_ddp configs/fsq/8000_imagenet_ddp.py 9 | auto_torchrun -m vq.train fsq/64000_imagenet_ddp configs/fsq/64000_imagenet_ddp.py 10 | ``` 11 | -------------------------------------------------------------------------------- /configs/vqgan/8192_stylegan2_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | '8192_imagenet_ddp.py', 5 | ] 6 | 7 | discriminator = dict( 8 | _delete_=True, 9 | type='StyleGAN2Discriminator', 10 | image_size=256, 11 | ) 12 | runner = dict(model=dict(discriminator=discriminator)) 13 | 14 | _export_ = dict(trainer=runner, validator=runner) 15 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/openclip_bigG_14.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | _base_ = [ 4 | 'openclip.py', 5 | ] 6 | 7 | decoder = dict(out_chans=1280) 8 | distiller = dict( 9 | teacher=dict( 10 | model=dict(model_name='ViT-bigG-14', pretrained='laion2b_s39b_b160k'), 11 | downsample_factor=14, 12 | image_wh=(224, 224), 13 | ), 14 | ) 15 | -------------------------------------------------------------------------------- /configs/llamagen/c2i_medium_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _export_ = PyConfig.load( 9 | 'configs/llamagen/ar.py', 10 | model='c2i', 11 | transformer_size='medium', 12 | dataset='imagenet', 13 | strategy='ddp', 14 | **_kwargs_, 15 | ) 16 | -------------------------------------------------------------------------------- /vq/runners/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQCallbackRegistry', 3 | 'VQMetricRegistry', 4 | ] 5 | 6 | from todd.runners import CallbackRegistry, MetricRegistry 7 | 8 | from ..registries import VQRegistry 9 | 10 | 11 | class VQCallbackRegistry(VQRegistry, CallbackRegistry): 12 | pass 13 | 14 | 15 | class VQMetricRegistry(VQRegistry, MetricRegistry): 16 | pass 17 | -------------------------------------------------------------------------------- /configs/ar/cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | dropout = _kwargs_.get('dropout', 0.1) 6 | alpha = _kwargs_.get('alpha', 1.75) 7 | 8 | _export_ = dict( 9 | trainer=dict(model=dict(cfg=dropout)), 10 | validator=dict( 11 | model=dict(cfg=dropout, transformer=dict(sampler=dict(cfg=alpha))), 12 | ), 13 | ) 14 | -------------------------------------------------------------------------------- /configs/datasets/transforms/none.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | image_size = _kwargs_['image_size'] 6 | 7 | transforms = [ 8 | dict(type='Resize', size=image_size, interpolation=3), 9 | dict(type='CenterCrop', size=image_size), 10 | dict(type='PILToTensor'), 11 | ] 12 | 13 | _export_ = dict(transforms=transforms) 14 | -------------------------------------------------------------------------------- /configs/datasets/transforms/default.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | image_size = _kwargs_['image_size'] 6 | 7 | transforms = [ 8 | dict(type='RandomResizedCrop', size=image_size, interpolation=3), 9 | dict(type='RandomHorizontalFlip'), 10 | dict(type='PILToTensor'), 11 | ] 12 | 13 | _export_ = dict(transforms=transforms) 14 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQITModelRegistry', 3 | 'VQITRunnerRegistry', 4 | ] 5 | 6 | from vq import VQModelRegistry, VQRunnerRegistry 7 | 8 | from ..registries import VQITRegistry 9 | 10 | 11 | class VQITModelRegistry(VQITRegistry, VQModelRegistry): 12 | pass 13 | 14 | 15 | class VQITRunnerRegistry(VQITRegistry, VQRunnerRegistry): 16 | pass 17 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQSMModelRegistry', 3 | 'VQSMRunnerRegistry', 4 | ] 5 | 6 | from vq import VQModelRegistry, VQRunnerRegistry 7 | 8 | from ..registries import VQSMRegistry 9 | 10 | 11 | class VQSMModelRegistry(VQSMRegistry, VQModelRegistry): 12 | pass 13 | 14 | 15 | class VQSMRunnerRegistry(VQSMRegistry, VQRunnerRegistry): 16 | pass 17 | -------------------------------------------------------------------------------- /configs/strategies/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | strategy = _kwargs_['strategy'] 6 | find_unused_parameters = _kwargs_.get('find_unused_parameters', False) 7 | 8 | _base_ = [f'{strategy}.py'] 9 | 10 | if find_unused_parameters: 11 | _base_.append('find_unused_parameters.py') 12 | 13 | _export_: dict[str, Any] = dict() 14 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/mae.py: -------------------------------------------------------------------------------- 1 | decoder = dict(out_chans=768) 2 | distiller = dict( 3 | teacher=dict( 4 | type='MAETeacher', 5 | model=dict( 6 | type='mae_vit_base_patch16', 7 | init_weights=dict( 8 | pretrained='pretrained/mae/mae_pretrain_vit_base.pth', 9 | ), 10 | ), 11 | downsample_factor=16, 12 | ), 13 | ) 14 | -------------------------------------------------------------------------------- /configs/ar/c2i_gpt2_medium_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _export_ = PyConfig.load( 9 | 'configs/ar/interface.py', 10 | model='c2i', 11 | transformer='gpt2', 12 | transformer_size='medium', 13 | dataset='imagenet', 14 | strategy='ddp', 15 | **_kwargs_, 16 | ) 17 | -------------------------------------------------------------------------------- /configs/ar/c2i_llama_medium_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _export_ = PyConfig.load( 9 | 'configs/ar/interface.py', 10 | model='c2i', 11 | transformer='llama', 12 | transformer_size='medium', 13 | dataset='imagenet', 14 | strategy='ddp', 15 | **_kwargs_, 16 | ) 17 | -------------------------------------------------------------------------------- /configs/vqkd/teachers/convnext.py: -------------------------------------------------------------------------------- 1 | decoder = dict(out_chans=1024) 2 | distiller = dict( 3 | teacher=dict( 4 | type='ConvNeXtTeacher', 5 | model=dict( 6 | type='convnext_base', 7 | weights='.ConvNeXt_Base_Weights.IMAGENET1K_V1', 8 | ), 9 | downsample_factor=(32, 32), 10 | image_wh=(224, 224), # TODO: remove this after rebuttal 11 | ), 12 | ) 13 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/quantizers/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQKDQuantizer', 3 | ] 4 | 5 | import todd 6 | 7 | from vq.algorithms.vq import VectorQuantizer 8 | from vq.tasks.image_tokenization.models import VQITQuantizerRegistry 9 | 10 | 11 | @VQITQuantizerRegistry.register_() 12 | class VQKDQuantizer(VectorQuantizer): 13 | 14 | def _init_weights(self, config: todd.Config) -> bool: 15 | return False 16 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan_256_f8_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _base_ = [ 6 | PyConfig.load( 7 | 'configs/llamagen/vqgan.py', 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | num_embeddings=256, 12 | ), 13 | '../vqgan/f8.py', 14 | ] 15 | 16 | _export_: dict[str, Any] = dict() 17 | -------------------------------------------------------------------------------- /vq/tasks/image_reconstruction/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQIRModelRegistry', 3 | 'VQIRLossRegistry', 4 | ] 5 | 6 | from vq import VQModelRegistry 7 | from vq.models import VQLossRegistry 8 | 9 | from ..registries import VQIRRegistry 10 | 11 | 12 | class VQIRModelRegistry(VQIRRegistry, VQModelRegistry): 13 | pass 14 | 15 | 16 | class VQIRLossRegistry(VQIRModelRegistry, VQLossRegistry): 17 | pass 18 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQITQuantizerCallbackRegistry', 3 | 'VQITQuantizerLossRegistry', 4 | ] 5 | 6 | from ..registries import VQITLossRegistry, VQITQuantizerRegistry 7 | 8 | 9 | class VQITQuantizerCallbackRegistry(VQITQuantizerRegistry): 10 | pass 11 | 12 | 13 | class VQITQuantizerLossRegistry(VQITQuantizerRegistry, VQITLossRegistry): 14 | pass 15 | -------------------------------------------------------------------------------- /configs/datasets/transforms/weak.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | image_size = _kwargs_['image_size'] 6 | 7 | transforms = [ 8 | dict(type='Resize', size=image_size, interpolation=3), 9 | dict(type='RandomCrop', size=image_size), 10 | dict(type='RandomHorizontalFlip'), 11 | dict(type='PILToTensor'), 12 | ] 13 | 14 | _export_ = dict(transforms=transforms) 15 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/runners/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQITCallbackRegistry', 3 | 'VQITMetricRegistry', 4 | ] 5 | 6 | from vq.runners import VQCallbackRegistry, VQMetricRegistry 7 | 8 | from ..registries import VQITRunnerRegistry 9 | 10 | 11 | class VQITCallbackRegistry(VQITRunnerRegistry, VQCallbackRegistry): 12 | pass 13 | 14 | 15 | class VQITMetricRegistry(VQITRunnerRegistry, VQMetricRegistry): 16 | pass 17 | -------------------------------------------------------------------------------- /configs/ar/transformers/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | transformer = _kwargs_['transformer'] 8 | 9 | transformer = PyConfig.load( 10 | f'configs/ar/transformers/{transformer}.py', 11 | **_kwargs_, 12 | ) 13 | runner = dict(model=dict(transformer=transformer)) 14 | 15 | _export_ = dict(trainer=runner, validator=runner) 16 | -------------------------------------------------------------------------------- /vq/tasks/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQITRegistry', 3 | 'VQSMRegistry', 4 | 'VQIRRegistry', 5 | 'VQICRegistry', 6 | ] 7 | 8 | from ..registries import VQTaskRegistry 9 | 10 | 11 | class VQITRegistry(VQTaskRegistry): 12 | pass 13 | 14 | 15 | class VQSMRegistry(VQTaskRegistry): 16 | pass 17 | 18 | 19 | class VQIRRegistry(VQTaskRegistry): 20 | pass 21 | 22 | 23 | class VQICRegistry(VQTaskRegistry): 24 | pass 25 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan_128x16_f8_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _base_ = [ 6 | PyConfig.load( 7 | 'configs/llamagen/vqgan.py', 8 | dataset='imagenet', 9 | strategy='ddp', 10 | find_unused_parameters=True, 11 | num_embeddings=128, 12 | embedding_dim=16, 13 | ), 14 | '../vqgan/f8.py', 15 | ] 16 | 17 | _export_: dict[str, Any] = dict() 18 | -------------------------------------------------------------------------------- /vq/models/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQEncoderRegistry', 3 | 'VQDecoderRegistry', 4 | 'VQLossRegistry', 5 | ] 6 | 7 | from todd.models import LossRegistry 8 | 9 | from ..registries import VQModelRegistry, VQRegistry 10 | 11 | 12 | class VQEncoderRegistry(VQModelRegistry): 13 | pass 14 | 15 | 16 | class VQDecoderRegistry(VQModelRegistry): 17 | pass 18 | 19 | 20 | class VQLossRegistry(VQRegistry, LossRegistry): 21 | pass 22 | -------------------------------------------------------------------------------- /configs/datasets/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | dataset = _kwargs_['dataset'] 6 | batch_size_in_total = _kwargs_.get('batch_size_in_total', False) 7 | 8 | _base_ = [ 9 | f'{dataset}.py', 10 | 'batch_size.py', 11 | 'transforms/interface.py', 12 | ] 13 | 14 | if batch_size_in_total: 15 | _base_.append('batch_size_in_total.py') 16 | 17 | _export_: dict[str, Any] = dict() 18 | -------------------------------------------------------------------------------- /tools/debugpy.sh: -------------------------------------------------------------------------------- 1 | #!bash 2 | ssh -fN -L ${PORT}:localhost:${PORT} -p 9000 root@$1 3 | pipenv run python \ 4 | -m debugpy \ 5 | --connect localhost:${PORT} \ 6 | $(pipenv --venv)/bin/torchrun \ 7 | --nnodes ${ARNOLD_WORKER_NUM} \ 8 | --nproc-per-node ${ARNOLD_WORKER_GPU} \ 9 | --node-rank ${ARNOLD_ID} \ 10 | --master-addr ${METIS_WORKER_0_HOST} \ 11 | --master-port ${METIS_WORKER_0_PORT} \ 12 | ${@:2} 13 | -------------------------------------------------------------------------------- /configs/datasets/transforms/strong.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | image_size = _kwargs_['image_size'] 6 | 7 | transforms = [ 8 | dict( 9 | type='RandomResizedCrop', 10 | size=image_size, 11 | interpolation=3, 12 | scale=(0.8, 1.0), 13 | ), 14 | dict(type='RandomHorizontalFlip'), 15 | dict(type='PILToTensor'), 16 | ] 17 | 18 | _export_ = dict(transforms=transforms) 19 | -------------------------------------------------------------------------------- /configs/decoder/llamagen.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault( 9 | 'ir_config', 10 | 'configs/llamagen/vqgan_imagenet_ddp.py', 11 | ) 12 | 13 | _export_ = PyConfig.load( 14 | 'configs/decoder/interface.py', 15 | num_embeddings=8192, 16 | dataset='imagenet', 17 | strategy='ddp', 18 | find_unused_parameters=True, 19 | **_kwargs_, 20 | ) 21 | -------------------------------------------------------------------------------- /configs/decoder/vqgan.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault( 9 | 'ir_config', 10 | 'configs/vqgan/8192_dd2_aglwg075_imagenet_ddp.py', 11 | ) 12 | 13 | _export_ = PyConfig.load( 14 | 'configs/decoder/interface.py', 15 | num_embeddings=8192, 16 | dataset='imagenet', 17 | strategy='ddp', 18 | find_unused_parameters=True, 19 | **_kwargs_, 20 | ) 21 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQITConnectorRegistry', 3 | 'VQITLossRegistry', 4 | 'VQITQuantizerRegistry', 5 | ] 6 | 7 | from vq.models import VQLossRegistry 8 | 9 | from ..registries import VQITModelRegistry 10 | 11 | 12 | class VQITConnectorRegistry(VQITModelRegistry): 13 | pass 14 | 15 | 16 | class VQITLossRegistry(VQITModelRegistry, VQLossRegistry): 17 | pass 18 | 19 | 20 | class VQITQuantizerRegistry(VQITModelRegistry): 21 | pass 22 | -------------------------------------------------------------------------------- /configs/cvqvae/8192_dd2_aglwg075_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | from typing import Any 4 | 5 | from todd.configs import PyConfig 6 | 7 | _kwargs_: dict[str, Any] 8 | _kwargs_ = dict(_kwargs_) 9 | 10 | _kwargs_.setdefault('distance', 'Cosine') 11 | 12 | _base_ = [ 13 | PyConfig.load( 14 | 'configs/vqgan/8192_dd2_aglwg075_imagenet_ddp.py', 15 | **_kwargs_, 16 | ), 17 | 'custom_imports.py', 18 | 'quantizer.py', 19 | ] 20 | 21 | _export_: dict[str, Any] = dict() 22 | -------------------------------------------------------------------------------- /configs/cluster/encoders/dino.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 768) 9 | 10 | _base_ = [ 11 | PyConfig.load('configs/vq/embedding_dim.py', **_kwargs_), 12 | ] 13 | 14 | teacher = dict( 15 | type='DINOTeacher', 16 | downsample_factor=16, 17 | ) 18 | runner = dict(model=dict(encoder=dict(teacher=teacher))) 19 | 20 | _export_ = dict(trainer=runner, validator=runner) 21 | -------------------------------------------------------------------------------- /configs/ar/transformers/llama.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | transformer_size = _kwargs_['transformer_size'] 6 | 7 | TRANSFORMERS = dict( 8 | medium=dict( 9 | num_hidden_layers=24, 10 | num_attention_heads=16, 11 | hidden_size=1024, 12 | intermediate_size=2816, 13 | rms_norm_eps=1e-5, 14 | ), 15 | ) 16 | 17 | _export_ = dict( 18 | type='LlamaTransformer', 19 | transformer=TRANSFORMERS[transformer_size], 20 | ) 21 | -------------------------------------------------------------------------------- /configs/ic/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | it_config = _kwargs_['it_config'] 6 | it_state_dict = _kwargs_['it_state_dict'] 7 | 8 | model = dict( 9 | type='VQModelRegistry.VQICModelRegistry.BaseModel', 10 | freeze=dict(type='NamedModulesFilter', name='_it'), 11 | filter_state_dict=True, 12 | it=dict(config=it_config, state_dicts=[it_state_dict], strict=False), 13 | ) 14 | runner = dict(model=model) 15 | 16 | _export_ = dict(trainer=runner, validator=runner) 17 | -------------------------------------------------------------------------------- /configs/cluster/encoders/clip.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 768) 9 | 10 | _base_ = [ 11 | PyConfig.load('configs/vq/embedding_dim.py', **_kwargs_), 12 | ] 13 | 14 | teacher = dict( 15 | type='CLIPTeacher', 16 | downsample_factor=16, 17 | with_proj=False, 18 | ) 19 | runner = dict(model=dict(encoder=dict(teacher=teacher))) 20 | 21 | _export_ = dict(trainer=runner, validator=runner) 22 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/losses.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseLoss', 3 | ] 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import torch 8 | import torch.distributed 9 | from todd.models import losses 10 | from todd.runners import Memo 11 | 12 | 13 | class BaseLoss(losses.BaseLoss, ABC): 14 | 15 | @abstractmethod 16 | def forward( # pylint: disable=arguments-differ 17 | self, 18 | z: torch.Tensor, 19 | x: torch.Tensor, 20 | memo: Memo, 21 | ) -> torch.Tensor: 22 | pass 23 | -------------------------------------------------------------------------------- /vq/algorithms/ar/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseMixin', 3 | ] 4 | 5 | import enum 6 | from abc import abstractmethod 7 | from typing import Any, TypeVar 8 | 9 | import torch 10 | from todd.runners import Memo 11 | 12 | from vq.tasks.sequence_modeling.models import BaseModel as BaseSMModel 13 | 14 | T = TypeVar('T', bound=enum.Enum) 15 | 16 | 17 | class BaseMixin(BaseSMModel[T]): 18 | 19 | @abstractmethod 20 | def sample( 21 | self, 22 | logits: torch.Tensor, 23 | memo: Memo, 24 | ) -> tuple[Any, Memo]: 25 | pass 26 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'VQDatasetRegistry.ImageNetDataset' # noqa: E501 pylint: disable=invalid-name 2 | 3 | _export_ = dict( 4 | trainer=dict( 5 | dataset=dict( 6 | type=dataset_type, 7 | name='imagenet_train', 8 | num_categories=1000, 9 | split='train', 10 | ), 11 | ), 12 | validator=dict( 13 | dataset=dict( 14 | type=dataset_type, 15 | name='imagenet_val', 16 | num_categories=1000, 17 | split='val', 18 | ), 19 | ), 20 | ) 21 | -------------------------------------------------------------------------------- /configs/cluster/encoders/vit.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 768) 9 | 10 | _base_ = [ 11 | PyConfig.load('configs/vq/embedding_dim.py', **_kwargs_), 12 | ] 13 | 14 | teacher = dict( 15 | type='ViTTeacher', 16 | model=dict(type='vit_b_16', weights='.ViT_B_16_Weights.IMAGENET1K_V1'), 17 | downsample_factor=16, 18 | ) 19 | runner = dict(model=dict(encoder=dict(teacher=teacher))) 20 | 21 | _export_ = dict(trainer=runner, validator=runner) 22 | -------------------------------------------------------------------------------- /configs/datasets/vanilla.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | dataset_type = 'VQDatasetRegistry.Dataset' 4 | access_layer_type = 'PILAccessLayer' 5 | 6 | _export_ = dict( 7 | trainer=dict( 8 | dataset=dict( 9 | type=dataset_type, 10 | name='vanilla_train', 11 | access_layer=dict(type=access_layer_type), 12 | ), 13 | ), 14 | validator=dict( 15 | dataset=dict( 16 | type=dataset_type, 17 | name='vanilla_val', 18 | access_layer=dict(type=access_layer_type), 19 | ), 20 | ), 21 | ) 22 | -------------------------------------------------------------------------------- /configs/datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'vanilla.py', 3 | ] 4 | 5 | data_root = 'data/ffhq-256/' # pylint: disable=invalid-name 6 | 7 | _export_ = dict( 8 | trainer=dict( 9 | dataset=dict( 10 | name='ffhq_train', 11 | num_categories=1, 12 | access_layer=dict(data_root=data_root, task_name='train'), 13 | ), 14 | ), 15 | validator=dict( 16 | dataset=dict( 17 | name='ffhq_val', 18 | num_categories=1, 19 | access_layer=dict(data_root=data_root, task_name='val'), 20 | ), 21 | ), 22 | ) 23 | -------------------------------------------------------------------------------- /configs/datasets/celeba_hq.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'vanilla.py', 3 | ] 4 | 5 | data_root = 'data/celeba-hq-256/' # pylint: disable=invalid-name 6 | 7 | _export_ = dict( 8 | trainer=dict( 9 | dataset=dict( 10 | name='celeba_hq_train', 11 | num_categories=1, 12 | access_layer=dict(data_root=data_root, task_name='train'), 13 | ), 14 | ), 15 | validator=dict( 16 | dataset=dict( 17 | name='celeba_hq_val', 18 | num_categories=1, 19 | access_layer=dict(data_root=data_root, task_name='val'), 20 | ), 21 | ), 22 | ) 23 | -------------------------------------------------------------------------------- /configs/datasets/coco_2014.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'VQDatasetRegistry.COCODataset' # pylint: disable=invalid-name 2 | 3 | _export_ = dict( 4 | trainer=dict( 5 | dataset=dict( 6 | type=dataset_type, 7 | name='coco_2014_train', 8 | num_categories=80, 9 | split='train', 10 | year=2014, 11 | ), 12 | ), 13 | validator=dict( 14 | dataset=dict( 15 | type=dataset_type, 16 | name='coco_2014_val', 17 | num_categories=80, 18 | split='val', 19 | year=2014, 20 | ), 21 | ), 22 | ) 23 | -------------------------------------------------------------------------------- /configs/datasets/sa_med2d_20m.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'VQDatasetRegistry.SAMed2DDataset' # noqa: E501 pylint: disable=invalid-name 2 | 3 | _export_ = dict( 4 | trainer=dict( 5 | dataset=dict( 6 | type=dataset_type, 7 | name='sa_med2d_train', 8 | num_categories=1, 9 | split='v1', 10 | train=True, 11 | ), 12 | ), 13 | validator=dict( 14 | dataset=dict( 15 | type=dataset_type, 16 | name='sa_med2d_val', 17 | num_categories=1, 18 | split='v1', 19 | train=False, 20 | ), 21 | ), 22 | ) 23 | -------------------------------------------------------------------------------- /configs/ar/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (96, 96)) 9 | _kwargs_.setdefault('num_workers', (5, 5)) 10 | _kwargs_.setdefault('image_size', 256) 11 | _kwargs_.setdefault('augmentation', 'weak') 12 | _kwargs_.setdefault('batch_size_in_total', True) 13 | 14 | _base_ = [ 15 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 16 | '../strategies/interface.py', 17 | 'model.py', 18 | 'runner.py', 19 | 'custom_imports.py', 20 | ] 21 | 22 | _export_: dict[str, Any] = dict() 23 | -------------------------------------------------------------------------------- /configs/datasets/batch_size.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | trainer_batch_size, validator_batch_size = _kwargs_['batch_sizes'] 6 | trainer_num_workers, validator_num_workers = _kwargs_['num_workers'] 7 | 8 | _export_ = dict( 9 | trainer=dict( 10 | dataloader=dict( 11 | batch_size=trainer_batch_size, 12 | num_workers=trainer_num_workers, 13 | ), 14 | ), 15 | validator=dict( 16 | dataloader=dict( 17 | batch_size=validator_batch_size, 18 | num_workers=validator_num_workers, 19 | ), 20 | ), 21 | ) 22 | -------------------------------------------------------------------------------- /configs/vqgan/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (96, 96)) 9 | _kwargs_.setdefault('num_workers', (5, 5)) 10 | _kwargs_.setdefault('image_size', 256) 11 | _kwargs_.setdefault('augmentation', 'strong') 12 | _kwargs_.setdefault('batch_size_in_total', True) 13 | 14 | _base_ = [ 15 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 16 | '../strategies/interface.py', 17 | 'model.py', 18 | 'runner.py', 19 | 'custom_imports.py', 20 | ] 21 | 22 | _export_: dict[str, Any] = dict() 23 | -------------------------------------------------------------------------------- /configs/cluster/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (256, 256)) 9 | _kwargs_.setdefault('num_workers', (10, 10)) 10 | _kwargs_.setdefault('image_size', 256) 11 | _kwargs_.setdefault('augmentation', 'strong') 12 | _kwargs_.setdefault('batch_size_in_total', True) 13 | 14 | _base_ = [ 15 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 16 | '../strategies/interface.py', 17 | 'model.py', 18 | 'runner.py', 19 | 'custom_imports.py', 20 | ] 21 | 22 | _export_: dict[str, Any] = dict() 23 | -------------------------------------------------------------------------------- /configs/fsq/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (96, 96)) 9 | _kwargs_.setdefault('num_workers', (5, 5)) 10 | _kwargs_.setdefault('image_size', 256) 11 | _kwargs_.setdefault('augmentation', 'strong') 12 | _kwargs_.setdefault('batch_size_in_total', True) 13 | 14 | _base_ = [ 15 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 16 | '../strategies/interface.py', 17 | 'model.py', 18 | '../vqgan/runner.py', 19 | 'custom_imports.py', 20 | ] 21 | 22 | _export_: dict[str, Any] = dict() 23 | -------------------------------------------------------------------------------- /configs/vqkd/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (512, 512)) 9 | _kwargs_.setdefault('num_workers', (10, 10)) 10 | _kwargs_.setdefault('image_size', 224) 11 | _kwargs_.setdefault('augmentation', 'default') 12 | _kwargs_.setdefault('batch_size_in_total', True) 13 | 14 | _base_ = [ 15 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 16 | '../strategies/interface.py', 17 | 'model.py', 18 | 'runner.py', 19 | 'custom_imports.py', 20 | ] 21 | 22 | _export_: dict[str, Any] = dict() 23 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/quantizer.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQGANQuantizer', 3 | ] 4 | 5 | import todd 6 | 7 | from vq.algorithms.vq import VectorQuantizer 8 | from vq.tasks.image_tokenization.models import VQITQuantizerRegistry 9 | 10 | 11 | @VQITQuantizerRegistry.register_() 12 | class VQGANQuantizer(VectorQuantizer): 13 | 14 | def _init_weights(self, config: todd.Config) -> bool: 15 | if config == todd.Config(type='vqgan'): 16 | config = todd.Config( 17 | type='uniform_', 18 | a=-1.0 / self.codebook_size, 19 | b=1.0 / self.codebook_size, 20 | ) 21 | return super()._init_weights(config) 22 | -------------------------------------------------------------------------------- /vq/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ImageNetDataset', 3 | ] 4 | 5 | from typing import cast 6 | 7 | import todd 8 | from PIL import Image 9 | 10 | from ..registries import VQDatasetRegistry 11 | from .base import BaseMixin, T 12 | 13 | 14 | @VQDatasetRegistry.register_() 15 | class ImageNetDataset( 16 | BaseMixin[str, Image.Image], 17 | todd.datasets.ImageNetDataset, 18 | ): 19 | 20 | def __getitem__(self, index: int) -> T: 21 | item = super().__getitem__(index) 22 | image = item['image'] 23 | item = cast(T, item) 24 | item['original_image'] = image 25 | item['image'] = self.encode(image) 26 | return item 27 | -------------------------------------------------------------------------------- /configs/datasets/laion_aesthetics.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'VQDatasetRegistry.LAIONAestheticsDataset' # noqa: E501 pylint: disable=invalid-name 2 | 3 | _export_ = dict( 4 | trainer=dict( 5 | dataset=dict( 6 | type=dataset_type, 7 | name='laion_aesthetics_train', 8 | num_categories=1, 9 | split='v2_6.5plus', 10 | train=True, 11 | ), 12 | ), 13 | validator=dict( 14 | dataset=dict( 15 | type=dataset_type, 16 | name='laion_aesthetics_val', 17 | num_categories=1, 18 | split='v2_6.5plus', 19 | train=False, 20 | ), 21 | ), 22 | ) 23 | -------------------------------------------------------------------------------- /configs/llamagen/c2i_medium_vqgan_imagenet_ddp.py: -------------------------------------------------------------------------------- 1 | """C2I Medium with VQGAN tokenizer. 2 | 3 | Example: 4 | bash tools/torchrun.sh -m vq.train llamagen/c2i_medium_vqgan_imagenet_ddp \ 5 | configs/llamagen/c2i_medium_vqgan_imagenet_ddp.py --config-options \ 6 | ir_state_dict::work_dirs/llamagen/vqgan_imagenet_ddp/checkpoints/iter_1/model.\ 7 | pth 8 | """ 9 | 10 | from typing import Any 11 | 12 | from todd.configs import PyConfig 13 | 14 | _kwargs_: dict[str, Any] 15 | _kwargs_ = dict(_kwargs_) 16 | _kwargs_.setdefault('ir_config', 'configs/llamagen/vqgan_imagenet_ddp.py') 17 | 18 | _export_ = PyConfig.load( 19 | 'configs/llamagen/c2i_medium_imagenet_ddp.py', 20 | **_kwargs_, 21 | ) 22 | -------------------------------------------------------------------------------- /vq/datasets/coco.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'COCODataset', 3 | ] 4 | 5 | from typing import cast 6 | 7 | import todd 8 | from PIL import Image 9 | 10 | from ..registries import VQDatasetRegistry 11 | from .base import BaseMixin, T 12 | 13 | 14 | @VQDatasetRegistry.register_() 15 | class COCODataset(BaseMixin[str, Image.Image], todd.datasets.COCODataset): 16 | 17 | def __getitem__(self, index: int) -> T: # type: ignore[override] 18 | item = super().__getitem__(index) 19 | image = item['image'] 20 | item_ = cast(T, item) 21 | item_['original_image'] = image 22 | item_['image'] = self.encode(image) 23 | item_['category'] = 0 24 | return item_ 25 | -------------------------------------------------------------------------------- /configs/cluster/encoders/mae.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 768) 9 | 10 | _base_ = [ 11 | PyConfig.load('configs/vq/embedding_dim.py', **_kwargs_), 12 | ] 13 | 14 | teacher = dict( 15 | type='MAETeacher', 16 | model=dict( 17 | type='mae_vit_base_patch16', 18 | init_weights=dict( 19 | pretrained='pretrained/mae/mae_pretrain_vit_base.pth', 20 | ), 21 | ), 22 | downsample_factor=16, 23 | ) 24 | runner = dict(model=dict(encoder=dict(teacher=teacher))) 25 | 26 | _export_ = dict(trainer=runner, validator=runner) 27 | -------------------------------------------------------------------------------- /configs/datasets/transforms/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | image_size = _kwargs_['image_size'] 8 | augmentation = _kwargs_['augmentation'] 9 | 10 | config = PyConfig.load( 11 | f'configs/datasets/transforms/{augmentation}.py', 12 | image_size=image_size, 13 | ) 14 | config.image_size = image_size 15 | trainer = dict(dataset=config) 16 | 17 | config = PyConfig.load( 18 | 'configs/datasets/transforms/none.py', 19 | image_size=image_size, 20 | ) 21 | config.image_size = image_size 22 | validator = dict(dataset=config) 23 | 24 | _export_ = dict(trainer=trainer, validator=validator) 25 | -------------------------------------------------------------------------------- /vq/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQRegistry', 3 | 'VQDatasetRegistry', 4 | 'VQModelRegistry', 5 | 'VQRunnerRegistry', 6 | 'VQTaskRegistry', 7 | ] 8 | 9 | import todd 10 | from todd.registries import ( 11 | DatasetRegistry, 12 | ModelRegistry, 13 | RunnerRegistry, 14 | TaskRegistry, 15 | ) 16 | 17 | 18 | class VQRegistry(todd.Registry): 19 | pass 20 | 21 | 22 | class VQDatasetRegistry(VQRegistry, DatasetRegistry): 23 | pass 24 | 25 | 26 | class VQModelRegistry(VQRegistry, ModelRegistry): 27 | pass 28 | 29 | 30 | class VQRunnerRegistry(VQRegistry, RunnerRegistry): 31 | pass 32 | 33 | 34 | class VQTaskRegistry(VQRegistry, TaskRegistry): 35 | pass 36 | -------------------------------------------------------------------------------- /configs/llamagen/ar.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (256, 256)) 9 | _kwargs_.setdefault('num_workers', (10, 10)) 10 | _kwargs_.setdefault('transformer', 'llama') 11 | 12 | _base_ = [ 13 | PyConfig.load('configs/ar/interface.py', **_kwargs_), 14 | '../ar/cfg.py', 15 | ] 16 | 17 | model = dict(transformer=dict(sampler=dict(type='BaseSampler'))) 18 | trainer = dict( 19 | model=model, 20 | optimizers=dict(betas=(0.9, 0.95), weight_decay=0.05), 21 | iters=250_000, 22 | ) 23 | validator = dict(model=model) 24 | 25 | _export_ = dict(trainer=trainer, validator=validator) 26 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This repository requires **Python 3.11** or higher. While PyTorch 2.4 is recommended, earlier versions may also be compatible. Ensure that you also install a matching version of TorchVision. 4 | 5 | After installing `torch` and `torchvision`, use the following command to install the remaining dependencies: 6 | 7 | ```bash 8 | GIT_LFS_SKIP_SMUDGE=1 pip install -e . 9 | pip install git+https://github.com/LutingWang/CLIP.git # TODO: remove this dependency 10 | ``` 11 | 12 | For experiments involving **StyleGAN**, install MMCV using: 13 | 14 | ```bash 15 | mim install mmcv 16 | ``` 17 | 18 | If you prefer to set up the environment manually, refer to the script provided in `setup.sh`. 19 | -------------------------------------------------------------------------------- /tools/model_ema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import todd 5 | import torch 6 | 7 | 8 | def parse_args() -> argparse.Namespace: 9 | parser = argparse.ArgumentParser(description='Extract EMA') 10 | parser.add_argument('path', type=pathlib.Path) 11 | parser.add_argument('attr', default='["callbacks"][2]["shadow"]') 12 | parser.add_argument('--out', default='model_ema.pth') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def main() -> None: 18 | args = parse_args() 19 | path: pathlib.Path = args.path 20 | state_dict = torch.load(path, 'cpu') 21 | model_ema = todd.patches.py.get_(state_dict, args.attr) 22 | torch.save(model_ema, path.parent / args.out) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /configs/ar/x2i.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | ir_config = _kwargs_['ir_config'] 6 | it_state_dict = _kwargs_.get('it_state_dict') 7 | ir_state_dict = _kwargs_['ir_state_dict'] 8 | 9 | _base_ = [ 10 | 'transformers/interface.py', 11 | ] 12 | 13 | state_dicts = ([ir_state_dict] 14 | if it_state_dict is None else [ir_state_dict, it_state_dict]) 15 | model = dict( 16 | type='ARC2I', 17 | freeze=dict(type='NamedModulesFilter', name='_ir'), 18 | filter_state_dict=True, 19 | transformer=dict(sampler=dict(type='TopKTopPSampler')), 20 | ir=dict(config=ir_config, state_dicts=state_dicts, strict=False), 21 | ) 22 | runner = dict(model=model) 23 | 24 | _export_ = dict(trainer=runner, validator=runner) 25 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/utils/quantizer_holder.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'QuantizerHolderMixin', 3 | ] 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import todd 8 | 9 | from ..base import BaseQuantizer 10 | 11 | if TYPE_CHECKING: 12 | from vq.algorithms.vq import VectorQuantizer 13 | 14 | 15 | class QuantizerHolderMixin(todd.utils.HolderMixin[BaseQuantizer]): 16 | 17 | @property 18 | def quantizer(self) -> BaseQuantizer: 19 | return self._instance 20 | 21 | @property 22 | def vector_quantizer(self) -> 'VectorQuantizer': 23 | from vq.algorithms.vq import ( # noqa: E501 pylint: disable=import-outside-toplevel 24 | VectorQuantizer, 25 | ) 26 | assert isinstance(self.quantizer, VectorQuantizer) 27 | return self.quantizer 28 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | project_root=$(dirname $(realpath $0)) 4 | 5 | curl https://raw.githubusercontent.com/LutingWang/todd/main/bin/pipenv_install | bash -s -- 3.11.10 6 | 7 | pipenv run pip install ~/wheels/torch-2.4.1+cu121-cp311-cp311-linux_x86_64.whl 8 | pipenv run pip install -i https://download.pytorch.org/whl/cu121 torchvision==0.19.1+cu121 9 | 10 | pipenv run pip install \ 11 | accelerate \ 12 | debugpy \ 13 | "protobuf<=3.20.1" \ 14 | scikit-image \ 15 | torch_fidelity \ 16 | transformers==4.35.2 17 | 18 | pipenv run pip install openmim 19 | pipenv run mim install mmcv 20 | pipenv run pip install git+https://github.com/lvis-dataset/lvis-api.git@lvis_challenge_2021 21 | make install_todd 22 | 23 | pip install git+https://github.com/LutingWang/CLIP.git # TODO: remove this dependency 24 | -------------------------------------------------------------------------------- /configs/ic/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (512, 64)) 9 | _kwargs_.setdefault('num_workers', (10, 5)) 10 | _kwargs_.setdefault('image_size', 256) 11 | _kwargs_.setdefault('augmentation', 'strong') 12 | 13 | _base_ = [ 14 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 15 | '../strategies/interface.py', 16 | 'model.py', 17 | 'runner.py', 18 | 'custom_imports.py', 19 | ] 20 | 21 | model = dict( 22 | type='VQModelRegistry.VQICModelRegistry.BaseModel', 23 | freeze=dict(type='NamedModulesFilter', name='_it'), 24 | filter_state_dict=True, 25 | num_categories=1000, 26 | ) 27 | runner = dict(model=model) 28 | 29 | _export_ = dict(trainer=runner, validator=runner) 30 | -------------------------------------------------------------------------------- /vq/algorithms/vq/callbacks/normalize.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'NormalizeCallback', 3 | ] 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from todd.runners import Memo 8 | 9 | from vq.tasks.image_tokenization.models.quantizers import ( 10 | VQITQuantizerCallbackRegistry, 11 | ) 12 | from vq.tasks.image_tokenization.models.quantizers.callbacks import ( 13 | BaseCallback, 14 | ) 15 | 16 | from .update import UpdateMixin 17 | 18 | 19 | @VQITQuantizerCallbackRegistry.register_() 20 | class NormalizeCallback(UpdateMixin, BaseCallback): 21 | 22 | def before_encode(self, x: torch.Tensor, memo: Memo) -> torch.Tensor: 23 | x = super().before_encode(x, memo) 24 | x = F.normalize(x) 25 | 26 | e = self.vector_quantizer.embedding.weight 27 | e = F.normalize(e) 28 | self._update_embedding(e) 29 | return x 30 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "options": { 3 | "env": { 4 | "CPU_": "12", 5 | "DRY_RUN": "True", 6 | "GPU": "2", 7 | "MEMORY": "60", 8 | "PORT": "5678" 9 | } 10 | }, 11 | "tasks": [ 12 | { 13 | "command": "tools/run.sh -m vq.train vqkd_clip_8192_imagenet_ddp configs/vqkd_clip_8192_imagenet_ddp.py", 14 | "isBackground": true, 15 | "label": "VQ-KD CLIP 8192 ImageNet DDP", 16 | "problemMatcher": { 17 | "background": { 18 | "activeOnStart": true, 19 | "beginsPattern": ".", 20 | "endsPattern": "." 21 | }, 22 | "pattern": [ 23 | { 24 | "file": 1, 25 | "location": 2, 26 | "message": 3, 27 | "regexp": "File \"(?!\\/)([^\"]+)\", line (\\d+), in (\\w+)" 28 | } 29 | ] 30 | }, 31 | "type": "shell" 32 | } 33 | ], 34 | "version": "2.0.0" 35 | } 36 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/teachers/convnext.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ConvNeXtTeacher', 3 | ] 4 | 5 | from typing import cast 6 | 7 | import einops 8 | import torch 9 | from todd.models import ConvNeXtRegistry 10 | from torch import nn 11 | from torchvision import models 12 | 13 | from ..registries import VQTeacherRegistry 14 | from .torchvision import TorchVisionTeacher 15 | 16 | 17 | @VQTeacherRegistry.register_() 18 | class ConvNeXtTeacher(TorchVisionTeacher[models.ConvNeXt]): 19 | REGISTRY = ConvNeXtRegistry 20 | 21 | @property 22 | def out_channels(self) -> int: 23 | layer_norm = cast(nn.LayerNorm, self._model.classifier[0]) 24 | return layer_norm.normalized_shape[0] 25 | 26 | def _forward(self, image: torch.Tensor, return_2d: bool) -> torch.Tensor: 27 | x = self._model.features(image) 28 | if not return_2d: 29 | x = einops.rearrange(x, 'b c h w -> b (h w) c') 30 | return x 31 | -------------------------------------------------------------------------------- /vq/datasets/vanilla.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Dataset', 3 | ] 4 | 5 | import todd 6 | import torch 7 | import torchvision.transforms.functional as F 8 | from PIL import Image 9 | 10 | from ..registries import VQDatasetRegistry 11 | from .base import BaseMixin, T 12 | 13 | 14 | @VQDatasetRegistry.register_() 15 | class Dataset(BaseMixin[str, Image.Image], todd.datasets.PILDataset[T]): 16 | 17 | def _transform(self, image: Image.Image) -> torch.Tensor: 18 | if self._transforms is None: 19 | return F.pil_to_tensor(image) 20 | return self._transforms(image) 21 | 22 | def __getitem__(self, index: int) -> T: 23 | key, image = self._access(index) 24 | tensor = self._transform(image) 25 | encoded_tensor = self.encode(tensor) 26 | return T( 27 | id_=key, 28 | original_image=tensor, 29 | image=encoded_tensor, 30 | category=0, 31 | ) 32 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | Our experiments primarily use the **ImageNet-1k** dataset. Please organize the dataset in the following structure: 4 | 5 | ```text 6 | data/imagenet/ 7 | ├── annotations 8 | │ ├── train.json 9 | | └── val.json 10 | ├── train 11 | │ ├── n1440764 12 | │ │ ├── 18.JPEG 13 | │ │ └── ... 14 | | └── ... 15 | ├── val 16 | │ ├── n1440764 17 | │ │ ├── 293.JPEG 18 | │ │ └── ... 19 | | └── ... 20 | └── synsets.json 21 | ``` 22 | 23 | Both ``train.json`` and ``val.json`` have the following structure: 24 | 25 | ```text 26 | [{"image":"12925.JPEG","synset_id":449},...] 27 | ``` 28 | 29 | ``synsets.json`` maps synset IDs to their corresponding information: 30 | 31 | ```text 32 | {"1":{"WNID":"n02119789","words":"kit fox, Vulpes macrotis",...},...} 33 | ``` 34 | 35 | For detailed instructions on downloading and preparing the dataset, refer to 36 | -------------------------------------------------------------------------------- /vq/datasets/sa_med2d.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'SAMed2DDataset', 3 | ] 4 | 5 | from typing import cast 6 | 7 | import todd 8 | from PIL import Image 9 | 10 | from ..registries import VQDatasetRegistry 11 | from .base import T 12 | from .split import SplitMixin 13 | 14 | 15 | @VQDatasetRegistry.register_() 16 | class SAMed2DDataset( 17 | SplitMixin[str, Image.Image], 18 | todd.datasets.SAMed2DDataset, 19 | ): 20 | 21 | def __init__(self, *args, **kwargs) -> None: 22 | super().__init__(*args, shuffle=True, **kwargs) 23 | 24 | def __getitem__(self, index: int) -> T: # type: ignore[override] 25 | item = super().__getitem__( # type: ignore[safe-super] 26 | self._align_index(index), 27 | ) 28 | image = item['image'] 29 | item_ = cast(T, item) 30 | item_['original_image'] = image 31 | item_['image'] = self.encode(image) 32 | item_['category'] = 0 33 | return item_ 34 | -------------------------------------------------------------------------------- /configs/vqgan/README.md: -------------------------------------------------------------------------------- 1 | # VQGAN 2 | 3 | [[arXiv](https://arxiv.org/abs/2012.09841)] [[GitHub](https://github.com/CompVis/taming-transformers)] 4 | 5 | Training: 6 | 7 | ```bash 8 | auto_torchrun -m vq.train vqgan/8192_dd2_aglwg075_imagenet_ddp configs/vqgan/8192_dd2_aglwg075_imagenet_ddp.py 9 | ``` 10 | 11 | 25 | -------------------------------------------------------------------------------- /configs/fid/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (128, 128)) 9 | _kwargs_.setdefault('num_workers', (10, 10)) 10 | _kwargs_.setdefault('image_size', 256) 11 | _kwargs_.setdefault('augmentation', 'none') 12 | 13 | _base_ = [ 14 | PyConfig.load('configs/datasets/interface.py', **_kwargs_), 15 | '../strategies/interface.py', 16 | ] 17 | 18 | runner = dict( 19 | type='VQRunnerRegistry.BaseValidator', 20 | model=dict(type='FIDModel'), 21 | callbacks=[ 22 | dict( 23 | type='LogCallback', 24 | interval=20, 25 | collect_env=dict(), 26 | eta=dict(type='EMA_ETA', ema=dict(decay=0.9)), 27 | with_file_handler=True, 28 | ), 29 | dict(type='FIDCallback'), 30 | ], 31 | ) 32 | 33 | _export_ = dict(trainer=runner, validator=runner) 34 | -------------------------------------------------------------------------------- /configs/fsq/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | codebook_size = _kwargs_['codebook_size'] 8 | 9 | _base_ = [ 10 | '../sq/interface.py', 11 | PyConfig.load('configs/vqgan/model.py', num_embeddings=None, **_kwargs_), 12 | ] 13 | 14 | NUM_SCALARS_PER_CHANNEL = { 15 | 8000: [8, 8, 5, 5, 5], 16 | 64000: [8, 8, 8, 5, 5, 5], 17 | } 18 | model = dict( 19 | post_encode=[ 20 | dict(type='ConvConnector', out_channels=256), 21 | dict(type='ConvConnector'), 22 | ], 23 | quantizer=dict( 24 | _delete_=True, 25 | type='FiniteScalarQuantizer', 26 | num_scalars_per_channel=NUM_SCALARS_PER_CHANNEL[codebook_size], 27 | ), 28 | pre_decode=[ 29 | dict(type='ConvConnector', out_channels=256), 30 | dict(type='ConvConnector'), 31 | ], 32 | ) 33 | runner = dict(model=model) 34 | 35 | _export_ = dict(trainer=runner, validator=runner) 36 | -------------------------------------------------------------------------------- /vq/models/autoencoders.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseEncoder', 3 | 'BaseDecoder', 4 | ] 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import torch 9 | from todd.runners import Memo 10 | from torch import nn 11 | 12 | 13 | class BaseEncoder(nn.Module, ABC): 14 | 15 | @property 16 | @abstractmethod 17 | def out_channels(self) -> int: 18 | pass 19 | 20 | @abstractmethod 21 | def forward( 22 | self, 23 | image: torch.Tensor, 24 | memo: Memo, 25 | ) -> tuple[torch.Tensor, Memo]: 26 | pass 27 | 28 | 29 | class BaseDecoder(nn.Module, ABC): 30 | 31 | @property 32 | @abstractmethod 33 | def in_channels(self) -> int: 34 | pass 35 | 36 | @property 37 | @abstractmethod 38 | def last_parameter(self) -> nn.Parameter: 39 | pass 40 | 41 | @abstractmethod 42 | def forward( 43 | self, 44 | z: torch.Tensor, 45 | memo: Memo, 46 | ) -> tuple[torch.Tensor, Memo]: 47 | pass 48 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/models/registries.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQSMTransformerRegistry', 3 | 'VQSMSamplerRegistry', 4 | ] 5 | 6 | from typing import TYPE_CHECKING, Any 7 | 8 | import todd 9 | from todd.bases.registries import Item 10 | 11 | from ..registries import VQSMModelRegistry 12 | 13 | if TYPE_CHECKING: 14 | from .samplers import BaseSampler 15 | 16 | 17 | class VQSMTransformerRegistry(VQSMModelRegistry): 18 | pass 19 | 20 | 21 | class VQSMSamplerRegistry(VQSMModelRegistry): 22 | 23 | @classmethod 24 | def _build(cls, item: Item, config: todd.Config) -> Any: 25 | config = config.copy() 26 | cfg = config.pop('cfg', None) 27 | sampler: 'BaseSampler' = todd.RegistryMeta._build(cls, item, config) 28 | if cfg is not None: 29 | from .samplers import ( # pylint: disable=import-outside-toplevel 30 | CFGSampler, 31 | ) 32 | sampler = CFGSampler(sampler=sampler, alpha=cfg) 33 | return sampler 34 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/teachers/torchvision.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'TorchVisionTeacher', 3 | ] 4 | 5 | from typing import TypeVar 6 | 7 | import todd 8 | from todd.bases.registries import Item 9 | from todd.models import TorchVisionRegistry 10 | from torch import nn 11 | 12 | from .base import BaseTeacher 13 | 14 | T = TypeVar('T', bound=nn.Module) 15 | 16 | 17 | class TorchVisionTeacher(BaseTeacher[T]): 18 | REGISTRY: type[TorchVisionRegistry] 19 | 20 | def __init__( 21 | self, 22 | *args, 23 | mean=todd.datasets.IMAGENET_MEAN_255, 24 | std=todd.datasets.IMAGENET_STD_255, 25 | **kwargs, 26 | ) -> None: 27 | super().__init__(*args, mean=mean, std=std, **kwargs) 28 | 29 | @classmethod 30 | def model_build_pre_hook( 31 | cls, 32 | config: todd.Config, 33 | registry: todd.RegistryMeta, 34 | item: Item, 35 | ) -> todd.Config: 36 | config.model = cls.REGISTRY.build_or_return(config.model) 37 | return config 38 | -------------------------------------------------------------------------------- /vq/algorithms/ar/transformers/gpt.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'GPT2Transformer', 3 | ] 4 | 5 | import os 6 | 7 | import todd 8 | from todd.bases.registries import Item 9 | from transformers import GPT2LMHeadModel 10 | 11 | from vq.tasks.sequence_modeling.models import VQSMTransformerRegistry 12 | from vq.utils import Store 13 | 14 | from .hf import HFTransformer 15 | 16 | 17 | @VQSMTransformerRegistry.register_() 18 | class GPT2Transformer(HFTransformer): 19 | _transformer: GPT2LMHeadModel 20 | 21 | @classmethod 22 | def transformer_build_pre_hook( 23 | cls, 24 | config: todd.Config, 25 | registry: todd.RegistryMeta, 26 | item: Item, 27 | ) -> todd.Config: 28 | transformer = ( 29 | 'pretrained/huggingface/gpt2' 30 | if todd.Store.DRY_RUN else config.transformer 31 | ) 32 | transformer = os.path.join(Store.PRETRAINED, transformer) 33 | config.transformer = GPT2LMHeadModel.from_pretrained(transformer) 34 | return config 35 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/connectors/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseConnector', 3 | ] 4 | 5 | import torch 6 | from todd.runners import Memo 7 | from torch import nn 8 | 9 | from ..registries import VQITConnectorRegistry 10 | 11 | 12 | @VQITConnectorRegistry.register_() 13 | class BaseConnector(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | *args, 18 | in_channels: int, 19 | out_channels: int, 20 | **kwargs, 21 | ) -> None: 22 | super().__init__(*args, **kwargs) 23 | self._in_channels = in_channels 24 | self._out_channels = out_channels 25 | 26 | @property 27 | def in_channels(self) -> int: 28 | return self._in_channels 29 | 30 | @property 31 | def out_channels(self) -> int: 32 | return self._out_channels 33 | 34 | def forward( 35 | self, 36 | x: torch.Tensor, 37 | memo: Memo, 38 | ) -> tuple[torch.Tensor, Memo]: 39 | assert x.shape[1] == self._in_channels == self._out_channels 40 | return x, memo 41 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.words": [ 3 | "aglw", 4 | "aglwg", 5 | "autoencoders", 6 | "beit", 7 | "CelebA", 8 | "codebook", 9 | "ConvNeXt", 10 | "convs", 11 | "CVQVAE", 12 | "debias", 13 | "downsample", 14 | "EVACLIP", 15 | "FFHQ", 16 | "HFAR", 17 | "huggingface", 18 | "ILSVRC", 19 | "ImageNet", 20 | "kmeans", 21 | "LAION", 22 | "LlamaGen", 23 | "logit", 24 | "LPIPS", 25 | "mult", 26 | "multihead", 27 | "multimodal", 28 | "multinomial", 29 | "mults", 30 | "OpenClip", 31 | "PatchGAN", 32 | "pretrained", 33 | "PSNR", 34 | "quantizer", 35 | "quantizers", 36 | "regexes", 37 | "SSIM", 38 | "StyleGAN", 39 | "tokenizes", 40 | "unbatched", 41 | "uncondition", 42 | "VGGFace2", 43 | "ViTbase", 44 | "VQGAN", 45 | "VQIC", 46 | "VQIR", 47 | "VQIT", 48 | "VQKD", 49 | "VQSM" 50 | ], 51 | "python.analysis.exclude": [ 52 | "**/data", 53 | "**/pretrained", 54 | "**/tensorboards", 55 | "**/work_dirs" 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /vq/datasets/laion_aesthetics.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'LAIONAestheticsDataset', 3 | ] 4 | 5 | import random 6 | from typing import cast 7 | 8 | import todd 9 | from PIL import Image 10 | from todd.utils import retry 11 | 12 | from ..registries import VQDatasetRegistry 13 | from .base import T 14 | from .split import SplitMixin 15 | 16 | 17 | @VQDatasetRegistry.register_() 18 | class LAIONAestheticsDataset( 19 | SplitMixin[str, Image.Image], 20 | todd.datasets.LAIONAestheticsDataset, 21 | ): 22 | 23 | def _rand_index(self) -> int: 24 | return random.randint(0, len(self) - 1) # nosec B311 25 | 26 | @retry(10) 27 | def __getitem__(self, index: int, *, retry_times: int) -> T: 28 | if retry_times > 0: 29 | index = self._rand_index() 30 | item = super().__getitem__( # type: ignore[safe-super] 31 | self._align_index(index), 32 | ) 33 | image = item['image'] 34 | item_ = cast(T, item) 35 | item_['original_image'] = image 36 | item_['image'] = self.encode(image) 37 | item_['category'] = 0 38 | return item_ 39 | -------------------------------------------------------------------------------- /configs/cluster/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 32) 9 | _kwargs_.setdefault('distance', 'Cosine') 10 | 11 | _base_ = [ 12 | PyConfig.load('configs/vq/interface.py', **_kwargs_), 13 | 'encoders/interface.py', 14 | ] 15 | 16 | model = dict( 17 | type='VQModelRegistry.Cluster', 18 | encoder=dict(type='ClusterEncoder'), 19 | post_encode=dict(type='BaseConnector'), 20 | quantizer=dict( 21 | type='VQGANQuantizer', 22 | losses=dict(vqgan_loss=dict(type='CodebookLoss')), 23 | init_weights=dict(type='vqgan'), 24 | callbacks=[ 25 | dict( 26 | type='CVQVAECallback', 27 | ema=dict(), 28 | anchor=dict(type='NearestAnchor', sync=True), 29 | ), 30 | ], 31 | ), 32 | freeze=dict(type='NamedModulesFilter', name='_encoder'), 33 | filter_state_dict=True, 34 | ) 35 | runner = dict(model=model) 36 | 37 | _export_ = dict(trainer=runner, validator=runner) 38 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/connector.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQKDConnector', 3 | ] 4 | 5 | import torch 6 | from todd.runners import Memo 7 | from torch import nn 8 | 9 | from vq.tasks.image_tokenization.models import VQITConnectorRegistry 10 | from vq.tasks.image_tokenization.models.connectors import BaseConnector 11 | 12 | 13 | @VQITConnectorRegistry.register_() 14 | class VQKDConnector(BaseConnector): 15 | 16 | def __init__(self, *args, **kwargs) -> None: 17 | super().__init__(*args, **kwargs) 18 | conv1 = nn.Conv2d( 19 | self._in_channels, 20 | self._in_channels, 21 | 3, 22 | padding=1, 23 | ) 24 | tanh = nn.Tanh() 25 | conv2 = nn.Conv2d( 26 | self._in_channels, 27 | self._out_channels, 28 | 3, 29 | padding=1, 30 | ) 31 | self._sequential = nn.Sequential(conv1, tanh, conv2) 32 | 33 | def forward( 34 | self, 35 | x: torch.Tensor, 36 | memo: Memo, 37 | ) -> tuple[torch.Tensor, Memo]: 38 | x = self._sequential(x) 39 | return x, memo 40 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/callbacks/lazy_init_weights.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'LazyInitWeightsMixin', 3 | ] 4 | 5 | from abc import abstractmethod 6 | 7 | import todd 8 | import torch 9 | from todd.runners import Memo 10 | 11 | from ..base import BaseQuantizer 12 | from .base import BaseCallback 13 | 14 | 15 | class LazyInitWeightsMixin(BaseCallback): 16 | 17 | @abstractmethod 18 | def lazy_init_weights( 19 | self, 20 | config: todd.Config, 21 | x: torch.Tensor, 22 | memo: Memo, 23 | ) -> None: 24 | pass 25 | 26 | def before_init_weights(self, config: todd.Config) -> None: 27 | super().before_init_weights(config) 28 | lazy_init_weights = config.pop('lazy_init_weights', todd.Config()) 29 | 30 | def forward_pre_hook( 31 | module: BaseQuantizer, 32 | args: tuple[torch.Tensor, Memo], 33 | ) -> None: 34 | x, memo = args 35 | self.lazy_init_weights(lazy_init_weights, x, memo) 36 | handle.remove() 37 | 38 | handle = self.quantizer.register_forward_pre_hook(forward_pre_hook) 39 | -------------------------------------------------------------------------------- /configs/vqgan/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 256) 9 | _kwargs_.setdefault('distance', 'L2') 10 | 11 | _base_ = [ 12 | PyConfig.load('configs/vq/interface.py', **_kwargs_), 13 | ] 14 | 15 | model = dict( 16 | type='VQModelRegistry.VQGAN', 17 | encoder=dict(type='VQGANEncoder'), 18 | post_encode=dict(type='ConvConnector'), 19 | quantizer=dict( 20 | type='VQGANQuantizer', 21 | losses=dict(vqgan_loss=dict(type='VQGANLoss')), 22 | init_weights=dict(type='vqgan'), 23 | ), 24 | pre_decode=dict(type='ConvConnector'), 25 | decoder=dict(type='VQGANDecoder'), 26 | discriminator=dict(type='PatchGANDiscriminator'), 27 | reconstruct_losses=dict( 28 | l1_r_loss=dict(type='L1Loss'), 29 | lpips_r_loss=dict(type='LPIPSLoss'), 30 | ), 31 | generator_loss=dict(type='VQGANGeneratorLoss'), 32 | discriminator_loss=dict(type='VQGANDiscriminatorLoss'), 33 | ) 34 | runner = dict(model=model) 35 | 36 | _export_ = dict(trainer=runner, validator=runner) 37 | -------------------------------------------------------------------------------- /configs/llamagen/vqgan.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('batch_sizes', (128, 128)) 9 | _kwargs_.setdefault('num_workers', (10, 10)) 10 | _kwargs_.setdefault('num_embeddings', 16384) 11 | _kwargs_.setdefault('embedding_dim', 8) 12 | 13 | _base_ = [ 14 | PyConfig.load('configs/vqgan/interface.py', **_kwargs_), 15 | ] 16 | 17 | model = dict( 18 | quantizer=dict(callbacks=[ 19 | dict(type='NormalizeCallback'), 20 | ]), 21 | reconstruct_losses=dict( 22 | l1_r_loss=None, 23 | l2_r_loss=dict(type='MSELoss'), 24 | ), 25 | generator_loss=dict(weight=0.5), 26 | discriminator_loss=dict(weight=0.5), 27 | adaptive_generator_loss_weight_gain=None, 28 | ) 29 | trainer = dict( 30 | model=model, 31 | discriminator_start=20_000, 32 | optimizers=dict( 33 | generator=dict(lr=1e-4, betas=(0.9, 0.95)), 34 | discriminator=dict(lr=1e-4, betas=(0.9, 0.95)), 35 | ), 36 | iters=400_000, 37 | ) 38 | validator = dict(model=model) 39 | 40 | _export_ = dict(trainer=trainer, validator=validator) 41 | -------------------------------------------------------------------------------- /configs/ar/README.md: -------------------------------------------------------------------------------- 1 | # Auto-Regressive Proposal Networks 2 | 3 | The proposal network loads a pre-trained image reconstruction model. The config file of the image reconstruction model is specified via the `ir_config` field of `--config-options`. The state dict of the image reconstruction model is specified via the `ir_state_dict` field of `--config-options`. If the image reconstruction model relies on an image tokenizer that performs feature reconstruction, like VQ-KD, the `ir_state_dict` will not include the state dict of the image tokenizer. Therefore, an additional `it_state_dict` field is required to specify the state dict of the image tokenizer. 4 | 5 | ```bash 6 | auto_torchrun -m vq.train ar/c2i_llama_medium_cfg_imagenet_ddp configs/ar/c2i_llama_medium_cfg_imagenet_ddp.py \ 7 | --config-options \ 8 | ir_config::work_dirs/decoder/llamagen/vqkd_clip_8192_imagenet_ddp/llamagen_8192_dd2_aglwg075_imagenet_ddp/llamagen.py \ 9 | it_state_dict::work_dirs/vqkd/clip_8192_imagenet_ddp/checkpoints/iter_250000/model.pth \ 10 | ir_state_dict::work_dirs/decoder/llamagen/vqkd_clip_8192_imagenet_ddp/llamagen_8192_dd2_aglwg075_imagenet_ddp/checkpoints/iter_400000/model.pth 11 | ``` 12 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/losses/generator.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseGeneratorLoss', 3 | 'VQGANGeneratorLoss', 4 | 'NonSaturatingLoss', 5 | ] 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | import todd 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from .registries import VQGeneratorLossRegistry 14 | 15 | 16 | class BaseGeneratorLoss(todd.models.losses.BaseLoss, ABC): 17 | 18 | @abstractmethod 19 | def forward( # pylint: disable=arguments-differ 20 | self, 21 | logits_fake: torch.Tensor, 22 | ) -> torch.Tensor: 23 | pass 24 | 25 | 26 | @VQGeneratorLossRegistry.register_() 27 | class VQGANGeneratorLoss(BaseGeneratorLoss): 28 | 29 | def forward(self, logits_fake: torch.Tensor) -> torch.Tensor: 30 | loss = -logits_fake 31 | return self._reduce(loss) 32 | 33 | 34 | @VQGeneratorLossRegistry.register_() 35 | class NonSaturatingLoss(BaseGeneratorLoss): 36 | 37 | def forward(self, logits_fake: torch.Tensor) -> torch.Tensor: 38 | loss = F.binary_cross_entropy_with_logits( 39 | logits_fake, 40 | torch.ones_like(logits_fake), 41 | ) 42 | return self._reduce(loss) 43 | -------------------------------------------------------------------------------- /vq/runners/metrics/loss.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ImageLossMetric', 3 | ] 4 | 5 | from typing import TypeVar 6 | 7 | import einops 8 | import torch 9 | from todd.patches.py_ import get_ 10 | from todd.runners import Memo 11 | from todd.runners.metrics import LossMetric 12 | from torch import nn 13 | 14 | from ...datasets import Batch 15 | from ..base import BaseMixin 16 | from ..registries import VQMetricRegistry 17 | 18 | T = TypeVar('T', bound=nn.Module) 19 | 20 | 21 | @VQMetricRegistry.register_() 22 | class ImageLossMetric(LossMetric[T]): 23 | runner: BaseMixin[T] 24 | 25 | def __init__(self, *args, pred_image: str, image: str, **kwargs) -> None: 26 | inputs = dict(pred_image=pred_image, image=image) 27 | super().__init__(*args, inputs=inputs, **kwargs) 28 | 29 | def _forward(self, batch: Batch, memo: Memo) -> tuple[torch.Tensor, Memo]: 30 | inputs = {k: get_(memo, v) for k, v in self._inputs.items()} 31 | inputs = { 32 | k: self.runner.dataset.decode(v) / 255 33 | for k, v in inputs.items() 34 | } 35 | loss = self._loss(**inputs) 36 | loss = einops.reduce(loss, 'b ... -> b', reduction='mean') 37 | return loss, memo 38 | -------------------------------------------------------------------------------- /vq/utils/builders.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'build_module_list', 3 | 'build_module_dict', 4 | 'build_sequential', 5 | ] 6 | 7 | from typing import Iterable 8 | 9 | import todd 10 | from todd.patches.torch import ModuleDict, ModuleList, Sequential 11 | 12 | 13 | def build_module_list( 14 | registry: todd.RegistryMeta, 15 | config: Iterable[todd.Config | None], 16 | **kwargs, 17 | ) -> ModuleList: 18 | module_list = [ 19 | registry.build_or_return(c, **kwargs) for c in config if c is not None 20 | ] 21 | return ModuleList(module_list) 22 | 23 | 24 | def build_module_dict( 25 | registry: todd.RegistryMeta, 26 | config: todd.Config, 27 | **kwargs, 28 | ) -> ModuleDict: 29 | module_dict = { 30 | k: registry.build_or_return(v, **kwargs) 31 | for k, v in config.items() 32 | if v is not None 33 | } 34 | return ModuleDict(module_dict) 35 | 36 | 37 | def build_sequential( 38 | registry: todd.RegistryMeta, 39 | config: Iterable[todd.Config | None], 40 | **kwargs, 41 | ) -> Sequential: 42 | sequential = [ 43 | registry.build_or_return(c, **kwargs) for c in config if c is not None 44 | ] 45 | return Sequential(*sequential) 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import setup 4 | 5 | 6 | def todd_ai() -> str: 7 | path = pathlib.Path(__file__) 8 | path = path.parent / '.todd_version' 9 | todd_version = path.read_text().strip() 10 | return ( 11 | 'todd_ai[optional,dev,lint,doc,test] @ ' 12 | f'git+https://github.com/LutingWang/todd.git@{todd_version}' 13 | ) 14 | 15 | 16 | def symlink_configs() -> None: 17 | path = pathlib.Path(__file__) 18 | path = path.parent / 'vq' / 'configs' 19 | if path.exists(): 20 | return 21 | path.symlink_to('../configs') 22 | 23 | 24 | def symlink_todd_version() -> None: 25 | path = pathlib.Path(__file__) 26 | path = path.parent / 'vq' / '.todd_version' 27 | if path.exists(): 28 | return 29 | path.symlink_to('../.todd_version') 30 | 31 | 32 | symlink_configs() 33 | symlink_todd_version() 34 | setup( 35 | install_requires=[ 36 | 'accelerate', 37 | 'debugpy', 38 | 'protobuf<=3.20.1', 39 | 'scikit-image', 40 | 'torch_fidelity', 41 | 'transformers==4.35.2', 42 | 'openmim', 43 | 'lvis @ git+https://github.com/lvis-dataset/lvis-api.git@' 44 | 'lvis_challenge_2021', 45 | todd_ai(), 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /vq/datasets/satin.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'SATINDataset', 3 | ] 4 | 5 | from typing import Any 6 | 7 | import todd 8 | 9 | from ..registries import VQDatasetRegistry 10 | from .base import T 11 | from .split import SplitMixin 12 | 13 | 14 | @VQDatasetRegistry.register_() 15 | class SATINDataset( 16 | SplitMixin[int, dict[str, Any]], 17 | todd.datasets.SATINDataset, 18 | ): 19 | 20 | def __init__(self, *args, **kwargs) -> None: 21 | super().__init__(*args, shuffle=True, **kwargs) 22 | 23 | def _access( # type: ignore[override] 24 | self, 25 | index: int, 26 | ) -> tuple[int, dict[str, Any]]: 27 | return todd.datasets.SATINDataset._access( # type: ignore[return-value] # noqa: E501 pylint: disable=line-too-long 28 | self, # type: ignore[arg-type] 29 | index, 30 | ) 31 | 32 | def __getitem__(self, index: int) -> T: # type: ignore[override] 33 | item = super().__getitem__( # type: ignore[safe-super] 34 | self._align_index(index), 35 | ) 36 | id_ = item['id_'] 37 | image = item['image'] 38 | return T( 39 | id_=str(id_), 40 | image=self.encode(image), 41 | original_image=image, 42 | category=0, 43 | ) 44 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/quantizers/callbacks/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseCallback', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from todd.runners import Memo 8 | 9 | from ..utils import QuantizerHolderMixin 10 | 11 | 12 | class BaseCallback(QuantizerHolderMixin): 13 | 14 | def before_init_weights(self, config: todd.Config) -> None: 15 | pass 16 | 17 | def after_init_weights(self, config: todd.Config, recursive: bool) -> bool: 18 | return recursive 19 | 20 | def before_encode(self, x: torch.Tensor, memo: Memo) -> torch.Tensor: 21 | return x 22 | 23 | def after_encode( 24 | self, 25 | x: torch.Tensor, 26 | quant: torch.Tensor, 27 | memo: Memo, 28 | ) -> torch.Tensor: 29 | return quant 30 | 31 | def before_decode(self, quant: torch.Tensor, memo: Memo) -> torch.Tensor: 32 | return quant 33 | 34 | def after_decode(self, z: torch.Tensor, memo: Memo) -> torch.Tensor: 35 | return z 36 | 37 | def before_loss( 38 | self, 39 | z: torch.Tensor, 40 | x: torch.Tensor, 41 | memo: Memo, 42 | ) -> tuple[torch.Tensor, torch.Tensor]: 43 | return z, x 44 | 45 | def after_loss(self, loss: torch.Tensor, memo: Memo) -> torch.Tensor: 46 | return loss 47 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | SHELL := zsh 2 | 3 | current_todd_version := $$(cat .todd_version) 4 | latest_todd_version := $(shell curl -H "Accept: application/vnd.github.sha" -s https://api.github.com/repos/LutingWang/todd/commits/main) 5 | 6 | define install_todd 7 | pipenv run pip uninstall -y todd_ai 8 | GIT_LFS_SKIP_SMUDGE=1 pipenv run pip install \ 9 | "todd_ai @ git+https://github.com/LutingWang/todd.git@$(1)" 10 | pipenv run pip uninstall -y opencv-python opencv-python-headless 11 | pipenv run pip install opencv-python-headless 12 | endef 13 | 14 | .PHONY: lint commit install_todd todd tb 15 | 16 | lint: 17 | pipenv run pre-commit run -a 18 | 19 | commit: 20 | pipenv run cz c 21 | 22 | install_todd: 23 | $(call install_todd,$(current_todd_version)) 24 | 25 | todd: 26 | if [[ "$(latest_todd_version)" == "$(current_todd_version)" ]]; then \ 27 | echo "No changes since last build."; \ 28 | exit 1; \ 29 | fi 30 | $(call install_todd,$(latest_todd_version)) 31 | echo $(latest_todd_version) > .todd_version 32 | 33 | tb: 34 | mkdir -p tensorboards 35 | for work_dir in work_dirs/*; do \ 36 | name=$$(basename $${work_dir}); \ 37 | tb_dir=$${work_dir}/tensorboard; \ 38 | if [[ -d $${tb_dir} ]]; then \ 39 | ln -sfT $$(realpath $${tb_dir}) tensorboards/$${name}; \ 40 | fi; \ 41 | done 42 | tensorboard --logdir tensorboards --bind_all 43 | -------------------------------------------------------------------------------- /vq/algorithms/vq/distances.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'VQITQuantizerDistanceRegistry', 3 | 'BaseDistance', 4 | 'L2Distance', 5 | 'CosineDistance', 6 | ] 7 | 8 | from abc import ABC, abstractmethod 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | from vq.tasks.image_tokenization.models import VQITQuantizerRegistry 15 | 16 | 17 | class VQITQuantizerDistanceRegistry(VQITQuantizerRegistry): 18 | pass 19 | 20 | 21 | class BaseDistance(nn.Module, ABC): 22 | 23 | @abstractmethod 24 | def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor: 25 | pass 26 | 27 | 28 | @VQITQuantizerDistanceRegistry.register_() 29 | class L2Distance(BaseDistance): 30 | 31 | def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor: 32 | return torch.cdist(x, e) 33 | 34 | 35 | @VQITQuantizerDistanceRegistry.register_() 36 | class CosineDistance(BaseDistance): 37 | 38 | @staticmethod 39 | def cosine_similarity(x: torch.Tensor, e: torch.Tensor) -> torch.Tensor: 40 | # fixed CUDA out of memory error of torch.cosine_similarity 41 | x = F.normalize(x) 42 | e = F.normalize(e) 43 | return torch.einsum('x d, e d -> x e', x, e) 44 | 45 | def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor: 46 | return 1 - self.cosine_similarity(x, e) 47 | -------------------------------------------------------------------------------- /configs/datasets/hq_faces.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'vanilla.py', 3 | ] 4 | 5 | trainer = dict( 6 | dataset=dict( 7 | name='hq_faces_train', 8 | num_categories=1, 9 | access_layer=dict( 10 | type='ConcatAccessLayer', 11 | access_layers=dict( 12 | celeba_hq=dict( 13 | type='PILAccessLayer', 14 | data_root='data/celeba-hq-256/', 15 | task_name='train', 16 | ), 17 | ffhq=dict( 18 | type='PILAccessLayer', 19 | data_root='data/ffhq-256/', 20 | task_name='train', 21 | ), 22 | ), 23 | ), 24 | ), 25 | ) 26 | validator = dict( 27 | dataset=dict( 28 | name='hq_faces_val', 29 | num_categories=1, 30 | access_layer=dict( 31 | type='ConcatAccessLayer', 32 | access_layers=dict( 33 | celeba_hq=dict( 34 | type='PILAccessLayer', 35 | data_root='data/celeba-hq-256/', 36 | task_name='val', 37 | ), 38 | ffhq=dict( 39 | type='PILAccessLayer', 40 | data_root='data/ffhq-256/', 41 | task_name='val', 42 | ), 43 | ), 44 | ), 45 | ), 46 | ) 47 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/runners/metrics.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'AccuracyMetric', 3 | ] 4 | 5 | from typing import TypeVar 6 | 7 | import torch 8 | from todd.patches.py_ import get_ 9 | from todd.runners import Memo 10 | from todd.runners.metrics import Metric 11 | from torch import nn 12 | 13 | from vq.datasets import Batch 14 | 15 | from .registries import VQSMMetricRegistry 16 | 17 | T = TypeVar('T', bound=nn.Module) 18 | 19 | 20 | @VQSMMetricRegistry.register_() 21 | class AccuracyMetric(Metric[T]): 22 | 23 | def __init__( 24 | self, 25 | *args, 26 | pred: str, 27 | target: str, 28 | **kwargs, 29 | ) -> None: 30 | super().__init__(*args, **kwargs) 31 | self._pred = pred 32 | self._target = target 33 | 34 | def _forward(self, batch: Batch, memo: Memo) -> tuple[torch.Tensor, Memo]: 35 | log: Memo | None = memo.get('log') 36 | pred: torch.Tensor = get_(memo, self._pred) 37 | target: torch.Tensor = get_(memo, self._target) 38 | assert not pred.dtype.is_floating_point 39 | assert not target.dtype.is_floating_point 40 | pred = pred.reshape(pred.shape[0], -1) 41 | target = target.reshape(target.shape[0], -1) 42 | accuracy = pred == target 43 | accuracy = accuracy.float().mean(-1) 44 | if log is not None: 45 | log[self._name] = f'{accuracy.mean():.3f}' 46 | return accuracy, memo 47 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/runners/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseMixin', 3 | 'BaseTrainer', 4 | 'BaseValidator', 5 | ] 6 | 7 | from typing import TypeVar 8 | 9 | import todd 10 | from todd.bases.registries import Item 11 | from torch import nn 12 | 13 | from vq import VQModelRegistry 14 | from vq.datasets import BaseMixin as BaseDatasetMixin 15 | from vq.runners import BaseMixin as BaseMixin_ 16 | from vq.runners import BaseTrainer as BaseTrainer_ 17 | from vq.runners import BaseValidator as BaseValidator_ 18 | 19 | from ..registries import VQSMRunnerRegistry 20 | 21 | T = TypeVar('T', bound=nn.Module) 22 | 23 | 24 | class BaseMixin(BaseMixin_[T]): 25 | 26 | @classmethod 27 | def model_build_pre_hook( 28 | cls, 29 | config: todd.Config, 30 | registry: todd.RegistryMeta, 31 | item: Item, 32 | ) -> todd.Config: 33 | # ensure that dataset is built 34 | config = cls.dataset_build_pre_hook(config, registry, item) 35 | dataset: BaseDatasetMixin = config.dataset 36 | config.model = VQModelRegistry.build_or_return( 37 | config.model, 38 | num_categories=dataset.num_categories, 39 | ) 40 | return config 41 | 42 | 43 | @VQSMRunnerRegistry.register_() 44 | class BaseTrainer(BaseMixin[T], BaseTrainer_[T]): 45 | pass 46 | 47 | 48 | @VQSMRunnerRegistry.register_() 49 | class BaseValidator(BaseMixin[T], BaseValidator_[T]): 50 | pass 51 | -------------------------------------------------------------------------------- /vq/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from todd.configs import PyConfig 4 | from todd.patches.py_ import DictAction 5 | from todd.registries import RunnerRegistry 6 | from torch import nn 7 | 8 | from .runners import BaseValidator 9 | from .utils import log 10 | 11 | 12 | def parse_args() -> argparse.Namespace: 13 | parser = argparse.ArgumentParser(description='Test') 14 | parser.add_argument('dataset') 15 | parser.add_argument('--train', action='store_true') 16 | parser.add_argument('--strategy', default='cuda') 17 | parser.add_argument('--override', action=DictAction, default=dict()) 18 | parser.add_argument('--autocast', action='store_true') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main() -> None: 24 | args = parse_args() 25 | config: PyConfig = PyConfig.load( 26 | 'configs/fid/interface.py', 27 | dataset=args.dataset, 28 | strategy=args.strategy, 29 | ) 30 | config = config.trainer if args.train else config.validator 31 | config.override(args.override) 32 | 33 | name = args.dataset 34 | if args.train: 35 | name += '_train' 36 | name += f'_{args.strategy}' 37 | 38 | runner: BaseValidator[nn.Module] = RunnerRegistry.build( 39 | config, 40 | name=f'fid/{name}', 41 | autocast=args.autocast, 42 | ) 43 | log(runner, args, config) 44 | runner.run() 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /vq/algorithms/ar/transformers/llama.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'LlamaTransformer', 3 | ] 4 | 5 | import todd 6 | from todd.bases.registries import Item 7 | from torch import nn 8 | from transformers import LlamaConfig, LlamaForCausalLM 9 | 10 | from vq.tasks.sequence_modeling.models import VQSMTransformerRegistry 11 | 12 | from .hf import HFTransformer 13 | 14 | 15 | @VQSMTransformerRegistry.register_() 16 | class LlamaTransformer(HFTransformer): 17 | _transformer: LlamaForCausalLM 18 | 19 | @classmethod 20 | def transformer_build_pre_hook( 21 | cls, 22 | config: todd.Config, 23 | registry: todd.RegistryMeta, 24 | item: Item, 25 | ) -> todd.Config: 26 | llama_config = LlamaConfig(**config.transformer) 27 | config.transformer = LlamaForCausalLM(llama_config) 28 | return config 29 | 30 | def init_weights(self, config: todd.Config) -> bool: 31 | 32 | def initializer(module: nn.Module) -> None: 33 | if isinstance(module, nn.Linear): 34 | module.weight.data.normal_(mean=0.0, std=0.02) 35 | if module.bias is not None: 36 | module.bias.data.zero_() 37 | elif isinstance(module, nn.Embedding): 38 | module.weight.data.normal_(mean=0.0, std=0.02) 39 | 40 | super().init_weights(config) 41 | self._transformer.apply(initializer) 42 | nn.init.constant_(self._transformer.lm_head.weight, 0.0) 43 | return False 44 | -------------------------------------------------------------------------------- /vq/algorithms/vq/utils.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'QuantStatistics', 3 | ] 4 | 5 | import functools 6 | from typing_extensions import Self 7 | 8 | import torch 9 | import torch.distributed 10 | from todd.patches.torch import get_world_size 11 | 12 | 13 | class QuantStatistics: 14 | 15 | def __init__( 16 | self, 17 | *args, 18 | quant: torch.Tensor, 19 | codebook_size: int, 20 | sync: bool = False, 21 | **kwargs, 22 | ) -> None: 23 | super().__init__(*args, **kwargs) 24 | self._quant = quant 25 | self._codebook_size = codebook_size 26 | self._sync = sync and get_world_size() > 1 27 | 28 | @staticmethod 29 | def sync_decorator(func): 30 | 31 | @functools.wraps(func) 32 | def wrapper(self: Self, *args, **kwargs) -> torch.Tensor: 33 | tensor = func(self, *args, **kwargs) 34 | if self._sync: 35 | torch.distributed.all_reduce(tensor) 36 | return tensor 37 | 38 | return wrapper 39 | 40 | @sync_decorator 41 | def bin_count(self) -> torch.Tensor: 42 | return self._quant.bincount(minlength=self._codebook_size) 43 | 44 | @sync_decorator 45 | def num_elements(self) -> torch.Tensor: 46 | return self._quant.new_tensor(self._quant.numel()) 47 | 48 | def frequency(self) -> torch.Tensor: 49 | bin_count = self.bin_count() 50 | numel = self.num_elements() 51 | frequency = bin_count / numel 52 | return frequency 53 | -------------------------------------------------------------------------------- /vq/datasets/split.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'SplitMixin', 3 | ] 4 | 5 | from abc import ABC 6 | from typing import TypeVar 7 | 8 | import todd 9 | 10 | from ..registries import VQDatasetRegistry 11 | from .base import BaseMixin 12 | 13 | KT = TypeVar('KT') 14 | VT = TypeVar('VT') 15 | 16 | 17 | @VQDatasetRegistry.register_() 18 | class SplitMixin(BaseMixin[KT, VT], ABC): 19 | 20 | def __init__( 21 | self, 22 | *args, 23 | train: bool, 24 | num_val_samples: int = 25_000, 25 | shuffle: bool = False, 26 | **kwargs, 27 | ) -> None: 28 | self._train = train 29 | self._num_val_samples = num_val_samples 30 | self._shuffle = shuffle 31 | super().__init__(*args, **kwargs) 32 | 33 | def __len__(self) -> int: 34 | if todd.Store.DRY_RUN: 35 | return super().__len__() 36 | if self._train: 37 | return super().__len__() - self._num_val_samples 38 | return self._num_val_samples 39 | 40 | def _align_index(self, index: int) -> int: 41 | if todd.Store.DRY_RUN or not self._shuffle: 42 | return index if self._train else super().__len__() - 1 - index 43 | 44 | chunk_size = super().__len__() // self._num_val_samples 45 | assert chunk_size > 1 46 | 47 | if self._train: 48 | chunk_id = index // (chunk_size - 1) 49 | chunk_id = min(chunk_id, self._num_val_samples - 1) 50 | return index + 1 + chunk_id 51 | return index * chunk_size 52 | -------------------------------------------------------------------------------- /vq/runners/metrics/fid.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'FIDMetric', 3 | ] 4 | 5 | from typing import TypeVar 6 | 7 | import todd.tasks.image_generation as ig 8 | from todd.patches.py_ import get_ 9 | from todd.patches.torch import load 10 | from todd.runners import Memo 11 | from todd.runners.metrics import BaseMetric 12 | from torch import nn 13 | 14 | from ...datasets import Batch 15 | from ..base import BaseMixin 16 | from ..registries import VQMetricRegistry 17 | 18 | T = TypeVar('T', bound=nn.Module) 19 | 20 | 21 | @VQMetricRegistry.register_() 22 | class FIDMetric(BaseMetric[T]): 23 | runner: BaseMixin[T] 24 | 25 | def __init__( 26 | self, 27 | *args, 28 | pred: str, 29 | eps: float = 1e-6, 30 | **kwargs, 31 | ) -> None: 32 | super().__init__(*args, **kwargs) 33 | self._pred = pred 34 | self._statistician = ig.Statistician() 35 | self._eps = eps 36 | 37 | def forward(self, batch: Batch, memo: Memo) -> Memo: 38 | pred_image = get_(memo, self._pred) 39 | image = self.runner.dataset.decode(pred_image) 40 | self._statistician(image) 41 | return memo 42 | 43 | def summary(self, memo: Memo) -> float: 44 | from ...utils import Store # pylint: disable=import-outside-toplevel 45 | gt_statistics = load( 46 | self.runner.dataset.fid_path, 47 | 'cpu', 48 | directory=Store.PRETRAINED, 49 | ) 50 | pred_statistics = self._statistician.summarize() 51 | return ig.fid(gt_statistics, pred_statistics, self._eps) 52 | -------------------------------------------------------------------------------- /vq/utils/fid.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'FIDModel', 3 | 'FIDCallback', 4 | ] 5 | 6 | from typing import Any, TypeVar, cast 7 | 8 | import todd.tasks.image_generation as ig 9 | import torch 10 | from todd.registries import ModelRegistry 11 | from todd.runners import CallbackRegistry, Memo 12 | from todd.runners.callbacks import BaseCallback 13 | from todd.utils import Store 14 | from torch import nn 15 | 16 | from ..datasets.base import T 17 | from ..runners import BaseValidator 18 | 19 | ModuleType = TypeVar('ModuleType', bound=nn.Module) 20 | 21 | 22 | @ModelRegistry.register_() 23 | class FIDModel(ig.Statistician): 24 | 25 | def forward( # type: ignore[override] # pylint: disable=arguments-differ 26 | self, 27 | runner: Any, 28 | batch: T, 29 | memo: Memo, 30 | *args, 31 | mode: None = None, 32 | **kwargs, 33 | ) -> Memo: 34 | assert mode is None 35 | original_image = batch['original_image'] 36 | if Store.cuda: 37 | original_image = original_image.cuda() 38 | super().forward(original_image) 39 | return memo 40 | 41 | 42 | @CallbackRegistry.register_() 43 | class FIDCallback(BaseCallback[ModuleType]): 44 | runner: BaseValidator[ModuleType] 45 | 46 | def after_run(self, memo: Memo) -> None: 47 | super().after_run(memo) 48 | runner = self.runner 49 | model = cast(FIDModel, runner.strategy.module) 50 | statistics = model.summarize() 51 | torch.save(statistics, runner.dataset.fid_path) 52 | memo['statistics'] = statistics 53 | -------------------------------------------------------------------------------- /configs/datasets/satin.py: -------------------------------------------------------------------------------- 1 | NUM_SAMPLES = { 2 | 'NASC-TG2': 20000, 3 | 'WHU-RS19': 1005, 4 | 'RSSCN7': 2800, 5 | 'RS_C11': 1232, 6 | 'SIRI-WHU': 2400, 7 | 'NWPU-RESISC45': 31500, 8 | 'PatternNet': 30400, 9 | 'RSD46-WHU': 17516, 10 | 'CLRS': 15000, 11 | 'Optimal-31': 1860, 12 | 'Airbus-Wind-Turbines-Patches': 71504, 13 | 'USTC_SmokeRS': 6225, 14 | 'Satellite-Images-of-Hurricane-Damage': 10000, 15 | 'Million-AID': 10000, 16 | 'UC_Merced_LandUse_MultiLabel': 2100, 17 | 'MLRSNet': 109161, 18 | 'MultiScene': 14000, 19 | 'RSI-CB256': 24747, 20 | 'AID_MultiLabel': 3000, 21 | } 22 | 23 | trainer = dict( 24 | dataset=dict( 25 | type='VQDatasetRegistry.ConcatDataset', 26 | name='satin_train', 27 | num_categories=1, 28 | datasets=[ 29 | dict( 30 | type='VQDatasetRegistry.SATINDataset', 31 | split=split, 32 | num_val_samples=num_samples // 10, 33 | train=True, 34 | ) for split, num_samples in NUM_SAMPLES.items() 35 | ], 36 | ), 37 | ) 38 | validator = dict( 39 | dataset=dict( 40 | type='VQDatasetRegistry.ConcatDataset', 41 | name='satin_val', 42 | num_categories=1, 43 | datasets=[ 44 | dict( 45 | type='VQDatasetRegistry.SATINDataset', 46 | split=split, 47 | num_val_samples=num_samples // 10, 48 | train=False, 49 | ) for split, num_samples in NUM_SAMPLES.items() 50 | ], 51 | ), 52 | ) 53 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/runners/callbacks.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Tokens', 3 | 'TokenizeCallback', 4 | ] 5 | 6 | import pathlib 7 | from typing import TypedDict, TypeVar 8 | 9 | import einops 10 | import torch 11 | from todd.patches.torch import get_rank 12 | from todd.runners import Memo 13 | from todd.runners.callbacks import BaseCallback 14 | from torch import nn 15 | 16 | from vq.datasets import Batch 17 | 18 | from .registries import VQITCallbackRegistry 19 | 20 | T = TypeVar('T', bound=nn.Module) 21 | 22 | 23 | class Tokens(TypedDict): 24 | id_: list[str] 25 | category: torch.Tensor 26 | tokens: torch.Tensor 27 | 28 | 29 | @VQITCallbackRegistry.register_() 30 | class TokenizeCallback(BaseCallback[T]): 31 | 32 | @property 33 | def token_dir(self) -> pathlib.Path: 34 | return self.runner.work_dir / 'tokens' 35 | 36 | def bind(self, *args, **kwargs) -> None: 37 | super().bind(*args, **kwargs) 38 | self.token_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | def after_run_iter(self, batch: Batch, memo: Memo) -> None: 41 | super().after_run_iter(batch, memo) 42 | quantizer_memo = memo['quantizer'] 43 | quant = quantizer_memo['quant'] 44 | b, _, h, w = quantizer_memo['x_shape'] 45 | tokens = einops.rearrange(quant, '(b h w) -> b h w', b=b, h=h, w=w) 46 | torch.save( 47 | Tokens( 48 | id_=batch['id_'], 49 | category=batch['category'], 50 | tokens=tokens, 51 | ), 52 | self.token_dir / f'{self.runner.iter_}_{get_rank()}.pth', 53 | ) 54 | -------------------------------------------------------------------------------- /vq/algorithms/vq/callbacks/update.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'UpdateMixin', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from todd.bases.registries import BuildPreHookMixin, Item 8 | from todd.utils import EMA 9 | 10 | from vq.tasks.image_tokenization.models.quantizers.callbacks import ( 11 | BaseCallback, 12 | ) 13 | 14 | 15 | class UpdateMixin(BuildPreHookMixin, BaseCallback): 16 | 17 | def __init__( 18 | self, 19 | *args, 20 | ema: EMA | None = None, 21 | **kwargs, 22 | ) -> None: 23 | super().__init__(*args, **kwargs) 24 | if ema is not None: 25 | self._ema = ema 26 | 27 | @classmethod 28 | def ema_build_pre_hook_mixin( 29 | cls, 30 | config: todd.Config, 31 | registry: todd.RegistryMeta, 32 | item: Item, 33 | ) -> todd.Config: 34 | if (ema := config.get('ema')) is not None: 35 | config.ema = EMA(**ema) 36 | return config 37 | 38 | @classmethod 39 | def build_pre_hook( 40 | cls, 41 | config: todd.Config, 42 | registry: todd.RegistryMeta, 43 | item: Item, 44 | ) -> todd.Config: 45 | config = super().build_pre_hook(config, registry, item) 46 | config = cls.ema_build_pre_hook_mixin(config, registry, item) 47 | return config 48 | 49 | @property 50 | def with_ema(self) -> bool: 51 | return hasattr(self, '_ema') 52 | 53 | def _update_embedding(self, e: torch.Tensor) -> None: 54 | if todd.Store.DRY_RUN: 55 | assert todd.utils.is_sync(e) 56 | self.vector_quantizer.embedding.weight.data = e 57 | -------------------------------------------------------------------------------- /docs/inference.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | TBD. 4 | 5 | 43 | -------------------------------------------------------------------------------- /vq/algorithms/cluster/autoencoders.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ClusterEncoder', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from todd.bases.registries import BuildPreHookMixin, Item 8 | from todd.runners import Memo 9 | 10 | from vq.algorithms.vqkd import VQTeacherRegistry 11 | from vq.algorithms.vqkd.teachers import BaseTeacher 12 | from vq.models import BaseEncoder, VQEncoderRegistry 13 | 14 | 15 | @VQEncoderRegistry.register_() 16 | class ClusterEncoder(BuildPreHookMixin, BaseEncoder): 17 | 18 | def __init__(self, *args, teacher: BaseTeacher, **kwargs) -> None: 19 | super().__init__(*args, **kwargs) 20 | self._teacher = teacher 21 | 22 | @classmethod 23 | def teacher_build_pre_hook( 24 | cls, 25 | config: todd.Config, 26 | registry: todd.RegistryMeta, 27 | item: Item, 28 | ) -> todd.Config: 29 | config.teacher = VQTeacherRegistry.build_or_return(config.teacher) 30 | return config 31 | 32 | @classmethod 33 | def build_pre_hook( 34 | cls, 35 | config: todd.Config, 36 | registry: todd.RegistryMeta, 37 | item: Item, 38 | ) -> todd.Config: 39 | config = super().build_pre_hook(config, registry, item) 40 | config = cls.teacher_build_pre_hook(config, registry, item) 41 | return config 42 | 43 | @property 44 | def out_channels(self) -> int: 45 | return self._teacher.out_channels 46 | 47 | def forward( 48 | self, 49 | image: torch.Tensor, 50 | memo: Memo, 51 | ) -> tuple[torch.Tensor, Memo]: 52 | x = self._teacher(memo['original_image'], return_2d=True) 53 | return x, memo 54 | -------------------------------------------------------------------------------- /vq/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import pathlib 4 | 5 | from todd.configs import PyConfig 6 | from todd.patches.py_ import DictAction 7 | from torch import nn 8 | 9 | from .registries import VQRunnerRegistry 10 | from .runners import BaseValidator 11 | from .utils import log 12 | 13 | 14 | def parse_args() -> argparse.Namespace: 15 | parser = argparse.ArgumentParser(description='Test') 16 | parser.add_argument('name') 17 | parser.add_argument('config', type=pathlib.Path) 18 | parser.add_argument('--config-options', action=DictAction, default=dict()) 19 | parser.add_argument('--override', action=DictAction, default=dict()) 20 | parser.add_argument('--visual') 21 | parser.add_argument('--tokenize', action='store_true') 22 | parser.add_argument('--autocast', action='store_true') 23 | parser.add_argument('--load-model-from', required=True, nargs='+') 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def main() -> None: 29 | args = parse_args() 30 | config = PyConfig.load(args.config, **args.config_options) 31 | config.override(args.override) 32 | 33 | for custom_import in config.get('custom_imports', []): 34 | importlib.import_module(custom_import) 35 | 36 | runner: BaseValidator[nn.Module] = VQRunnerRegistry.build( 37 | config.validator, 38 | name=f'{args.name}_test', 39 | visual=args.visual, 40 | tokenize=args.tokenize, 41 | autocast=args.autocast, 42 | ) 43 | log(runner, args, config) 44 | runner.strategy.load_model_from(args.load_model_from, strict=False) 45 | runner.run() 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /vq/algorithms/ar/transformers/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseTransformer', 3 | ] 4 | 5 | import enum 6 | from abc import abstractmethod 7 | from typing import TypeVar 8 | 9 | import todd.tasks.large_multimodal_model as lmm 10 | import torch 11 | from todd.runners import Memo 12 | 13 | from vq.tasks.sequence_modeling.models import ( 14 | BaseTransformer as BaseSMTransformer, 15 | ) 16 | 17 | T = TypeVar('T', bound=enum.Enum) 18 | KVCache = tuple[tuple[torch.Tensor, torch.Tensor], ...] 19 | 20 | 21 | class BaseTransformer(BaseSMTransformer): 22 | 23 | @abstractmethod 24 | def _inference( 25 | self, 26 | tokens: torch.Tensor, 27 | kv_cache: KVCache | None, 28 | memo: Memo, 29 | ) -> tuple[torch.Tensor, KVCache, Memo]: 30 | pass 31 | 32 | @torch.no_grad() 33 | def inference( 34 | self, 35 | tokens: torch.Tensor, 36 | kv_cache: KVCache | None, 37 | memo: Memo, 38 | ) -> tuple[torch.Tensor, KVCache, Memo]: 39 | return self._inference(tokens, kv_cache, memo) 40 | 41 | def _generate( 42 | self, 43 | tokens: torch.Tensor, 44 | length: int, 45 | codebook: lmm.Codebook[T], 46 | memo: Memo, 47 | ) -> tuple[torch.Tensor, Memo]: 48 | assert tokens.shape[1] < length 49 | token = tokens 50 | kv_cache = None 51 | while tokens.shape[1] < length: 52 | logits, kv_cache, memo = self.inference(token, kv_cache, memo) 53 | token, memo = self.sample(logits[:, [-1]], codebook, memo) 54 | tokens = torch.cat((tokens, token), 1) 55 | assert tokens.shape[1] == length 56 | return tokens, memo 57 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/teachers/dino.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'DINOTeacher', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from todd.bases.registries import Item 8 | from todd.models.modules import DINO 9 | from todd.registries import InitWeightsMixin 10 | 11 | from vq.utils import Store 12 | 13 | from ..registries import VQTeacherRegistry 14 | from .base import BaseTeacher 15 | 16 | 17 | @VQTeacherRegistry.register_() 18 | class DINOTeacher(InitWeightsMixin, BaseTeacher[DINO]): 19 | 20 | def __init__(self, *args, **kwargs) -> None: 21 | super().__init__( 22 | *args, 23 | mean=todd.datasets.IMAGENET_MEAN_255, 24 | std=todd.datasets.IMAGENET_STD_255, 25 | **kwargs, 26 | ) 27 | 28 | @classmethod 29 | def model_build_pre_hook( 30 | cls, 31 | config: todd.Config, 32 | registry: todd.RegistryMeta, 33 | item: Item, 34 | ) -> todd.Config: 35 | config = super().model_build_pre_hook(config, registry, item) 36 | model = config.model 37 | if isinstance(model, todd.Config): 38 | config.model = DINO(**model) 39 | return config 40 | 41 | @property 42 | def out_channels(self) -> int: 43 | return self._model.width 44 | 45 | def init_weights(self, config: todd.Config) -> bool: 46 | super().init_weights(config) 47 | self._model.load_pretrained( 48 | 'pretrained/dino/vitbase16.pth', 49 | directory=Store.PRETRAINED, 50 | ) 51 | return False 52 | 53 | def _forward(self, image: torch.Tensor, return_2d: bool) -> torch.Tensor: 54 | _, x = self._model(image, return_2d) 55 | return x 56 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/connectors/conv.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ConvConnector', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from todd.bases.registries import BuildPreHookMixin, Item 8 | from todd.runners import Memo 9 | from torch import nn 10 | 11 | from ..registries import VQITConnectorRegistry 12 | from .base import BaseConnector 13 | 14 | 15 | @VQITConnectorRegistry.register_() 16 | class ConvConnector(BuildPreHookMixin, BaseConnector): 17 | 18 | def __init__(self, *args, conv: nn.Conv2d, **kwargs) -> None: 19 | super().__init__(*args, **kwargs) 20 | assert conv.in_channels == self._in_channels 21 | assert conv.out_channels == self._out_channels 22 | self._conv = conv 23 | 24 | @classmethod 25 | def conv_build_pre_hook( 26 | cls, 27 | config: todd.Config, 28 | registry: todd.RegistryMeta, 29 | item: Item, 30 | ) -> todd.Config: 31 | conv = config.conv if 'conv' in config else todd.Config(kernel_size=1) 32 | config.conv = nn.Conv2d( 33 | config.in_channels, 34 | config.out_channels, 35 | **conv, 36 | ) 37 | return config 38 | 39 | @classmethod 40 | def build_pre_hook( 41 | cls, 42 | config: todd.Config, 43 | registry: todd.RegistryMeta, 44 | item: Item, 45 | ) -> todd.Config: 46 | config = super().build_pre_hook(config, registry, item) 47 | config = cls.conv_build_pre_hook(config, registry, item) 48 | return config 49 | 50 | def forward( 51 | self, 52 | x: torch.Tensor, 53 | memo: Memo, 54 | ) -> tuple[torch.Tensor, Memo]: 55 | x = self._conv(x) 56 | return x, memo 57 | -------------------------------------------------------------------------------- /configs/decoder/interface.py: -------------------------------------------------------------------------------- 1 | """Interface for pixel decoders. 2 | 3 | Example: 4 | bash tools/torchrun.sh -m vq.train decoder/vqgan/\ 5 | vqkd_clip_8192_vqgan_8192_dd2_aglwg075_imagenet_ddp configs/decoder/vqgan.py \ 6 | --config-options it_config::configs/vqkd/clip_8192_imagenet_ddp.py decoder::\ 7 | vqkd --load-model-from work_dirs/vqkd/clip_8192_imagenet_ddp/checkpoints/\ 8 | iter_1/model.pth 9 | """ 10 | 11 | from typing import Any 12 | 13 | from todd.configs import PyConfig 14 | 15 | _kwargs_: dict[str, Any] 16 | _kwargs_ = dict(_kwargs_) 17 | it_config = _kwargs_['it_config'] 18 | ir_config = _kwargs_['ir_config'] 19 | decoder = _kwargs_.get('decoder') 20 | 21 | it = PyConfig.load(it_config, **_kwargs_) 22 | ir = PyConfig.load(ir_config, **_kwargs_) 23 | 24 | _base_ = [ir] 25 | 26 | it_model = it.validator.model 27 | model = dict( 28 | encoder=dict(_delete_=True, **it_model.encoder), 29 | post_encode=dict(_delete_=True, **it_model.post_encode), 30 | quantizer=dict(_delete_=True, **it_model.quantizer), 31 | freeze=dict( 32 | type='NamedModulesFilter', 33 | names=('_encoder', '_post_encode', '_quantizer'), 34 | ), 35 | filter_state_dict=True, 36 | ) 37 | 38 | if decoder is not None: 39 | decoder_config = PyConfig.load( 40 | f'configs/decoder/{decoder}.py', 41 | **_kwargs_, 42 | ) 43 | model['decoder'] = dict(_delete_=True, **decoder_config) 44 | 45 | params = dict( 46 | type='NamedParametersFilter', 47 | modules=dict(type='NamedModulesFilter', names=('_decoder', '_pre_decode')), 48 | ) 49 | 50 | _export_ = dict( 51 | trainer=dict( 52 | model=model, 53 | optimizers=dict(generator=dict(params=[dict(params=params)])), 54 | ), 55 | validator=dict(model=model), 56 | custom_imports=ir.custom_imports + it.custom_imports, 57 | ) 58 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/tokenize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import pathlib 4 | 5 | from todd.configs import PyConfig 6 | from todd.patches.py_ import DictAction 7 | from torch import nn 8 | 9 | from vq.registries import VQRunnerRegistry 10 | from vq.utils import log 11 | 12 | from .runners import Tokenizer 13 | 14 | 15 | def parse_args() -> argparse.Namespace: 16 | parser = argparse.ArgumentParser(description='Test') 17 | parser.add_argument('name') 18 | parser.add_argument('config', type=pathlib.Path) 19 | parser.add_argument('--config-options', action=DictAction, default=dict()) 20 | parser.add_argument('--train', action='store_true') 21 | parser.add_argument('--override', action=DictAction, default=dict()) 22 | parser.add_argument('--autocast', action='store_true') 23 | parser.add_argument('--load-model-from', required=True, nargs='+') 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def main() -> None: 29 | args = parse_args() 30 | config = PyConfig.load(args.config, **args.config_options) 31 | config.validator.type = 'VQITRunnerRegistry.Tokenizer' 32 | if args.train: 33 | config.validator.dataset = config.trainer.dataset 34 | config.override(args.override) 35 | 36 | for custom_import in config.get('custom_imports', []): 37 | importlib.import_module(custom_import) 38 | 39 | name = f'{args.name}_tokenize' 40 | if args.train: 41 | name += '_train' 42 | 43 | runner: Tokenizer[nn.Module] = VQRunnerRegistry.build( 44 | config.validator, 45 | name=name, 46 | autocast=args.autocast, 47 | ) 48 | log(runner, args, config) 49 | runner.strategy.load_model_from(args.load_model_from, strict=False) 50 | runner.run() 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /configs/decoder/README.md: -------------------------------------------------------------------------------- 1 | # Pixel Decoders 2 | 3 | For tokenizers like VQ-KD and Cluster, we need to train pixel decoders to reconstruct images from tokens. 4 | 5 | ```bash 6 | # VQ-KD 7 | auto_torchrun -m vq.train \ 8 | decoder/llamagen/vqkd_clip_8192_imagenet_ddp/llamagen_8192_dd2_aglwg075_imagenet_ddp \ 9 | configs/decoder/llamagen.py \ 10 | --config-options it_config::configs/vqkd/clip_8192_imagenet_ddp.py \ 11 | --load-model-from work_dirs/vqkd/clip_8192_imagenet_ddp/checkpoints/iter_250000/model.pth 12 | 13 | # Cluster 14 | auto_torchrun -m vq.train \ 15 | decoder/llamagen/cluster_clip_8192_imagenet_ddp/llamagen_8192_dd2_aglwg075_imagenet_ddp \ 16 | configs/decoder/llamagen.py \ 17 | --config-options it_config::configs/cluster/clip_8192_imagenet_ddp.py \ 18 | --load-model-from work_dirs/cluster/clip_8192_imagenet_ddp/checkpoints/iter_100000/model.pth 19 | ``` 20 | 21 | 44 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/runners/tokenizer.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Tokenizer', 3 | ] 4 | 5 | from typing import TypeVar, cast 6 | 7 | import todd 8 | from todd.bases.registries import Item 9 | from todd.runners import Memo 10 | from torch import nn 11 | 12 | from vq.datasets import Batch 13 | from vq.runners import BaseValidator 14 | 15 | from ..models import BaseModel 16 | from ..registries import VQITRunnerRegistry 17 | 18 | T = TypeVar('T', bound=nn.Module) 19 | 20 | 21 | @VQITRunnerRegistry.register_() 22 | class Tokenizer(BaseValidator[T]): 23 | 24 | @classmethod 25 | def callbacks_build_pre_hook( 26 | cls, 27 | config: todd.Config, 28 | registry: todd.RegistryMeta, 29 | item: Item, 30 | ) -> todd.Config: 31 | config.update(tokenize=True) 32 | if not todd.Store.DRY_RUN: 33 | config.callbacks = [ 34 | dict( 35 | type='LogCallback', 36 | interval=50, 37 | collect_env=dict(), 38 | with_file_handler=True, 39 | eta=dict(type='EMA_ETA', ema=dict(decay=0.9)), 40 | ), 41 | ] 42 | return super().callbacks_build_pre_hook(config, registry, item) 43 | 44 | def _run_iter(self, batch: Batch, memo: Memo, *args, **kwargs) -> Memo: 45 | if todd.Store.DRY_RUN: 46 | return super()._run_iter(batch, memo, *args, **kwargs) 47 | model = cast(BaseModel, self.strategy.module) 48 | original_image = batch['original_image'] 49 | image = batch['image'] 50 | if todd.Store.cuda: 51 | original_image = original_image.cuda() 52 | image = image.cuda() 53 | memo.update(original_image=original_image, image=image) 54 | _, memo = model.encode_to_quant(image, memo) 55 | return memo 56 | -------------------------------------------------------------------------------- /vq/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import pathlib 4 | 5 | from todd.configs import PyConfig 6 | from todd.patches.py_ import DictAction 7 | from todd.utils import init_seed 8 | from torch import nn 9 | 10 | from .registries import VQRunnerRegistry 11 | from .runners import BaseMixin as BaseRunnerMixin 12 | from .utils import log 13 | 14 | 15 | def parse_args() -> argparse.Namespace: 16 | parser = argparse.ArgumentParser(description='Train') 17 | parser.add_argument('name') 18 | parser.add_argument('config', type=pathlib.Path) 19 | parser.add_argument('--config-options', action=DictAction, default=dict()) 20 | parser.add_argument('--override', action=DictAction, default=dict()) 21 | parser.add_argument('--seed', type=int, default=3407) 22 | parser.add_argument('--autocast', action='store_true') 23 | parser.add_argument('--load-model-from', nargs='+', default=[]) 24 | parser.add_argument('--load-from') 25 | parser.add_argument('--auto-resume', action='store_true') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main() -> None: 31 | args = parse_args() 32 | config = PyConfig.load(args.config, **args.config_options) 33 | config.override(args.override) 34 | init_seed(args.seed) 35 | 36 | for custom_import in config.get('custom_imports', []): 37 | importlib.import_module(custom_import) 38 | 39 | trainer: BaseRunnerMixin[nn.Module] = VQRunnerRegistry.build( 40 | config.trainer, 41 | name=args.name, 42 | load_from=args.load_from, 43 | auto_resume=args.auto_resume, 44 | autocast=args.autocast, 45 | ) 46 | log(trainer, args, config) 47 | if args.load_model_from: 48 | trainer.strategy.load_model_from(args.load_model_from, strict=False) 49 | trainer.run() 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /vq/datasets/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Batch', 3 | 'BaseMixin', 4 | ] 5 | 6 | from abc import ABC 7 | from typing import TypedDict, TypeVar 8 | 9 | import torch 10 | from todd.datasets import BaseDataset 11 | 12 | KT = TypeVar('KT') 13 | VT = TypeVar('VT') 14 | 15 | 16 | class T(TypedDict): 17 | id_: str 18 | original_image: torch.Tensor 19 | image: torch.Tensor 20 | category: int 21 | 22 | 23 | class Batch(TypedDict): 24 | id_: list[str] 25 | original_image: torch.Tensor 26 | image: torch.Tensor 27 | category: torch.Tensor 28 | 29 | 30 | class BaseMixin(BaseDataset[T, KT, VT], ABC): 31 | 32 | def __init__( 33 | self, 34 | *args, 35 | name: str, 36 | num_categories: int, 37 | image_size: int, 38 | fid_path: str | None = None, 39 | **kwargs, 40 | ) -> None: 41 | super().__init__(*args, **kwargs) 42 | self._name = name 43 | self._num_categories = num_categories 44 | self._image_size = image_size 45 | self._fid_path = fid_path 46 | 47 | @property 48 | def name(self) -> str: 49 | return self._name 50 | 51 | @property 52 | def num_categories(self) -> int: 53 | return self._num_categories 54 | 55 | @property 56 | def image_size(self) -> int: 57 | return self._image_size 58 | 59 | @property 60 | def fid_path(self) -> str: 61 | if self._fid_path is None: 62 | return f'pretrained/fid/{self._name}.pth' 63 | return self._fid_path 64 | 65 | @classmethod 66 | def encode(cls, images: torch.Tensor) -> torch.Tensor: 67 | return images / 127.5 - 1.0 68 | 69 | @classmethod 70 | def decode(cls, images: torch.Tensor) -> torch.Tensor: 71 | images = (images + 1) * 127.5 72 | images = images.clamp(0, 255).to(torch.uint8) 73 | return images 74 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "args": [ 5 | "vqgan_16384_imagenet", 6 | "configs/vqgan_16384_imagenet.py" 7 | ], 8 | "env": { 9 | "DRY_RUN": "True" 10 | }, 11 | "module": "vq.train", 12 | "name": "VQGAN 16384 ImageNet", 13 | "request": "launch", 14 | "type": "debugpy" 15 | }, 16 | { 17 | "args": [ 18 | "vqkd_clip_8192_imagenet", 19 | "configs/vqkd_clip_8192_imagenet.py" 20 | ], 21 | "env": { 22 | "DRY_RUN": "True" 23 | }, 24 | "module": "vq.train", 25 | "name": "VQ-KD CLIP 8192 ImageNet", 26 | "request": "launch", 27 | "type": "debugpy" 28 | }, 29 | { 30 | "justMyCode": false, 31 | "listen": { 32 | "host": "0.0.0.0", 33 | "port": 5678 34 | }, 35 | "name": "Run VQ-KD CLIP 8192 ImageNet DDP", 36 | "pathMappings": [ 37 | { 38 | "localRoot": "${workspaceFolder}", 39 | "remoteRoot": "." 40 | } 41 | ], 42 | "preLaunchTask": "VQ-KD CLIP 8192 ImageNet DDP", 43 | "request": "attach", 44 | "type": "debugpy" 45 | }, 46 | { 47 | "args": [ 48 | "cvqvae_8192_imagenet", 49 | "configs/cvqvae_8192_imagenet.py" 50 | ], 51 | "env": { 52 | "DRY_RUN": "True" 53 | }, 54 | "module": "vq.train", 55 | "name": "CVQ-VAE 8192 ImageNet", 56 | "request": "launch", 57 | "type": "debugpy" 58 | }, 59 | { 60 | "args": [ 61 | "cvqvae_8192_imagenet", 62 | "configs/cvqvae_8192_imagenet.py" 63 | ], 64 | "env": { 65 | "DRY_RUN": "True" 66 | }, 67 | "module": "vq.val", 68 | "name": "Val CVQ-VAE 8192 ImageNet", 69 | "request": "launch", 70 | "type": "debugpy" 71 | } 72 | ], 73 | "version": "0.2.0" 74 | } 75 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/teachers/clip.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'CLIPTeacher', 3 | ] 4 | 5 | import os 6 | 7 | import clip.model 8 | import todd 9 | import torch 10 | from todd.bases.registries import Item 11 | 12 | from vq.utils import Store 13 | 14 | from ..registries import VQTeacherRegistry 15 | from .base import BaseTeacher 16 | 17 | 18 | @VQTeacherRegistry.register_() 19 | class CLIPTeacher(BaseTeacher[clip.model.CLIP]): 20 | 21 | def __init__( 22 | self, 23 | *args, 24 | with_proj: bool = True, 25 | **kwargs, 26 | ) -> None: 27 | super().__init__( 28 | *args, 29 | mean=todd.datasets.CLIP_MEAN_255, 30 | std=todd.datasets.CLIP_STD_255, 31 | **kwargs, 32 | ) 33 | self._with_proj = with_proj 34 | 35 | if not with_proj: 36 | self.visual.proj = None 37 | 38 | @classmethod 39 | def model_build_pre_hook( 40 | cls, 41 | config: todd.Config, 42 | registry: todd.RegistryMeta, 43 | item: Item, 44 | ) -> todd.Config: 45 | config = super().model_build_pre_hook(config, registry, item) 46 | model = config.model 47 | if isinstance(model, todd.Config): 48 | model.setdefault('name', 'pretrained/clip/ViT-B-16.pt') 49 | model.name = os.path.join(Store.PRETRAINED, model.name) 50 | model, _ = clip.load(**model) 51 | config.model = model 52 | return config 53 | 54 | @property 55 | def visual(self) -> clip.model.VisionTransformer: 56 | return self._model.visual 57 | 58 | @property 59 | def out_channels(self) -> int: 60 | if self._with_proj: 61 | return self.visual.output_dim 62 | return self.visual.class_embedding.numel() 63 | 64 | def _forward(self, image: torch.Tensor, return_2d: bool) -> torch.Tensor: 65 | return self._model.encode_image(image, return_2d) 66 | -------------------------------------------------------------------------------- /vq/tasks/image_classification/optimizers.py: -------------------------------------------------------------------------------- 1 | # TODO 2 | # flake8: noqa 3 | 4 | __all__ = [ 5 | 'LARSOptimizer', 6 | ] 7 | 8 | import torch 9 | from todd.registries import OptimizerRegistry 10 | 11 | 12 | @OptimizerRegistry.register_() 13 | class LARSOptimizer(torch.optim.Optimizer): 14 | 15 | def __init__( 16 | self, 17 | params, 18 | lr=0, 19 | weight_decay=0, 20 | momentum=0.9, 21 | trust_coefficient=0.001, 22 | ): 23 | defaults = dict( 24 | lr=lr, 25 | weight_decay=weight_decay, 26 | momentum=momentum, 27 | trust_coefficient=trust_coefficient, 28 | ) 29 | super().__init__(params, defaults) 30 | 31 | @torch.no_grad() 32 | def step(self): 33 | for g in self.param_groups: 34 | for p in g['params']: 35 | dp = p.grad 36 | 37 | if dp is None: 38 | continue 39 | 40 | if p.ndim > 1: # if not normalization gamma/beta or bias 41 | dp = dp.add(p, alpha=g['weight_decay']) 42 | param_norm = torch.norm(p) 43 | update_norm = torch.norm(dp) 44 | one = torch.ones_like(param_norm) 45 | q = torch.where( 46 | param_norm > 0., 47 | torch.where( 48 | update_norm > 0, ( 49 | g['trust_coefficient'] * param_norm 50 | / update_norm 51 | ), one 52 | ), one 53 | ) 54 | dp = dp.mul(q) 55 | 56 | param_state = self.state[p] 57 | if 'mu' not in param_state: 58 | param_state['mu'] = torch.zeros_like(p) 59 | mu = param_state['mu'] 60 | mu.mul_(g['momentum']).add_(dp) 61 | p.add_(mu, alpha=-g['lr']) 62 | -------------------------------------------------------------------------------- /vq/algorithms/utils/losses.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'CosineEmbeddingLoss', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from todd.bases.registries import Item 8 | from todd.models import losses 9 | 10 | from vq.models import VQLossRegistry 11 | 12 | 13 | @VQLossRegistry.register_() 14 | class CosineEmbeddingLoss(losses.BaseLoss): 15 | 16 | def __init__( 17 | self, 18 | *args, 19 | cosine_embedding: losses.CosineEmbeddingLoss, 20 | **kwargs, 21 | ) -> None: 22 | super().__init__(*args, **kwargs) 23 | self._cosine_embedding = cosine_embedding 24 | 25 | @classmethod 26 | def cosine_embedding_build_pre_hook( 27 | cls, 28 | config: todd.Config, 29 | registry: todd.RegistryMeta, 30 | item: Item, 31 | ) -> todd.Config: 32 | config.cosine_embedding = losses.CosineEmbeddingLoss( 33 | reduction='none', 34 | **config.cosine_embedding, 35 | ) 36 | return config 37 | 38 | @classmethod 39 | def build_pre_hook( 40 | cls, 41 | config: todd.Config, 42 | registry: todd.RegistryMeta, 43 | item: Item, 44 | ) -> todd.Config: 45 | config = super().build_pre_hook(config, registry, item) 46 | config = cls.cosine_embedding_build_pre_hook(config, registry, item) 47 | return config 48 | 49 | def forward( # pylint: disable=arguments-differ 50 | self, 51 | pred_image: torch.Tensor, 52 | image: torch.Tensor, 53 | ) -> torch.Tensor: 54 | assert pred_image.shape == image.shape 55 | shape = pred_image.shape 56 | pred_image = pred_image.flatten(0, -2) 57 | image = image.flatten(0, -2) 58 | target = pred_image.new_ones(pred_image.shape[0]) 59 | loss: torch.Tensor = self._cosine_embedding( 60 | pred_image, 61 | image, 62 | target, 63 | ) 64 | loss = loss.reshape(shape[:-1]) 65 | return self._reduce(loss) 66 | -------------------------------------------------------------------------------- /vq/algorithms/ar/transformers/hf.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'HFTransformer', 3 | ] 4 | 5 | import enum 6 | from abc import abstractmethod 7 | from typing import TypeVar 8 | 9 | import todd 10 | import todd.tasks.large_multimodal_model as lmm 11 | import torch 12 | from todd.bases.registries import Item 13 | from todd.runners import Memo 14 | from transformers import PreTrainedModel 15 | 16 | from .base import BaseTransformer, KVCache 17 | 18 | T = TypeVar('T', bound=enum.Enum) 19 | 20 | 21 | class HFTransformer(BaseTransformer): 22 | 23 | def __init__(self, *args, transformer: PreTrainedModel, **kwargs) -> None: 24 | super().__init__(*args, **kwargs) 25 | transformer.resize_token_embeddings(self._vocabulary_size) 26 | self._transformer = transformer 27 | 28 | @classmethod 29 | @abstractmethod 30 | def transformer_build_pre_hook( 31 | cls, 32 | config: todd.Config, 33 | registry: todd.RegistryMeta, 34 | item: Item, 35 | ) -> todd.Config: 36 | pass 37 | 38 | @classmethod 39 | def build_pre_hook( 40 | cls, 41 | config: todd.Config, 42 | registry: todd.RegistryMeta, 43 | item: Item, 44 | ) -> todd.Config: 45 | config = super().build_pre_hook(config, registry, item) 46 | config = cls.transformer_build_pre_hook(config, registry, item) 47 | return config 48 | 49 | def init_weights(self, config: todd.Config) -> bool: 50 | return False 51 | 52 | def _inference( 53 | self, 54 | tokens: torch.Tensor, 55 | kv_cache: KVCache | None, 56 | memo: Memo, 57 | ) -> tuple[torch.Tensor, KVCache, Memo]: 58 | output = self._transformer(tokens, past_key_values=kv_cache) 59 | return output['logits'], output['past_key_values'], memo 60 | 61 | def forward( 62 | self, 63 | data: lmm.InterleavedData[T], 64 | memo: Memo, 65 | ) -> tuple[torch.Tensor, Memo]: 66 | tokens = data.tokens 67 | output = self._transformer(tokens, labels=tokens) 68 | memo['logits'] = output['logits'] 69 | return output['loss'], memo 70 | -------------------------------------------------------------------------------- /vq/algorithms/cluster/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'Cluster', 3 | ] 4 | 5 | from typing import Literal, TypeVar 6 | 7 | import todd 8 | import torch 9 | from todd.runners import Memo 10 | from todd.runners.callbacks import TensorBoardCallback 11 | from torch import nn 12 | 13 | from vq import VQModelRegistry 14 | from vq.datasets import Batch 15 | from vq.runners import BaseMixin as BaseRunnerMixin 16 | from vq.tasks.image_tokenization.models import BaseModel as BaseITModel 17 | from vq.utils import get_memo 18 | 19 | T = TypeVar('T', bound=nn.Module) 20 | 21 | 22 | @VQModelRegistry.register_() 23 | class Cluster(BaseITModel): 24 | 25 | def forward( 26 | self, 27 | runner: BaseRunnerMixin[T], 28 | batch: Batch, 29 | memo: Memo, 30 | *args, 31 | mode: Literal['train'] | None, 32 | **kwargs, 33 | ) -> Memo: 34 | log: Memo | None = memo.get('log') 35 | tensorboard: TensorBoardCallback[T] | None = memo.get('tensorboard') 36 | 37 | original_image: torch.Tensor = batch['original_image'] 38 | image: torch.Tensor = batch['image'] 39 | if todd.Store.cuda: 40 | original_image = original_image.cuda() 41 | image = image.cuda() 42 | memo.update(original_image=original_image, image=image) 43 | 44 | encoder_memo = get_memo(memo, 'encoder') 45 | encoder_memo['original_image'] = original_image 46 | x, memo = self.encode(image, memo) 47 | memo['x'] = x 48 | 49 | z, q_loss, memo = self.quantize(x, memo) 50 | memo.update(z=z, loss=q_loss) 51 | 52 | losses = dict(q_loss=q_loss) 53 | if 'loss' in memo['quantizer']: 54 | losses.update(memo['quantizer']['loss']) 55 | if log is not None: 56 | log.update({k: f'{v:.3f}' for k, v in losses.items()}) 57 | if tensorboard is not None: 58 | for k, v in losses.items(): 59 | tensorboard.summary_writer.add_scalar( 60 | tensorboard.tag(k), 61 | v.float(), 62 | runner.iter_, 63 | ) 64 | 65 | return memo 66 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/losses/discriminator.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseDiscriminatorLoss', 3 | 'VQGANDiscriminatorLoss', 4 | 'R1GradientPenalty', 5 | ] 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | import todd 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from ..discriminators import BaseDiscriminator 14 | from .registries import VQDiscriminatorLossRegistry 15 | 16 | 17 | class BaseDiscriminatorLoss(todd.models.losses.BaseLoss, ABC): 18 | 19 | @abstractmethod 20 | def forward( # pylint: disable=arguments-differ 21 | self, 22 | logits_fake: torch.Tensor, 23 | logits_real: torch.Tensor, 24 | ) -> torch.Tensor: 25 | pass 26 | 27 | 28 | @VQDiscriminatorLossRegistry.register_() 29 | class VQGANDiscriminatorLoss(BaseDiscriminatorLoss): 30 | 31 | def forward( 32 | self, 33 | logits_fake: torch.Tensor, 34 | logits_real: torch.Tensor, 35 | ) -> torch.Tensor: 36 | loss_fake = F.relu(1. + logits_fake) 37 | loss_real = F.relu(1. - logits_real) 38 | loss = (loss_fake + loss_real) / 2 39 | return self._reduce(loss) 40 | 41 | 42 | class R1GradientPenalty(todd.models.losses.BaseLoss): 43 | """https://arxiv.org/abs/1801.04406.""" 44 | 45 | def forward( # pylint: disable=arguments-differ 46 | self, 47 | discriminator: BaseDiscriminator, 48 | image: torch.Tensor, 49 | ) -> torch.Tensor: 50 | image = image.clone().requires_grad_() 51 | 52 | training = { 53 | module: module.training 54 | for module in discriminator.modules() 55 | } 56 | discriminator.eval() 57 | logits_real = discriminator(image) 58 | for module, mode in training.items(): 59 | module.training = mode 60 | 61 | gradients, = torch.autograd.grad( 62 | logits_real, 63 | image, 64 | torch.ones_like(logits_real), 65 | create_graph=True, 66 | ) 67 | gradient_penalty: torch.Tensor = gradients.norm(2, (1, 2, 3)) 68 | gradient_penalty = gradient_penalty.pow(2) 69 | return self._reduce(gradient_penalty) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vector Quantization 2 | 3 | This repository is the official implementation of "[Image Understanding Makes for A Good Tokenizer for Image Generation](https://arxiv.org/abs/2411.04406)". 4 | 5 | ![Static Badge](https://img.shields.io/badge/NeurIPS-2024-purple) 6 | 7 | 8 | 9 | ## Overview 10 | 11 | ![Overview](docs/assets/fig1.jpg) 12 | 13 | Image understanding (IU) and image generation (IG) have long been central to computer vision research. While many studies explore how IG models can aid IU, few investigate the reverse—**using IU models to enhance IG**. 14 | 15 | This work bridges the gap by introducing IU-based tokenizers in the AutoRegressive (AR) IG framework. Specifically, we evaluate the following tokenizers: 16 | 17 | - [VQGAN](configs/vqgan/README.md) 18 | - [CVQ-VAE](configs/cvq_vae/README.md) 19 | - [FSQ](configs/fsq/README.md) 20 | - [VQ-KD](configs/vqkd/README.md) 21 | - [Cluster](configs/cluster/README.md) 22 | 23 | The VQ-KD and Cluster tokenizers leverage pretrained models such as CLIP, delivering superior results compared to traditional tokenizers. The following sections provice detailed instructions for training and validating these tokenizers. 24 | 25 | ## Preparation 26 | 27 | Please follow [data.md](docs/data.md) and [installation.md](docs/installation.md) to prepare the data and environment. 28 | 29 | Use [pretrained_models.md](docs/pretrained_models.md) to download the pretrained models. 30 | 31 | Generate the FID cache as described in [metrics.md](docs/metrics.md#cache). 32 | 33 | ## Framework 34 | 35 | Please refer to [training.md](docs/training.md) and [validation.md](docs/validation.md) for detailed instructions on training and validating the tokenizers. The model card is available in [model_card.md](docs/model_card.md). 36 | 37 | ## Acknowledgments 38 | 39 | This project draws inspiration from the following works: 40 | 41 | - [VQGAN](https://github.com/CompVis/taming-transformers) 42 | - [BEiT v2](https://github.com/microsoft/unilm/tree/master/beit2) 43 | - [CVQ-VAE](https://github.com/lyndonzheng/CVQ-VAE) 44 | - [LlamaGen](https://github.com/FoundationVision/LlamaGen) 45 | 46 | For a full list of influential works, please refer to our paper. 47 | -------------------------------------------------------------------------------- /configs/cluster/runner.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | 3 | VQITMetricRegistry = 'VQMetricRegistry.VQITMetricRegistry' 4 | trainer = dict( 5 | type='BaseTrainer', 6 | callbacks=[ 7 | dict(type='OptimizeCallback'), 8 | dict( 9 | type='LogCallback', 10 | interval=50, 11 | collect_env=dict(), 12 | with_file_handler=True, 13 | eta=dict(type='EMA_ETA', ema=dict(decay=0.9)), 14 | priority=dict(init=-1), 15 | ), 16 | dict(type='GitCallback', diff='HEAD'), 17 | dict( 18 | type='TensorBoardCallback', 19 | interval=50, 20 | summary_writer=dict(), 21 | main_tag='train', 22 | ), 23 | dict(type='CheckpointCallback', interval=1e4), 24 | ], 25 | optimizer=dict( 26 | type='Adam', 27 | lr=5.4e-5, 28 | betas=(0.5, 0.9), 29 | params=[ 30 | dict( 31 | params=dict( 32 | type='NamedParametersFilter', 33 | modules=dict(type='NamedModulesFilter', name='_quantizer'), 34 | ), 35 | ), 36 | ], 37 | ), 38 | iters=1e5, 39 | ) 40 | validator = dict( 41 | type='BaseValidator', 42 | callbacks=[ 43 | dict( 44 | type='MetricCallback', 45 | metrics=dict( 46 | loss=dict( 47 | type='ReadyMadeMetric', 48 | attr='["loss"]', 49 | ), 50 | codebook_usage=dict( 51 | type=f'{VQITMetricRegistry}.CodebookUsageMetric', 52 | quant='["quantizer"]["quant"]', 53 | ), 54 | codebook_ppl=dict( 55 | type=f'{VQITMetricRegistry}.CodebookPPLMetric', 56 | quant='["quantizer"]["quant"]', 57 | ), 58 | ), 59 | ), 60 | dict( 61 | type='LogCallback', 62 | interval=50, 63 | collect_env=dict(), 64 | with_file_handler=True, 65 | eta=dict(type='EMA_ETA', ema=dict(decay=0.9)), 66 | ), 67 | ], 68 | ) 69 | 70 | _export_ = dict(trainer=trainer, validator=validator) 71 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/models/connectors/composed.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ComposedConnector', 3 | ] 4 | 5 | from typing import cast 6 | 7 | import todd 8 | import torch 9 | from todd.bases.registries import BuildPreHookMixin, Item 10 | from todd.patches.torch import Sequential 11 | from todd.runners import Memo 12 | 13 | from ..registries import VQITConnectorRegistry 14 | from .base import BaseConnector 15 | 16 | 17 | @VQITConnectorRegistry.register_() 18 | class ComposedConnector(BuildPreHookMixin, BaseConnector): 19 | 20 | def __init__(self, *args, connectors: Sequential, **kwargs) -> None: 21 | super().__init__(*args, **kwargs) 22 | assert cast(BaseConnector, connectors[0]).in_channels == \ 23 | self._in_channels 24 | assert cast(BaseConnector, connectors[-1]).out_channels == \ 25 | self._out_channels 26 | self._connectors = connectors 27 | 28 | @classmethod 29 | def connectors_build_pre_hook( 30 | cls, 31 | config: todd.Config, 32 | registry: todd.RegistryMeta, 33 | item: Item, 34 | ) -> todd.Config: 35 | in_channels = config.in_channels 36 | 37 | connector: BaseConnector 38 | connectors: list[BaseConnector] = [] 39 | for connector in config.connectors: 40 | connector = VQITConnectorRegistry.build_or_return( 41 | connector, 42 | in_channels=in_channels, 43 | out_channels=config.out_channels, 44 | ) 45 | in_channels = connector.out_channels 46 | connectors.append(connector) 47 | 48 | config.connectors = Sequential(*connectors, unpack_args=True) 49 | return config 50 | 51 | @classmethod 52 | def build_pre_hook( 53 | cls, 54 | config: todd.Config, 55 | registry: todd.RegistryMeta, 56 | item: Item, 57 | ) -> todd.Config: 58 | config = super().build_pre_hook(config, registry, item) 59 | config = cls.connectors_build_pre_hook(config, registry, item) 60 | return config 61 | 62 | def forward( 63 | self, 64 | x: torch.Tensor, 65 | memo: Memo, 66 | ) -> tuple[torch.Tensor, Memo]: 67 | return self._connectors(x, memo) 68 | -------------------------------------------------------------------------------- /vq/datasets/concat.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ConcatDataset', 3 | ] 4 | 5 | from typing import Any, Never 6 | 7 | import todd 8 | import torch 9 | from todd.bases.registries import Item 10 | from todd.datasets.access_layers import BaseAccessLayer 11 | from todd.registries import DatasetRegistry 12 | 13 | from ..registries import VQDatasetRegistry 14 | from .base import BaseMixin 15 | 16 | 17 | class PseudoAccessLayer(BaseAccessLayer[Any, Never]): 18 | 19 | def __getitem__(self, *args, **kwargs) -> Never: 20 | raise NotImplementedError 21 | 22 | def __setitem__(self, *args, **kwargs) -> Never: 23 | raise NotImplementedError 24 | 25 | def __delitem__(self, *args, **kwargs) -> Never: 26 | raise NotImplementedError 27 | 28 | def __len__(self, *args, **kwargs) -> Never: 29 | raise NotImplementedError 30 | 31 | def __iter__(self, *args, **kwargs) -> Never: 32 | raise NotImplementedError 33 | 34 | @property 35 | def exists(self) -> Never: 36 | raise NotImplementedError 37 | 38 | def touch(self, *args, **kwargs) -> Never: 39 | raise NotImplementedError 40 | 41 | 42 | @VQDatasetRegistry.register_() 43 | class ConcatDataset(BaseMixin[Any, Never], torch.utils.data.ConcatDataset): 44 | 45 | @classmethod 46 | def datasets_build_pre_hook( 47 | cls, 48 | config: todd.Config, 49 | registry: todd.RegistryMeta, 50 | item: Item, 51 | ) -> todd.Config: 52 | config = cls.transforms_build_pre_hook(config, registry, item) 53 | config.datasets = [ 54 | DatasetRegistry.build_or_return( 55 | dataset, 56 | name=config.name, 57 | num_categories=config.num_categories, 58 | image_size=config.image_size, 59 | transforms=config.transforms, 60 | ) for dataset in config.datasets 61 | ] 62 | return config 63 | 64 | @classmethod 65 | def build_pre_hook( 66 | cls, 67 | config: todd.Config, 68 | registry: todd.RegistryMeta, 69 | item: Item, 70 | ) -> todd.Config: 71 | config = cls.datasets_build_pre_hook(config, registry, item) 72 | return super().build_pre_hook(config, registry, item) 73 | -------------------------------------------------------------------------------- /configs/ic/runner.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | _kwargs_: dict[str, Any] 4 | _kwargs_ = dict(_kwargs_) 5 | iters = _kwargs_.get('iters', 30_000) 6 | 7 | trainer = dict( 8 | type='BaseTrainer', 9 | callbacks=[ 10 | dict(type='OptimizeCallback'), 11 | dict( 12 | type='LRScheduleCallback', 13 | lr_scheduler=dict(type='CosineAnnealingLR', T_max=iters), 14 | interval=1, 15 | ), 16 | dict( 17 | type='LogCallback', 18 | interval=50, 19 | collect_env=dict(), 20 | with_file_handler=True, 21 | eta=dict(type='EMA_ETA', ema=dict(decay=0.9)), 22 | priority=dict(init=-1), 23 | ), 24 | dict(type='GitCallback', diff='HEAD'), 25 | dict( 26 | type='TensorBoardCallback', 27 | interval=50, 28 | summary_writer=dict(), 29 | main_tag='train', 30 | ), 31 | dict(type='CheckpointCallback', interval=1e3), 32 | ], 33 | optimizer=dict( 34 | type='LARSOptimizer', 35 | params=[ 36 | dict( 37 | params=dict( 38 | type='NamedParametersFilter', 39 | modules=dict( 40 | type='NamedModulesFilter', 41 | name='_head', 42 | ), 43 | ), 44 | ), 45 | ], 46 | lr=1.6, 47 | ), 48 | iters=iters, 49 | ) 50 | validator = dict( 51 | type='BaseValidator', 52 | callbacks=[ 53 | dict( 54 | type='MetricCallback', 55 | metrics=dict( 56 | accuracy=dict( 57 | type='AccuracyMetric', 58 | topk=5, 59 | logits='["logits"]', 60 | target='["category"]', 61 | ), 62 | loss=dict( 63 | type='ReadyMadeMetric', 64 | attr='["loss"]', 65 | ), 66 | ), 67 | ), 68 | dict( 69 | type='LogCallback', 70 | interval=1, 71 | collect_env=dict(), 72 | eta=dict(type='EMA_ETA', ema=dict(decay=0.9)), 73 | with_file_handler=True, 74 | ), 75 | ], 76 | ) 77 | 78 | _export_ = dict(trainer=trainer, validator=validator) 79 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/runners/metrics.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'CodebookMixin', 3 | 'CodebookUsageMetric', 4 | 'CodebookPPLMetric', 5 | ] 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import TypeVar, cast 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from todd.patches.py_ import get_ 13 | from todd.runners import Memo 14 | from todd.runners.metrics import BaseMetric 15 | from torch import nn 16 | 17 | from vq.datasets import Batch 18 | 19 | from ..models import BaseModel 20 | from .registries import VQITMetricRegistry 21 | 22 | T = TypeVar('T', bound=nn.Module) 23 | 24 | 25 | class CodebookMixin(BaseMetric[T], ABC): 26 | 27 | def __init__(self, *args, quant: str, **kwargs) -> None: 28 | super().__init__(*args, **kwargs) 29 | self._quant = quant 30 | self._counts: torch.Tensor | int = 0 31 | 32 | def bind(self, *args, **kwargs) -> None: 33 | super().bind(*args, **kwargs) 34 | module = cast(BaseModel, self.runner.strategy.module) 35 | self._codebook_size = module.quantizer.codebook_size 36 | 37 | def forward(self, batch: Batch, memo: Memo) -> Memo: 38 | quant: torch.Tensor = get_(memo, self._quant) 39 | quant = quant.flatten() 40 | counts = torch.bincount( 41 | quant, 42 | minlength=self._codebook_size, 43 | ) 44 | self._counts = self._counts + counts 45 | return memo 46 | 47 | @abstractmethod 48 | def _summary(self, memo: Memo, counts: torch.Tensor) -> float: 49 | pass 50 | 51 | def summary(self, memo: Memo) -> float: 52 | if isinstance(self._counts, int): 53 | return 0. 54 | counts = self._counts.clone() 55 | dist.all_reduce(counts) 56 | return self._summary(memo, counts) 57 | 58 | 59 | @VQITMetricRegistry.register_() 60 | class CodebookUsageMetric(CodebookMixin[T], BaseMetric[T]): 61 | 62 | def _summary(self, memo: Memo, counts: torch.Tensor) -> float: 63 | return counts.bool().sum().item() / self._codebook_size 64 | 65 | 66 | @VQITMetricRegistry.register_() 67 | class CodebookPPLMetric(CodebookMixin[T], BaseMetric[T]): 68 | 69 | def _summary(self, memo: Memo, counts: torch.Tensor) -> float: 70 | probabilities = counts / counts.sum() 71 | categorical = torch.distributions.Categorical(probabilities) 72 | entropy: torch.Tensor = categorical.entropy() 73 | return entropy.item() 74 | -------------------------------------------------------------------------------- /docs/validation.md: -------------------------------------------------------------------------------- 1 | # Validation 2 | 3 | The validation command follows a similar pattern as the training command: 4 | 5 | ```bash 6 | auto_torchrun -m vq.val ${EXPERIMENT_NAME} ${CONFIG_NAME} \ 7 | --config-options ... \ 8 | --override ... \ 9 | --visual ... \ 10 | --autocast \ 11 | --load-model-from ... \ 12 | --load-from ... 13 | ``` 14 | 15 | `--config-options`, `--override`, `--autocast`, and `--load-model-from` have the same meanings as in the training command. 16 | 17 | If the `--visual` argument is present, visualizations will be saved under `work_dirs/${EXPERIMENT_NAME}/unbatched_visuals`. The value of `--visual` specifies a regex to filter the images to visualize. In most cases, `--visual pred_image` is sufficient to save the reconstructed images. 18 | 19 | In `vq.train`, `--load-from` is used to resume training. However, in `vq.val`, `--load-from` is used to specify the checkpoints to validate. By default, `vq.val` automatically validates all checkpoints found under `work_dirs/${EXPERIMENT_NAME}/checkpoints`. If `--load-from iter_{15..26}0000` is provided, only `work_dirs/${EXPERIMENT_NAME}/iter_{15..26}0000` are validated. 20 | 21 | To construct a validation command from the training command, simply replace `vq.train` with `vq.val`. For example, the validation command for VQ-KD decoderis: 22 | 23 | ```bash 24 | # auto_torchrun -m vq.train \ 25 | # decoder/llamagen/vqkd_clip_8192_imagenet_ddp/llamagen_8192_dd2_aglwg075_imagenet_ddp \ 26 | # configs/decoder/llamagen.py \ 27 | # --config-options it_config::configs/vqkd/clip_8192_imagenet_ddp.py \ 28 | # --load-model-from work_dirs/vqkd/clip_8192_imagenet_ddp/checkpoints/iter_250000/model.pth 29 | 30 | auto_torchrun -m vq.val \ 31 | decoder/llamagen/vqkd_clip_8192_imagenet_ddp/llamagen_8192_dd2_aglwg075_imagenet_ddp \ 32 | configs/decoder/llamagen.py \ 33 | --config-options it_config::configs/vqkd/clip_8192_imagenet_ddp.py \ 34 | --load-model-from work_dirs/vqkd/clip_8192_imagenet_ddp/checkpoints/iter_250000/model.pth 35 | ``` 36 | 37 | To test a single checkpoint, we also provie the `vq.test` command: 38 | 39 | ```bash 40 | auto_torchrun -m vq.test ${EXPERIMENT_NAME} ${CONFIG_NAME} \ 41 | --config-options ... \ 42 | --override ... \ 43 | --visual ... \ 44 | --autocast \ 45 | --load-model-from ... 46 | ``` 47 | 48 | Compared to `va.val`, the `--load-from` argument is removed. The `vq.test` command validates a single checkpoint specified by `--load-model-from`. 49 | -------------------------------------------------------------------------------- /vq/utils/misc.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'get_memo', 3 | 'device', 4 | 'log', 5 | 'load', 6 | 'todd_version', 7 | ] 8 | 9 | import argparse 10 | import functools 11 | import importlib 12 | import pathlib 13 | import sys 14 | from typing import TYPE_CHECKING, Iterable, Literal, TypeVar, cast 15 | 16 | import todd 17 | import torch 18 | from todd.configs import PyConfig 19 | from todd.patches.torch import get_rank, load_state_dict, load_state_dict_ 20 | from todd.registries import ModelRegistry 21 | from todd.runners import Memo 22 | from torch import nn 23 | 24 | if TYPE_CHECKING: 25 | from ..runners import BaseMixin as BaseRunnerMixin 26 | 27 | T = TypeVar('T', bound=nn.Module) 28 | 29 | 30 | def get_memo(memo: Memo, key: str) -> Memo: 31 | memo_: Memo 32 | if key in memo: 33 | memo_ = memo[key] 34 | assert isinstance(memo_, dict) 35 | else: 36 | memo_ = dict() 37 | memo[key] = memo_ 38 | return memo_ 39 | 40 | 41 | @functools.lru_cache(1) 42 | def device() -> str: 43 | return 'cuda' if todd.Store.cuda else 'cpu' 44 | 45 | 46 | def log( 47 | runner: 'BaseRunnerMixin[T]', 48 | args: argparse.Namespace, 49 | config: PyConfig, 50 | ) -> None: 51 | if get_rank() != 0: 52 | return 53 | 54 | runner.logger.info("Command\n" + ' '.join(sys.argv)) 55 | runner.logger.info(f"Args\n{vars(args)}") 56 | runner.logger.info(f"Config\n{config.dumps()}") 57 | 58 | if 'config' in args: 59 | config_name = cast(pathlib.Path, args.config).name 60 | PyConfig(config).dump(runner.work_dir / config_name) 61 | 62 | 63 | def load( 64 | *args, 65 | config: str, 66 | runner_type: Literal['trainer', 'validator'] = 'validator', 67 | state_dicts: Iterable[torch.serialization.FILE_LIKE], 68 | **kwargs, 69 | ) -> nn.Module: 70 | config_ = PyConfig.load(config) 71 | for custom_import in config_.get('custom_imports', []): 72 | importlib.import_module(custom_import) 73 | config_ = config_[runner_type].model 74 | model: nn.Module = ModelRegistry.build(config_) 75 | 76 | state_dict = load_state_dict_(state_dicts) # type: ignore[arg-type] 77 | load_state_dict(model, state_dict, *args, **kwargs) 78 | return model 79 | 80 | 81 | @todd.utils.EnvRegistry.register_(force=True) 82 | def todd_version(verbose: bool = False) -> str: 83 | with open('.todd_version') as f: 84 | commit_id = f.read().strip() 85 | version = f'{todd.utils.todd_version()}+{commit_id}' 86 | return version 87 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/models/c2i.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'C2I', 3 | ] 4 | 5 | from typing import TypeVar 6 | 7 | import todd 8 | import todd.tasks.large_multimodal_model as lmm 9 | import torch 10 | from todd.bases.registries import Item 11 | from todd.runners import Memo 12 | from torch import nn 13 | 14 | from vq.tasks.image_reconstruction import BaseModel as BaseIRModel 15 | 16 | from ..runners import BaseMixin as BaseRunnerMixin 17 | from .registries import VQSMModelRegistry 18 | from .x2i import X2I 19 | 20 | T = TypeVar('T', bound=nn.Module) 21 | 22 | 23 | @VQSMModelRegistry.register_() 24 | class C2I(X2I[lmm.C2IEnum]): 25 | 26 | @classmethod 27 | def vocabulary_size_build_pre_hook( 28 | cls, 29 | config: todd.Config, 30 | registry: todd.RegistryMeta, 31 | item: Item, 32 | ) -> todd.Config: 33 | if 'vocabulary_size' in config: 34 | return config 35 | config = cls.ir_build_pre_hook( 36 | config, 37 | registry, 38 | item, 39 | ) 40 | cfg = bool(config.get('cfg')) 41 | ir: BaseIRModel = config.ir 42 | config.vocabulary_size = config.num_categories + cfg + ir.codebook_size 43 | return config 44 | 45 | @classmethod 46 | def build_pre_hook( 47 | cls, 48 | config: todd.Config, 49 | registry: todd.RegistryMeta, 50 | item: Item, 51 | ) -> todd.Config: 52 | config = cls.vocabulary_size_build_pre_hook(config, registry, item) 53 | return super().build_pre_hook(config, registry, item) 54 | 55 | @property 56 | def num_categories(self) -> int: 57 | return self._num_categories + bool(self._cfg) 58 | 59 | def uncondition_tokens( 60 | self, 61 | condition_tokens: torch.Tensor, 62 | ) -> torch.Tensor: 63 | return torch.full_like(condition_tokens, self._num_categories) 64 | 65 | def _data( 66 | self, 67 | runner: BaseRunnerMixin[T], 68 | memo: Memo, 69 | ) -> tuple[lmm.C2IData, Memo]: 70 | image_tokens, memo = self.encode_image_tokens(memo['image'], memo) 71 | memo['image_tokens'] = image_tokens 72 | 73 | category_tokens = memo['category'] 74 | if self._cfg is not None: 75 | category_tokens, memo = self.dropout_tokens(category_tokens, memo) 76 | memo['category_tokens'] = category_tokens 77 | 78 | data = lmm.C2IData( 79 | category_tokens, 80 | image_tokens, 81 | self.num_categories, 82 | self._ir.codebook_size, 83 | ) 84 | return data, memo 85 | -------------------------------------------------------------------------------- /tools/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import pathlib 4 | 5 | import todd.tasks.image_generation as ig 6 | from todd.configs import PyConfig 7 | from todd.patches.py_ import DictAction 8 | from todd.patches.torch import get_rank, load 9 | from todd.registries import RunnerRegistry 10 | from torch import nn 11 | 12 | from vq.runners import BaseValidator 13 | from vq.utils import Store, log 14 | 15 | 16 | def parse_args() -> argparse.Namespace: 17 | parser = argparse.ArgumentParser(description='Test') 18 | parser.add_argument('reference') 19 | parser.add_argument('data_root') 20 | parser.add_argument('--strategy', default='cuda') 21 | parser.add_argument('--override', action=DictAction, default=dict()) 22 | parser.add_argument('--autocast', action='store_true') 23 | parser.add_argument('--work-dir', type=pathlib.Path) 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def main() -> None: 29 | args = parse_args() 30 | reference: ig.Statistics = load( 31 | f'pretrained/fid/{args.reference}.pth', 32 | directory=Store.PRETRAINED, 33 | ) 34 | 35 | config: PyConfig = PyConfig.load( 36 | 'configs/fid/interface.py', 37 | dataset='vanilla', 38 | strategy=args.strategy, 39 | ) 40 | config = config.validator 41 | 42 | if args.work_dir is None: 43 | fid_path = '/dev/null' 44 | else: 45 | work_dir: pathlib.Path = args.work_dir 46 | fid_path = args.work_dir / 'fid.pth' 47 | 48 | config.update( 49 | dataset=dict( 50 | name=__file__.replace('/', '_'), 51 | num_categories=1, 52 | fid_path=fid_path, 53 | access_layer=dict( 54 | data_root=args.data_root, 55 | subfolder_action='none', 56 | suffix='png', 57 | ), 58 | ), 59 | ) 60 | config.override(args.override) 61 | 62 | name: str = args.data_root 63 | name = name.replace('/', '_') 64 | 65 | runner: BaseValidator[nn.Module] = RunnerRegistry.build( 66 | config, 67 | name=f'fid/{name}', 68 | autocast=args.autocast, 69 | ) 70 | log(runner, args, config) 71 | memo = runner.run() 72 | 73 | fid = ig.fid(reference, memo['statistics']) 74 | 75 | if get_rank() == 0: 76 | runner.logger.info(f'FID: {fid}') 77 | if args.work_dir is not None: 78 | output_file = work_dir / 'fid.txt' 79 | with open(output_file, 'a') as f: 80 | f.write(f"{datetime.datetime.now()}\n") 81 | f.write(f"{args=}\n") 82 | f.write(f"{fid=}\n") 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /vq/tasks/sequence_modeling/models/transformers.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseTransformer', 3 | ] 4 | 5 | import enum 6 | from abc import ABC, abstractmethod 7 | from typing import TypeVar 8 | 9 | import todd 10 | import todd.tasks.large_multimodal_model as lmm 11 | import torch 12 | from todd.bases.registries import BuildPreHookMixin, Item 13 | from todd.runners import Memo 14 | from torch import nn 15 | 16 | from vq.utils import get_memo 17 | 18 | from .registries import VQSMSamplerRegistry 19 | from .samplers import BaseSampler 20 | 21 | T = TypeVar('T', bound=enum.Enum) 22 | 23 | 24 | class BaseTransformer(BuildPreHookMixin, nn.Module, ABC): 25 | 26 | def __init__( 27 | self, 28 | *args, 29 | vocabulary_size: int, 30 | sampler: BaseSampler, 31 | **kwargs, 32 | ) -> None: 33 | super().__init__(*args, **kwargs) 34 | self._vocabulary_size = vocabulary_size 35 | self._sampler = sampler 36 | 37 | @classmethod 38 | def sampler_build_pre_hook( 39 | cls, 40 | config: todd.Config, 41 | registry: todd.RegistryMeta, 42 | item: Item, 43 | ) -> todd.Config: 44 | config.sampler = VQSMSamplerRegistry.build_or_return(config.sampler) 45 | return config 46 | 47 | @classmethod 48 | def build_pre_hook( 49 | cls, 50 | config: todd.Config, 51 | registry: todd.RegistryMeta, 52 | item: Item, 53 | ) -> todd.Config: 54 | config = super().build_pre_hook(config, registry, item) 55 | config = cls.sampler_build_pre_hook(config, registry, item) 56 | return config 57 | 58 | def sample( 59 | self, 60 | logits: torch.Tensor, 61 | codebook: lmm.Codebook[T], 62 | memo: Memo, 63 | ) -> tuple[torch.Tensor, Memo]: 64 | tokens, memo['sampler'] = self._sampler( 65 | logits, 66 | codebook.start, 67 | codebook.end, 68 | get_memo(memo, 'sampler'), 69 | ) 70 | return tokens, memo 71 | 72 | @abstractmethod 73 | def _generate( 74 | self, 75 | tokens: torch.Tensor, 76 | length: int, 77 | codebook: lmm.Codebook[T], 78 | memo: Memo, 79 | ) -> tuple[torch.Tensor, Memo]: 80 | pass 81 | 82 | @torch.no_grad() 83 | def generate( 84 | self, 85 | tokens: torch.Tensor, 86 | length: int, 87 | codebook: lmm.Codebook[T], 88 | memo: Memo, 89 | ) -> tuple[torch.Tensor, Memo]: 90 | return self._generate(tokens, length, codebook, memo) 91 | 92 | @abstractmethod 93 | def forward( 94 | self, 95 | data: lmm.InterleavedData[T], 96 | memo: Memo, 97 | ) -> tuple[torch.Tensor, Memo]: 98 | pass 99 | -------------------------------------------------------------------------------- /vq/algorithms/vqkd/teachers/base.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'BaseTeacher', 3 | ] 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Generic, TypeVar 7 | 8 | import einops 9 | import todd 10 | import torch 11 | import torch.nn.functional as F 12 | from todd.bases.registries import BuildPreHookMixin, Item 13 | from todd.models import MeanStdMixin 14 | from torch import nn 15 | 16 | T = TypeVar('T', bound=nn.Module) 17 | 18 | 19 | class BaseTeacher(BuildPreHookMixin, MeanStdMixin, ABC, Generic[T]): 20 | 21 | def __init__( 22 | self, 23 | *args, 24 | model: T, 25 | downsample_factor: int, 26 | image_wh: tuple[int, int] | None = None, 27 | output_wh: tuple[int, int] | None = None, 28 | **kwargs, 29 | ) -> None: 30 | super().__init__(*args, **kwargs) 31 | if image_wh is None and output_wh is not None: 32 | image_wh = ( 33 | output_wh[0] * downsample_factor, 34 | output_wh[1] * downsample_factor, 35 | ) 36 | self._model = model 37 | self._downsample_factor = downsample_factor 38 | self._image_wh = image_wh 39 | self._output_wh = output_wh 40 | 41 | @classmethod 42 | def model_build_pre_hook( 43 | cls, 44 | config: todd.Config, 45 | registry: todd.RegistryMeta, 46 | item: Item, 47 | ) -> todd.Config: 48 | config.setdefault('model', todd.Config()) 49 | return config 50 | 51 | @classmethod 52 | def build_pre_hook( 53 | cls, 54 | config: todd.Config, 55 | registry: todd.RegistryMeta, 56 | item: Item, 57 | ) -> todd.Config: 58 | config = super().build_pre_hook(config, registry, item) 59 | config = cls.model_build_pre_hook(config, registry, item) 60 | return config 61 | 62 | @property 63 | @abstractmethod 64 | def out_channels(self) -> int: 65 | pass 66 | 67 | @abstractmethod 68 | def _forward(self, image: torch.Tensor, return_2d: bool) -> torch.Tensor: 69 | pass 70 | 71 | def forward( 72 | self, 73 | original_image: torch.Tensor, 74 | return_2d: bool = False, 75 | ) -> torch.Tensor: 76 | image = self.normalize(original_image) 77 | 78 | if self._image_wh is not None: 79 | w, h = self._image_wh 80 | image = F.interpolate(image, (h, w), mode='bicubic') 81 | 82 | x = self._forward(image, return_2d or self._output_wh is not None) 83 | x = x.float() 84 | 85 | if self._output_wh is not None: 86 | w, h = self._output_wh 87 | x = F.interpolate(x, (h, w), mode='bicubic') 88 | 89 | if not return_2d and x.ndim != 3: 90 | x = einops.rearrange(x, 'b c h w -> b (h w) c') 91 | 92 | return x 93 | -------------------------------------------------------------------------------- /configs/vqkd/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from todd.configs import PyConfig 4 | 5 | _kwargs_: dict[str, Any] 6 | _kwargs_ = dict(_kwargs_) 7 | 8 | _kwargs_.setdefault('embedding_dim', 32) 9 | _kwargs_.setdefault('distance', 'Cosine') 10 | 11 | _base_ = [ 12 | PyConfig.load('configs/vq/interface.py', **_kwargs_), 13 | 'teachers/interface.py', 14 | ] 15 | 16 | model = dict( 17 | type='VQModelRegistry.VQKD', 18 | encoder=dict(type='VQKDEncoder'), 19 | post_encode=dict(type='BaseConnector'), 20 | quantizer=dict( 21 | type='VQKDQuantizer', 22 | callbacks=[dict(type='VQKDCallback', ema=dict())], 23 | losses=dict( 24 | commitment_loss=dict(type='CommitmentLoss', mse=dict(norm=True)), 25 | ), 26 | ), 27 | pre_decode=dict(type='BaseConnector'), 28 | decoder=dict(type='VQKDDecoder', depth=1), 29 | distiller=dict( 30 | type='VQKDDistiller', 31 | teacher=dict(output_wh=(14, 14)), 32 | teacher_hook_pipeline=dict( 33 | target_features=dict( 34 | type='SingleOperator', 35 | args=tuple(), 36 | atom=dict( 37 | type=( 38 | 'TaskRegistry.KDRegistry.KDDistillerRegistry.' 39 | 'KDHookRegistry.Hook' 40 | ), 41 | path='', 42 | ), 43 | ), 44 | ), 45 | student_hook_pipeline=dict(), 46 | adapt_pipeline=dict( 47 | pred_features=dict( 48 | type='SingleOperator', 49 | args=('pred_features', ), 50 | atom=dict( 51 | type=( 52 | 'TaskRegistry.KDRegistry.KDDistillerRegistry.' 53 | 'KDAdaptRegistry.Model' 54 | ), 55 | model=dict( 56 | type='einops_layers_torch_Rearrange', 57 | pattern='b c h w -> b (h w) c', 58 | ), 59 | ), 60 | ), 61 | ), 62 | loss_pipeline=dict( 63 | r_loss=dict( 64 | type='SingleOperator', 65 | args=('pred_features', 'target_features'), 66 | atom=dict( 67 | type=( 68 | 'ModelRegistry.LossRegistry.VQLossRegistry.' 69 | 'CosineEmbeddingLoss' 70 | ), 71 | cosine_embedding=dict(), 72 | ), 73 | ), 74 | ), 75 | ), 76 | no_grad=dict( 77 | type='NamedParametersFilter', 78 | modules=dict( 79 | type='NamedModulesFilter', 80 | name='_quantizer', 81 | ), 82 | ), 83 | ) 84 | runner = dict(model=model) 85 | 86 | _export_ = dict(trainer=runner, validator=runner) 87 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Tokenizers 2 | 3 | The training command is as follows: 4 | 5 | ```bash 6 | auto_torchrun -m vq.train ${EXPERIMENT_NAME} ${CONFIG_NAME} \ 7 | --config-options ... \ 8 | --override ... \ 9 | --autocast \ 10 | --load-model-from ... \ 11 | --load-from ... \ 12 | --auto-resume 13 | ``` 14 | 15 | The `auto_torchrun` command is installed by the `todd_ai` package and is equivalent to `torchrun --nproc-per-node=${GPUS} --master-port=${PORT}`. You can always use `torchrun` as a workaround should `auto_torchrun` fail. 16 | 17 | Checkout `work_dirs/${EXPERIMENT_NAME}` for the training products. Specifically, the checkpoints are stored under `work_dirs/${EXPERIMENT_NAME}/checkpoints`. 18 | 19 | The `${CONFIG_NAME}` argument follows the format: 20 | 21 | ```text 22 | configs/{model}/{codebook size}_{architecture detail}_{dataset}_{strategy}.py 23 | ``` 24 | 25 | - `model` is the name of the tokenizer. For example, `vqgan`, `cvqvae`, `fsq`, `vqkd`, or `cluster`. 26 | - `codebook size` is the number of tokens in that can be used by the tokenizer. By default, VQ-KD uses 8192 tokens. 27 | - `architecture detail` is the specific model architecture used in the tokenizer. For example, `dd2_aglwg075` refers to a model with the depth of discriminator being `2` and the adaptive generator loss weight gain being `0.75`. 28 | - `dataset` is usually just `imagenet`. 29 | - `strategy` is the parallel strategy used for training. Both `ddp` and `fsdp` are supported. 30 | 31 | All other arguments are optional: 32 | 33 | - `--config-options` and `--override` are related to config files: 34 | - `--config-options` passes options to the config file. 35 | - `--override` overrides the config file at runtime. 36 | - `--autocast` enables automatic mixed precision training. 37 | - `--load-model-from` specifies pretrained models to be loaded. For example, training a pixel decoder for VQ-KD requires loading the pretrained VQ-KD tokenizer. 38 | - `--load-from` and `--auto-resume` enables resumption of training. 39 | - `--load-from work_dirs/${EXPERIMENT_NAME}/checkpoints/iter_${n}` resumes training from iteration `n`. 40 | - `--auto-resume` automatically resumes training from the latest checkpoint. 41 | 42 | > If a training script uses `--load-model-from` and either `--load-from` or `--auto-resume`, the override `--override .trainer.callbacks[-1].load_state_dict:dict\(strict=False\)` should be specified. 43 | 44 | This project adopts a two-stage framework: 45 | 46 | - Tokenizers encodes images into tokens. Decoders are included for reconstructing images from tokens. 47 | - [VQGAN](../configs/vqgan/README.md) 48 | - [CVQ-VAE](../configs/cvqvae/README.md) 49 | - [FSQ](../configs/fsq/README.md) 50 | - [VQ-KD](../configs/vqkd/README.md) 51 | - [Cluster](../configs/cluster/README.md) 52 | - [decoder](../configs/decoder/README.md) for both VQ-KD and Cluster 53 | - Proposal Networks generates image tokens for image synthesis. 54 | - [AR](../configs/ar/README.md) 55 | -------------------------------------------------------------------------------- /vq/tasks/image_tokenization/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | from typing import Iterator 4 | 5 | from todd.configs import PyConfig 6 | from todd.datasets.access_layers import PILAccessLayer 7 | from todd.patches.py_ import DictAction 8 | from todd.patches.torch import get_world_size 9 | 10 | from vq import VQRunnerRegistry 11 | from vq.datasets import Dataset 12 | from vq.utils import log 13 | 14 | from ...runners.base import BaseValidator 15 | 16 | # TODO 17 | 18 | 19 | class SingletonAccessLayer(PILAccessLayer): 20 | 21 | def __init__(self, *args, singleton: str | pathlib.Path, **kwargs) -> None: 22 | if isinstance(singleton, str): 23 | singleton = pathlib.Path(singleton) 24 | singleton = singleton.absolute() 25 | super().__init__( 26 | *args, 27 | data_root=singleton.parent, 28 | suffix=singleton.suffix.removeprefix('.'), 29 | **kwargs, 30 | ) 31 | 32 | self._singleton = singleton 33 | 34 | def _files(self) -> Iterator[pathlib.Path]: 35 | yield self._singleton 36 | 37 | 38 | def parse_args() -> argparse.Namespace: 39 | parser = argparse.ArgumentParser(description='Inference') 40 | parser.add_argument('name') 41 | parser.add_argument('config', type=pathlib.Path) 42 | parser.add_argument('singleton') 43 | parser.add_argument('--config-options', action=DictAction, default=dict()) 44 | parser.add_argument('--override', action=DictAction, default=dict()) 45 | parser.add_argument('--load-model-from', required=True, nargs='+') 46 | parser.add_argument('--save') 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | def main() -> None: 52 | assert get_world_size() <= 1 53 | 54 | args = parse_args() 55 | 56 | config: PyConfig = PyConfig.load(args.config, **args.config_options) 57 | config = PyConfig( 58 | type=config.validator.type, 59 | strategy=config.validator.strategy, 60 | model=config.validator.model, 61 | callbacks=[], 62 | dataset=dict( 63 | type=Dataset.__name__, 64 | name=config.validator.dataset.name, 65 | num_categories=config.validator.dataset.num_categories, 66 | image_size=config.validator.dataset.image_size, 67 | access_layer=SingletonAccessLayer(singleton=args.singleton), 68 | transforms=config.validator.dataset.transforms, 69 | ), 70 | dataloader=dict(batch_size=1, num_workers=0), 71 | name=f'{args.name}_inference', 72 | save=args.save, 73 | ) 74 | 75 | config.override(args.override) 76 | 77 | runner: BaseValidator = VQRunnerRegistry.build(config) 78 | log(runner, args, config) 79 | runner.strategy.load_model_from(args.load_model_from, strict=False) 80 | memo = runner.run() 81 | runner.logger.info("\n%s", memo['quantize']['quant']) # TODO 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /docs/pretrained_models.md: -------------------------------------------------------------------------------- 1 | # Pretrained Models 2 | 3 | The following checkpoints are required to run the code: 4 | 5 | ```bash 6 | python tools/prepare_checkpoints.py pytorch_fid 7 | 8 | python tools/prepare_checkpoints.py lpips 9 | python tools/convert_checkpoints.py lpips pretrained/lpips/vgg.pth 10 | 11 | python tools/prepare_checkpoints.py torchvision --weights .VGG16_Weights.DEFAULT 12 | ``` 13 | 14 | The following checkpoints are the IU models that are used in the paper: 15 | 16 | ```bash 17 | 18 | python tools/prepare_checkpoints.py clip --weights ViT-B/16 19 | python tools/prepare_checkpoints.py dino 20 | python tools/prepare_checkpoints.py torchvision --weights .ViT_B_16_Weights.DEFAULT 21 | python tools/prepare_checkpoints.py mae 22 | ``` 23 | 24 | The following checkpoints are used to initialize the AR proposal networks: 25 | 26 | ```bash 27 | python tools/prepare_checkpoints.py huggingface 28 | ``` 29 | 30 | VQGAN and VQ-KD checkpoints from their original repo can be loaded by our code, after conversion: 31 | 32 | ```bash 33 | python tools/prepare_checkpoints.py taming_transformers 34 | python tools/convert_checkpoints.py taming_transformers pretrained/taming-transformers/vqgan_imagenet_f16_1024.pth --check configs/vqgan/1024_imagenet_ddp.py 35 | python tools/convert_checkpoints.py taming_transformers pretrained/taming-transformers/vqgan_imagenet_f16_16384.pth --check configs/vqgan/16384_dd2_aglwg075_imagenet_ddp.py 36 | 37 | python tools/prepare_checkpoints.py beitv2 38 | python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_clip.pth 39 | python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_clip.pth --check configs/vqkd/clip_8192_imagenet_ddp.py --suffix .converted.with_decoder --with-decoder 40 | python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_dino.pth 41 | python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_dino.pth --check configs/vqkd/dino_8192_imagenet_ddp.py --suffix .converted.with_decoder --with-decoder 42 | ``` 43 | 44 | After generating the FID cache, you can run the following command to validate the pretrained models: 45 | 46 | ```bash 47 | auto_torchrun -m vq.test vqgan/16384_dd2_aglwg075_imagenet_ddp configs/vqgan/16384_dd2_aglwg075_imagenet_ddp.py --load-model-from pretrained/taming-transformers/vqgan_imagenet_f16_16384.pth.converted --visual pred_image 48 | # {'lpips_loss': 0.28323277831077576, 'l1_image_loss': 0.06811775267124176, 'mse_image_loss': 0.013179616071283817, 'psnr': 19.970359802246094, 'ssim': 0.5023356676101685, 'fid': 4.980832106065748, 'codebook_usage': 0.059326171875, 'codebook_ppl': 6.812368392944336} 49 | 50 | auto_torchrun -m vq.test vqkd/clip_8192_imagenet_ddp configs/vqkd/clip_8192_imagenet_ddp.py --load-model-from pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_clip.pth.converted.with_decoder 51 | # {'cosine_embedding_r_loss': 0.16431047022342682, 'codebook_usage': 1.0, 'codebook_ppl': 8.94822883605957} 52 | ``` 53 | -------------------------------------------------------------------------------- /vq/algorithms/vqgan/discriminators/patchgan.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'PatchGANDiscriminator', 3 | ] 4 | 5 | import todd 6 | import torch 7 | from torch import nn 8 | 9 | from ..registries import VQDiscriminatorRegistry 10 | from .base import BaseDiscriminator 11 | 12 | # TODO: refactor 13 | 14 | 15 | @VQDiscriminatorRegistry.register_() 16 | class PatchGANDiscriminator(BaseDiscriminator): 17 | 18 | def __init__( 19 | self, 20 | *args, 21 | in_channels: int = 3, 22 | width: int = 64, 23 | depth: int = 3, 24 | kernel_size: int = 4, 25 | padding: int = 1, 26 | **kwargs, 27 | ) -> None: 28 | super().__init__(*args, **kwargs) 29 | sequence: list[nn.Module] = [ 30 | nn.Conv2d( 31 | in_channels, 32 | width, 33 | kernel_size=kernel_size, 34 | stride=2, 35 | padding=padding, 36 | ), 37 | nn.LeakyReLU(0.2, True), 38 | ] 39 | nf_mult = 1 40 | nf_mult_prev = 1 41 | for n in range(1, depth): 42 | # gradually increase the number of filters 43 | nf_mult_prev = nf_mult 44 | nf_mult = min(2**n, 8) 45 | sequence += [ 46 | nn.Conv2d( 47 | width * nf_mult_prev, 48 | width * nf_mult, 49 | kernel_size=kernel_size, 50 | stride=2, 51 | padding=padding, 52 | bias=False, 53 | ), 54 | nn.BatchNorm2d(width * nf_mult), 55 | nn.LeakyReLU(0.2, True), 56 | ] 57 | 58 | nf_mult_prev = nf_mult 59 | nf_mult = min(2**depth, 8) 60 | sequence += [ 61 | nn.Conv2d( 62 | width * nf_mult_prev, 63 | width * nf_mult, 64 | kernel_size=kernel_size, 65 | stride=1, 66 | padding=padding, 67 | bias=False, 68 | ), 69 | nn.BatchNorm2d(width * nf_mult), 70 | nn.LeakyReLU(0.2, True), 71 | ] 72 | 73 | sequence += [ 74 | nn.Conv2d( 75 | width * nf_mult, 76 | 1, 77 | kernel_size=kernel_size, 78 | stride=1, 79 | padding=padding, 80 | ), 81 | ] 82 | self._discriminator = nn.Sequential(*sequence) 83 | 84 | def init_weights(self, config: todd.Config) -> bool: 85 | 86 | def weights_init(m): 87 | classname = m.__class__.__name__ 88 | if classname.find('Conv') != -1: 89 | nn.init.normal_(m.weight.data, 0.0, 0.02) 90 | elif classname.find('BatchNorm') != -1: 91 | nn.init.normal_(m.weight.data, 1.0, 0.02) 92 | nn.init.constant_(m.bias.data, 0) 93 | 94 | # todd.logger.debug(f'Initializing {self.__class__.__name__} weights') 95 | self.apply(weights_init) 96 | return False 97 | 98 | def forward(self, image: torch.Tensor) -> torch.Tensor: 99 | return self._discriminator(image) 100 | --------------------------------------------------------------------------------