├── .gitignore ├── LICENSE ├── NLVR.py ├── Pretrain.py ├── README.md ├── VE.py ├── VQA.py ├── accelerators ├── __init__.py ├── accelerator.py └── apex_ddp_accelerator.py ├── configs ├── GLUE.yaml ├── NLVR.yaml ├── Pretrain.yaml ├── VE.yaml ├── VQA.yaml ├── config_bert.json ├── gen_coco.yaml ├── image_ft.yaml ├── image_sampling.yaml └── linear_probe.yaml ├── dataset ├── __init__.py ├── caption_dataset.py ├── dalle_transforms.py ├── dist_dataset.py ├── gen_dataset.py ├── nlvr_dataset.py ├── randaugment.py ├── utils.py ├── ve_dataset.py └── vqa_dataset.py ├── eval_coco.py ├── gen_coco.py ├── glue.py ├── image_finetune.py ├── image_linprobe.py ├── image_sampling.py ├── img.png ├── models ├── DALLE-pytorch │ ├── .DS_Store │ ├── .github │ │ └── workflows │ │ │ └── python-publish.yml │ ├── .gitignore │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── docker │ │ └── Dockerfile │ ├── examples │ │ └── rainbow_dalle.ipynb │ ├── generate.py │ ├── install_apex.sh │ ├── install_deepspeed.sh │ ├── setup.py │ ├── train_dalle.py │ └── train_vae.py ├── __init__.py ├── bert.py ├── dall_e │ ├── __init__.py │ ├── decoder.py │ ├── encoder.py │ └── utils.py ├── dalle_pytorch │ ├── __init__.py │ ├── attention.py │ ├── dalle_pytorch.py │ ├── distributed_backends │ │ ├── __init__.py │ │ ├── deepspeed_backend.py │ │ ├── distributed_backend.py │ │ ├── dummy_backend.py │ │ └── horovod_backend.py │ ├── distributed_utils.py │ ├── loader.py │ ├── reversible.py │ ├── tokenizer.py │ ├── transformer.py │ └── vae.py ├── dalle_utils.py ├── davinci_pretrain.py ├── model_glue.py ├── model_image_sampling.py ├── model_imageft.py ├── model_linearprobe.py ├── model_nlvr.py ├── model_ve.py ├── model_vqa.py ├── modeling_discrete_vae.py ├── resnet.py ├── tokenization_bert.py ├── vit.py └── xbert.py ├── optim ├── __init__.py ├── adafactor.py ├── adahessian.py ├── adamp.py ├── adamw.py ├── lars.py ├── lookahead.py ├── nadam.py ├── novograd.py ├── nvnovograd.py ├── optim_factory.py ├── radam.py ├── rmsprop_tf.py └── sgdp.py ├── output └── pretrain │ └── REAMDE.md ├── requirements.txt ├── scheduler ├── __init__.py ├── cosine_lr.py ├── plateau_lr.py ├── scheduler.py ├── scheduler_factory.py ├── step_lr.py └── tanh_lr.py ├── taming ├── __init__.py ├── lr_scheduler.py ├── main.py ├── models │ ├── __init__.py │ ├── cond_transformer.py │ └── vqgan.py ├── modules │ ├── .DS_Store │ ├── __init__.py │ ├── autoencoder │ │ └── lpips │ │ │ └── vgg.pth │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── model.py │ ├── discriminator │ │ ├── __init__.py │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── segmentation.py │ │ └── vqperceptual.py │ ├── misc │ │ ├── __init__.py │ │ └── coord.py │ ├── transformer │ │ ├── __init__.py │ │ ├── mingpt.py │ │ └── permuter.py │ ├── util.py │ └── vqvae │ │ ├── __init__.py │ │ └── quantize.py └── util.py ├── util ├── checkpointer.py ├── hdfs_io.py └── torch_io.py ├── utils.py └── vqaTools ├── __init__.py ├── answer_list.json ├── vqa.py └── vqaEval.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data 3 | images 4 | results 5 | *.pyc 6 | __pycache__ 7 | apex 8 | coco.tar 9 | visualgenome.tar -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023, ByteDance Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of bytedance.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /accelerators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/accelerators/__init__.py -------------------------------------------------------------------------------- /accelerators/accelerator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Created on Feb-19-21 16:36 4 | accelerator.py 5 | Description: accelerators的基类,便于后续其他加速方案的接入。 6 | ''' 7 | 8 | from logging import Logger 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | Net = torch.nn.Module 14 | 15 | 16 | class Accelerator: 17 | """ 18 | Accelerator是所有accelerators的基类,新添加的accelerator需要继承该类。 19 | """ 20 | 21 | def __init__(self, cfg, logger) -> None: 22 | self.cfg = cfg 23 | self.logger = logger 24 | 25 | def set_up(self, model: Net): 26 | raise NotImplementedError("Set Up method not implement in Accelerator, please check! ") 27 | 28 | def broadcast(self): 29 | raise NotImplementedError("Broadcast method not implement in Accelerator, please check! ") 30 | 31 | def backward_step(self, loss: torch.Tensor): 32 | loss.backward() 33 | 34 | def optimizer_step(self, optimizer: Optimizer, model: Net, grad_norm: float) -> float: 35 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 36 | grad_norm) 37 | return float(total_norm) 38 | 39 | -------------------------------------------------------------------------------- /accelerators/apex_ddp_accelerator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Created on Nov-18-20 15:21 4 | ddp_accelerator.py 5 | @author: liuzhen.nlp 6 | Description: 7 | ''' 8 | 9 | import os 10 | import random 11 | import sys 12 | from typing import Tuple, Union, Optional, Any 13 | import numpy as np 14 | 15 | import torch 16 | import torch.distributed as distributed 17 | from torch.optim import Optimizer 18 | from torch.optim.lr_scheduler import LambdaLR 19 | 20 | Net = torch.nn.Module 21 | 22 | from .accelerator import Accelerator 23 | 24 | try: 25 | from apex import amp 26 | from apex.parallel import DistributedDataParallel as Apex_DDP 27 | from apex.parallel import convert_syncbn_model 28 | except ImportError: 29 | print('no apex! Please install from https://www.github.com/nvidia/apex') 30 | 31 | 32 | class ApexDDPAccelerator(Accelerator): 33 | """ 34 | ApexDDPAccelerator 使用apex DistributedDataParallel进行分布式加速训练 35 | """ 36 | 37 | def __init__(self, cfg, logger): 38 | super().__init__(cfg, logger) 39 | self.accelerator_rng_seed = self.cfg.RNG_SEED 40 | self.accelerator_syncbn = self.cfg.SYNCBN 41 | self.accelerator_fp16_opt_level = self.cfg.FP16_OPT_LEVEL 42 | self.accelerator_fp16_loss_scale = self.cfg.FP16_LOSS_SCALE 43 | 44 | def set_up(self, model: Net, optimizer: Optimizer, lr_scheduler: LambdaLR, 45 | local_rank: int, world_size: int, rank: int) -> Tuple[Apex_DDP, Optimizer, LambdaLR]: 46 | """ 47 | 初始化ApexDDPAccelerator,包括process_group和apex_ddp的初始化 48 | """ 49 | torch.backends.cudnn.benchmark = False 50 | random.seed(self.accelerator_rng_seed) 51 | np.random.seed(self.accelerator_rng_seed) 52 | torch.random.manual_seed(self.accelerator_rng_seed) 53 | torch.cuda.manual_seed_all(self.accelerator_rng_seed) 54 | master_address = os.environ.get('MASTER_ADDR', "127.0.0.1") 55 | master_port = int(os.environ.get('MASTER_PORT', 34171)) 56 | 57 | torch.cuda.set_device(local_rank) 58 | model = model.cuda() 59 | if not torch.distributed.is_initialized(): 60 | distributed.init_process_group( 61 | backend='nccl', 62 | init_method='tcp://{}:{}'.format(master_address, master_port), 63 | world_size=world_size, 64 | rank=rank, 65 | group_name='mtorch') 66 | print( 67 | f'ApexDDPAccelerator distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}') 68 | sys.stdout.flush() 69 | 70 | self.broadcast(model) 71 | apex_model, optimizer = self.configure_ddp(model, optimizer) 72 | 73 | if self.accelerator_syncbn: 74 | apex_model = self.configure_sync_batchnorm(apex_model) 75 | return apex_model, optimizer, lr_scheduler 76 | 77 | def broadcast(self, model: Net, src=0) -> None: 78 | """ 79 | 将model的参数做broadcast 80 | """ 81 | for v in model.state_dict().values(): 82 | distributed.broadcast(v, src) 83 | 84 | def configure_ddp(self, model: Net, optimizer: Optimizer) -> Tuple[Apex_DDP, Optimizer]: 85 | """ 86 | 初始化apex_ddp 87 | """ 88 | model, optimizer = amp.initialize(model, optimizer, 89 | opt_level=self.accelerator_fp16_opt_level, 90 | keep_batchnorm_fp32=None, # from True to None 91 | loss_scale=self.accelerator_fp16_loss_scale, 92 | max_loss_scale=1024.0, 93 | min_loss_scale=1.0) 94 | 95 | apex_model = Apex_DDP(model, delay_allreduce=True) 96 | self.ddp_model = apex_model 97 | return apex_model, optimizer 98 | 99 | def configure_sync_batchnorm(self, model: Net) -> Net: 100 | """ 101 | 将model中的``torch.nn.modules.batchnorm._BatchNorm`` 转为 :class:`apex.parallel.SyncBatchNorm`. 102 | """ 103 | model = convert_syncbn_model(model) 104 | return model 105 | 106 | def backward_step(self, loss: torch.Tensor, optimizer: Optimizer): 107 | """ 108 | backward step 109 | """ 110 | with amp.scale_loss(loss, optimizer) as scaled_loss: 111 | scaled_loss.backward() 112 | 113 | def optimizer_step(self, optimizer: Optimizer, model: Net, grad_norm: float) -> float: 114 | """ 115 | Gradient clipping 116 | """ 117 | total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 118 | grad_norm) 119 | return float(total_norm) 120 | -------------------------------------------------------------------------------- /configs/GLUE.yaml: -------------------------------------------------------------------------------- 1 | image_res: 384 2 | second_input_size: 384 3 | vision_width: 1024 4 | embed_dim: 256 5 | batch_size_train: 32 6 | batch_size_test: 16 7 | alpha: 0.4 8 | warm_up: False 9 | 10 | # nouse hyperparams 11 | max_length: 25 12 | num_beams: 1 13 | temperature: 1 14 | top_k: 0 15 | top_p: 1 16 | repetition_penalty: 1 17 | length_penalty: 1 18 | early_stopping: false 19 | num_return_sequences: 1 20 | init_encoder: False 21 | init_decoder: False 22 | last_hidden_id_shift: 1 23 | 24 | bert_config: 'configs/config_bert.json' 25 | # dalle 26 | discrete_vae_weight_path: "vqgan_ckpt" 27 | discrete_vae_type: "vqgan" 28 | 29 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02} 30 | schedular: {sched: linear, last_epoch: -1, epochs: 10, warmup_epochs: 1} 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /configs/NLVR.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/nlvr_train.json'] 2 | val_file: ['data/nlvr_dev.json'] 3 | test_file: ['data/nlvr_test.json'] 4 | 5 | image_root: './images/nlvr2/' 6 | 7 | image_res: 384 8 | second_input_size: 384 9 | vision_width: 1024 10 | embed_dim: 256 11 | batch_size_train: 10 12 | 13 | bert_config: 'configs/config_bert.json' 14 | 15 | max_length: 25 16 | num_beams: 1 17 | temperature: 1 18 | top_k: 0 19 | top_p: 1 20 | repetition_penalty: 1 21 | length_penalty: 1 22 | early_stopping: false 23 | num_return_sequences: 1 24 | init_encoder: False 25 | init_decoder: False 26 | last_hidden_id_shift: 1 27 | 28 | # dalle 29 | discrete_vae_weight_path: "vqgan_ckpt" 30 | discrete_vae_type: "vqgan" 31 | 32 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.0} 33 | schedular: {sched: linear, last_epoch: -1, epochs: 9, warmup_epochs: 1} 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /configs/Pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ["path/to/image_text_pair/data"] 2 | c4_train_file: ["path/to/c4/data"] 3 | 4 | image_name: "binary" 5 | caption_name: "desc" 6 | train_file_tokenized: false 7 | train_dataset_size: 1335283 8 | checkpoint_frequent: 10000 9 | bert_config: 'configs/config_bert.json' 10 | init_encoder: False 11 | init_decoder: False 12 | 13 | image_res: 256 14 | second_input_size: 256 15 | vision_width: 1024 16 | embed_dim: 256 17 | batch_size: 64 18 | batch_size_c4: 64 19 | temp: 0.07 20 | dalle_goal: "mask" 21 | prefix_image: "dynamic" 22 | max_prefix_image_epoch: -1 23 | 24 | context_max_length: 25 25 | max_length: 25 26 | enc_max_words: 96 27 | dec_max_words: 256 28 | enc_max_tokens: 256 29 | dec_max_tokens: 256 30 | enc_dec_max_words: 512 31 | loss_pair_alpha: 1 32 | loss_image_generation_alpha: 1 33 | c4_alpha: 1 34 | loss_mim_alpha: 0 35 | 36 | num_beams: 1 37 | temperature: 1 38 | top_k: 0 39 | top_p: 1 40 | repetition_penalty: 1 41 | length_penalty: 1 42 | early_stopping: false 43 | num_return_sequences: 1 44 | eos: '[SEP]' 45 | 46 | # dalle 47 | discrete_vae_weight_path: "vqgan_ckpt" 48 | discrete_vae_type: "vqgan" 49 | 50 | optimizer: {opt: adamW, lr: 2e-4, weight_decay: 0.01} 51 | schedular: {sched: linear, last_epoch: -1, epochs: 40, warmup_epochs: 1} 52 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} -------------------------------------------------------------------------------- /configs/VE.yaml: -------------------------------------------------------------------------------- 1 | train_file: 'data/ve_train.json' 2 | val_file: 'data/ve_dev.json' 3 | test_file: 'data/ve_test.json' 4 | 5 | image_root: './images/snlive_images' 6 | 7 | image_res: 576 8 | second_input_size: 576 9 | vision_width: 1024 10 | embed_dim: 256 11 | batch_size_train: 8 12 | batch_size_test: 16 13 | alpha: 0.4 14 | warm_up: False 15 | 16 | # useless hyperparams 17 | max_length: 25 18 | num_beams: 1 19 | temperature: 1 20 | top_k: 0 21 | top_p: 1 22 | repetition_penalty: 1 23 | length_penalty: 1 24 | early_stopping: false 25 | num_return_sequences: 1 26 | init_encoder: False 27 | init_decoder: False 28 | last_hidden_id_shift: 1 29 | 30 | bert_config: 'configs/config_bert.json' 31 | # dalle 32 | discrete_vae_weight_path: "vqgan_ckpt" 33 | discrete_vae_type: "vqgan" 34 | 35 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.01} 36 | schedular: {sched: linear, last_epoch: -1, epochs: 3, warmup_epochs: 1} 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/VQA.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['./data/vqa_train.json', 2 | './data/vqa_val.json', 3 | './data/vg_qa.json'] 4 | val_file: ['./data/vqa_val.json'] 5 | test_file: ['./data/vqa_test.json'] 6 | answer_list: './vqaTools/answer_list.json' 7 | 8 | vqa_root: './images/coco/' 9 | vg_root: './images/visualgenome/' 10 | 11 | image_res: 480 12 | second_input_size: 240 13 | vision_width: 1024 14 | embed_dim: 256 15 | temp: 0.07 16 | batch_size_train: 24 17 | batch_size_test: 24 18 | 19 | alpha: 0.4 20 | num_answers: 3129 21 | context_max_length: 25 22 | max_length: 25 23 | num_beams: 1 24 | temperature: 1 25 | top_k: 0 26 | top_p: 1 27 | repetition_penalty: 1 28 | length_penalty: 1 29 | early_stopping: false 30 | num_return_sequences: 1 31 | eos: '[SEP]' 32 | loss_type: 'bce' 33 | k_test: 128 34 | 35 | bert_config: 'configs/config_bert.json' 36 | init_encoder: False 37 | init_decoder: False 38 | 39 | # dalle 40 | discrete_vae_weight_path: "vqgan_ckpt" 41 | discrete_vae_type: "vqgan" 42 | 43 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02} 44 | schedular: {sched: linear, last_epoch: -1, epochs: 8, warmup_epochs: 4} 45 | 46 | -------------------------------------------------------------------------------- /configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 257, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 6, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "text_vocab_size": 30522, 20 | "visual_vocab_size": 1024, 21 | "fusion_layer": 6, 22 | "encoder_width": 768 23 | } -------------------------------------------------------------------------------- /configs/gen_coco.yaml: -------------------------------------------------------------------------------- 1 | train_file: './data/coco_train.json' 2 | val_file: './data/coco_test.json' 3 | test_file: './data/coco_test.json' 4 | 5 | bert_config: 'configs/config_bert.json' 6 | init_encoder: False 7 | init_decoder: False 8 | 9 | image_root: './images/coco/' 10 | 11 | image_res: 576 12 | second_input_size: 192 13 | vision_width: 1024 14 | embed_dim: 256 15 | batch_size_train: 24 16 | batch_size_test: 6 17 | temp: 0.07 18 | 19 | max_length: 20 20 | num_beams: 2 21 | temperature: 1 22 | top_k: 0 23 | top_p: 1 24 | repetition_penalty: 1 25 | length_penalty: 1 26 | early_stopping: False 27 | num_return_sequences: 1 28 | eos: '[SEP]' 29 | prompt: 'a picture of ' 30 | label_smoothing: 0.1 31 | 32 | # dalle 33 | discrete_vae_weight_path: "vqgan_ckpt" 34 | discrete_vae_type: "vqgan" 35 | 36 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02} 37 | schedular: {sched: linear, last_epoch: -1, epochs: 15, warmup_epochs: 2} 38 | -------------------------------------------------------------------------------- /configs/image_ft.yaml: -------------------------------------------------------------------------------- 1 | dataset: imagenet 2 | root_dir: ./ 3 | 4 | image_res: 256 5 | second_input_size: 128 6 | vision_width: 1024 7 | embed_dim: 256 8 | temp: 0.07 9 | batch_size_train: 512 10 | batch_size_test: 512 11 | 12 | alpha: 0.4 13 | num_answers: 3129 14 | context_max_length: 25 15 | max_length: 25 16 | num_beams: 1 17 | temperature: 1 18 | top_k: 0 19 | top_p: 1 20 | repetition_penalty: 1 21 | length_penalty: 1 22 | early_stopping: false 23 | num_return_sequences: 1 24 | eos: '[SEP]' 25 | loss_type: 'kl' # bce or kl 26 | 27 | bert_config: 'configs/config_bert.json' 28 | init_encoder: False 29 | init_decoder: False 30 | 31 | # dalle 32 | discrete_vae_weight_path: "./vqgan_ckpt" 33 | discrete_vae_type: "vqgan" 34 | 35 | lr: 5e-4 36 | epochs: 100 37 | optimizer: "adamw" -------------------------------------------------------------------------------- /configs/image_sampling.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | "hdfs://haruna/home/byte_ailab_litg/user/zhangxinsong/datasets/vlm/coco_vg_bs64", 3 | ] 4 | val_file: './data/coco_val_image_generation.json' 5 | c4_train_file: ["hdfs://haruna/home/byte_ailab_litg/user/zhangxinsong/datasets/c4/en_split"] 6 | image_root: './images/coco/' 7 | 8 | image_name: "binary" 9 | caption_name: "desc" 10 | train_file_tokenized: false 11 | train_dataset_size: 1335283 12 | checkpoint_frequent: 10000 13 | 14 | bert_config: 'configs/config_bert.json' 15 | init_encoder: False 16 | init_decoder: False 17 | 18 | image_res: 256 19 | second_input_size: 256 20 | vision_width: 1024 21 | embed_dim: 256 22 | batch_size: 1 23 | temp: 0.07 24 | num_images: 128 25 | image_per_round: 8 26 | 27 | context_max_length: 25 28 | max_length: 257 29 | enc_max_words: 96 30 | dec_max_words: 256 31 | enc_max_tokens: 256 32 | dec_max_tokens: 256 33 | enc_dec_max_words: 512 34 | c4_alpha: 1 35 | 36 | num_beams: 1 37 | temperature: 1.0 38 | top_k: 4000 39 | top_p: 0.9 40 | repetition_penalty: 1 41 | length_penalty: 1 42 | early_stopping: false 43 | num_return_sequences: 1 44 | eos: '[SEP]' 45 | 46 | # dalle 47 | discrete_vae_weight_path: "vqgan_ckpt" 48 | discrete_vae_type: "vqgan" 49 | 50 | optimizer: {opt: adamW, lr: 2e-4, weight_decay: 0.01} 51 | schedular: {sched: linear, last_epoch: -1, epochs: 80, warmup_epochs: 4} 52 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 8, CLIP_GRAD_NORM: 1.0} -------------------------------------------------------------------------------- /configs/linear_probe.yaml: -------------------------------------------------------------------------------- 1 | dataset: imagenet 2 | root_dir: ./ 3 | 4 | image_res: 256 5 | second_input_size: 128 6 | vision_width: 1024 7 | embed_dim: 256 8 | temp: 0.07 9 | batch_size_train: 256 10 | batch_size_test: 256 11 | train_dataset_size: 1281167 12 | 13 | alpha: 0.4 14 | num_answers: 3129 15 | context_max_length: 25 16 | max_length: 25 17 | num_beams: 1 18 | temperature: 1 19 | top_k: 0 20 | top_p: 1 21 | repetition_penalty: 1 22 | length_penalty: 1 23 | early_stopping: false 24 | num_return_sequences: 1 25 | eos: '[SEP]' 26 | loss_type: 'kl' 27 | 28 | bert_config: 'configs/config_bert.json' 29 | init_encoder: False 30 | init_decoder: False 31 | 32 | # dalle 33 | discrete_vae_weight_path: "./vqgan_ckpt" 34 | discrete_vae_type: "vqgan" 35 | 36 | optimizer: {opt: adamW, lr: 6e-4, weight_decay: 0.01} 37 | schedular: {sched: linear, last_epoch: -1, epochs: 100, warmup_epochs: 10} 38 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} -------------------------------------------------------------------------------- /dataset/dist_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys 4 | from typing import List, Any 5 | import warnings 6 | import random 7 | from itertools import cycle 8 | import torch 9 | from torch.utils.data import IterableDataset 10 | 11 | from util.hdfs_io import hopen, hlist_files 12 | 13 | 14 | class DistLineReadingDataset(IterableDataset): # pylint: disable=W0223 15 | """ 16 | iterate a set of folders. 17 | """ 18 | 19 | def __init__(self, 20 | data_path: str, 21 | rank: int = 0, 22 | world_size: int = 1, 23 | shuffle: bool = False, 24 | repeat: bool = False): 25 | super().__init__() 26 | self.shuffle = shuffle 27 | self.rank = rank 28 | self.world_size = world_size 29 | 30 | self.files = hlist_files(data_path.split(',')) 31 | self.files = [f for f in self.files if f.find('_SUCCESS') < 0] 32 | self.is_hdfs = data_path.startswith('hdfs') 33 | 34 | self.repeat = repeat 35 | print('[DATA]--all dataset containing {} files.'.format(len(self.files))) 36 | if len(self.files) % self.world_size != 0: 37 | print('[DATA]--Whole dataset file num %s cannot split to worldsize %s ' % 38 | (len(self.files), self.world_size)) 39 | sys.stdout.flush() 40 | 41 | def generate(self): 42 | if self.world_size == 1 or len(self.files) == 1: 43 | cur_dataloader_files = self.files 44 | else: 45 | cur_dataloader_files = split_shard( 46 | self.files, self.rank, self.world_size) 47 | 48 | while True: 49 | if self.shuffle: 50 | random.shuffle(cur_dataloader_files) 51 | worker_info = torch.utils.data.get_worker_info() 52 | 53 | if worker_info is not None: 54 | if len(cur_dataloader_files) % worker_info.num_workers != 0: 55 | print('[DATA]--current dataloader %s file num %s cannot split to worker_num %s ' % 56 | (self.rank, len(cur_dataloader_files), worker_info.num_workers)) 57 | cur_worker_files = split_shard( 58 | cur_dataloader_files, worker_info.id, worker_info.num_workers) 59 | if worker_info.id == 0: 60 | print("[DataLoader] --> Rank:{} Workers:[{} ~ {}][{}] Size of process file:{} ...".format( 61 | self.rank, 0, worker_info.num_workers - 1, worker_info.id, len(cur_dataloader_files))) 62 | else: 63 | cur_worker_files = cur_dataloader_files 64 | 65 | if self.shuffle: 66 | random.shuffle(cur_worker_files) 67 | for filepath in cur_worker_files: 68 | if self.is_hdfs: 69 | with hopen(filepath, 'r') as reader: 70 | for line in reader: 71 | yield line.decode() 72 | continue 73 | with open(filepath, 'r') as reader: 74 | for line in reader: 75 | yield line 76 | 77 | if not self.repeat: 78 | break 79 | 80 | def __iter__(self): 81 | return self.generate() 82 | 83 | 84 | def split_shard(data: List[Any], shard_idx: int, shard_size: int): 85 | num = len(data) 86 | if num < shard_size: 87 | raise RuntimeError("num:{} < shard size:{}".format(num, shard_size)) 88 | start_idx = (num * shard_idx) // shard_size 89 | end_idx = (num * (shard_idx + 1)) // shard_size 90 | return data[start_idx: end_idx] 91 | -------------------------------------------------------------------------------- /dataset/gen_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class gen_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root, split='train', max_words=30, prompt=''): 10 | self.ann = json.load(open(ann_file,'r')) 11 | self.transform = transform 12 | 13 | self.image_root = image_root 14 | self.max_words = max_words 15 | self.split = split # 5 captions per image if not in train set 16 | self.prompt = prompt 17 | 18 | def __len__(self): 19 | return len(self.ann) 20 | 21 | def __getitem__(self, index): 22 | ann = self.ann[index] 23 | 24 | image_path = os.path.join(self.image_root, ann['image']) 25 | image = Image.open(image_path).convert('RGB') 26 | image = self.transform(image) 27 | 28 | if self.split == 'train': 29 | caption = self.prompt + pre_caption(ann['caption'], self.max_words) 30 | return image, caption 31 | else: 32 | if "nocaps" in image_path: 33 | fname = ann["id"] 34 | else: 35 | fname = ann['image'] 36 | caption = self.prompt + pre_caption(ann['caption'][0], self.max_words) 37 | return image, caption, fname 38 | -------------------------------------------------------------------------------- /dataset/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class nlvr_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root): 10 | self.ann = [] 11 | for f in ann_file: 12 | self.ann += json.load(open(f,'r')) 13 | self.transform = transform 14 | self.image_root = image_root 15 | self.max_words = 30 16 | 17 | def __len__(self): 18 | return len(self.ann) 19 | 20 | 21 | def __getitem__(self, index): 22 | 23 | ann = self.ann[index] 24 | 25 | image0_path = os.path.join(self.image_root,ann['images'][0]) 26 | image0 = Image.open(image0_path).convert('RGB') 27 | image0 = self.transform(image0) 28 | 29 | image1_path = os.path.join(self.image_root,ann['images'][1]) 30 | image1 = Image.open(image1_path).convert('RGB') 31 | image1 = self.transform(image1) 32 | 33 | sentence = pre_caption(ann['sentence'], self.max_words) 34 | 35 | if ann['label']=='True': 36 | label = 1 37 | else: 38 | label = 0 39 | 40 | return image0, image1, sentence, label -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def pre_question(question,max_ques_words): 4 | question = re.sub( 5 | r"([,.'!?\"()*#:;~])", 6 | '', 7 | question.lower(), 8 | ).replace('-', ' ').replace('/', ' ') 9 | question = question.rstrip(' ') 10 | 11 | #truncate question 12 | question_words = question.split(' ') 13 | if len(question_words)>max_ques_words: 14 | question = ' '.join(question_words[:max_ques_words]) 15 | 16 | return question 17 | 18 | 19 | def pre_caption(caption,max_words): 20 | caption = re.sub( 21 | r"([,.'!?\"()*#:;~])", 22 | '', 23 | caption.lower(), 24 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 25 | 26 | caption = re.sub( 27 | r"\s{2,}", 28 | ' ', 29 | caption, 30 | ) 31 | caption = caption.rstrip('\n') 32 | caption = caption.strip(' ') 33 | 34 | #truncate caption 35 | caption_words = caption.split(' ') 36 | if len(caption_words)>max_words: 37 | caption = ' '.join(caption_words[:max_words]) 38 | 39 | return caption 40 | 41 | 42 | # from vqaTools.vqaEval import VQAEval 43 | # from refTools.evaluation.refEvaluation import RefEvaluation 44 | 45 | import json 46 | import os 47 | import numpy as np 48 | import torch 49 | import torch.distributed as dist 50 | import torch.nn.functional as F 51 | 52 | import utils 53 | from tqdm import tqdm 54 | 55 | 56 | # def vqa_eval(vqa, result_file, test_ques_path): 57 | # vqaRes = vqa.loadRes(result_file, test_ques_path) 58 | # # create vqaEval object by taking vqa and vqaRes 59 | # vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 60 | # # evaluate results 61 | # vqaEval.evaluate() 62 | 63 | # # print accuracies 64 | # print("\n") 65 | # print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 66 | # print("Per Answer Type Accuracy is the following:") 67 | # for ansType in vqaEval.accuracy['perAnswerType']: 68 | # print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 69 | # print("\n") 70 | 71 | # return vqaEval 72 | 73 | 74 | 75 | def collect_result(result, result_dir, filename, is_json=True, is_list=True): 76 | if is_json: 77 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 78 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 79 | json.dump(result,open(result_file,'w')) 80 | else: 81 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank())) 82 | final_result_file = os.path.join(result_dir, '%s.pth'%filename) 83 | torch.save(result,result_file) 84 | 85 | dist.barrier() 86 | 87 | result = None 88 | if utils.is_main_process(): 89 | # combine results from all processes 90 | if is_list: 91 | result = [] 92 | else: 93 | result = {} 94 | for rank in range(utils.get_world_size()): 95 | if is_json: 96 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 97 | res = json.load(open(result_file,'r')) 98 | else: 99 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank)) 100 | res = torch.load(result_file) 101 | if is_list: 102 | result += res 103 | else: 104 | result.update(res) 105 | 106 | return result 107 | 108 | 109 | def save_result(result, result_dir, filename, is_json=True, is_list=True): 110 | if is_json: 111 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 112 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 113 | json.dump(result,open(result_file,'w')) 114 | else: 115 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank())) 116 | final_result_file = os.path.join(result_dir, '%s.pth'%filename) 117 | torch.save(result,result_file) 118 | 119 | dist.barrier() 120 | 121 | if utils.is_main_process(): 122 | # combine results from all processes 123 | if is_list: 124 | result = [] 125 | else: 126 | result = {} 127 | for rank in range(utils.get_world_size()): 128 | if is_json: 129 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 130 | res = json.load(open(result_file,'r')) 131 | else: 132 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank)) 133 | res = torch.load(result_file) 134 | if is_list: 135 | result += res 136 | else: 137 | result.update(res) 138 | if is_json: 139 | json.dump(result,open(final_result_file,'w')) 140 | else: 141 | torch.save(result,final_result_file) 142 | 143 | print('result file saved to %s'%final_result_file) 144 | dist.barrier() 145 | return final_result_file 146 | 147 | 148 | # IoU function 149 | def computeIoU(box1, box2): 150 | # each box is of [x1, y1, w, h] 151 | inter_x1 = max(box1[0], box2[0]) 152 | inter_y1 = max(box1[1], box2[1]) 153 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1) 154 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1) 155 | 156 | if inter_x1 < inter_x2 and inter_y1 < inter_y2: 157 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) 158 | else: 159 | inter = 0 160 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter 161 | return float(inter)/union 162 | 163 | 164 | -------------------------------------------------------------------------------- /dataset/ve_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class ve_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root, max_words=30): 10 | self.ann = json.load(open(ann_file,'r')) 11 | self.transform = transform 12 | self.image_root = image_root 13 | self.max_words = max_words 14 | self.labels = {'entailment':2,'neutral':1,'contradiction':0} 15 | 16 | def __len__(self): 17 | return len(self.ann) 18 | 19 | 20 | def __getitem__(self, index): 21 | 22 | ann = self.ann[index] 23 | 24 | image_path = os.path.join(self.image_root,'%s.jpg'%ann['image']) 25 | image = Image.open(image_path).convert('RGB') 26 | image = self.transform(image) 27 | 28 | sentence = pre_caption(ann['sentence'], self.max_words) 29 | 30 | return image, sentence, self.labels[ann['label']] 31 | -------------------------------------------------------------------------------- /dataset/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from dataset.utils import pre_question 7 | from collections import Counter 8 | 9 | class vqa_dataset(Dataset): 10 | def __init__(self, ann_file, transform, vqa_root, vg_root, eos='[SEP]', split="train", max_ques_words=30, answer_list='./vqaTools/answer_list.json'): 11 | self.split = split 12 | self.ann = [] 13 | for f in ann_file: 14 | self.ann += json.load(open(f,'r')) 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | self.max_ques_words = max_ques_words 20 | self.eos = eos 21 | 22 | # self.max_ques_words = 50 23 | # self.answer_list = ["oov"]+json.load(open(answer_list,'r')) 24 | # self.answer2id = {ans:i for i, ans in enumerate(self.answer_list)} 25 | 26 | if split=='test': 27 | self.max_ques_words = 50 # do not limit question length during test 28 | self.answer_list = json.load(open(answer_list,'r')) 29 | 30 | 31 | def __len__(self): 32 | return len(self.ann) 33 | 34 | def __getitem__(self, index): 35 | 36 | ann = self.ann[index] 37 | 38 | if ann['dataset']=='vqa': 39 | image_path = os.path.join(self.vqa_root,ann['image']) 40 | elif ann['dataset']=='vg': 41 | image_path = os.path.join(self.vg_root,ann['image']) 42 | 43 | image = Image.open(image_path).convert('RGB') 44 | image = self.transform(image) 45 | 46 | if self.split == 'test': 47 | question = pre_question(ann['question'],self.max_ques_words) 48 | question_id = ann['question_id'] 49 | return image, question, question_id 50 | 51 | elif self.split=='train' or self.split=='val': 52 | 53 | question = pre_question(ann['question'],self.max_ques_words) 54 | question_id = ann['question_id'] 55 | 56 | if ann['dataset']=='vqa': 57 | answer_weight = {} 58 | for answer in ann['answer']: 59 | if answer in answer_weight.keys(): 60 | answer_weight[answer] += 1/len(ann['answer']) 61 | else: 62 | answer_weight[answer] = 1/len(ann['answer']) 63 | 64 | answers = list(answer_weight.keys()) 65 | weights = list(answer_weight.values()) 66 | 67 | elif ann['dataset']=='vg': 68 | answers = [ann['answer']] 69 | weights = [0.5] 70 | 71 | # answers = [answer+self.eos for answer in answers] 72 | return image, question, answers, weights, question_id -------------------------------------------------------------------------------- /eval_coco.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | import os 7 | import re 8 | import json 9 | import base64 10 | import numpy as np 11 | import pandas as pd 12 | from tqdm import tqdm 13 | from pycocotools.coco import COCO 14 | from pycocoevalcap.eval import COCOEvalCap 15 | from collections import defaultdict 16 | import time 17 | import argparse 18 | import traceback 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--model_path', required=True, help="for example, gen_coco_3998754_20211015-002418") 22 | args = parser.parse_args() 23 | 24 | # model_path = "./results/gen_coco_3998754_20211015-002418" 25 | model_path = f"./results/{args.model_path}" 26 | 27 | def bpe_post_process(text): 28 | return re.sub(r" ?##", "", text) 29 | bpe_post_process("r . j . russell , 31 , has been wearing con ##st ##ric ##tive sports bra ##s since i developed breasts in high school .") 30 | 31 | # convert the reference first 32 | for split in ["test", "val"]: 33 | out_data = {"info": [], "images": [], "licenses": [], "type": "caption", "annotations": []} 34 | with open(f"./data/coco_{split}.json") as fi, open(f"./data/coco_{split}_converted.json", "w") as fo: 35 | for sample in json.load(fi): 36 | # ann0 = {} 37 | fname = sample["image"] 38 | id0 = fname[13:-3] 39 | for i, caption in enumerate(sample['caption']): 40 | ann0 = {} 41 | ann0["caption"] = caption 42 | ann0["image_id"] = id0 43 | ann0["id"] = i 44 | out_data["annotations"].append(ann0) 45 | out_data['images'].append({'id': id0}) 46 | json.dump(out_data, fo) 47 | 48 | # convert the generation 49 | def convert_format(in_fname, out_fname): 50 | with open(in_fname) as f: 51 | test_gen = json.load(f) 52 | out_gen = [] 53 | used_ids = set() 54 | for sample in test_gen: 55 | fname = sample["images"] 56 | cid = fname[13:-3] 57 | caption = sample['generated'] 58 | if cid in used_ids: 59 | # print(cid) 60 | continue 61 | else: 62 | used_ids.add(cid) 63 | out_gen.append({'image_id': cid, 'caption': bpe_post_process(caption)}) 64 | with open(out_fname, "w") as f: 65 | json.dump(out_gen, f) 66 | 67 | def eval_gen(annotation_file, org_results_file, results_file): 68 | convert_format(org_results_file, results_file) 69 | # create coco object and coco_result object 70 | coco = COCO(annotation_file) 71 | coco_result = coco.loadRes(results_file) 72 | 73 | # create coco_eval object by taking coco and coco_result 74 | coco_eval = COCOEvalCap(coco, coco_result) 75 | 76 | # evaluate on a subset of images by setting 77 | # coco_eval.params['image_id'] = coco_result.getImgIds() 78 | # please remove this line when evaluating the full validation set 79 | coco_eval.params['image_id'] = coco_result.getImgIds() 80 | 81 | # evaluate results 82 | # SPICE will take a few minutes the first time, but speeds up due to caching 83 | coco_eval.evaluate() 84 | 85 | CIDEr = 0 86 | # print output evaluation scores 87 | for metric, score in coco_eval.eval.items(): 88 | print(f'{metric}: {score:.3f}') 89 | if metric == "CIDEr": CIDEr = score 90 | return CIDEr 91 | 92 | epoch2cider = {} 93 | max_CIDEr = 0 94 | for epoch in range(-1, 40): 95 | try: 96 | print(time.strftime("%a, %d %b %Y %H:%M:%S +0000: ", time.localtime()) + 'Epoch ', epoch, flush=True) 97 | annotation_file = './data/coco_test_converted.json' 98 | org_results_file = f'./{model_path}/gen_val_result_epoch{epoch}.json' 99 | results_file = f'./{model_path}/gen_val_result_epoch{epoch}_converted.json' 100 | 101 | CIDEr = eval_gen(annotation_file, org_results_file, results_file) 102 | epoch2cider[epoch] = int(CIDEr * 10000) / 100 103 | max_CIDEr = max(max_CIDEr, CIDEr) 104 | except Exception as e: 105 | traceback.print_exc() 106 | print(time.strftime("%a, %d %b %Y %H:%M:%S +0000: ", time.localtime()) + 'MAX CIDEr = ', max_CIDEr, flush=True) 107 | print("epoch2cider", epoch2cider) -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/img.png -------------------------------------------------------------------------------- /models/DALLE-pytorch/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/models/DALLE-pytorch/.DS_Store -------------------------------------------------------------------------------- /models/DALLE-pytorch/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # dall-e generation outputs 2 | outputs/ 3 | *.pt 4 | taming/ 5 | wandb/ 6 | dalle-ds-cp/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # Visual Studio Code 95 | .vscode 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dalle_pytorch *.txt 2 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | ARG IMG_TAG=1.8.1-cuda10.2-cudnn7-devel 3 | ARG IMG_REPO=pytorch 4 | 5 | FROM pytorch/$IMG_REPO:$IMG_TAG 6 | 7 | RUN apt-get -y update && apt-get -y install git gcc llvm-9-dev cmake libaio-dev vim wget 8 | 9 | RUN git clone https://github.com/microsoft/DeepSpeed.git /tmp/DeepSpeed 10 | RUN cd /tmp/DeepSpeed && DS_BUILD_OPS=1 ./install.sh -r 11 | RUN pip install git+https://github.com/lucidrains/DALLE-pytorch.git 12 | 13 | WORKDIR dalle 14 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | 5 | # torch 6 | 7 | import torch 8 | 9 | from einops import repeat 10 | 11 | # vision imports 12 | 13 | from PIL import Image 14 | from torchvision.utils import make_grid, save_image 15 | 16 | # dalle related classes and utils 17 | 18 | from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE 19 | from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer 20 | 21 | # argument parsing 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--dalle_path', type = str, required = True, 26 | help='path to your trained DALL-E') 27 | 28 | parser.add_argument('--vqgan_model_path', type=str, default = None, 29 | help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)') 30 | 31 | parser.add_argument('--vqgan_config_path', type=str, default = None, 32 | help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)') 33 | 34 | parser.add_argument('--text', type = str, required = True, 35 | help='your text prompt') 36 | 37 | parser.add_argument('--num_images', type = int, default = 128, required = False, 38 | help='number of images') 39 | 40 | parser.add_argument('--batch_size', type = int, default = 4, required = False, 41 | help='batch size') 42 | 43 | parser.add_argument('--top_k', type = float, default = 0.9, required = False, 44 | help='top k filter threshold') 45 | 46 | parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False, 47 | help='output directory') 48 | 49 | parser.add_argument('--bpe_path', type = str, 50 | help='path to your huggingface BPE json file') 51 | 52 | parser.add_argument('--hug', dest='hug', action = 'store_true') 53 | 54 | parser.add_argument('--chinese', dest='chinese', action = 'store_true') 55 | 56 | parser.add_argument('--taming', dest='taming', action='store_true') 57 | 58 | parser.add_argument('--gentxt', dest='gentxt', action='store_true') 59 | 60 | args = parser.parse_args() 61 | 62 | # helper fns 63 | 64 | def exists(val): 65 | return val is not None 66 | 67 | # tokenizer 68 | 69 | if exists(args.bpe_path): 70 | klass = HugTokenizer if args.hug else YttmTokenizer 71 | tokenizer = klass(args.bpe_path) 72 | elif args.chinese: 73 | tokenizer = ChineseTokenizer() 74 | 75 | # load DALL-E 76 | 77 | dalle_path = Path(args.dalle_path) 78 | 79 | assert dalle_path.exists(), 'trained DALL-E must exist' 80 | 81 | load_obj = torch.load(str(dalle_path)) 82 | dalle_params, vae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None) 83 | 84 | # friendly print 85 | 86 | if exists(version): 87 | print(f'Loading a model trained with DALLE-pytorch version {version}') 88 | else: 89 | print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version') 90 | 91 | # load VAE 92 | 93 | if args.taming: 94 | vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path) 95 | elif vae_params is not None: 96 | vae = DiscreteVAE(**vae_params) 97 | else: 98 | vae = OpenAIDiscreteVAE() 99 | 100 | assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation' 101 | 102 | # reconstitute DALL-E 103 | 104 | dalle = DALLE(vae = vae, **dalle_params).cuda() 105 | 106 | dalle.load_state_dict(weights) 107 | 108 | # generate images 109 | 110 | image_size = vae.image_size 111 | 112 | texts = args.text.split('|') 113 | 114 | for j, text in tqdm(enumerate(texts)): 115 | if args.gentxt: 116 | text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k) 117 | text = gen_texts[0] 118 | else: 119 | text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda() 120 | 121 | text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images) 122 | 123 | outputs = [] 124 | 125 | for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'): 126 | output = dalle.generate_images(text_chunk, filter_thres = args.top_k) 127 | outputs.append(output) 128 | 129 | outputs = torch.cat(outputs) 130 | 131 | # save all images 132 | 133 | file_name = text 134 | outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)] 135 | outputs_dir.mkdir(parents = True, exist_ok = True) 136 | 137 | for i, image in tqdm(enumerate(outputs), desc = 'saving images'): 138 | save_image(image, outputs_dir / f'{i}.jpg', normalize=True) 139 | with open(outputs_dir / 'caption.txt', 'w') as f: 140 | f.write(file_name) 141 | 142 | print(f'created {args.num_images} images at "{str(outputs_dir)}"') 143 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/install_apex.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/NVIDIA/apex.git /tmp/apex 2 | cd /tmp/apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 3 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/install_deepspeed.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get -y install llvm-9-dev cmake 2 | git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed 3 | cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s 4 | -------------------------------------------------------------------------------- /models/DALLE-pytorch/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'dalle-pytorch', 5 | packages = find_packages(), 6 | include_package_data = True, 7 | version = '1.2.1', 8 | license='MIT', 9 | description = 'DALL-E - Pytorch', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/dalle-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'attention mechanism', 16 | 'transformers', 17 | 'text-to-image' 18 | ], 19 | install_requires=[ 20 | 'axial_positional_embedding', 21 | 'DALL-E', 22 | 'einops>=0.3.2', 23 | 'ftfy', 24 | 'g-mlp-pytorch', 25 | 'pillow', 26 | 'regex', 27 | 'rotary-embedding-torch', 28 | 'taming-transformers-rom1504', 29 | 'tokenizers', 30 | 'torch>=1.6', 31 | 'torchvision', 32 | 'transformers', 33 | 'tqdm', 34 | 'youtokentome', 35 | 'WebDataset' 36 | ], 37 | classifiers=[ 38 | 'Development Status :: 4 - Beta', 39 | 'Intended Audience :: Developers', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'License :: OSI Approved :: MIT License', 42 | 'Programming Language :: Python :: 3.6', 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/models/__init__.py -------------------------------------------------------------------------------- /models/dall_e/__init__.py: -------------------------------------------------------------------------------- 1 | import io, requests 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.dall_e.encoder import Encoder 6 | from models.dall_e.decoder import Decoder 7 | from models.dall_e.utils import map_pixels, unmap_pixels 8 | 9 | def load_model(path: str, device: torch.device = None) -> nn.Module: 10 | if path.startswith('http://') or path.startswith('https://'): 11 | resp = requests.get(path) 12 | resp.raise_for_status() 13 | 14 | with io.BytesIO(resp.content) as buf: 15 | return torch.load(buf, map_location=device) 16 | else: 17 | with open(path, 'rb') as f: 18 | return torch.load(f, map_location=device) 19 | -------------------------------------------------------------------------------- /models/dall_e/decoder.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | from functools import partial 10 | from models.dall_e.utils import Conv2d 11 | 12 | @attr.s(eq=False, repr=False) 13 | class DecoderBlock(nn.Module): 14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) 16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) 17 | 18 | device: torch.device = attr.ib(default=None) 19 | requires_grad: bool = attr.ib(default=False) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | super().__init__() 23 | self.n_hid = self.n_out // 4 24 | self.post_gain = 1 / (self.n_layers ** 2) 25 | 26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() 28 | self.res_path = nn.Sequential(OrderedDict([ 29 | ('relu_1', nn.ReLU()), 30 | ('conv_1', make_conv(self.n_in, self.n_hid, 1)), 31 | ('relu_2', nn.ReLU()), 32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), 33 | ('relu_3', nn.ReLU()), 34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), 35 | ('relu_4', nn.ReLU()), 36 | ('conv_4', make_conv(self.n_hid, self.n_out, 3)),])) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.id_path(x) + self.post_gain * self.res_path(x) 40 | 41 | @attr.s(eq=False, repr=False) 42 | class Decoder(nn.Module): 43 | group_count: int = 4 44 | n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8) 45 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) 46 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) 47 | output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) 48 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) 49 | 50 | device: torch.device = attr.ib(default=torch.device('cpu')) 51 | requires_grad: bool = attr.ib(default=False) 52 | use_mixed_precision: bool = attr.ib(default=True) 53 | 54 | def __attrs_post_init__(self) -> None: 55 | super().__init__() 56 | 57 | blk_range = range(self.n_blk_per_group) 58 | n_layers = self.group_count * self.n_blk_per_group 59 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 60 | make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device, 61 | requires_grad=self.requires_grad) 62 | 63 | self.blocks = nn.Sequential(OrderedDict([ 64 | ('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)), 65 | ('group_1', nn.Sequential(OrderedDict([ 66 | *[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], 67 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 68 | ]))), 69 | ('group_2', nn.Sequential(OrderedDict([ 70 | *[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], 71 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 72 | ]))), 73 | ('group_3', nn.Sequential(OrderedDict([ 74 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], 75 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 76 | ]))), 77 | ('group_4', nn.Sequential(OrderedDict([ 78 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], 79 | ]))), 80 | ('output', nn.Sequential(OrderedDict([ 81 | ('relu', nn.ReLU()), 82 | ('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)), 83 | ]))), 84 | ])) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | if len(x.shape) != 4: 88 | raise ValueError(f'input shape {x.shape} is not 4d') 89 | if x.shape[1] != self.vocab_size: 90 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}') 91 | if x.dtype != torch.float32: 92 | raise ValueError('input must have dtype torch.float32') 93 | 94 | return self.blocks(x) 95 | -------------------------------------------------------------------------------- /models/dall_e/encoder.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | from functools import partial 10 | from models.dall_e.utils import Conv2d 11 | 12 | @attr.s(eq=False, repr=False) 13 | class EncoderBlock(nn.Module): 14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) 16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) 17 | 18 | device: torch.device = attr.ib(default=None) 19 | requires_grad: bool = attr.ib(default=False) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | super().__init__() 23 | self.n_hid = self.n_out // 4 24 | self.post_gain = 1 / (self.n_layers ** 2) 25 | 26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() 28 | self.res_path = nn.Sequential(OrderedDict([ 29 | ('relu_1', nn.ReLU()), 30 | ('conv_1', make_conv(self.n_in, self.n_hid, 3)), 31 | ('relu_2', nn.ReLU()), 32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), 33 | ('relu_3', nn.ReLU()), 34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), 35 | ('relu_4', nn.ReLU()), 36 | ('conv_4', make_conv(self.n_hid, self.n_out, 1)),])) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.id_path(x) + self.post_gain * self.res_path(x) 40 | 41 | @attr.s(eq=False, repr=False) 42 | class Encoder(nn.Module): 43 | group_count: int = 4 44 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) 45 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) 46 | input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) 47 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) 48 | 49 | device: torch.device = attr.ib(default=torch.device('cpu')) 50 | requires_grad: bool = attr.ib(default=False) 51 | use_mixed_precision: bool = attr.ib(default=True) 52 | 53 | def __attrs_post_init__(self) -> None: 54 | super().__init__() 55 | 56 | blk_range = range(self.n_blk_per_group) 57 | n_layers = self.group_count * self.n_blk_per_group 58 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 59 | make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device, 60 | requires_grad=self.requires_grad) 61 | 62 | self.blocks = nn.Sequential(OrderedDict([ 63 | ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), 64 | ('group_1', nn.Sequential(OrderedDict([ 65 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], 66 | ('pool', nn.MaxPool2d(kernel_size=2)), 67 | ]))), 68 | ('group_2', nn.Sequential(OrderedDict([ 69 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], 70 | ('pool', nn.MaxPool2d(kernel_size=2)), 71 | ]))), 72 | ('group_3', nn.Sequential(OrderedDict([ 73 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], 74 | ('pool', nn.MaxPool2d(kernel_size=2)), 75 | ]))), 76 | ('group_4', nn.Sequential(OrderedDict([ 77 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], 78 | ]))), 79 | ('output', nn.Sequential(OrderedDict([ 80 | ('relu', nn.ReLU()), 81 | ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), 82 | ]))), 83 | ])) 84 | 85 | def forward(self, x: torch.Tensor) -> torch.Tensor: 86 | if len(x.shape) != 4: 87 | raise ValueError(f'input shape {x.shape} is not 4d') 88 | if x.shape[1] != self.input_channels: 89 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') 90 | if x.dtype != torch.float32: 91 | raise ValueError('input must have dtype torch.float32') 92 | 93 | return self.blocks(x) 94 | -------------------------------------------------------------------------------- /models/dall_e/utils.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | logit_laplace_eps: float = 0.1 9 | 10 | @attr.s(eq=False) 11 | class Conv2d(nn.Module): 12 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 13 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) 14 | kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) 15 | 16 | use_float16: bool = attr.ib(default=True) 17 | device: torch.device = attr.ib(default=torch.device('cpu')) 18 | requires_grad: bool = attr.ib(default=False) 19 | 20 | def __attrs_post_init__(self) -> None: 21 | super().__init__() 22 | 23 | w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, 24 | device=self.device, requires_grad=self.requires_grad) 25 | w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) 26 | 27 | b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, 28 | requires_grad=self.requires_grad) 29 | self.w, self.b = nn.Parameter(w), nn.Parameter(b) 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | if self.use_float16 and 'cuda' in self.w.device.type: 33 | if x.dtype != torch.float16: 34 | x = x.half() 35 | 36 | w, b = self.w.half(), self.b.half() 37 | else: 38 | if x.dtype != torch.float32: 39 | x = x.float() 40 | 41 | w, b = self.w, self.b 42 | 43 | return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) 44 | 45 | def map_pixels(x: torch.Tensor) -> torch.Tensor: 46 | if x.dtype != torch.float: 47 | raise ValueError('expected input to have type float') 48 | 49 | return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps 50 | 51 | def unmap_pixels(x: torch.Tensor) -> torch.Tensor: 52 | if len(x.shape) != 4: 53 | raise ValueError('expected input to be 4d') 54 | if x.dtype != torch.float: 55 | raise ValueError('expected input to have type float') 56 | 57 | return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1) 58 | -------------------------------------------------------------------------------- /models/dalle_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from models.dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE 2 | from models.dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE 3 | 4 | from pkg_resources import get_distribution 5 | # __version__ = get_distribution('dalle_pytorch').version 6 | -------------------------------------------------------------------------------- /models/dalle_pytorch/distributed_backends/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed_backend import DeepSpeedBackend 2 | from .distributed_backend import DistributedBackend 3 | from .dummy_backend import DummyBackend 4 | from .horovod_backend import HorovodBackend 5 | -------------------------------------------------------------------------------- /models/dalle_pytorch/distributed_backends/deepspeed_backend.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | 6 | from .distributed_backend import DistributedBackend 7 | 8 | 9 | class DeepSpeedBackend(DistributedBackend): 10 | """Distributed backend using the DeepSpeed engine.""" 11 | 12 | BACKEND_MODULE_NAME = 'deepspeed' 13 | BACKEND_NAME = 'DeepSpeed' 14 | 15 | def wrap_arg_parser(self, parser): 16 | if not self.has_backend(): 17 | parser.add_argument( 18 | '--deepspeed', 19 | type=lambda _: False, 20 | help=( 21 | 'whether to use DeepSpeed ' 22 | "(ignored since it's not available)" 23 | ), 24 | ) 25 | else: 26 | parser = self.backend_module.add_config_arguments(parser) 27 | 28 | parser.add_argument( 29 | '--local_rank', 30 | type=int, 31 | default=-1, 32 | help='local rank passed from distributed launcher', 33 | ) 34 | return parser 35 | 36 | def _initialize(self): 37 | self.backend_module.init_distributed() 38 | if torch.cuda.is_available(): 39 | torch.cuda.set_device(self._get_local_rank()) 40 | 41 | @staticmethod 42 | def _require_torch_distributed_init(): 43 | """Raise an error when `torch.distributed` has not been 44 | initialized yet. 45 | """ 46 | assert torch.distributed.is_initialized(), \ 47 | ('`torch.distributed` is not initialized; please call ' 48 | '`DeepSpeedBackend.initialize` at the start of your script') 49 | 50 | def _get_world_size(self): 51 | self._require_torch_distributed_init() 52 | return torch.distributed.get_world_size() 53 | 54 | def _get_rank(self): 55 | self._require_torch_distributed_init() 56 | return torch.distributed.get_rank() 57 | 58 | def _get_local_rank(self): 59 | self._require_torch_distributed_init() 60 | return int(os.environ['LOCAL_RANK']) 61 | 62 | def _local_barrier(self): 63 | self._require_torch_distributed_init() 64 | torch.distributed.barrier() 65 | 66 | def _check_args(self, args, optimizer, lr_scheduler, kwargs): 67 | """Return an appropriate optimizer and learning rate scheduler 68 | after checking the values passed to `distribute`. 69 | """ 70 | self._check_argvs(args, optimizer, lr_scheduler, kwargs) 71 | (optimizer, lr_scheduler) = self._check_config( 72 | args, optimizer, lr_scheduler, kwargs) 73 | return (optimizer, lr_scheduler) 74 | 75 | def _check_argvs(self, args, optimizer, lr_scheduler, kwargs): 76 | """Apply several sanity checks to the given command 77 | line arguments. 78 | """ 79 | has_json_config = (hasattr(args, 'deepspeed_config') 80 | and args.deepspeed_config is not None) 81 | has_dict_config = 'config_params' in kwargs 82 | if ( 83 | # No config given 84 | (not has_json_config and not has_dict_config) 85 | # JSON config file does not exist 86 | or (not has_dict_config 87 | and not os.path.isfile(args.deepspeed_config)) 88 | ): 89 | # Let DeepSpeed handle these argument errors. 90 | return 91 | 92 | if not args.deepspeed: 93 | print( 94 | 'WARNING: DeepSpeed backend was selected; setting ' 95 | '`args.deepspeed = True`' 96 | ) 97 | args.deepspeed = True 98 | 99 | if has_json_config and has_dict_config: 100 | print( 101 | 'WARNING: DeepSpeed config was given as both JSON file and ' 102 | 'Python dictionary. Python dictionary takes precedence.' 103 | ) 104 | 105 | def _check_config(self, args, optimizer, lr_scheduler, kwargs): 106 | """Return an appropriate optimizer and learning rate scheduler 107 | for the DeepSpeed configuration. 108 | """ 109 | if 'config_params' in kwargs: 110 | config = kwargs['config_params'] 111 | else: 112 | with open(args.deepspeed_config, 'r') as json_config_file: 113 | config = json.load(json_config_file) 114 | 115 | if 'optimizer' in config and optimizer is not None: 116 | print( 117 | 'WARNING: Optimizer encountered in both DeepSpeed config and ' 118 | 'keyword arguments. Optimizer in DeepSpeed config ' 119 | 'takes precedence.' 120 | ) 121 | optimizer = None 122 | 123 | if 'scheduler' in config and lr_scheduler is not None: 124 | print( 125 | 'WARNING: Learning rate scheduler encountered in both ' 126 | 'DeepSpeed config and keyword arguments. Learning rate ' 127 | 'scheduler in DeepSpeed config takes precedence.' 128 | ) 129 | # For the LR scheduler, the JSON config already has 130 | # precedence. We do this for forward compatibility. 131 | lr_scheduler = None 132 | 133 | return (optimizer, lr_scheduler) 134 | 135 | def _distribute( 136 | self, 137 | args=None, 138 | model=None, 139 | optimizer=None, 140 | model_parameters=None, 141 | training_data=None, 142 | lr_scheduler=None, 143 | **kwargs, 144 | ): 145 | """Return a distributed model engine, optimizer, dataloader, and 146 | learning rate scheduler. These are obtained by wrapping the 147 | given values with the backend. 148 | 149 | For the other or other possible arguments, 150 | see `deepspeed.initialize`. 151 | """ 152 | (optimizer, lr_scheduler) = self._check_args( 153 | args, optimizer, lr_scheduler, kwargs) 154 | 155 | return self.backend_module.initialize( 156 | args=args, 157 | model=model, 158 | optimizer=optimizer, 159 | model_parameters=model_parameters, 160 | training_data=training_data, 161 | lr_scheduler=lr_scheduler, 162 | **kwargs, 163 | ) 164 | 165 | def _average_all(self, tensor): 166 | self._require_torch_distributed_init() 167 | # We copy because modification happens in-place 168 | averaged = tensor.detach().clone() 169 | # We use `all_reduce` because it is better supported than `reduce` 170 | torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM) 171 | return averaged / self.get_world_size() 172 | -------------------------------------------------------------------------------- /models/dalle_pytorch/distributed_backends/distributed_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | An abstract backend for distributed deep learning. 3 | 4 | Provides several standard utility methods under a common API. 5 | Please check the documentation of the class `DistributedBackend` for 6 | details to implement a new backend. 7 | """ 8 | 9 | from importlib import import_module 10 | 11 | 12 | class DistributedBackend: 13 | """An abstract backend class for distributed deep learning. 14 | 15 | Provides several standard utility methods under a common API. 16 | Variables that must be overridden: 17 | - BACKEND_MODULE_NAME 18 | - BACKEND_NAME 19 | Methods that must be overridden: 20 | - wrap_arg_parser 21 | - _initialize 22 | - _get_world_size 23 | - _get_rank 24 | - _get_local_rank 25 | - _local_barrier 26 | - _distribute 27 | - _average_all 28 | """ 29 | 30 | BACKEND_MODULE_NAME = None 31 | """Name of the module to import for the backend.""" 32 | BACKEND_NAME = None 33 | """Name of the backend for printing.""" 34 | 35 | ROOT_RANK = 0 36 | 37 | backend_module = None 38 | """The module to access the backend.""" 39 | is_initialized = False 40 | """Whether the backend is initialized.""" 41 | 42 | def __init__(self): 43 | if self.BACKEND_MODULE_NAME is None: 44 | raise NotImplementedError('BACKEND_MODULE_NAME is not set') 45 | if self.BACKEND_NAME is None: 46 | raise NotImplementedError('BACKEND_NAME is not set') 47 | 48 | def has_backend(self): 49 | """Return whether the backend module is now imported.""" 50 | try: 51 | self.backend_module = import_module(self.BACKEND_MODULE_NAME) 52 | except ModuleNotFoundError: 53 | return False 54 | return True 55 | 56 | def check_batch_size(self, batch_size): 57 | """Check whether the batch size makes sense for distribution.""" 58 | assert batch_size >= self.get_world_size(), \ 59 | (f"batch size can't be smaller than number of processes " 60 | f'({batch_size} < {self.get_world_size()})') 61 | 62 | def wrap_arg_parser(self, parser): 63 | """Add arguments to support optional distributed backend usage.""" 64 | raise NotImplementedError 65 | 66 | def initialize(self): 67 | """Initialize the distributed backend.""" 68 | self._initialize() 69 | self.is_initialized = True 70 | 71 | def _initialize(self): 72 | """Initialize the distributed backend.""" 73 | raise NotImplementedError 74 | 75 | def require_init(self): 76 | """Raise an error when the backend has not been initialized yet.""" 77 | assert self.is_initialized, \ 78 | (f'{BACKEND_NAME} backend has not been initialized; please call ' 79 | f'`distributed_utils.initialize` at the start of your script to ' 80 | f'allow optional distributed usage') 81 | 82 | def get_world_size(self): 83 | """Return the amount of distributed processes.""" 84 | self.require_init() 85 | return self._get_world_size() 86 | 87 | def _get_world_size(self): 88 | """Return the amount of distributed processes.""" 89 | raise NotImplementedError 90 | 91 | def get_rank(self): 92 | """Return the global rank of the calling worker process.""" 93 | self.require_init() 94 | return self._get_rank() 95 | 96 | def _get_rank(self): 97 | """Return the global rank of the calling worker process.""" 98 | raise NotImplementedError 99 | 100 | def get_local_rank(self): 101 | """Return the local rank of the calling worker process. 102 | The local rank is the rank based on a single node's processes. 103 | """ 104 | self.require_init() 105 | return self._get_local_rank() 106 | 107 | def _get_local_rank(self): 108 | """Return the local rank of the calling worker process. 109 | The local rank is the rank based on a single node's processes. 110 | """ 111 | raise NotImplementedError 112 | 113 | def is_root_worker(self): 114 | """Return whether the calling worker has the root rank.""" 115 | return self.get_rank() == self.ROOT_RANK 116 | 117 | def is_local_root_worker(self): 118 | """Return whether the calling worker has the root rank on this node.""" 119 | return self.get_local_rank() == self.ROOT_RANK 120 | 121 | def local_barrier(self): 122 | """Wait until all processes on this node have called this function.""" 123 | self.require_init() 124 | self._local_barrier() 125 | 126 | def _local_barrier(self): 127 | """Wait until all processes on this node have called this function.""" 128 | raise NotImplementedError 129 | 130 | def distribute( 131 | self, 132 | args=None, 133 | model=None, 134 | optimizer=None, 135 | model_parameters=None, 136 | training_data=None, 137 | lr_scheduler=None, 138 | **kwargs, 139 | ): 140 | """Return a distributed model engine, optimizer, dataloader, and 141 | learning rate scheduler. These are obtained by wrapping the 142 | given values with the backend. 143 | """ 144 | self.require_init() 145 | return self._distribute( 146 | args, 147 | model, 148 | optimizer, 149 | model_parameters, 150 | training_data, 151 | lr_scheduler, 152 | **kwargs, 153 | ) 154 | 155 | def _distribute( 156 | self, 157 | args=None, 158 | model=None, 159 | optimizer=None, 160 | model_parameters=None, 161 | training_data=None, 162 | lr_scheduler=None, 163 | **kwargs, 164 | ): 165 | """Return a distributed model engine, optimizer, dataloader, and 166 | learning rate scheduler. These are obtained by wrapping the 167 | given values with the backend. 168 | """ 169 | raise NotImplementedError 170 | 171 | def average_all(self, tensor): 172 | """Return the average of `tensor` over all workers.""" 173 | self.require_init() 174 | return self._average_all(tensor) 175 | 176 | def _average_all(self, tensor): 177 | """Return the average of `tensor` over all workers.""" 178 | raise NotImplementedError 179 | -------------------------------------------------------------------------------- /models/dalle_pytorch/distributed_backends/dummy_backend.py: -------------------------------------------------------------------------------- 1 | from .distributed_backend import DistributedBackend 2 | 3 | 4 | class DummyBackend(DistributedBackend): 5 | """Acts like a distributed backend. 6 | 7 | Used as a stand-in replacement to obtain a non-distributed program. 8 | """ 9 | 10 | # We define this so we can use `super().__init__` but want this to 11 | # throw an error upon import. 12 | BACKEND_MODULE_NAME = 'NO MODULE' 13 | BACKEND_NAME = 'Dummy' 14 | 15 | def has_backend(self): 16 | return True 17 | 18 | def wrap_arg_parser(self, parser): 19 | return parser 20 | 21 | def _initialize(self): 22 | pass 23 | 24 | def _get_world_size(self): 25 | return 1 26 | 27 | def _get_rank(self): 28 | return self.ROOT_RANK 29 | 30 | def _get_local_rank(self): 31 | return self.ROOT_RANK 32 | 33 | def _local_barrier(self): 34 | pass 35 | 36 | def _distribute( 37 | self, 38 | _args=None, 39 | model=None, 40 | optimizer=None, 41 | _model_parameters=None, 42 | training_data=None, 43 | lr_scheduler=None, 44 | **_kwargs, 45 | ): 46 | """Return the model, optimizer, dataloader, and learning rate scheduler 47 | as is. 48 | """ 49 | return (model, optimizer, training_data, lr_scheduler) 50 | 51 | def _average_all(self, tensor): 52 | return tensor 53 | -------------------------------------------------------------------------------- /models/dalle_pytorch/distributed_backends/horovod_backend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .distributed_backend import DistributedBackend 4 | 5 | 6 | class HorovodBackend(DistributedBackend): 7 | """Distributed backend using Horovod.""" 8 | 9 | BACKEND_MODULE_NAME = 'horovod.torch' 10 | BACKEND_NAME = 'Horovod' 11 | 12 | def wrap_arg_parser(self, parser): 13 | return parser 14 | 15 | def check_batch_size(self, batch_size): 16 | # Horovod uses the local batch size to determine the effective 17 | # batch size. 18 | pass 19 | 20 | def _initialize(self): 21 | self.backend_module.init() 22 | if torch.cuda.is_available(): 23 | torch.cuda.set_device(self._get_local_rank()) 24 | 25 | def _get_world_size(self): 26 | return self.backend_module.size() 27 | 28 | def _get_rank(self): 29 | return self.backend_module.rank() 30 | 31 | def _get_local_rank(self): 32 | return self.backend_module.local_rank() 33 | 34 | def _local_barrier(self): 35 | # Actually a global barrier but works for our purposes. 36 | self.backend_module.join() 37 | 38 | def _distribute( 39 | self, 40 | _args=None, 41 | model=None, 42 | optimizer=None, 43 | _model_parameters=None, 44 | training_data=None, 45 | lr_scheduler=None, 46 | **_kwargs, 47 | ): 48 | optimizer = self.backend_module.DistributedOptimizer(optimizer) 49 | self.backend_module.broadcast_parameters( 50 | model.state_dict(), root_rank=self.ROOT_RANK) 51 | self.backend_module.broadcast_optimizer_state( 52 | optimizer, root_rank=self.ROOT_RANK) 53 | return (model, optimizer, training_data, lr_scheduler) 54 | 55 | def _average_all(self, tensor): 56 | # Reduce op is average by default 57 | averaged = self.backend_module.allreduce(tensor) 58 | return averaged 59 | -------------------------------------------------------------------------------- /models/dalle_pytorch/distributed_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for optional distributed execution. 3 | 4 | To use, 5 | 1. set the `BACKENDS` to the ones you want to make available, 6 | 2. in the script, wrap the argument parser with `wrap_arg_parser`, 7 | 3. in the script, set and use the backend by calling 8 | `set_backend_from_args`. 9 | 10 | You can check whether a backend is in use with the `using_backend` 11 | function. 12 | """ 13 | 14 | from models.dalle_pytorch.distributed_backends import \ 15 | DeepSpeedBackend, \ 16 | DummyBackend, \ 17 | HorovodBackend 18 | 19 | _DEFAULT_BACKEND = DummyBackend() 20 | """Which backend to use by default. Assumed to be _not_ distributed.""" 21 | 22 | BACKENDS = [ 23 | _DEFAULT_BACKEND, 24 | DeepSpeedBackend(), 25 | HorovodBackend(), 26 | ] 27 | 28 | is_distributed = None 29 | """Whether we are distributed.""" 30 | backend = None 31 | """Backend in usage.""" 32 | 33 | 34 | def wrap_arg_parser(parser): 35 | """Add arguments to support optional distributed backend usage.""" 36 | parser.add_argument( 37 | '--distributed_backend', 38 | '--distr_backend', 39 | type=str, 40 | default=None, 41 | help='which distributed backend to use. Do not distribute by default', 42 | ) 43 | for distr_backend in BACKENDS: 44 | parser = distr_backend.wrap_arg_parser(parser) 45 | return parser 46 | 47 | 48 | def set_backend_from_args(args): 49 | """Set and return the backend based on the given `args`.""" 50 | global is_distributed, backend 51 | 52 | # Handle this specially for backwards compatibility. 53 | if args.deepspeed: 54 | args.distributed_backend = DeepSpeedBackend.BACKEND_NAME 55 | 56 | if not args.distributed_backend: 57 | is_distributed = False 58 | backend = _DEFAULT_BACKEND 59 | return backend 60 | 61 | backend_name = args.distributed_backend.lower() 62 | for distr_backend in BACKENDS: 63 | if distr_backend.BACKEND_NAME.lower() == backend_name: 64 | backend = distr_backend 65 | if not backend.has_backend(): 66 | raise ModuleNotFoundError( 67 | f'{backend.BACKEND_NAME} backend selected but ' 68 | 'module not available' 69 | ) 70 | 71 | print(f'Using {backend.BACKEND_NAME} for distributed execution') 72 | is_distributed = True 73 | return backend 74 | 75 | raise ValueError( 76 | 'unknown backend; please check `distributed_utils.BACKENDS`') 77 | 78 | 79 | def require_set_backend(): 80 | """Raise an `AssertionError` when the backend has not been set.""" 81 | assert backend is not None, ( 82 | 'distributed backend is not set. Please call ' 83 | '`distributed_utils.set_backend_from_args` at the start of your script' 84 | ) 85 | 86 | 87 | def using_backend(test_backend): 88 | """Return whether the backend is set to `test_backend`. 89 | 90 | `test_backend` may be a string of the name of the backend or 91 | its class. 92 | """ 93 | require_set_backend() 94 | if isinstance(test_backend, str): 95 | return backend.BACKEND_NAME == test_backend 96 | return isinstance(backend, test_backend) 97 | -------------------------------------------------------------------------------- /models/dalle_pytorch/loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from random import randint, choice 3 | 4 | import PIL 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms as T 8 | 9 | 10 | class TextImageDataset(Dataset): 11 | def __init__(self, 12 | folder, 13 | text_len=256, 14 | image_size=128, 15 | truncate_captions=False, 16 | resize_ratio=0.75, 17 | tokenizer=None, 18 | shuffle=False 19 | ): 20 | """ 21 | @param folder: Folder containing images and text files matched by their paths' respective "stem" 22 | @param truncate_captions: Rather than throw an exception, captions which are too long will be truncated. 23 | """ 24 | super().__init__() 25 | self.shuffle = shuffle 26 | path = Path(folder) 27 | 28 | text_files = [*path.glob('**/*.txt')] 29 | image_files = [ 30 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 31 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 32 | ] 33 | 34 | text_files = {text_file.stem: text_file for text_file in text_files} 35 | image_files = {image_file.stem: image_file for image_file in image_files} 36 | 37 | keys = (image_files.keys() & text_files.keys()) 38 | 39 | self.keys = list(keys) 40 | self.text_files = {k: v for k, v in text_files.items() if k in keys} 41 | self.image_files = {k: v for k, v in image_files.items() if k in keys} 42 | self.text_len = text_len 43 | self.truncate_captions = truncate_captions 44 | self.resize_ratio = resize_ratio 45 | self.tokenizer = tokenizer 46 | self.image_transform = T.Compose([ 47 | T.Lambda(lambda img: img.convert('RGB') 48 | if img.mode != 'RGB' else img), 49 | T.RandomResizedCrop(image_size, 50 | scale=(self.resize_ratio, 1.), 51 | ratio=(1., 1.)), 52 | T.ToTensor() 53 | ]) 54 | 55 | def __len__(self): 56 | return len(self.keys) 57 | 58 | def random_sample(self): 59 | return self.__getitem__(randint(0, self.__len__() - 1)) 60 | 61 | def sequential_sample(self, ind): 62 | if ind >= self.__len__() - 1: 63 | return self.__getitem__(0) 64 | return self.__getitem__(ind + 1) 65 | 66 | def skip_sample(self, ind): 67 | if self.shuffle: 68 | return self.random_sample() 69 | return self.sequential_sample(ind=ind) 70 | 71 | def __getitem__(self, ind): 72 | key = self.keys[ind] 73 | 74 | text_file = self.text_files[key] 75 | image_file = self.image_files[key] 76 | 77 | descriptions = text_file.read_text().split('\n') 78 | descriptions = list(filter(lambda t: len(t) > 0, descriptions)) 79 | try: 80 | description = choice(descriptions) 81 | except IndexError as zero_captions_in_file_ex: 82 | print(f"An exception occurred trying to load file {text_file}.") 83 | print(f"Skipping index {ind}") 84 | return self.skip_sample(ind) 85 | 86 | tokenized_text = self.tokenizer.tokenize( 87 | description, 88 | self.text_len, 89 | truncate_text=self.truncate_captions 90 | ).squeeze(0) 91 | try: 92 | image_tensor = self.image_transform(PIL.Image.open(image_file)) 93 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions: 94 | print(f"An exception occurred trying to load file {image_file}.") 95 | print(f"Skipping index {ind}") 96 | return self.skip_sample(ind) 97 | 98 | # Success 99 | return tokenized_text, image_tensor 100 | -------------------------------------------------------------------------------- /models/dalle_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class SequentialSequence(nn.Module): 127 | def __init__(self, layers, args_route = {}, layer_dropout = 0.): 128 | super().__init__() 129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 130 | self.layers = layers 131 | self.args_route = args_route 132 | self.layer_dropout = layer_dropout 133 | 134 | def forward(self, x, **kwargs): 135 | args = route_args(self.args_route, kwargs, len(self.layers)) 136 | layers_and_args = list(zip(self.layers, args)) 137 | 138 | for (f, g), (f_args, g_args) in layers_and_args: 139 | x = x + f(x, **f_args) 140 | x = x + g(x, **g_args) 141 | return x 142 | 143 | class ReversibleSequence(nn.Module): 144 | def __init__(self, blocks, args_route = {}): 145 | super().__init__() 146 | self.args_route = args_route 147 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 148 | 149 | def forward(self, x, **kwargs): 150 | x = torch.cat([x, x], dim=-1) 151 | 152 | blocks = self.blocks 153 | args = route_args(self.args_route, kwargs, len(blocks)) 154 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 155 | 156 | out = _ReversibleFunction.apply(x, blocks, args) 157 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) 158 | -------------------------------------------------------------------------------- /models/model_glue.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from models.xbert import BertConfig 7 | from models.davinci_pretrain import DaVinci 8 | 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | class DaVinciGLUE(nn.Module): 13 | def __init__(self, 14 | encoder = None, 15 | text_decoder = None, 16 | tokenizer = None, 17 | config = None, 18 | num_labels = None, 19 | ): 20 | super().__init__() 21 | self.last_hidden_id_shift = config['last_hidden_id_shift'] 22 | self.tokenizer = tokenizer 23 | self.davinci = DaVinci(encoder, text_decoder, tokenizer, config, init_deit=False, init_dalle=True) 24 | bert_config = BertConfig.from_json_file(config['bert_config']) 25 | self.num_labels = num_labels 26 | self.cls_head = nn.Sequential( 27 | nn.Linear(bert_config.hidden_size, bert_config.hidden_size), 28 | nn.ReLU(), 29 | nn.Linear(bert_config.hidden_size, num_labels) 30 | ) 31 | 32 | def forward(self, text, targets, alpha=0, train=True): 33 | last_state_ids = text.attention_mask.sum(1) - self.last_hidden_id_shift 34 | output = self.davinci(image=None, 35 | context=text, 36 | gen_text=text, 37 | last_state_ids = last_state_ids, 38 | is_ve = True, 39 | train=train, decode=False) 40 | prediction = self.cls_head(output) 41 | if train: 42 | if self.num_labels == 1: 43 | loss = F.mse_loss(prediction.squeeze(), targets.squeeze()) 44 | else: 45 | loss = F.cross_entropy(prediction, targets) 46 | return loss 47 | else: 48 | return prediction 49 | -------------------------------------------------------------------------------- /models/model_imageft.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from models.davinci_pretrain import DaVinci 7 | from torch import nn 8 | 9 | class DaVinciImageFT(nn.Module): 10 | def __init__(self, 11 | encoder = None, 12 | text_decoder = None, 13 | tokenizer = None, 14 | config = None, 15 | ): 16 | super().__init__() 17 | self.last_hidden_id_shift = 1 18 | self.tokenizer = tokenizer 19 | self.davinci = DaVinci(encoder, text_decoder, tokenizer, config, init_deit=False, init_dalle=True, imagenet=True) 20 | emb_dim = self.davinci.config_decoder.hidden_size 21 | self.fc = nn.Sequential( 22 | nn.Linear(3*emb_dim, emb_dim), 23 | nn.ReLU(), 24 | nn.Linear(emb_dim, 1000) 25 | ) 26 | def forward(self, image, text=None, train=True): 27 | dummy_text = self.tokenizer([""] * image.size(0), return_tensors='pt').to(image.device) 28 | text_inputs = self.tokenizer(["a picture of "]*image.size(0), return_tensors="pt").to(image.device) 29 | last_state_ids = text_inputs.attention_mask.sum(1) - self.last_hidden_id_shift 30 | hidden_states = self.davinci(image, 31 | dummy_text, 32 | text_inputs, 33 | last_state_ids = last_state_ids, 34 | imagenet=True, 35 | train=train, decode=False) 36 | logits = self.fc(hidden_states) 37 | return logits 38 | -------------------------------------------------------------------------------- /models/model_linearprobe.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from models.davinci_pretrain import DaVinci 7 | from torch import nn 8 | 9 | class DaVinciLinearProbe(nn.Module): 10 | def __init__(self, 11 | encoder = None, 12 | text_decoder = None, 13 | tokenizer = None, 14 | config = None, 15 | n_labels = None, 16 | ): 17 | super().__init__() 18 | self.last_hidden_id_shift = 1 19 | self.tokenizer = tokenizer 20 | self.davinci = DaVinci(encoder, text_decoder, tokenizer, config, init_deit=False, init_dalle=True) 21 | emb_dim = self.davinci.config_decoder.hidden_size 22 | self.fc = nn.Linear(emb_dim * 3, n_labels) 23 | 24 | def forward(self, image, text=None, train=True): 25 | dummy_text = self.tokenizer([""] * image.size(0), return_tensors='pt').to(image.device) 26 | text_inputs = self.tokenizer(["a picture of "]*image.size(0), return_tensors="pt").to(image.device) 27 | last_state_ids = text_inputs.attention_mask.sum(1) - self.last_hidden_id_shift 28 | hidden_states = self.davinci(image, 29 | dummy_text, 30 | text_inputs, 31 | last_state_ids = last_state_ids, 32 | imagenet=True, 33 | train=train, decode=False) 34 | logits = self.fc(hidden_states) 35 | return logits 36 | -------------------------------------------------------------------------------- /models/model_nlvr.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from models.xbert import BertConfig, BertModel 7 | from models.davinci_pretrain import DaVinci 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | class DaVinciNLVR(nn.Module): 14 | def __init__(self, 15 | encoder = None, 16 | text_decoder = None, 17 | tokenizer = None, 18 | config = None, 19 | ): 20 | super().__init__() 21 | self.last_hidden_id_shift = config['last_hidden_id_shift'] 22 | self.tokenizer = tokenizer 23 | self.davinci = DaVinci(encoder, text_decoder, tokenizer, config, init_deit=False, init_dalle=True) 24 | bert_config = BertConfig.from_json_file(config['bert_config']) 25 | self.cls_head = nn.Sequential( 26 | nn.Linear(bert_config.hidden_size * 2, bert_config.hidden_size), 27 | nn.ReLU(), 28 | nn.Linear(bert_config.hidden_size, 2) 29 | ) 30 | 31 | def forward(self, image0, image1, text, targets, alpha=0, train=True): 32 | dummy_input = self.tokenizer([""] * image0.size(0), return_tensors='pt').to(image0.device) 33 | last_state_ids = text.attention_mask.sum(1) - self.last_hidden_id_shift 34 | output0 = self.davinci(image0, 35 | dummy_input, 36 | text, 37 | last_state_ids = last_state_ids, 38 | is_nlvr = True, 39 | train=train, decode=False) 40 | output1 = self.davinci(image1, 41 | dummy_input, 42 | text, 43 | last_state_ids = last_state_ids, 44 | is_nlvr = True, 45 | train=train, decode=False) 46 | 47 | hidden_state = torch.cat([output0, output1], dim=1) 48 | prediction = self.cls_head(hidden_state) 49 | 50 | if train: 51 | loss = F.cross_entropy(prediction, targets) 52 | return loss 53 | else: 54 | return prediction 55 | -------------------------------------------------------------------------------- /models/model_ve.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from models.xbert import BertConfig, BertModel 7 | from models.davinci_pretrain import DaVinci 8 | 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | class DaVinciVE(nn.Module): 13 | def __init__(self, 14 | encoder = None, 15 | text_decoder = None, 16 | tokenizer = None, 17 | config = None, 18 | ): 19 | super().__init__() 20 | self.last_hidden_id_shift = config['last_hidden_id_shift'] 21 | self.tokenizer = tokenizer 22 | self.davinci = DaVinci(encoder, text_decoder, tokenizer, config, init_deit=False, init_dalle=True) 23 | bert_config = BertConfig.from_json_file(config['bert_config']) 24 | self.cls_head = nn.Sequential( 25 | nn.Linear(bert_config.hidden_size, bert_config.hidden_size), 26 | nn.ReLU(), 27 | nn.Linear(bert_config.hidden_size, 3) 28 | ) 29 | 30 | 31 | def forward(self, image, text, targets, alpha=0, train=True): 32 | dummy_input = self.tokenizer([""] * image.size(0), return_tensors='pt').to(image.device) 33 | last_state_ids = text.attention_mask.sum(1) - self.last_hidden_id_shift 34 | output = self.davinci(image, 35 | dummy_input, 36 | text, 37 | last_state_ids = last_state_ids, 38 | is_ve = True, 39 | train=train, decode=False) 40 | prediction = self.cls_head(output) 41 | if train: 42 | loss = F.cross_entropy(prediction, targets) 43 | return loss 44 | else: 45 | return prediction 46 | -------------------------------------------------------------------------------- /models/model_vqa.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from models.davinci_pretrain import DaVinci 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | import numpy as np 13 | 14 | class DaVinciVQA(nn.Module): 15 | def __init__(self, 16 | encoder = None, 17 | text_decoder = None, 18 | tokenizer = None, 19 | config = None, 20 | ): 21 | super().__init__() 22 | 23 | self.tokenizer = tokenizer 24 | self.davinci = DaVinci(encoder, text_decoder, tokenizer, config, init_deit=False, init_dalle=True) 25 | 26 | def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, train=True): 27 | if train: 28 | loss, logits = self.davinci(image, 29 | quesiton, 30 | answer, 31 | is_vqa = True, 32 | k = k, 33 | train=train, decode=False, weights=weights) 34 | return loss 35 | else: 36 | topk_ids, topk_probs = self.davinci(image, 37 | quesiton, 38 | answer, 39 | is_vqa = True, 40 | k = k, 41 | train=train, decode=False) 42 | return topk_ids, topk_probs 43 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamp import AdamP 2 | from .adamw import AdamW 3 | from .adafactor import Adafactor 4 | from .adahessian import Adahessian 5 | from .lookahead import Lookahead 6 | from .nadam import Nadam 7 | from .novograd import NovoGrad 8 | from .nvnovograd import NvNovoGrad 9 | from .radam import RAdam 10 | from .rmsprop_tf import RMSpropTF 11 | from .sgdp import SGDP 12 | 13 | from .optim_factory import create_optimizer -------------------------------------------------------------------------------- /optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class AdamP(Optimizer): 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 21 | super(AdamP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | grad = p.grad.data 66 | beta1, beta2 = group['betas'] 67 | nesterov = group['nesterov'] 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | # Adam 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | 80 | state['step'] += 1 81 | bias_correction1 = 1 - beta1 ** state['step'] 82 | bias_correction2 = 1 - beta2 ** state['step'] 83 | 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 88 | step_size = group['lr'] / bias_correction1 89 | 90 | if nesterov: 91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 92 | else: 93 | perturb = exp_avg / denom 94 | 95 | # Projection 96 | wd_ratio = 1 97 | if len(p.shape) > 1: 98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 99 | 100 | # Weight decay 101 | if group['weight_decay'] > 0: 102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) 103 | 104 | # Step 105 | p.data.add_(-step_size, perturb) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /optim/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 47 | loss = self.base_optimizer.step(closure) 48 | for group in self.param_groups: 49 | group['lookahead_step'] += 1 50 | if group['lookahead_step'] % group['lookahead_k'] == 0: 51 | self.update_slow(group) 52 | return loss 53 | 54 | def state_dict(self): 55 | fast_state_dict = self.base_optimizer.state_dict() 56 | slow_state = { 57 | (id(k) if isinstance(k, torch.Tensor) else k): v 58 | for k, v in self.state.items() 59 | } 60 | fast_state = fast_state_dict['state'] 61 | param_groups = fast_state_dict['param_groups'] 62 | return { 63 | 'state': fast_state, 64 | 'slow_state': slow_state, 65 | 'param_groups': param_groups, 66 | } 67 | 68 | def load_state_dict(self, state_dict): 69 | fast_state_dict = { 70 | 'state': state_dict['state'], 71 | 'param_groups': state_dict['param_groups'], 72 | } 73 | self.base_optimizer.load_state_dict(fast_state_dict) 74 | 75 | # We want to restore the slow state, but share param_groups reference 76 | # with base_optimizer. This is a bit redundant but least code 77 | slow_state_new = False 78 | if 'slow_state' not in state_dict: 79 | print('Loading state_dict from optimizer without Lookahead applied.') 80 | state_dict['slow_state'] = defaultdict(dict) 81 | slow_state_new = True 82 | slow_state_dict = { 83 | 'state': state_dict['slow_state'], 84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 85 | } 86 | super(Lookahead, self).load_state_dict(slow_state_dict) 87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 88 | if slow_state_new: 89 | # reapply defaults to catch missing lookahead specific ones 90 | for name, default in self.defaults.items(): 91 | for group in self.param_groups: 92 | group.setdefault(name, default) 93 | -------------------------------------------------------------------------------- /optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /optim/nvnovograd.py: -------------------------------------------------------------------------------- 1 | """ Nvidia NovoGrad Optimizer. 2 | Original impl by Nvidia from Jasper example: 3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper 4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 5 | - https://arxiv.org/abs/1905.11286 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | import math 11 | 12 | 13 | class NvNovoGrad(Optimizer): 14 | """ 15 | Implements Novograd algorithm. 16 | 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-3) 21 | betas (Tuple[float, float], optional): coefficients used for computing 22 | running averages of gradient and its square (default: (0.95, 0.98)) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 26 | grad_averaging: gradient averaging 27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 28 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 29 | (default: False) 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, 33 | weight_decay=0, grad_averaging=False, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, 44 | grad_averaging=grad_averaging, 45 | amsgrad=amsgrad) 46 | 47 | super(NvNovoGrad, self).__init__(params, defaults) 48 | 49 | def __setstate__(self, state): 50 | super(NvNovoGrad, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError('Sparse gradients are not supported.') 72 | amsgrad = group['amsgrad'] 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | if amsgrad: 84 | # Maintains max of all exp. moving avg. of sq. grad. values 85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | if amsgrad: 89 | max_exp_avg_sq = state['max_exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | norm = torch.sum(torch.pow(grad, 2)) 95 | 96 | if exp_avg_sq == 0: 97 | exp_avg_sq.copy_(norm) 98 | else: 99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 100 | 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) 108 | 109 | grad.div_(denom) 110 | if group['weight_decay'] != 0: 111 | grad.add_(group['weight_decay'], p.data) 112 | if group['grad_averaging']: 113 | grad.mul_(1 - beta1) 114 | exp_avg.mul_(beta1).add_(grad) 115 | 116 | p.data.add_(-group['lr'], exp_avg) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | from torch import optim as optim 6 | 7 | from .adafactor import Adafactor 8 | from .adahessian import Adahessian 9 | from .adamp import AdamP 10 | from .lookahead import Lookahead 11 | from .nadam import Nadam 12 | from .novograd import NovoGrad 13 | from .nvnovograd import NvNovoGrad 14 | from .radam import RAdam 15 | from .rmsprop_tf import RMSpropTF 16 | from .sgdp import SGDP 17 | 18 | try: 19 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 20 | has_apex = True 21 | except ImportError: 22 | has_apex = False 23 | 24 | 25 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 26 | decay = [] 27 | no_decay = [] 28 | for name, param in model.named_parameters(): 29 | if not param.requires_grad: 30 | continue # frozen weights 31 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 32 | no_decay.append(param) 33 | else: 34 | decay.append(param) 35 | return [ 36 | {'params': no_decay, 'weight_decay': 0.}, 37 | {'params': decay, 'weight_decay': weight_decay}] 38 | 39 | 40 | def create_optimizer(args, model, filter_bias_and_bn=True): 41 | opt_lower = args.opt.lower() 42 | weight_decay = args.weight_decay 43 | if weight_decay and filter_bias_and_bn: 44 | skip = {} 45 | if hasattr(model, 'no_weight_decay'): 46 | skip = model.no_weight_decay() 47 | parameters = add_weight_decay(model, weight_decay, skip) 48 | weight_decay = 0. 49 | else: 50 | parameters = model.parameters() 51 | 52 | if 'fused' in opt_lower: 53 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 54 | 55 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 56 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 57 | opt_args['eps'] = args.opt_eps 58 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 59 | opt_args['betas'] = args.opt_betas 60 | if hasattr(args, 'opt_args') and args.opt_args is not None: 61 | opt_args.update(args.opt_args) 62 | 63 | opt_split = opt_lower.split('_') 64 | opt_lower = opt_split[-1] 65 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 66 | opt_args.pop('eps', None) 67 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 68 | elif opt_lower == 'momentum': 69 | opt_args.pop('eps', None) 70 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 71 | elif opt_lower == 'adam': 72 | optimizer = optim.Adam(parameters, **opt_args) 73 | elif opt_lower == 'adamw': 74 | optimizer = optim.AdamW(parameters, **opt_args) 75 | elif opt_lower == 'nadam': 76 | optimizer = Nadam(parameters, **opt_args) 77 | elif opt_lower == 'radam': 78 | optimizer = RAdam(parameters, **opt_args) 79 | elif opt_lower == 'adamp': 80 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 81 | elif opt_lower == 'sgdp': 82 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 83 | elif opt_lower == 'adadelta': 84 | optimizer = optim.Adadelta(parameters, **opt_args) 85 | elif opt_lower == 'adafactor': 86 | if not args.lr: 87 | opt_args['lr'] = None 88 | optimizer = Adafactor(parameters, **opt_args) 89 | elif opt_lower == 'adahessian': 90 | optimizer = Adahessian(parameters, **opt_args) 91 | elif opt_lower == 'rmsprop': 92 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 93 | elif opt_lower == 'rmsproptf': 94 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 95 | elif opt_lower == 'novograd': 96 | optimizer = NovoGrad(parameters, **opt_args) 97 | elif opt_lower == 'nvnovograd': 98 | optimizer = NvNovoGrad(parameters, **opt_args) 99 | elif opt_lower == 'fusedsgd': 100 | opt_args.pop('eps', None) 101 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 102 | elif opt_lower == 'fusedmomentum': 103 | opt_args.pop('eps', None) 104 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 105 | elif opt_lower == 'fusedadam': 106 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 107 | elif opt_lower == 'fusedadamw': 108 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 109 | elif opt_lower == 'fusedlamb': 110 | optimizer = FusedLAMB(parameters, **opt_args) 111 | elif opt_lower == 'fusednovograd': 112 | opt_args.setdefault('betas', (0.95, 0.98)) 113 | optimizer = FusedNovoGrad(parameters, **opt_args) 114 | else: 115 | assert False and "Invalid optimizer" 116 | raise ValueError 117 | 118 | if len(opt_split) > 1: 119 | if opt_split[0] == 'lookahead': 120 | optimizer = Lookahead(optimizer) 121 | 122 | return optimizer 123 | -------------------------------------------------------------------------------- /optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(RAdam, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(RAdam, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('RAdam does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | 53 | state['step'] += 1 54 | buffered = self.buffer[int(state['step'] % 10)] 55 | if state['step'] == buffered[0]: 56 | N_sma, step_size = buffered[1], buffered[2] 57 | else: 58 | buffered[0] = state['step'] 59 | beta2_t = beta2 ** state['step'] 60 | N_sma_max = 2 / (1 - beta2) - 1 61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 62 | buffered[1] = N_sma 63 | 64 | # more conservative since it's an approximated value 65 | if N_sma >= 5: 66 | step_size = group['lr'] * math.sqrt( 67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 68 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 69 | else: 70 | step_size = group['lr'] / (1 - beta1 ** state['step']) 71 | buffered[2] = step_size 72 | 73 | if group['weight_decay'] != 0: 74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | denom = exp_avg_sq.sqrt().add_(group['eps']) 79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 80 | else: 81 | p_data_fp32.add_(-step_size, exp_avg) 82 | 83 | p.data.copy_(p_data_fp32) 84 | 85 | return loss 86 | 87 | 88 | class PlainRAdam(Optimizer): 89 | 90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 92 | 93 | super(PlainRAdam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(PlainRAdam, self).__setstate__(state) 97 | 98 | def step(self, closure=None): 99 | 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data.float() 110 | if grad.is_sparse: 111 | raise RuntimeError('RAdam does not support sparse gradients') 112 | 113 | p_data_fp32 = p.data.float() 114 | 115 | state = self.state[p] 116 | 117 | if len(state) == 0: 118 | state['step'] = 0 119 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | beta1, beta2 = group['betas'] 127 | 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 130 | 131 | state['step'] += 1 132 | beta2_t = beta2 ** state['step'] 133 | N_sma_max = 2 / (1 - beta2) - 1 134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 135 | 136 | if group['weight_decay'] != 0: 137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 138 | 139 | # more conservative since it's an approximated value 140 | if N_sma >= 5: 141 | step_size = group['lr'] * math.sqrt( 142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 143 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 144 | denom = exp_avg_sq.sqrt().add_(group['eps']) 145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 146 | else: 147 | step_size = group['lr'] / (1 - beta1 ** state['step']) 148 | p_data_fp32.add_(-step_size, exp_avg) 149 | 150 | p.data.copy_(p_data_fp32) 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /optim/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | """ RMSProp modified to behave like Tensorflow impl 2 | 3 | Originally cut & paste from PyTorch RMSProp 4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py 5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE 6 | 7 | Modifications Copyright 2020 Ross Wightman 8 | """ 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | class RMSpropTF(Optimizer): 15 | """Implements RMSprop algorithm (TensorFlow style epsilon) 16 | 17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 18 | and a few other modifications to closer match Tensorflow for matching hyper-params. 19 | 20 | Noteworthy changes include: 21 | 1. Epsilon applied inside square-root 22 | 2. square_avg initialized to ones 23 | 3. LR scaling of update accumulated in momentum buffer 24 | 25 | Proposed by G. Hinton in his 26 | `course `_. 27 | 28 | The centered version first appears in `Generating Sequences 29 | With Recurrent Neural Networks `_. 30 | 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-2) 35 | momentum (float, optional): momentum factor (default: 0) 36 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-10) 39 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 40 | the gradient is normalized by an estimation of its variance 41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 44 | update as per defaults in Tensorflow 45 | 46 | """ 47 | 48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, 49 | decoupled_decay=False, lr_in_momentum=True): 50 | if not 0.0 <= lr: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | if not 0.0 <= eps: 53 | raise ValueError("Invalid epsilon value: {}".format(eps)) 54 | if not 0.0 <= momentum: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if not 0.0 <= weight_decay: 57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 58 | if not 0.0 <= alpha: 59 | raise ValueError("Invalid alpha value: {}".format(alpha)) 60 | 61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, 62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) 63 | super(RMSpropTF, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(RMSpropTF, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('momentum', 0) 69 | group.setdefault('centered', False) 70 | 71 | def step(self, closure=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | loss = closure() 81 | 82 | for group in self.param_groups: 83 | for p in group['params']: 84 | if p.grad is None: 85 | continue 86 | grad = p.grad.data 87 | if grad.is_sparse: 88 | raise RuntimeError('RMSprop does not support sparse gradients') 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero 95 | if group['momentum'] > 0: 96 | state['momentum_buffer'] = torch.zeros_like(p.data) 97 | if group['centered']: 98 | state['grad_avg'] = torch.zeros_like(p.data) 99 | 100 | square_avg = state['square_avg'] 101 | one_minus_alpha = 1. - group['alpha'] 102 | 103 | state['step'] += 1 104 | 105 | if group['weight_decay'] != 0: 106 | if 'decoupled_decay' in group and group['decoupled_decay']: 107 | p.data.add_(-group['weight_decay'], p.data) 108 | else: 109 | grad = grad.add(group['weight_decay'], p.data) 110 | 111 | # Tensorflow order of ops for updating squared avg 112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) 113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 114 | 115 | if group['centered']: 116 | grad_avg = state['grad_avg'] 117 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt 120 | else: 121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt 122 | 123 | if group['momentum'] > 0: 124 | buf = state['momentum_buffer'] 125 | # Tensorflow accumulates the LR scaling in the momentum buffer 126 | if 'lr_in_momentum' in group and group['lr_in_momentum']: 127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) 128 | p.data.add_(-buf) 129 | else: 130 | # PyTorch scales the param update by LR 131 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 132 | p.data.add_(-group['lr'], buf) 133 | else: 134 | p.data.addcdiv_(-group['lr'], grad, avg) 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class SGDP(Optimizer): 17 | def __init__(self, params, lr=required, momentum=0, dampening=0, 18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 21 | super(SGDP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | weight_decay = group['weight_decay'] 62 | momentum = group['momentum'] 63 | dampening = group['dampening'] 64 | nesterov = group['nesterov'] 65 | 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['momentum'] = torch.zeros_like(p.data) 75 | 76 | # SGD 77 | buf = state['momentum'] 78 | buf.mul_(momentum).add_(1 - dampening, grad) 79 | if nesterov: 80 | d_p = grad + momentum * buf 81 | else: 82 | d_p = buf 83 | 84 | # Projection 85 | wd_ratio = 1 86 | if len(p.shape) > 1: 87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 88 | 89 | # Weight decay 90 | if weight_decay != 0: 91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 92 | 93 | # Step 94 | p.data.add_(-group['lr'], d_p) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /output/pretrain/REAMDE.md: -------------------------------------------------------------------------------- 1 | # placeholder -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | axial_positional_embedding 2 | omegaconf 3 | pytorch_lightning 4 | g-mlp-pytorch 5 | rotary-embedding-torch 6 | timm==0.4.9 7 | transformers==4.10.1 8 | ruamel_yaml 9 | opencv-python 10 | scikit-image 11 | pandas 12 | tqdm 13 | pycocotools 14 | pycocoevalcap 15 | einops 16 | attrs 17 | tensorboardX 18 | pytorch-lightning 19 | protobuf~=3.19.0 20 | accelerate 21 | datasets >= 1.8.0 22 | scipy 23 | scikit-learn -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | from pdb import set_trace as breakpoint 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class CosineLRScheduler(Scheduler): 20 | """ 21 | Cosine decay with restarts. 22 | This is described in the paper https://arxiv.org/abs/1608.03983. 23 | 24 | Inspiration from 25 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 26 | """ 27 | 28 | def __init__(self, 29 | optimizer: torch.optim.Optimizer, 30 | t_initial: int, 31 | t_mul: float = 1., 32 | lr_min: float = 0., 33 | decay_rate: float = 1., 34 | warmup_t=0, 35 | warmup_lr_init=0, 36 | warmup_prefix=True, 37 | cycle_limit=0, 38 | t_in_epochs=True, 39 | noise_range_t=None, 40 | noise_pct=0.67, 41 | noise_std=1.0, 42 | noise_seed=42, 43 | initialize=True) -> None: 44 | super().__init__( 45 | optimizer, param_group_field="lr", 46 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 47 | initialize=initialize) 48 | 49 | assert t_initial > 0 50 | assert lr_min >= 0 51 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 52 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 53 | "rate since t_initial = t_mul = eta_mul = 1.") 54 | self.t_initial = t_initial 55 | self.t_mul = t_mul 56 | self.lr_min = lr_min 57 | self.decay_rate = decay_rate 58 | self.cycle_limit = cycle_limit 59 | self.warmup_t = warmup_t 60 | self.warmup_lr_init = warmup_lr_init 61 | self.warmup_prefix = warmup_prefix 62 | self.t_in_epochs = t_in_epochs 63 | if self.warmup_t: 64 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 65 | super().update_groups(self.warmup_lr_init) 66 | else: 67 | self.warmup_steps = [1 for _ in self.base_values] 68 | 69 | def _get_lr(self, t): 70 | if t < self.warmup_t: 71 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 72 | else: 73 | if self.warmup_prefix: 74 | t = t - self.warmup_t 75 | 76 | if self.t_mul != 1: 77 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 78 | t_i = self.t_mul ** i * self.t_initial 79 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 80 | else: 81 | i = t // self.t_initial 82 | t_i = self.t_initial 83 | t_curr = t - (self.t_initial * i) 84 | 85 | gamma = self.decay_rate ** i 86 | lr_min = self.lr_min * gamma 87 | lr_max_values = [v * gamma for v in self.base_values] 88 | 89 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 90 | lrs = [ 91 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 92 | ] 93 | else: 94 | lrs = [self.lr_min for _ in self.base_values] 95 | 96 | return lrs 97 | 98 | def get_epoch_values(self, epoch: int): 99 | if self.t_in_epochs: 100 | return self._get_lr(epoch) 101 | else: 102 | return None 103 | 104 | def get_update_values(self, num_updates: int): 105 | if not self.t_in_epochs: 106 | return self._get_lr(num_updates) 107 | else: 108 | return None 109 | 110 | def get_cycle_length(self, cycles=0): 111 | if not cycles: 112 | cycles = self.cycle_limit 113 | cycles = max(1, cycles) 114 | if self.t_mul == 1.0: 115 | return self.t_initial * cycles 116 | else: 117 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 118 | -------------------------------------------------------------------------------- /scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | from .scheduler import Scheduler 10 | 11 | 12 | class PlateauLRScheduler(Scheduler): 13 | """Decay the LR by a factor every time the validation loss plateaus.""" 14 | 15 | def __init__(self, 16 | optimizer, 17 | decay_rate=0.1, 18 | patience_t=10, 19 | verbose=True, 20 | threshold=1e-4, 21 | cooldown_t=0, 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | lr_min=0, 25 | mode='max', 26 | noise_range_t=None, 27 | noise_type='normal', 28 | noise_pct=0.67, 29 | noise_std=1.0, 30 | noise_seed=None, 31 | initialize=True, 32 | ): 33 | super().__init__(optimizer, 'lr', initialize=initialize) 34 | 35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 36 | self.optimizer, 37 | patience=patience_t, 38 | factor=decay_rate, 39 | verbose=verbose, 40 | threshold=threshold, 41 | cooldown=cooldown_t, 42 | mode=mode, 43 | min_lr=lr_min 44 | ) 45 | 46 | self.noise_range = noise_range_t 47 | self.noise_pct = noise_pct 48 | self.noise_type = noise_type 49 | self.noise_std = noise_std 50 | self.noise_seed = noise_seed if noise_seed is not None else 42 51 | self.warmup_t = warmup_t 52 | self.warmup_lr_init = warmup_lr_init 53 | if self.warmup_t: 54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 55 | super().update_groups(self.warmup_lr_init) 56 | else: 57 | self.warmup_steps = [1 for _ in self.base_values] 58 | self.restore_lr = None 59 | 60 | def state_dict(self): 61 | return { 62 | 'best': self.lr_scheduler.best, 63 | 'last_epoch': self.lr_scheduler.last_epoch, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | self.lr_scheduler.best = state_dict['best'] 68 | if 'last_epoch' in state_dict: 69 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 70 | 71 | # override the base class step fn completely 72 | def step(self, epoch, metric=None): 73 | if epoch <= self.warmup_t: 74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 75 | super().update_groups(lrs) 76 | else: 77 | if self.restore_lr is not None: 78 | # restore actual LR from before our last noise perturbation before stepping base 79 | for i, param_group in enumerate(self.optimizer.param_groups): 80 | param_group['lr'] = self.restore_lr[i] 81 | self.restore_lr = None 82 | 83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 84 | 85 | if self.noise_range is not None: 86 | if isinstance(self.noise_range, (list, tuple)): 87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] 88 | else: 89 | apply_noise = epoch >= self.noise_range 90 | if apply_noise: 91 | self._apply_noise(epoch) 92 | 93 | def _apply_noise(self, epoch): 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + epoch) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | 105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 106 | # stepping of base scheduler 107 | restore_lr = [] 108 | for i, param_group in enumerate(self.optimizer.param_groups): 109 | old_lr = float(param_group['lr']) 110 | restore_lr.append(old_lr) 111 | new_lr = old_lr + old_lr * noise 112 | param_group['lr'] = new_lr 113 | self.restore_lr = restore_lr 114 | -------------------------------------------------------------------------------- /scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .tanh_lr import TanhLRScheduler 6 | from .step_lr import StepLRScheduler 7 | from .plateau_lr import PlateauLRScheduler 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | 11 | def create_scheduler(args, optimizer): 12 | num_epochs = args.epochs 13 | 14 | if getattr(args, 'lr_noise', None) is not None: 15 | lr_noise = getattr(args, 'lr_noise') 16 | if isinstance(lr_noise, (list, tuple)): 17 | noise_range = [n * num_epochs for n in lr_noise] 18 | if len(noise_range) == 1: 19 | noise_range = noise_range[0] 20 | else: 21 | noise_range = lr_noise * num_epochs 22 | else: 23 | noise_range = None 24 | 25 | lr_scheduler = None 26 | if args.sched == 'cosine': 27 | lr_scheduler = CosineLRScheduler( 28 | optimizer, 29 | t_initial=num_epochs, 30 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 31 | lr_min=args.min_lr, 32 | decay_rate=args.decay_rate, 33 | warmup_lr_init=args.warmup_lr, 34 | warmup_t=args.warmup_epochs, 35 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 36 | t_in_epochs=True, 37 | noise_range_t=noise_range, 38 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 39 | noise_std=getattr(args, 'lr_noise_std', 1.), 40 | noise_seed=getattr(args, 'seed', 42), 41 | ) 42 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 43 | elif args.sched == 'tanh': 44 | lr_scheduler = TanhLRScheduler( 45 | optimizer, 46 | t_initial=num_epochs, 47 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 48 | lr_min=args.min_lr, 49 | warmup_lr_init=args.warmup_lr, 50 | warmup_t=args.warmup_epochs, 51 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 52 | t_in_epochs=True, 53 | noise_range_t=noise_range, 54 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 55 | noise_std=getattr(args, 'lr_noise_std', 1.), 56 | noise_seed=getattr(args, 'seed', 42), 57 | ) 58 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 59 | elif args.sched == 'step': 60 | lr_scheduler = StepLRScheduler( 61 | optimizer, 62 | decay_t=args.decay_epochs, 63 | decay_rate=args.decay_rate, 64 | warmup_lr_init=args.warmup_lr, 65 | warmup_t=args.warmup_epochs, 66 | noise_range_t=noise_range, 67 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 68 | noise_std=getattr(args, 'lr_noise_std', 1.), 69 | noise_seed=getattr(args, 'seed', 42), 70 | ) 71 | elif args.sched == 'linear': 72 | """ 73 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 74 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 75 | Args: 76 | optimizer (:class:`~torch.optim.Optimizer`): 77 | The optimizer for which to schedule the learning rate. 78 | num_warmup_steps (:obj:`int`): 79 | The number of steps for the warmup phase. 80 | num_training_steps (:obj:`int`): 81 | The total number of training steps. 82 | last_epoch (:obj:`int`, `optional`, defaults to -1): 83 | The index of the last epoch when resuming training. 84 | Return: 85 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 86 | """ 87 | 88 | def lr_lambda(current_step: int): 89 | if current_step < args.num_warmup_steps: 90 | return float(current_step) / float(max(1, args.num_warmup_steps)) 91 | return max( 92 | 0.0, float(args.num_training_steps - current_step) / float(max(1, args.num_training_steps - args.num_warmup_steps)) 93 | ) 94 | 95 | lr_scheduler = LambdaLR(optimizer, lr_lambda, args.last_epoch) 96 | 97 | elif args.sched == 'plateau': 98 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 99 | lr_scheduler = PlateauLRScheduler( 100 | optimizer, 101 | decay_rate=args.decay_rate, 102 | patience_t=args.patience_epochs, 103 | lr_min=args.min_lr, 104 | mode=mode, 105 | warmup_lr_init=args.warmup_lr, 106 | warmup_t=args.warmup_epochs, 107 | cooldown_t=0, 108 | noise_range_t=noise_range, 109 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 110 | noise_std=getattr(args, 'lr_noise_std', 1.), 111 | noise_seed=getattr(args, 'seed', 42), 112 | ) 113 | 114 | return lr_scheduler, num_epochs 115 | -------------------------------------------------------------------------------- /scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class TanhLRScheduler(Scheduler): 19 | """ 20 | Hyberbolic-Tangent decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1806.01593 22 | """ 23 | 24 | def __init__(self, 25 | optimizer: torch.optim.Optimizer, 26 | t_initial: int, 27 | lb: float = -6., 28 | ub: float = 4., 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | assert lb < ub 50 | assert cycle_limit >= 0 51 | assert warmup_t >= 0 52 | assert warmup_lr_init >= 0 53 | self.lb = lb 54 | self.ub = ub 55 | self.t_initial = t_initial 56 | self.t_mul = t_mul 57 | self.lr_min = lr_min 58 | self.decay_rate = decay_rate 59 | self.cycle_limit = cycle_limit 60 | self.warmup_t = warmup_t 61 | self.warmup_lr_init = warmup_lr_init 62 | self.warmup_prefix = warmup_prefix 63 | self.t_in_epochs = t_in_epochs 64 | if self.warmup_t: 65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 67 | super().update_groups(self.warmup_lr_init) 68 | else: 69 | self.warmup_steps = [1 for _ in self.base_values] 70 | 71 | def _get_lr(self, t): 72 | if t < self.warmup_t: 73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 74 | else: 75 | if self.warmup_prefix: 76 | t = t - self.warmup_t 77 | 78 | if self.t_mul != 1: 79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 80 | t_i = self.t_mul ** i * self.t_initial 81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 82 | else: 83 | i = t // self.t_initial 84 | t_i = self.t_initial 85 | t_curr = t - (self.t_initial * i) 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | gamma = self.decay_rate ** i 89 | lr_min = self.lr_min * gamma 90 | lr_max_values = [v * gamma for v in self.base_values] 91 | 92 | tr = t_curr / t_i 93 | lrs = [ 94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 95 | for lr_max in lr_max_values 96 | ] 97 | else: 98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 99 | return lrs 100 | 101 | def get_epoch_values(self, epoch: int): 102 | if self.t_in_epochs: 103 | return self._get_lr(epoch) 104 | else: 105 | return None 106 | 107 | def get_update_values(self, num_updates: int): 108 | if not self.t_in_epochs: 109 | return self._get_lr(num_updates) 110 | else: 111 | return None 112 | 113 | def get_cycle_length(self, cycles=0): 114 | if not cycles: 115 | cycles = self.cycle_limit 116 | cycles = max(1, cycles) 117 | if self.t_mul == 1.0: 118 | return self.t_initial * cycles 119 | else: 120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 121 | -------------------------------------------------------------------------------- /taming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/__init__.py -------------------------------------------------------------------------------- /taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /taming/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/models/__init__.py -------------------------------------------------------------------------------- /taming/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/.DS_Store -------------------------------------------------------------------------------- /taming/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/__init__.py -------------------------------------------------------------------------------- /taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /taming/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /taming/modules/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/discriminator/__init__.py -------------------------------------------------------------------------------- /taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name is not "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /taming/modules/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/misc/__init__.py -------------------------------------------------------------------------------- /taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /taming/modules/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/transformer/__init__.py -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /taming/modules/vqvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shizhediao/DaVinci/283ea6f2977f61b113aee042ad426244d3f20f69/taming/modules/vqvae/__init__.py -------------------------------------------------------------------------------- /taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /util/hdfs_io.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | #!/usr/bin/env python 7 | # -*- coding: utf-8 -*- 8 | import sys 9 | from typing import IO, Any, List 10 | 11 | import shutil 12 | import subprocess 13 | from contextlib import contextmanager 14 | import os 15 | import glob 16 | import threading 17 | 18 | HADOOP_BIN = 'HADOOP_ROOT_LOGGER=ERROR,console /opt/tiger/yarn_deploy/hadoop/bin/hdfs' 19 | 20 | __all__ = ['hlist_files', 'hopen', 'hexists', 'hmkdir', 'hglob', 'hisdir', 'hcountline'] 21 | 22 | 23 | @contextmanager # type: ignore 24 | def hopen(hdfs_path: str, mode: str = "r") -> IO[Any]: 25 | """ 26 | 打开一个 hdfs 文件, 用 contextmanager. 27 | 28 | Args: 29 | hfdfs_path (str): hdfs文件路径 30 | mode (str): 打开模式,支持 ["r", "w", "wa"] 31 | """ 32 | pipe = None 33 | if mode.startswith("r"): 34 | pipe = subprocess.Popen( 35 | "{} dfs -text {}".format(HADOOP_BIN, hdfs_path), shell=True, stdout=subprocess.PIPE) 36 | yield pipe.stdout 37 | pipe.stdout.close() # type: ignore 38 | pipe.wait() 39 | return 40 | if mode == "wa" or mode == "a": 41 | pipe = subprocess.Popen( 42 | "{} dfs -appendToFile - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE) 43 | yield pipe.stdin 44 | pipe.stdin.close() # type: ignore 45 | pipe.wait() 46 | return 47 | if mode.startswith("w"): 48 | pipe = subprocess.Popen( 49 | "{} dfs -put -f - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE) 50 | yield pipe.stdin 51 | pipe.stdin.close() # type: ignore 52 | pipe.wait() 53 | return 54 | raise RuntimeError("unsupported io mode: {}".format(mode)) 55 | 56 | 57 | def hlist_files(folders: List[str]) -> List[str]: 58 | """ 59 | 罗列一些 hdfs 路径下的文件。 60 | 61 | Args: 62 | folders (List): hdfs文件路径的list 63 | Returns: 64 | 一个list of hdfs 路径 65 | """ 66 | files = [] 67 | for folder in folders: 68 | if folder.startswith('hdfs'): 69 | pipe = subprocess.Popen("{} dfs -ls {}".format(HADOOP_BIN, folder), shell=True, 70 | stdout=subprocess.PIPE) 71 | # output, _ = pipe.communicate() 72 | for line in pipe.stdout: # type: ignore 73 | line = line.strip() 74 | # drwxr-xr-x - user group 4 file 75 | if len(line.split()) < 5: 76 | continue 77 | files.append(line.split()[-1].decode("utf8")) 78 | pipe.stdout.close() # type: ignore 79 | pipe.wait() 80 | else: 81 | if os.path.isdir(folder): 82 | files.extend([os.path.join(folder, d) for d in os.listdir(folder)]) 83 | elif os.path.isfile(folder): 84 | files.append(folder) 85 | else: 86 | print('Path {} is invalid'.format(folder)) 87 | sys.stdout.flush() 88 | 89 | return files 90 | 91 | 92 | def hexists(file_path: str) -> bool: 93 | """ hdfs capable to check whether a file_path is exists """ 94 | if file_path.startswith('hdfs'): 95 | return os.system("{} dfs -test -e {}".format(HADOOP_BIN, file_path)) == 0 96 | return os.path.exists(file_path) 97 | 98 | 99 | def hisdir(file_path: str) -> bool: 100 | """ hdfs capable to check whether a file_path is a dir """ 101 | if file_path.startswith('hdfs'): 102 | flag1 = os.system("{} dfs -test -e {}".format(HADOOP_BIN, file_path)) # 0:路径存在 103 | flag2 = os.system("{} dfs -test -f {}".format(HADOOP_BIN, file_path)) # 0:是文件 1:不是文件 104 | flag = ((flag1 == 0) and (flag2 == 1)) 105 | return flag 106 | return os.path.isdir(file_path) 107 | 108 | 109 | def hmkdir(file_path: str) -> bool: 110 | """ hdfs mkdir """ 111 | if file_path.startswith('hdfs'): 112 | os.system("{} dfs -mkdir -p {}".format(HADOOP_BIN, file_path)) 113 | else: 114 | os.mkdir(file_path) 115 | return True 116 | 117 | 118 | def hcopy(from_path: str, to_path: str) -> bool: 119 | """ hdfs copy """ 120 | if to_path.startswith("hdfs"): 121 | if from_path.startswith("hdfs"): 122 | os.system("{} dfs -cp -f {} {}".format(HADOOP_BIN, from_path, to_path)) 123 | else: 124 | os.system("{} dfs -copyFromLocal -f {} {}".format(HADOOP_BIN, from_path, to_path)) 125 | else: 126 | if from_path.startswith("hdfs"): 127 | os.system("{} dfs -text {} > {}".format(HADOOP_BIN, from_path, to_path)) 128 | else: 129 | shutil.copy(from_path, to_path) 130 | return True 131 | 132 | 133 | def hglob(search_path, sort_by_time=False): 134 | """ hdfs glob """ 135 | if search_path.startswith("hdfs"): 136 | if sort_by_time: 137 | hdfs_command = HADOOP_BIN + ' dfs -ls %s | sort -k6,7' % search_path 138 | else: 139 | hdfs_command = HADOOP_BIN + ' dfs -ls %s' % search_path 140 | path_list = [] 141 | files = os.popen(hdfs_command).read() 142 | files = files.split("\n") 143 | for file in files: 144 | if 'hdfs' in file: 145 | startindex = file.index('hdfs') 146 | path_list.append(file[startindex:]) 147 | return path_list 148 | else: 149 | files = glob.glob(search_path) 150 | if sort_by_time: 151 | files = sorted(files, key=lambda x: os.path.getmtime(x)) 152 | return files 153 | 154 | 155 | def htext_list(files, target_folder): 156 | for fn in files: 157 | name = fn.split('/')[-1] 158 | hdfs_command = HADOOP_BIN + ' dfs -text %s > %s/%s' % (fn, target_folder, name) 159 | os.system(hdfs_command) 160 | 161 | 162 | def hmget(files, target_folder, num_thread=16): 163 | """ 将整个hdfs 文件夹 get下来,但是不是简单的get,因为一些hdfs文件是压缩的,需要解压""" 164 | part = len(files) // num_thread 165 | thread_list = [] 166 | for i in range(num_thread): 167 | start = part * i 168 | if i == num_thread - 1: 169 | end = len(files) 170 | else: 171 | end = start + part 172 | t = threading.Thread(target=htext_list, kwargs={ 173 | 'files': files[start:end], 'target_folder': target_folder}) 174 | thread_list.append(t) 175 | 176 | for t in thread_list: 177 | t.setDaemon(True) 178 | t.start() 179 | 180 | for t in thread_list: 181 | t.join() 182 | 183 | 184 | def hcountline(path): 185 | ''' 186 | count line in file 187 | ''' 188 | count = 0 189 | if path.startswith('hdfs'): 190 | with hopen(path, 'r') as f: 191 | for line in f: 192 | count += 1 193 | else: 194 | with open(path, 'r') as f: 195 | for line in f: 196 | count += 1 197 | return count 198 | -------------------------------------------------------------------------------- /util/torch_io.py: -------------------------------------------------------------------------------- 1 | # Write and Paint: Generative Vision-Language Models are Unified Modal Learners (https://arxiv.org/abs/2206.07699) 2 | # Github: https://github.com/shizhediao/DaVinci 3 | # Copyright (c) 2023, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | #!/usr/bin/env python 7 | # -*- coding: utf-8 -*- 8 | ''' torch model hdfs io warpper ''' 9 | 10 | import io 11 | import torch 12 | 13 | from .hdfs_io import hopen 14 | 15 | 16 | def load(filepath: str, **kwargs): 17 | """ load model """ 18 | if not filepath.startswith("hdfs://"): 19 | return torch.load(filepath, **kwargs) 20 | with hopen(filepath, "rb") as reader: 21 | accessor = io.BytesIO(reader.read()) 22 | state_dict = torch.load(accessor, **kwargs) 23 | del accessor 24 | return state_dict 25 | 26 | 27 | def save(obj, filepath: str, **kwargs): 28 | """ save model """ 29 | if filepath.startswith("hdfs://"): 30 | with hopen(filepath, "wb") as writer: 31 | torch.save(obj, writer, **kwargs) 32 | else: 33 | torch.save(obj, filepath, **kwargs) 34 | -------------------------------------------------------------------------------- /vqaTools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | --------------------------------------------------------------------------------