├── .gitignore ├── LICENSE ├── LICENSE.apache-2.0 ├── LICENSE.cc-by-nc-sa-4.0 ├── README.md ├── assets └── figures │ ├── T2I_samples │ ├── Eiffel tower on a desert._temp_1.0_top_k_1024_top_p_0.95.jpg │ ├── Eiffel tower on a mountain._temp_1.0_top_k_1024_top_p_0.95.jpg │ ├── a painting by RENÉ MAGRITTE_temp_1.0_top_k_1024_top_p_0.95.jpg │ ├── a painting by Vincent Van Gogh_temp_1.0_top_k_1024_top_p_0.95.jpg │ ├── a painting of a cat with sunglasses in the frame._temp_1.0_top_k_1024_top_p_0.95.jpg │ └── a painting of a dog with sunglasses in the frame._temp_1.0_top_k_1024_top_p_0.95.jpg │ ├── overview_figure.png │ ├── sampling_speed_comparison.png │ └── teaser.png ├── compute_metrics.py ├── compute_rfid.py ├── configs ├── cc3m │ └── cc3m-rqtransformer-8x8x4-650M.yaml ├── ffhq │ ├── stage1 │ │ └── ffhq256-rqvae-8x8x4.yaml │ └── stage2 │ │ └── ffhq256-rqtransformer-8x8x4-350M.yaml ├── imagenet256 │ ├── stage1 │ │ └── in256-rqvae-8x8x4.yaml │ └── stage2 │ │ ├── in256-rqtransformer-8x8x4-1400M.yaml │ │ ├── in256-rqtransformer-8x8x4-3800M.yaml │ │ ├── in256-rqtransformer-8x8x4-480M.yaml │ │ └── in256-rqtransformer-8x8x4-800M.yaml ├── lsun-bedroom │ ├── stage1 │ │ └── bedroom256-rqvae-8x8x4.yaml │ └── stage2 │ │ └── bedroom256-rqtransformer-8x8x4-600M.yaml ├── lsun-cat │ ├── stage1 │ │ └── cat256-rqvae-8x8x4.yaml │ └── stage2 │ │ └── cat256-rqtransformer-8x8x4-600M.yaml └── lsun-church │ ├── stage1 │ └── church256-rqvae-8x8x4.yaml │ └── stage2 │ └── lsun-church256-sqgan-8x8x4-350M-simp.yaml ├── data ├── README.md └── cc3m │ ├── README.md │ └── download_cc3m.py ├── main_sampling_fid.py ├── main_sampling_txt2img.py ├── main_stage1.py ├── measure_throughput ├── __main__.py └── rq_defaults.yaml ├── notebooks ├── T2I_sampling.ipynb ├── notebook_utils.py └── rqvae ├── requirements.txt ├── rqvae ├── __init__.py ├── img_datasets │ ├── __init__.py │ ├── assets │ │ ├── ffhqtrain.txt │ │ └── ffhqvalidation.txt │ ├── ffhq.py │ ├── lsun.py │ └── transforms.py ├── losses │ ├── __init__.py │ └── vqgan │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── gan_loss.py │ │ ├── lpips.py │ │ └── lpips_utils.py ├── metrics │ ├── IS.py │ ├── __init__.py │ ├── assets │ │ ├── ImageNet_label.txt │ │ └── wnid_to_idx.yaml │ ├── clip_score.py │ ├── fid.py │ └── inception.py ├── models │ ├── __init__.py │ ├── ema.py │ ├── interfaces.py │ ├── rqtransformer │ │ ├── __init__.py │ │ ├── attentions.py │ │ ├── configs.py │ │ ├── primitives.py │ │ └── transformers.py │ └── rqvae │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── modules.py │ │ ├── quantizations.py │ │ └── rqvae.py ├── optimizer │ ├── __init__.py │ ├── loss.py │ ├── optimizer.py │ └── scheduler.py ├── trainers │ ├── __init__.py │ ├── accumulator.py │ ├── trainer.py │ └── trainer_rqvae.py ├── txtimg_datasets │ ├── __init__.py │ ├── cc3m.py │ ├── coco.py │ ├── tokenizers │ │ ├── __init__.py │ │ ├── pretrained │ │ │ ├── bert-base-uncased-vocab.txt │ │ │ ├── bpe-16k-merges.txt │ │ │ ├── bpe-16k-vocab.json │ │ │ ├── bpe-30k-merges.txt │ │ │ ├── bpe-30k-vocab.json │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ ├── merges.txt │ │ │ └── vocab.json │ │ ├── simple_tokenizer.py │ │ └── utils.py │ └── transforms.py └── utils │ ├── __init__.py │ ├── config.py │ ├── dist.py │ ├── profiler.py │ ├── setup.py │ ├── utils.py │ └── writer.py └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The `source codes` are licensed under [Apache 2.0](LICENSE.apache-2.0) License. 2 | The `stage2 pretrained weights` are licensed under [CC-BY-NC-SA 4.0](LICENSE.cc-by-nc-sa-4.0) License. -------------------------------------------------------------------------------- /assets/figures/T2I_samples/Eiffel tower on a desert._temp_1.0_top_k_1024_top_p_0.95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/T2I_samples/Eiffel tower on a desert._temp_1.0_top_k_1024_top_p_0.95.jpg -------------------------------------------------------------------------------- /assets/figures/T2I_samples/Eiffel tower on a mountain._temp_1.0_top_k_1024_top_p_0.95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/T2I_samples/Eiffel tower on a mountain._temp_1.0_top_k_1024_top_p_0.95.jpg -------------------------------------------------------------------------------- /assets/figures/T2I_samples/a painting by RENÉ MAGRITTE_temp_1.0_top_k_1024_top_p_0.95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/T2I_samples/a painting by RENÉ MAGRITTE_temp_1.0_top_k_1024_top_p_0.95.jpg -------------------------------------------------------------------------------- /assets/figures/T2I_samples/a painting by Vincent Van Gogh_temp_1.0_top_k_1024_top_p_0.95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/T2I_samples/a painting by Vincent Van Gogh_temp_1.0_top_k_1024_top_p_0.95.jpg -------------------------------------------------------------------------------- /assets/figures/T2I_samples/a painting of a cat with sunglasses in the frame._temp_1.0_top_k_1024_top_p_0.95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/T2I_samples/a painting of a cat with sunglasses in the frame._temp_1.0_top_k_1024_top_p_0.95.jpg -------------------------------------------------------------------------------- /assets/figures/T2I_samples/a painting of a dog with sunglasses in the frame._temp_1.0_top_k_1024_top_p_0.95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/T2I_samples/a painting of a dog with sunglasses in the frame._temp_1.0_top_k_1024_top_p_0.95.jpg -------------------------------------------------------------------------------- /assets/figures/overview_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/overview_figure.png -------------------------------------------------------------------------------- /assets/figures/sampling_speed_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/sampling_speed_comparison.png -------------------------------------------------------------------------------- /assets/figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/assets/figures/teaser.png -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | import logging 17 | from pathlib import Path 18 | 19 | from omegaconf import OmegaConf 20 | 21 | from rqvae.metrics import compute_fid, compute_IS, compute_clip_score 22 | 23 | 24 | DATASET_STATS_FOR_FID = { 25 | 'imagenet': 'assets/fid_stats/imagenet_256_train.npz', 26 | 'ffhq': 'assets/fid_stats/ffhq_256_train.npz', 27 | 'lsun_bedroom': 'assets/fid_stats/lsun_256_bedroom.npz', 28 | 'lsun_cat': 'assets/fid_stats/lsun_256_cat.npz', 29 | 'lsun_church': 'assets/fid_stats/lsun_256_church.npz', 30 | 'cc3m': 'assets/fid_stats/cc3m_256_val.npz', 31 | 'coco_2014val': 'assets/fid_stats/coco_256_val.npz', 32 | } 33 | 34 | 35 | def compute_metrics(fake_path, ref_dataset): 36 | results = {} 37 | 38 | ref_stat_path = DATASET_STATS_FOR_FID[ref_dataset] 39 | results['fid'] = compute_fid(fake_path, ref_stat_path) 40 | 41 | if ref_dataset == 'imagenet': 42 | IS_mean, IS_std = compute_IS(fake_path) 43 | results['IS_mean'] = IS_mean 44 | results['IS_std'] = IS_std 45 | 46 | if ref_dataset in ['cc3m']: 47 | results['clip_score'] = compute_clip_score(fake_path, dataset_name=ref_dataset) 48 | 49 | return results 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | @dataclass 55 | class Arguments: 56 | fake_path: str 57 | ref_dataset: str 58 | 59 | @staticmethod 60 | def verify(args): 61 | datasets = set(DATASET_STATS_FOR_FID.keys()) 62 | if args.ref_dataset not in datasets: 63 | raise ValueError(f"No dataset info found: {args.ref_dataset}") 64 | 65 | args = OmegaConf.structured(Arguments) 66 | args = OmegaConf.merge(args, OmegaConf.from_cli()) # type: Arguments 67 | 68 | Arguments.verify(args) 69 | 70 | log_path = Path(args.fake_path) 71 | log_path = log_path / 'metrics.log' 72 | 73 | logging.basicConfig( 74 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 75 | datefmt="%m/%d/%Y %H:%M:%S", 76 | level=logging.INFO, 77 | handlers=[logging.FileHandler(log_path), logging.StreamHandler()] 78 | ) 79 | 80 | logging.info('=' * 80) 81 | logging.info(f'{args}') 82 | 83 | results = compute_metrics(args.fake_path, args.ref_dataset) 84 | 85 | logging.info('=' * 80) 86 | -------------------------------------------------------------------------------- /compute_rfid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 16 | import logging 17 | import os 18 | 19 | import torch 20 | 21 | from rqvae.img_datasets import create_dataset 22 | from rqvae.models import create_model 23 | from rqvae.metrics.fid import compute_rfid 24 | from rqvae.utils.config import load_config, augment_arch_defaults 25 | 26 | 27 | def load_model(path, ema=False): 28 | 29 | model_config = os.path.join(os.path.dirname(path), 'config.yaml') 30 | config = load_config(model_config) 31 | config.arch = augment_arch_defaults(config.arch) 32 | 33 | model, _ = create_model(config.arch, ema=False) 34 | ckpt = torch.load(path)['state_dict_ema'] if ema else torch.load(path)['state_dict'] 35 | model.load_state_dict(ckpt) 36 | 37 | return model, config 38 | 39 | 40 | def setup_logger(result_path): 41 | log_fname = os.path.join(result_path, 'rfid.log') 42 | logging.basicConfig( 43 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 44 | datefmt="%m/%d/%Y %H:%M:%S", 45 | level=logging.INFO, 46 | handlers=[ 47 | logging.FileHandler(log_fname), logging.StreamHandler() 48 | ] 49 | ) 50 | logger = logging.getLogger(__name__) 51 | return logger 52 | 53 | 54 | if __name__ == '__main__': 55 | """ 56 | Computes rFID, i.e., FID between val images and reconstructed images. 57 | Log is saved to `rfid.log` in the same directory as the given vqvae model. 58 | """ 59 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 60 | parser.add_argument('--batch-size', type=int, default=100, 61 | help='Batch size to use') 62 | parser.add_argument('--split', type=str, default='val') 63 | parser.add_argument('--vqvae', type=str, default='', required=True, 64 | help='vqvae path for recon FID') 65 | 66 | args = parser.parse_args() 67 | 68 | result_path = os.path.dirname(args.vqvae) 69 | logger = setup_logger(result_path) 70 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 71 | 72 | vqvae_model, config = load_model(args.vqvae) 73 | vqvae_model = vqvae_model.to(device) 74 | vqvae_model = torch.nn.DataParallel(vqvae_model).eval() 75 | logger.info(f'vqvae model loaded from {args.vqvae}') 76 | 77 | dataset_trn, dataset_val = create_dataset(config, is_eval=True, logger=logger) 78 | dataset = dataset_val if args.split in ['val', 'valid'] else dataset_trn 79 | logger.info(f'measuring rFID on {config.dataset.type}/{args.split}') 80 | 81 | rfid = compute_rfid(dataset, vqvae_model, batch_size=args.batch_size, device=device) 82 | logger.info(f'rFID: {rfid:.4f}') 83 | -------------------------------------------------------------------------------- /configs/cc3m/cc3m-rqtransformer-8x8x4-650M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | dataset: cc3m 3 | txt_tok_name: bpe16k_huggingface 4 | vocab_size_txt: 16384 5 | vocab_size: 16384 6 | image_resolution: 256 7 | context_length: 32 8 | transforms: dalle-vqvae 9 | bpe_dropout: 0.1 10 | 11 | arch: 12 | type: rq-transformer 13 | block_size: [ 8, 8, 4 ] 14 | 15 | embed_dim: 1280 16 | input_embed_dim: 256 17 | shared_tok_emb: true 18 | shared_cls_emb: true 19 | 20 | input_emb_vqvae: true 21 | head_emb_vqvae: true 22 | cumsum_depth_ctx: true 23 | 24 | vocab_size_cond: 16384 25 | block_size_cond: 32 26 | 27 | body: 28 | n_layer: 26 29 | block: 30 | n_head: 20 31 | head: 32 | n_layer: 4 33 | block: 34 | n_head: 20 35 | 36 | loss: 37 | type: soft_target_cross_entropy 38 | stochastic_codes: true 39 | temp: 0.5 40 | txt_weight: 0.1 41 | img_weight: 0.9 42 | 43 | optimizer: 44 | type: adamW 45 | init_lr: 0.0005 46 | weight_decay: 0.0001 47 | betas: [0.9, 0.95] 48 | warmup: 49 | epoch: 0 50 | multiplier: 1 51 | buffer_epoch: 0 52 | min_lr: 0.0 53 | mode: fix 54 | start_from_zero: True 55 | max_gn: 1.0 56 | 57 | experiment: 58 | amp: True 59 | batch_size: 32 60 | total_batch_size: 2048 61 | epochs: 100 62 | save_ckpt_freq: 1 63 | test_freq: 1 64 | sample: 65 | top_k: 16384 66 | top_p: 0.7 67 | -------------------------------------------------------------------------------- /configs/ffhq/stage1/ffhq256-rqvae-8x8x4.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: ffhq 3 | transforms: 4 | type: ffhq256x256 5 | 6 | arch: 7 | type: rq-vae 8 | code_hier: 1 9 | hparams: 10 | bottleneck_type: rq 11 | embed_dim: 256 12 | n_embed: 2048 13 | latent_shape: [ 8, 8, 256 ] 14 | code_shape: [ 8, 8, 4 ] 15 | shared_codebook: true 16 | decay: 0.99 17 | restart_unused_codes: true 18 | 19 | loss_type: mse 20 | latent_loss_weight: 0.25 21 | ddconfig: 22 | double_z: false 23 | z_channels: 256 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: [ 1, 1, 2, 2, 4, 4] 29 | num_res_blocks: 2 30 | attn_resolutions: [ 16 ] 31 | dropout: 0.00 32 | checkpointing: true 33 | 34 | 35 | optimizer: 36 | # Original VQ-GAN: lr = 4.5e-06 * (batch size) -> 5.4e-5 for batch size 12 37 | type: adam 38 | init_lr: 4.0e-5 39 | weight_decay: 0.0 40 | betas: [0.5, 0.9] 41 | warmup: 42 | epoch: 5 # 5% of total epochs 43 | multiplier: 1 44 | buffer_epoch: 0 45 | min_lr: 4.0e-5 46 | mode: fix 47 | 48 | 49 | experiment: 50 | batch_size: 32 51 | epochs: 150 52 | save_ckpt_freq: 5 53 | test_freq: 1 54 | 55 | gan: 56 | disc: 57 | arch: 58 | in_channels: 3 59 | num_layers: 2 60 | use_actnorm: False 61 | ndf: 64 62 | spectral_norm: False 63 | 64 | loss: 65 | disc_loss: hinge 66 | gen_loss: vanilla 67 | disc_weight: 0.75 68 | perceptual_weight: 1.0 69 | disc_start: 0 -------------------------------------------------------------------------------- /configs/ffhq/stage2/ffhq256-rqtransformer-8x8x4-350M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: ffhq 3 | vocab_size: 2048 4 | transforms: 5 | type: ffhq256x256 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1024 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 24 25 | block: 26 | n_head: 16 27 | head: 28 | n_layer: 4 29 | block: 30 | n_head: 16 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 32 54 | total_batch_size: 128 55 | epochs: 200 56 | save_ckpt_freq: 10 57 | test_freq: 5 58 | sample: 59 | top_k: 250 60 | top_p: 1.0 61 | -------------------------------------------------------------------------------- /configs/imagenet256/stage1/in256-rqvae-8x8x4.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: imagenet 3 | transforms: 4 | type: imagenet256x256 5 | 6 | arch: 7 | type: rq-vae 8 | code_hier: 1 9 | hparams: 10 | bottleneck_type: rq 11 | embed_dim: 256 12 | n_embed: 16384 13 | latent_shape: [ 8, 8, 256 ] # could be inferred: H=W=resolution / (2 ** num_down), D=embed_dim 14 | code_shape: [ 8, 8, 4 ] 15 | shared_codebook: true 16 | decay: 0.99 17 | restart_unused_codes: true 18 | 19 | loss_type: mse 20 | latent_loss_weight: 0.25 21 | ddconfig: 22 | double_z: false 23 | z_channels: 256 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: [ 1, 1, 2, 2, 4, 4] 29 | num_res_blocks: 2 30 | attn_resolutions: [ 8 ] 31 | dropout: 0.00 32 | checkpointing: true 33 | 34 | 35 | optimizer: 36 | # Original VQ-GAN: lr = 4.5e-06 * (batch size) -> 5.4e-5 for batch size 12 37 | type: adam 38 | init_lr: 4.0e-5 39 | weight_decay: 0.0 40 | betas: [0.5, 0.9] 41 | warmup: 42 | epoch: 0.5 # 5% of total epochs 43 | multiplier: 1 44 | buffer_epoch: 0 45 | min_lr: 4.0e-5 46 | mode: fix 47 | 48 | 49 | experiment: 50 | batch_size: 32 51 | epochs: 10 52 | save_ckpt_freq: 5 53 | test_freq: 1 54 | 55 | gan: 56 | disc: 57 | arch: 58 | in_channels: 3 59 | num_layers: 2 60 | use_actnorm: False 61 | ndf: 64 62 | spectral_norm: False 63 | 64 | loss: 65 | disc_loss: hinge 66 | gen_loss: vanilla 67 | disc_weight: 0.75 68 | perceptual_weight: 1.0 69 | disc_start: 0 70 | -------------------------------------------------------------------------------- /configs/imagenet256/stage2/in256-rqtransformer-8x8x4-1400M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: imagenet 3 | vocab_size: 16384 4 | transforms: 5 | type: imagenet256x256 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1536 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1000 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 42 25 | block: 26 | n_head: 24 27 | head: 28 | n_layer: 6 29 | block: 30 | n_head: 24 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 8 54 | total_batch_size: 2048 55 | epochs: 100 56 | save_ckpt_freq: 2 57 | test_freq: 2 58 | sample: 59 | top_k: 16384 60 | top_p: 0.92 61 | -------------------------------------------------------------------------------- /configs/imagenet256/stage2/in256-rqtransformer-8x8x4-3800M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: imagenet 3 | vocab_size: 16384 4 | transforms: 5 | type: imagenet256x256 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 2560 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1000 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 42 25 | block: 26 | n_head: 40 27 | head: 28 | n_layer: 6 29 | block: 30 | n_head: 40 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 8 54 | total_batch_size: 2048 55 | epochs: 100 56 | save_ckpt_freq: 2 57 | test_freq: 2 58 | sample: 59 | top_k: 16384 60 | top_p: 0.92 61 | -------------------------------------------------------------------------------- /configs/imagenet256/stage2/in256-rqtransformer-8x8x4-480M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: imagenet 3 | vocab_size: 16384 4 | transforms: 5 | type: imagenet256x256 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1536 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1000 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 12 25 | block: 26 | n_head: 24 27 | head: 28 | n_layer: 4 29 | block: 30 | n_head: 24 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 32 54 | total_batch_size: 2048 55 | epochs: 100 56 | save_ckpt_freq: 2 57 | test_freq: 2 58 | sample: 59 | top_k: 16384 60 | top_p: 0.92 61 | -------------------------------------------------------------------------------- /configs/imagenet256/stage2/in256-rqtransformer-8x8x4-800M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: imagenet 3 | vocab_size: 16384 4 | transforms: 5 | type: imagenet256x256 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1536 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1000 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 24 25 | block: 26 | n_head: 24 27 | head: 28 | n_layer: 4 29 | block: 30 | n_head: 24 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 32 54 | total_batch_size: 2048 55 | epochs: 100 56 | save_ckpt_freq: 2 57 | test_freq: 2 58 | sample: 59 | top_k: 16384 60 | top_p: 0.92 61 | -------------------------------------------------------------------------------- /configs/lsun-bedroom/stage1/bedroom256-rqvae-8x8x4.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: LSUN-bedroom 3 | transforms: 4 | type: LSUN-bedroom 5 | 6 | arch: 7 | type: rq-vae 8 | code_hier: 1 9 | hparams: 10 | bottleneck_type: rq 11 | embed_dim: 256 12 | n_embed: 16384 13 | latent_shape: [ 8, 8, 256 ] 14 | code_shape: [ 8, 8, 4 ] 15 | shared_codebook: true 16 | decay: 0.99 17 | restart_unused_codes: true 18 | 19 | loss_type: mse 20 | latent_loss_weight: 0.25 21 | ddconfig: 22 | double_z: false 23 | z_channels: 256 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: [ 1, 1, 2, 2, 4, 4 ] 29 | num_res_blocks: 2 30 | attn_resolutions: [ 8 ] 31 | dropout: 0.00 32 | checkpointing: true 33 | 34 | 35 | optimizer: 36 | type: adam 37 | init_lr: 4.0e-5 38 | weight_decay: 0.0 39 | betas: [0.5, 0.9] 40 | warmup: 41 | epoch: 0 42 | multiplier: 1 43 | buffer_epoch: 0 44 | min_lr: 4.0e-5 45 | mode: fix 46 | 47 | 48 | experiment: 49 | batch_size: 32 50 | epochs: 1 51 | save_ckpt_freq: 1 52 | test_freq: 1 53 | 54 | gan: 55 | disc: 56 | arch: 57 | in_channels: 3 58 | num_layers: 2 59 | use_actnorm: False 60 | ndf: 64 61 | spectral_norm: False 62 | 63 | loss: 64 | disc_loss: hinge 65 | gen_loss: vanilla 66 | disc_weight: 0.75 67 | perceptual_weight: 1.0 68 | disc_start: 0 -------------------------------------------------------------------------------- /configs/lsun-bedroom/stage2/bedroom256-rqtransformer-8x8x4-600M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: LSUN-bedroom 3 | vocab_size: 16384 4 | transforms: 5 | type: LSUN-bedroom 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1280 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 26 25 | block: 26 | n_head: 20 27 | head: 28 | n_layer: 4 29 | block: 30 | n_head: 20 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 32 54 | total_batch_size: 2048 55 | epochs: 100 56 | save_ckpt_freq: 10 57 | test_freq: 50 58 | sample: 59 | top_k: 250 60 | top_p: 1.0 61 | -------------------------------------------------------------------------------- /configs/lsun-cat/stage1/cat256-rqvae-8x8x4.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: LSUN-cat 3 | transforms: 4 | type: LSUN-cat 5 | 6 | arch: 7 | type: rq-vae 8 | code_hier: 1 9 | hparams: 10 | bottleneck_type: rq 11 | embed_dim: 256 12 | n_embed: 16384 13 | latent_shape: [ 8, 8, 256 ] 14 | code_shape: [ 8, 8, 4 ] 15 | shared_codebook: true 16 | decay: 0.99 17 | restart_unused_codes: true 18 | 19 | loss_type: mse 20 | latent_loss_weight: 0.25 21 | ddconfig: 22 | double_z: false 23 | z_channels: 256 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: [ 1, 1, 2, 2, 4, 4 ] 29 | num_res_blocks: 2 30 | attn_resolutions: [ 8 ] 31 | dropout: 0.00 32 | checkpointing: true 33 | 34 | 35 | optimizer: 36 | type: adam 37 | init_lr: 4.0e-5 38 | weight_decay: 0.0 39 | betas: [0.5, 0.9] 40 | warmup: 41 | epoch: 0 42 | multiplier: 1 43 | buffer_epoch: 0 44 | min_lr: 4.0e-5 45 | mode: fix 46 | 47 | 48 | experiment: 49 | batch_size: 32 50 | epochs: 1 51 | save_ckpt_freq: 1 52 | test_freq: 1 53 | 54 | gan: 55 | disc: 56 | arch: 57 | in_channels: 3 58 | num_layers: 2 59 | use_actnorm: False 60 | ndf: 64 61 | spectral_norm: False 62 | 63 | loss: 64 | disc_loss: hinge 65 | gen_loss: vanilla 66 | disc_weight: 0.75 67 | perceptual_weight: 1.0 68 | disc_start: 0 -------------------------------------------------------------------------------- /configs/lsun-cat/stage2/cat256-rqtransformer-8x8x4-600M.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: LSUN-cat 3 | vocab_size: 16384 4 | transforms: 5 | type: LSUN-cat 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1280 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 26 25 | block: 26 | n_head: 20 27 | head: 28 | n_layer: 4 29 | block: 30 | n_head: 20 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 32 54 | total_batch_size: 2048 55 | epochs: 100 56 | save_ckpt_freq: 2 57 | test_freq: 50 58 | sample: 59 | top_k: 250 60 | top_p: 1.0 61 | -------------------------------------------------------------------------------- /configs/lsun-church/stage1/church256-rqvae-8x8x4.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: LSUN-church 3 | transforms: 4 | type: LSUN-church 5 | 6 | arch: 7 | type: rq-vae 8 | code_hier: 1 9 | hparams: 10 | bottleneck_type: rq 11 | embed_dim: 256 12 | n_embed: 16384 13 | latent_shape: [ 8, 8, 256 ] 14 | code_shape: [ 8, 8, 4 ] 15 | shared_codebook: true 16 | decay: 0.99 17 | restart_unused_codes: true 18 | 19 | loss_type: mse 20 | latent_loss_weight: 0.25 21 | ddconfig: 22 | double_z: false 23 | z_channels: 256 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: [ 1, 1, 2, 2, 4, 4 ] 29 | num_res_blocks: 2 30 | attn_resolutions: [ 8 ] 31 | dropout: 0.00 32 | checkpointing: true 33 | 34 | 35 | optimizer: 36 | type: adam 37 | init_lr: 4.0e-5 38 | weight_decay: 0.0 39 | betas: [0.5, 0.9] 40 | warmup: 41 | epoch: 0 42 | multiplier: 1 43 | buffer_epoch: 0 44 | min_lr: 4.0e-5 45 | mode: fix 46 | 47 | 48 | experiment: 49 | batch_size: 32 50 | epochs: 1 51 | save_ckpt_freq: 1 52 | test_freq: 1 53 | 54 | gan: 55 | disc: 56 | arch: 57 | in_channels: 3 58 | num_layers: 2 59 | use_actnorm: False 60 | ndf: 64 61 | spectral_norm: False 62 | 63 | loss: 64 | disc_loss: hinge 65 | gen_loss: vanilla 66 | disc_weight: 0.75 67 | perceptual_weight: 1.0 68 | disc_start: 0 -------------------------------------------------------------------------------- /configs/lsun-church/stage2/lsun-church256-sqgan-8x8x4-350M-simp.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | type: LSUN-church 3 | vocab_size: 16384 4 | transforms: 5 | type: LSUN-church 6 | 7 | arch: 8 | type: rq-transformer 9 | block_size: [ 8, 8, 4 ] 10 | 11 | embed_dim: 1024 12 | input_embed_dim: 256 13 | shared_tok_emb: true 14 | shared_cls_emb: true 15 | 16 | input_emb_vqvae: true 17 | head_emb_vqvae: true 18 | cumsum_depth_ctx: true 19 | 20 | vocab_size_cond: 1 21 | block_size_cond: 1 22 | 23 | body: 24 | n_layer: 24 25 | block: 26 | n_head: 16 27 | head: 28 | n_layer: 4 29 | block: 30 | n_head: 16 31 | 32 | loss: 33 | type: soft_target_cross_entropy 34 | stochastic_codes: true 35 | temp: 0.5 36 | 37 | optimizer: 38 | type: adamW 39 | init_lr: 0.0005 40 | weight_decay: 0.0001 41 | betas: [0.9, 0.95] 42 | warmup: 43 | epoch: 0 44 | multiplier: 1 45 | buffer_epoch: 0 46 | min_lr: 0.0 47 | mode: fix 48 | start_from_zero: True 49 | max_gn: 1.0 50 | 51 | experiment: 52 | amp: True 53 | batch_size: 32 54 | total_batch_size: 256 55 | epochs: 300 56 | save_ckpt_freq: 10 57 | test_freq: 5 58 | sample: 59 | top_k: 250 60 | top_p: 1.0 61 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # How to Download the Datasets 2 | 3 | In this document, we introduce how to prepare the datasets used in this study. 4 | When you use the pre-computed features to evaluate FID and IS, note that you do not need to download the datasets below. 5 | After you download each dataset, please use the directory path, which includes the datasets below, as `root` for the `Dataset` classes in `../rqvae/img_datasets` and `../rqvae/txtimg_datasets`. 6 | If you have already downloaded a dataset, you can use its path as `root` for the `Dataset` classes. 7 | 8 | 9 | ## FFHQ 10 | Before you download the FFHQ dataset, you can refer to the details in [the official repository](https://github.com/NVlabs/ffhq-dataset). 11 | You can download the zip file for FFHQ images 1024x1024 at [this link](https://drive.google.com/file/d/1WvlAIvuochQn_L_f9p3OdFdTiSLlnnhv/view?usp=sharing). 12 | After downloading the zip file, please unzip it into the `root` directory for `class FFHQ` in `../rqvae/img_datasets/ffhq.py`. 13 | 14 | 15 | ## LSUN-{Church, Bedroom} 16 | Before you download LSUN-{Church, Bedroom}, you can refer to the details in [the official repository](https://github.com/fyu/lsun). 17 | After cloning the official LSUN repository, you can easily download the two datasets using the scripts below. 18 | 19 | ```bash 20 | git clone https://github.com/fyu/lsun.git 21 | cd lsun 22 | python3 download.py -c church_outdoor -o $CHURCH_DIR_FOR_ROOT # your root directory 23 | python3 download.py -c bedroom -o $BEDROOM_DIR_FOR_ROOT # your root directory 24 | ``` 25 | 26 | ## LSUN-Cat 27 | To download the LSUN-Cat dataset, you can refer to [the official LSUN homepage](http://dl.yf.io/lsun/objects/). 28 | Otherwise, use the codes below to download `cat.zip` and unzip it. 29 | ```bash 30 | mkdir $CAT_DIR_FOR_ROOT # your root directory 31 | cd $CAT_DIR_FOR_ROOT 32 | wget http://dl.yf.io/lsun/objects/cat.zip 33 | unzip cat.zip 34 | ``` 35 | If `$CAT_DIR_FOR_ROOT` does not exist, make `$CAT_DIR_FOR_ROOT` first. 36 | 37 | ## ImageNet 38 | For ImageNet, we use `torchvision.datasets.ImageNet` in this repository. 39 | Since the ImageNet dataset is no longer publicly accessible, please download the train/val [datasets](https://image-net.org/download.php). 40 | Then, move the train/val datasets into a directory, which is used for `root` for `torchvision.datasets.ImageNet`. 41 | 42 | ## Conceptual Captions (CC-3M) 43 | For the CC-3M dataset, only Image URLs are provided instead of the image file. 44 | To download the images and prepare (image_path, text) pairs, please refer to `./cc3m/README.md`. 45 | 46 | 47 | ## MS-COCO 48 | You have to make a `$COCO_ROOT_DIR` directory. Then, make `$COCO_ROOT_DIR/ìmages` and `$COCO_ROOT_DIR/annotations` for downloading images and annotations, respectively. 49 | ```bash 50 | mkdir $COCO_ROOT_DIR # your root directory 51 | cd $COCO_ROOT_DIR 52 | mkdir images 53 | mkdir annotations 54 | ``` 55 | 56 | You can download the images and annotations at [the official homepage](http://images.cocodataset.org/zips/train2014.zip). 57 | Of course, you can use the scripts below. 58 | 59 | - To download MS-COCO images 60 | ```bash 61 | cd $COCO_ROOT_DIR/images 62 | wget http://images.cocodataset.org/zips/val2014.zip 63 | unzip val2014zip 64 | ``` 65 | 66 | - To download MS-COCO annotations 67 | ```bash 68 | cd $COCO_ROOT_DIR/annotations 69 | wget https://twg.kakaocdn.net/brainrepo/etc/RQVAE/54599b4b2286fdc2252d927aa3fd55eb/captions_val2014_30K_samples.json 70 | ``` 71 | 72 | -------------------------------------------------------------------------------- /data/cc3m/README.md: -------------------------------------------------------------------------------- 1 | # Download CC-3M Dataset 2 | 3 | To reproduce the results of RQ-Transformer trained on [CC-3M](https://ai.google.com/research/ConceptualCaptions/), 4 | we provide `download_cc3m.py` to download the available images of CC-3M. 5 | 6 | Since CC-3M datasets only provide pairs of a text caption and an image URL, you have to download all images first. 7 | Please follow the instructions below to successfully download the images of CC-3M and prepare its text-image pairs. 8 | 9 | 10 | 11 | ## Step 1: Download (text, image url) tsv files 12 | 13 | First of all, you have to download `Train_GCC-training.tsv` or `Validation_GCC-1.1.0-Validation.tsv` files. 14 | Please download the tsv files at [the public CC-3M homepage](https://ai.google.com/research/ConceptualCaptions/download). 15 | The tsv files include the pairs of (text caption, image URL). 16 | 17 | In here, we assume that the two tsv files are downloaded at `$CC3M_ROOT_DIR/Train_GCC-training.tsv` and `$CC3M_ROOT_DIR/Validation_GCC-1.1.0-Validation.tsv`. 18 | 19 | 20 | 21 | ## Step 2: Download images from their URL and prepare the (text, image filename) pairs. 22 | 23 | If the tsv files are prepared, you can download the images in the tsv files. 24 | After you download all available images, you have to prepare the (text, filename) pairs to use `../../rqvae/txtimg_datasets/cc3m.py`. 25 | 26 | If you want to download train images, 27 | ``` 28 | $ mkdir $CC3M_ROOT_DIR # your root directory 29 | $ python download_cc3m.py --split=train --save-dir=$CC3M_ROOT_DIR 30 | $ ls $CC3M_ROOT_DIR 31 | Train_GCC-training.tsv train_list.txt train 32 | ``` 33 | 34 | If you want to download validation images, 35 | ```bash 36 | $ python download_cc3m.py --split=val --save-dir=$CC3M_ROOT_DIR 37 | ``` 38 | 39 | `train_list.txt` and `val_list.txt` contain the pairs of (image filename, text). 40 | For example, when `$CC3M_ROOT_DIR=./`, the validation (image filename, text) pairs are saved as below. 41 | ``` 42 | ./val/246f243992061b252d986b1c2e0cebba author : a life in photography -- in pictures 43 | ./val/6a8701ad0c70e74b243acade5bb90870 photograph of the sign being repaired by brave person 44 | ./val/d4473adfd46c43218ae4774dbbbe8b12 the player staring intently at a computer screen . 45 | ./val/9d487a6594f2fde0b759cd67ae9d63fa the - bedroom stone cottage can sleep people 46 | ./val/f2c84277d1878b1285bcd4637f2df3e8 party in the park under cherry blossoms 47 | 48 | ``` 49 | 50 | After downloading the images and preparing `train_list.txt` and `val_list.txt`, 51 | you can use `$CC3M_ROOT_DIR` as the `root` in `../../rqvae/txtimg_datasets/cc3m.py`. -------------------------------------------------------------------------------- /data/cc3m/download_cc3m.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from multiprocessing import Pool 16 | import os 17 | import io 18 | import requests 19 | import logging 20 | import hashlib 21 | 22 | import argparse 23 | from PIL import Image 24 | import pandas as pd 25 | from torchvision.transforms import functional as F 26 | from tqdm import tqdm 27 | 28 | def get_parser(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('-s', '--split', type=str, default='val', help='split of cc3m: train or val') 31 | parser.add_argument('-dir', '--save-dir', type=str, default=None, help='dir path for downloading images') 32 | return parser 33 | 34 | parser = get_parser() 35 | args = parser.parse_args() 36 | if args.save_dir is None: 37 | current_dir = os.getcwd() 38 | else: 39 | current_dir = args.save_dir 40 | base_dir = os.path.join(current_dir, args.split) 41 | 42 | assert args.split in ['train', 'val'] 43 | if args.split == 'train': 44 | file_path = os.path.join(current_dir, 'Train_GCC-training.tsv') 45 | assert os.path.exists(file_path), 'download tsv files from https://ai.google.com/research/ConceptualCaptions/download' 46 | else: 47 | file_path = os.path.join(current_dir, 'Validation_GCC-1.1.0-Validation.tsv') 48 | assert os.path.exists(file_path), 'download tsv files from https://ai.google.com/research/ConceptualCaptions/download' 49 | 50 | os.makedirs(base_dir, exist_ok=True) 51 | print(f'Images are downloaded into {base_dir}') 52 | 53 | # set up url requests 54 | requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning) 55 | 56 | # Load data 57 | print(f'Load tsv file: {file_path}') 58 | df = pd.read_csv(file_path, delimiter='\t', header=None) 59 | 60 | url_caption_list = [(url, caption) for index, caption, url in df.itertuples()] 61 | print(f'Loaded {len(url_caption_list)} urls') 62 | 63 | def download_url_with_hashing(url_caption): 64 | try: 65 | url, caption = url_caption 66 | filename = hashlib.md5(url.encode('utf-8')).hexdigest() 67 | filepath = os.path.join(base_dir, filename) # concat to get filepath 68 | if not os.path.isfile(filepath): 69 | req = requests.get(url, stream=True, timeout=3, verify=False).raw 70 | image = Image.open(req).convert('RGB') 71 | 72 | min_image_size = 346 73 | new_size = image.size 74 | if min(new_size) > min_image_size: 75 | ratio = min(new_size) / min_image_size 76 | new_size = [int(x / ratio) for x in new_size] 77 | image = image.resize(new_size,) 78 | image.save(filepath, 'jpeg') # save PIL image 79 | return 0, caption, os.path.join('./', args.split, filename) 80 | return 0, caption, os.path.join('./', args.split, filename) 81 | except Exception as e: 82 | url, caption = url_caption 83 | print(" ".join(repr(e).splitlines())) 84 | print(url) 85 | return 1, caption, url 86 | 87 | with Pool(128) as p: 88 | retcodes = [] 89 | for retcode in tqdm(p.imap_unordered(download_url_with_hashing, url_caption_list), total=len(url_caption_list)): 90 | retcodes.append(retcode) 91 | print('Download DONE') 92 | 93 | 94 | okay_count = 0 95 | print(f"Write (caption filename) tsv files into {args.split}_list.txt") 96 | with open(os.path.join(current_dir, f'{args.split}_list.txt'), 'w') as f: 97 | with open(os.path.join(current_dir, f'{args.split}_error_list.txt', 'w')) as fnot: 98 | for retcode, text, imgpath in tqdm(retcodes, total=len(retcodes)): 99 | if retcode == 0: 100 | okay_count += 1 101 | f.write(f'{imgpath}\t{text}\n') 102 | else: 103 | fnot.write(f'{imgpath}\t{text}\n') 104 | 105 | print(f"Total {okay_count} / {len(retcodes)} pairs are prepared.") 106 | -------------------------------------------------------------------------------- /main_sampling_txt2img.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | from pathlib import Path 18 | 19 | import numpy as np 20 | import torch 21 | import torchvision 22 | import torch.distributed as dist 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import rqvae.utils.dist as dist_utils 26 | from rqvae.txtimg_datasets.cc3m import Cc3mTextOnly 27 | from rqvae.txtimg_datasets.coco import CocoTextOnly 28 | from rqvae.metrics.fid import compute_statistics_from_files 29 | from rqvae.utils.utils import set_seed, save_pickle 30 | from rqvae.utils.config import load_config 31 | 32 | from main_sampling_fid import (setup_logging, 33 | load_model, 34 | compute_metrics) 35 | 36 | 37 | def get_parser(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-a', '--model-ar-path', type=str, default=None) 40 | parser.add_argument('-v', '--model-vqvae-path', type=str, default=None) 41 | parser.add_argument('--dataset', type=str, default='cc3m', choices=['cc3m', 'coco_2014val']) 42 | 43 | parser.add_argument('-t', '--temp', type=float, default=None) 44 | parser.add_argument('--top-k', type=int, nargs='+', default=None) 45 | parser.add_argument('--top-p', type=float, nargs='+', default=None) 46 | parser.add_argument('-bs', '--batch-size', type=int, default=100, help='batch size (per gpu)') 47 | 48 | parser.add_argument('--ema', action='store_true') 49 | parser.add_argument('--save-dir', type=str, default=None) 50 | parser.add_argument('--no-tensorboard', action='store_false', dest='tensorboard') 51 | parser.add_argument('--no-stats-saving', action='store_false', dest='stats_saving') 52 | parser.add_argument('--seed', type=int, default=None) 53 | 54 | parser.add_argument('--world_size', default=-1, type=int, help='number of nodes for distributed training') 55 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') 56 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 57 | parser.add_argument('--timeout', type=int, default=86400, help='time limit (s) to wait for other nodes in DDP') 58 | 59 | return parser 60 | 61 | 62 | def add_default_args(args): 63 | 64 | config_path = Path(args.model_ar_path).parent / 'config.yaml' 65 | config = load_config(config_path) 66 | 67 | if args.temp is None: 68 | args.temp = config.sampling.temp 69 | 70 | if args.top_k is None: 71 | args.top_k = config.sampling.top_k 72 | 73 | if args.top_p is None: 74 | args.top_p = config.sampling.top_p 75 | 76 | 77 | def get_text_loader(args, config, distenv): 78 | valid_transform = [ 79 | torchvision.transforms.Resize(size=(config.dataset.image_resolution, config.dataset.image_resolution)), 80 | torchvision.transforms.ToTensor(), 81 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 82 | ] 83 | if args.dataset == 'cc3m': 84 | root = config.dataset.get('root', 'data/cc3m') 85 | dataset_val = Cc3mTextOnly( 86 | root, 87 | split='val', 88 | tok_name=config.dataset.txt_tok_name, 89 | transform=valid_transform, 90 | context_length=config.dataset.context_length, 91 | dropout=None, 92 | ) 93 | elif args.dataset == 'coco_2014val': 94 | root = config.dataset.get('root', 'data/coco') 95 | dataset_val = CocoTextOnly( 96 | root, 97 | split='val', 98 | tok_name=config.dataset.txt_tok_name, 99 | transform=valid_transform, 100 | context_length=config.dataset.context_length, 101 | dropout=None, 102 | ) 103 | else: 104 | raise NotImplementedError 105 | sampler = torch.utils.data.distributed.DistributedSampler( 106 | dataset_val, 107 | num_replicas=distenv.world_size, 108 | rank=distenv.world_rank, 109 | shuffle=False 110 | ) 111 | loader = torch.utils.data.DataLoader( 112 | dataset_val, 113 | sampler=sampler, 114 | shuffle=False, 115 | batch_size=args.batch_size, 116 | num_workers=0 117 | ) 118 | return loader 119 | 120 | 121 | def main(args): 122 | torch.backends.cudnn.benchmark = True 123 | distenv = dist_utils.initialize(args) 124 | device = torch.device('cuda', distenv.local_rank) 125 | torch.cuda.set_device(device) 126 | 127 | if args.seed: 128 | seed = set_seed(args.seed + distenv.world_rank) 129 | else: 130 | seed = set_seed(None) 131 | 132 | result_path, logger, writer = setup_logging(args, seed, distenv, device) 133 | 134 | # load the checkpoint of RQ-Transformer 135 | model_ar, config = load_model(args.model_ar_path, ema=args.ema) 136 | 137 | # load the checkpoint of RQ-VAE 138 | vqvae_path = args.model_vqvae_path 139 | model_vqvae, _ = load_model(vqvae_path) 140 | 141 | model_vqvae = model_vqvae.to(device) 142 | model_ar = model_ar.to(device) 143 | 144 | model_ar = dist_utils.dataparallel_and_sync(distenv, model_ar) 145 | 146 | model_vqvae.eval() 147 | model_ar.eval() 148 | 149 | loader = get_text_loader(args, config, distenv) 150 | batch_size = args.batch_size 151 | num_batches = len(loader) 152 | if distenv.master: 153 | logger.info(f'[state] batch_size (per gpu): {batch_size}') 154 | logger.info(f'[state] n_batches: {len(loader)}x{batch_size*distenv.world_size}' 155 | f'={len(loader) * batch_size* distenv.world_size}') 156 | 157 | sample_shape = model_ar.module.get_block_size() 158 | 159 | def get_initial_sample(n_samples): 160 | return torch.zeros(n_samples, *sample_shape, dtype=torch.long, device=device) 161 | 162 | for batch_idx, (_, txts) in enumerate(loader): 163 | 164 | # Sampling quantized codes 165 | txts = txts.to(device) 166 | partial_sample = get_initial_sample(txts.shape[0]) 167 | pixels = model_ar.module.sample(partial_sample, 168 | model_vqvae, 169 | cond=txts, 170 | temperature=args.temp, 171 | top_k=args.top_k, 172 | top_p=args.top_p, 173 | amp=True, 174 | fast=True, 175 | is_tqdm=distenv.master, 176 | desc=f"(sampling {batch_idx+1}/{num_batches})", 177 | ) 178 | 179 | # Decoding the sampled codes into RGB images 180 | pixels = torch.cat([model_vqvae.decode_code(pixels[i:i+1]) for i in range(pixels.size(0))], dim=0) 181 | pixels = pixels * 0.5 + 0.5 182 | pixels = torch.clamp(pixels, 0, 1) 183 | pixels = dist_utils.all_gather_cat(distenv, pixels) 184 | targets = dist_utils.all_gather_cat(distenv, txts) 185 | 186 | if distenv.master: 187 | # (M * B) -> (M, B) -> (B, M) -> (B * M) 188 | # to retain sample order same as in the dataset 189 | pixels = pixels.reshape(distenv.world_size, -1, *pixels.shape[1:]) 190 | pixels = pixels.transpose(0, 1) 191 | pixels = pixels.reshape(-1, *pixels.shape[2:]) 192 | 193 | logger.info(f'sync pixels: {pixels.shape}') 194 | save_pickle( 195 | os.path.join(result_path, f'samples_({batch_idx+1}_{num_batches}).pkl'), 196 | pixels.cpu().numpy(), 197 | ) 198 | 199 | targets = targets.reshape(distenv.world_size, -1, *targets.shape[1:]) 200 | targets = targets.transpose(0, 1) 201 | targets = targets.reshape(-1, *targets.shape[2:]) 202 | np.savez( 203 | os.path.join(result_path, f'targets_({batch_idx + 1}_{num_batches}).npz'), 204 | targets=targets.cpu().numpy(), 205 | ) 206 | 207 | if writer: 208 | grid = torchvision.utils.make_grid(pixels[:100], nrow=10) 209 | writer.add_image('samples', grid, batch_idx) 210 | 211 | if os.environ.get("SMOKE_TEST", 0): 212 | break 213 | 214 | logger.info(f'[state] end of sampling.') 215 | if dist.is_initialized(): 216 | dist.barrier() 217 | 218 | if distenv.master: 219 | 220 | # compute and save stats 221 | if args.stats_saving: 222 | mu_gen, sigma_gen, acts = compute_statistics_from_files(result_path, device=device, return_acts=True) 223 | acts_path = Path(result_path).joinpath('acts.npz') 224 | np.savez(acts_path, acts=acts, mu=mu_gen, sigma=sigma_gen) 225 | logger.info(f'[state] stat saved at {acts_path}') 226 | 227 | metrics = compute_metrics(result_path, args.dataset) 228 | metrics_repr = ', '.join([f'{key}: {value}' for key, value in metrics.items()]) 229 | logger.info(f'metrics: {metrics_repr}') 230 | 231 | # close the tb writer 232 | if writer: 233 | writer.close() 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = get_parser() 238 | args = parser.parse_args() 239 | add_default_args(args) 240 | print(args) 241 | if (not args.model_ar_path or not args.model_vqvae_path): 242 | raise Exception("Both ar_path and vqvae_path are needed for sampling") 243 | main(args) 244 | -------------------------------------------------------------------------------- /main_stage1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | import math 18 | 19 | import torch 20 | import torch.distributed as dist 21 | 22 | import rqvae.utils.dist as dist_utils 23 | from rqvae.models import create_model 24 | from rqvae.trainers import create_trainer 25 | from rqvae.img_datasets import create_dataset 26 | from rqvae.optimizer import create_optimizer, create_scheduler 27 | from rqvae.utils.utils import set_seed, compute_model_size, get_num_conv_linear_layers 28 | from rqvae.utils.setup import setup 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | 33 | parser.add_argument('-m', '--model-config', type=str, default='./configs/c10-igpt.yaml') 34 | parser.add_argument('-r', '--result-path', type=str, default='./results.tmp') 35 | parser.add_argument('-l', '--load-path', type=str, default='') 36 | parser.add_argument('-t', '--test-batch-size', type=int, default=200) 37 | parser.add_argument('-e', '--test-epoch', type=int, default=-1) 38 | parser.add_argument('-p', '--postfix', type=str, default='') 39 | parser.add_argument('--seed', type=int, default=0) 40 | 41 | parser.add_argument('--world_size', default=-1, type=int, help='number of nodes for distributed training') 42 | parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training') 43 | parser.add_argument('--node_rank', default=-1, type=int, help='node rank for distributed training') 44 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 45 | parser.add_argument('--timeout', type=int, default=86400, help='time limit (s) to wait for other nodes in DDP') 46 | parser.add_argument('--eval', action='store_true') 47 | parser.add_argument('--resume', action='store_true') 48 | 49 | args, extra_args = parser.parse_known_args() 50 | 51 | set_seed(args.seed) 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | config, logger, writer = setup(args, extra_args) 57 | distenv = config.runtime.distenv 58 | 59 | torch.backends.cudnn.benchmark = True 60 | device = torch.device('cuda', distenv.local_rank) 61 | torch.cuda.set_device(device) 62 | 63 | dataset_trn, dataset_val = create_dataset(config, is_eval=args.eval, logger=logger) 64 | model, model_ema = create_model(config.arch, ema=config.arch.ema is not None) 65 | model = model.to(device) 66 | if model_ema: 67 | model_ema = model_ema.to(device) 68 | trainer = create_trainer(config) 69 | 70 | train_epochs = config.experiment.epochs 71 | steps_per_epoch = math.ceil(len(dataset_trn) / (config.experiment.batch_size * distenv.world_size)) 72 | epoch_st = 0 73 | 74 | if distenv.master: 75 | logger.info(f'#conv+linear layers: {get_num_conv_linear_layers(model)}') 76 | 77 | if not args.eval: 78 | optimizer = create_optimizer(model, config) 79 | scheduler = create_scheduler( 80 | optimizer, config.optimizer.warmup, steps_per_epoch, 81 | config.experiment.epochs, distenv 82 | ) 83 | 84 | disc_state_dict = None 85 | if not args.load_path == '': 86 | ckpt = torch.load(args.load_path, map_location='cpu') 87 | model.load_state_dict(ckpt['state_dict']) 88 | disc_state_dict = ckpt.get('discriminator', None) 89 | if model_ema: 90 | model_ema.load_state_dict(ckpt['state_dict_ema']) 91 | 92 | if args.resume: 93 | optimizer.load_state_dict(ckpt['optimizer']) 94 | scheduler.load_state_dict(ckpt['scheduler']) 95 | epoch_st = ckpt['epoch'] 96 | 97 | if distenv.master: 98 | logger.info(f'{args.load_path} model is loaded') 99 | if args.resume: 100 | logger.info(f'Optimizer, scheduelr, and epoch is resumed') 101 | 102 | if distenv.master: 103 | print(model) 104 | compute_model_size(model, logger) 105 | 106 | if distenv.master and not args.eval: 107 | logger.info(optimizer.__repr__()) 108 | 109 | model = dist_utils.dataparallel_and_sync(distenv, model) 110 | if model_ema: 111 | model_ema = dist_utils.dataparallel_and_sync(distenv, model_ema) 112 | trainer = trainer(model, model_ema, dataset_trn, dataset_val, config, writer, 113 | device, distenv, disc_state_dict=disc_state_dict) 114 | if args.eval: 115 | trainer.eval(valid=False, verbose=True) 116 | trainer.eval(valid=True, verbose=True) 117 | if model_ema: 118 | trainer.eval(valid=True, ema=True, verbose=True) 119 | else: 120 | trainer.run_epoch(optimizer, scheduler, epoch_st) 121 | 122 | dist.barrier() 123 | 124 | if distenv.master: 125 | writer.close() # may prevent from a file stable error in brain cloud.. 126 | -------------------------------------------------------------------------------- /measure_throughput/rq_defaults.yaml: -------------------------------------------------------------------------------- 1 | rqtransformer: 2 | type: rq-transformer 3 | block_size: null 4 | 5 | embed_dim: null 6 | input_embed_dim: 256 7 | shared_tok_emb: true 8 | shared_cls_emb: true 9 | 10 | input_emb_vqvae: true 11 | head_emb_vqvae: true 12 | cumsum_depth_ctx: true 13 | 14 | vocab_size_cond: 1000 15 | block_size_cond: 1 16 | 17 | body: 18 | n_layer: null 19 | block: 20 | n_head: null 21 | head: 22 | n_layer: null 23 | block: 24 | n_head: null 25 | 26 | rqvae_f32: 27 | type: rq-vae 28 | code_hier: 1 29 | hparams: 30 | bottleneck_type: rq 31 | embed_dim: 256 32 | n_embed: null 33 | latent_shape: [8, 8, 256] 34 | code_shape: null 35 | shared_codebook: true 36 | decay: 0.99 37 | restart_unused_codes: true 38 | 39 | loss_type: mse 40 | latent_loss_weight: 0.25 41 | 42 | ddconfig: 43 | double_z: false 44 | z_channels: 256 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: [ 1, 1, 2, 2, 4, 4 ] 50 | num_res_blocks: 2 51 | attn_resolutions: [ 8 ] 52 | dropout: 0.00 53 | 54 | rqvae_f16: 55 | type: rq-vae 56 | code_hier: 1 57 | hparams: 58 | bottleneck_type: rq 59 | embed_dim: 256 60 | n_embed: null 61 | latent_shape: [16, 16, 256] 62 | code_shape: null 63 | shared_codebook: true 64 | decay: 0.99 65 | restart_unused_codes: true 66 | use_padding_idx: false 67 | masked_dropout: 0.0 68 | 69 | loss_type: mse 70 | latent_loss_weight: 0.25 71 | 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 3 77 | out_ch: 3 78 | ch: 128 79 | ch_mult: [ 1, 1, 2, 2, 4 ] 80 | num_res_blocks: 2 81 | attn_resolutions: [ 16 ] 82 | dropout: 0.00 83 | 84 | rqvae_f8: 85 | type: rq-vae 86 | code_hier: 1 87 | hparams: 88 | bottleneck_type: rq 89 | embed_dim: 256 90 | n_embed: null 91 | latent_shape: [32, 32, 256] 92 | code_shape: null 93 | shared_codebook: true 94 | decay: 0.99 95 | restart_unused_codes: true 96 | use_padding_idx: false 97 | masked_dropout: 0.0 98 | 99 | loss_type: mse 100 | latent_loss_weight: 0.25 101 | 102 | ddconfig: 103 | double_z: false 104 | z_channels: 256 105 | resolution: 256 106 | in_channels: 3 107 | out_ch: 3 108 | ch: 128 109 | ch_mult: [ 1, 2, 2, 4 ] 110 | num_res_blocks: 2 111 | attn_resolutions: [ 32 ] 112 | dropout: 0.00 113 | -------------------------------------------------------------------------------- /notebooks/notebook_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | 18 | from PIL import Image 19 | import yaml 20 | import numpy as np 21 | import torch 22 | import torchvision 23 | import clip 24 | import torch.nn.functional as F 25 | 26 | from rqvae.utils.config import load_config, augment_arch_defaults 27 | from rqvae.models import create_model 28 | from rqvae.txtimg_datasets.tokenizers import create_tokenizer 29 | 30 | 31 | class TextEncoder: 32 | def __init__(self, tokenizer_name, context_length=64, lowercase=True): 33 | self.tokenizer = create_tokenizer(tokenizer_name, lowercase=lowercase) 34 | self.context_length = context_length 35 | 36 | 37 | self.tokenizer.add_special_tokens(["[PAD]"]) 38 | self.tokenizer.enable_padding(length=self.context_length, 39 | pad_id=self.tokenizer.token_to_id("[PAD]")) 40 | self.tokenizer.enable_truncation(max_length=self.context_length) 41 | 42 | def encode(self, texts): 43 | output = self.tokenizer.encode(texts) 44 | ids = output.ids 45 | 46 | if not isinstance(ids, torch.LongTensor): 47 | ids = torch.LongTensor(ids) 48 | 49 | return ids 50 | 51 | def __call__(self, texts): 52 | return self.encode(texts) 53 | 54 | 55 | def load_model(path, ema=False): 56 | model_config = os.path.join(os.path.dirname(path), 'config.yaml') 57 | config = load_config(model_config) 58 | config.arch = augment_arch_defaults(config.arch) 59 | 60 | model, _ = create_model(config.arch, ema=False) 61 | if ema: 62 | ckpt = torch.load(path, map_location='cpu')['state_dict_ema'] 63 | else: 64 | ckpt = torch.load(path, map_location='cpu')['state_dict'] 65 | model.load_state_dict(ckpt) 66 | 67 | return model, config 68 | 69 | def get_initial_sample(batch_sample_shape, device=torch.device('cuda')): 70 | partial_sample = torch.zeros(*batch_sample_shape, 71 | dtype=torch.long, 72 | device=device) 73 | return partial_sample 74 | 75 | @torch.no_grad() 76 | def get_clip_score(pixels, texts, model_clip, preprocess_clip, device=torch.device('cuda')): 77 | # pixels: 0~1 valued tensors 78 | pixels = pixels.cpu().numpy() 79 | pixels = np.transpose(pixels, (0, 2, 3, 1)) 80 | 81 | images = [preprocess_clip(Image.fromarray((pixel*255).astype(np.uint8))) 82 | for pixel in pixels] 83 | images = torch.stack(images, dim=0).to(device=device) 84 | texts = clip.tokenize(texts).to(device=device) 85 | 86 | image_features = model_clip.encode_image(images) 87 | text_features = model_clip.encode_text(texts) 88 | 89 | scores = F.cosine_similarity(image_features, text_features).squeeze() 90 | 91 | return scores 92 | 93 | @torch.no_grad() 94 | def get_generated_images_by_texts(model_ar, 95 | model_vqvae, 96 | text_encoder, 97 | model_clip, 98 | preprocess_clip, 99 | text_prompts, 100 | num_samples, 101 | temperature, 102 | top_k, 103 | top_p, 104 | amp=True, 105 | fast=True, 106 | is_tqdm=True, 107 | ): 108 | 109 | sample_shape = model_ar.get_block_size() 110 | 111 | text_cond = text_encoder(text_prompts).unsqueeze(0).repeat(num_samples, 1).cuda() 112 | 113 | initial_codes = get_initial_sample([num_samples, *sample_shape]) 114 | generated_codes = model_ar.sample(initial_codes, 115 | model_vqvae, 116 | cond=text_cond, 117 | temperature=temperature, 118 | top_k=top_k, 119 | top_p=top_p, 120 | amp=amp, 121 | fast=fast, 122 | is_tqdm=is_tqdm, 123 | ) 124 | pixels = torch.cat([model_vqvae.decode_code(generated_codes[i:i+1]) 125 | for i in range(generated_codes.size(0)) 126 | ], dim=0) 127 | 128 | clip_scores = get_clip_score(pixels, 129 | text_prompts, 130 | model_clip, 131 | preprocess_clip, 132 | ) 133 | 134 | reranked_idxs = clip_scores.argsort(descending=True) 135 | reranked_pixels = pixels[reranked_idxs] 136 | 137 | return reranked_pixels 138 | -------------------------------------------------------------------------------- /notebooks/rqvae: -------------------------------------------------------------------------------- 1 | ../rqvae -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | ipdb==0.13.3 3 | lmdb==1.0.0 4 | pyflakes==2.2.0 5 | PyYAML==5.3.1 6 | tqdm==4.46.0 7 | tensorboard==2.3.0 8 | scikit-learn==0.24.0 9 | omegaconf 10 | pickle5 11 | matplotlib 12 | tokenizers>=0.10.2 13 | ftfy 14 | regex 15 | datadings 16 | pycocotools 17 | git+https://github.com/openai/CLIP.git 18 | -------------------------------------------------------------------------------- /rqvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/rqvae/__init__.py -------------------------------------------------------------------------------- /rqvae/img_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | from torch.utils.data import Subset 19 | import torchvision 20 | from torchvision.datasets import ImageNet 21 | 22 | from .lsun import LSUNClass 23 | from .ffhq import FFHQ 24 | from .transforms import create_transforms 25 | 26 | SMOKE_TEST = bool(os.environ.get("SMOKE_TEST", 0)) 27 | 28 | 29 | def create_dataset(config, is_eval=False, logger=None): 30 | transforms_trn = create_transforms(config.dataset, split='train', is_eval=is_eval) 31 | transforms_val = create_transforms(config.dataset, split='val', is_eval=is_eval) 32 | 33 | root = config.dataset.get('root', None) 34 | 35 | if config.dataset.type == 'imagenet': 36 | root = root if root else 'data/imagenet' 37 | dataset_trn = ImageNet(root, split='train', transform=transforms_trn) 38 | dataset_val = ImageNet(root, split='val', transform=transforms_val) 39 | elif config.dataset.type == 'imagenet_u': 40 | root = root if root else 'data/imagenet' 41 | 42 | def target_transform(_): 43 | return 0 44 | dataset_trn = ImageNet(root, split='train', transform=transforms_trn, target_transform=target_transform) 45 | dataset_val = ImageNet(root, split='val', transform=transforms_val, target_transform=target_transform) 46 | elif config.dataset.type == 'ffhq': 47 | root = root if root else 'data/ffhq' 48 | dataset_trn = FFHQ(root, split='train', transform=transforms_trn) 49 | dataset_val = FFHQ(root, split='val', transform=transforms_val) 50 | elif config.dataset.type in ['LSUN-cat', 'LSUN-church', 'LSUN-bedroom']: 51 | root = root if root else 'data/lsun' 52 | category_name = config.dataset.type.split('-')[-1] 53 | dataset_trn = LSUNClass(root, category_name=category_name, transform=transforms_trn) 54 | dataset_val = LSUNClass(root, category_name=category_name, transform=transforms_trn) 55 | else: 56 | raise ValueError('%s not supported...' % config.dataset.type) 57 | 58 | if SMOKE_TEST: 59 | dataset_len = config.experiment.total_batch_size * 2 60 | dataset_trn = torch.utils.data.Subset(dataset_trn, torch.randperm(len(dataset_trn))[:dataset_len]) 61 | dataset_val = torch.utils.data.Subset(dataset_val, torch.randperm(len(dataset_val))[:dataset_len]) 62 | 63 | if logger is not None: 64 | logger.info(f'#train samples: {len(dataset_trn)}, #valid samples: {len(dataset_val)}') 65 | 66 | return dataset_trn, dataset_val 67 | -------------------------------------------------------------------------------- /rqvae/img_datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | 18 | import torchvision 19 | 20 | 21 | class ImageFolder(torchvision.datasets.VisionDataset): 22 | 23 | def __init__(self, root, train_list_file, val_list_file, split='train', **kwargs): 24 | 25 | root = Path(root) 26 | super().__init__(root, **kwargs) 27 | 28 | self.train_list_file = train_list_file 29 | self.val_list_file = val_list_file 30 | 31 | self.split = self._verify_split(split) 32 | 33 | self.loader = torchvision.datasets.folder.default_loader 34 | self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS 35 | 36 | if self.split == 'trainval': 37 | fname_list = os.listdir(self.root) 38 | samples = [self.root.joinpath(fname) for fname in fname_list 39 | if fname.lower().endswith(self.extensions)] 40 | else: 41 | listfile = self.train_list_file if self.split == 'train' else self.val_list_file 42 | with open(listfile, 'r') as f: 43 | samples = [self.root.joinpath(line.strip()) for line in f.readlines()] 44 | 45 | self.samples = samples 46 | 47 | def _verify_split(self, split): 48 | if split not in self.valid_splits: 49 | msg = "Unknown split {} .".format(split) 50 | msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) 51 | raise ValueError(msg) 52 | return split 53 | 54 | @property 55 | def valid_splits(self): 56 | return 'train', 'val', 'trainval' 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | 61 | def __getitem__(self, index, with_transform=True): 62 | path = self.samples[index] 63 | sample = self.loader(path) 64 | if self.transforms is not None and with_transform: 65 | sample, _ = self.transforms(sample, None) 66 | return sample, 0 67 | 68 | 69 | class FFHQ(ImageFolder): 70 | train_list_file = Path(__file__).parent.joinpath('assets/ffhqtrain.txt') 71 | val_list_file = Path(__file__).parent.joinpath('assets/ffhqvalidation.txt') 72 | 73 | def __init__(self, root, split='train', **kwargs): 74 | super().__init__(root, FFHQ.train_list_file, FFHQ.val_list_file, split, **kwargs) 75 | -------------------------------------------------------------------------------- /rqvae/img_datasets/lsun.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import io 16 | import os 17 | from pathlib import Path 18 | import pickle 19 | import string 20 | from typing import Tuple, Any 21 | 22 | import torchvision 23 | import lmdb 24 | from PIL import Image 25 | 26 | 27 | class LSUNClass(torchvision.datasets.VisionDataset): 28 | 29 | subpaths = {'church': 'church/church_outdoor_train_lmdb', 30 | 'church_val': 'church/church_outdoor_val_lmdb', 31 | 'bedroom': 'bedroom/bedroom_train_lmdb', 32 | 'bedroom_val': 'bedroom/bedroom_val_lmdb', 33 | 'cat': 'cat', 34 | } 35 | valid_categories = ['church', 'bedroom', 'cat'] 36 | 37 | def __init__(self, root, category_name='church', transform=None): 38 | 39 | assert category_name in LSUNClass.valid_categories 40 | root = Path(root) / LSUNClass.subpaths[category_name] 41 | print(root) 42 | 43 | super(LSUNClass, self).__init__(root, transform=transform) 44 | 45 | self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) 46 | with self.env.begin(write=False) as txn: 47 | self.length = txn.stat()["entries"] 48 | cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters) 49 | cache_file = os.path.join(root, cache_file) 50 | if os.path.isfile(cache_file): 51 | self.keys = pickle.load(open(cache_file, "rb")) 52 | else: 53 | with self.env.begin(write=False) as txn: 54 | self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)] 55 | pickle.dump(self.keys, open(cache_file, "wb")) 56 | 57 | self.exception_idx = [29343, 88863] if category_name == 'cat' else [] 58 | 59 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 60 | index = index if index not in self.exception_idx else index - 1 61 | 62 | img, target = None, None 63 | env = self.env 64 | with env.begin(write=False) as txn: 65 | imgbuf = txn.get(self.keys[index]) 66 | 67 | buf = io.BytesIO() 68 | buf.write(imgbuf) 69 | buf.seek(0) 70 | img = Image.open(buf).convert("RGB") 71 | 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | 75 | return img, 0 76 | 77 | def __len__(self) -> int: 78 | return self.length 79 | -------------------------------------------------------------------------------- /rqvae/img_datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torchvision.transforms as transforms 16 | 17 | def create_transforms(config, split='train', is_eval=False): 18 | if config.transforms.type == 'imagenet256x256': 19 | if split == 'train' and not is_eval: 20 | transforms_ = [ 21 | transforms.Resize(256), 22 | transforms.RandomCrop(256), 23 | transforms.RandomHorizontalFlip(p=0.5), 24 | transforms.ToTensor(), 25 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 26 | ] 27 | else: 28 | transforms_ = [ 29 | transforms.Resize(256), 30 | transforms.CenterCrop(256), 31 | transforms.Resize((256, 256)), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 34 | ] 35 | elif 'ffhq' in config.transforms.type: 36 | resolution = int(config.transforms.type.split('_')[0].split('x')[-1]) 37 | if split == 'train' and not is_eval: 38 | transforms_ = [ 39 | transforms.RandomResizedCrop(resolution, scale=(0.75, 1.0), ratio=(1.0, 1.0)), 40 | transforms.RandomHorizontalFlip(p=0.5), 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 43 | ] 44 | else: 45 | transforms_ = [ 46 | transforms.Resize(resolution), 47 | transforms.CenterCrop(resolution), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 50 | ] 51 | elif config.transforms.type in ['LSUN', 'LSUN-cat', 'LSUN-church', 'LSUN-bedroom']: 52 | resolution = 256 # only 256 resolution is supoorted for LSUN 53 | transforms_ = [ 54 | transforms.Resize(resolution), 55 | transforms.CenterCrop(resolution), 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 58 | ] 59 | elif config.transforms.type == 'none': 60 | transforms_ = [] 61 | else: 62 | raise NotImplementedError('%s not implemented..' % config.transforms.type) 63 | 64 | transforms_ = transforms.Compose(transforms_) 65 | 66 | return transforms_ 67 | -------------------------------------------------------------------------------- /rqvae/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/rqvae/losses/__init__.py -------------------------------------------------------------------------------- /rqvae/losses/vqgan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from rqvae.optimizer import create_scheduler 16 | from rqvae.optimizer.optimizer import create_resnet_optimizer 17 | 18 | from .discriminator import NLayerDiscriminator, weights_init 19 | from .gan_loss import hinge_d_loss, vanilla_d_loss, vanilla_g_loss 20 | from .lpips import LPIPS 21 | 22 | 23 | def create_vqgan_loss(loss_config): 24 | 25 | disc_loss_type = loss_config.disc_loss 26 | if disc_loss_type == "hinge": 27 | disc_loss = hinge_d_loss 28 | elif disc_loss_type == "vanilla": 29 | disc_loss = vanilla_d_loss 30 | else: 31 | raise ValueError(f"Unknown GAN loss '{disc_loss_type}'.") 32 | 33 | gen_loss_type = loss_config.gen_loss 34 | if gen_loss_type == 'vanilla': 35 | gen_loss = vanilla_g_loss 36 | else: 37 | raise ValueError(f"Unknown GAN loss '{gen_loss_type}'.") 38 | 39 | perceptual_loss = LPIPS() 40 | 41 | return disc_loss, gen_loss, perceptual_loss 42 | 43 | 44 | def create_discriminator_with_optimizer_scheduler(disc_config, steps_per_epoch, max_epoch, distenv=None): 45 | model = NLayerDiscriminator(input_nc=disc_config.arch.in_channels, 46 | n_layers=disc_config.arch.num_layers, 47 | use_actnorm=disc_config.arch.use_actnorm, 48 | ndf=disc_config.arch.ndf, 49 | ).apply(weights_init) 50 | 51 | optimizer = create_resnet_optimizer(model, disc_config.optimizer) 52 | scheduler = create_scheduler(optimizer, 53 | config=disc_config.optimizer.warmup, 54 | steps_per_epoch=steps_per_epoch, 55 | max_epoch=max_epoch, 56 | distenv=distenv) 57 | 58 | return model, optimizer, scheduler 59 | -------------------------------------------------------------------------------- /rqvae/losses/vqgan/discriminator.py: -------------------------------------------------------------------------------- 1 | """Adapted and modified from https://github.com/CompVis/taming-transformers""" 2 | import functools 3 | import torch 4 | import torch.nn as nn 5 | 6 | # ActNorm, weights_init, NLayerDiscriminator 7 | 8 | 9 | class ActNorm(nn.Module): 10 | def __init__(self, num_features, logdet=False, affine=True, 11 | allow_reverse_init=False): 12 | assert affine 13 | super().__init__() 14 | self.logdet = logdet 15 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 16 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 17 | self.allow_reverse_init = allow_reverse_init 18 | 19 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 20 | 21 | def initialize(self, input): 22 | with torch.no_grad(): 23 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 24 | mean = ( 25 | flatten.mean(1) 26 | .unsqueeze(1) 27 | .unsqueeze(2) 28 | .unsqueeze(3) 29 | .permute(1, 0, 2, 3) 30 | ) 31 | std = ( 32 | flatten.std(1) 33 | .unsqueeze(1) 34 | .unsqueeze(2) 35 | .unsqueeze(3) 36 | .permute(1, 0, 2, 3) 37 | ) 38 | 39 | self.loc.data.copy_(-mean) 40 | self.scale.data.copy_(1 / (std + 1e-6)) 41 | 42 | def forward(self, input, reverse=False): 43 | if reverse: 44 | return self.reverse(input) 45 | if len(input.shape) == 2: 46 | input = input[:,:,None,None] 47 | squeeze = True 48 | else: 49 | squeeze = False 50 | 51 | _, _, height, width = input.shape 52 | 53 | if self.training and self.initialized.item() == 0: 54 | self.initialize(input) 55 | self.initialized.fill_(1) 56 | 57 | h = self.scale * (input + self.loc) 58 | 59 | if squeeze: 60 | h = h.squeeze(-1).squeeze(-1) 61 | 62 | if self.logdet: 63 | log_abs = torch.log(torch.abs(self.scale)) 64 | logdet = height*width*torch.sum(log_abs) 65 | logdet = logdet * torch.ones(input.shape[0]).to(input) 66 | return h, logdet 67 | 68 | return h 69 | 70 | def reverse(self, output): 71 | if self.training and self.initialized.item() == 0: 72 | if not self.allow_reverse_init: 73 | raise RuntimeError( 74 | "Initializing ActNorm in reverse direction is " 75 | "disabled by default. Use allow_reverse_init=True to enable." 76 | ) 77 | else: 78 | self.initialize(output) 79 | self.initialized.fill_(1) 80 | 81 | if len(output.shape) == 2: 82 | output = output[:,:,None,None] 83 | squeeze = True 84 | else: 85 | squeeze = False 86 | 87 | h = output / self.scale - self.loc 88 | 89 | if squeeze: 90 | h = h.squeeze(-1).squeeze(-1) 91 | return h 92 | 93 | 94 | def weights_init(m): 95 | classname = m.__class__.__name__ 96 | if classname.find('Conv') != -1: 97 | nn.init.normal_(m.weight.data, 0.0, 0.02) 98 | elif classname.find('BatchNorm') != -1: 99 | nn.init.normal_(m.weight.data, 1.0, 0.02) 100 | nn.init.constant_(m.bias.data, 0) 101 | 102 | 103 | class NLayerDiscriminator(nn.Module): 104 | """Defines a PatchGAN discriminator as in Pix2Pix 105 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 106 | """ 107 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 108 | """Construct a PatchGAN discriminator 109 | Parameters: 110 | input_nc (int) -- the number of channels in input images 111 | ndf (int) -- the number of filters in the last conv layer 112 | n_layers (int) -- the number of conv layers in the discriminator 113 | norm_layer -- normalization layer 114 | """ 115 | super(NLayerDiscriminator, self).__init__() 116 | if not use_actnorm: 117 | norm_layer = nn.BatchNorm2d 118 | else: 119 | norm_layer = ActNorm 120 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 121 | use_bias = norm_layer.func != nn.BatchNorm2d 122 | else: 123 | use_bias = norm_layer != nn.BatchNorm2d 124 | 125 | kw = 4 126 | padw = 1 127 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 128 | nf_mult = 1 129 | nf_mult_prev = 1 130 | for n in range(1, n_layers): # gradually increase the number of filters 131 | nf_mult_prev = nf_mult 132 | nf_mult = min(2 ** n, 8) 133 | sequence += [ 134 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 135 | norm_layer(ndf * nf_mult), 136 | nn.LeakyReLU(0.2, True) 137 | ] 138 | 139 | nf_mult_prev = nf_mult 140 | nf_mult = min(2 ** n_layers, 8) 141 | sequence += [ 142 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 143 | norm_layer(ndf * nf_mult), 144 | nn.LeakyReLU(0.2, True) 145 | ] 146 | 147 | sequence += [ 148 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 149 | self.main = nn.Sequential(*sequence) 150 | 151 | def forward(self, input0, input1=None): 152 | """Standard forward.""" 153 | return self.main(input0), self.main(input1) if input1 is not None else None 154 | -------------------------------------------------------------------------------- /rqvae/losses/vqgan/gan_loss.py: -------------------------------------------------------------------------------- 1 | """Adapted and modified from https://github.com/CompVis/taming-transformers""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def hinge_d_loss(logits_real, logits_fake, reduction='mean'): 8 | reduce_op = torch.mean if reduction == 'mean' else torch.sum 9 | loss_real = reduce_op(F.relu(1. - logits_real)) 10 | loss_fake = reduce_op(F.relu(1. + logits_fake)) 11 | d_loss = 0.5 * (loss_real + loss_fake) 12 | return d_loss 13 | 14 | 15 | def vanilla_g_loss(logits_fake, reduction='mean'): 16 | if reduction == 'mean': 17 | return -torch.mean(logits_fake) 18 | elif reduction == 'sum': 19 | return -torch.sum(logits_fake) 20 | 21 | 22 | def vanilla_d_loss(logits_real, logits_fake, reduction='mean'): 23 | reduce_op = torch.mean if reduction == 'mean' else torch.sum 24 | d_loss = 0.5 * ( 25 | reduce_op(torch.nn.functional.softplus(-logits_real)) + 26 | reduce_op(torch.nn.functional.softplus(logits_fake))) 27 | return d_loss 28 | -------------------------------------------------------------------------------- /rqvae/losses/vqgan/lpips.py: -------------------------------------------------------------------------------- 1 | """Adapted and modified from https://github.com/CompVis/taming-transformers""" 2 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from .lpips_utils import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name) 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name is not "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target, reduction='mean'): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | if reduction == 'none': 55 | return val 56 | elif reduction == 'mean': 57 | return torch.mean(val) 58 | elif reduction == 'sum': 59 | return torch.sum(val) 60 | 61 | 62 | class ScalingLayer(nn.Module): 63 | def __init__(self): 64 | super(ScalingLayer, self).__init__() 65 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 66 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 67 | 68 | def forward(self, inp): 69 | return (inp - self.shift) / self.scale 70 | 71 | 72 | class NetLinLayer(nn.Module): 73 | """ A single linear layer which does a 1x1 conv """ 74 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 75 | super(NetLinLayer, self).__init__() 76 | layers = [nn.Dropout(), ] if (use_dropout) else [] 77 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 78 | self.model = nn.Sequential(*layers) 79 | 80 | 81 | class vgg16(torch.nn.Module): 82 | def __init__(self, requires_grad=False, pretrained=True): 83 | super(vgg16, self).__init__() 84 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 85 | self.slice1 = torch.nn.Sequential() 86 | self.slice2 = torch.nn.Sequential() 87 | self.slice3 = torch.nn.Sequential() 88 | self.slice4 = torch.nn.Sequential() 89 | self.slice5 = torch.nn.Sequential() 90 | self.N_slices = 5 91 | for x in range(4): 92 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 93 | for x in range(4, 9): 94 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 95 | for x in range(9, 16): 96 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 97 | for x in range(16, 23): 98 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 99 | for x in range(23, 30): 100 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 101 | if not requires_grad: 102 | for param in self.parameters(): 103 | param.requires_grad = False 104 | 105 | def forward(self, X): 106 | h = self.slice1(X) 107 | h_relu1_2 = h 108 | h = self.slice2(h) 109 | h_relu2_2 = h 110 | h = self.slice3(h) 111 | h_relu3_3 = h 112 | h = self.slice4(h) 113 | h_relu4_3 = h 114 | h = self.slice5(h) 115 | h_relu5_3 = h 116 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 117 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 118 | return out 119 | 120 | 121 | def normalize_tensor(x,eps=1e-10): 122 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 123 | return x/(norm_factor+eps) 124 | 125 | 126 | def spatial_average(x, keepdim=True): 127 | return x.mean([2,3],keepdim=keepdim) 128 | -------------------------------------------------------------------------------- /rqvae/losses/vqgan/lpips_utils.py: -------------------------------------------------------------------------------- 1 | """Adapted and modified from https://github.com/CompVis/taming-transformers""" 2 | import os, hashlib, pathlib 3 | import requests 4 | from tqdm import tqdm 5 | 6 | URL_MAP = { 7 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 8 | } 9 | 10 | CKPT_MAP = { 11 | "vgg_lpips": "vgg.pth" 12 | } 13 | 14 | MD5_MAP = { 15 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 16 | } 17 | 18 | 19 | def download(url, local_path, chunk_size=1024): 20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 21 | with requests.get(url, stream=True) as r: 22 | total_size = int(r.headers.get("content-length", 0)) 23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 24 | with open(local_path, "wb") as f: 25 | for data in r.iter_content(chunk_size=chunk_size): 26 | if data: 27 | f.write(data) 28 | pbar.update(chunk_size) 29 | 30 | 31 | def md5_hash(path): 32 | with open(path, "rb") as f: 33 | content = f.read() 34 | return hashlib.md5(content).hexdigest() 35 | 36 | 37 | def get_ckpt_path(name, root=None, check=False): 38 | assert name in URL_MAP 39 | if root is None: 40 | root = pathlib.Path(__file__).parent.absolute() 41 | root = os.path.join(root, '.caches') 42 | path = os.path.join(root, CKPT_MAP[name]) 43 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 44 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 45 | download(URL_MAP[name], path) 46 | md5 = md5_hash(path) 47 | assert md5 == MD5_MAP[name], md5 48 | return path 49 | -------------------------------------------------------------------------------- /rqvae/metrics/IS.py: -------------------------------------------------------------------------------- 1 | """Utils for Inception Score calculation. 2 | Borrowed from: 3 | PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 4 | The MIT License (MIT) 5 | See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 6 | """ 7 | from pathlib import Path 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from .fid import get_inception_model, create_dataset_from_files 14 | 15 | 16 | def inception_softmax(inception_model, images): 17 | with torch.no_grad(): 18 | logits = inception_model.get_logits(images) 19 | ps = torch.nn.functional.softmax(logits, dim=1) 20 | return ps 21 | 22 | 23 | @torch.no_grad() 24 | def calculate_kl_div(ps, splits: int): 25 | scores = [] 26 | num_samples = ps.shape[0] 27 | for j in range(splits): 28 | part = ps[(j * num_samples // splits):((j + 1) * num_samples // splits), :] 29 | kl = part * (torch.log(part) - torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) 30 | kl = torch.mean(torch.sum(kl, 1)) 31 | kl = torch.exp(kl) 32 | scores.append(kl.unsqueeze(0)) 33 | scores = torch.cat(scores, 0) 34 | m_scores = torch.mean(scores).detach().cpu().numpy() 35 | m_std = torch.std(scores).detach().cpu().numpy() 36 | return m_scores, m_std 37 | 38 | 39 | @torch.no_grad() 40 | def compute_inception_score_from_dataset(dataset, 41 | splits, 42 | batch_size, 43 | device=torch.device('cuda'), 44 | inception_model=None, 45 | disable_tqdm=False): 46 | """ 47 | Args: 48 | - dataset: dataset returning **float (0~1)** images 49 | """ 50 | if inception_model is None: 51 | inception_model = get_inception_model().to(device) 52 | 53 | data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=16) 54 | 55 | inception_model.eval() 56 | probs_list = [] 57 | 58 | for imgs in tqdm(data_loader, disable=disable_tqdm): 59 | imgs = imgs[0].to(device) 60 | logits = inception_model.get_logits(imgs) 61 | probs = torch.nn.functional.softmax(logits, dim=-1) 62 | probs_list.append(probs) 63 | 64 | probs_list = torch.cat(probs_list, 0) 65 | m_scores, m_std = calculate_kl_div(probs_list, splits=splits) 66 | 67 | return m_scores, m_std 68 | 69 | 70 | def compute_inception_score_from_files(path, 71 | splits=10, 72 | batch_size=500, 73 | device=torch.device('cuda'), 74 | inception_model=None, 75 | disable_tqdm=False): 76 | 77 | dataset = create_dataset_from_files(path) 78 | return compute_inception_score_from_dataset(dataset, 79 | splits, 80 | batch_size, 81 | device, 82 | inception_model, 83 | disable_tqdm) 84 | -------------------------------------------------------------------------------- /rqvae/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .fid import compute_rfid, compute_fid 16 | from .IS import compute_inception_score_from_files as compute_IS 17 | from .clip_score import compute_clip_score 18 | -------------------------------------------------------------------------------- /rqvae/metrics/clip_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from torch.nn import functional as F 18 | 19 | import clip 20 | from PIL import Image 21 | 22 | from .fid import create_dataset_from_files 23 | 24 | from rqvae.txtimg_datasets.cc3m import Cc3mRawTextOnly 25 | from rqvae.txtimg_datasets.coco import CocoRawTextOnly 26 | 27 | 28 | def get_clip(): 29 | model_clip, preprocess_clip = clip.load("ViT-B/32", device='cpu') 30 | return model_clip, preprocess_clip 31 | 32 | 33 | @torch.no_grad() 34 | def clip_score(pixels, texts, model_clip, preprocess_clip, device=torch.device('cuda')): 35 | # pixels: 0~1 valued tensors 36 | pixels = pixels.cpu().numpy() 37 | pixels = np.transpose(pixels, (0, 2, 3, 1)) 38 | 39 | images = [preprocess_clip(Image.fromarray((pixel*255).astype(np.uint8))) for pixel in pixels] 40 | images = torch.stack(images, dim=0).to(device=device) 41 | texts = clip.tokenize(texts).to(device=device) 42 | 43 | image_features = model_clip.encode_image(images) 44 | text_features = model_clip.encode_text(texts) 45 | 46 | scores = F.cosine_similarity(image_features, text_features).squeeze() 47 | 48 | return scores 49 | 50 | 51 | def compute_clip_score(fake_path, 52 | dataset_name='cc3m', 53 | dataset_root=None, 54 | split='val', 55 | batch_size=100, 56 | device=torch.device('cuda'), 57 | ): 58 | 59 | model_clip, preprocess_clip = get_clip() 60 | model_clip.to(device=device) 61 | model_clip.eval() 62 | 63 | img_dataset = create_dataset_from_files(fake_path) 64 | 65 | if dataset_name == 'cc3m': 66 | root = dataset_root if dataset_root else 'data/cc3m' 67 | txt_dataset = Cc3mRawTextOnly(root, split=split) 68 | elif dataset_name == 'coco': 69 | root = dataset_root if dataset_root else 'data/coco' 70 | txt_dataset = CocoRawTextOnly(root, split=split) 71 | else: 72 | raise ValueError(f'Unsupported dataset: {dataset_name}') 73 | 74 | # Here we assume that the order of imgs is same as the order of txts, 75 | # possibly has some duplicates at the end due to the distributed sampler. 76 | assert len(img_dataset) >= len(txt_dataset) 77 | img_dataset = torch.utils.data.Subset(img_dataset, np.arange(len(txt_dataset))) 78 | 79 | img_loader = torch.utils.data.DataLoader(img_dataset, batch_size=batch_size) 80 | txt_loader = torch.utils.data.DataLoader(txt_dataset, batch_size=batch_size) 81 | 82 | scores = [] 83 | for (imgs,), txts in zip(img_loader, txt_loader): 84 | score = clip_score(imgs, txts, model_clip, preprocess_clip) 85 | scores.append(score.cpu().numpy()) 86 | 87 | scores = np.concatenate(scores) 88 | scores_avg = scores.mean() 89 | 90 | return scores_avg 91 | -------------------------------------------------------------------------------- /rqvae/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .ema import ExponentialMovingAverage 16 | from .rqvae import get_rqvae 17 | from .rqtransformer import get_rqtransformer 18 | 19 | 20 | def create_model(config, ema=False): 21 | model_type = config.type.lower() 22 | 23 | if model_type == 'rq-transformer': 24 | model = get_rqtransformer(config) 25 | model_ema = get_rqtransformer(config) if ema else None 26 | elif model_type == 'rq-vae': 27 | model = get_rqvae(config) 28 | model_ema = get_rqvae(config) if ema else None 29 | else: 30 | raise ValueError(f'{model_type} is invalid..') 31 | 32 | if ema: 33 | model_ema = ExponentialMovingAverage(model_ema, config.ema) 34 | model_ema.eval() 35 | model_ema.update(model, step=-1) 36 | 37 | return model, model_ema 38 | -------------------------------------------------------------------------------- /rqvae/models/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import torch 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class ExponentialMovingAverage(torch.nn.Module): 22 | def __init__(self, init_module, mu): 23 | super(ExponentialMovingAverage, self).__init__() 24 | 25 | self.module = init_module 26 | self.mu = mu 27 | 28 | def forward(self, x, *args, **kwargs): 29 | return self.module(x, *args, **kwargs) 30 | 31 | def update(self, module, step=None): 32 | if step is None: 33 | mu = self.mu 34 | else: 35 | mu = min(self.mu, (1. + step) / (10. + step)) 36 | 37 | state_dict = {} 38 | with torch.no_grad(): 39 | for (name, m1), (name2, m2) in zip(self.module.state_dict().items(), module.state_dict().items()): 40 | if name != name2: 41 | logger.warning('[ExpoentialMovingAverage] not matched keys %s, %s', name, name2) 42 | 43 | if step is not None and step < 0: 44 | state_dict[name] = m2.clone().detach() 45 | else: 46 | state_dict[name] = ((mu * m1) + ((1.0 - mu) * m2)).clone().detach() 47 | 48 | self.module.load_state_dict(state_dict) 49 | 50 | def compute_loss(self, *args, **kwargs): 51 | return self.module.compute_loss(*args, **kwargs) 52 | 53 | def get_recon_imgs(self, *args, **kwargs): 54 | return self.module.get_recon_imgs(*args, **kwargs) 55 | -------------------------------------------------------------------------------- /rqvae/models/interfaces.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | from torch import nn 18 | 19 | 20 | class Stage1Model(nn.Module, metaclass=abc.ABCMeta): 21 | 22 | @abc.abstractmethod 23 | def get_codes(self, *args, **kwargs): 24 | """Generate the code from the input.""" 25 | pass 26 | 27 | @abc.abstractmethod 28 | def decode_code(self, *args, **kwargs): 29 | """Generate the decoded image from the given code.""" 30 | pass 31 | 32 | @abc.abstractmethod 33 | def get_recon_imgs(self, *args, **kwargs): 34 | """Scales the real and recon images properly. 35 | """ 36 | pass 37 | 38 | @abc.abstractmethod 39 | def compute_loss(self, *args, **kwargs): 40 | """Compute the losses necessary for training. 41 | 42 | return { 43 | 'loss_total': ..., 44 | 'loss_recon': ..., 45 | 'loss_latent': ..., 46 | 'codes': ..., 47 | ... 48 | } 49 | """ 50 | pass 51 | 52 | 53 | class Stage2Model(nn.Module, metaclass=abc.ABCMeta): 54 | 55 | @abc.abstractmethod 56 | def compute_loss(self, *args, **kwargs): 57 | """Compute the losses necessary for training. 58 | Typically, it would be the cross-entropy of the AR prediction w.r.t. the ground truth. 59 | """ 60 | pass 61 | 62 | def _init_weights(self, module): 63 | if isinstance(module, (nn.Linear, nn.Embedding)): 64 | module.weight.data.normal_(mean=0.0, std=0.02) 65 | if isinstance(module, nn.Linear) and module.bias is not None: 66 | module.bias.data.zero_() 67 | elif isinstance(module, nn.LayerNorm): 68 | module.bias.data.zero_() 69 | module.weight.data.fill_(1.0) 70 | 71 | def get_block_size(self): 72 | return self.block_size 73 | -------------------------------------------------------------------------------- /rqvae/models/rqtransformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .configs import RQTransformerConfig 16 | from .transformers import RQTransformer 17 | 18 | 19 | def get_rqtransformer(config: RQTransformerConfig): 20 | return RQTransformer(config) 21 | -------------------------------------------------------------------------------- /rqvae/models/rqtransformer/attentions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import Iterable 17 | 18 | import torch 19 | from torch import nn 20 | from torch.nn import functional as F 21 | 22 | from .configs import AttentionBlockConfig, AttentionStackConfig 23 | 24 | 25 | class GELU(nn.Module): 26 | def __init__(self, version='v1'): 27 | super().__init__() 28 | assert version == 'v1' or version == 'v2' 29 | 30 | self.version = version 31 | 32 | def forward(self, x): 33 | if self.version == 'v1': 34 | return F.gelu(x) 35 | else: 36 | return x * torch.sigmoid(1.702 * x) 37 | 38 | 39 | class MultiSelfAttention(nn.Module): 40 | """ 41 | Optimized by batched matmul operations 42 | """ 43 | 44 | def __init__(self, config: AttentionBlockConfig, mask=True): 45 | super().__init__() 46 | assert config.embed_dim % config.n_head == 0 47 | # key, query, value projections for all heads 48 | self.key = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) 49 | self.query = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) 50 | self.value = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) 51 | # regularization 52 | self.attn_drop = nn.Dropout(config.attn_pdrop, inplace=False) 53 | self.resid_drop = nn.Dropout(config.resid_pdrop, inplace=True) 54 | # output projection 55 | self.proj = nn.Linear(config.embed_dim, config.embed_dim, config.attn_bias) 56 | 57 | self.n_head = config.n_head 58 | self.mask = mask 59 | 60 | def forward(self, x, caching=False, past_kv=None): 61 | (B, T, C) = x.shape 62 | 63 | if not caching: 64 | assert past_kv is None 65 | 66 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C) 67 | 68 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 69 | k = self.key(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) # (B*nh, T, hs) 70 | q = self.query(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) # (B*nh, T, hs) 71 | v = self.value(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) # (B*nh, T, hs) 72 | 73 | if past_kv is not None: 74 | past_key, past_value = past_kv 75 | k = torch.cat([past_key, k], dim=-2) 76 | v = torch.cat([past_value, v], dim=-2) 77 | T_past = past_key.shape[1] 78 | else: 79 | T_past = 0 80 | 81 | if caching: 82 | present = torch.stack([k, v]) 83 | else: 84 | present = None 85 | 86 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T_past+T) -> (B * nh, T, T_past+T) 87 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 88 | if self.mask: 89 | mask = torch.tril(torch.ones(T_past+T, T_past+T, device=x.device, dtype=torch.bool)) 90 | mask = mask.view(1, T_past+T, T_past+T) 91 | att = att.masked_fill(~mask[:, T_past:T_past+T, :T_past+T], float('-inf')) 92 | att = F.softmax(att, dim=-1) 93 | att = self.attn_drop(att) 94 | 95 | y = torch.bmm(att, v) # (B*nh, T, T_past+T) X (B*nh, T_past+T, hs) -> (B*nh, T, hs) 96 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side 97 | 98 | # output projection 99 | y = self.resid_drop(self.proj(y)) 100 | 101 | if caching: 102 | return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C) 103 | else: 104 | return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C) 105 | 106 | 107 | class AttentionBlock(nn.Module): 108 | """ an unassuming Transformer block """ 109 | 110 | def __init__(self, config: AttentionBlockConfig): 111 | super().__init__() 112 | 113 | self.ln1 = nn.LayerNorm(config.embed_dim) 114 | self.ln2 = nn.LayerNorm(config.embed_dim) 115 | 116 | self.attn = MultiSelfAttention(config, mask=True) 117 | self.mlp = nn.Sequential( 118 | nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.mlp_bias), 119 | GELU(config.gelu), 120 | nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.mlp_bias), 121 | nn.Dropout(config.resid_pdrop, inplace=True), 122 | ) 123 | self._cache = None 124 | 125 | def forward(self, x): 126 | 127 | attn = self.attn(self.ln1(x)) 128 | 129 | x = x + attn 130 | x = x + self.mlp(self.ln2(x)) 131 | 132 | return x 133 | 134 | def cached_forward(self, x_present): 135 | 136 | attn, present = self.attn(self.ln1(x_present), caching=True, past_kv=self._cache['past_kv']) 137 | self._cache['past_kv'] = present 138 | 139 | x_present = x_present + attn 140 | x_present = x_present + self.mlp(self.ln2(x_present)) 141 | 142 | return x_present 143 | 144 | def init_cache(self): 145 | self._cache = {'past_kv': None} 146 | 147 | 148 | class AttentionStack(nn.Module): 149 | 150 | blocks: Iterable[AttentionBlock] 151 | 152 | def __init__(self, config: AttentionStackConfig): 153 | super().__init__() 154 | 155 | self.blocks = nn.ModuleList([AttentionBlock(config.block) for _ in range(config.n_layer)]) 156 | 157 | def forward(self, x): 158 | for block in self.blocks: 159 | x = block(x) 160 | return x 161 | 162 | def cached_forward(self, x_present): 163 | for block in self.blocks: 164 | x_present = block.cached_forward(x_present) 165 | return x_present 166 | 167 | def init_cache(self): 168 | for block in self.blocks: 169 | block.init_cache() 170 | -------------------------------------------------------------------------------- /rqvae/models/rqtransformer/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Any 16 | from dataclasses import dataclass 17 | 18 | from omegaconf import OmegaConf, MISSING 19 | 20 | 21 | @dataclass 22 | class AttentionBlockConfig: 23 | embed_dim: int = MISSING 24 | n_head: int = MISSING 25 | mlp_bias: bool = True 26 | attn_bias: bool = True 27 | attn_pdrop: float = 0.0 28 | resid_pdrop: float = 0.1 29 | gelu: str = 'v1' 30 | 31 | 32 | @dataclass 33 | class AttentionStackConfig: 34 | n_layer: int = MISSING 35 | block: AttentionBlockConfig = AttentionBlockConfig() 36 | 37 | 38 | @dataclass 39 | class RQTransformerConfig: 40 | 41 | type: str = 'rq-transformer' 42 | ema: Optional[bool] = None 43 | ar_hierarchy: Optional[bool] = None 44 | 45 | vocab_size: Any = MISSING 46 | block_size: List[int] = MISSING 47 | 48 | vocab_size_cond: int = 0 49 | block_size_cond: int = 0 50 | 51 | embed_dim: int = MISSING 52 | input_embed_dim: Optional[int] = None 53 | use_padding_emb: bool = False 54 | 55 | input_emb_vqvae: bool = False 56 | head_emb_vqvae: bool = False 57 | scaled_head_emb_vqvae: bool = False 58 | cumsum_depth_ctx: bool = False 59 | shared_tok_emb: bool = False 60 | 61 | embd_pdrop: float = 0.0 62 | 63 | body: AttentionStackConfig = AttentionStackConfig() 64 | head: AttentionStackConfig = AttentionStackConfig() 65 | 66 | shared_cls_emb: bool = False 67 | 68 | @classmethod 69 | def create(cls, config): 70 | defaults = OmegaConf.structured(cls(embed_dim=config.embed_dim)) 71 | defaults.body.block.embed_dim = defaults.embed_dim 72 | defaults.head.block.embed_dim = defaults.embed_dim 73 | return OmegaConf.merge(defaults, config) 74 | -------------------------------------------------------------------------------- /rqvae/models/rqtransformer/primitives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Iterable, List, Optional, Tuple, Union 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | from torch import Tensor 21 | 22 | Size = Union[Tuple[int, ...], List[int], torch.Size] 23 | 24 | 25 | class TupleEmbedding(nn.Embedding): 26 | r"""A simple lookup table that stores embeddings of multiple dictionaries and fixed size. 27 | 28 | This module intends to represent a tuple (x_1, ..., x_k) from the product of k (possibly differently-sized) 29 | dictionaries by the tuple of the embeddings of individual entries. 30 | The input to the module is a list of tuples of indices, and the output is the corresponding 31 | tuple embeddings. 32 | 33 | Args: 34 | num_embeddings (int or tuple): list of the sizes of each dictionary of embeddings 35 | embedding_dim (int): the size of each embedding vector 36 | 37 | Shape: 38 | - Input: :math:`(*, D)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract 39 | - Output: :math:`(*, D, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` 40 | """ 41 | 42 | def __init__(self, num_embeddings: Union[int, Iterable[int]], embedding_dim, **kwargs) -> None: 43 | 44 | if 'padding_idx' in kwargs: 45 | raise ValueError('padding_idx argument not supported') 46 | 47 | if isinstance(num_embeddings, int): 48 | num_embeddings = (num_embeddings,) 49 | 50 | self.num_embeddings_per_dict = num_embeddings 51 | self.embedding_dim = embedding_dim 52 | 53 | super(TupleEmbedding, self).__init__(num_embeddings=sum(self.num_embeddings_per_dict), 54 | embedding_dim=embedding_dim, 55 | **kwargs) 56 | 57 | self.register_buffer('offsets', None) 58 | self.offsets = torch.tensor(np.cumsum([0] + self.num_embeddings_per_dict[:-1]), dtype=torch.long) 59 | 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self) -> None: 63 | self.weight.data.normal_(mean=0.0, std=0.02) 64 | 65 | def forward(self, x: Tensor) -> Tensor: 66 | (*rem, D) = x.shape 67 | assert D == len(self.num_embeddings_per_dict) 68 | 69 | offsets = self.offsets.view(*[1 for _ in rem], D) 70 | x_emb = super(TupleEmbedding, self).forward(x + offsets) 71 | 72 | return x_emb 73 | 74 | 75 | class LogitMask(nn.Module): 76 | def __init__(self, vocab_size: Iterable[int], value=-1e6): 77 | super().__init__() 78 | 79 | self.vocab_size = vocab_size 80 | self.mask_cond = [vocab_size[0]]*len(vocab_size) != vocab_size 81 | self.value = value 82 | 83 | def forward(self, logits: Tensor) -> Tensor: 84 | if not self.mask_cond: 85 | return logits 86 | else: 87 | for idx, vocab_size in enumerate(self.vocab_size): 88 | logits[:, idx, vocab_size:].fill_(-float('Inf')) 89 | return logits 90 | 91 | 92 | class BatchLinear(nn.Module): 93 | r"""Applies multiple linear transformations to multiple vectors in a batched way: 94 | 95 | .. math:: 96 | y_i = x_i A_i^T + b_i \text{for } i=1, \cdots, n_{vectors} 97 | 98 | Args: 99 | n_vectors (int): number of linear transformations (=number of vectors in input) 100 | in_features (int): size of each input sample 101 | out_features (int): size of each output sample 102 | bias (bool): If set to ``False``, the layer will not learn an additive bias. 103 | Default: ``True`` 104 | """ 105 | bias: Optional[Tensor] 106 | 107 | def __init__(self, n_vectors: int, in_features: int, out_features: int, bias: bool = True) -> None: 108 | super().__init__() 109 | self.n_vectors = n_vectors 110 | self.in_features = in_features 111 | self.out_features = out_features 112 | 113 | self.weight = nn.Parameter(torch.Tensor(n_vectors, in_features, out_features)) 114 | 115 | if bias: 116 | self.bias = nn.Parameter(torch.Tensor(n_vectors, out_features)) 117 | else: 118 | self.register_parameter('bias', None) 119 | 120 | self.reset_parameters() 121 | 122 | def reset_parameters(self) -> None: 123 | self.weight.data.normal_(mean=0.0, std=0.02) 124 | if self.bias is not None: 125 | self.bias.data.zero_() 126 | 127 | def forward(self, input: Tensor, indices=None) -> Tensor: 128 | """ 129 | Inputs: 130 | input (Tensor): A tensor to which linear transfs. are applied 131 | indices (optional, List[int]): List of indices of linear transf. to be applied in a batched manner. 132 | If 'None', all linear transforms are applied. 133 | Shapes: 134 | - input: (*, n_vectors, in_channel) 135 | - output: (*, n_vectors, out_channel) 136 | Output: 137 | Tensor(shape=[..., n_vectors, out_channel]) 138 | """ 139 | (*rem, n_vectors, in_ch) = input.shape 140 | 141 | if indices: 142 | assert n_vectors == len(indices) 143 | weight = self.weight[indices] 144 | if self.bias is not None: 145 | bias = self.bias[indices] 146 | else: 147 | bias = None 148 | else: 149 | weight = self.weight 150 | bias = self.bias 151 | 152 | output = torch.einsum('bij,ijk->bik', 153 | input.view(-1, n_vectors, in_ch), 154 | weight, 155 | ) 156 | 157 | if bias is not None: 158 | output = output + bias.unsqueeze(0) 159 | 160 | return output.reshape(*rem, n_vectors, -1) 161 | 162 | def extra_repr(self) -> str: 163 | return 'n_vectors={}, in_features={}, out_features={}, bias={}'.format( 164 | self.n_vectors, self.in_features, self.out_features, self.bias is not None 165 | ) 166 | -------------------------------------------------------------------------------- /rqvae/models/rqvae/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .rqvae import RQVAE 16 | 17 | def get_rqvae(config): 18 | 19 | hps = config.hparams 20 | ddconfig = config.ddconfig 21 | 22 | model = RQVAE(**hps, ddconfig=ddconfig, checkpointing=config.checkpointing) 23 | 24 | return model 25 | 26 | -------------------------------------------------------------------------------- /rqvae/models/rqvae/layers.py: -------------------------------------------------------------------------------- 1 | """borrowed and modified from https://github.com/CompVis/taming-transformers""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | import math 8 | import numpy as np 9 | 10 | 11 | def nonlinearity(x): 12 | # swish 13 | return F.silu(x, inplace=True) # x*torch.sigmoid(x) 14 | 15 | 16 | def Normalize(in_channels): 17 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 18 | 19 | 20 | class Upsample(nn.Module): 21 | def __init__(self, in_channels, with_conv): 22 | super().__init__() 23 | self.with_conv = with_conv 24 | if self.with_conv: 25 | self.conv = torch.nn.Conv2d(in_channels, 26 | in_channels, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1) 30 | 31 | def forward(self, x): 32 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 33 | if self.with_conv: 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class Downsample(nn.Module): 39 | def __init__(self, in_channels, with_conv): 40 | super().__init__() 41 | self.with_conv = with_conv 42 | if self.with_conv: 43 | # no asymmetric padding in torch conv, must do it ourselves 44 | self.conv = torch.nn.Conv2d(in_channels, 45 | in_channels, 46 | kernel_size=3, 47 | stride=2, 48 | padding=0) 49 | 50 | def forward(self, x): 51 | if self.with_conv: 52 | pad = (0,1,0,1) 53 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 54 | x = self.conv(x) 55 | else: 56 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 57 | return x 58 | 59 | 60 | class ResnetBlock(nn.Module): 61 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 62 | dropout, temb_channels=512): 63 | super().__init__() 64 | self.in_channels = in_channels 65 | out_channels = in_channels if out_channels is None else out_channels 66 | self.out_channels = out_channels 67 | self.use_conv_shortcut = conv_shortcut 68 | self.checkpointing = False 69 | 70 | self.norm1 = Normalize(in_channels) 71 | self.conv1 = torch.nn.Conv2d(in_channels, 72 | out_channels, 73 | kernel_size=3, 74 | stride=1, 75 | padding=1) 76 | if temb_channels > 0: 77 | self.temb_proj = torch.nn.Linear(temb_channels, 78 | out_channels) 79 | self.norm2 = Normalize(out_channels) 80 | self.dropout = torch.nn.Dropout(dropout, inplace=True) 81 | self.conv2 = torch.nn.Conv2d(out_channels, 82 | out_channels, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1) 86 | if self.in_channels != self.out_channels: 87 | if self.use_conv_shortcut: 88 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 89 | out_channels, 90 | kernel_size=3, 91 | stride=1, 92 | padding=1) 93 | else: 94 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 95 | out_channels, 96 | kernel_size=1, 97 | stride=1, 98 | padding=0) 99 | 100 | def _forward(self, x, temb): 101 | h = x 102 | h = self.norm1(h) 103 | h = nonlinearity(h) 104 | h = self.conv1(h) 105 | 106 | if temb is not None: 107 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 108 | 109 | h = self.norm2(h) 110 | h = nonlinearity(h) 111 | h = self.dropout(h) 112 | h = self.conv2(h) 113 | 114 | if self.in_channels != self.out_channels: 115 | if self.use_conv_shortcut: 116 | x = self.conv_shortcut(x) 117 | else: 118 | x = self.nin_shortcut(x) 119 | 120 | return x+h 121 | 122 | def forward(self, x, temb): 123 | if self.checkpointing and self.training: 124 | out = checkpoint(self._forward, x, temb) 125 | else: 126 | out = self._forward(x, temb) 127 | return out 128 | 129 | 130 | class AttnBlock(nn.Module): 131 | def __init__(self, in_channels): 132 | super().__init__() 133 | self.in_channels = in_channels 134 | 135 | self.norm = Normalize(in_channels) 136 | self.q = torch.nn.Conv2d(in_channels, 137 | in_channels, 138 | kernel_size=1, 139 | stride=1, 140 | padding=0) 141 | self.k = torch.nn.Conv2d(in_channels, 142 | in_channels, 143 | kernel_size=1, 144 | stride=1, 145 | padding=0) 146 | self.v = torch.nn.Conv2d(in_channels, 147 | in_channels, 148 | kernel_size=1, 149 | stride=1, 150 | padding=0) 151 | self.proj_out = torch.nn.Conv2d(in_channels, 152 | in_channels, 153 | kernel_size=1, 154 | stride=1, 155 | padding=0) 156 | 157 | 158 | def forward(self, x): 159 | h_ = x 160 | h_ = self.norm(h_) 161 | q = self.q(h_) 162 | k = self.k(h_) 163 | v = self.v(h_) 164 | 165 | # compute attention 166 | b,c,h,w = q.shape 167 | q = q.reshape(b,c,h*w) 168 | q = q.permute(0,2,1) # b,hw,c 169 | k = k.reshape(b,c,h*w) # b,c,hw 170 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 171 | w_ = w_ * (int(c)**(-0.5)) 172 | w_ = torch.nn.functional.softmax(w_, dim=2) 173 | 174 | # attend to values 175 | v = v.reshape(b,c,h*w) 176 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 177 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 178 | h_ = h_.reshape(b,c,h,w) 179 | 180 | h_ = self.proj_out(h_) 181 | 182 | return x+h_ -------------------------------------------------------------------------------- /rqvae/models/rqvae/modules.py: -------------------------------------------------------------------------------- 1 | """borrowed and modified from https://github.com/CompVis/taming-transformers""" 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .layers import (AttnBlock, Downsample, Normalize, ResnetBlock, Upsample, nonlinearity) 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 12 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 13 | resolution, z_channels, double_z=True, **ignore_kwargs): 14 | super().__init__() 15 | self.ch = ch 16 | self.temb_ch = 0 17 | self.num_resolutions = len(ch_mult) 18 | self.num_res_blocks = num_res_blocks 19 | self.resolution = resolution 20 | self.in_channels = in_channels 21 | 22 | # downsampling 23 | self.conv_in = torch.nn.Conv2d(in_channels, 24 | self.ch, 25 | kernel_size=3, 26 | stride=1, 27 | padding=1) 28 | 29 | curr_res = resolution 30 | in_ch_mult = (1,)+tuple(ch_mult) 31 | self.down = nn.ModuleList() 32 | for i_level in range(self.num_resolutions): 33 | block = nn.ModuleList() 34 | attn = nn.ModuleList() 35 | block_in = ch*in_ch_mult[i_level] 36 | block_out = ch*ch_mult[i_level] 37 | for i_block in range(self.num_res_blocks): 38 | block.append(ResnetBlock(in_channels=block_in, 39 | out_channels=block_out, 40 | temb_channels=self.temb_ch, 41 | dropout=dropout)) 42 | block_in = block_out 43 | if curr_res in attn_resolutions: 44 | attn.append(AttnBlock(block_in)) 45 | down = nn.Module() 46 | down.block = block 47 | down.attn = attn 48 | if i_level != self.num_resolutions-1: 49 | down.downsample = Downsample(block_in, resamp_with_conv) 50 | curr_res = curr_res // 2 51 | self.down.append(down) 52 | 53 | # middle 54 | self.mid = nn.Module() 55 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 56 | out_channels=block_in, 57 | temb_channels=self.temb_ch, 58 | dropout=dropout) 59 | self.mid.attn_1 = AttnBlock(block_in) 60 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 61 | out_channels=block_in, 62 | temb_channels=self.temb_ch, 63 | dropout=dropout) 64 | 65 | # end 66 | self.norm_out = Normalize(block_in) 67 | self.conv_out = torch.nn.Conv2d(block_in, 68 | 2*z_channels if double_z else z_channels, 69 | kernel_size=3, 70 | stride=1, 71 | padding=1) 72 | 73 | def forward(self, x): 74 | # timestep embedding 75 | temb = None 76 | 77 | # downsampling 78 | hs = [self.conv_in(x)] 79 | for i_level in range(self.num_resolutions): 80 | for i_block in range(self.num_res_blocks): 81 | h = self.down[i_level].block[i_block](hs[-1], temb) 82 | if len(self.down[i_level].attn) > 0: 83 | h = self.down[i_level].attn[i_block](h) 84 | hs.append(h) 85 | if i_level != self.num_resolutions-1: 86 | hs.append(self.down[i_level].downsample(hs[-1])) 87 | 88 | # middle 89 | h = hs[-1] 90 | h = self.mid.block_1(h, temb) 91 | h = self.mid.attn_1(h) 92 | h = self.mid.block_2(h, temb) 93 | 94 | # end 95 | h = self.norm_out(h) 96 | h = nonlinearity(h) 97 | h = self.conv_out(h) 98 | return h 99 | 100 | 101 | class Decoder(nn.Module): 102 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 103 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 104 | resolution, z_channels, give_pre_end=False, **ignorekwargs): 105 | super().__init__() 106 | self.ch = ch 107 | self.temb_ch = 0 108 | self.num_resolutions = len(ch_mult) 109 | self.num_res_blocks = num_res_blocks 110 | self.resolution = resolution 111 | self.in_channels = in_channels 112 | self.give_pre_end = give_pre_end 113 | 114 | # compute in_ch_mult, block_in and curr_res at lowest res 115 | in_ch_mult = (1,)+tuple(ch_mult) 116 | block_in = ch*ch_mult[self.num_resolutions-1] 117 | curr_res = resolution // 2**(self.num_resolutions-1) 118 | self.z_shape = (1, z_channels, curr_res, curr_res) 119 | print("Working with z of shape {} = {} dimensions.".format( 120 | self.z_shape, np.prod(self.z_shape))) 121 | 122 | # z to block_in 123 | self.conv_in = torch.nn.Conv2d(z_channels, 124 | block_in, 125 | kernel_size=3, 126 | stride=1, 127 | padding=1) 128 | 129 | # middle 130 | self.mid = nn.Module() 131 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 132 | out_channels=block_in, 133 | temb_channels=self.temb_ch, 134 | dropout=dropout) 135 | self.mid.attn_1 = AttnBlock(block_in) 136 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 137 | out_channels=block_in, 138 | temb_channels=self.temb_ch, 139 | dropout=dropout) 140 | 141 | # upsampling 142 | self.up = nn.ModuleList() 143 | for i_level in reversed(range(self.num_resolutions)): 144 | block = nn.ModuleList() 145 | attn = nn.ModuleList() 146 | block_out = ch*ch_mult[i_level] 147 | for i_block in range(self.num_res_blocks+1): 148 | block.append(ResnetBlock(in_channels=block_in, 149 | out_channels=block_out, 150 | temb_channels=self.temb_ch, 151 | dropout=dropout)) 152 | block_in = block_out 153 | if curr_res in attn_resolutions: 154 | attn.append(AttnBlock(block_in)) 155 | up = nn.Module() 156 | up.block = block 157 | up.attn = attn 158 | if i_level != 0: 159 | up.upsample = Upsample(block_in, resamp_with_conv) 160 | curr_res = curr_res * 2 161 | self.up.insert(0, up) # prepend to get consistent order 162 | 163 | # end 164 | self.norm_out = Normalize(block_in) 165 | self.conv_out = torch.nn.Conv2d(block_in, 166 | out_ch, 167 | kernel_size=3, 168 | stride=1, 169 | padding=1) 170 | 171 | def forward(self, z): 172 | #assert z.shape[1:] == self.z_shape[1:] 173 | self.last_z_shape = z.shape 174 | 175 | # timestep embedding 176 | temb = None 177 | 178 | # z to block_in 179 | h = self.conv_in(z) 180 | 181 | # middle 182 | h = self.mid.block_1(h, temb) 183 | h = self.mid.attn_1(h) 184 | h = self.mid.block_2(h, temb) 185 | 186 | # upsampling 187 | for i_level in reversed(range(self.num_resolutions)): 188 | for i_block in range(self.num_res_blocks+1): 189 | h = self.up[i_level].block[i_block](h, temb) 190 | if len(self.up[i_level].attn) > 0: 191 | h = self.up[i_level].attn[i_block](h) 192 | if i_level != 0: 193 | h = self.up[i_level].upsample(h) 194 | 195 | # end 196 | if self.give_pre_end: 197 | return h 198 | 199 | h = self.norm_out(h) 200 | h = nonlinearity(h) 201 | h = self.conv_out(h) 202 | return h 203 | -------------------------------------------------------------------------------- /rqvae/models/rqvae/rqvae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | from torch.nn import functional as F 19 | 20 | from ..interfaces import Stage1Model 21 | from .quantizations import RQBottleneck 22 | from .modules import Encoder, Decoder 23 | from .layers import ResnetBlock 24 | 25 | 26 | class RQVAE(Stage1Model): 27 | def __init__(self, 28 | *, 29 | embed_dim=64, 30 | n_embed=512, 31 | decay=0.99, 32 | loss_type='mse', 33 | latent_loss_weight=0.25, 34 | bottleneck_type='rq', 35 | ddconfig=None, 36 | checkpointing=False, 37 | **kwargs): 38 | super().__init__() 39 | 40 | assert loss_type in ['mse', 'l1'] 41 | 42 | self.encoder = Encoder(**ddconfig) 43 | self.decoder = Decoder(**ddconfig) 44 | 45 | def set_checkpointing(m): 46 | if isinstance(m, ResnetBlock): 47 | m.checkpointing = checkpointing 48 | 49 | self.encoder.apply(set_checkpointing) 50 | self.decoder.apply(set_checkpointing) 51 | 52 | if bottleneck_type == 'rq': 53 | latent_shape = kwargs['latent_shape'] 54 | code_shape = kwargs['code_shape'] 55 | shared_codebook = kwargs['shared_codebook'] 56 | restart_unused_codes = kwargs['restart_unused_codes'] 57 | self.quantizer = RQBottleneck(latent_shape=latent_shape, 58 | code_shape=code_shape, 59 | n_embed=n_embed, 60 | decay=decay, 61 | shared_codebook=shared_codebook, 62 | restart_unused_codes=restart_unused_codes, 63 | ) 64 | self.code_shape = code_shape 65 | else: 66 | raise ValueError("invalid 'bottleneck_type' (must be 'rq')") 67 | 68 | self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 69 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 70 | 71 | self.loss_type = loss_type 72 | self.latent_loss_weight = latent_loss_weight 73 | 74 | def forward(self, xs): 75 | z_e = self.encode(xs) 76 | z_q, quant_loss, code = self.quantizer(z_e) 77 | out = self.decode(z_q) 78 | return out, quant_loss, code 79 | 80 | def encode(self, x): 81 | z_e = self.encoder(x) 82 | z_e = self.quant_conv(z_e).permute(0, 2, 3, 1).contiguous() 83 | return z_e 84 | 85 | def decode(self, z_q): 86 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 87 | z_q = self.post_quant_conv(z_q) 88 | out = self.decoder(z_q) 89 | return out 90 | 91 | @torch.no_grad() 92 | def get_codes(self, xs): 93 | z_e = self.encode(xs) 94 | _, _, code = self.quantizer(z_e) 95 | return code 96 | 97 | @torch.no_grad() 98 | def get_soft_codes(self, xs, temp=1.0, stochastic=False): 99 | assert hasattr(self.quantizer, 'get_soft_codes') 100 | 101 | z_e = self.encode(xs) 102 | soft_code, code = self.quantizer.get_soft_codes(z_e, temp=temp, stochastic=stochastic) 103 | return soft_code, code 104 | 105 | @torch.no_grad() 106 | def decode_code(self, code): 107 | z_q = self.quantizer.embed_code(code) 108 | decoded = self.decode(z_q) 109 | return decoded 110 | 111 | def get_recon_imgs(self, xs_real, xs_recon): 112 | 113 | xs_real = xs_real * 0.5 + 0.5 114 | xs_recon = xs_recon * 0.5 + 0.5 115 | xs_recon = torch.clamp(xs_recon, 0, 1) 116 | 117 | return xs_real, xs_recon 118 | 119 | def compute_loss(self, out, quant_loss, code, xs=None, valid=False): 120 | 121 | if self.loss_type == 'mse': 122 | loss_recon = F.mse_loss(out, xs, reduction='mean') 123 | elif self.loss_type == 'l1': 124 | loss_recon = F.l1_loss(out, xs, reduction='mean') 125 | else: 126 | raise ValueError('incompatible loss type') 127 | 128 | loss_latent = quant_loss 129 | 130 | if valid: 131 | loss_recon = loss_recon * xs.shape[0] * xs.shape[1] 132 | loss_latent = loss_latent * xs.shape[0] 133 | 134 | loss_total = loss_recon + self.latent_loss_weight * loss_latent 135 | 136 | return { 137 | 'loss_total': loss_total, 138 | 'loss_recon': loss_recon, 139 | 'loss_latent': loss_latent, 140 | 'codes': [code] 141 | } 142 | 143 | def get_last_layer(self): 144 | return self.decoder.conv_out.weight 145 | 146 | @torch.no_grad() 147 | def get_code_emb_with_depth(self, code): 148 | return self.quantizer.embed_code_with_depth(code) 149 | 150 | @torch.no_grad() 151 | def decode_partial_code(self, code, code_idx, decode_type='select'): 152 | r""" 153 | Use partial codebooks and decode the codebook features. 154 | If decode_type == 'select', the (code_idx)-th codebook features are decoded. 155 | If decode_type == 'add', the [0,1,...,code_idx]-th codebook features are added and decoded. 156 | """ 157 | z_q = self.quantizer.embed_partial_code(code, code_idx, decode_type) 158 | decoded = self.decode(z_q) 159 | return decoded 160 | 161 | @torch.no_grad() 162 | def forward_partial_code(self, xs, code_idx, decode_type='select'): 163 | r""" 164 | Reconstuct an input using partial codebooks. 165 | """ 166 | code = self.get_codes(xs) 167 | out = self.decode_partial_code(code, code_idx, decode_type) 168 | return out 169 | -------------------------------------------------------------------------------- /rqvae/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .optimizer import create_optimizer 16 | from .scheduler import create_scheduler 17 | -------------------------------------------------------------------------------- /rqvae/optimizer/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | 18 | import numpy as np 19 | from torch.nn import functional as F 20 | 21 | LOG_SCALE_MIN = -7 22 | 23 | 24 | def compute_entropy(x, normalized=False): 25 | if not normalized: 26 | x /= np.sum(x) 27 | h = -np.sum(x * np.log(x + 1e-10)) 28 | return h 29 | 30 | def update_codebook_with_entropy(codebook, code): 31 | code_h, code_w = code.shape[1:] 32 | try: 33 | code = code.view(-1).cpu().numpy() 34 | except: 35 | code = code.view(-1).numpy() 36 | code, code_cnt = np.unique(code, return_counts=True) 37 | code_cnt = code_cnt.astype(np.float32) / (code_h*code_w) 38 | codebook[code] += code_cnt 39 | code_ent_ = compute_entropy(codebook) 40 | return codebook, code_ent_ 41 | 42 | 43 | 44 | def torch_compute_entropy(x, normalized=False): 45 | if not normalized: 46 | x = x / torch.sum(x, dim=-1, keepdim=True) 47 | h = -torch.sum(x * torch.log(x + 1e-10), dim=-1) 48 | return h 49 | 50 | 51 | def to_one_hot(tensor, n, fill_with=1.): 52 | # we perform one hot encore with respect to the last axis 53 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 54 | if tensor.is_cuda: 55 | one_hot = one_hot.cuda() 56 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 57 | return one_hot 58 | 59 | 60 | def log_sum_exp(x, axis=1): 61 | """ numerically stable log_sum_exp implementation that prevents overflow """ 62 | # TF ordering -> NCHW format 63 | m, _ = torch.max(x, dim=axis) 64 | m2, _ = torch.max(x, dim=axis, keepdim=True) 65 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis) + 1e-7) 66 | 67 | 68 | def log_prob_from_logits(x, axis=1): 69 | """ numerically stable log_softmax implementation that prevents overflow """ 70 | # TF ordering -> NCHW format 71 | m, _ = torch.max(x, dim=axis, keepdim=True) 72 | return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True) + 1e-7) 73 | 74 | 75 | def soft_target_cross_entropy(input, target, reduction='mean'): 76 | loss = torch.sum(-target * log_prob_from_logits(input, axis=-1), dim=-1) 77 | if reduction == 'mean': 78 | return loss.mean() 79 | elif reduction == 'sum': 80 | return loss.sum() 81 | elif reduction == 'none': 82 | return loss 83 | else: 84 | raise ValueError() 85 | -------------------------------------------------------------------------------- /rqvae/optimizer/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | def create_resnet_optimizer(model, config): 18 | optimizer_type = config.type.lower() 19 | if optimizer_type == 'adamw': 20 | optimizer = torch.optim.AdamW( 21 | model.parameters(), lr=config.init_lr, weight_decay=config.weight_decay, 22 | betas=config.betas 23 | ) 24 | elif optimizer_type == 'adam': 25 | optimizer = torch.optim.Adam( 26 | model.parameters(), lr=config.init_lr, weight_decay=config.weight_decay, betas=config.betas 27 | ) 28 | elif optimizer_type == 'sgd': 29 | optimizer = torch.optim.SGD( 30 | model.parameters(), lr=config.init_lr, weight_decay=config.weight_decay, momentum=0.9 31 | ) 32 | else: 33 | raise ValueError(f'{optimizer_type} invalid..') 34 | return optimizer 35 | 36 | 37 | def create_optimizer(model, config): 38 | arch_type = config.arch.type.lower() 39 | if 'rq-vae' in config.arch.type: 40 | optimizer = create_resnet_optimizer(model, config.optimizer) 41 | else: 42 | raise ValueError(f'{arch_type} invalid..') 43 | return optimizer 44 | -------------------------------------------------------------------------------- /rqvae/optimizer/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | from torch.optim.lr_scheduler import CosineAnnealingLR 18 | 19 | 20 | def create_scheduler(optimizer, config, steps_per_epoch, max_epoch, distenv=None): 21 | 22 | multiplier = config.multiplier 23 | warmup_steps = config.epoch * steps_per_epoch 24 | buffer_steps = config.buffer_epoch * steps_per_epoch 25 | final_steps = max_epoch * steps_per_epoch 26 | min_lr = config.min_lr 27 | mode = config.mode 28 | start_from_zero = config.start_from_zero 29 | 30 | scheduler = CosineAnnealingLR( 31 | optimizer, T_max=final_steps - warmup_steps - buffer_steps, eta_min=min_lr 32 | ) 33 | 34 | if warmup_steps > 0.0: 35 | if mode == 'linear': 36 | multiplier = max(1.0, multiplier * distenv.world_size) 37 | elif mode == 'sqrt': 38 | multiplier = max(1.0, multiplier * math.sqrt(distenv.world_size)) 39 | elif mode == 'fix': 40 | multiplier = max(1.0, multiplier) 41 | elif mode == 'none': 42 | pass 43 | else: 44 | raise NotImplementedError(f'{mode} is not a valid warmup policy') 45 | warmup = GradualWarmup( 46 | optimizer, 47 | steps=warmup_steps, 48 | buffer_steps=buffer_steps, 49 | multiplier=multiplier, 50 | start_from_zero=start_from_zero 51 | ) 52 | else: 53 | warmup = None 54 | 55 | scheduler = Scheduler(warmup_scheduler=warmup, after_scheduler=scheduler) 56 | 57 | return scheduler 58 | 59 | 60 | class GradualWarmup(torch.optim.lr_scheduler._LRScheduler): 61 | def __init__(self, optimizer, steps, buffer_steps, multiplier, start_from_zero=True, last_epoch=-1): 62 | self.steps = steps 63 | self.t_steps = steps + buffer_steps 64 | self.multiplier = multiplier 65 | self.start_from_zero = start_from_zero 66 | 67 | super(GradualWarmup, self).__init__(optimizer, last_epoch) 68 | 69 | def get_lr(self): 70 | if self.last_epoch > self.steps: 71 | return [group['lr'] for group in self.optimizer.param_groups] 72 | 73 | if self.start_from_zero: 74 | multiplier = self.multiplier * min(1.0, (self.last_epoch / self.steps)) 75 | else: 76 | multiplier = 1 + ((self.multiplier - 1) * min(1.0, (self.last_epoch / self.steps))) 77 | return [lr * multiplier for lr in self.base_lrs] 78 | 79 | 80 | class Scheduler: 81 | def __init__(self, warmup_scheduler, after_scheduler): 82 | self.warmup_scheduler = warmup_scheduler 83 | self.after_scheduler = after_scheduler 84 | 85 | def step(self, epoch=None): 86 | if self.warmup_scheduler is not None: 87 | self.warmup_scheduler.step(epoch=epoch) 88 | 89 | if self.warmup_scheduler is None or \ 90 | self.warmup_scheduler.last_epoch > self.warmup_scheduler.t_steps: 91 | self.after_scheduler.step(epoch=epoch) 92 | 93 | def get_last_lr(self): 94 | if self.warmup_scheduler is not None and \ 95 | self.warmup_scheduler.last_epoch <= self.warmup_scheduler.t_steps: 96 | return self.warmup_scheduler.get_last_lr() 97 | else: 98 | return self.after_scheduler.get_last_lr() 99 | 100 | def state_dict(self): 101 | return { 102 | 'warmup': None if self.warmup_scheduler is None else self.warmup_scheduler.state_dict(), 103 | 'after': self.after_scheduler.state_dict() 104 | } 105 | 106 | def load_state_dict(self, state_dict): 107 | if self.warmup_scheduler is not None: 108 | self.warmup_scheduler.load_state_dict(state_dict['warmup']) 109 | self.after_scheduler.load_state_dict(state_dict['after']) 110 | -------------------------------------------------------------------------------- /rqvae/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .trainer_rqvae import Trainer as TrainerRQVAE 16 | 17 | STAGE1_ARCH_TYPE = [ 18 | 'rq-vae' 19 | ] 20 | 21 | 22 | def create_trainer(config): 23 | if config.arch.type in STAGE1_ARCH_TYPE: 24 | return TrainerRQVAE 25 | 26 | else: 27 | raise ValueError('architecture type not supported') 28 | -------------------------------------------------------------------------------- /rqvae/trainers/accumulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Iterable 16 | 17 | import torch 18 | 19 | import rqvae.utils.dist as dist_utils 20 | from rqvae.optimizer.loss import torch_compute_entropy 21 | 22 | 23 | def assign_code(codebook, code): 24 | 25 | if len(code.shape) == 3: 26 | code = code.reshape(*code.shape, 1) 27 | 28 | n_codebooks = codebook.shape[0] 29 | code_h, code_w, code_d = code.shape[1:] 30 | chunks = torch.chunk(code, chunks=n_codebooks, dim=-1) 31 | for i, chunk in enumerate(chunks): 32 | uniques, counts = torch.unique(chunk.view(-1), return_counts=True) 33 | freqs = counts.to(dtype=torch.float) / (code_h * code_w) 34 | codebook[i][uniques] += freqs 35 | 36 | 37 | class SummaryStage1: 38 | def __init__(self, loss_total, loss_recon, loss_latent, ent_codes_w_pad, ent_codes_wo_pad): 39 | self.loss_total = loss_total 40 | self.loss_recon = loss_recon 41 | self.loss_latent = loss_latent 42 | self.ent_codes_w_pad = ent_codes_w_pad 43 | self.ent_codes_wo_pad = ent_codes_wo_pad 44 | 45 | def print_line(self): 46 | loss_total = self.loss_total.item() 47 | loss_recon = self.loss_recon.item() 48 | loss_latent = self.loss_latent.item() 49 | 50 | line = f"loss_total: {loss_total:.4f}, loss_recon: {loss_recon:.4f}, loss_latent: {loss_latent:.4f}, " 51 | 52 | if self.ent_codes_w_pad is not None: 53 | for level, ent_code in enumerate(self.ent_codes_w_pad): 54 | ent_code_str = '[' + ', '.join([f'{ent:.4f}' for ent in ent_code]) + ']' 55 | line += f"""w/ pad entropy-level-{level}: {ent_code_str}, """ 56 | 57 | for level, ent_code in enumerate(self.ent_codes_wo_pad): 58 | ent_code_str = '[' + ', '.join([f'{ent:.4f}' for ent in ent_code]) + ']' 59 | line += f"""w/o pad entropy-level-{level}: {ent_code_str}, """ 60 | 61 | return line 62 | 63 | def __getitem__(self, item): 64 | return getattr(self, item) 65 | 66 | def __setitem__(self, key, value): 67 | setattr(self, key, value) 68 | 69 | 70 | class AccmStage1: 71 | def __init__(self, n_codebook=1, codebook_size=512, code_hier=1, use_padding_idx=False, device='cpu'): 72 | self.n_codebook = n_codebook 73 | self.max_codebook_size = self.codebook_size = codebook_size 74 | self.use_padding_idx = use_padding_idx 75 | 76 | if isinstance(codebook_size, Iterable): 77 | self.max_codebook_size = max(codebook_size) 78 | 79 | if self.use_padding_idx: 80 | self.max_codebook_size += 1 81 | 82 | self.code_hier = code_hier 83 | self.device = device 84 | 85 | self.init() 86 | 87 | def init(self): 88 | self.loss_total = torch.zeros(1, device=self.device) 89 | self.loss_recon = torch.zeros(1, device=self.device) 90 | self.loss_latent = torch.zeros(1, device=self.device) 91 | 92 | self.codebooks = [torch.zeros(self.n_codebook, self.max_codebook_size, device=self.device) 93 | for _ in range(self.code_hier)] 94 | self.counter = 0 95 | 96 | @torch.no_grad() 97 | def update(self, 98 | loss_total, 99 | loss_recon, 100 | loss_latent, 101 | codes, 102 | count=None, 103 | sync=False, 104 | distenv=None): 105 | 106 | if sync: 107 | loss_total = dist_utils.all_gather_cat(distenv, loss_total.unsqueeze(0)).sum() 108 | loss_recon = dist_utils.all_gather_cat(distenv, loss_recon.unsqueeze(0)).sum() 109 | loss_latent = dist_utils.all_gather_cat(distenv, loss_latent.unsqueeze(0)).sum() 110 | codes = [dist_utils.all_gather_cat(distenv, code) for code in codes] 111 | 112 | self.loss_total += loss_total.detach() 113 | self.loss_recon += loss_recon.detach() 114 | self.loss_latent += loss_latent.detach() 115 | 116 | for i in range(self.code_hier): 117 | assign_code(self.codebooks[i], codes[i].detach()) 118 | 119 | self.counter += count if not sync else count * distenv.world_size 120 | 121 | @torch.no_grad() 122 | def get_summary(self, n_samples=None): 123 | n_samples = n_samples if n_samples else self.counter 124 | 125 | loss_total = self.loss_total / n_samples 126 | loss_recon = self.loss_recon / n_samples 127 | loss_latent = self.loss_latent / n_samples 128 | 129 | if self.use_padding_idx: 130 | ent_codes_w_pad = [torch_compute_entropy(self.codebooks[i]) for i in range(self.code_hier)] 131 | ent_codes_wo_pad = [torch_compute_entropy(self.codebooks[i][:, :-1]) for i in range(self.code_hier)] 132 | else: 133 | ent_codes_w_pad = None 134 | ent_codes_wo_pad = [torch_compute_entropy(self.codebooks[i][:, :-1]) for i in range(self.code_hier)] 135 | 136 | summary = SummaryStage1(loss_total=loss_total, 137 | loss_recon=loss_recon, 138 | loss_latent=loss_latent, 139 | ent_codes_w_pad=ent_codes_w_pad, 140 | ent_codes_wo_pad=ent_codes_wo_pad) 141 | 142 | return summary 143 | 144 | 145 | class SummaryStage1WithGAN: 146 | def __init__(self, ent_codes_w_pad, ent_codes_wo_pad, **kwargs): 147 | for k, v in kwargs.items(): 148 | self[k] = v 149 | self.ent_codes_w_pad = ent_codes_w_pad 150 | self.ent_codes_wo_pad = ent_codes_wo_pad 151 | 152 | def print_line(self): 153 | line = "" 154 | for name, value in self.metrics.items(): 155 | line += f"{name}: {value.item():.4f}, " 156 | 157 | if self.ent_codes_w_pad is not None: 158 | for level, ent_code in enumerate(self.ent_codes_w_pad): 159 | ent_code_str = '[' + ', '.join([f'{ent:.4f}' for ent in ent_code]) + ']' 160 | line += f"""w/ pad entropy-level-{level}: {ent_code_str}, """ 161 | 162 | for level, ent_code in enumerate(self.ent_codes_wo_pad): 163 | ent_code_str = '[' + ', '.join([f'{ent:.4f}' for ent in ent_code]) + ']' 164 | line += f"""w/o pad entropy-level-{level}: {ent_code_str}""" 165 | 166 | return line 167 | 168 | @property 169 | def metrics(self): 170 | def is_scalar(value): 171 | return (isinstance(value, torch.Tensor) and value.numel() == 1) or isinstance(value, float) 172 | 173 | return {key: value for key, value in self.__dict__.items() if is_scalar(value)} 174 | 175 | def __getitem__(self, item): 176 | return getattr(self, item) 177 | 178 | def __setitem__(self, key, value): 179 | setattr(self, key, value) 180 | 181 | 182 | class AccmStage1WithGAN: 183 | def __init__(self, metric_names, n_codebook=1, codebook_size=512, code_hier=1, use_padding_idx=False, device='cpu'): 184 | self.n_codebook = n_codebook 185 | self.max_codebook_size = self.codebook_size = codebook_size 186 | self.use_padding_idx = use_padding_idx 187 | 188 | if isinstance(codebook_size, list): 189 | self.max_codebook_size = max(codebook_size) 190 | 191 | if self.use_padding_idx: 192 | self.max_codebook_size += 1 193 | 194 | self.code_hier = code_hier 195 | self.device = device 196 | 197 | self.metrics_sum = {n: torch.zeros(1, device=self.device) for n in metric_names} 198 | 199 | self.codebooks = [torch.zeros(self.n_codebook, self.max_codebook_size, device=self.device) 200 | for _ in range(self.code_hier)] 201 | self.counter = 0 202 | 203 | @torch.no_grad() 204 | def update(self, 205 | codes, 206 | metrics_to_add, 207 | count=None, 208 | sync=False, 209 | distenv=None): 210 | 211 | if sync: 212 | codes = [dist_utils.all_gather_cat(distenv, code) for code in codes] 213 | for name, value in metrics_to_add.items(): 214 | gathered_value = dist_utils.all_gather_cat(distenv, value.unsqueeze(0)) 215 | gathered_value = gathered_value.sum().detach() 216 | metrics_to_add[name] = gathered_value 217 | 218 | for name, value in metrics_to_add.items(): 219 | if name not in self.metrics_sum: 220 | raise KeyError(f'unexpected metric name: {name}') 221 | self.metrics_sum[name] += value 222 | 223 | for i in range(self.code_hier): 224 | assign_code(self.codebooks[i], codes[i].detach()) 225 | 226 | self.counter += count if not sync else count * distenv.world_size 227 | 228 | @torch.no_grad() 229 | def get_summary(self, n_samples=None): 230 | n_samples = n_samples if n_samples else self.counter 231 | 232 | metrics_avg = {k: v / n_samples for k, v in self.metrics_sum.items()} 233 | 234 | if self.use_padding_idx: 235 | ent_codes_w_pad = [torch_compute_entropy(self.codebooks[i]) for i in range(self.code_hier)] 236 | ent_codes_wo_pad = [torch_compute_entropy(self.codebooks[i][:, :-1]) for i in range(self.code_hier)] 237 | else: 238 | ent_codes_w_pad = None 239 | ent_codes_wo_pad = [torch_compute_entropy(self.codebooks[i]) for i in range(self.code_hier)] 240 | 241 | summary = SummaryStage1WithGAN(ent_codes_w_pad=ent_codes_w_pad, 242 | ent_codes_wo_pad=ent_codes_wo_pad, 243 | **metrics_avg) 244 | 245 | return summary 246 | -------------------------------------------------------------------------------- /rqvae/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import logging 17 | 18 | import torch 19 | 20 | from torch.utils.data.dataloader import DataLoader 21 | from torch.cuda.amp import GradScaler 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | SMOKE_TEST = bool(os.environ.get("SMOKE_TEST", 0)) 26 | 27 | 28 | class TrainerTemplate(): 29 | 30 | def __init__(self, 31 | model, 32 | model_ema, 33 | dataset_trn, 34 | dataset_val, 35 | config, 36 | writer, 37 | device, 38 | distenv, 39 | model_aux=None, 40 | *, 41 | disc_state_dict=None, # only used in VQGAN trainer 42 | ): 43 | super().__init__() 44 | 45 | num_workers = 16 46 | 47 | if SMOKE_TEST: 48 | if not torch.distributed.is_initialized(): 49 | num_workers = 0 50 | config.experiment.test_freq = 1 51 | config.experiment.save_ckpt_freq = 1 52 | 53 | self.model = model 54 | self.model_ema = model_ema 55 | self.model_aux = model_aux 56 | 57 | self.config = config 58 | self.writer = writer 59 | self.device = device 60 | self.distenv = distenv 61 | 62 | self.dataset_trn = dataset_trn 63 | self.dataset_val = dataset_val 64 | 65 | self.sampler_trn = torch.utils.data.distributed.DistributedSampler( 66 | self.dataset_trn, 67 | num_replicas=self.distenv.world_size, 68 | rank=self.distenv.world_rank, 69 | shuffle=True, 70 | seed=self.config.seed, 71 | ) 72 | self.loader_trn = DataLoader( 73 | self.dataset_trn, sampler=self.sampler_trn, shuffle=False, pin_memory=True, 74 | batch_size=config.experiment.batch_size, 75 | num_workers=num_workers, 76 | ) 77 | 78 | self.sampler_val = torch.utils.data.distributed.DistributedSampler( 79 | self.dataset_val, 80 | num_replicas=self.distenv.world_size, 81 | rank=self.distenv.world_rank, 82 | shuffle=False 83 | ) 84 | self.loader_val = DataLoader( 85 | self.dataset_val, sampler=self.sampler_val, shuffle=False, pin_memory=True, 86 | batch_size=config.experiment.batch_size, 87 | num_workers=num_workers 88 | ) 89 | 90 | def train(self, optimizer=None, scheduler=None, scaler=None, epoch=0): 91 | raise NotImplementedError 92 | 93 | def eval(self, valid=True, ema=False, verbose=False, epoch=0): 94 | raise NotImplementedError 95 | 96 | def run_epoch(self, optimizer=None, scheduler=None, epoch_st=0): 97 | scaler = GradScaler() if self.config.experiment.amp else None 98 | 99 | for i in range(epoch_st, self.config.experiment.epochs): 100 | self.sampler_trn.set_epoch(i) 101 | torch.cuda.empty_cache() 102 | summary_trn = self.train(optimizer, scheduler, scaler, epoch=i) 103 | if i == 0 or (i+1) % self.config.experiment.test_freq == 0: 104 | torch.cuda.empty_cache() 105 | summary_val = self.eval(epoch=i) 106 | if self.model_ema is not None: 107 | summary_val_ema = self.eval(ema=True, epoch=i) 108 | 109 | if self.distenv.master: 110 | self.logging(summary_trn, scheduler=scheduler, epoch=i+1, mode='train') 111 | 112 | if i == 0 or (i+1) % self.config.experiment.test_freq == 0: 113 | self.logging(summary_val, scheduler=scheduler, epoch=i+1, mode='valid') 114 | if self.model_ema is not None: 115 | self.logging(summary_val_ema, scheduler=scheduler, epoch=i+1, mode='valid_ema') 116 | 117 | if (i+1) % self.config.experiment.save_ckpt_freq == 0: 118 | self.save_ckpt(optimizer, scheduler, i+1) 119 | 120 | def save_ckpt(self, optimizer, scheduler, epoch): 121 | ckpt_path = os.path.join(self.config.result_path, 'epoch%d_model.pt' % epoch) 122 | logger.info("epoch: %d, saving %s", epoch, ckpt_path) 123 | ckpt = { 124 | 'epoch': epoch, 125 | 'state_dict': self.model.module.state_dict(), 126 | 'optimizer': optimizer.state_dict(), 127 | 'scheduler': scheduler.state_dict() 128 | } 129 | if self.model_ema is not None: 130 | ckpt.update(state_dict_ema=self.model_ema.module.module.state_dict()) 131 | torch.save(ckpt, ckpt_path) 132 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import torch.utils.data 17 | 18 | from .transforms import create_transforms 19 | from .coco import Coco 20 | from .cc3m import Cc3m 21 | 22 | 23 | def create_datasets(config, is_eval=False, logger=None): 24 | 25 | data_config = config.dataset 26 | 27 | train_transform = create_transforms(data_config, split='train', is_eval=is_eval) 28 | valid_transform = create_transforms(data_config, split='valid', is_eval=is_eval) 29 | 30 | root = data_config.get('root', None) 31 | 32 | if data_config.dataset == 'coco': 33 | root = root if root else 'data/coco' 34 | train_ds_cls = Coco 35 | valid_ds_cls = Coco 36 | elif data_config.dataset == 'cc3m': 37 | root = root if root else 'data/cc3m' 38 | train_ds_cls = Cc3m 39 | valid_ds_cls = Cc3m 40 | else: 41 | raise NotImplementedError(data_config.dataset) 42 | 43 | train_dataset = train_ds_cls(root, 44 | split='train', 45 | tok_name=data_config.txt_tok_name, 46 | transform=train_transform, 47 | context_length=data_config.context_length, 48 | dropout=data_config.bpe_dropout) 49 | valid_dataset = valid_ds_cls(root, 50 | split='val', 51 | tok_name=data_config.txt_tok_name, 52 | transform=valid_transform, 53 | context_length=data_config.context_length, 54 | dropout=None) 55 | 56 | if bool(os.environ.get("SMOKE_TEST", 0)): 57 | dataset_len = config.experiment.total_batch_size * 2 58 | train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset))[:dataset_len]) 59 | valid_dataset = torch.utils.data.Subset(valid_dataset, torch.randperm(len(valid_dataset))[:dataset_len]) 60 | 61 | if logger is not None: 62 | logger.info(f'#train samples: {len(train_dataset)}, #valid samples: {len(valid_dataset)}') 63 | 64 | return train_dataset, valid_dataset 65 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/cc3m.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | import torch.utils.data 19 | from torchvision.datasets import VisionDataset 20 | from PIL import Image 21 | from tqdm import tqdm 22 | 23 | from .tokenizers import create_tokenizer 24 | 25 | 26 | class Cc3m(VisionDataset): 27 | splits = {'train', 'val'} 28 | 29 | def __init__(self, root, split, tok_name, transform=None, context_length=77, dropout=None): 30 | assert split in self.splits, f'{split} is not in {self.splits}' 31 | super().__init__(root, transform=transform) 32 | 33 | self.split = split 34 | self.tokenizer = create_tokenizer(tok_name, lowercase=True, dropout=dropout) 35 | self.context_length = context_length 36 | 37 | self.tokenizer.add_special_tokens(["[PAD]"]) 38 | self.tokenizer.enable_padding(length=self.context_length, 39 | pad_id=self.tokenizer.token_to_id("[PAD]")) 40 | self.tokenizer.enable_truncation(max_length=self.context_length) 41 | 42 | self.items = [] 43 | 44 | for line in open(f'{self.root}/{split}_list.txt', 'r').readlines(): 45 | toks = line.strip().split('\t') 46 | assert len(toks) == 2 47 | (imgpath, text) = toks 48 | self.items.append((os.path.join(self.root, imgpath), text)) 49 | 50 | def __len__(self): 51 | return len(self.items) 52 | 53 | def __getitem__(self, item): 54 | imgpath, text = self.items[item] 55 | 56 | img = Image.open(imgpath).convert('RGB') 57 | if self.transform: 58 | img = self.transform(img) 59 | 60 | output = self.tokenizer.encode(text) 61 | ids = output.ids 62 | if not isinstance(ids, torch.LongTensor): 63 | ids = torch.LongTensor(ids) 64 | 65 | return img, ids 66 | 67 | 68 | class Cc3mRawTextOnly(torch.utils.data.Dataset): 69 | 70 | def __init__(self, root, split): 71 | 72 | self.root = root 73 | self.items = [] 74 | for line in open(f'{self.root}/{split}_list.txt', 'r').readlines(): 75 | toks = line.strip().split('\t') 76 | assert len(toks) == 2 77 | (_, text) = toks 78 | self.items.append(text) 79 | 80 | def __len__(self): 81 | return len(self.items) 82 | 83 | def __getitem__(self, item): 84 | text = self.items[item] 85 | return text 86 | 87 | 88 | class Cc3mTextOnly(Cc3m): 89 | 90 | def __getitem__(self, item): 91 | _, text = self.items[item] 92 | 93 | output = self.tokenizer.encode(text) 94 | ids = output.ids 95 | if not isinstance(ids, torch.LongTensor): 96 | ids = torch.LongTensor(ids) 97 | 98 | return 0, ids 99 | 100 | 101 | class Cc3mRawText(VisionDataset): 102 | splits = {'train', 'val'} 103 | 104 | def __init__(self, root, split, transform=None): 105 | assert split in self.splits, f'{split} is not in {self.splits}' 106 | super().__init__(root, transform=transform) 107 | 108 | self.split = split 109 | self.items = [] 110 | 111 | for line in open(f'{self.root}/{split}_list.txt', 'r').readlines(): 112 | toks = line.strip().split('\t') 113 | assert len(toks) == 2 114 | (imgpath, text) = toks 115 | self.items.append((os.path.join(self.root, imgpath), text)) 116 | 117 | def __len__(self): 118 | return len(self.items) 119 | 120 | def __getitem__(self, item): 121 | imgpath, text = self.items[item] 122 | 123 | img = Image.open(imgpath).convert('RGB') 124 | if self.transform: 125 | img = self.transform(img) 126 | 127 | return img, text 128 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/coco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | import torch 18 | 19 | from torchvision.datasets import CocoCaptions, VisionDataset 20 | 21 | from .tokenizers import create_tokenizer 22 | 23 | 24 | class Coco(VisionDataset): 25 | splits = {'val'} 26 | 27 | def __init__(self, root, split, tok_name, transform=None, context_length=77, dropout=None): 28 | assert split in self.splits, f'{split} is not in {self.splits}' 29 | super().__init__(root, transform=transform) 30 | 31 | self.split = split 32 | self.tokenizer = create_tokenizer(tok_name, lowercase=True, dropout=dropout) 33 | self.context_length = context_length 34 | 35 | self.dataset = CocoCaptions(root=f'{self.root}/images/val2014', 36 | annFile=f'{self.root}/annotations/captions_val2014_30K_samples.json') 37 | 38 | self.tokenizer.add_special_tokens(["[PAD]"]) 39 | self.tokenizer.enable_padding(length=self.context_length, 40 | pad_id=self.tokenizer.token_to_id("[PAD]")) 41 | self.tokenizer.enable_truncation(max_length=self.context_length) 42 | 43 | def __len__(self): 44 | return len(self.dataset) 45 | 46 | def __getitem__(self, item): 47 | img, text = self.dataset[item] 48 | 49 | if self.transform: 50 | img = self.transform(img) 51 | 52 | # text = ' '.join(text) # text is a list of sentences. Concat them. 53 | if self.split == 'train': 54 | rnd_txt = random.randint(0, len(text)-1) 55 | text = text[rnd_txt] 56 | else: 57 | text = text[0] 58 | 59 | output = self.tokenizer.encode(text) 60 | ids = output.ids 61 | if not isinstance(ids, torch.LongTensor): 62 | ids = torch.LongTensor(ids) 63 | 64 | return img, ids 65 | 66 | 67 | class CocoTextOnly(Coco): 68 | 69 | def __getitem__(self, item): 70 | _, text = self.dataset[item] 71 | 72 | text = text[0] 73 | 74 | output = self.tokenizer.encode(text) 75 | ids = output.ids 76 | if not isinstance(ids, torch.LongTensor): 77 | ids = torch.LongTensor(ids) 78 | 79 | return 0, ids 80 | 81 | 82 | class CocoRawText(VisionDataset): 83 | splits = {'val'} 84 | 85 | def __init__(self, root, split, transform=None): 86 | assert split in self.splits, f'{split} is not in {self.splits}' 87 | super().__init__(root, transform=transform) 88 | 89 | self.split = split 90 | 91 | self.dataset = CocoCaptions(root=f'{self.root}/images/val2014', 92 | annFile=f'{self.root}/annotations/captions_val2014_30K_samples.json') 93 | 94 | def __len__(self): 95 | return len(self.dataset) 96 | 97 | def __getitem__(self, item): 98 | img, text = self.dataset[item] 99 | 100 | if self.transform: 101 | img = self.transform(img) 102 | 103 | text = text[0] 104 | 105 | return img, text 106 | 107 | 108 | class CocoRawTextOnly(CocoRawText): 109 | def __getitem__(self, item): 110 | _, text = self.dataset[item] 111 | return text 112 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | 17 | from tokenizers import BertWordPieceTokenizer, ByteLevelBPETokenizer, CharBPETokenizer 18 | from .simple_tokenizer import SimpleTokenizer 19 | from .utils import bert_vocab, gpt2_vocab, gpt2_merges 20 | from .utils import huggingface_bpe_16k_vocab, huggingface_bpe_16k_merges 21 | from .utils import huggingface_bpe_30k_vocab, huggingface_bpe_30k_merges 22 | 23 | 24 | TOKENIZERS = { 25 | 'simple': SimpleTokenizer, 26 | 'bert_huggingface': partial(BertWordPieceTokenizer, vocab=bert_vocab()), 27 | 'gpt2_huggingface': partial(ByteLevelBPETokenizer.from_file, 28 | vocab_filename=gpt2_vocab(), 29 | merges_filename=gpt2_merges()), 30 | 'bpe16k_huggingface': partial(CharBPETokenizer.from_file, 31 | vocab_filename=huggingface_bpe_16k_vocab(), 32 | merges_filename=huggingface_bpe_16k_merges(), 33 | unk_token="[UNK]"), 34 | 'bpe30k_huggingface': partial(CharBPETokenizer.from_file, 35 | vocab_filename=huggingface_bpe_30k_vocab(), 36 | merges_filename=huggingface_bpe_30k_merges(), 37 | unk_token="[UNK]") 38 | } 39 | 40 | 41 | def create_tokenizer(tok_name, *args, **kwargs): 42 | if tok_name == 'simple' or tok_name == 'bert_huggingface': 43 | filtered_keys = [key for key in kwargs.keys() if key != 'dropout'] 44 | filtered_dict = {key: kwargs[key] for key in filtered_keys} 45 | return TOKENIZERS[tok_name](*args, **filtered_dict) 46 | else: 47 | return TOKENIZERS[tok_name](*args, **kwargs) 48 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/tokenizers/pretrained/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/rqvae/txtimg_datasets/tokenizers/pretrained/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | """modified from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py""" 2 | import gzip 3 | import html 4 | import random 5 | from functools import lru_cache 6 | from collections import namedtuple 7 | 8 | import ftfy 9 | import torch 10 | import regex as re 11 | 12 | from .utils import default_bpe 13 | 14 | 15 | TOKENIZER_OUTPUT = namedtuple('output', ['tokens', 'ids']) 16 | 17 | 18 | @lru_cache() 19 | def bytes_to_unicode(): 20 | """ 21 | Returns list of utf-8 byte and a corresponding list of unicode strings. 22 | The reversible bpe codes work on unicode strings. 23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 25 | This is a signficant percentage of your normal, say, 32K bpe vocab. 26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 27 | And avoids mapping to whitespace/control characters the bpe code barfs on. 28 | """ 29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 30 | cs = bs[:] 31 | n = 0 32 | for b in range(2**8): 33 | if b not in bs: 34 | bs.append(b) 35 | cs.append(2**8+n) 36 | n += 1 37 | cs = [chr(n) for n in cs] 38 | return dict(zip(bs, cs)) 39 | 40 | 41 | def get_pairs(word): 42 | """Return set of symbol pairs in a word. 43 | Word is represented as tuple of symbols (symbols being variable-length strings). 44 | """ 45 | pairs = set() 46 | prev_char = word[0] 47 | for char in word[1:]: 48 | pairs.add((prev_char, char)) 49 | prev_char = char 50 | return pairs 51 | 52 | 53 | def basic_clean(text): 54 | text = ftfy.fix_text(text) 55 | text = html.unescape(html.unescape(text)) 56 | return text.strip() 57 | 58 | 59 | def whitespace_clean(text): 60 | text = re.sub(r'\s+', ' ', text) 61 | text = text.strip() 62 | return text 63 | 64 | 65 | class SimpleTokenizer(object): 66 | def __init__(self, bpe_path=None, lowercase=True): 67 | 68 | assert lowercase 69 | bpe_path = default_bpe() if bpe_path is None else bpe_path 70 | 71 | self.byte_encoder = bytes_to_unicode() 72 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 73 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 74 | merges = merges[1:49152-256-2+1] 75 | merges = [tuple(merge.split()) for merge in merges] 76 | vocab = list(bytes_to_unicode().values()) 77 | vocab = vocab + [v+'' for v in vocab] 78 | for merge in merges: 79 | vocab.append(''.join(merge)) 80 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 81 | self.encoder = dict(zip(vocab, range(len(vocab)))) 82 | self.decoder = {v: k for k, v in self.encoder.items()} 83 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 84 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 85 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 86 | 87 | def bpe(self, token): 88 | if token in self.cache: 89 | return self.cache[token] 90 | word = tuple(token[:-1]) + (token[-1] + '',) 91 | pairs = get_pairs(word) 92 | 93 | if not pairs: 94 | return token+'' 95 | 96 | while True: 97 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 98 | if bigram not in self.bpe_ranks: 99 | break 100 | first, second = bigram 101 | new_word = [] 102 | i = 0 103 | while i < len(word): 104 | try: 105 | j = word.index(first, i) 106 | new_word.extend(word[i:j]) 107 | i = j 108 | except: 109 | new_word.extend(word[i:]) 110 | break 111 | 112 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 113 | new_word.append(first+second) 114 | i += 2 115 | else: 116 | new_word.append(word[i]) 117 | i += 1 118 | new_word = tuple(new_word) 119 | word = new_word 120 | if len(word) == 1: 121 | break 122 | else: 123 | pairs = get_pairs(word) 124 | word = ' '.join(word) 125 | self.cache[token] = word 126 | return word 127 | 128 | def enable_padding(self, *args, **kwargs): 129 | self.context_length = kwargs['length'] 130 | 131 | def enable_truncation(self, *args, **kwargs): 132 | pass # just for matching the template of huggingface tokenizers 133 | 134 | def encode(self, text): 135 | sot_token = self.encoder["<|startoftext|>"] 136 | eot_token = self.encoder["<|endoftext|>"] 137 | tokens = self._encode(text) 138 | 139 | # Some text exceeds maximum context_length. Random crop to fit under context_length 140 | start = 0 141 | end = len(tokens) - (self.context_length - 2) 142 | if end > 0: 143 | start = random.randint(0, end) 144 | 145 | tokens = [sot_token] + tokens[start:start+self.context_length - 2] + [eot_token] 146 | result = torch.zeros(self.context_length, dtype=torch.long) 147 | 148 | result[:len(tokens)] = torch.tensor(tokens) 149 | 150 | output = TOKENIZER_OUTPUT(tokens=None, ids=result) 151 | 152 | return output 153 | 154 | def _encode(self, text): 155 | bpe_tokens = [] 156 | text = whitespace_clean(basic_clean(text)).lower() 157 | for token in re.findall(self.pat, text): 158 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 159 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 160 | 161 | return bpe_tokens 162 | 163 | def decode(self, tokens): 164 | text = ''.join([self.decoder[token] for token in tokens]) 165 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 166 | return text 167 | 168 | def decode_batch(self, tokenss): 169 | return [self.decode(tokens) for tokens in tokenss] 170 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/tokenizers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from functools import lru_cache 17 | 18 | 19 | @lru_cache() 20 | def default_bpe(): 21 | # used in the original CLIP implementation 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 23 | "pretrained", 24 | "bpe_simple_vocab_16e6.txt.gz") 25 | 26 | 27 | @lru_cache() 28 | def bert_vocab(): 29 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 30 | "pretrained", 31 | "bert-base-uncased-vocab.txt") 32 | 33 | 34 | @lru_cache() 35 | def gpt2_vocab(): 36 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 37 | "pretrained", 38 | "vocab.json") 39 | 40 | 41 | @lru_cache() 42 | def gpt2_merges(): 43 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 44 | "pretrained", 45 | "merges.txt") 46 | 47 | 48 | @lru_cache() 49 | def huggingface_bpe_16k_vocab(): 50 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 51 | "pretrained", 52 | "bpe-16k-vocab.json") 53 | 54 | 55 | @lru_cache() 56 | def huggingface_bpe_16k_merges(): 57 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 58 | "pretrained", 59 | "bpe-16k-merges.txt") 60 | 61 | 62 | @lru_cache() 63 | def huggingface_bpe_30k_vocab(): 64 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 65 | "pretrained", 66 | "bpe-30k-vocab.json") 67 | 68 | 69 | @lru_cache() 70 | def huggingface_bpe_30k_merges(): 71 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), 72 | "pretrained", 73 | "bpe-30k-merges.txt") 74 | -------------------------------------------------------------------------------- /rqvae/txtimg_datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torchvision.transforms as transforms 18 | import torchvision.transforms.functional as F 19 | 20 | try: 21 | from torchvision.transforms.functional import get_image_size 22 | except ImportError: 23 | from torchvision.transforms.functional import _get_image_size as get_image_size 24 | 25 | 26 | class AugmentationDALLE(nn.Module): 27 | def __init__(self, size): 28 | super().__init__() 29 | 30 | self.size = size 31 | 32 | def forward(self, img): 33 | w, h = get_image_size(img) 34 | s_min = min(w, h) 35 | 36 | off_h = torch.randint(low=3 * (h - s_min) // 8, 37 | high=max(3 * (h - s_min) // 8 + 1, 5 * (h - s_min) // 8), 38 | size=(1,)).item() 39 | off_w = torch.randint(low=3 * (w - s_min) // 8, 40 | high=max(3 * (w - s_min) // 8 + 1, 5 * (w - s_min) // 8), 41 | size=(1,)).item() 42 | 43 | img = F.crop(img, top=off_h, left=off_w, height=s_min, width=s_min) 44 | 45 | t_max = max(min(s_min, round(9 / 8 * self.size)), self.size) 46 | t = torch.randint(low=self.size, high=t_max + 1, size=(1,)).item() 47 | img = F.resize(img, [t, t]) 48 | return img 49 | 50 | 51 | class Rescale(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | 55 | def forward(self, img): 56 | return (1 - 2 * 0.1) * img + 0.1 57 | 58 | 59 | def create_transforms(config, split='train', is_eval=False): 60 | if config.transforms == 'dalle': 61 | if split == 'train' and not is_eval: 62 | transforms_ = [ 63 | AugmentationDALLE(size=config.image_resolution), 64 | transforms.RandomCrop(size=(config.image_resolution, config.image_resolution)), 65 | transforms.ToTensor(), 66 | Rescale() 67 | ] 68 | else: 69 | transforms_ = [ 70 | transforms.Resize(size=(config.image_resolution, config.image_resolution)), 71 | transforms.ToTensor(), 72 | Rescale() 73 | ] 74 | elif config.transforms == 'dalle-vqvae': 75 | if split == 'train' and not is_eval: 76 | transforms_ = [ 77 | AugmentationDALLE(size=config.image_resolution), 78 | transforms.RandomCrop(size=(config.image_resolution, config.image_resolution)), 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 81 | ] 82 | else: 83 | transforms_ = [ 84 | transforms.Resize(size=(config.image_resolution, config.image_resolution)), 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 87 | ] 88 | elif config.transforms == 'clip': 89 | if split == 'train' and not is_eval: 90 | transforms_ = [ 91 | transforms.Resize(size=(config.image_resolution, config.image_resolution)), 92 | transforms.RandomResizedCrop(size=config.image_resolution, scale=(0.8, 1.0)), 93 | transforms.ToTensor(), 94 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 95 | ] 96 | else: 97 | transforms_ = [ 98 | transforms.Resize(size=(config.image_resolution, config.image_resolution)), 99 | transforms.ToTensor(), 100 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 101 | ] 102 | elif config.transforms == 'clip-dvae': 103 | if split == 'train' and not is_eval: 104 | transforms_ = [ 105 | transforms.Resize(size=(config.image_resolution, config.image_resolution)), 106 | transforms.RandomResizedCrop(size=config.image_resolution, scale=(0.8, 1.0)), 107 | transforms.ToTensor(), 108 | Rescale() 109 | ] 110 | else: 111 | transforms_ = [ 112 | transforms.Resize(size=(config.image_resolution, config.image_resolution)), 113 | transforms.ToTensor(), 114 | Rescale() 115 | ] 116 | elif config.transforms == 'none': 117 | transforms_ = [] 118 | else: 119 | raise NotImplementedError('%s not implemented..' % config.transforms) 120 | 121 | transforms_ = transforms.Compose(transforms_) 122 | 123 | return transforms_ 124 | -------------------------------------------------------------------------------- /rqvae/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/rq-vae-transformer/341395e562ac347f5eb62db9f5f08b9f2cc42a60/rqvae/utils/__init__.py -------------------------------------------------------------------------------- /rqvae/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from omegaconf import OmegaConf, DictConfig 4 | from easydict import EasyDict as edict 5 | import yaml 6 | 7 | from rqvae.models.rqtransformer.configs import RQTransformerConfig 8 | 9 | 10 | def easydict_to_dict(obj): 11 | if not isinstance(obj, edict): 12 | return obj 13 | else: 14 | return {k: easydict_to_dict(v) for k, v in obj.items()} 15 | 16 | 17 | def load_config(config_path): 18 | with open(config_path) as f: 19 | config = yaml.load(f, Loader=yaml.FullLoader) 20 | config = easydict_to_dict(config) 21 | config = OmegaConf.create(config) 22 | return config 23 | 24 | 25 | def is_stage1_arch(arch_type): 26 | return not ('transformer' in arch_type) 27 | 28 | 29 | def augment_arch_defaults(arch_config): 30 | 31 | if arch_config.type == 'rq-vae': 32 | arch_defaults = OmegaConf.create( 33 | { 34 | 'ema': None, 35 | 'hparams': { 36 | 'loss_type': 'l1', 37 | 'restart_unused_codes': False, 38 | 'use_padding_idx': False, 39 | 'masked_dropout': 0.0, 40 | }, 41 | 'checkpointing': False, 42 | } 43 | ) 44 | elif arch_config.type == 'rq-transformer': 45 | arch_defaults = RQTransformerConfig.create(arch_config) 46 | else: 47 | raise NotImplementedError 48 | 49 | return OmegaConf.merge(arch_defaults, arch_config) 50 | 51 | 52 | def augment_optimizer_defaults(optim_config): 53 | 54 | defaults = OmegaConf.create( 55 | { 56 | 'type': 'adamW', 57 | 'max_gn': None, 58 | 'warmup': { 59 | 'mode': 'linear', 60 | 'start_from_zero': (True if optim_config.warmup.epoch > 0 else False), 61 | }, 62 | } 63 | ) 64 | return OmegaConf.merge(defaults, optim_config) 65 | 66 | 67 | def augment_defaults(config): 68 | 69 | defaults = OmegaConf.create( 70 | { 71 | 'arch': augment_arch_defaults(config.arch), 72 | 'dataset': { 73 | 'transform': {'type': None}, 74 | }, 75 | 'optimizer': augment_optimizer_defaults(config.optimizer), 76 | 'experiment': { 77 | 'test_freq': 10, 78 | 'amp': False, 79 | }, 80 | } 81 | ) 82 | 83 | if 'gan' in config: 84 | gan_defaults = OmegaConf.merge(defaults.optimizer, config.gan.disc.get('optimizer', {})) 85 | defaults.gan = OmegaConf.create( 86 | { 87 | 'disc': {'optimizer': gan_defaults}, 88 | } 89 | ) 90 | 91 | if not is_stage1_arch(config.arch.type): 92 | 93 | model_aux_path = config.vqvae.ckpt 94 | model_aux_config_path = os.path.join(os.path.dirname(model_aux_path), 'config.yaml') 95 | stage1_arch_config = load_config(model_aux_config_path).arch 96 | 97 | config.vqvae = stage1_arch_config 98 | config.vqvae.ckpt = model_aux_path 99 | 100 | defaults.vqvae = augment_arch_defaults(config.vqvae) 101 | defaults.arch.vocab_size = config.dataset.vocab_size 102 | defaults.experiment.sample = {'top_k': None, 'top_p': None} 103 | 104 | if config.get('loss', {}).get('type', '') == 'soft_target_cross_entropy': 105 | defaults.loss = {'temp': 1.0, 'stochastic_codes': False} 106 | else: 107 | defaults.loss = {'type': 'cross_entropy', 'temp': 1.0, 'stochastic_codes': False} 108 | 109 | config = OmegaConf.merge(defaults, config) 110 | 111 | return config 112 | 113 | 114 | def augment_dist_defaults(config, distenv): 115 | config = config.copy() 116 | 117 | local_batch_size = config.experiment.batch_size 118 | world_batch_size = distenv.world_size * local_batch_size 119 | total_batch_size = config.experiment.get('total_batch_size', world_batch_size) 120 | 121 | if total_batch_size % world_batch_size != 0: 122 | raise ValueError('total batch size must be divisible by world batch size') 123 | else: 124 | grad_accm_steps = total_batch_size // world_batch_size 125 | 126 | config.optimizer.grad_accm_steps = grad_accm_steps 127 | config.experiment.total_batch_size = total_batch_size 128 | 129 | return config 130 | 131 | 132 | def config_setup(args, distenv, config_path, extra_args=()): 133 | 134 | if args.eval: 135 | config = load_config(config_path) 136 | config = augment_defaults(config) 137 | 138 | if hasattr(args, 'test_batch_size'): 139 | config.experiment.batch_size = args.test_batch_size 140 | if not hasattr(config, 'seed'): 141 | config.seed = args.seed 142 | 143 | elif args.resume: 144 | config = load_config(config_path) 145 | if distenv.world_size != config.runtime.distenv.world_size: 146 | raise ValueError("world_size not identical to the resuming config") 147 | config.runtime = {'args': vars(args), 'distenv': distenv} 148 | 149 | else: # training 150 | config_path = args.model_config 151 | config = load_config(config_path) 152 | 153 | extra_config = OmegaConf.from_dotlist(extra_args) 154 | config = OmegaConf.merge(config, extra_config) 155 | 156 | config = augment_defaults(config) 157 | config = augment_dist_defaults(config, distenv) 158 | 159 | config.seed = args.seed 160 | config.runtime = {'args': vars(args), 'extra_config': extra_config, 'distenv': distenv} 161 | 162 | return config 163 | -------------------------------------------------------------------------------- /rqvae/utils/dist.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import datetime 3 | import os 4 | import collections 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from torch.nn.parallel import DistributedDataParallel 9 | 10 | 11 | def update_argument_parser(parser): 12 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 13 | parser.add_argument( 14 | '--local_rank', default=-1, type=int, 15 | help='Used for multi-process training. Can either be manually set ' + 16 | 'or automatically set by using \'python -m torch.distributed.launch\'.') 17 | return parser 18 | 19 | 20 | @dataclass 21 | class DistEnv: 22 | world_size: int 23 | world_rank: int 24 | local_rank: int 25 | num_gpus: int 26 | master: bool 27 | device_name: str 28 | 29 | 30 | def initialize(args, logger=None): 31 | 32 | args.rank = int(os.environ.get("RANK", 0)) 33 | args.world_size = int(os.environ.get('WORLD_SIZE', 1)) 34 | args.local_rank = int(os.environ.get('LOCAL_RANK', 0)) 35 | 36 | if args.world_size > 1: 37 | 38 | os.environ["RANK"] = str(args.rank) 39 | os.environ["WORLD_SIZE"] = str(args.world_size) 40 | os.environ["LOCAL_RANK"] = str(args.local_rank) 41 | 42 | print(f'[dist] Distributed: wait dist process group:{args.local_rank}') 43 | dist.init_process_group(backend=args.dist_backend, init_method='env://', 44 | world_size=args.world_size, 45 | timeout=datetime.timedelta(0, args.timeout)) 46 | assert (args.world_size == dist.get_world_size()) 47 | print( 48 | f"""[dist] Distributed: success device:{args.local_rank}, """, 49 | f"""{dist.get_rank()}/{dist.get_world_size()}""" 50 | ) 51 | distenv = DistEnv(world_size=dist.get_world_size(), 52 | world_rank=dist.get_rank(), 53 | local_rank=args.local_rank, 54 | num_gpus=1, 55 | master=(dist.get_rank() == 0), 56 | device_name=torch.cuda.get_device_name(), 57 | ) 58 | else: 59 | print('[dist] Single processed') 60 | distenv = DistEnv(1, 0, 0, torch.cuda.device_count(), True, torch.cuda.get_device_name()) 61 | 62 | print(f'[dist] {distenv}') 63 | 64 | if logger is not None: 65 | logger.info(distenv) 66 | 67 | return distenv 68 | 69 | 70 | def dataparallel_and_sync(distenv, model, find_unused_parameters=False): 71 | 72 | if dist.is_initialized(): 73 | model = DistributedDataParallel( 74 | model, device_ids=[distenv.local_rank], output_device=distenv.local_rank, 75 | find_unused_parameters=find_unused_parameters 76 | ) 77 | for _, param in model.state_dict().items(): 78 | dist.broadcast(param, 0) 79 | 80 | dist.barrier() 81 | else: 82 | model = torch.nn.DataParallel(model) 83 | torch.cuda.synchronize() 84 | 85 | return model 86 | 87 | 88 | def param_sync(param): 89 | dist.broadcast(param, 0) 90 | dist.barrier() 91 | torch.cuda.synchronize() 92 | 93 | 94 | @torch.no_grad() 95 | def all_gather_cat(distenv, tensor, dim=0): 96 | if distenv.world_size == 1: 97 | return tensor 98 | 99 | g_tensor = [torch.ones_like(tensor) for _ in range(distenv.world_size)] 100 | dist.all_gather(g_tensor, tensor) 101 | g_tensor = torch.cat(g_tensor, dim=dim) 102 | 103 | return g_tensor 104 | -------------------------------------------------------------------------------- /rqvae/utils/profiler.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Profiler: 4 | opts_model_size = {'trainable-only', 'transformer-block-only'} 5 | 6 | def __init__(self, logger): 7 | self._logger = logger 8 | 9 | def get_model_size(self, model, opt=None): 10 | if opt is None: 11 | self._logger.info( 12 | "[OPTION: ALL] #parameters: %.4fM", sum(p.numel() for p in model.parameters()) / 1e6 13 | ) 14 | else: 15 | assert opt in self.opts_model_size, f'{opt} is not in {self.opts_model_size}' 16 | 17 | if opt == 'trainable-only': 18 | self._logger.info( 19 | "[OPTION: %s] #parameters: %.4fM", opt, 20 | sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 21 | ) 22 | else: 23 | if hasattr(model, 'blocks'): 24 | self._logger.info( 25 | "[OPTION: %s] #parameters: %.4fM", opt, 26 | sum(p.numel() for p in model.blocks.parameters()) / 1e6 27 | ) 28 | -------------------------------------------------------------------------------- /rqvae/utils/setup.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import logging 3 | import inspect 4 | import os 5 | import shutil 6 | from pathlib import Path 7 | 8 | from omegaconf import OmegaConf 9 | import torch 10 | 11 | from .writer import Writer 12 | from .config import config_setup 13 | from .dist import initialize as dist_init 14 | 15 | 16 | def logger_setup(log_path, eval=False): 17 | 18 | log_fname = os.path.join(log_path, 'val.log' if eval else 'train.log') 19 | 20 | for hdlr in logging.root.handlers: 21 | logging.root.removeHandler(hdlr) 22 | 23 | logging.basicConfig( 24 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 25 | datefmt="%m/%d/%Y %H:%M:%S", 26 | level=logging.INFO, 27 | handlers=[ 28 | logging.FileHandler(log_fname), logging.StreamHandler() 29 | ], 30 | ) 31 | main_filename, *_ = inspect.getframeinfo(inspect.currentframe().f_back.f_back) 32 | 33 | logger = logging.getLogger(Path(main_filename).name) 34 | writer = Writer(log_path) 35 | 36 | return logger, writer 37 | 38 | 39 | def setup(args, extra_args=()): 40 | """ 41 | meaning of args.result_path: 42 | - if args.eval, directory where the model is 43 | - if args.resume, no meaning 44 | - otherwise, path to store the logs 45 | 46 | Returns: 47 | config, logger, writer 48 | """ 49 | 50 | distenv = dist_init(args) 51 | 52 | args.result_path = Path(args.result_path).absolute().as_posix() 53 | args.model_config = Path(args.model_config).absolute().resolve().as_posix() 54 | 55 | now = datetime.now().strftime('%d%m%Y_%H%M%S') 56 | 57 | if args.eval: 58 | config_path = Path(args.result_path).joinpath('config.yaml') 59 | log_path = Path(args.result_path).joinpath('val', now) 60 | 61 | elif args.resume: 62 | load_path = Path(args.load_path) 63 | if not load_path.is_file(): 64 | raise ValueError("load_path must be a valid filename") 65 | 66 | config_path = load_path.parent.joinpath('config.yaml').absolute() 67 | log_path = load_path.parent.parent.joinpath(now) 68 | 69 | else: 70 | config_path = Path(args.model_config).absolute() 71 | task_name = config_path.stem 72 | if args.postfix: 73 | task_name += f'__{args.postfix}' 74 | log_path = Path(args.result_path).joinpath(task_name, now) 75 | 76 | config = config_setup(args, distenv, config_path, extra_args=extra_args) 77 | config.result_path = log_path.absolute().resolve().as_posix() 78 | 79 | if distenv.master: 80 | if not log_path.exists(): 81 | os.makedirs(log_path) 82 | logger, writer = logger_setup(log_path) 83 | logger.info(distenv) 84 | logger.info(f'log_path: {log_path}') 85 | logger.info('\n' + OmegaConf.to_yaml(config)) 86 | OmegaConf.save(config, log_path.joinpath('config.yaml')) 87 | 88 | src_dir = Path(os.getcwd()).joinpath('rqvae') 89 | shutil.copytree(src_dir, log_path.joinpath('rqvae')) 90 | logger.info(f'source copied to {log_path}/rqvae') 91 | else: 92 | logger, writer, log_path = None, None, None 93 | 94 | return config, logger, writer 95 | -------------------------------------------------------------------------------- /rqvae/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from tqdm import tqdm 8 | from torch.nn import functional as F 9 | 10 | 11 | def save_pickle(fname, data): 12 | with open(fname, 'wb') as fp: 13 | pickle.dump(data, fp, pickle.HIGHEST_PROTOCOL) 14 | 15 | 16 | def compute_p_norm(model): 17 | norm = 0 18 | for k, v in model.state_dict().items(): 19 | v = v.detach().clone() 20 | norm += torch.sum(v.view(-1).pow_(2)) 21 | return norm 22 | 23 | 24 | def get_num_conv_linear_layers(model): 25 | cnt = 0 26 | weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ConvTranspose2d) 27 | for mn, m in model.named_modules(): 28 | for pn, p in m.named_parameters(): 29 | if pn.endswith('weight') and isinstance(m, weight_modules): 30 | cnt += 1 31 | return cnt 32 | 33 | 34 | def compute_model_size(model, logger): 35 | if logger is not None: 36 | logger.info( 37 | "#parameters: %.4fM", sum(p.numel() for p in model.parameters()) / 1000 / 1000 38 | ) 39 | 40 | 41 | def set_seed(seed=None): 42 | if seed is None: 43 | seed = random.getrandbits(32) 44 | random.seed(seed) 45 | np.random.seed(seed) 46 | torch.manual_seed(seed) 47 | torch.cuda.manual_seed_all(seed) 48 | return seed 49 | 50 | 51 | def np2tn(array): 52 | if len(array.shape) == 4: 53 | return torch.from_numpy(np.transpose(array, (3, 2, 0, 1))) 54 | elif len(array.shape) == 2: 55 | return torch.from_numpy(array.T) 56 | else: 57 | raise ValueError('invalid shape') 58 | 59 | 60 | def top_k_logits(logits, k): 61 | v, ix = torch.topk(logits, k) 62 | out = logits.clone() 63 | out[out < v[:, [-1]]] = -float('Inf') 64 | return out 65 | 66 | 67 | def top_p_probs(probs, p): 68 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 69 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 70 | 71 | sorted_idx_remove_cond = cum_probs >= p 72 | 73 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 74 | sorted_idx_remove_cond[..., 0] = 0 75 | 76 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 77 | probs = probs.masked_fill(indices_to_remove, 0.0) 78 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) 79 | return norm_probs 80 | 81 | 82 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None): 83 | """Take a 2-dim tensor, apply softmax along each row, and sample from 84 | each multinomial distribution defined by the rows. 85 | 86 | Args: 87 | logits: 2-dim tensor of shape (n_samples, logit_dim) 88 | temperature (float): softmax temperature 89 | top_k (Optional[int]): if given, sample only using `top_k` logits 90 | top_p (Optional[float]): if given, sample only using `top_p` logits 91 | 92 | Returns: 93 | samples: 1-dim integer tensor of shape (n_samples,) 94 | """ 95 | 96 | logits = logits.to(dtype=torch.float32) 97 | logits = logits / temperature 98 | 99 | # optionally crop probabilities to only the top k options 100 | if top_k is not None: 101 | logits = top_k_logits(logits, top_k) 102 | 103 | if torch.sum(torch.isnan(logits)): 104 | print('WARNING... NaN observed') 105 | logits[torch.isnan(logits)] = -float('Inf') 106 | 107 | # apply softmax to convert to probabilities 108 | probs = F.softmax(logits, dim=-1) 109 | 110 | if top_p is not None: 111 | probs = top_p_probs(probs, top_p) 112 | 113 | try: 114 | samples = torch.multinomial(probs, num_samples=1) 115 | except RuntimeError: 116 | print(probs) 117 | print(logits) 118 | print('isinf, ', torch.sum(torch.isinf(probs))) 119 | print('isnan, ', torch.sum(torch.isnan(probs))) 120 | print('is negative', torch.sum(probs < 0)) 121 | raise 122 | 123 | return samples.view(-1) 124 | -------------------------------------------------------------------------------- /rqvae/utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | 6 | class Writer: 7 | def __init__(self, result_path): 8 | self.result_path = result_path 9 | 10 | self.writer_trn = SummaryWriter(os.path.join(result_path, 'train')) 11 | self.writer_val = SummaryWriter(os.path.join(result_path, 'valid')) 12 | self.writer_val_ema = SummaryWriter(os.path.join(result_path, 'valid_ema')) 13 | 14 | def _get_writer(self, mode): 15 | if mode == 'train': 16 | writer = self.writer_trn 17 | elif mode == 'valid': 18 | writer = self.writer_val 19 | elif mode == 'valid_ema': 20 | writer = self.writer_val_ema 21 | else: 22 | raise ValueError(f'{mode} is not valid..') 23 | 24 | return writer 25 | 26 | def add_scalar(self, tag, scalar, mode, epoch=0): 27 | writer = self._get_writer(mode) 28 | writer.add_scalar(tag, scalar, epoch) 29 | 30 | def add_image(self, tag, image, mode, epoch=0): 31 | writer = self._get_writer(mode) 32 | writer.add_image(tag, image, epoch) 33 | 34 | def add_text(self, tag, text, mode, epoch=0): 35 | writer = self._get_writer(mode) 36 | writer.add_text(tag, text, epoch) 37 | 38 | def close(self): 39 | self.writer_trn.close() 40 | self.writer_val.close() 41 | self.writer_val_ema.close() 42 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E261, E226 4 | --------------------------------------------------------------------------------